diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index cdc141f..c4bd04b 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -13,12 +13,8 @@
#### 🔀 变更说明 | Description of Change
-
+
#### 📝 补充信息 | Additional Information
-
----
-
-> 💡 **提示**:如果修改了 `webui/` 目录下的文件,PR 合并后 CI 会自动构建并提交产物,无需手动构建。
\ No newline at end of file
diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml
index 67689cc..00cecee 100644
--- a/.github/workflows/release-artifacts.yml
+++ b/.github/workflows/release-artifacts.yml
@@ -12,6 +12,9 @@ permissions:
jobs:
build-and-upload:
runs-on: ubuntu-latest
+ env:
+ DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
+ DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -87,10 +90,11 @@ jobs:
password: ${{ secrets.GITHUB_TOKEN }}
- name: Log in to Docker Hub
+ if: "${{ env.DOCKERHUB_USERNAME != '' }}"
uses: docker/login-action@v3
with:
- username: ${{ secrets.DOCKERHUB_USERNAME }}
- password: ${{ secrets.DOCKERHUB_TOKEN }}
+ username: ${{ env.DOCKERHUB_USERNAME }}
+ password: ${{ env.DOCKERHUB_TOKEN }}
- name: Extract Docker metadata
id: meta_release
@@ -98,7 +102,7 @@ jobs:
with:
images: |
ghcr.io/${{ github.repository }}
- cjackhwang/ds2api
+ ${{ env.DOCKERHUB_USERNAME || 'cjackhwang' }}/ds2api
tags: |
type=raw,value=${{ github.event.release.tag_name }}
type=raw,value=latest
diff --git a/CONTRIBUTING.en.md b/CONTRIBUTING.en.md
index baf5eae..212cbb7 100644
--- a/CONTRIBUTING.en.md
+++ b/CONTRIBUTING.en.md
@@ -86,7 +86,7 @@ Manually build WebUI to `static/admin/`:
go test ./...
# End-to-end live tests (real accounts)
-./scripts/testsuite/run-live.sh
+./tests/scripts/run-live.sh
```
## Project Structure
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index c75d450..3bbd5b3 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -86,7 +86,7 @@ docker-compose -f docker-compose.dev.yml up
go test ./...
# 端到端全链路测试(真实账号)
-./scripts/testsuite/run-live.sh
+./tests/scripts/run-live.sh
```
## 项目结构
diff --git a/DEPLOY.en.md b/DEPLOY.en.md
index 8a62c98..046df83 100644
--- a/DEPLOY.en.md
+++ b/DEPLOY.en.md
@@ -518,7 +518,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
Run the full live testsuite before release (real account tests):
```bash
-./scripts/testsuite/run-live.sh
+./tests/scripts/run-live.sh
```
With custom flags:
diff --git a/DEPLOY.md b/DEPLOY.md
index e5b0630..4a1efcb 100644
--- a/DEPLOY.md
+++ b/DEPLOY.md
@@ -518,7 +518,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
建议在发布前执行完整的端到端测试集(使用真实账号):
```bash
-./scripts/testsuite/run-live.sh
+./tests/scripts/run-live.sh
```
可自定义参数:
diff --git a/README.MD b/README.MD
index 261e34a..b26cc72 100644
--- a/README.MD
+++ b/README.MD
@@ -283,6 +283,9 @@ cp opencode.json.example opencode.json
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容旧名) | — |
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel 混合流式内部鉴权密钥 | 回退用 `DS2API_ADMIN_KEY` |
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease 过期秒数 | `900` |
+| `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 |
+| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | 本地抓包保留条数(超出自动淘汰) | `5` |
+| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | 单条响应体最大记录字节数 | `2097152` |
| `VERCEL_TOKEN` | Vercel 同步 token | — |
| `VERCEL_PROJECT_ID` | Vercel 项目 ID | — |
| `VERCEL_TEAM_ID` | Vercel 团队 ID | — |
@@ -321,6 +324,29 @@ cp opencode.json.example opencode.json
3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content`
4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出
+## 本地开发抓包工具
+
+用于定位「responses 思考流/工具调用」等问题。开启后会自动记录最近 N 条 DeepSeek 对话上游请求体与响应体(默认 5 条,超出自动淘汰)。
+
+启用示例:
+
+```bash
+DS2API_DEV_PACKET_CAPTURE=true \
+DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
+go run ./cmd/ds2api
+```
+
+查询/清空(需 Admin JWT):
+
+- `GET /admin/dev/captures`:查看抓包列表(最新在前)
+- `DELETE /admin/dev/captures`:清空抓包
+
+返回字段包含:
+
+- `request_body`:发送给 DeepSeek 的完整请求体
+- `response_body`:上游返回的原始流式内容拼接文本
+- `response_truncated`:是否触发单条大小截断
+
## 项目结构
```text
@@ -350,8 +376,10 @@ ds2api/
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
│ └── locales/ # 中英文语言包(zh.json / en.json)
├── scripts/
-│ ├── build-webui.sh # WebUI 手动构建脚本
-│ └── testsuite/ # 测试集运行脚本
+│ └── build-webui.sh # WebUI 手动构建脚本
+├── tests/
+│ ├── compat/ # 兼容性测试夹具与期望输出
+│ └── scripts/ # 统一测试脚本入口(unit/e2e)
├── static/admin/ # WebUI 构建产物(不提交到 Git)
├── .github/
│ ├── workflows/ # GitHub Actions(Release 自动构建)
@@ -379,11 +407,11 @@ ds2api/
## 测试
```bash
-# 单元测试
-go test ./...
+# 单元测试(Go + Node)
+./tests/scripts/run-unit-all.sh
# 一键端到端全链路测试(真实账号,生成完整请求/响应日志)
-./scripts/testsuite/run-live.sh
+./tests/scripts/run-live.sh
# 或自定义参数
go run ./cmd/ds2api-tests \
diff --git a/README.en.md b/README.en.md
index 5d2f326..69b47bd 100644
--- a/README.en.md
+++ b/README.en.md
@@ -283,6 +283,9 @@ cp opencode.json.example opencode.json
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL seconds | `900` |
+| `DS2API_DEV_PACKET_CAPTURE` | Local dev packet capture switch (record recent request/response bodies) | Enabled by default on non-Vercel local runtime |
+| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | Number of captured sessions to retain (auto-evict overflow) | `5` |
+| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | Max recorded bytes per captured response body | `2097152` |
| `VERCEL_TOKEN` | Vercel sync token | — |
| `VERCEL_PROJECT_ID` | Vercel project ID | — |
| `VERCEL_TEAM_ID` | Vercel team ID | — |
@@ -321,6 +324,29 @@ When `tools` is present in the request, DS2API performs anti-leak handling:
3. Confirmed toolcall JSON fragments are never leaked into `delta.content`
4. Natural language before/after toolcalls keeps original order, with incremental argument output supported
+## Local Dev Packet Capture
+
+This is for debugging issues such as Responses reasoning streaming and tool-call handoff. When enabled, DS2API stores the latest N DeepSeek conversation payload pairs (request body + upstream response body), defaulting to 5 entries with auto-eviction.
+
+Enable example:
+
+```bash
+DS2API_DEV_PACKET_CAPTURE=true \
+DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
+go run ./cmd/ds2api
+```
+
+Inspect/clear (Admin JWT required):
+
+- `GET /admin/dev/captures`: list captured items (newest first)
+- `DELETE /admin/dev/captures`: clear captured items
+
+Response fields include:
+
+- `request_body`: full payload sent to DeepSeek
+- `response_body`: concatenated raw upstream stream body text
+- `response_truncated`: whether body-size truncation happened
+
## Project Structure
```text
@@ -350,8 +376,10 @@ ds2api/
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
│ └── locales/ # Language packs (zh.json / en.json)
├── scripts/
-│ ├── build-webui.sh # Manual WebUI build script
-│ └── testsuite/ # Testsuite runner scripts
+│ └── build-webui.sh # Manual WebUI build script
+├── tests/
+│ ├── compat/ # Compatibility fixtures and expected outputs
+│ └── scripts/ # Unified test script entrypoints (unit/e2e)
├── static/admin/ # WebUI build output (not committed to Git)
├── .github/
│ ├── workflows/ # GitHub Actions (Release artifact automation)
@@ -379,11 +407,11 @@ ds2api/
## Testing
```bash
-# Unit tests
-go test ./...
+# Unit tests (Go + Node)
+./tests/scripts/run-unit-all.sh
# One-command live end-to-end tests (real accounts, full request/response logs)
-./scripts/testsuite/run-live.sh
+./tests/scripts/run-live.sh
# Or with custom flags
go run ./cmd/ds2api-tests \
diff --git a/TESTING.md b/TESTING.md
index ce349ec..f8e532a 100644
--- a/TESTING.md
+++ b/TESTING.md
@@ -8,8 +8,10 @@ DS2API 提供两个层级的测试:
| 层级 | 命令 | 说明 |
| --- | --- | --- |
-| 单元测试 | `go test ./...` | 不需要真实账号 |
-| 端到端测试 | `./scripts/testsuite/run-live.sh` | 使用真实账号执行全链路测试 |
+| 单元测试(Go) | `./tests/scripts/run-unit-go.sh` | 不需要真实账号 |
+| 单元测试(Node) | `./tests/scripts/run-unit-node.sh` | 不需要真实账号 |
+| 单元测试(全部) | `./tests/scripts/run-unit-all.sh` | 不需要真实账号 |
+| 端到端测试 | `./tests/scripts/run-live.sh` | 使用真实账号执行全链路测试 |
端到端测试集会录制完整的请求/响应日志,用于故障排查。
@@ -20,17 +22,19 @@ DS2API 提供两个层级的测试:
### 单元测试 | Unit Tests
```bash
-go test ./...
+./tests/scripts/run-unit-all.sh
```
```bash
-node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js
+# 或按语言拆分执行
+./tests/scripts/run-unit-go.sh
+./tests/scripts/run-unit-node.sh
```
### 端到端测试 | End-to-End Tests
```bash
-./scripts/testsuite/run-live.sh
+./tests/scripts/run-live.sh
```
**默认行为**:
@@ -179,7 +183,7 @@ go run ./cmd/ds2api-tests \
```bash
# 确保 config.json 存在且包含有效测试账号
-./scripts/testsuite/run-live.sh
+./tests/scripts/run-live.sh
exit_code=$?
if [ $exit_code -ne 0 ]; then
echo "Tests failed! Check artifacts for details."
diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go
index 282b569..2fa6796 100644
--- a/internal/adapter/claude/handler.go
+++ b/internal/adapter/claude/handler.go
@@ -38,6 +38,10 @@ func RegisterRoutes(r chi.Router, h *Handler) {
r.Get("/anthropic/v1/models", h.ListModels)
r.Post("/anthropic/v1/messages", h.Messages)
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
+ r.Post("/v1/messages", h.Messages)
+ r.Post("/messages", h.Messages)
+ r.Post("/v1/messages/count_tokens", h.CountTokens)
+ r.Post("/messages/count_tokens", h.CountTokens)
}
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
@@ -167,7 +171,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
- canFlush := rc.Flush() == nil
+ _, canFlush := w.(http.Flusher)
if !canFlush {
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
}
@@ -250,7 +254,7 @@ func normalizeClaudeMessages(messages []any) []any {
}
}
if typeStr == "tool_result" {
- parts = append(parts, fmt.Sprintf("%v", b["content"]))
+ parts = append(parts, formatClaudeToolResultForPrompt(b))
}
}
copied["content"] = strings.Join(parts, "\n")
@@ -272,10 +276,36 @@ func buildClaudeToolPrompt(tools []any) string {
schema, _ := json.Marshal(m["input_schema"])
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
}
- parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}")
+ parts = append(parts,
+ "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
+ "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
+ "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
+ )
return strings.Join(parts, "\n\n")
}
+func formatClaudeToolResultForPrompt(block map[string]any) string {
+ if block == nil {
+ return ""
+ }
+ toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
+ if toolCallID == "" {
+ toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
+ }
+ if toolCallID == "" {
+ toolCallID = "unknown"
+ }
+ name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
+ if name == "" {
+ name = "unknown"
+ }
+ content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
+ if content == "" {
+ content = "null"
+ }
+ return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
+}
+
func hasSystemMessage(messages []any) bool {
for _, m := range messages {
msg, ok := m.(map[string]any)
diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go
index 73d2fab..ae75d8e 100644
--- a/internal/adapter/claude/handler_util_test.go
+++ b/internal/adapter/claude/handler_util_test.go
@@ -1,6 +1,7 @@
package claude
import (
+ "strings"
"testing"
)
@@ -48,8 +49,9 @@ func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
- if m["content"] != "tool output" {
- t.Fatalf("expected 'tool output', got %q", m["content"])
+ content, _ := m["content"].(string)
+ if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") {
+ t.Fatalf("expected serialized tool result marker, got %q", content)
}
}
diff --git a/internal/adapter/claude/route_alias_test.go b/internal/adapter/claude/route_alias_test.go
new file mode 100644
index 0000000..f01e5e3
--- /dev/null
+++ b/internal/adapter/claude/route_alias_test.go
@@ -0,0 +1,44 @@
+package claude
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/go-chi/chi/v5"
+
+ "ds2api/internal/auth"
+)
+
+type routeAliasAuthStub struct{}
+
+func (routeAliasAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
+ return nil, auth.ErrUnauthorized
+}
+
+func (routeAliasAuthStub) Release(_ *auth.RequestAuth) {}
+
+func TestClaudeRouteAliasesDoNot404(t *testing.T) {
+ h := &Handler{
+ Auth: routeAliasAuthStub{},
+ }
+ r := chi.NewRouter()
+ RegisterRoutes(r, h)
+
+ paths := []string{
+ "/anthropic/v1/messages",
+ "/v1/messages",
+ "/messages",
+ "/anthropic/v1/messages/count_tokens",
+ "/v1/messages/count_tokens",
+ "/messages/count_tokens",
+ }
+ for _, path := range paths {
+ req := httptest.NewRequest(http.MethodPost, path, nil)
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+ if rec.Code == http.StatusNotFound {
+ t.Fatalf("expected route %s to be registered, got 404", path)
+ }
+ }
+}
diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go
index cdbb675..23520c0 100644
--- a/internal/adapter/claude/standard_request.go
+++ b/internal/adapter/claude/standard_request.go
@@ -27,9 +27,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
payload := cloneMap(req)
payload["messages"] = normalizedMessages
toolsRequested, _ := req["tools"].([]any)
- if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) {
- payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...)
- }
+ payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested)
dsPayload := convertClaudeToDeepSeek(payload, store)
dsModel, _ := dsPayload["model"].(string)
@@ -57,3 +55,59 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
NormalizedMessages: normalizedMessages,
}, nil
}
+
+func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any {
+ if len(tools) == 0 {
+ return normalizedMessages
+ }
+ toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools))
+ if toolPrompt == "" {
+ return normalizedMessages
+ }
+
+ // Prefer top-level Anthropic-style system prompt when available.
+ if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" {
+ payload["system"] = mergeSystemPrompt(systemText, toolPrompt)
+ return normalizedMessages
+ }
+
+ messages := cloneAnySlice(normalizedMessages)
+ for i := range messages {
+ msg, ok := messages[i].(map[string]any)
+ if !ok {
+ continue
+ }
+ role, _ := msg["role"].(string)
+ if !strings.EqualFold(strings.TrimSpace(role), "system") {
+ continue
+ }
+ copied := cloneMap(msg)
+ copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt)
+ messages[i] = copied
+ return messages
+ }
+
+ return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...)
+}
+
+func mergeSystemPrompt(base, extra string) string {
+ base = strings.TrimSpace(base)
+ extra = strings.TrimSpace(extra)
+ switch {
+ case base == "":
+ return extra
+ case extra == "":
+ return base
+ default:
+ return base + "\n\n" + extra
+ }
+}
+
+func cloneAnySlice(in []any) []any {
+ if len(in) == 0 {
+ return nil
+ }
+ out := make([]any, len(in))
+ copy(out, in)
+ return out
+}
diff --git a/internal/adapter/claude/standard_request_test.go b/internal/adapter/claude/standard_request_test.go
index 7ffdfb8..6110124 100644
--- a/internal/adapter/claude/standard_request_test.go
+++ b/internal/adapter/claude/standard_request_test.go
@@ -36,3 +36,57 @@ func TestNormalizeClaudeRequest(t *testing.T) {
t.Fatalf("expected non-empty final prompt")
}
}
+
+func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) {
+ t.Setenv("DS2API_CONFIG_JSON", `{}`)
+ store := config.LoadStore()
+ req := map[string]any{
+ "model": "claude-sonnet-4-5",
+ "messages": []any{
+ map[string]any{"role": "system", "content": "baseline rule"},
+ map[string]any{"role": "user", "content": "hello"},
+ },
+ "tools": []any{
+ map[string]any{"name": "search", "description": "Search"},
+ },
+ }
+
+ norm, err := normalizeClaudeRequest(store, req)
+ if err != nil {
+ t.Fatalf("normalize failed: %v", err)
+ }
+
+ if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
+ t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt)
+ }
+ if !containsStr(norm.Standard.FinalPrompt, "baseline rule") {
+ t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt)
+ }
+}
+
+func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) {
+ t.Setenv("DS2API_CONFIG_JSON", `{}`)
+ store := config.LoadStore()
+ req := map[string]any{
+ "model": "claude-sonnet-4-5",
+ "system": "top-level system",
+ "messages": []any{
+ map[string]any{"role": "user", "content": "hello"},
+ },
+ "tools": []any{
+ map[string]any{"name": "search", "description": "Search"},
+ },
+ }
+
+ norm, err := normalizeClaudeRequest(store, req)
+ if err != nil {
+ t.Fatalf("normalize failed: %v", err)
+ }
+
+ if !containsStr(norm.Standard.FinalPrompt, "top-level system") {
+ t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt)
+ }
+ if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
+ t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt)
+ }
+}
diff --git a/internal/adapter/claude/stream_status_test.go b/internal/adapter/claude/stream_status_test.go
new file mode 100644
index 0000000..c3936de
--- /dev/null
+++ b/internal/adapter/claude/stream_status_test.go
@@ -0,0 +1,100 @@
+package claude
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/go-chi/chi/v5"
+ chimw "github.com/go-chi/chi/v5/middleware"
+
+ "ds2api/internal/auth"
+)
+
+type streamStatusClaudeAuthStub struct{}
+
+func (streamStatusClaudeAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
+ return &auth.RequestAuth{
+ UseConfigToken: false,
+ DeepSeekToken: "direct-token",
+ CallerID: "caller:test",
+ TriedAccounts: map[string]bool{},
+ }, nil
+}
+
+func (streamStatusClaudeAuthStub) Release(_ *auth.RequestAuth) {}
+
+type streamStatusClaudeDSStub struct{}
+
+func (streamStatusClaudeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
+ return "session-id", nil
+}
+
+func (streamStatusClaudeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
+ return "pow", nil
+}
+
+func (streamStatusClaudeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
+ body := "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n" + "data: [DONE]\n"
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Header: make(http.Header),
+ Body: ioNopCloser{strings.NewReader(body)},
+ }, nil
+}
+
+type ioNopCloser struct {
+ *strings.Reader
+}
+
+func (ioNopCloser) Close() error { return nil }
+
+type streamStatusClaudeStoreStub struct{}
+
+func (streamStatusClaudeStoreStub) ClaudeMapping() map[string]string {
+ return map[string]string{
+ "fast": "deepseek-chat",
+ "slow": "deepseek-reasoner",
+ }
+}
+
+func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
+ next.ServeHTTP(ww, r)
+ *statuses = append(*statuses, ww.Status())
+ })
+ }
+}
+
+func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) {
+ statuses := make([]int, 0, 1)
+ h := &Handler{
+ Store: streamStatusClaudeStoreStub{},
+ Auth: streamStatusClaudeAuthStub{},
+ DS: streamStatusClaudeDSStub{},
+ }
+ r := chi.NewRouter()
+ r.Use(captureClaudeStatusMiddleware(&statuses))
+ RegisterRoutes(r, h)
+
+ reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`
+ req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody))
+ req.Header.Set("Authorization", "Bearer direct-token")
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
+ }
+ if len(statuses) != 1 {
+ t.Fatalf("expected one captured status, got %d", len(statuses))
+ }
+ if statuses[0] != http.StatusOK {
+ t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
+ }
+}
diff --git a/internal/adapter/gemini/convert.go b/internal/adapter/gemini/convert.go
new file mode 100644
index 0000000..3f63579
--- /dev/null
+++ b/internal/adapter/gemini/convert.go
@@ -0,0 +1,313 @@
+package gemini
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "ds2api/internal/adapter/openai"
+ "ds2api/internal/config"
+ "ds2api/internal/util"
+)
+
+func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
+ requestedModel := strings.TrimSpace(routeModel)
+ if requestedModel == "" {
+ return util.StandardRequest{}, fmt.Errorf("model is required in request path")
+ }
+
+ resolvedModel, ok := config.ResolveModel(store, requestedModel)
+ if !ok {
+ return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
+ }
+ thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
+
+ messagesRaw := geminiMessagesFromRequest(req)
+ if len(messagesRaw) == 0 {
+ return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
+ }
+
+ toolsRaw := convertGeminiTools(req["tools"])
+ finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
+ passThrough := collectGeminiPassThrough(req)
+
+ return util.StandardRequest{
+ Surface: "google_gemini",
+ RequestedModel: requestedModel,
+ ResolvedModel: resolvedModel,
+ ResponseModel: requestedModel,
+ Messages: messagesRaw,
+ FinalPrompt: finalPrompt,
+ ToolNames: toolNames,
+ Stream: stream,
+ Thinking: thinkingEnabled,
+ Search: searchEnabled,
+ PassThrough: passThrough,
+ }, nil
+}
+
+func geminiMessagesFromRequest(req map[string]any) []any {
+ out := make([]any, 0, 8)
+ if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
+ out = append(out, map[string]any{
+ "role": "system",
+ "content": sys,
+ })
+ }
+
+ contents, _ := req["contents"].([]any)
+ for _, item := range contents {
+ content, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ role := mapGeminiRole(content["role"])
+ if role == "" {
+ role = "user"
+ }
+ parts, _ := content["parts"].([]any)
+ if len(parts) == 0 {
+ if text := strings.TrimSpace(asString(content["text"])); text != "" {
+ out = append(out, map[string]any{
+ "role": role,
+ "content": text,
+ })
+ }
+ continue
+ }
+
+ textParts := make([]string, 0, len(parts))
+ flushText := func() {
+ if len(textParts) == 0 {
+ return
+ }
+ out = append(out, map[string]any{
+ "role": role,
+ "content": strings.Join(textParts, "\n"),
+ })
+ textParts = textParts[:0]
+ }
+
+ for _, rawPart := range parts {
+ part, ok := rawPart.(map[string]any)
+ if !ok {
+ continue
+ }
+ if text := strings.TrimSpace(asString(part["text"])); text != "" {
+ textParts = append(textParts, text)
+ continue
+ }
+
+ if fnCall, ok := part["functionCall"].(map[string]any); ok {
+ flushText()
+ if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
+ callID := strings.TrimSpace(asString(fnCall["id"]))
+ if callID == "" {
+ callID = "call_gemini"
+ }
+ out = append(out, map[string]any{
+ "role": "assistant",
+ "tool_calls": []any{
+ map[string]any{
+ "id": callID,
+ "type": "function",
+ "function": map[string]any{
+ "name": name,
+ "arguments": stringifyJSON(fnCall["args"]),
+ },
+ },
+ },
+ })
+ }
+ continue
+ }
+
+ if fnResp, ok := part["functionResponse"].(map[string]any); ok {
+ flushText()
+ name := strings.TrimSpace(asString(fnResp["name"]))
+ callID := strings.TrimSpace(asString(fnResp["id"]))
+ if callID == "" {
+ callID = strings.TrimSpace(asString(fnResp["callId"]))
+ }
+ if callID == "" {
+ callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
+ }
+ if callID == "" {
+ callID = "call_gemini"
+ }
+ content := fnResp["response"]
+ if content == nil {
+ content = fnResp["output"]
+ }
+ if content == nil {
+ content = ""
+ }
+ msg := map[string]any{
+ "role": "tool",
+ "tool_call_id": callID,
+ "content": content,
+ }
+ if name != "" {
+ msg["name"] = name
+ }
+ out = append(out, msg)
+ }
+ }
+ flushText()
+ }
+ return out
+}
+
+func normalizeGeminiSystemInstruction(raw any) string {
+ switch v := raw.(type) {
+ case string:
+ return strings.TrimSpace(v)
+ case map[string]any:
+ if parts, ok := v["parts"].([]any); ok {
+ texts := make([]string, 0, len(parts))
+ for _, item := range parts {
+ part, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ if text := strings.TrimSpace(asString(part["text"])); text != "" {
+ texts = append(texts, text)
+ }
+ }
+ return strings.Join(texts, "\n")
+ }
+ if text := strings.TrimSpace(asString(v["text"])); text != "" {
+ return text
+ }
+ }
+ return ""
+}
+
+func mapGeminiRole(v any) string {
+ switch strings.ToLower(strings.TrimSpace(asString(v))) {
+ case "user":
+ return "user"
+ case "model", "assistant":
+ return "assistant"
+ case "system":
+ return "system"
+ default:
+ return ""
+ }
+}
+
+func convertGeminiTools(raw any) []any {
+ tools, _ := raw.([]any)
+ if len(tools) == 0 {
+ return nil
+ }
+ out := make([]any, 0, len(tools))
+ for _, item := range tools {
+ tool, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+
+ if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
+ for _, declRaw := range fnDecls {
+ decl, ok := declRaw.(map[string]any)
+ if !ok {
+ continue
+ }
+ name := strings.TrimSpace(asString(decl["name"]))
+ if name == "" {
+ continue
+ }
+ function := map[string]any{
+ "name": name,
+ }
+ if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
+ function["description"] = desc
+ }
+ if params, ok := decl["parameters"].(map[string]any); ok {
+ function["parameters"] = params
+ }
+ out = append(out, map[string]any{
+ "type": "function",
+ "function": function,
+ })
+ }
+ continue
+ }
+
+ // OpenAI-style passthrough fallback.
+ if _, ok := tool["function"].(map[string]any); ok {
+ out = append(out, tool)
+ continue
+ }
+
+ // Loose fallback for flattened function schema objects.
+ name := strings.TrimSpace(asString(tool["name"]))
+ if name == "" {
+ continue
+ }
+ fn := map[string]any{"name": name}
+ if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
+ fn["description"] = desc
+ }
+ if params, ok := tool["parameters"].(map[string]any); ok {
+ fn["parameters"] = params
+ }
+ out = append(out, map[string]any{
+ "type": "function",
+ "function": fn,
+ })
+ }
+ if len(out) == 0 {
+ return nil
+ }
+ return out
+}
+
+func collectGeminiPassThrough(req map[string]any) map[string]any {
+ cfg, _ := req["generationConfig"].(map[string]any)
+ if len(cfg) == 0 {
+ return nil
+ }
+ out := map[string]any{}
+ if v, ok := cfg["temperature"]; ok {
+ out["temperature"] = v
+ }
+ if v, ok := cfg["topP"]; ok {
+ out["top_p"] = v
+ }
+ if v, ok := cfg["maxOutputTokens"]; ok {
+ out["max_tokens"] = v
+ }
+ if v, ok := cfg["stopSequences"]; ok {
+ out["stop"] = v
+ }
+ if len(out) == 0 {
+ return nil
+ }
+ return out
+}
+
+func asString(v any) string {
+ s, _ := v.(string)
+ return s
+}
+
+func stringifyJSON(v any) string {
+ switch x := v.(type) {
+ case nil:
+ return "{}"
+ case string:
+ s := strings.TrimSpace(x)
+ if s == "" {
+ return "{}"
+ }
+ return s
+ default:
+ b, err := json.Marshal(x)
+ if err != nil || len(b) == 0 {
+ return "{}"
+ }
+ return string(b)
+ }
+}
diff --git a/internal/adapter/gemini/deps.go b/internal/adapter/gemini/deps.go
new file mode 100644
index 0000000..312114a
--- /dev/null
+++ b/internal/adapter/gemini/deps.go
@@ -0,0 +1,29 @@
+package gemini
+
+import (
+ "context"
+ "net/http"
+
+ "ds2api/internal/auth"
+ "ds2api/internal/config"
+ "ds2api/internal/deepseek"
+)
+
+type AuthResolver interface {
+ Determine(req *http.Request) (*auth.RequestAuth, error)
+ Release(a *auth.RequestAuth)
+}
+
+type DeepSeekCaller interface {
+ CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
+ GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
+ CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
+}
+
+type ConfigReader interface {
+ ModelAliases() map[string]string
+}
+
+var _ AuthResolver = (*auth.Resolver)(nil)
+var _ DeepSeekCaller = (*deepseek.Client)(nil)
+var _ ConfigReader = (*config.Store)(nil)
diff --git a/internal/adapter/gemini/handler.go b/internal/adapter/gemini/handler.go
new file mode 100644
index 0000000..8daaeda
--- /dev/null
+++ b/internal/adapter/gemini/handler.go
@@ -0,0 +1,348 @@
+package gemini
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/go-chi/chi/v5"
+
+ "ds2api/internal/auth"
+ "ds2api/internal/deepseek"
+ "ds2api/internal/sse"
+ streamengine "ds2api/internal/stream"
+ "ds2api/internal/util"
+)
+
+var writeJSON = util.WriteJSON
+
+type Handler struct {
+ Store ConfigReader
+ Auth AuthResolver
+ DS DeepSeekCaller
+}
+
+func RegisterRoutes(r chi.Router, h *Handler) {
+ r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
+ r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
+ r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
+ r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
+}
+
+func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
+ h.handleGenerateContent(w, r, false)
+}
+
+func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
+ h.handleGenerateContent(w, r, true)
+}
+
+func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
+ a, err := h.Auth.Determine(r)
+ if err != nil {
+ status := http.StatusUnauthorized
+ detail := err.Error()
+ if err == auth.ErrNoAccount {
+ status = http.StatusTooManyRequests
+ }
+ writeGeminiError(w, status, detail)
+ return
+ }
+ defer h.Auth.Release(a)
+
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeGeminiError(w, http.StatusBadRequest, "invalid json")
+ return
+ }
+
+ routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
+ stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
+ if err != nil {
+ writeGeminiError(w, http.StatusBadRequest, err.Error())
+ return
+ }
+
+ sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
+ if err != nil {
+ if a.UseConfigToken {
+ writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
+ } else {
+ writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
+ }
+ return
+ }
+ pow, err := h.DS.GetPow(r.Context(), a, 3)
+ if err != nil {
+ writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
+ return
+ }
+ payload := stdReq.CompletionPayload(sessionID)
+ resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
+ if err != nil {
+ writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
+ return
+ }
+
+ if stream {
+ h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
+ return
+ }
+ h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
+}
+
+func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
+ return
+ }
+
+ result := sse.CollectStream(resp, thinkingEnabled, true)
+ writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
+}
+
+func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
+ return
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache, no-transform")
+ w.Header().Set("Connection", "keep-alive")
+ w.Header().Set("X-Accel-Buffering", "no")
+
+ rc := http.NewResponseController(w)
+ _, canFlush := w.(http.Flusher)
+ runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
+
+ initialType := "text"
+ if thinkingEnabled {
+ initialType = "thinking"
+ }
+ streamengine.ConsumeSSE(streamengine.ConsumeConfig{
+ Context: r.Context(),
+ Body: resp.Body,
+ ThinkingEnabled: thinkingEnabled,
+ InitialType: initialType,
+ KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
+ IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
+ MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
+ }, streamengine.ConsumeHooks{
+ OnParsed: runtime.onParsed,
+ OnFinalize: func(_ streamengine.StopReason, _ error) {
+ runtime.finalize()
+ },
+ })
+}
+
+func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
+ parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
+ usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
+ return map[string]any{
+ "candidates": []map[string]any{
+ {
+ "index": 0,
+ "content": map[string]any{
+ "role": "model",
+ "parts": parts,
+ },
+ "finishReason": "STOP",
+ },
+ },
+ "modelVersion": model,
+ "usageMetadata": usage,
+ }
+}
+
+func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
+ promptTokens := util.EstimateTokens(finalPrompt)
+ reasoningTokens := util.EstimateTokens(finalThinking)
+ completionTokens := util.EstimateTokens(finalText)
+ return map[string]any{
+ "promptTokenCount": promptTokens,
+ "candidatesTokenCount": reasoningTokens + completionTokens,
+ "totalTokenCount": promptTokens + reasoningTokens + completionTokens,
+ }
+}
+
+func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
+ detected := util.ParseToolCalls(finalText, toolNames)
+ if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
+ detected = util.ParseToolCalls(finalThinking, toolNames)
+ }
+ if len(detected) > 0 {
+ parts := make([]map[string]any, 0, len(detected))
+ for _, tc := range detected {
+ parts = append(parts, map[string]any{
+ "functionCall": map[string]any{
+ "name": tc.Name,
+ "args": tc.Input,
+ },
+ })
+ }
+ return parts
+ }
+
+ text := finalText
+ if strings.TrimSpace(text) == "" {
+ text = finalThinking
+ }
+ return []map[string]any{{"text": text}}
+}
+
+type geminiStreamRuntime struct {
+ w http.ResponseWriter
+ rc *http.ResponseController
+ canFlush bool
+
+ model string
+ finalPrompt string
+
+ thinkingEnabled bool
+ searchEnabled bool
+ bufferContent bool
+ toolNames []string
+
+ thinking strings.Builder
+ text strings.Builder
+}
+
+func newGeminiStreamRuntime(
+ w http.ResponseWriter,
+ rc *http.ResponseController,
+ canFlush bool,
+ model string,
+ finalPrompt string,
+ thinkingEnabled bool,
+ searchEnabled bool,
+ toolNames []string,
+) *geminiStreamRuntime {
+ return &geminiStreamRuntime{
+ w: w,
+ rc: rc,
+ canFlush: canFlush,
+ model: model,
+ finalPrompt: finalPrompt,
+ thinkingEnabled: thinkingEnabled,
+ searchEnabled: searchEnabled,
+ bufferContent: len(toolNames) > 0,
+ toolNames: toolNames,
+ }
+}
+
+func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
+ b, _ := json.Marshal(payload)
+ _, _ = s.w.Write([]byte("data: "))
+ _, _ = s.w.Write(b)
+ _, _ = s.w.Write([]byte("\n\n"))
+ if s.canFlush {
+ _ = s.rc.Flush()
+ }
+}
+
+func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
+ if !parsed.Parsed {
+ return streamengine.ParsedDecision{}
+ }
+ if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
+ return streamengine.ParsedDecision{Stop: true}
+ }
+
+ contentSeen := false
+ for _, p := range parsed.Parts {
+ if p.Text == "" {
+ continue
+ }
+ if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
+ continue
+ }
+ contentSeen = true
+ if p.Type == "thinking" {
+ if s.thinkingEnabled {
+ s.thinking.WriteString(p.Text)
+ }
+ continue
+ }
+ s.text.WriteString(p.Text)
+ if s.bufferContent {
+ continue
+ }
+ s.sendChunk(map[string]any{
+ "candidates": []map[string]any{
+ {
+ "index": 0,
+ "content": map[string]any{
+ "role": "model",
+ "parts": []map[string]any{{"text": p.Text}},
+ },
+ },
+ },
+ "modelVersion": s.model,
+ })
+ }
+ return streamengine.ParsedDecision{ContentSeen: contentSeen}
+}
+
+func (s *geminiStreamRuntime) finalize() {
+ finalThinking := s.thinking.String()
+ finalText := s.text.String()
+
+ if s.bufferContent {
+ parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
+ s.sendChunk(map[string]any{
+ "candidates": []map[string]any{
+ {
+ "index": 0,
+ "content": map[string]any{
+ "role": "model",
+ "parts": parts,
+ },
+ },
+ },
+ "modelVersion": s.model,
+ })
+ }
+
+ s.sendChunk(map[string]any{
+ "candidates": []map[string]any{
+ {
+ "index": 0,
+ "finishReason": "STOP",
+ },
+ },
+ "modelVersion": s.model,
+ "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
+ })
+}
+
+func writeGeminiError(w http.ResponseWriter, status int, message string) {
+ errorStatus := "INVALID_ARGUMENT"
+ switch status {
+ case http.StatusUnauthorized:
+ errorStatus = "UNAUTHENTICATED"
+ case http.StatusForbidden:
+ errorStatus = "PERMISSION_DENIED"
+ case http.StatusTooManyRequests:
+ errorStatus = "RESOURCE_EXHAUSTED"
+ case http.StatusNotFound:
+ errorStatus = "NOT_FOUND"
+ default:
+ if status >= 500 {
+ errorStatus = "INTERNAL"
+ }
+ }
+ writeJSON(w, status, map[string]any{
+ "error": map[string]any{
+ "code": status,
+ "message": message,
+ "status": errorStatus,
+ },
+ })
+}
diff --git a/internal/adapter/gemini/handler_test.go b/internal/adapter/gemini/handler_test.go
new file mode 100644
index 0000000..862750a
--- /dev/null
+++ b/internal/adapter/gemini/handler_test.go
@@ -0,0 +1,174 @@
+package gemini
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/go-chi/chi/v5"
+
+ "ds2api/internal/auth"
+)
+
+type testGeminiConfig struct{}
+
+func (testGeminiConfig) ModelAliases() map[string]string { return nil }
+
+type testGeminiAuth struct {
+ a *auth.RequestAuth
+ err error
+}
+
+func (m testGeminiAuth) Determine(_ *http.Request) (*auth.RequestAuth, error) {
+ if m.err != nil {
+ return nil, m.err
+ }
+ if m.a != nil {
+ return m.a, nil
+ }
+ return &auth.RequestAuth{
+ UseConfigToken: false,
+ DeepSeekToken: "direct-token",
+ CallerID: "caller:test",
+ TriedAccounts: map[string]bool{},
+ }, nil
+}
+
+func (testGeminiAuth) Release(_ *auth.RequestAuth) {}
+
+type testGeminiDS struct {
+ resp *http.Response
+ err error
+}
+
+func (m testGeminiDS) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
+ return "session-id", nil
+}
+
+func (m testGeminiDS) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
+ return "pow", nil
+}
+
+func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
+ if m.err != nil {
+ return nil, m.err
+ }
+ return m.resp, nil
+}
+
+func makeGeminiUpstreamResponse(lines ...string) *http.Response {
+ body := strings.Join(lines, "\n")
+ if !strings.HasSuffix(body, "\n") {
+ body += "\n"
+ }
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(body)),
+ }
+}
+
+func TestGeminiRoutesRegistered(t *testing.T) {
+ h := &Handler{
+ Store: testGeminiConfig{},
+ Auth: testGeminiAuth{err: auth.ErrUnauthorized},
+ }
+ r := chi.NewRouter()
+ RegisterRoutes(r, h)
+
+ paths := []string{
+ "/v1beta/models/gemini-2.5-pro:generateContent",
+ "/v1beta/models/gemini-2.5-pro:streamGenerateContent",
+ "/v1/models/gemini-2.5-pro:generateContent",
+ "/v1/models/gemini-2.5-pro:streamGenerateContent",
+ }
+ for _, path := range paths {
+ req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`))
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+ if rec.Code == http.StatusNotFound {
+ t.Fatalf("expected route %s to be registered, got 404", path)
+ }
+ }
+}
+
+func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
+ upstream := makeGeminiUpstreamResponse(
+ `data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
+ `data: [DONE]`,
+ )
+ h := &Handler{
+ Store: testGeminiConfig{},
+ Auth: testGeminiAuth{},
+ DS: testGeminiDS{resp: upstream},
+ }
+ r := chi.NewRouter()
+ RegisterRoutes(r, h)
+
+ body := `{
+ "contents":[{"role":"user","parts":[{"text":"call tool"}]}],
+ "tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
+ }`
+ req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
+ req.Header.Set("Authorization", "Bearer direct-token")
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
+ }
+
+ var out map[string]any
+ if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
+ t.Fatalf("decode response failed: %v", err)
+ }
+ candidates, _ := out["candidates"].([]any)
+ if len(candidates) == 0 {
+ t.Fatalf("expected non-empty candidates: %#v", out)
+ }
+ c0, _ := candidates[0].(map[string]any)
+ content, _ := c0["content"].(map[string]any)
+ parts, _ := content["parts"].([]any)
+ if len(parts) == 0 {
+ t.Fatalf("expected non-empty parts: %#v", content)
+ }
+ part0, _ := parts[0].(map[string]any)
+ functionCall, _ := part0["functionCall"].(map[string]any)
+ if functionCall["name"] != "eval_javascript" {
+ t.Fatalf("expected functionCall name eval_javascript, got %#v", functionCall)
+ }
+}
+
+func TestStreamGenerateContentEmitsSSE(t *testing.T) {
+ upstream := makeGeminiUpstreamResponse(
+ `data: {"p":"response/content","v":"hello "}`,
+ `data: {"p":"response/content","v":"world"}`,
+ `data: [DONE]`,
+ )
+ h := &Handler{
+ Store: testGeminiConfig{},
+ Auth: testGeminiAuth{},
+ DS: testGeminiDS{resp: upstream},
+ }
+ r := chi.NewRouter()
+ RegisterRoutes(r, h)
+
+ body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`
+ req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:streamGenerateContent?alt=sse", strings.NewReader(body))
+ req.Header.Set("Authorization", "Bearer direct-token")
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
+ }
+ if !strings.Contains(rec.Body.String(), "data: ") {
+ t.Fatalf("expected SSE data frames, got body=%s", rec.Body.String())
+ }
+ if !strings.Contains(rec.Body.String(), `"finishReason":"STOP"`) {
+ t.Fatalf("expected stream finish frame, got body=%s", rec.Body.String())
+ }
+}
diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go
index 0e64bc5..d9a1ba4 100644
--- a/internal/adapter/openai/chat_stream_runtime.go
+++ b/internal/adapter/openai/chat_stream_runtime.go
@@ -25,10 +25,11 @@ type chatStreamRuntime struct {
thinkingEnabled bool
searchEnabled bool
- firstChunkSent bool
- bufferToolContent bool
- emitEarlyToolDeltas bool
- toolCallsEmitted bool
+ firstChunkSent bool
+ bufferToolContent bool
+ emitEarlyToolDeltas bool
+ toolCallsEmitted bool
+ toolCallsDoneEmitted bool
toolSieve toolStreamSieveState
streamToolCallIDs map[int]string
@@ -96,10 +97,10 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
finalThinking := s.thinking.String()
finalText := s.text.String()
detected := util.ParseToolCalls(finalText, s.toolNames)
- if len(detected) > 0 && !s.toolCallsEmitted {
+ if len(detected) > 0 && !s.toolCallsDoneEmitted {
finishReason = "tool_calls"
delta := map[string]any{
- "tool_calls": util.FormatOpenAIStreamToolCalls(detected),
+ "tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs),
}
if !s.firstChunkSent {
delta["role"] = "assistant"
@@ -112,8 +113,29 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
nil,
))
+ s.toolCallsEmitted = true
+ s.toolCallsDoneEmitted = true
} else if s.bufferToolContent {
for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) {
+ if len(evt.ToolCalls) > 0 {
+ finishReason = "tool_calls"
+ s.toolCallsEmitted = true
+ s.toolCallsDoneEmitted = true
+ tcDelta := map[string]any{
+ "tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
+ }
+ if !s.firstChunkSent {
+ tcDelta["role"] = "assistant"
+ s.firstChunkSent = true
+ }
+ s.sendChunk(openaifmt.BuildChatStreamChunk(
+ s.completionID,
+ s.created,
+ s.model,
+ []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
+ nil,
+ ))
+ }
if evt.Content == "" {
continue
}
@@ -189,10 +211,14 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
if !s.emitEarlyToolDeltas {
continue
}
- s.toolCallsEmitted = true
- tcDelta := map[string]any{
- "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs),
+ formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
+ if len(formatted) == 0 {
+ continue
}
+ tcDelta := map[string]any{
+ "tool_calls": formatted,
+ }
+ s.toolCallsEmitted = true
if !s.firstChunkSent {
tcDelta["role"] = "assistant"
s.firstChunkSent = true
@@ -202,8 +228,9 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
}
if len(evt.ToolCalls) > 0 {
s.toolCallsEmitted = true
+ s.toolCallsDoneEmitted = true
tcDelta := map[string]any{
- "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls),
+ "tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
}
if !s.firstChunkSent {
tcDelta["role"] = "assistant"
diff --git a/internal/adapter/openai/deps_injection_test.go b/internal/adapter/openai/deps_injection_test.go
index baa0c11..6286c0c 100644
--- a/internal/adapter/openai/deps_injection_test.go
+++ b/internal/adapter/openai/deps_injection_test.go
@@ -31,7 +31,7 @@ func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
"model": "my-model",
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
}
- out, err := normalizeOpenAIChatRequest(cfg, req)
+ out, err := normalizeOpenAIChatRequest(cfg, req, "")
if err != nil {
t.Fatalf("normalizeOpenAIChatRequest error: %v", err)
}
@@ -52,7 +52,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
_, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
aliases: map[string]string{},
wideInput: false,
- }, req)
+ }, req, "")
if err == nil {
t.Fatal("expected error when wide input is disabled and only input is provided")
}
@@ -60,7 +60,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
aliases: map[string]string{},
wideInput: true,
- }, req)
+ }, req, "")
if err != nil {
t.Fatalf("unexpected error when wide input is enabled: %v", err)
}
diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go
index 28a451c..391a035 100644
--- a/internal/adapter/openai/handler.go
+++ b/internal/adapter/openai/handler.go
@@ -93,7 +93,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
- stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
+ stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
@@ -154,7 +154,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
- canFlush := rc.Flush() == nil
+ _, canFlush := w.(http.Flusher)
if !canFlush {
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
}
@@ -233,7 +233,7 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
if len(toolSchemas) == 0 {
return messages, names
}
- toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error."
+ toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
for i := range messages {
if messages[i]["role"] == "system" {
@@ -280,6 +280,36 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s
return out
}
+func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
+ if len(calls) == 0 {
+ return nil
+ }
+ out := make([]map[string]any, 0, len(calls))
+ for i, c := range calls {
+ callID := ""
+ if ids != nil {
+ callID = strings.TrimSpace(ids[i])
+ }
+ if callID == "" {
+ callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ if ids != nil {
+ ids[i] = callID
+ }
+ }
+ args, _ := json.Marshal(c.Input)
+ out = append(out, map[string]any{
+ "index": i,
+ "id": callID,
+ "type": "function",
+ "function": map[string]any{
+ "name": c.Name,
+ "arguments": string(args),
+ },
+ })
+ }
+ return out
+}
+
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]any{
"error": map[string]any{
diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go
index dd2bb0f..2027729 100644
--- a/internal/adapter/openai/handler_toolcall_test.go
+++ b/internal/adapter/openai/handler_toolcall_test.go
@@ -735,3 +735,73 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
+
+func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
+ h := &Handler{}
+ resp := makeSSEHTTPResponse(
+ `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
+ `data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
+ `data: [DONE]`,
+ )
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
+
+ frames, done := parseSSEDataFrames(t, rec.Body.String())
+ if !done {
+ t.Fatalf("expected [DONE], body=%s", rec.Body.String())
+ }
+ if !streamHasToolCallsDelta(frames) {
+ t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
+ }
+
+ foundSearch := false
+ foundEval := false
+ foundIndex1 := false
+ toolCallsDeltaLens := make([]int, 0, 2)
+ for _, frame := range frames {
+ choices, _ := frame["choices"].([]any)
+ for _, item := range choices {
+ choice, _ := item.(map[string]any)
+ delta, _ := choice["delta"].(map[string]any)
+ rawToolCalls, hasToolCalls := delta["tool_calls"]
+ if !hasToolCalls {
+ continue
+ }
+ toolCalls, _ := rawToolCalls.([]any)
+ toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls))
+ for _, tc := range toolCalls {
+ tcm, _ := tc.(map[string]any)
+ if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
+ foundIndex1 = true
+ }
+ fn, _ := tcm["function"].(map[string]any)
+ name, _ := fn["name"].(string)
+ switch name {
+ case "search_web":
+ foundSearch = true
+ case "eval_javascript":
+ foundEval = true
+ case "search_webeval_javascript":
+ t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
+ }
+ if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
+ t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
+ }
+ }
+ }
+ }
+ if !foundSearch || !foundEval {
+ t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
+ }
+ if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 {
+ t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String())
+ }
+ if !foundIndex1 {
+ t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
+ }
+ if streamFinishReason(frames) != "tool_calls" {
+ t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
+ }
+}
diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go
index 3ebd1e7..a767960 100644
--- a/internal/adapter/openai/message_normalize.go
+++ b/internal/adapter/openai/message_normalize.go
@@ -4,9 +4,11 @@ import (
"encoding/json"
"fmt"
"strings"
+
+ "ds2api/internal/config"
)
-func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
+func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
out := make([]map[string]any, 0, len(raw))
for _, item := range raw {
msg, ok := item.(map[string]any)
@@ -17,7 +19,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
switch role {
case "assistant":
content := normalizeOpenAIContentForPrompt(msg["content"])
- toolCalls := formatAssistantToolCallsForPrompt(msg)
+ toolCalls := formatAssistantToolCallsForPrompt(msg, traceID)
combined := joinNonEmpty(content, toolCalls)
if combined == "" {
continue
@@ -53,7 +55,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
return out
}
-func formatAssistantToolCallsForPrompt(msg map[string]any) string {
+func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string {
entries := make([]string, 0)
if calls, ok := msg["tool_calls"].([]any); ok {
for i, item := range calls {
@@ -86,7 +88,8 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
if args == "" {
args = "{}"
}
- entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: %s\n- function.name: %s\n- function.arguments: %s", id, name, args))
+ maybeWarnSuspiciousToolHistory(traceID, id, name, args)
+ entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args))
}
}
@@ -99,7 +102,8 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
if args == "" {
args = "{}"
}
- entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: call_legacy\n- function.name: %s\n- function.arguments: %s", name, args))
+ maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args)
+ entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args))
}
return strings.Join(entries, "\n\n")
@@ -124,7 +128,7 @@ func formatToolResultForPrompt(msg map[string]any) string {
content = "null"
}
- return fmt.Sprintf("Tool result:\n- tool_call_id: %s\n- name: %s\n- content: %s", toolCallID, name, content)
+ return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
}
func normalizeOpenAIContentForPrompt(v any) string {
@@ -190,3 +194,45 @@ func joinNonEmpty(parts ...string) string {
}
return strings.Join(nonEmpty, "\n\n")
}
+
+func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) {
+ if !looksLikeConcatenatedJSON(args) {
+ return
+ }
+ traceID = strings.TrimSpace(traceID)
+ if traceID == "" {
+ traceID = "unknown"
+ }
+ config.Logger.Warn(
+ "[openai] suspicious tool call history payload detected",
+ "trace_id", traceID,
+ "tool_call_id", strings.TrimSpace(callID),
+ "name", strings.TrimSpace(name),
+ "arguments_preview", previewToolArgs(args, 160),
+ )
+}
+
+func looksLikeConcatenatedJSON(raw string) bool {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return false
+ }
+ if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
+ return true
+ }
+ dec := json.NewDecoder(strings.NewReader(trimmed))
+ var first any
+ if err := dec.Decode(&first); err != nil {
+ return false
+ }
+ var second any
+ return dec.Decode(&second) == nil
+}
+
+func previewToolArgs(raw string, max int) string {
+ trimmed := strings.TrimSpace(raw)
+ if max <= 0 || len(trimmed) <= max {
+ return trimmed
+ }
+ return trimmed[:max]
+}
diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go
index bb648d3..30403bc 100644
--- a/internal/adapter/openai/message_normalize_test.go
+++ b/internal/adapter/openai/message_normalize_test.go
@@ -33,23 +33,24 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes
},
}
- normalized := normalizeOpenAIMessagesForPrompt(raw)
+ normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 4 {
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
}
assistantContent, _ := normalized[2]["content"].(string)
- if !strings.Contains(assistantContent, "tool_call_id: call_1") ||
+ if !strings.Contains(assistantContent, "[TOOL_CALL_HISTORY]") ||
+ !strings.Contains(assistantContent, "tool_call_id: call_1") ||
!strings.Contains(assistantContent, "function.name: get_weather") ||
!strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") {
t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent)
}
toolContent, _ := normalized[3]["content"].(string)
- if !strings.Contains(toolContent, "Tool result:") || !strings.Contains(toolContent, "name: get_weather") {
+ if !strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") || !strings.Contains(toolContent, "name: get_weather") {
t.Fatalf("tool result not serialized correctly: %q", toolContent)
}
prompt := util.MessagesPrepare(normalized)
- if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "Tool result:") {
+ if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") {
t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt)
}
}
@@ -67,7 +68,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.
},
}
- normalized := normalizeOpenAIMessagesForPrompt(raw)
+ normalized := normalizeOpenAIMessagesForPrompt(raw, "")
got, _ := normalized[0]["content"].(string)
if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) {
t.Fatalf("expected serialized object in tool content, got %q", got)
@@ -88,7 +89,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
},
}
- normalized := normalizeOpenAIMessagesForPrompt(raw)
+ normalized := normalizeOpenAIMessagesForPrompt(raw, "")
got, _ := normalized[0]["content"].(string)
if !strings.Contains(got, "line-1\nline-2") {
t.Fatalf("expected joined text blocks, got %q", got)
@@ -107,7 +108,7 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
},
}
- normalized := normalizeOpenAIMessagesForPrompt(raw)
+ normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 1 {
t.Fatalf("expected one normalized message, got %d", len(normalized))
}
@@ -119,3 +120,50 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
t.Fatalf("unexpected normalized function-role content: %q", got)
}
}
+
+func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) {
+ raw := []any{
+ map[string]any{
+ "role": "assistant",
+ "tool_calls": []any{
+ map[string]any{
+ "id": "call_search",
+ "type": "function",
+ "function": map[string]any{
+ "name": "search_web",
+ "arguments": `{"query":"latest ai news"}`,
+ },
+ },
+ map[string]any{
+ "id": "call_eval",
+ "type": "function",
+ "function": map[string]any{
+ "name": "eval_javascript",
+ "arguments": `{"code":"1+1"}`,
+ },
+ },
+ },
+ },
+ }
+
+ normalized := normalizeOpenAIMessagesForPrompt(raw, "")
+ if len(normalized) != 1 {
+ t.Fatalf("expected one normalized assistant message, got %d", len(normalized))
+ }
+ content, _ := normalized[0]["content"].(string)
+ if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 {
+ t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content)
+ }
+ if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") {
+ t.Fatalf("missing first tool call block, got %q", content)
+ }
+ if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") {
+ t.Fatalf("missing second tool call block, got %q", content)
+ }
+ if strings.Contains(content, "search_webeval_javascript") {
+ t.Fatalf("unexpected merged function name detected: %q", content)
+ }
+ if strings.Contains(content, `}{"`) {
+ t.Fatalf("unexpected concatenated function arguments detected: %q", content)
+ }
+}
diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go
index f83963f..76739ed 100644
--- a/internal/adapter/openai/prompt_build.go
+++ b/internal/adapter/openai/prompt_build.go
@@ -4,11 +4,18 @@ import (
"ds2api/internal/deepseek"
)
-func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) {
- messages := normalizeOpenAIMessagesForPrompt(messagesRaw)
+func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
+ messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
toolNames := []string{}
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
messages, toolNames = injectToolPrompt(messages, tools)
}
return deepseek.MessagesPrepare(messages), toolNames
}
+
+// BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so
+// other protocol adapters (for example Gemini) can reuse the same tool/history
+// normalization logic and remain behavior-compatible with chat/completions.
+func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
+ return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID)
+}
diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go
index 1833860..bd6223e 100644
--- a/internal/adapter/openai/prompt_build_test.go
+++ b/internal/adapter/openai/prompt_build_test.go
@@ -40,13 +40,13 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
},
}
- finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools)
+ finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "")
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
t.Fatalf("unexpected tool names: %#v", toolNames)
}
if !strings.Contains(finalPrompt, "tool_call_id: call_1") ||
!strings.Contains(finalPrompt, "function.name: get_weather") ||
- !strings.Contains(finalPrompt, "Tool result:") ||
+ !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") ||
!strings.Contains(finalPrompt, `"condition":"sunny"`) {
t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt)
}
@@ -70,11 +70,14 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
},
}
- finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools)
+ finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "")
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
}
if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") {
t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt)
}
+ if !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") {
+ t.Fatalf("vercel prepare finalPrompt missing history marker instruction: %q", finalPrompt)
+ }
}
diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go
index d270e1a..a5e2b72 100644
--- a/internal/adapter/openai/responses_embeddings_test.go
+++ b/internal/adapter/openai/responses_embeddings_test.go
@@ -1,6 +1,7 @@
package openai
import (
+ "strings"
"testing"
"time"
)
@@ -32,6 +33,82 @@ func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) {
}
}
+func TestNormalizeResponsesInputAsMessagesObjectRoleContentBlocks(t *testing.T) {
+ msgs := normalizeResponsesInputAsMessages(map[string]any{
+ "role": "user",
+ "content": []any{
+ map[string]any{"type": "input_text", "text": "line-1"},
+ map[string]any{"type": "input_text", "text": "line-2"},
+ },
+ })
+ if len(msgs) != 1 {
+ t.Fatalf("expected one message, got %d", len(msgs))
+ }
+ m, _ := msgs[0].(map[string]any)
+ if m["role"] != "user" {
+ t.Fatalf("unexpected role: %#v", m)
+ }
+ if strings.TrimSpace(normalizeOpenAIContentForPrompt(m["content"])) != "line-1\nline-2" {
+ t.Fatalf("unexpected content: %#v", m["content"])
+ }
+}
+
+func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) {
+ msgs := normalizeResponsesInputAsMessages([]any{
+ map[string]any{
+ "type": "function_call_output",
+ "call_id": "call_123",
+ "output": map[string]any{"ok": true},
+ },
+ })
+ if len(msgs) != 1 {
+ t.Fatalf("expected one message, got %d", len(msgs))
+ }
+ m, _ := msgs[0].(map[string]any)
+ if m["role"] != "tool" {
+ t.Fatalf("expected tool role, got %#v", m)
+ }
+ if m["tool_call_id"] != "call_123" {
+ t.Fatalf("expected tool_call_id propagated, got %#v", m)
+ }
+}
+
+func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
+ msgs := normalizeResponsesInputAsMessages([]any{
+ map[string]any{
+ "type": "function_call",
+ "call_id": "call_456",
+ "name": "search",
+ "arguments": `{"q":"golang"}`,
+ },
+ })
+ if len(msgs) != 1 {
+ t.Fatalf("expected one message, got %d", len(msgs))
+ }
+ m, _ := msgs[0].(map[string]any)
+ if m["role"] != "assistant" {
+ t.Fatalf("expected assistant role, got %#v", m["role"])
+ }
+ toolCalls, _ := m["tool_calls"].([]any)
+ if len(toolCalls) != 1 {
+ t.Fatalf("expected one tool_call, got %#v", m["tool_calls"])
+ }
+ call, _ := toolCalls[0].(map[string]any)
+ if call["id"] != "call_456" {
+ t.Fatalf("expected call id preserved, got %#v", call)
+ }
+ if call["type"] != "function" {
+ t.Fatalf("expected function type, got %#v", call)
+ }
+ fn, _ := call["function"].(map[string]any)
+ if fn["name"] != "search" {
+ t.Fatalf("expected call name preserved, got %#v", call)
+ }
+ if fn["arguments"] != `{"q":"golang"}` {
+ t.Fatalf("expected call arguments preserved, got %#v", call)
+ }
+}
+
func TestExtractEmbeddingInputs(t *testing.T) {
got := extractEmbeddingInputs([]any{"a", "b"})
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go
index e767b2b..e71cafe 100644
--- a/internal/adapter/openai/responses_handler.go
+++ b/internal/adapter/openai/responses_handler.go
@@ -68,7 +68,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
- stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req)
+ stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
@@ -128,7 +128,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
- canFlush := rc.Flush() == nil
+ _, canFlush := w.(http.Flusher)
initialType := "text"
if thinkingEnabled {
@@ -203,40 +203,231 @@ func normalizeResponsesInputAsMessages(input any) []any {
}
return []any{map[string]any{"role": "user", "content": v}}
case []any:
- if len(v) == 0 {
- return nil
- }
- // If caller already provides role-shaped items, keep as-is.
- if first, ok := v[0].(map[string]any); ok {
- if _, hasRole := first["role"]; hasRole {
- return v
- }
- }
- parts := make([]string, 0, len(v))
- for _, item := range v {
- if m, ok := item.(map[string]any); ok {
- if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
- if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
- parts = append(parts, txt)
- continue
- }
- }
- }
- if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
- parts = append(parts, s)
- }
- }
- if len(parts) == 0 {
- return nil
- }
- return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}}
+ return normalizeResponsesInputArray(v)
case map[string]any:
+ if msg := normalizeResponsesInputItem(v); msg != nil {
+ return []any{msg}
+ }
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
return []any{map[string]any{"role": "user", "content": txt}}
}
- if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" {
- return []any{map[string]any{"role": "user", "content": content}}
+ if content, ok := v["content"]; ok {
+ if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
+ return []any{map[string]any{"role": "user", "content": content}}
+ }
}
}
return nil
}
+
+func normalizeResponsesInputArray(items []any) []any {
+ if len(items) == 0 {
+ return nil
+ }
+ out := make([]any, 0, len(items))
+ fallbackParts := make([]string, 0, len(items))
+ flushFallback := func() {
+ if len(fallbackParts) == 0 {
+ return
+ }
+ out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
+ fallbackParts = fallbackParts[:0]
+ }
+
+ for _, item := range items {
+ switch x := item.(type) {
+ case map[string]any:
+ if msg := normalizeResponsesInputItem(x); msg != nil {
+ flushFallback()
+ out = append(out, msg)
+ continue
+ }
+ if s := normalizeResponsesFallbackPart(x); s != "" {
+ fallbackParts = append(fallbackParts, s)
+ }
+ default:
+ if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
+ fallbackParts = append(fallbackParts, s)
+ }
+ }
+ }
+ flushFallback()
+ if len(out) == 0 {
+ return nil
+ }
+ return out
+}
+
+func normalizeResponsesInputItem(m map[string]any) map[string]any {
+ if m == nil {
+ return nil
+ }
+
+ role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
+ if role != "" {
+ content := m["content"]
+ if content == nil {
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ content = txt
+ }
+ }
+ if content == nil {
+ return nil
+ }
+ return map[string]any{
+ "role": role,
+ "content": content,
+ }
+ }
+
+ itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
+ switch itemType {
+ case "message", "input_message":
+ content := m["content"]
+ if content == nil {
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ content = txt
+ }
+ }
+ if content == nil {
+ return nil
+ }
+ role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
+ if role == "" {
+ role = "user"
+ }
+ return map[string]any{
+ "role": role,
+ "content": content,
+ }
+ case "function_call_output", "tool_result":
+ content := m["output"]
+ if content == nil {
+ content = m["content"]
+ }
+ if content == nil {
+ content = ""
+ }
+ out := map[string]any{
+ "role": "tool",
+ "content": content,
+ }
+ if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
+ out["tool_call_id"] = callID
+ } else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
+ out["tool_call_id"] = callID
+ }
+ if name := strings.TrimSpace(asString(m["name"])); name != "" {
+ out["name"] = name
+ } else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
+ out["name"] = name
+ }
+ return out
+ case "function_call", "tool_call":
+ name := strings.TrimSpace(asString(m["name"]))
+ var fn map[string]any
+ if rawFn, ok := m["function"].(map[string]any); ok {
+ fn = rawFn
+ if name == "" {
+ name = strings.TrimSpace(asString(fn["name"]))
+ }
+ }
+ if name == "" {
+ return nil
+ }
+
+ var argsRaw any
+ if v, ok := m["arguments"]; ok {
+ argsRaw = v
+ } else if v, ok := m["input"]; ok {
+ argsRaw = v
+ }
+ if argsRaw == nil && fn != nil {
+ if v, ok := fn["arguments"]; ok {
+ argsRaw = v
+ } else if v, ok := fn["input"]; ok {
+ argsRaw = v
+ }
+ }
+
+ functionPayload := map[string]any{
+ "name": name,
+ "arguments": stringifyToolCallArguments(argsRaw),
+ }
+ call := map[string]any{
+ "type": "function",
+ "function": functionPayload,
+ }
+ if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
+ call["id"] = callID
+ } else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
+ call["id"] = callID
+ }
+ return map[string]any{
+ "role": "assistant",
+ "tool_calls": []any{call},
+ }
+ case "input_text":
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ return map[string]any{
+ "role": "user",
+ "content": txt,
+ }
+ }
+ }
+
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ return map[string]any{
+ "role": "user",
+ "content": txt,
+ }
+ }
+ if content, ok := m["content"]; ok {
+ if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
+ return map[string]any{
+ "role": "user",
+ "content": content,
+ }
+ }
+ }
+ return nil
+}
+
+func normalizeResponsesFallbackPart(m map[string]any) string {
+ if m == nil {
+ return ""
+ }
+ if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ return txt
+ }
+ }
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ return txt
+ }
+ if content, ok := m["content"]; ok {
+ if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
+ return normalized
+ }
+ }
+ return strings.TrimSpace(fmt.Sprintf("%v", m))
+}
+
+func stringifyToolCallArguments(v any) string {
+ switch x := v.(type) {
+ case nil:
+ return "{}"
+ case string:
+ s := strings.TrimSpace(x)
+ if s == "" {
+ return "{}"
+ }
+ return s
+ default:
+ b, err := json.Marshal(x)
+ if err != nil || len(b) == 0 {
+ return "{}"
+ }
+ return string(b)
+ }
+}
diff --git a/internal/adapter/openai/responses_stream_runtime.go b/internal/adapter/openai/responses_stream_runtime.go
index f7e8b20..050965c 100644
--- a/internal/adapter/openai/responses_stream_runtime.go
+++ b/internal/adapter/openai/responses_stream_runtime.go
@@ -3,12 +3,15 @@ package openai
import (
"encoding/json"
"net/http"
+ "sort"
"strings"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
+
+ "github.com/google/uuid"
)
type responsesStreamRuntime struct {
@@ -24,14 +27,20 @@ type responsesStreamRuntime struct {
thinkingEnabled bool
searchEnabled bool
- bufferToolContent bool
- emitEarlyToolDeltas bool
- toolCallsEmitted bool
+ bufferToolContent bool
+ emitEarlyToolDeltas bool
+ toolCallsEmitted bool
+ toolCallsDoneEmitted bool
sieve toolStreamSieveState
+ thinkingSieve toolStreamSieveState
thinking strings.Builder
text strings.Builder
streamToolCallIDs map[int]string
+ streamFunctionIDs map[int]string
+ functionDone map[int]bool
+ toolCallsDoneSigs map[string]bool
+ reasoningItemID string
persistResponse func(obj map[string]any)
}
@@ -63,6 +72,9 @@ func newResponsesStreamRuntime(
bufferToolContent: bufferToolContent,
emitEarlyToolDeltas: emitEarlyToolDeltas,
streamToolCallIDs: map[int]string{},
+ streamFunctionIDs: map[int]string{},
+ functionDone: map[int]bool{},
+ toolCallsDoneSigs: map[string]bool{},
persistResponse: persistResponse,
}
}
@@ -92,19 +104,33 @@ func (s *responsesStreamRuntime) sendDone() {
func (s *responsesStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := s.text.String()
+ if strings.TrimSpace(finalThinking) != "" {
+ s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
+ }
if s.bufferToolContent {
- for _, evt := range flushToolSieve(&s.sieve, s.toolNames) {
- if evt.Content != "" {
- s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
- }
- if len(evt.ToolCalls) > 0 {
- s.toolCallsEmitted = true
- s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
+ s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
+ s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
+ }
+ // Compatibility fallback: some streams only emit incremental tool deltas.
+ // Ensure final function_call_arguments.done is emitted at least once.
+ if s.toolCallsEmitted {
+ detected := util.ParseToolCalls(finalText, s.toolNames)
+ if len(detected) == 0 {
+ detected = util.ParseToolCalls(finalThinking, s.toolNames)
+ }
+ if len(detected) > 0 {
+ if !s.toolCallsDoneEmitted {
+ s.emitToolCallsDone(detected)
+ } else {
+ s.emitFunctionCallDoneEvents(detected)
}
}
}
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
+ if s.toolCallsEmitted {
+ s.alignCompletedOutputCallIDs(obj)
+ }
if s.toolCallsEmitted {
obj["status"] = "completed"
}
@@ -138,6 +164,10 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
}
s.thinking.WriteString(p.Text)
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
+ s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
+ if s.bufferToolContent {
+ s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
+ }
continue
}
@@ -146,23 +176,191 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
continue
}
- for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) {
- if evt.Content != "" {
- s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
- }
- if len(evt.ToolCallDeltas) > 0 {
- if !s.emitEarlyToolDeltas {
- continue
- }
- s.toolCallsEmitted = true
- s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)))
- }
- if len(evt.ToolCalls) > 0 {
- s.toolCallsEmitted = true
- s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
- }
- }
+ s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}
+
+func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
+ for _, evt := range events {
+ if emitContent && evt.Content != "" {
+ s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
+ }
+ if len(evt.ToolCallDeltas) > 0 {
+ if !s.emitEarlyToolDeltas {
+ continue
+ }
+ formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
+ if len(formatted) == 0 {
+ continue
+ }
+ s.toolCallsEmitted = true
+ s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
+ s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
+ }
+ if len(evt.ToolCalls) > 0 {
+ s.emitToolCallsDone(evt.ToolCalls)
+ }
+ }
+}
+
+func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
+ if len(calls) == 0 {
+ return
+ }
+ sig := toolCallListSignature(calls)
+ if sig != "" && s.toolCallsDoneSigs[sig] {
+ return
+ }
+ if sig != "" {
+ s.toolCallsDoneSigs[sig] = true
+ }
+ formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
+ if len(formatted) == 0 {
+ return
+ }
+ s.toolCallsEmitted = true
+ s.toolCallsDoneEmitted = true
+ s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
+ s.emitFunctionCallDoneEvents(calls)
+}
+
+func (s *responsesStreamRuntime) ensureReasoningItemID() string {
+ if strings.TrimSpace(s.reasoningItemID) != "" {
+ return s.reasoningItemID
+ }
+ s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ return s.reasoningItemID
+}
+
+func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
+ if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
+ return id
+ }
+ id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ s.streamFunctionIDs[index] = id
+ return id
+}
+
+func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
+ if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
+ return id
+ }
+ id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ s.streamToolCallIDs[index] = id
+ return id
+}
+
+func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
+ if strings.TrimSpace(s.thinking.String()) != "" {
+ return 1
+ }
+ return 0
+}
+
+func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
+ for _, d := range deltas {
+ if strings.TrimSpace(d.Arguments) == "" {
+ continue
+ }
+ outputIndex := s.functionOutputBaseIndex() + d.Index
+ itemID := s.ensureFunctionItemID(outputIndex)
+ callID := s.ensureToolCallID(d.Index)
+ s.sendEvent(
+ "response.function_call_arguments.delta",
+ openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
+ )
+ }
+}
+
+func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
+ base := s.functionOutputBaseIndex()
+ for idx, tc := range calls {
+ if strings.TrimSpace(tc.Name) == "" {
+ continue
+ }
+ outputIndex := base + idx
+ if s.functionDone[outputIndex] {
+ continue
+ }
+ itemID := s.ensureFunctionItemID(outputIndex)
+ callID := s.ensureToolCallID(idx)
+ argsBytes, _ := json.Marshal(tc.Input)
+ s.sendEvent(
+ "response.function_call_arguments.done",
+ openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
+ )
+ s.functionDone[outputIndex] = true
+ }
+}
+
+func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
+ if obj == nil || len(s.streamToolCallIDs) == 0 {
+ return
+ }
+ output, _ := obj["output"].([]any)
+ if len(output) == 0 {
+ return
+ }
+ indices := make([]int, 0, len(s.streamToolCallIDs))
+ for idx := range s.streamToolCallIDs {
+ indices = append(indices, idx)
+ }
+ sort.Ints(indices)
+ ordered := make([]string, 0, len(indices))
+ for _, idx := range indices {
+ id := strings.TrimSpace(s.streamToolCallIDs[idx])
+ if id == "" {
+ continue
+ }
+ ordered = append(ordered, id)
+ }
+ if len(ordered) == 0 {
+ return
+ }
+
+ functionIdx := 0
+ for _, item := range output {
+ m, _ := item.(map[string]any)
+ if m == nil {
+ continue
+ }
+ typ, _ := m["type"].(string)
+ switch typ {
+ case "function_call":
+ if functionIdx < len(ordered) {
+ m["call_id"] = ordered[functionIdx]
+ functionIdx++
+ }
+ case "tool_calls":
+ tcArr, _ := m["tool_calls"].([]any)
+ for i, raw := range tcArr {
+ tc, _ := raw.(map[string]any)
+ if tc == nil {
+ continue
+ }
+ if i < len(ordered) {
+ tc["id"] = ordered[i]
+ }
+ }
+ }
+ }
+}
+
+func toolCallListSignature(calls []util.ParsedToolCall) string {
+ if len(calls) == 0 {
+ return ""
+ }
+ var b strings.Builder
+ for i, tc := range calls {
+ if i > 0 {
+ b.WriteString("|")
+ }
+ b.WriteString(strings.TrimSpace(tc.Name))
+ b.WriteString(":")
+ args, _ := json.Marshal(tc.Input)
+ b.Write(args)
+ }
+ return b.String()
+}
diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go
index 9b0a5ac..a513e6f 100644
--- a/internal/adapter/openai/responses_stream_test.go
+++ b/internal/adapter/openai/responses_stream_test.go
@@ -45,17 +45,37 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
if len(output) == 0 {
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
}
- first, _ := output[0].(map[string]any)
- if first["type"] != "tool_calls" {
- t.Fatalf("expected first output type tool_calls, got %#v", first["type"])
+ var firstToolWrapper map[string]any
+ hasFunctionCall := false
+ for _, item := range output {
+ m, _ := item.(map[string]any)
+ if m == nil {
+ continue
+ }
+ if m["type"] == "function_call" {
+ hasFunctionCall = true
+ }
+ if m["type"] == "tool_calls" && firstToolWrapper == nil {
+ firstToolWrapper = m
+ }
}
- toolCalls, _ := first["tool_calls"].([]any)
+ if !hasFunctionCall {
+ t.Fatalf("expected at least one function_call item for responses compatibility, got %#v", responseObj["output"])
+ }
+ if firstToolWrapper == nil {
+ t.Fatalf("expected a tool_calls wrapper item, got %#v", responseObj["output"])
+ }
+ toolCalls, _ := firstToolWrapper["tool_calls"].([]any)
if len(toolCalls) == 0 {
- t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"])
+ t.Fatalf("expected at least one tool_call in output, got %#v", firstToolWrapper["tool_calls"])
}
call0, _ := toolCalls[0].(map[string]any)
- if call0["name"] != "read_file" {
- t.Fatalf("unexpected tool call name: %#v", call0["name"])
+ if call0["type"] != "function" {
+ t.Fatalf("unexpected tool call type: %#v", call0["type"])
+ }
+ fn, _ := call0["function"].(map[string]any)
+ if fn["name"] != "read_file" {
+ t.Fatalf("unexpected tool call name: %#v", fn["name"])
}
if strings.Contains(outputText, `"tool_calls"`) {
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
@@ -95,6 +115,314 @@ func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t
}
}
+func TestHandleResponsesStreamEmitsReasoningCompatEvents(t *testing.T) {
+ h := &Handler{}
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+ rec := httptest.NewRecorder()
+
+ b, _ := json.Marshal(map[string]any{
+ "p": "response/thinking_content",
+ "v": "thought",
+ })
+ streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(streamBody)),
+ }
+
+ h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil)
+
+ body := rec.Body.String()
+ if !strings.Contains(body, "event: response.reasoning.delta") {
+ t.Fatalf("expected response.reasoning.delta event, body=%s", body)
+ }
+ if !strings.Contains(body, "event: response.reasoning_text.delta") {
+ t.Fatalf("expected response.reasoning_text.delta compatibility event, body=%s", body)
+ }
+ if !strings.Contains(body, "event: response.reasoning_text.done") {
+ t.Fatalf("expected response.reasoning_text.done compatibility event, body=%s", body)
+ }
+}
+
+func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
+ h := &Handler{}
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+ rec := httptest.NewRecorder()
+
+ sseLine := func(v string) string {
+ b, _ := json.Marshal(map[string]any{
+ "p": "response/content",
+ "v": v,
+ })
+ return "data: " + string(b) + "\n"
+ }
+
+ streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(streamBody)),
+ }
+
+ h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
+ body := rec.Body.String()
+ if !strings.Contains(body, "event: response.function_call_arguments.delta") {
+ t.Fatalf("expected response.function_call_arguments.delta compatibility event, body=%s", body)
+ }
+ if !strings.Contains(body, "event: response.function_call_arguments.done") {
+ t.Fatalf("expected response.function_call_arguments.done compatibility event, body=%s", body)
+ }
+ donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
+ if !ok {
+ t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
+ }
+ if strings.TrimSpace(asString(donePayload["call_id"])) == "" {
+ t.Fatalf("expected call_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
+ }
+ if strings.TrimSpace(asString(donePayload["response_id"])) == "" {
+ t.Fatalf("expected response_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
+ }
+ doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
+ if doneCallID == "" {
+ t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
+ }
+ completed, ok := extractSSEEventPayload(body, "response.completed")
+ if !ok {
+ t.Fatalf("expected response.completed payload, body=%s", body)
+ }
+ responseObj, _ := completed["response"].(map[string]any)
+ output, _ := responseObj["output"].([]any)
+ if len(output) == 0 {
+ t.Fatalf("expected non-empty output in response.completed, response=%#v", responseObj)
+ }
+ var completedCallID string
+ for _, item := range output {
+ m, _ := item.(map[string]any)
+ if m == nil || m["type"] != "function_call" {
+ continue
+ }
+ completedCallID = strings.TrimSpace(asString(m["call_id"]))
+ if completedCallID != "" {
+ break
+ }
+ }
+ if completedCallID == "" {
+ t.Fatalf("expected function_call.call_id in completed output, output=%#v", output)
+ }
+ if completedCallID != doneCallID {
+ t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID)
+ }
+}
+
+func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T) {
+ h := &Handler{}
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+ rec := httptest.NewRecorder()
+
+ sseLine := func(path, v string) string {
+ b, _ := json.Marshal(map[string]any{
+ "p": path,
+ "v": v,
+ })
+ return "data: " + string(b) + "\n"
+ }
+
+ streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(streamBody)),
+ }
+
+ h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"})
+
+ body := rec.Body.String()
+ if !strings.Contains(body, "event: response.reasoning_text.delta") {
+ t.Fatalf("expected response.reasoning_text.delta event, body=%s", body)
+ }
+ if !strings.Contains(body, "event: response.function_call_arguments.done") {
+ t.Fatalf("expected response.function_call_arguments.done event from thinking channel, body=%s", body)
+ }
+ if !strings.Contains(body, "event: response.output_tool_call.done") {
+ t.Fatalf("expected response.output_tool_call.done event from thinking channel, body=%s", body)
+ }
+}
+
+func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
+ h := &Handler{}
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+ rec := httptest.NewRecorder()
+
+ sseLine := func(v string) string {
+ b, _ := json.Marshal(map[string]any{
+ "p": "response/content",
+ "v": v,
+ })
+ return "data: " + string(b) + "\n"
+ }
+
+ streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
+ sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
+ "data: [DONE]\n"
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(streamBody)),
+ }
+
+ h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
+
+ body := rec.Body.String()
+ if !strings.Contains(body, "event: response.output_tool_call.done") {
+ t.Fatalf("expected response.output_tool_call.done event, body=%s", body)
+ }
+ donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
+ if len(donePayloads) != 2 {
+ t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
+ }
+
+ seenNames := map[string]string{}
+ for _, payload := range donePayloads {
+ name := strings.TrimSpace(asString(payload["name"]))
+ callID := strings.TrimSpace(asString(payload["call_id"]))
+ args := strings.TrimSpace(asString(payload["arguments"]))
+ if callID == "" {
+ t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
+ }
+ if strings.Contains(args, `}{"`) {
+ t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload)
+ }
+ if name == "search_webeval_javascript" {
+ t.Fatalf("unexpected merged tool name in done payload: %#v", payload)
+ }
+ if name != "search_web" && name != "eval_javascript" {
+ t.Fatalf("unexpected tool name in done payload: %#v", payload)
+ }
+ seenNames[name] = callID
+ }
+ if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" {
+ t.Fatalf("expected done events for both tools, got %#v", seenNames)
+ }
+ if seenNames["search_web"] == seenNames["eval_javascript"] {
+ t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
+ }
+
+ completed, ok := extractSSEEventPayload(body, "response.completed")
+ if !ok {
+ t.Fatalf("expected response.completed event, body=%s", body)
+ }
+ responseObj, _ := completed["response"].(map[string]any)
+ output, _ := responseObj["output"].([]any)
+ functionCallIDs := map[string]string{}
+ for _, item := range output {
+ m, _ := item.(map[string]any)
+ if m == nil || m["type"] != "function_call" {
+ continue
+ }
+ name := strings.TrimSpace(asString(m["name"]))
+ callID := strings.TrimSpace(asString(m["call_id"]))
+ if name != "" && callID != "" {
+ functionCallIDs[name] = callID
+ }
+ }
+ if functionCallIDs["search_web"] != seenNames["search_web"] {
+ t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"])
+ }
+ if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] {
+ t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"])
+ }
+}
+
+func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) {
+ h := &Handler{}
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+ rec := httptest.NewRecorder()
+
+ sseLine := func(path, v string) string {
+ b, _ := json.Marshal(map[string]any{
+ "p": path,
+ "v": v,
+ })
+ return "data: " + string(b) + "\n"
+ }
+
+ streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
+ sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
+ "data: [DONE]\n"
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(streamBody)),
+ }
+
+ h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"})
+
+ body := rec.Body.String()
+ if !strings.Contains(body, "event: response.reasoning_text.delta") {
+ t.Fatalf("expected reasoning stream events, body=%s", body)
+ }
+ donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
+ if len(donePayloads) != 2 {
+ t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
+ }
+ seen := map[string]bool{}
+ for _, payload := range donePayloads {
+ name := strings.TrimSpace(asString(payload["name"]))
+ if name == "search_webeval_javascript" {
+ t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload)
+ }
+ if name != "search_web" && name != "eval_javascript" {
+ t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload)
+ }
+ args := strings.TrimSpace(asString(payload["arguments"]))
+ if strings.Contains(args, `}{"`) {
+ t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload)
+ }
+ seen[name] = true
+ }
+ if !seen["search_web"] || !seen["eval_javascript"] {
+ t.Fatalf("expected both tools in thinking channel done events, got %#v", seen)
+ }
+}
+
+func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T) {
+ h := &Handler{}
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+ rec := httptest.NewRecorder()
+
+ sseLine := func(v string) string {
+ b, _ := json.Marshal(map[string]any{
+ "p": "response/content",
+ "v": v,
+ })
+ return "data: " + string(b) + "\n"
+ }
+
+ streamBody := sseLine("我来调用工具\n") +
+ sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
+ "data: [DONE]\n"
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(streamBody)),
+ }
+
+ h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
+
+ completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
+ if !ok {
+ t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
+ }
+ responseObj, _ := completed["response"].(map[string]any)
+ output, _ := responseObj["output"].([]any)
+ hasFunctionCall := false
+ for _, item := range output {
+ m, _ := item.(map[string]any)
+ if m != nil && m["type"] == "function_call" {
+ hasFunctionCall = true
+ break
+ }
+ }
+ if !hasFunctionCall {
+ t.Fatalf("expected completed output to include function_call when mixed prose contains tool_calls payload, output=%#v", output)
+ }
+}
+
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
scanner := bufio.NewScanner(strings.NewReader(body))
matched := false
@@ -120,3 +448,30 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
}
return nil, false
}
+
+func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
+ scanner := bufio.NewScanner(strings.NewReader(body))
+ matched := false
+ out := make([]map[string]any, 0, 2)
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if strings.HasPrefix(line, "event: ") {
+ evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
+ matched = evt == targetEvent
+ continue
+ }
+ if !matched || !strings.HasPrefix(line, "data: ") {
+ continue
+ }
+ raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
+ if raw == "" || raw == "[DONE]" {
+ continue
+ }
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(raw), &payload); err != nil {
+ continue
+ }
+ out = append(out, payload)
+ }
+ return out
+}
diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go
index 5883d03..7683ee7 100644
--- a/internal/adapter/openai/standard_request.go
+++ b/internal/adapter/openai/standard_request.go
@@ -8,7 +8,7 @@ import (
"ds2api/internal/util"
)
-func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
+func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
@@ -23,7 +23,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
if responseModel == "" {
responseModel = resolvedModel
}
- finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
+ finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
@@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
}, nil
}
-func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
+func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
model, _ := req["model"].(string)
model = strings.TrimSpace(model)
if model == "" {
@@ -67,7 +67,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (ut
if len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
}
- finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
+ finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go
index f3453a2..a876364 100644
--- a/internal/adapter/openai/standard_request_test.go
+++ b/internal/adapter/openai/standard_request_test.go
@@ -22,7 +22,7 @@ func TestNormalizeOpenAIChatRequest(t *testing.T) {
"temperature": 0.3,
"stream": true,
}
- n, err := normalizeOpenAIChatRequest(store, req)
+ n, err := normalizeOpenAIChatRequest(store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
@@ -47,7 +47,7 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
"input": "ping",
"instructions": "system",
}
- n, err := normalizeOpenAIResponsesRequest(store, req)
+ n, err := normalizeOpenAIResponsesRequest(store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
diff --git a/internal/adapter/openai/stream_status_test.go b/internal/adapter/openai/stream_status_test.go
new file mode 100644
index 0000000..4f8305a
--- /dev/null
+++ b/internal/adapter/openai/stream_status_test.go
@@ -0,0 +1,185 @@
+package openai
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/go-chi/chi/v5"
+ chimw "github.com/go-chi/chi/v5/middleware"
+
+ "ds2api/internal/auth"
+)
+
+type streamStatusAuthStub struct{}
+
+func (streamStatusAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
+ return &auth.RequestAuth{
+ UseConfigToken: false,
+ DeepSeekToken: "direct-token",
+ CallerID: "caller:test",
+ TriedAccounts: map[string]bool{},
+ }, nil
+}
+
+func (streamStatusAuthStub) DetermineCaller(_ *http.Request) (*auth.RequestAuth, error) {
+ return &auth.RequestAuth{
+ UseConfigToken: false,
+ DeepSeekToken: "direct-token",
+ CallerID: "caller:test",
+ TriedAccounts: map[string]bool{},
+ }, nil
+}
+
+func (streamStatusAuthStub) Release(_ *auth.RequestAuth) {}
+
+type streamStatusDSStub struct {
+ resp *http.Response
+}
+
+func (m streamStatusDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
+ return "session-id", nil
+}
+
+func (m streamStatusDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
+ return "pow", nil
+}
+
+func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
+ return m.resp, nil
+}
+
+func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
+ body := strings.Join(lines, "\n")
+ if !strings.HasSuffix(body, "\n") {
+ body += "\n"
+ }
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(body)),
+ }
+}
+
+func captureStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
+ next.ServeHTTP(ww, r)
+ *statuses = append(*statuses, ww.Status())
+ })
+ }
+}
+
+func TestChatCompletionsStreamStatusCapturedAs200(t *testing.T) {
+ statuses := make([]int, 0, 1)
+ h := &Handler{
+ Store: mockOpenAIConfig{wideInput: true},
+ Auth: streamStatusAuthStub{},
+ DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
+ }
+ r := chi.NewRouter()
+ r.Use(captureStatusMiddleware(&statuses))
+ RegisterRoutes(r, h)
+
+ reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}`
+ req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
+ req.Header.Set("Authorization", "Bearer direct-token")
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
+ }
+ if len(statuses) != 1 {
+ t.Fatalf("expected one captured status, got %d", len(statuses))
+ }
+ if statuses[0] != http.StatusOK {
+ t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
+ }
+}
+
+func TestResponsesStreamStatusCapturedAs200(t *testing.T) {
+ statuses := make([]int, 0, 1)
+ h := &Handler{
+ Store: mockOpenAIConfig{wideInput: true},
+ Auth: streamStatusAuthStub{},
+ DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
+ }
+ r := chi.NewRouter()
+ r.Use(captureStatusMiddleware(&statuses))
+ RegisterRoutes(r, h)
+
+ reqBody := `{"model":"deepseek-chat","input":"hi","stream":true}`
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
+ req.Header.Set("Authorization", "Bearer direct-token")
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
+ }
+ if len(statuses) != 1 {
+ t.Fatalf("expected one captured status, got %d", len(statuses))
+ }
+ if statuses[0] != http.StatusOK {
+ t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
+ }
+}
+
+func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
+ statuses := make([]int, 0, 1)
+ content, _ := json.Marshal(map[string]any{
+ "p": "response/content",
+ "v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
+ })
+ h := &Handler{
+ Store: mockOpenAIConfig{wideInput: true},
+ Auth: streamStatusAuthStub{},
+ DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")},
+ }
+ r := chi.NewRouter()
+ r.Use(captureStatusMiddleware(&statuses))
+ RegisterRoutes(r, h)
+
+ reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}`
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
+ req.Header.Set("Authorization", "Bearer direct-token")
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
+ }
+ if len(statuses) != 1 || statuses[0] != http.StatusOK {
+ t.Fatalf("expected captured status 200, got %#v", statuses)
+ }
+
+ var out map[string]any
+ if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
+ t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
+ }
+ outputText, _ := out["output_text"].(string)
+ if outputText != "" {
+ t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText)
+ }
+ output, _ := out["output"].([]any)
+ hasFunctionCall := false
+ for _, item := range output {
+ m, _ := item.(map[string]any)
+ if m != nil && m["type"] == "function_call" {
+ hasFunctionCall = true
+ break
+ }
+ }
+ if !hasFunctionCall {
+ t.Fatalf("expected function_call output item, got %#v", output)
+ }
+}
diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go
index fd7222b..9c46649 100644
--- a/internal/adapter/openai/tool_sieve.go
+++ b/internal/adapter/openai/tool_sieve.go
@@ -11,6 +11,7 @@ type toolStreamSieveState struct {
capture strings.Builder
capturing bool
recentTextTail string
+ disableDeltas bool
toolNameSent bool
toolName string
toolArgsStart int
@@ -35,6 +36,7 @@ const toolSieveCaptureLimit = 8 * 1024
const toolSieveContextTailLimit = 256
func (s *toolStreamSieveState) resetIncrementalToolState() {
+ s.disableDeltas = false
s.toolNameSent = false
s.toolName = ""
s.toolArgsStart = -1
@@ -239,17 +241,8 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
}
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
if len(parsed) == 0 {
- if state.toolNameSent {
- return prefixPart, nil, suffixPart, true
- }
return captured, nil, "", true
}
- if state.toolNameSent {
- if len(parsed) > 1 {
- return prefixPart, parsed[1:], suffixPart, true
- }
- return prefixPart, nil, suffixPart, true
- }
return prefixPart, parsed, suffixPart, true
}
@@ -296,6 +289,9 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
}
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
+ if state.disableDeltas {
+ return nil
+ }
captured := state.capture.String()
if captured == "" {
return nil
@@ -312,6 +308,16 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if insideCodeFence(state.recentTextTail + captured[:start]) {
return nil
}
+ certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
+ if hasMultiple {
+ state.disableDeltas = true
+ return nil
+ }
+ if !certainSingle {
+ // In uncertain phases (e.g. first call arrived but array not closed yet),
+ // avoid speculative deltas and wait for final parsed tool_calls payload.
+ return nil
+ }
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
if !ok {
return nil
@@ -363,6 +369,68 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
return deltas
}
+func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
+ arrStart, ok := findToolCallsArrayStart(text, keyIdx)
+ if !ok {
+ return false, false
+ }
+ i := skipSpaces(text, arrStart+1)
+ if i >= len(text) || text[i] != '{' {
+ return false, false
+ }
+ count := 0
+ depth := 0
+ quote := byte(0)
+ escaped := false
+ for ; i < len(text); i++ {
+ ch := text[i]
+ if quote != 0 {
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == quote {
+ quote = 0
+ }
+ continue
+ }
+ if ch == '"' || ch == '\'' {
+ quote = ch
+ continue
+ }
+ if ch == '{' {
+ if depth == 0 {
+ count++
+ if count > 1 {
+ return false, true
+ }
+ }
+ depth++
+ continue
+ }
+ if ch == '}' {
+ if depth > 0 {
+ depth--
+ }
+ continue
+ }
+ if ch == ',' && depth == 0 {
+ // top-level separator means at least one more tool call exists
+ // (or is expected). Treat as multi-call and stop incremental deltas.
+ return false, true
+ }
+ if ch == ']' && depth == 0 {
+ return count == 1, false
+ }
+ }
+ // array not closed yet: still uncertain whether more calls will appear
+ return false, false
+}
+
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
diff --git a/internal/adapter/openai/trace.go b/internal/adapter/openai/trace.go
new file mode 100644
index 0000000..8ea58f0
--- /dev/null
+++ b/internal/adapter/openai/trace.go
@@ -0,0 +1,21 @@
+package openai
+
+import (
+ "net/http"
+ "strings"
+
+ "github.com/go-chi/chi/v5/middleware"
+)
+
+func requestTraceID(r *http.Request) string {
+ if r == nil {
+ return ""
+ }
+ if q := strings.TrimSpace(r.URL.Query().Get("__trace_id")); q != "" {
+ return q
+ }
+ if h := strings.TrimSpace(r.Header.Get("X-Ds2-Test-Trace")); h != "" {
+ return h
+ }
+ return strings.TrimSpace(middleware.GetReqID(r.Context()))
+}
diff --git a/internal/adapter/openai/trace_test.go b/internal/adapter/openai/trace_test.go
new file mode 100644
index 0000000..cbacbf3
--- /dev/null
+++ b/internal/adapter/openai/trace_test.go
@@ -0,0 +1,47 @@
+package openai
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/go-chi/chi/v5/middleware"
+)
+
+func traceIDViaMiddleware(req *http.Request) string {
+ if req == nil {
+ return requestTraceID(nil)
+ }
+ var got string
+ h := middleware.RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
+ got = requestTraceID(r)
+ }))
+ h.ServeHTTP(httptest.NewRecorder(), req)
+ return got
+}
+
+func TestRequestTraceIDPriority(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions?__trace_id=query-trace", nil)
+ req.Header.Set("X-Ds2-Test-Trace", "header-trace")
+ got := traceIDViaMiddleware(req)
+ if got != "query-trace" {
+ t.Fatalf("expected query trace id to win, got %q", got)
+ }
+}
+
+func TestRequestTraceIDHeaderFallback(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
+ req.Header.Set("X-Ds2-Test-Trace", "header-trace")
+ got := traceIDViaMiddleware(req)
+ if got != "header-trace" {
+ t.Fatalf("expected header trace id to win when query missing, got %q", got)
+ }
+}
+
+func TestRequestTraceIDReqIDFallback(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
+ got := traceIDViaMiddleware(req)
+ if got == "" {
+ t.Fatal("expected middleware request id fallback to be non-empty")
+ }
+}
diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go
index 65006c4..f34ea8b 100644
--- a/internal/adapter/openai/vercel_stream.go
+++ b/internal/adapter/openai/vercel_stream.go
@@ -56,7 +56,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
return
}
- stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
+ stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
diff --git a/internal/admin/handler.go b/internal/admin/handler.go
index 829b657..c8f7702 100644
--- a/internal/admin/handler.go
+++ b/internal/admin/handler.go
@@ -36,5 +36,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
pr.Post("/vercel/sync", h.syncVercel)
pr.Get("/vercel/status", h.vercelStatus)
pr.Get("/export", h.exportConfig)
+ pr.Get("/dev/captures", h.getDevCaptures)
+ pr.Delete("/dev/captures", h.clearDevCaptures)
})
}
diff --git a/internal/admin/handler_dev_capture.go b/internal/admin/handler_dev_capture.go
new file mode 100644
index 0000000..9b3615c
--- /dev/null
+++ b/internal/admin/handler_dev_capture.go
@@ -0,0 +1,26 @@
+package admin
+
+import (
+ "net/http"
+
+ "ds2api/internal/devcapture"
+)
+
+func (h *Handler) getDevCaptures(w http.ResponseWriter, _ *http.Request) {
+ store := devcapture.Global()
+ writeJSON(w, http.StatusOK, map[string]any{
+ "enabled": store.Enabled(),
+ "limit": store.Limit(),
+ "max_body_bytes": store.MaxBodyBytes(),
+ "items": store.Snapshot(),
+ })
+}
+
+func (h *Handler) clearDevCaptures(w http.ResponseWriter, _ *http.Request) {
+ store := devcapture.Global()
+ store.Clear()
+ writeJSON(w, http.StatusOK, map[string]any{
+ "success": true,
+ "detail": "capture logs cleared",
+ })
+}
diff --git a/internal/admin/handler_dev_capture_test.go b/internal/admin/handler_dev_capture_test.go
new file mode 100644
index 0000000..90ced8b
--- /dev/null
+++ b/internal/admin/handler_dev_capture_test.go
@@ -0,0 +1,45 @@
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestGetDevCapturesShape(t *testing.T) {
+ h := &Handler{}
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/admin/dev/captures", nil)
+ h.getDevCaptures(rec, req)
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
+ }
+ var out map[string]any
+ if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
+ t.Fatalf("decode failed: %v", err)
+ }
+ if _, ok := out["enabled"]; !ok {
+ t.Fatalf("expected enabled field, got %#v", out)
+ }
+ if _, ok := out["items"]; !ok {
+ t.Fatalf("expected items field, got %#v", out)
+ }
+}
+
+func TestClearDevCapturesShape(t *testing.T) {
+ h := &Handler{}
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodDelete, "/admin/dev/captures", nil)
+ h.clearDevCaptures(rec, req)
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
+ }
+ var out map[string]any
+ if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
+ t.Fatalf("decode failed: %v", err)
+ }
+ if out["success"] != true {
+ t.Fatalf("expected success=true, got %#v", out)
+ }
+}
diff --git a/internal/auth/request.go b/internal/auth/request.go
index 25980cf..23acb40 100644
--- a/internal/auth/request.go
+++ b/internal/auth/request.go
@@ -187,7 +187,12 @@ func extractCallerToken(req *http.Request) string {
return token
}
}
- return strings.TrimSpace(req.Header.Get("x-api-key"))
+ if key := strings.TrimSpace(req.Header.Get("x-api-key")); key != "" {
+ return key
+ }
+ // Gemini AI Studio compatibility: allow query key fallback only when no
+ // header-based credential is present.
+ return strings.TrimSpace(req.URL.Query().Get("key"))
}
func callerTokenID(token string) string {
diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go
index c292856..2eca44b 100644
--- a/internal/auth/request_test.go
+++ b/internal/auth/request_test.go
@@ -114,6 +114,40 @@ func TestDetermineMissingToken(t *testing.T) {
}
}
+func TestDetermineWithQueryKeyUsesDirectToken(t *testing.T) {
+ r := newTestResolver(t)
+ req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=direct-query-key", nil)
+
+ a, err := r.Determine(req)
+ if err != nil {
+ t.Fatalf("determine failed: %v", err)
+ }
+ if a.UseConfigToken {
+ t.Fatalf("expected direct token mode")
+ }
+ if a.DeepSeekToken != "direct-query-key" {
+ t.Fatalf("unexpected token: %q", a.DeepSeekToken)
+ }
+}
+
+func TestDetermineHeaderTokenPrecedenceOverQueryKey(t *testing.T) {
+ r := newTestResolver(t)
+ req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=query-key", nil)
+ req.Header.Set("x-api-key", "managed-key")
+
+ a, err := r.Determine(req)
+ if err != nil {
+ t.Fatalf("determine failed: %v", err)
+ }
+ defer r.Release(a)
+ if !a.UseConfigToken {
+ t.Fatalf("expected managed key mode from header token")
+ }
+ if a.AccountID == "" {
+ t.Fatalf("expected managed account to be acquired")
+ }
+}
+
func TestDetermineCallerMissingToken(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
diff --git a/internal/deepseek/client.go b/internal/deepseek/client.go
index 0523435..2ffe05d 100644
--- a/internal/deepseek/client.go
+++ b/internal/deepseek/client.go
@@ -16,6 +16,7 @@ import (
"ds2api/internal/auth"
"ds2api/internal/config"
trans "ds2api/internal/deepseek/transport"
+ "ds2api/internal/devcapture"
"ds2api/internal/util"
"github.com/andybalholm/brotli"
@@ -27,6 +28,7 @@ var intFrom = util.IntFrom
type Client struct {
Store *config.Store
Auth *auth.Resolver
+ capture *devcapture.Store
regular trans.Doer
stream trans.Doer
fallback *http.Client
@@ -39,6 +41,7 @@ func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
return &Client{
Store: store,
Auth: resolver,
+ capture: devcapture.Global(),
regular: trans.New(60 * time.Second),
stream: trans.New(0),
fallback: &http.Client{Timeout: 60 * time.Second},
@@ -179,6 +182,7 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
}
headers := c.authHeaders(a.DeepSeekToken)
headers["x-ds-pow-response"] = powResp
+ captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
attempts := 0
for attempts < maxAttempts {
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
@@ -188,8 +192,14 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
continue
}
if resp.StatusCode == http.StatusOK {
+ if captureSession != nil {
+ resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
+ }
return resp, nil
}
+ if captureSession != nil {
+ resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
+ }
_ = resp.Body.Close()
attempts++
time.Sleep(time.Second)
diff --git a/internal/devcapture/store.go b/internal/devcapture/store.go
new file mode 100644
index 0000000..6d0d8cd
--- /dev/null
+++ b/internal/devcapture/store.go
@@ -0,0 +1,259 @@
+package devcapture
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+)
+
+const (
+ defaultLimit = 5
+ defaultMaxBodyBytes = 2 * 1024 * 1024
+ maxLimit = 50
+)
+
+type Entry struct {
+ ID string `json:"id"`
+ CreatedAt int64 `json:"created_at"`
+ Label string `json:"label"`
+ URL string `json:"url"`
+ AccountID string `json:"account_id,omitempty"`
+ StatusCode int `json:"status_code"`
+ RequestBody string `json:"request_body"`
+ ResponseBody string `json:"response_body"`
+ ResponseTruncated bool `json:"response_truncated"`
+}
+
+type Store struct {
+ mu sync.Mutex
+ enabled bool
+ limit int
+ maxBodyBytes int
+ items []Entry
+}
+
+type Session struct {
+ store *Store
+ id string
+ createdAt int64
+ label string
+ url string
+ accountID string
+ requestRaw string
+}
+
+type captureBody struct {
+ rc io.ReadCloser
+ s *Session
+ statusCode int
+ buf strings.Builder
+ truncated bool
+ finalized bool
+}
+
+var (
+ globalOnce sync.Once
+ globalInst *Store
+)
+
+func Global() *Store {
+ globalOnce.Do(func() {
+ globalInst = NewFromEnv()
+ })
+ return globalInst
+}
+
+func NewFromEnv() *Store {
+ enabled := !isVercelRuntime()
+ if raw, ok := os.LookupEnv("DS2API_DEV_PACKET_CAPTURE"); ok {
+ enabled = parseBool(raw)
+ }
+ limit := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_LIMIT"), defaultLimit)
+ if limit < 1 {
+ limit = defaultLimit
+ }
+ if limit > maxLimit {
+ limit = maxLimit
+ }
+ maxBodyBytes := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES"), defaultMaxBodyBytes)
+ if maxBodyBytes < 1024 {
+ maxBodyBytes = defaultMaxBodyBytes
+ }
+ return &Store{
+ enabled: enabled,
+ limit: limit,
+ maxBodyBytes: maxBodyBytes,
+ items: make([]Entry, 0, limit),
+ }
+}
+
+func isVercelRuntime() bool {
+ return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
+}
+
+func (s *Store) Enabled() bool {
+ if s == nil {
+ return false
+ }
+ return s.enabled
+}
+
+func (s *Store) Limit() int {
+ if s == nil {
+ return defaultLimit
+ }
+ return s.limit
+}
+
+func (s *Store) MaxBodyBytes() int {
+ if s == nil {
+ return defaultMaxBodyBytes
+ }
+ return s.maxBodyBytes
+}
+
+func (s *Store) Snapshot() []Entry {
+ if s == nil {
+ return nil
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ out := make([]Entry, len(s.items))
+ copy(out, s.items)
+ return out
+}
+
+func (s *Store) Clear() {
+ if s == nil {
+ return
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.items = s.items[:0]
+}
+
+func (s *Store) Start(label, url, accountID string, requestPayload any) *Session {
+ if s == nil || !s.enabled {
+ return nil
+ }
+ return &Session{
+ store: s,
+ id: "cap_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
+ createdAt: time.Now().Unix(),
+ label: strings.TrimSpace(label),
+ url: strings.TrimSpace(url),
+ accountID: strings.TrimSpace(accountID),
+ requestRaw: marshalPayload(requestPayload),
+ }
+}
+
+func (s *Session) WrapBody(rc io.ReadCloser, statusCode int) io.ReadCloser {
+ if s == nil || rc == nil {
+ return rc
+ }
+ return &captureBody{
+ rc: rc,
+ s: s,
+ statusCode: statusCode,
+ }
+}
+
+func (c *captureBody) Read(p []byte) (int, error) {
+ n, err := c.rc.Read(p)
+ if n > 0 {
+ c.append(string(p[:n]))
+ }
+ if err == io.EOF {
+ c.finalize()
+ }
+ return n, err
+}
+
+func (c *captureBody) Close() error {
+ err := c.rc.Close()
+ c.finalize()
+ return err
+}
+
+func (c *captureBody) append(chunk string) {
+ if chunk == "" || c.s == nil || c.s.store == nil {
+ return
+ }
+ maxLen := c.s.store.maxBodyBytes
+ current := c.buf.Len()
+ if current >= maxLen {
+ c.truncated = true
+ return
+ }
+ remain := maxLen - current
+ if len(chunk) > remain {
+ c.buf.WriteString(chunk[:remain])
+ c.truncated = true
+ return
+ }
+ c.buf.WriteString(chunk)
+}
+
+func (c *captureBody) finalize() {
+ if c.finalized || c.s == nil || c.s.store == nil {
+ return
+ }
+ c.finalized = true
+ entry := Entry{
+ ID: c.s.id,
+ CreatedAt: c.s.createdAt,
+ Label: c.s.label,
+ URL: c.s.url,
+ AccountID: c.s.accountID,
+ StatusCode: c.statusCode,
+ RequestBody: c.s.requestRaw,
+ ResponseBody: c.buf.String(),
+ ResponseTruncated: c.truncated,
+ }
+ c.s.store.push(entry)
+}
+
+func (s *Store) push(entry Entry) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.items = append([]Entry{entry}, s.items...)
+ if len(s.items) > s.limit {
+ s.items = s.items[:s.limit]
+ }
+}
+
+func marshalPayload(v any) string {
+ b, err := json.Marshal(v)
+ if err != nil {
+ return fmt.Sprintf("%v", v)
+ }
+ return string(b)
+}
+
+func parseBool(v string) bool {
+ switch strings.ToLower(strings.TrimSpace(v)) {
+ case "1", "true", "yes", "on":
+ return true
+ default:
+ return false
+ }
+}
+
+func parseIntWithDefault(raw string, d int) int {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return d
+ }
+ n, err := strconv.Atoi(raw)
+ if err != nil {
+ return d
+ }
+ return n
+}
diff --git a/internal/devcapture/store_test.go b/internal/devcapture/store_test.go
new file mode 100644
index 0000000..1dd58b4
--- /dev/null
+++ b/internal/devcapture/store_test.go
@@ -0,0 +1,55 @@
+package devcapture
+
+import (
+ "io"
+ "strings"
+ "testing"
+)
+
+func TestStorePushKeepsNewestWithinLimit(t *testing.T) {
+ s := &Store{enabled: true, limit: 2, maxBodyBytes: 1024}
+ for i := 0; i < 3; i++ {
+ session := s.Start("test", "http://x", "", map[string]any{"seq": i})
+ if session == nil {
+ t.Fatal("expected session")
+ }
+ rc := session.WrapBody(io.NopCloser(strings.NewReader("ok")), 200)
+ _, _ = io.ReadAll(rc)
+ _ = rc.Close()
+ }
+ items := s.Snapshot()
+ if len(items) != 2 {
+ t.Fatalf("expected 2 items, got %d", len(items))
+ }
+ if !strings.Contains(items[0].RequestBody, `"seq":2`) {
+ t.Fatalf("expected newest first, got %#v", items[0].RequestBody)
+ }
+ if !strings.Contains(items[1].RequestBody, `"seq":1`) {
+ t.Fatalf("expected second newest, got %#v", items[1].RequestBody)
+ }
+}
+
+func TestWrapBodyTruncatesByLimit(t *testing.T) {
+ s := &Store{enabled: true, limit: 5, maxBodyBytes: 4}
+ session := s.Start("test", "http://x", "acc1", map[string]any{"x": 1})
+ if session == nil {
+ t.Fatal("expected session")
+ }
+ rc := session.WrapBody(io.NopCloser(strings.NewReader("abcdef")), 200)
+ _, _ = io.ReadAll(rc)
+ _ = rc.Close()
+
+ items := s.Snapshot()
+ if len(items) != 1 {
+ t.Fatalf("expected 1 item, got %d", len(items))
+ }
+ if items[0].ResponseBody != "abcd" {
+ t.Fatalf("expected truncated body, got %q", items[0].ResponseBody)
+ }
+ if !items[0].ResponseTruncated {
+ t.Fatal("expected truncated flag true")
+ }
+ if items[0].AccountID != "acc1" {
+ t.Fatalf("expected account id, got %q", items[0].AccountID)
+ }
+}
diff --git a/internal/format/openai/render.go b/internal/format/openai/render.go
index fc7473f..2107d4e 100644
--- a/internal/format/openai/render.go
+++ b/internal/format/openai/render.go
@@ -1,6 +1,7 @@
package openai
import (
+ "encoding/json"
"strings"
"time"
@@ -43,36 +44,45 @@ func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalT
}
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
+ // Align responses tool-call semantics with chat/completions:
+ // mixed prose + tool_call payloads should still be interpreted as tool calls.
detected := util.ParseToolCalls(finalText, toolNames)
+ if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
+ detected = util.ParseToolCalls(finalThinking, toolNames)
+ }
exposedOutputText := finalText
output := make([]any, 0, 2)
if len(detected) > 0 {
exposedOutputText = ""
- toolCalls := make([]any, 0, len(detected))
- for _, tc := range detected {
- toolCalls = append(toolCalls, map[string]any{
- "type": "tool_call",
- "name": tc.Name,
- "arguments": tc.Input,
+ if strings.TrimSpace(finalThinking) != "" {
+ output = append(output, map[string]any{
+ "type": "reasoning",
+ "text": finalThinking,
})
}
+ formatted := util.FormatOpenAIToolCalls(detected)
+ output = append(output, toResponsesFunctionCallItems(formatted)...)
output = append(output, map[string]any{
"type": "tool_calls",
- "tool_calls": toolCalls,
+ "tool_calls": formatted,
})
} else {
- content := []any{
- map[string]any{
- "type": "output_text",
- "text": finalText,
- },
- }
+ content := make([]any, 0, 2)
if finalThinking != "" {
content = append([]any{map[string]any{
"type": "reasoning",
"text": finalThinking,
}}, content...)
}
+ if strings.TrimSpace(finalText) != "" {
+ content = append(content, map[string]any{
+ "type": "output_text",
+ "text": finalText,
+ })
+ }
+ if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" {
+ exposedOutputText = finalThinking
+ }
output = append(output, map[string]any{
"type": "message",
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
@@ -100,6 +110,54 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
}
}
+func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
+ if len(toolCalls) == 0 {
+ return nil
+ }
+ out := make([]any, 0, len(toolCalls))
+ for _, tc := range toolCalls {
+ callID, _ := tc["id"].(string)
+ if strings.TrimSpace(callID) == "" {
+ callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ }
+ name := ""
+ args := "{}"
+ if fn, ok := tc["function"].(map[string]any); ok {
+ if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
+ name = n
+ }
+ if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
+ args = a
+ }
+ }
+ out = append(out, map[string]any{
+ "id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
+ "type": "function_call",
+ "call_id": callID,
+ "name": name,
+ "arguments": normalizeJSONString(args),
+ "status": "completed",
+ })
+ }
+ return out
+}
+
+func normalizeJSONString(raw string) string {
+ s := strings.TrimSpace(raw)
+ if s == "" {
+ return "{}"
+ }
+ var v any
+ if err := json.Unmarshal([]byte(s), &v); err != nil {
+ return raw
+ }
+ b, err := json.Marshal(v)
+ if err != nil {
+ return raw
+ }
+ return string(b)
+}
+
func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
return map[string]any{
"delta": delta,
@@ -145,49 +203,105 @@ func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
return map[string]any{
- "type": "response.created",
- "id": responseID,
- "object": "response",
- "model": model,
- "status": "in_progress",
+ "type": "response.created",
+ "id": responseID,
+ "response_id": responseID,
+ "object": "response",
+ "model": model,
+ "status": "in_progress",
}
}
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
- "type": "response.output_text.delta",
- "id": responseID,
- "delta": delta,
+ "type": "response.output_text.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "delta": delta,
}
}
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
- "type": "response.reasoning.delta",
- "id": responseID,
- "delta": delta,
+ "type": "response.reasoning.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "delta": delta,
+ }
+}
+
+func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
+ return map[string]any{
+ "type": "response.reasoning_text.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "item_id": itemID,
+ "output_index": outputIndex,
+ "content_index": contentIndex,
+ "delta": delta,
+ }
+}
+
+func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
+ return map[string]any{
+ "type": "response.reasoning_text.done",
+ "id": responseID,
+ "response_id": responseID,
+ "item_id": itemID,
+ "output_index": outputIndex,
+ "content_index": contentIndex,
+ "text": text,
}
}
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
- "type": "response.output_tool_call.delta",
- "id": responseID,
- "tool_calls": toolCalls,
+ "type": "response.output_tool_call.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "tool_calls": toolCalls,
}
}
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
- "type": "response.output_tool_call.done",
- "id": responseID,
- "tool_calls": toolCalls,
+ "type": "response.output_tool_call.done",
+ "id": responseID,
+ "response_id": responseID,
+ "tool_calls": toolCalls,
+ }
+}
+
+func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
+ return map[string]any{
+ "type": "response.function_call_arguments.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "item_id": itemID,
+ "output_index": outputIndex,
+ "call_id": callID,
+ "delta": delta,
+ }
+}
+
+func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any {
+ return map[string]any{
+ "type": "response.function_call_arguments.done",
+ "id": responseID,
+ "response_id": responseID,
+ "item_id": itemID,
+ "output_index": outputIndex,
+ "call_id": callID,
+ "name": name,
+ "arguments": normalizeJSONString(arguments),
}
}
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
+ responseID, _ := response["id"].(string)
return map[string]any{
- "type": "response.completed",
- "response": response,
+ "type": "response.completed",
+ "response_id": responseID,
+ "response": response,
}
}
diff --git a/internal/format/openai/render_test.go b/internal/format/openai/render_test.go
new file mode 100644
index 0000000..e3bf0dd
--- /dev/null
+++ b/internal/format/openai/render_test.go
@@ -0,0 +1,181 @@
+package openai
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
+ obj := BuildResponseObject(
+ "resp_test",
+ "gpt-4o",
+ "prompt",
+ "",
+ `{"tool_calls":[{"name":"search","input":{"q":"golang"}}]}`,
+ []string{"search"},
+ )
+
+ outputText, _ := obj["output_text"].(string)
+ if outputText != "" {
+ t.Fatalf("expected output_text to be hidden for tool calls, got %q", outputText)
+ }
+
+ output, _ := obj["output"].([]any)
+ if len(output) != 2 {
+ t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
+ }
+
+ first, _ := output[0].(map[string]any)
+ if first["type"] != "function_call" {
+ t.Fatalf("expected first output item type function_call, got %#v", first["type"])
+ }
+ if first["call_id"] == "" {
+ t.Fatalf("expected function_call item to have call_id, got %#v", first)
+ }
+ second, _ := output[1].(map[string]any)
+ if second["type"] != "tool_calls" {
+ t.Fatalf("expected second output item type tool_calls, got %#v", second["type"])
+ }
+ var toolCalls []map[string]any
+ switch v := second["tool_calls"].(type) {
+ case []map[string]any:
+ toolCalls = v
+ case []any:
+ toolCalls = make([]map[string]any, 0, len(v))
+ for _, item := range v {
+ m, _ := item.(map[string]any)
+ if m != nil {
+ toolCalls = append(toolCalls, m)
+ }
+ }
+ }
+ if len(toolCalls) != 1 {
+ t.Fatalf("expected one tool call, got %#v", second["tool_calls"])
+ }
+ tc := toolCalls[0]
+ if tc["type"] != "function" || tc["id"] == "" {
+ t.Fatalf("unexpected tool call shape: %#v", tc)
+ }
+ fn, _ := tc["function"].(map[string]any)
+ if fn["name"] != "search" {
+ t.Fatalf("unexpected function name: %#v", fn["name"])
+ }
+ argsRaw, _ := fn["arguments"].(string)
+ var args map[string]any
+ if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
+ t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
+ }
+ if args["q"] != "golang" {
+ t.Fatalf("unexpected arguments: %#v", args)
+ }
+}
+
+func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) {
+ obj := BuildResponseObject(
+ "resp_test",
+ "gpt-4o",
+ "prompt",
+ "",
+ `示例格式:{"tool_calls":[{"name":"search","input":{"q":"golang"}}]},但这条是普通回答。`,
+ []string{"search"},
+ )
+
+ outputText, _ := obj["output_text"].(string)
+ if outputText != "" {
+ t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText)
+ }
+
+ output, _ := obj["output"].([]any)
+ if len(output) != 2 {
+ t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
+ }
+ first, _ := output[0].(map[string]any)
+ if first["type"] != "function_call" {
+ t.Fatalf("expected first output type function_call, got %#v", first["type"])
+ }
+}
+
+func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) {
+ obj := BuildResponseObject(
+ "resp_test",
+ "gpt-4o",
+ "prompt",
+ "",
+ "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
+ []string{"search"},
+ )
+
+ outputText, _ := obj["output_text"].(string)
+ if outputText == "" {
+ t.Fatalf("expected output_text preserved for fenced example")
+ }
+ output, _ := obj["output"].([]any)
+ if len(output) != 1 {
+ t.Fatalf("expected one message output item, got %#v", obj["output"])
+ }
+ first, _ := output[0].(map[string]any)
+ if first["type"] != "message" {
+ t.Fatalf("expected message output type, got %#v", first["type"])
+ }
+}
+
+func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
+ obj := BuildResponseObject(
+ "resp_test",
+ "gpt-4o",
+ "prompt",
+ "internal thinking content",
+ "",
+ nil,
+ )
+
+ outputText, _ := obj["output_text"].(string)
+ if outputText == "" {
+ t.Fatalf("expected output_text fallback from reasoning when final text is empty")
+ }
+
+ output, _ := obj["output"].([]any)
+ if len(output) != 1 {
+ t.Fatalf("expected one output item, got %#v", obj["output"])
+ }
+ first, _ := output[0].(map[string]any)
+ if first["type"] != "message" {
+ t.Fatalf("expected output type message, got %#v", first["type"])
+ }
+ content, _ := first["content"].([]any)
+ if len(content) == 0 {
+ t.Fatalf("expected reasoning content, got %#v", first["content"])
+ }
+ block0, _ := content[0].(map[string]any)
+ if block0["type"] != "reasoning" {
+ t.Fatalf("expected first content block reasoning, got %#v", block0["type"])
+ }
+}
+
+func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
+ obj := BuildResponseObject(
+ "resp_test",
+ "gpt-4o",
+ "prompt",
+ `{"tool_calls":[{"name":"search","input":{"q":"from-thinking"}}]}`,
+ "",
+ []string{"search"},
+ )
+
+ output, _ := obj["output"].([]any)
+ if len(output) != 3 {
+ t.Fatalf("expected reasoning + function_call + tool_calls outputs, got %#v", obj["output"])
+ }
+ first, _ := output[0].(map[string]any)
+ if first["type"] != "reasoning" {
+ t.Fatalf("expected first output reasoning, got %#v", first["type"])
+ }
+ second, _ := output[1].(map[string]any)
+ if second["type"] != "function_call" {
+ t.Fatalf("expected second output function_call, got %#v", second["type"])
+ }
+ third, _ := output[2].(map[string]any)
+ if third["type"] != "tool_calls" {
+ t.Fatalf("expected third output tool_calls, got %#v", third["type"])
+ }
+}
diff --git a/internal/server/router.go b/internal/server/router.go
index a81f0cb..ae3108e 100644
--- a/internal/server/router.go
+++ b/internal/server/router.go
@@ -12,6 +12,7 @@ import (
"ds2api/internal/account"
"ds2api/internal/adapter/claude"
+ "ds2api/internal/adapter/gemini"
"ds2api/internal/adapter/openai"
"ds2api/internal/admin"
"ds2api/internal/auth"
@@ -44,6 +45,7 @@ func NewApp() *App {
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient}
+ geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient}
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient}
webuiHandler := webui.NewHandler()
@@ -67,6 +69,7 @@ func NewApp() *App {
})
openai.RegisterRoutes(r, openaiHandler)
claude.RegisterRoutes(r, claudeHandler)
+ gemini.RegisterRoutes(r, geminiHandler)
r.Route("/admin", func(ar chi.Router) {
admin.RegisterRoutes(ar, adminHandler)
})
diff --git a/scripts/testsuite/run-live.sh b/tests/scripts/run-live.sh
similarity index 100%
rename from scripts/testsuite/run-live.sh
rename to tests/scripts/run-live.sh
diff --git a/tests/scripts/run-unit-all.sh b/tests/scripts/run-unit-all.sh
new file mode 100755
index 0000000..59b202c
--- /dev/null
+++ b/tests/scripts/run-unit-all.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
+cd "$ROOT_DIR"
+
+./tests/scripts/run-unit-go.sh
+./tests/scripts/run-unit-node.sh
diff --git a/tests/scripts/run-unit-go.sh b/tests/scripts/run-unit-go.sh
new file mode 100755
index 0000000..38a11b8
--- /dev/null
+++ b/tests/scripts/run-unit-go.sh
@@ -0,0 +1,7 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
+cd "$ROOT_DIR"
+
+go test ./... "$@"
diff --git a/tests/scripts/run-unit-node.sh b/tests/scripts/run-unit-node.sh
new file mode 100755
index 0000000..95f11e0
--- /dev/null
+++ b/tests/scripts/run-unit-node.sh
@@ -0,0 +1,7 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
+cd "$ROOT_DIR"
+
+node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js "$@"
diff --git a/vercel.json b/vercel.json
index 2e68a94..bad49e0 100644
--- a/vercel.json
+++ b/vercel.json
@@ -38,6 +38,18 @@
"source": "/admin/config",
"destination": "/api/index"
},
+ {
+ "source": "/admin/config/(.*)",
+ "destination": "/api/index"
+ },
+ {
+ "source": "/admin/settings",
+ "destination": "/api/index"
+ },
+ {
+ "source": "/admin/settings/(.*)",
+ "destination": "/api/index"
+ },
{
"source": "/admin/keys(.*)",
"destination": "/api/index"
diff --git a/webui/src/App.jsx b/webui/src/App.jsx
index 3f6ad27..2c3f099 100644
--- a/webui/src/App.jsx
+++ b/webui/src/App.jsx
@@ -1,4 +1,4 @@
-import { useState, useEffect } from 'react'
+import { useState, useEffect, useCallback, useMemo } from 'react'
import {
Routes,
Route,
@@ -29,12 +29,12 @@ import Login from './components/Login'
import LandingPage from './components/LandingPage'
import LanguageToggle from './components/LanguageToggle'
import { useI18n } from './i18n'
+import { detectRuntimeEnv } from './utils/runtimeEnv'
-function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout }) {
+function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout, isVercel }) {
const { t } = useI18n()
const [activeTab, setActiveTab] = useState('accounts')
const [sidebarOpen, setSidebarOpen] = useState(false)
- const [loading, setLoading] = useState(false)
const navItems = [
{ id: 'accounts', label: t('nav.accounts.label'), icon: Users, description: t('nav.accounts.desc') },
@@ -44,7 +44,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message,
{ id: 'settings', label: t('nav.settings.label'), icon: SettingsIcon, description: t('nav.settings.desc') },
]
- const authFetch = async (url, options = {}) => {
+ const authFetch = useCallback(async (url, options = {}) => {
const headers = {
...options.headers,
'Authorization': `Bearer ${token}`
@@ -56,7 +56,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message,
throw new Error(t('auth.expired'))
}
return res
- }
+ }, [onLogout, t, token])
const renderTab = () => {
switch (activeTab) {
@@ -67,9 +67,9 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message,
case 'import':
return
{t('vercel.description')}
+ {pollPaused && ( ++ {t('vercel.pollPaused', { count: pollFailures })} +
+ +