diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index cdc141f..c4bd04b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -13,12 +13,8 @@ #### 🔀 变更说明 | Description of Change - + #### 📝 补充信息 | Additional Information - ---- - -> 💡 **提示**:如果修改了 `webui/` 目录下的文件,PR 合并后 CI 会自动构建并提交产物,无需手动构建。 \ No newline at end of file diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 67689cc..00cecee 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -12,6 +12,9 @@ permissions: jobs: build-and-upload: runs-on: ubuntu-latest + env: + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} + DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} steps: - name: Checkout uses: actions/checkout@v4 @@ -87,10 +90,11 @@ jobs: password: ${{ secrets.GITHUB_TOKEN }} - name: Log in to Docker Hub + if: "${{ env.DOCKERHUB_USERNAME != '' }}" uses: docker/login-action@v3 with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} + username: ${{ env.DOCKERHUB_USERNAME }} + password: ${{ env.DOCKERHUB_TOKEN }} - name: Extract Docker metadata id: meta_release @@ -98,7 +102,7 @@ jobs: with: images: | ghcr.io/${{ github.repository }} - cjackhwang/ds2api + ${{ env.DOCKERHUB_USERNAME || 'cjackhwang' }}/ds2api tags: | type=raw,value=${{ github.event.release.tag_name }} type=raw,value=latest diff --git a/CONTRIBUTING.en.md b/CONTRIBUTING.en.md index baf5eae..212cbb7 100644 --- a/CONTRIBUTING.en.md +++ b/CONTRIBUTING.en.md @@ -86,7 +86,7 @@ Manually build WebUI to `static/admin/`: go test ./... # End-to-end live tests (real accounts) -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` ## Project Structure diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c75d450..3bbd5b3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -86,7 +86,7 @@ docker-compose -f docker-compose.dev.yml up go test ./... # 端到端全链路测试(真实账号) -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` ## 项目结构 diff --git a/DEPLOY.en.md b/DEPLOY.en.md index 8a62c98..046df83 100644 --- a/DEPLOY.en.md +++ b/DEPLOY.en.md @@ -518,7 +518,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \ Run the full live testsuite before release (real account tests): ```bash -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` With custom flags: diff --git a/DEPLOY.md b/DEPLOY.md index e5b0630..4a1efcb 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -518,7 +518,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \ 建议在发布前执行完整的端到端测试集(使用真实账号): ```bash -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` 可自定义参数: diff --git a/README.MD b/README.MD index 261e34a..b26cc72 100644 --- a/README.MD +++ b/README.MD @@ -283,6 +283,9 @@ cp opencode.json.example opencode.json | `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容旧名) | — | | `DS2API_VERCEL_INTERNAL_SECRET` | Vercel 混合流式内部鉴权密钥 | 回退用 `DS2API_ADMIN_KEY` | | `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease 过期秒数 | `900` | +| `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 | +| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | 本地抓包保留条数(超出自动淘汰) | `5` | +| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | 单条响应体最大记录字节数 | `2097152` | | `VERCEL_TOKEN` | Vercel 同步 token | — | | `VERCEL_PROJECT_ID` | Vercel 项目 ID | — | | `VERCEL_TEAM_ID` | Vercel 团队 ID | — | @@ -321,6 +324,29 @@ cp opencode.json.example opencode.json 3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content` 4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出 +## 本地开发抓包工具 + +用于定位「responses 思考流/工具调用」等问题。开启后会自动记录最近 N 条 DeepSeek 对话上游请求体与响应体(默认 5 条,超出自动淘汰)。 + +启用示例: + +```bash +DS2API_DEV_PACKET_CAPTURE=true \ +DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \ +go run ./cmd/ds2api +``` + +查询/清空(需 Admin JWT): + +- `GET /admin/dev/captures`:查看抓包列表(最新在前) +- `DELETE /admin/dev/captures`:清空抓包 + +返回字段包含: + +- `request_body`:发送给 DeepSeek 的完整请求体 +- `response_body`:上游返回的原始流式内容拼接文本 +- `response_truncated`:是否触发单条大小截断 + ## 项目结构 ```text @@ -350,8 +376,10 @@ ds2api/ │ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage │ └── locales/ # 中英文语言包(zh.json / en.json) ├── scripts/ -│ ├── build-webui.sh # WebUI 手动构建脚本 -│ └── testsuite/ # 测试集运行脚本 +│ └── build-webui.sh # WebUI 手动构建脚本 +├── tests/ +│ ├── compat/ # 兼容性测试夹具与期望输出 +│ └── scripts/ # 统一测试脚本入口(unit/e2e) ├── static/admin/ # WebUI 构建产物(不提交到 Git) ├── .github/ │ ├── workflows/ # GitHub Actions(Release 自动构建) @@ -379,11 +407,11 @@ ds2api/ ## 测试 ```bash -# 单元测试 -go test ./... +# 单元测试(Go + Node) +./tests/scripts/run-unit-all.sh # 一键端到端全链路测试(真实账号,生成完整请求/响应日志) -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh # 或自定义参数 go run ./cmd/ds2api-tests \ diff --git a/README.en.md b/README.en.md index 5d2f326..69b47bd 100644 --- a/README.en.md +++ b/README.en.md @@ -283,6 +283,9 @@ cp opencode.json.example opencode.json | `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — | | `DS2API_VERCEL_INTERNAL_SECRET` | Vercel hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` | | `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL seconds | `900` | +| `DS2API_DEV_PACKET_CAPTURE` | Local dev packet capture switch (record recent request/response bodies) | Enabled by default on non-Vercel local runtime | +| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | Number of captured sessions to retain (auto-evict overflow) | `5` | +| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | Max recorded bytes per captured response body | `2097152` | | `VERCEL_TOKEN` | Vercel sync token | — | | `VERCEL_PROJECT_ID` | Vercel project ID | — | | `VERCEL_TEAM_ID` | Vercel team ID | — | @@ -321,6 +324,29 @@ When `tools` is present in the request, DS2API performs anti-leak handling: 3. Confirmed toolcall JSON fragments are never leaked into `delta.content` 4. Natural language before/after toolcalls keeps original order, with incremental argument output supported +## Local Dev Packet Capture + +This is for debugging issues such as Responses reasoning streaming and tool-call handoff. When enabled, DS2API stores the latest N DeepSeek conversation payload pairs (request body + upstream response body), defaulting to 5 entries with auto-eviction. + +Enable example: + +```bash +DS2API_DEV_PACKET_CAPTURE=true \ +DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \ +go run ./cmd/ds2api +``` + +Inspect/clear (Admin JWT required): + +- `GET /admin/dev/captures`: list captured items (newest first) +- `DELETE /admin/dev/captures`: clear captured items + +Response fields include: + +- `request_body`: full payload sent to DeepSeek +- `response_body`: concatenated raw upstream stream body text +- `response_truncated`: whether body-size truncation happened + ## Project Structure ```text @@ -350,8 +376,10 @@ ds2api/ │ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage │ └── locales/ # Language packs (zh.json / en.json) ├── scripts/ -│ ├── build-webui.sh # Manual WebUI build script -│ └── testsuite/ # Testsuite runner scripts +│ └── build-webui.sh # Manual WebUI build script +├── tests/ +│ ├── compat/ # Compatibility fixtures and expected outputs +│ └── scripts/ # Unified test script entrypoints (unit/e2e) ├── static/admin/ # WebUI build output (not committed to Git) ├── .github/ │ ├── workflows/ # GitHub Actions (Release artifact automation) @@ -379,11 +407,11 @@ ds2api/ ## Testing ```bash -# Unit tests -go test ./... +# Unit tests (Go + Node) +./tests/scripts/run-unit-all.sh # One-command live end-to-end tests (real accounts, full request/response logs) -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh # Or with custom flags go run ./cmd/ds2api-tests \ diff --git a/TESTING.md b/TESTING.md index ce349ec..f8e532a 100644 --- a/TESTING.md +++ b/TESTING.md @@ -8,8 +8,10 @@ DS2API 提供两个层级的测试: | 层级 | 命令 | 说明 | | --- | --- | --- | -| 单元测试 | `go test ./...` | 不需要真实账号 | -| 端到端测试 | `./scripts/testsuite/run-live.sh` | 使用真实账号执行全链路测试 | +| 单元测试(Go) | `./tests/scripts/run-unit-go.sh` | 不需要真实账号 | +| 单元测试(Node) | `./tests/scripts/run-unit-node.sh` | 不需要真实账号 | +| 单元测试(全部) | `./tests/scripts/run-unit-all.sh` | 不需要真实账号 | +| 端到端测试 | `./tests/scripts/run-live.sh` | 使用真实账号执行全链路测试 | 端到端测试集会录制完整的请求/响应日志,用于故障排查。 @@ -20,17 +22,19 @@ DS2API 提供两个层级的测试: ### 单元测试 | Unit Tests ```bash -go test ./... +./tests/scripts/run-unit-all.sh ``` ```bash -node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js +# 或按语言拆分执行 +./tests/scripts/run-unit-go.sh +./tests/scripts/run-unit-node.sh ``` ### 端到端测试 | End-to-End Tests ```bash -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` **默认行为**: @@ -179,7 +183,7 @@ go run ./cmd/ds2api-tests \ ```bash # 确保 config.json 存在且包含有效测试账号 -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh exit_code=$? if [ $exit_code -ne 0 ]; then echo "Tests failed! Check artifacts for details." diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index 282b569..2fa6796 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -38,6 +38,10 @@ func RegisterRoutes(r chi.Router, h *Handler) { r.Get("/anthropic/v1/models", h.ListModels) r.Post("/anthropic/v1/messages", h.Messages) r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens) + r.Post("/v1/messages", h.Messages) + r.Post("/messages", h.Messages) + r.Post("/v1/messages/count_tokens", h.CountTokens) + r.Post("/messages/count_tokens", h.CountTokens) } func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { @@ -167,7 +171,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") rc := http.NewResponseController(w) - canFlush := rc.Flush() == nil + _, canFlush := w.(http.Flusher) if !canFlush { config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") } @@ -250,7 +254,7 @@ func normalizeClaudeMessages(messages []any) []any { } } if typeStr == "tool_result" { - parts = append(parts, fmt.Sprintf("%v", b["content"])) + parts = append(parts, formatClaudeToolResultForPrompt(b)) } } copied["content"] = strings.Join(parts, "\n") @@ -272,10 +276,36 @@ func buildClaudeToolPrompt(tools []any) string { schema, _ := json.Marshal(m["input_schema"]) parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) } - parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}") + parts = append(parts, + "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}", + "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.", + "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.", + ) return strings.Join(parts, "\n\n") } +func formatClaudeToolResultForPrompt(block map[string]any) string { + if block == nil { + return "" + } + toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])) + if toolCallID == "" { + toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])) + } + if toolCallID == "" { + toolCallID = "unknown" + } + name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])) + if name == "" { + name = "unknown" + } + content := strings.TrimSpace(fmt.Sprintf("%v", block["content"])) + if content == "" { + content = "null" + } + return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content) +} + func hasSystemMessage(messages []any) bool { for _, m := range messages { msg, ok := m.(map[string]any) diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go index 73d2fab..ae75d8e 100644 --- a/internal/adapter/claude/handler_util_test.go +++ b/internal/adapter/claude/handler_util_test.go @@ -1,6 +1,7 @@ package claude import ( + "strings" "testing" ) @@ -48,8 +49,9 @@ func TestNormalizeClaudeMessagesToolResult(t *testing.T) { } got := normalizeClaudeMessages(msgs) m := got[0].(map[string]any) - if m["content"] != "tool output" { - t.Fatalf("expected 'tool output', got %q", m["content"]) + content, _ := m["content"].(string) + if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") { + t.Fatalf("expected serialized tool result marker, got %q", content) } } diff --git a/internal/adapter/claude/route_alias_test.go b/internal/adapter/claude/route_alias_test.go new file mode 100644 index 0000000..f01e5e3 --- /dev/null +++ b/internal/adapter/claude/route_alias_test.go @@ -0,0 +1,44 @@ +package claude + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" +) + +type routeAliasAuthStub struct{} + +func (routeAliasAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) { + return nil, auth.ErrUnauthorized +} + +func (routeAliasAuthStub) Release(_ *auth.RequestAuth) {} + +func TestClaudeRouteAliasesDoNot404(t *testing.T) { + h := &Handler{ + Auth: routeAliasAuthStub{}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + paths := []string{ + "/anthropic/v1/messages", + "/v1/messages", + "/messages", + "/anthropic/v1/messages/count_tokens", + "/v1/messages/count_tokens", + "/messages/count_tokens", + } + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code == http.StatusNotFound { + t.Fatalf("expected route %s to be registered, got 404", path) + } + } +} diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go index cdbb675..23520c0 100644 --- a/internal/adapter/claude/standard_request.go +++ b/internal/adapter/claude/standard_request.go @@ -27,9 +27,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma payload := cloneMap(req) payload["messages"] = normalizedMessages toolsRequested, _ := req["tools"].([]any) - if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) { - payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...) - } + payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested) dsPayload := convertClaudeToDeepSeek(payload, store) dsModel, _ := dsPayload["model"].(string) @@ -57,3 +55,59 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma NormalizedMessages: normalizedMessages, }, nil } + +func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any { + if len(tools) == 0 { + return normalizedMessages + } + toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools)) + if toolPrompt == "" { + return normalizedMessages + } + + // Prefer top-level Anthropic-style system prompt when available. + if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" { + payload["system"] = mergeSystemPrompt(systemText, toolPrompt) + return normalizedMessages + } + + messages := cloneAnySlice(normalizedMessages) + for i := range messages { + msg, ok := messages[i].(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if !strings.EqualFold(strings.TrimSpace(role), "system") { + continue + } + copied := cloneMap(msg) + copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt) + messages[i] = copied + return messages + } + + return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...) +} + +func mergeSystemPrompt(base, extra string) string { + base = strings.TrimSpace(base) + extra = strings.TrimSpace(extra) + switch { + case base == "": + return extra + case extra == "": + return base + default: + return base + "\n\n" + extra + } +} + +func cloneAnySlice(in []any) []any { + if len(in) == 0 { + return nil + } + out := make([]any, len(in)) + copy(out, in) + return out +} diff --git a/internal/adapter/claude/standard_request_test.go b/internal/adapter/claude/standard_request_test.go index 7ffdfb8..6110124 100644 --- a/internal/adapter/claude/standard_request_test.go +++ b/internal/adapter/claude/standard_request_test.go @@ -36,3 +36,57 @@ func TestNormalizeClaudeRequest(t *testing.T) { t.Fatalf("expected non-empty final prompt") } } + +func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{ + map[string]any{"role": "system", "content": "baseline rule"}, + map[string]any{"role": "user", "content": "hello"}, + }, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + + if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") { + t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt) + } + if !containsStr(norm.Standard.FinalPrompt, "baseline rule") { + t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt) + } +} + +func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "system": "top-level system", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + + if !containsStr(norm.Standard.FinalPrompt, "top-level system") { + t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt) + } + if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") { + t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt) + } +} diff --git a/internal/adapter/claude/stream_status_test.go b/internal/adapter/claude/stream_status_test.go new file mode 100644 index 0000000..c3936de --- /dev/null +++ b/internal/adapter/claude/stream_status_test.go @@ -0,0 +1,100 @@ +package claude + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + chimw "github.com/go-chi/chi/v5/middleware" + + "ds2api/internal/auth" +) + +type streamStatusClaudeAuthStub struct{} + +func (streamStatusClaudeAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) { + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (streamStatusClaudeAuthStub) Release(_ *auth.RequestAuth) {} + +type streamStatusClaudeDSStub struct{} + +func (streamStatusClaudeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "session-id", nil +} + +func (streamStatusClaudeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "pow", nil +} + +func (streamStatusClaudeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) { + body := "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n" + "data: [DONE]\n" + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: ioNopCloser{strings.NewReader(body)}, + }, nil +} + +type ioNopCloser struct { + *strings.Reader +} + +func (ioNopCloser) Close() error { return nil } + +type streamStatusClaudeStoreStub struct{} + +func (streamStatusClaudeStoreStub) ClaudeMapping() map[string]string { + return map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner", + } +} + +func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor) + next.ServeHTTP(ww, r) + *statuses = append(*statuses, ww.Status()) + }) + } +} + +func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: streamStatusClaudeStoreStub{}, + Auth: streamStatusClaudeAuthStub{}, + DS: streamStatusClaudeDSStub{}, + } + r := chi.NewRouter() + r.Use(captureClaudeStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}` + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 { + t.Fatalf("expected one captured status, got %d", len(statuses)) + } + if statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0]) + } +} diff --git a/internal/adapter/gemini/convert.go b/internal/adapter/gemini/convert.go new file mode 100644 index 0000000..3f63579 --- /dev/null +++ b/internal/adapter/gemini/convert.go @@ -0,0 +1,313 @@ +package gemini + +import ( + "encoding/json" + "fmt" + "strings" + + "ds2api/internal/adapter/openai" + "ds2api/internal/config" + "ds2api/internal/util" +) + +func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) { + requestedModel := strings.TrimSpace(routeModel) + if requestedModel == "" { + return util.StandardRequest{}, fmt.Errorf("model is required in request path") + } + + resolvedModel, ok := config.ResolveModel(store, requestedModel) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + + messagesRaw := geminiMessagesFromRequest(req) + if len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.") + } + + toolsRaw := convertGeminiTools(req["tools"]) + finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "") + passThrough := collectGeminiPassThrough(req) + + return util.StandardRequest{ + Surface: "google_gemini", + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + ResponseModel: requestedModel, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: stream, + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func geminiMessagesFromRequest(req map[string]any) []any { + out := make([]any, 0, 8) + if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" { + out = append(out, map[string]any{ + "role": "system", + "content": sys, + }) + } + + contents, _ := req["contents"].([]any) + for _, item := range contents { + content, ok := item.(map[string]any) + if !ok { + continue + } + role := mapGeminiRole(content["role"]) + if role == "" { + role = "user" + } + parts, _ := content["parts"].([]any) + if len(parts) == 0 { + if text := strings.TrimSpace(asString(content["text"])); text != "" { + out = append(out, map[string]any{ + "role": role, + "content": text, + }) + } + continue + } + + textParts := make([]string, 0, len(parts)) + flushText := func() { + if len(textParts) == 0 { + return + } + out = append(out, map[string]any{ + "role": role, + "content": strings.Join(textParts, "\n"), + }) + textParts = textParts[:0] + } + + for _, rawPart := range parts { + part, ok := rawPart.(map[string]any) + if !ok { + continue + } + if text := strings.TrimSpace(asString(part["text"])); text != "" { + textParts = append(textParts, text) + continue + } + + if fnCall, ok := part["functionCall"].(map[string]any); ok { + flushText() + if name := strings.TrimSpace(asString(fnCall["name"])); name != "" { + callID := strings.TrimSpace(asString(fnCall["id"])) + if callID == "" { + callID = "call_gemini" + } + out = append(out, map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": callID, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": stringifyJSON(fnCall["args"]), + }, + }, + }, + }) + } + continue + } + + if fnResp, ok := part["functionResponse"].(map[string]any); ok { + flushText() + name := strings.TrimSpace(asString(fnResp["name"])) + callID := strings.TrimSpace(asString(fnResp["id"])) + if callID == "" { + callID = strings.TrimSpace(asString(fnResp["callId"])) + } + if callID == "" { + callID = strings.TrimSpace(asString(fnResp["tool_call_id"])) + } + if callID == "" { + callID = "call_gemini" + } + content := fnResp["response"] + if content == nil { + content = fnResp["output"] + } + if content == nil { + content = "" + } + msg := map[string]any{ + "role": "tool", + "tool_call_id": callID, + "content": content, + } + if name != "" { + msg["name"] = name + } + out = append(out, msg) + } + } + flushText() + } + return out +} + +func normalizeGeminiSystemInstruction(raw any) string { + switch v := raw.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if parts, ok := v["parts"].([]any); ok { + texts := make([]string, 0, len(parts)) + for _, item := range parts { + part, ok := item.(map[string]any) + if !ok { + continue + } + if text := strings.TrimSpace(asString(part["text"])); text != "" { + texts = append(texts, text) + } + } + return strings.Join(texts, "\n") + } + if text := strings.TrimSpace(asString(v["text"])); text != "" { + return text + } + } + return "" +} + +func mapGeminiRole(v any) string { + switch strings.ToLower(strings.TrimSpace(asString(v))) { + case "user": + return "user" + case "model", "assistant": + return "assistant" + case "system": + return "system" + default: + return "" + } +} + +func convertGeminiTools(raw any) []any { + tools, _ := raw.([]any) + if len(tools) == 0 { + return nil + } + out := make([]any, 0, len(tools)) + for _, item := range tools { + tool, ok := item.(map[string]any) + if !ok { + continue + } + + if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 { + for _, declRaw := range fnDecls { + decl, ok := declRaw.(map[string]any) + if !ok { + continue + } + name := strings.TrimSpace(asString(decl["name"])) + if name == "" { + continue + } + function := map[string]any{ + "name": name, + } + if desc := strings.TrimSpace(asString(decl["description"])); desc != "" { + function["description"] = desc + } + if params, ok := decl["parameters"].(map[string]any); ok { + function["parameters"] = params + } + out = append(out, map[string]any{ + "type": "function", + "function": function, + }) + } + continue + } + + // OpenAI-style passthrough fallback. + if _, ok := tool["function"].(map[string]any); ok { + out = append(out, tool) + continue + } + + // Loose fallback for flattened function schema objects. + name := strings.TrimSpace(asString(tool["name"])) + if name == "" { + continue + } + fn := map[string]any{"name": name} + if desc := strings.TrimSpace(asString(tool["description"])); desc != "" { + fn["description"] = desc + } + if params, ok := tool["parameters"].(map[string]any); ok { + fn["parameters"] = params + } + out = append(out, map[string]any{ + "type": "function", + "function": fn, + }) + } + if len(out) == 0 { + return nil + } + return out +} + +func collectGeminiPassThrough(req map[string]any) map[string]any { + cfg, _ := req["generationConfig"].(map[string]any) + if len(cfg) == 0 { + return nil + } + out := map[string]any{} + if v, ok := cfg["temperature"]; ok { + out["temperature"] = v + } + if v, ok := cfg["topP"]; ok { + out["top_p"] = v + } + if v, ok := cfg["maxOutputTokens"]; ok { + out["max_tokens"] = v + } + if v, ok := cfg["stopSequences"]; ok { + out["stop"] = v + } + if len(out) == 0 { + return nil + } + return out +} + +func asString(v any) string { + s, _ := v.(string) + return s +} + +func stringifyJSON(v any) string { + switch x := v.(type) { + case nil: + return "{}" + case string: + s := strings.TrimSpace(x) + if s == "" { + return "{}" + } + return s + default: + b, err := json.Marshal(x) + if err != nil || len(b) == 0 { + return "{}" + } + return string(b) + } +} diff --git a/internal/adapter/gemini/deps.go b/internal/adapter/gemini/deps.go new file mode 100644 index 0000000..312114a --- /dev/null +++ b/internal/adapter/gemini/deps.go @@ -0,0 +1,29 @@ +package gemini + +import ( + "context" + "net/http" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type AuthResolver interface { + Determine(req *http.Request) (*auth.RequestAuth, error) + Release(a *auth.RequestAuth) +} + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type ConfigReader interface { + ModelAliases() map[string]string +} + +var _ AuthResolver = (*auth.Resolver)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) +var _ ConfigReader = (*config.Store)(nil) diff --git a/internal/adapter/gemini/handler.go b/internal/adapter/gemini/handler.go new file mode 100644 index 0000000..8daaeda --- /dev/null +++ b/internal/adapter/gemini/handler.go @@ -0,0 +1,348 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" + "ds2api/internal/deepseek" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +var writeJSON = util.WriteJSON + +type Handler struct { + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller +} + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent) + r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent) + r.Post("/v1/models/{model}:generateContent", h.GenerateContent) + r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent) +} + +func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) { + h.handleGenerateContent(w, r, false) +} + +func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) { + h.handleGenerateContent(w, r, true) +} + +func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeGeminiError(w, status, detail) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeGeminiError(w, http.StatusBadRequest, "invalid json") + return + } + + routeModel := strings.TrimSpace(chi.URLParam(r, "model")) + stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream) + if err != nil { + writeGeminiError(w, http.StatusBadRequest, err.Error()) + return + } + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeGeminiError(w, http.StatusUnauthorized, "Invalid token.") + } + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + payload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.") + return + } + + if stream { + h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + return + } + h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) +} + +func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + + result := sse.CollectStream(resp, thinkingEnabled, true) + writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames)) +} + +func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + rc := http.NewResponseController(w) + _, canFlush := w.(http.Flusher) + runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnParsed: runtime.onParsed, + OnFinalize: func(_ streamengine.StopReason, _ error) { + runtime.finalize() + }, + }) +} + +func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) + usage := buildGeminiUsage(finalPrompt, finalThinking, finalText) + return map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + "finishReason": "STOP", + }, + }, + "modelVersion": model, + "usageMetadata": usage, + } +} + +func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "promptTokenCount": promptTokens, + "candidatesTokenCount": reasoningTokens + completionTokens, + "totalTokenCount": promptTokens + reasoningTokens + completionTokens, + } +} + +func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { + detected = util.ParseToolCalls(finalThinking, toolNames) + } + if len(detected) > 0 { + parts := make([]map[string]any, 0, len(detected)) + for _, tc := range detected { + parts = append(parts, map[string]any{ + "functionCall": map[string]any{ + "name": tc.Name, + "args": tc.Input, + }, + }) + } + return parts + } + + text := finalText + if strings.TrimSpace(text) == "" { + text = finalThinking + } + return []map[string]any{{"text": text}} +} + +type geminiStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + model string + finalPrompt string + + thinkingEnabled bool + searchEnabled bool + bufferContent bool + toolNames []string + + thinking strings.Builder + text strings.Builder +} + +func newGeminiStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, +) *geminiStreamRuntime { + return &geminiStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + model: model, + finalPrompt: finalPrompt, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferContent: len(toolNames) > 0, + toolNames: toolNames, + } +} + +func (s *geminiStreamRuntime) sendChunk(payload map[string]any) { + b, _ := json.Marshal(payload) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + if p.Type == "thinking" { + if s.thinkingEnabled { + s.thinking.WriteString(p.Text) + } + continue + } + s.text.WriteString(p.Text) + if s.bufferContent { + continue + } + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": []map[string]any{{"text": p.Text}}, + }, + }, + }, + "modelVersion": s.model, + }) + } + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} + +func (s *geminiStreamRuntime) finalize() { + finalThinking := s.thinking.String() + finalText := s.text.String() + + if s.bufferContent { + parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames) + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + }, + }, + "modelVersion": s.model, + }) + } + + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "finishReason": "STOP", + }, + }, + "modelVersion": s.model, + "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText), + }) +} + +func writeGeminiError(w http.ResponseWriter, status int, message string) { + errorStatus := "INVALID_ARGUMENT" + switch status { + case http.StatusUnauthorized: + errorStatus = "UNAUTHENTICATED" + case http.StatusForbidden: + errorStatus = "PERMISSION_DENIED" + case http.StatusTooManyRequests: + errorStatus = "RESOURCE_EXHAUSTED" + case http.StatusNotFound: + errorStatus = "NOT_FOUND" + default: + if status >= 500 { + errorStatus = "INTERNAL" + } + } + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "code": status, + "message": message, + "status": errorStatus, + }, + }) +} diff --git a/internal/adapter/gemini/handler_test.go b/internal/adapter/gemini/handler_test.go new file mode 100644 index 0000000..862750a --- /dev/null +++ b/internal/adapter/gemini/handler_test.go @@ -0,0 +1,174 @@ +package gemini + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" +) + +type testGeminiConfig struct{} + +func (testGeminiConfig) ModelAliases() map[string]string { return nil } + +type testGeminiAuth struct { + a *auth.RequestAuth + err error +} + +func (m testGeminiAuth) Determine(_ *http.Request) (*auth.RequestAuth, error) { + if m.err != nil { + return nil, m.err + } + if m.a != nil { + return m.a, nil + } + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (testGeminiAuth) Release(_ *auth.RequestAuth) {} + +type testGeminiDS struct { + resp *http.Response + err error +} + +func (m testGeminiDS) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "session-id", nil +} + +func (m testGeminiDS) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "pow", nil +} + +func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) { + if m.err != nil { + return nil, m.err + } + return m.resp, nil +} + +func makeGeminiUpstreamResponse(lines ...string) *http.Response { + body := strings.Join(lines, "\n") + if !strings.HasSuffix(body, "\n") { + body += "\n" + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestGeminiRoutesRegistered(t *testing.T) { + h := &Handler{ + Store: testGeminiConfig{}, + Auth: testGeminiAuth{err: auth.ErrUnauthorized}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + paths := []string{ + "/v1beta/models/gemini-2.5-pro:generateContent", + "/v1beta/models/gemini-2.5-pro:streamGenerateContent", + "/v1/models/gemini-2.5-pro:generateContent", + "/v1/models/gemini-2.5-pro:streamGenerateContent", + } + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code == http.StatusNotFound { + t.Fatalf("expected route %s to be registered, got 404", path) + } + } +} + +func TestGenerateContentReturnsFunctionCallParts(t *testing.T) { + upstream := makeGeminiUpstreamResponse( + `data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`, + `data: [DONE]`, + ) + h := &Handler{ + Store: testGeminiConfig{}, + Auth: testGeminiAuth{}, + DS: testGeminiDS{resp: upstream}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := `{ + "contents":[{"role":"user","parts":[{"text":"call tool"}]}], + "tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer direct-token") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + candidates, _ := out["candidates"].([]any) + if len(candidates) == 0 { + t.Fatalf("expected non-empty candidates: %#v", out) + } + c0, _ := candidates[0].(map[string]any) + content, _ := c0["content"].(map[string]any) + parts, _ := content["parts"].([]any) + if len(parts) == 0 { + t.Fatalf("expected non-empty parts: %#v", content) + } + part0, _ := parts[0].(map[string]any) + functionCall, _ := part0["functionCall"].(map[string]any) + if functionCall["name"] != "eval_javascript" { + t.Fatalf("expected functionCall name eval_javascript, got %#v", functionCall) + } +} + +func TestStreamGenerateContentEmitsSSE(t *testing.T) { + upstream := makeGeminiUpstreamResponse( + `data: {"p":"response/content","v":"hello "}`, + `data: {"p":"response/content","v":"world"}`, + `data: [DONE]`, + ) + h := &Handler{ + Store: testGeminiConfig{}, + Auth: testGeminiAuth{}, + DS: testGeminiDS{resp: upstream}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}` + req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:streamGenerateContent?alt=sse", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer direct-token") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "data: ") { + t.Fatalf("expected SSE data frames, got body=%s", rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"finishReason":"STOP"`) { + t.Fatalf("expected stream finish frame, got body=%s", rec.Body.String()) + } +} diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index 0e64bc5..d9a1ba4 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -25,10 +25,11 @@ type chatStreamRuntime struct { thinkingEnabled bool searchEnabled bool - firstChunkSent bool - bufferToolContent bool - emitEarlyToolDeltas bool - toolCallsEmitted bool + firstChunkSent bool + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + toolCallsDoneEmitted bool toolSieve toolStreamSieveState streamToolCallIDs map[int]string @@ -96,10 +97,10 @@ func (s *chatStreamRuntime) finalize(finishReason string) { finalThinking := s.thinking.String() finalText := s.text.String() detected := util.ParseToolCalls(finalText, s.toolNames) - if len(detected) > 0 && !s.toolCallsEmitted { + if len(detected) > 0 && !s.toolCallsDoneEmitted { finishReason = "tool_calls" delta := map[string]any{ - "tool_calls": util.FormatOpenAIStreamToolCalls(detected), + "tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs), } if !s.firstChunkSent { delta["role"] = "assistant" @@ -112,8 +113,29 @@ func (s *chatStreamRuntime) finalize(finishReason string) { []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)}, nil, )) + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true } else if s.bufferToolContent { for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) { + if len(evt.ToolCalls) > 0 { + finishReason = "tool_calls" + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true + tcDelta := map[string]any{ + "tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs), + } + if !s.firstChunkSent { + tcDelta["role"] = "assistant" + s.firstChunkSent = true + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)}, + nil, + )) + } if evt.Content == "" { continue } @@ -189,10 +211,14 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD if !s.emitEarlyToolDeltas { continue } - s.toolCallsEmitted = true - tcDelta := map[string]any{ - "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs), + formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs) + if len(formatted) == 0 { + continue } + tcDelta := map[string]any{ + "tool_calls": formatted, + } + s.toolCallsEmitted = true if !s.firstChunkSent { tcDelta["role"] = "assistant" s.firstChunkSent = true @@ -202,8 +228,9 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD } if len(evt.ToolCalls) > 0 { s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true tcDelta := map[string]any{ - "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), + "tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs), } if !s.firstChunkSent { tcDelta["role"] = "assistant" diff --git a/internal/adapter/openai/deps_injection_test.go b/internal/adapter/openai/deps_injection_test.go index baa0c11..6286c0c 100644 --- a/internal/adapter/openai/deps_injection_test.go +++ b/internal/adapter/openai/deps_injection_test.go @@ -31,7 +31,7 @@ func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) { "model": "my-model", "messages": []any{map[string]any{"role": "user", "content": "hello"}}, } - out, err := normalizeOpenAIChatRequest(cfg, req) + out, err := normalizeOpenAIChatRequest(cfg, req, "") if err != nil { t.Fatalf("normalizeOpenAIChatRequest error: %v", err) } @@ -52,7 +52,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing. _, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ aliases: map[string]string{}, wideInput: false, - }, req) + }, req, "") if err == nil { t.Fatal("expected error when wide input is disabled and only input is provided") } @@ -60,7 +60,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing. out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ aliases: map[string]string{}, wideInput: true, - }, req) + }, req, "") if err != nil { t.Fatalf("unexpected error when wide input is enabled: %v", err) } diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 28a451c..391a035 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -93,7 +93,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r)) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error()) return @@ -154,7 +154,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") rc := http.NewResponseController(w) - canFlush := rc.Flush() == nil + _, canFlush := w.(http.Flusher) if !canFlush { config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered") } @@ -233,7 +233,7 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, if len(toolSchemas) == 0 { return messages, names } - toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error." + toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block." for i := range messages { if messages[i]["role"] == "system" { @@ -280,6 +280,36 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s return out } +func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any { + if len(calls) == 0 { + return nil + } + out := make([]map[string]any, 0, len(calls)) + for i, c := range calls { + callID := "" + if ids != nil { + callID = strings.TrimSpace(ids[i]) + } + if callID == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + if ids != nil { + ids[i] = callID + } + } + args, _ := json.Marshal(c.Input) + out = append(out, map[string]any{ + "index": i, + "id": callID, + "type": "function", + "function": map[string]any{ + "name": c.Name, + "arguments": string(args), + }, + }) + } + return out +} + func writeOpenAIError(w http.ResponseWriter, status int, message string) { writeJSON(w, status, map[string]any{ "error": map[string]any{ diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index dd2bb0f..2027729 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -735,3 +735,73 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) { t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } } + +func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`, + `data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + + foundSearch := false + foundEval := false + foundIndex1 := false + toolCallsDeltaLens := make([]int, 0, 2) + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + rawToolCalls, hasToolCalls := delta["tool_calls"] + if !hasToolCalls { + continue + } + toolCalls, _ := rawToolCalls.([]any) + toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls)) + for _, tc := range toolCalls { + tcm, _ := tc.(map[string]any) + if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 { + foundIndex1 = true + } + fn, _ := tcm["function"].(map[string]any) + name, _ := fn["name"].(string) + switch name { + case "search_web": + foundSearch = true + case "eval_javascript": + foundEval = true + case "search_webeval_javascript": + t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String()) + } + if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) { + t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String()) + } + } + } + } + if !foundSearch || !foundEval { + t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String()) + } + if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 { + t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String()) + } + if !foundIndex1 { + t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String()) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } +} diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go index 3ebd1e7..a767960 100644 --- a/internal/adapter/openai/message_normalize.go +++ b/internal/adapter/openai/message_normalize.go @@ -4,9 +4,11 @@ import ( "encoding/json" "fmt" "strings" + + "ds2api/internal/config" ) -func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any { +func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any { out := make([]map[string]any, 0, len(raw)) for _, item := range raw { msg, ok := item.(map[string]any) @@ -17,7 +19,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any { switch role { case "assistant": content := normalizeOpenAIContentForPrompt(msg["content"]) - toolCalls := formatAssistantToolCallsForPrompt(msg) + toolCalls := formatAssistantToolCallsForPrompt(msg, traceID) combined := joinNonEmpty(content, toolCalls) if combined == "" { continue @@ -53,7 +55,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any { return out } -func formatAssistantToolCallsForPrompt(msg map[string]any) string { +func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string { entries := make([]string, 0) if calls, ok := msg["tool_calls"].([]any); ok { for i, item := range calls { @@ -86,7 +88,8 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string { if args == "" { args = "{}" } - entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: %s\n- function.name: %s\n- function.arguments: %s", id, name, args)) + maybeWarnSuspiciousToolHistory(traceID, id, name, args) + entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args)) } } @@ -99,7 +102,8 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string { if args == "" { args = "{}" } - entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: call_legacy\n- function.name: %s\n- function.arguments: %s", name, args)) + maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args) + entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args)) } return strings.Join(entries, "\n\n") @@ -124,7 +128,7 @@ func formatToolResultForPrompt(msg map[string]any) string { content = "null" } - return fmt.Sprintf("Tool result:\n- tool_call_id: %s\n- name: %s\n- content: %s", toolCallID, name, content) + return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content) } func normalizeOpenAIContentForPrompt(v any) string { @@ -190,3 +194,45 @@ func joinNonEmpty(parts ...string) string { } return strings.Join(nonEmpty, "\n\n") } + +func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) { + if !looksLikeConcatenatedJSON(args) { + return + } + traceID = strings.TrimSpace(traceID) + if traceID == "" { + traceID = "unknown" + } + config.Logger.Warn( + "[openai] suspicious tool call history payload detected", + "trace_id", traceID, + "tool_call_id", strings.TrimSpace(callID), + "name", strings.TrimSpace(name), + "arguments_preview", previewToolArgs(args, 160), + ) +} + +func looksLikeConcatenatedJSON(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return false + } + if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") { + return true + } + dec := json.NewDecoder(strings.NewReader(trimmed)) + var first any + if err := dec.Decode(&first); err != nil { + return false + } + var second any + return dec.Decode(&second) == nil +} + +func previewToolArgs(raw string, max int) string { + trimmed := strings.TrimSpace(raw) + if max <= 0 || len(trimmed) <= max { + return trimmed + } + return trimmed[:max] +} diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go index bb648d3..30403bc 100644 --- a/internal/adapter/openai/message_normalize_test.go +++ b/internal/adapter/openai/message_normalize_test.go @@ -33,23 +33,24 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes }, } - normalized := normalizeOpenAIMessagesForPrompt(raw) + normalized := normalizeOpenAIMessagesForPrompt(raw, "") if len(normalized) != 4 { t.Fatalf("expected 4 normalized messages, got %d", len(normalized)) } assistantContent, _ := normalized[2]["content"].(string) - if !strings.Contains(assistantContent, "tool_call_id: call_1") || + if !strings.Contains(assistantContent, "[TOOL_CALL_HISTORY]") || + !strings.Contains(assistantContent, "tool_call_id: call_1") || !strings.Contains(assistantContent, "function.name: get_weather") || !strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") { t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent) } toolContent, _ := normalized[3]["content"].(string) - if !strings.Contains(toolContent, "Tool result:") || !strings.Contains(toolContent, "name: get_weather") { + if !strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") || !strings.Contains(toolContent, "name: get_weather") { t.Fatalf("tool result not serialized correctly: %q", toolContent) } prompt := util.MessagesPrepare(normalized) - if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "Tool result:") { + if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") { t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt) } } @@ -67,7 +68,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing. }, } - normalized := normalizeOpenAIMessagesForPrompt(raw) + normalized := normalizeOpenAIMessagesForPrompt(raw, "") got, _ := normalized[0]["content"].(string) if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) { t.Fatalf("expected serialized object in tool content, got %q", got) @@ -88,7 +89,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) { }, } - normalized := normalizeOpenAIMessagesForPrompt(raw) + normalized := normalizeOpenAIMessagesForPrompt(raw, "") got, _ := normalized[0]["content"].(string) if !strings.Contains(got, "line-1\nline-2") { t.Fatalf("expected joined text blocks, got %q", got) @@ -107,7 +108,7 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { }, } - normalized := normalizeOpenAIMessagesForPrompt(raw) + normalized := normalizeOpenAIMessagesForPrompt(raw, "") if len(normalized) != 1 { t.Fatalf("expected one normalized message, got %d", len(normalized)) } @@ -119,3 +120,50 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { t.Fatalf("unexpected normalized function-role content: %q", got) } } + +func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_search", + "type": "function", + "function": map[string]any{ + "name": "search_web", + "arguments": `{"query":"latest ai news"}`, + }, + }, + map[string]any{ + "id": "call_eval", + "type": "function", + "function": map[string]any{ + "name": "eval_javascript", + "arguments": `{"code":"1+1"}`, + }, + }, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 1 { + t.Fatalf("expected one normalized assistant message, got %d", len(normalized)) + } + content, _ := normalized[0]["content"].(string) + if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 { + t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content) + } + if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") { + t.Fatalf("missing first tool call block, got %q", content) + } + if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") { + t.Fatalf("missing second tool call block, got %q", content) + } + if strings.Contains(content, "search_webeval_javascript") { + t.Fatalf("unexpected merged function name detected: %q", content) + } + if strings.Contains(content, `}{"`) { + t.Fatalf("unexpected concatenated function arguments detected: %q", content) + } +} diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go index f83963f..76739ed 100644 --- a/internal/adapter/openai/prompt_build.go +++ b/internal/adapter/openai/prompt_build.go @@ -4,11 +4,18 @@ import ( "ds2api/internal/deepseek" ) -func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) { - messages := normalizeOpenAIMessagesForPrompt(messagesRaw) +func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) { + messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID) toolNames := []string{} if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { messages, toolNames = injectToolPrompt(messages, tools) } return deepseek.MessagesPrepare(messages), toolNames } + +// BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so +// other protocol adapters (for example Gemini) can reuse the same tool/history +// normalization logic and remain behavior-compatible with chat/completions. +func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) { + return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID) +} diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go index 1833860..bd6223e 100644 --- a/internal/adapter/openai/prompt_build_test.go +++ b/internal/adapter/openai/prompt_build_test.go @@ -40,13 +40,13 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes }, } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools) + finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "") if len(toolNames) != 1 || toolNames[0] != "get_weather" { t.Fatalf("unexpected tool names: %#v", toolNames) } if !strings.Contains(finalPrompt, "tool_call_id: call_1") || !strings.Contains(finalPrompt, "function.name: get_weather") || - !strings.Contains(finalPrompt, "Tool result:") || + !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") || !strings.Contains(finalPrompt, `"condition":"sunny"`) { t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt) } @@ -70,11 +70,14 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t * }, } - finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools) + finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "") if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") { t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt) } if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") { t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt) } + if !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") { + t.Fatalf("vercel prepare finalPrompt missing history marker instruction: %q", finalPrompt) + } } diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go index d270e1a..a5e2b72 100644 --- a/internal/adapter/openai/responses_embeddings_test.go +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -1,6 +1,7 @@ package openai import ( + "strings" "testing" "time" ) @@ -32,6 +33,82 @@ func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) { } } +func TestNormalizeResponsesInputAsMessagesObjectRoleContentBlocks(t *testing.T) { + msgs := normalizeResponsesInputAsMessages(map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": "line-1"}, + map[string]any{"type": "input_text", "text": "line-2"}, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "user" { + t.Fatalf("unexpected role: %#v", m) + } + if strings.TrimSpace(normalizeOpenAIContentForPrompt(m["content"])) != "line-1\nline-2" { + t.Fatalf("unexpected content: %#v", m["content"]) + } +} + +func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call_output", + "call_id": "call_123", + "output": map[string]any{"ok": true}, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "tool" { + t.Fatalf("expected tool role, got %#v", m) + } + if m["tool_call_id"] != "call_123" { + t.Fatalf("expected tool_call_id propagated, got %#v", m) + } +} + +func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call", + "call_id": "call_456", + "name": "search", + "arguments": `{"q":"golang"}`, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "assistant" { + t.Fatalf("expected assistant role, got %#v", m["role"]) + } + toolCalls, _ := m["tool_calls"].([]any) + if len(toolCalls) != 1 { + t.Fatalf("expected one tool_call, got %#v", m["tool_calls"]) + } + call, _ := toolCalls[0].(map[string]any) + if call["id"] != "call_456" { + t.Fatalf("expected call id preserved, got %#v", call) + } + if call["type"] != "function" { + t.Fatalf("expected function type, got %#v", call) + } + fn, _ := call["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("expected call name preserved, got %#v", call) + } + if fn["arguments"] != `{"q":"golang"}` { + t.Fatalf("expected call arguments preserved, got %#v", call) + } +} + func TestExtractEmbeddingInputs(t *testing.T) { got := extractEmbeddingInputs([]any{"a", "b"}) if len(got) != 2 || got[0] != "a" || got[1] != "b" { diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index e767b2b..e71cafe 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -68,7 +68,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req) + stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r)) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error()) return @@ -128,7 +128,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") rc := http.NewResponseController(w) - canFlush := rc.Flush() == nil + _, canFlush := w.(http.Flusher) initialType := "text" if thinkingEnabled { @@ -203,40 +203,231 @@ func normalizeResponsesInputAsMessages(input any) []any { } return []any{map[string]any{"role": "user", "content": v}} case []any: - if len(v) == 0 { - return nil - } - // If caller already provides role-shaped items, keep as-is. - if first, ok := v[0].(map[string]any); ok { - if _, hasRole := first["role"]; hasRole { - return v - } - } - parts := make([]string, 0, len(v)) - for _, item := range v { - if m, ok := item.(map[string]any); ok { - if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") { - if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { - parts = append(parts, txt) - continue - } - } - } - if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" { - parts = append(parts, s) - } - } - if len(parts) == 0 { - return nil - } - return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}} + return normalizeResponsesInputArray(v) case map[string]any: + if msg := normalizeResponsesInputItem(v); msg != nil { + return []any{msg} + } if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" { return []any{map[string]any{"role": "user", "content": txt}} } - if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" { - return []any{map[string]any{"role": "user", "content": content}} + if content, ok := v["content"]; ok { + if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { + return []any{map[string]any{"role": "user", "content": content}} + } } } return nil } + +func normalizeResponsesInputArray(items []any) []any { + if len(items) == 0 { + return nil + } + out := make([]any, 0, len(items)) + fallbackParts := make([]string, 0, len(items)) + flushFallback := func() { + if len(fallbackParts) == 0 { + return + } + out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")}) + fallbackParts = fallbackParts[:0] + } + + for _, item := range items { + switch x := item.(type) { + case map[string]any: + if msg := normalizeResponsesInputItem(x); msg != nil { + flushFallback() + out = append(out, msg) + continue + } + if s := normalizeResponsesFallbackPart(x); s != "" { + fallbackParts = append(fallbackParts, s) + } + default: + if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" { + fallbackParts = append(fallbackParts, s) + } + } + } + flushFallback() + if len(out) == 0 { + return nil + } + return out +} + +func normalizeResponsesInputItem(m map[string]any) map[string]any { + if m == nil { + return nil + } + + role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) + if role != "" { + content := m["content"] + if content == nil { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + content = txt + } + } + if content == nil { + return nil + } + return map[string]any{ + "role": role, + "content": content, + } + } + + itemType := strings.ToLower(strings.TrimSpace(asString(m["type"]))) + switch itemType { + case "message", "input_message": + content := m["content"] + if content == nil { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + content = txt + } + } + if content == nil { + return nil + } + role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) + if role == "" { + role = "user" + } + return map[string]any{ + "role": role, + "content": content, + } + case "function_call_output", "tool_result": + content := m["output"] + if content == nil { + content = m["content"] + } + if content == nil { + content = "" + } + out := map[string]any{ + "role": "tool", + "content": content, + } + if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { + out["tool_call_id"] = callID + } else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" { + out["tool_call_id"] = callID + } + if name := strings.TrimSpace(asString(m["name"])); name != "" { + out["name"] = name + } else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" { + out["name"] = name + } + return out + case "function_call", "tool_call": + name := strings.TrimSpace(asString(m["name"])) + var fn map[string]any + if rawFn, ok := m["function"].(map[string]any); ok { + fn = rawFn + if name == "" { + name = strings.TrimSpace(asString(fn["name"])) + } + } + if name == "" { + return nil + } + + var argsRaw any + if v, ok := m["arguments"]; ok { + argsRaw = v + } else if v, ok := m["input"]; ok { + argsRaw = v + } + if argsRaw == nil && fn != nil { + if v, ok := fn["arguments"]; ok { + argsRaw = v + } else if v, ok := fn["input"]; ok { + argsRaw = v + } + } + + functionPayload := map[string]any{ + "name": name, + "arguments": stringifyToolCallArguments(argsRaw), + } + call := map[string]any{ + "type": "function", + "function": functionPayload, + } + if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { + call["id"] = callID + } else if callID = strings.TrimSpace(asString(m["id"])); callID != "" { + call["id"] = callID + } + return map[string]any{ + "role": "assistant", + "tool_calls": []any{call}, + } + case "input_text": + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return map[string]any{ + "role": "user", + "content": txt, + } + } + } + + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return map[string]any{ + "role": "user", + "content": txt, + } + } + if content, ok := m["content"]; ok { + if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { + return map[string]any{ + "role": "user", + "content": content, + } + } + } + return nil +} + +func normalizeResponsesFallbackPart(m map[string]any) string { + if m == nil { + return "" + } + if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return txt + } + } + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return txt + } + if content, ok := m["content"]; ok { + if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" { + return normalized + } + } + return strings.TrimSpace(fmt.Sprintf("%v", m)) +} + +func stringifyToolCallArguments(v any) string { + switch x := v.(type) { + case nil: + return "{}" + case string: + s := strings.TrimSpace(x) + if s == "" { + return "{}" + } + return s + default: + b, err := json.Marshal(x) + if err != nil || len(b) == 0 { + return "{}" + } + return string(b) + } +} diff --git a/internal/adapter/openai/responses_stream_runtime.go b/internal/adapter/openai/responses_stream_runtime.go index f7e8b20..050965c 100644 --- a/internal/adapter/openai/responses_stream_runtime.go +++ b/internal/adapter/openai/responses_stream_runtime.go @@ -3,12 +3,15 @@ package openai import ( "encoding/json" "net/http" + "sort" "strings" openaifmt "ds2api/internal/format/openai" "ds2api/internal/sse" streamengine "ds2api/internal/stream" "ds2api/internal/util" + + "github.com/google/uuid" ) type responsesStreamRuntime struct { @@ -24,14 +27,20 @@ type responsesStreamRuntime struct { thinkingEnabled bool searchEnabled bool - bufferToolContent bool - emitEarlyToolDeltas bool - toolCallsEmitted bool + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + toolCallsDoneEmitted bool sieve toolStreamSieveState + thinkingSieve toolStreamSieveState thinking strings.Builder text strings.Builder streamToolCallIDs map[int]string + streamFunctionIDs map[int]string + functionDone map[int]bool + toolCallsDoneSigs map[string]bool + reasoningItemID string persistResponse func(obj map[string]any) } @@ -63,6 +72,9 @@ func newResponsesStreamRuntime( bufferToolContent: bufferToolContent, emitEarlyToolDeltas: emitEarlyToolDeltas, streamToolCallIDs: map[int]string{}, + streamFunctionIDs: map[int]string{}, + functionDone: map[int]bool{}, + toolCallsDoneSigs: map[string]bool{}, persistResponse: persistResponse, } } @@ -92,19 +104,33 @@ func (s *responsesStreamRuntime) sendDone() { func (s *responsesStreamRuntime) finalize() { finalThinking := s.thinking.String() finalText := s.text.String() + if strings.TrimSpace(finalThinking) != "" { + s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking)) + } if s.bufferToolContent { - for _, evt := range flushToolSieve(&s.sieve, s.toolNames) { - if evt.Content != "" { - s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) - } - if len(evt.ToolCalls) > 0 { - s.toolCallsEmitted = true - s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) + s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) + s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false) + } + // Compatibility fallback: some streams only emit incremental tool deltas. + // Ensure final function_call_arguments.done is emitted at least once. + if s.toolCallsEmitted { + detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) == 0 { + detected = util.ParseToolCalls(finalThinking, s.toolNames) + } + if len(detected) > 0 { + if !s.toolCallsDoneEmitted { + s.emitToolCallsDone(detected) + } else { + s.emitFunctionCallDoneEvents(detected) } } } obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames) + if s.toolCallsEmitted { + s.alignCompletedOutputCallIDs(obj) + } if s.toolCallsEmitted { obj["status"] = "completed" } @@ -138,6 +164,10 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa } s.thinking.WriteString(p.Text) s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) + s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text)) + if s.bufferToolContent { + s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false) + } continue } @@ -146,23 +176,191 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text)) continue } - for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) { - if evt.Content != "" { - s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) - } - if len(evt.ToolCallDeltas) > 0 { - if !s.emitEarlyToolDeltas { - continue - } - s.toolCallsEmitted = true - s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs))) - } - if len(evt.ToolCalls) > 0 { - s.toolCallsEmitted = true - s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls))) - } - } + s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true) } return streamengine.ParsedDecision{ContentSeen: contentSeen} } + +func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) { + for _, evt := range events { + if emitContent && evt.Content != "" { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) + } + if len(evt.ToolCallDeltas) > 0 { + if !s.emitEarlyToolDeltas { + continue + } + formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs) + if len(formatted) == 0 { + continue + } + s.toolCallsEmitted = true + s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted)) + s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas) + } + if len(evt.ToolCalls) > 0 { + s.emitToolCallsDone(evt.ToolCalls) + } + } +} + +func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) { + if len(calls) == 0 { + return + } + sig := toolCallListSignature(calls) + if sig != "" && s.toolCallsDoneSigs[sig] { + return + } + if sig != "" { + s.toolCallsDoneSigs[sig] = true + } + formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs) + if len(formatted) == 0 { + return + } + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true + s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted)) + s.emitFunctionCallDoneEvents(calls) +} + +func (s *responsesStreamRuntime) ensureReasoningItemID() string { + if strings.TrimSpace(s.reasoningItemID) != "" { + return s.reasoningItemID + } + s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "") + return s.reasoningItemID +} + +func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string { + if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" { + return id + } + id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "") + s.streamFunctionIDs[index] = id + return id +} + +func (s *responsesStreamRuntime) ensureToolCallID(index int) string { + if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" { + return id + } + id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + s.streamToolCallIDs[index] = id + return id +} + +func (s *responsesStreamRuntime) functionOutputBaseIndex() int { + if strings.TrimSpace(s.thinking.String()) != "" { + return 1 + } + return 0 +} + +func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) { + for _, d := range deltas { + if strings.TrimSpace(d.Arguments) == "" { + continue + } + outputIndex := s.functionOutputBaseIndex() + d.Index + itemID := s.ensureFunctionItemID(outputIndex) + callID := s.ensureToolCallID(d.Index) + s.sendEvent( + "response.function_call_arguments.delta", + openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments), + ) + } +} + +func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) { + base := s.functionOutputBaseIndex() + for idx, tc := range calls { + if strings.TrimSpace(tc.Name) == "" { + continue + } + outputIndex := base + idx + if s.functionDone[outputIndex] { + continue + } + itemID := s.ensureFunctionItemID(outputIndex) + callID := s.ensureToolCallID(idx) + argsBytes, _ := json.Marshal(tc.Input) + s.sendEvent( + "response.function_call_arguments.done", + openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)), + ) + s.functionDone[outputIndex] = true + } +} + +func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) { + if obj == nil || len(s.streamToolCallIDs) == 0 { + return + } + output, _ := obj["output"].([]any) + if len(output) == 0 { + return + } + indices := make([]int, 0, len(s.streamToolCallIDs)) + for idx := range s.streamToolCallIDs { + indices = append(indices, idx) + } + sort.Ints(indices) + ordered := make([]string, 0, len(indices)) + for _, idx := range indices { + id := strings.TrimSpace(s.streamToolCallIDs[idx]) + if id == "" { + continue + } + ordered = append(ordered, id) + } + if len(ordered) == 0 { + return + } + + functionIdx := 0 + for _, item := range output { + m, _ := item.(map[string]any) + if m == nil { + continue + } + typ, _ := m["type"].(string) + switch typ { + case "function_call": + if functionIdx < len(ordered) { + m["call_id"] = ordered[functionIdx] + functionIdx++ + } + case "tool_calls": + tcArr, _ := m["tool_calls"].([]any) + for i, raw := range tcArr { + tc, _ := raw.(map[string]any) + if tc == nil { + continue + } + if i < len(ordered) { + tc["id"] = ordered[i] + } + } + } + } +} + +func toolCallListSignature(calls []util.ParsedToolCall) string { + if len(calls) == 0 { + return "" + } + var b strings.Builder + for i, tc := range calls { + if i > 0 { + b.WriteString("|") + } + b.WriteString(strings.TrimSpace(tc.Name)) + b.WriteString(":") + args, _ := json.Marshal(tc.Input) + b.Write(args) + } + return b.String() +} diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index 9b0a5ac..a513e6f 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -45,17 +45,37 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T if len(output) == 0 { t.Fatalf("expected structured output entries, got %#v", responseObj["output"]) } - first, _ := output[0].(map[string]any) - if first["type"] != "tool_calls" { - t.Fatalf("expected first output type tool_calls, got %#v", first["type"]) + var firstToolWrapper map[string]any + hasFunctionCall := false + for _, item := range output { + m, _ := item.(map[string]any) + if m == nil { + continue + } + if m["type"] == "function_call" { + hasFunctionCall = true + } + if m["type"] == "tool_calls" && firstToolWrapper == nil { + firstToolWrapper = m + } } - toolCalls, _ := first["tool_calls"].([]any) + if !hasFunctionCall { + t.Fatalf("expected at least one function_call item for responses compatibility, got %#v", responseObj["output"]) + } + if firstToolWrapper == nil { + t.Fatalf("expected a tool_calls wrapper item, got %#v", responseObj["output"]) + } + toolCalls, _ := firstToolWrapper["tool_calls"].([]any) if len(toolCalls) == 0 { - t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"]) + t.Fatalf("expected at least one tool_call in output, got %#v", firstToolWrapper["tool_calls"]) } call0, _ := toolCalls[0].(map[string]any) - if call0["name"] != "read_file" { - t.Fatalf("unexpected tool call name: %#v", call0["name"]) + if call0["type"] != "function" { + t.Fatalf("unexpected tool call type: %#v", call0["type"]) + } + fn, _ := call0["function"].(map[string]any) + if fn["name"] != "read_file" { + t.Fatalf("unexpected tool call name: %#v", fn["name"]) } if strings.Contains(outputText, `"tool_calls"`) { t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText) @@ -95,6 +115,314 @@ func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t } } +func TestHandleResponsesStreamEmitsReasoningCompatEvents(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + b, _ := json.Marshal(map[string]any{ + "p": "response/thinking_content", + "v": "thought", + }) + streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil) + + body := rec.Body.String() + if !strings.Contains(body, "event: response.reasoning.delta") { + t.Fatalf("expected response.reasoning.delta event, body=%s", body) + } + if !strings.Contains(body, "event: response.reasoning_text.delta") { + t.Fatalf("expected response.reasoning_text.delta compatibility event, body=%s", body) + } + if !strings.Contains(body, "event: response.reasoning_text.done") { + t.Fatalf("expected response.reasoning_text.done compatibility event, body=%s", body) + } +} + +func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + body := rec.Body.String() + if !strings.Contains(body, "event: response.function_call_arguments.delta") { + t.Fatalf("expected response.function_call_arguments.delta compatibility event, body=%s", body) + } + if !strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("expected response.function_call_arguments.done compatibility event, body=%s", body) + } + donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done") + if !ok { + t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body) + } + if strings.TrimSpace(asString(donePayload["call_id"])) == "" { + t.Fatalf("expected call_id in response.function_call_arguments.done payload, payload=%#v", donePayload) + } + if strings.TrimSpace(asString(donePayload["response_id"])) == "" { + t.Fatalf("expected response_id in response.function_call_arguments.done payload, payload=%#v", donePayload) + } + doneCallID := strings.TrimSpace(asString(donePayload["call_id"])) + if doneCallID == "" { + t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload) + } + completed, ok := extractSSEEventPayload(body, "response.completed") + if !ok { + t.Fatalf("expected response.completed payload, body=%s", body) + } + responseObj, _ := completed["response"].(map[string]any) + output, _ := responseObj["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected non-empty output in response.completed, response=%#v", responseObj) + } + var completedCallID string + for _, item := range output { + m, _ := item.(map[string]any) + if m == nil || m["type"] != "function_call" { + continue + } + completedCallID = strings.TrimSpace(asString(m["call_id"])) + if completedCallID != "" { + break + } + } + if completedCallID == "" { + t.Fatalf("expected function_call.call_id in completed output, output=%#v", output) + } + if completedCallID != doneCallID { + t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID) + } +} + +func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(path, v string) string { + b, _ := json.Marshal(map[string]any{ + "p": path, + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}) + + body := rec.Body.String() + if !strings.Contains(body, "event: response.reasoning_text.delta") { + t.Fatalf("expected response.reasoning_text.delta event, body=%s", body) + } + if !strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("expected response.function_call_arguments.done event from thinking channel, body=%s", body) + } + if !strings.Contains(body, "event: response.output_tool_call.done") { + t.Fatalf("expected response.output_tool_call.done event from thinking channel, body=%s", body) + } +} + +func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) + + sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}) + + body := rec.Body.String() + if !strings.Contains(body, "event: response.output_tool_call.done") { + t.Fatalf("expected response.output_tool_call.done event, body=%s", body) + } + donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done") + if len(donePayloads) != 2 { + t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body) + } + + seenNames := map[string]string{} + for _, payload := range donePayloads { + name := strings.TrimSpace(asString(payload["name"])) + callID := strings.TrimSpace(asString(payload["call_id"])) + args := strings.TrimSpace(asString(payload["arguments"])) + if callID == "" { + t.Fatalf("expected non-empty call_id in done payload: %#v", payload) + } + if strings.Contains(args, `}{"`) { + t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload) + } + if name == "search_webeval_javascript" { + t.Fatalf("unexpected merged tool name in done payload: %#v", payload) + } + if name != "search_web" && name != "eval_javascript" { + t.Fatalf("unexpected tool name in done payload: %#v", payload) + } + seenNames[name] = callID + } + if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" { + t.Fatalf("expected done events for both tools, got %#v", seenNames) + } + if seenNames["search_web"] == seenNames["eval_javascript"] { + t.Fatalf("expected distinct call_id per tool, got %#v", seenNames) + } + + completed, ok := extractSSEEventPayload(body, "response.completed") + if !ok { + t.Fatalf("expected response.completed event, body=%s", body) + } + responseObj, _ := completed["response"].(map[string]any) + output, _ := responseObj["output"].([]any) + functionCallIDs := map[string]string{} + for _, item := range output { + m, _ := item.(map[string]any) + if m == nil || m["type"] != "function_call" { + continue + } + name := strings.TrimSpace(asString(m["name"])) + callID := strings.TrimSpace(asString(m["call_id"])) + if name != "" && callID != "" { + functionCallIDs[name] = callID + } + } + if functionCallIDs["search_web"] != seenNames["search_web"] { + t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"]) + } + if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] { + t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"]) + } +} + +func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(path, v string) string { + b, _ := json.Marshal(map[string]any{ + "p": path, + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) + + sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"}) + + body := rec.Body.String() + if !strings.Contains(body, "event: response.reasoning_text.delta") { + t.Fatalf("expected reasoning stream events, body=%s", body) + } + donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done") + if len(donePayloads) != 2 { + t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body) + } + seen := map[string]bool{} + for _, payload := range donePayloads { + name := strings.TrimSpace(asString(payload["name"])) + if name == "search_webeval_javascript" { + t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload) + } + if name != "search_web" && name != "eval_javascript" { + t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload) + } + args := strings.TrimSpace(asString(payload["arguments"])) + if strings.Contains(args, `}{"`) { + t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload) + } + seen[name] = true + } + if !seen["search_web"] || !seen["eval_javascript"] { + t.Fatalf("expected both tools in thinking channel done events, got %#v", seen) + } +} + +func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("我来调用工具\n") + + sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + + completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") + if !ok { + t.Fatalf("expected response.completed event, body=%s", rec.Body.String()) + } + responseObj, _ := completed["response"].(map[string]any) + output, _ := responseObj["output"].([]any) + hasFunctionCall := false + for _, item := range output { + m, _ := item.(map[string]any) + if m != nil && m["type"] == "function_call" { + hasFunctionCall = true + break + } + } + if !hasFunctionCall { + t.Fatalf("expected completed output to include function_call when mixed prose contains tool_calls payload, output=%#v", output) + } +} + func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { scanner := bufio.NewScanner(strings.NewReader(body)) matched := false @@ -120,3 +448,30 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { } return nil, false } + +func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any { + scanner := bufio.NewScanner(strings.NewReader(body)) + matched := false + out := make([]map[string]any, 0, 2) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "event: ") { + evt := strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + matched = evt == targetEvent + continue + } + if !matched || !strings.HasPrefix(line, "data: ") { + continue + } + raw := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + if raw == "" || raw == "[DONE]" { + continue + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + continue + } + out = append(out, payload) + } + return out +} diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go index 5883d03..7683ee7 100644 --- a/internal/adapter/openai/standard_request.go +++ b/internal/adapter/openai/standard_request.go @@ -8,7 +8,7 @@ import ( "ds2api/internal/util" ) -func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) { +func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) { model, _ := req["model"].(string) messagesRaw, _ := req["messages"].([]any) if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { @@ -23,7 +23,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St if responseModel == "" { responseModel = resolvedModel } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID) passThrough := collectOpenAIChatPassThrough(req) return util.StandardRequest{ @@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St }, nil } -func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) { +func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) { model, _ := req["model"].(string) model = strings.TrimSpace(model) if model == "" { @@ -67,7 +67,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (ut if len(messagesRaw) == 0 { return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.") } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID) passThrough := collectOpenAIChatPassThrough(req) return util.StandardRequest{ diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go index f3453a2..a876364 100644 --- a/internal/adapter/openai/standard_request_test.go +++ b/internal/adapter/openai/standard_request_test.go @@ -22,7 +22,7 @@ func TestNormalizeOpenAIChatRequest(t *testing.T) { "temperature": 0.3, "stream": true, } - n, err := normalizeOpenAIChatRequest(store, req) + n, err := normalizeOpenAIChatRequest(store, req, "") if err != nil { t.Fatalf("normalize failed: %v", err) } @@ -47,7 +47,7 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) { "input": "ping", "instructions": "system", } - n, err := normalizeOpenAIResponsesRequest(store, req) + n, err := normalizeOpenAIResponsesRequest(store, req, "") if err != nil { t.Fatalf("normalize failed: %v", err) } diff --git a/internal/adapter/openai/stream_status_test.go b/internal/adapter/openai/stream_status_test.go new file mode 100644 index 0000000..4f8305a --- /dev/null +++ b/internal/adapter/openai/stream_status_test.go @@ -0,0 +1,185 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + chimw "github.com/go-chi/chi/v5/middleware" + + "ds2api/internal/auth" +) + +type streamStatusAuthStub struct{} + +func (streamStatusAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) { + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (streamStatusAuthStub) DetermineCaller(_ *http.Request) (*auth.RequestAuth, error) { + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (streamStatusAuthStub) Release(_ *auth.RequestAuth) {} + +type streamStatusDSStub struct { + resp *http.Response +} + +func (m streamStatusDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "session-id", nil +} + +func (m streamStatusDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "pow", nil +} + +func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) { + return m.resp, nil +} + +func makeOpenAISSEHTTPResponse(lines ...string) *http.Response { + body := strings.Join(lines, "\n") + if !strings.HasSuffix(body, "\n") { + body += "\n" + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func captureStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor) + next.ServeHTTP(ww, r) + *statuses = append(*statuses, ww.Status()) + }) + } +} + +func TestChatCompletionsStreamStatusCapturedAs200(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 { + t.Fatalf("expected one captured status, got %d", len(statuses)) + } + if statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0]) + } +} + +func TestResponsesStreamStatusCapturedAs200(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","input":"hi","stream":true}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 { + t.Fatalf("expected one captured status, got %d", len(statuses)) + } + if statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0]) + } +} + +func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) { + statuses := make([]int, 0, 1) + content, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}", + }) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 || statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200, got %#v", statuses) + } + + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String()) + } + outputText, _ := out["output_text"].(string) + if outputText != "" { + t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText) + } + output, _ := out["output"].([]any) + hasFunctionCall := false + for _, item := range output { + m, _ := item.(map[string]any) + if m != nil && m["type"] == "function_call" { + hasFunctionCall = true + break + } + } + if !hasFunctionCall { + t.Fatalf("expected function_call output item, got %#v", output) + } +} diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go index fd7222b..9c46649 100644 --- a/internal/adapter/openai/tool_sieve.go +++ b/internal/adapter/openai/tool_sieve.go @@ -11,6 +11,7 @@ type toolStreamSieveState struct { capture strings.Builder capturing bool recentTextTail string + disableDeltas bool toolNameSent bool toolName string toolArgsStart int @@ -35,6 +36,7 @@ const toolSieveCaptureLimit = 8 * 1024 const toolSieveContextTailLimit = 256 func (s *toolStreamSieveState) resetIncrementalToolState() { + s.disableDeltas = false s.toolNameSent = false s.toolName = "" s.toolArgsStart = -1 @@ -239,17 +241,8 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix } parsed := util.ParseStandaloneToolCalls(obj, toolNames) if len(parsed) == 0 { - if state.toolNameSent { - return prefixPart, nil, suffixPart, true - } return captured, nil, "", true } - if state.toolNameSent { - if len(parsed) > 1 { - return prefixPart, parsed[1:], suffixPart, true - } - return prefixPart, nil, suffixPart, true - } return prefixPart, parsed, suffixPart, true } @@ -296,6 +289,9 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) { } func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { + if state.disableDeltas { + return nil + } captured := state.capture.String() if captured == "" { return nil @@ -312,6 +308,16 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { if insideCodeFence(state.recentTextTail + captured[:start]) { return nil } + certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx) + if hasMultiple { + state.disableDeltas = true + return nil + } + if !certainSingle { + // In uncertain phases (e.g. first call arrived but array not closed yet), + // avoid speculative deltas and wait for final parsed tool_calls payload. + return nil + } callStart, ok := findFirstToolCallObjectStart(captured, keyIdx) if !ok { return nil @@ -363,6 +369,68 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { return deltas } +func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) { + arrStart, ok := findToolCallsArrayStart(text, keyIdx) + if !ok { + return false, false + } + i := skipSpaces(text, arrStart+1) + if i >= len(text) || text[i] != '{' { + return false, false + } + count := 0 + depth := 0 + quote := byte(0) + escaped := false + for ; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' { + if depth == 0 { + count++ + if count > 1 { + return false, true + } + } + depth++ + continue + } + if ch == '}' { + if depth > 0 { + depth-- + } + continue + } + if ch == ',' && depth == 0 { + // top-level separator means at least one more tool call exists + // (or is expected). Treat as multi-call and stop incremental deltas. + return false, true + } + if ch == ']' && depth == 0 { + return count == 1, false + } + } + // array not closed yet: still uncertain whether more calls will appear + return false, false +} + func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) { arrStart, ok := findToolCallsArrayStart(text, keyIdx) if !ok { diff --git a/internal/adapter/openai/trace.go b/internal/adapter/openai/trace.go new file mode 100644 index 0000000..8ea58f0 --- /dev/null +++ b/internal/adapter/openai/trace.go @@ -0,0 +1,21 @@ +package openai + +import ( + "net/http" + "strings" + + "github.com/go-chi/chi/v5/middleware" +) + +func requestTraceID(r *http.Request) string { + if r == nil { + return "" + } + if q := strings.TrimSpace(r.URL.Query().Get("__trace_id")); q != "" { + return q + } + if h := strings.TrimSpace(r.Header.Get("X-Ds2-Test-Trace")); h != "" { + return h + } + return strings.TrimSpace(middleware.GetReqID(r.Context())) +} diff --git a/internal/adapter/openai/trace_test.go b/internal/adapter/openai/trace_test.go new file mode 100644 index 0000000..cbacbf3 --- /dev/null +++ b/internal/adapter/openai/trace_test.go @@ -0,0 +1,47 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5/middleware" +) + +func traceIDViaMiddleware(req *http.Request) string { + if req == nil { + return requestTraceID(nil) + } + var got string + h := middleware.RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + got = requestTraceID(r) + })) + h.ServeHTTP(httptest.NewRecorder(), req) + return got +} + +func TestRequestTraceIDPriority(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions?__trace_id=query-trace", nil) + req.Header.Set("X-Ds2-Test-Trace", "header-trace") + got := traceIDViaMiddleware(req) + if got != "query-trace" { + t.Fatalf("expected query trace id to win, got %q", got) + } +} + +func TestRequestTraceIDHeaderFallback(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + req.Header.Set("X-Ds2-Test-Trace", "header-trace") + got := traceIDViaMiddleware(req) + if got != "header-trace" { + t.Fatalf("expected header trace id to win when query missing, got %q", got) + } +} + +func TestRequestTraceIDReqIDFallback(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + got := traceIDViaMiddleware(req) + if got == "" { + t.Fatal("expected middleware request id fallback to be non-empty") + } +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index 65006c4..f34ea8b 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -56,7 +56,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r)) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error()) return diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 829b657..c8f7702 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -36,5 +36,7 @@ func RegisterRoutes(r chi.Router, h *Handler) { pr.Post("/vercel/sync", h.syncVercel) pr.Get("/vercel/status", h.vercelStatus) pr.Get("/export", h.exportConfig) + pr.Get("/dev/captures", h.getDevCaptures) + pr.Delete("/dev/captures", h.clearDevCaptures) }) } diff --git a/internal/admin/handler_dev_capture.go b/internal/admin/handler_dev_capture.go new file mode 100644 index 0000000..9b3615c --- /dev/null +++ b/internal/admin/handler_dev_capture.go @@ -0,0 +1,26 @@ +package admin + +import ( + "net/http" + + "ds2api/internal/devcapture" +) + +func (h *Handler) getDevCaptures(w http.ResponseWriter, _ *http.Request) { + store := devcapture.Global() + writeJSON(w, http.StatusOK, map[string]any{ + "enabled": store.Enabled(), + "limit": store.Limit(), + "max_body_bytes": store.MaxBodyBytes(), + "items": store.Snapshot(), + }) +} + +func (h *Handler) clearDevCaptures(w http.ResponseWriter, _ *http.Request) { + store := devcapture.Global() + store.Clear() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "detail": "capture logs cleared", + }) +} diff --git a/internal/admin/handler_dev_capture_test.go b/internal/admin/handler_dev_capture_test.go new file mode 100644 index 0000000..90ced8b --- /dev/null +++ b/internal/admin/handler_dev_capture_test.go @@ -0,0 +1,45 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetDevCapturesShape(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/dev/captures", nil) + h.getDevCaptures(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + if _, ok := out["enabled"]; !ok { + t.Fatalf("expected enabled field, got %#v", out) + } + if _, ok := out["items"]; !ok { + t.Fatalf("expected items field, got %#v", out) + } +} + +func TestClearDevCapturesShape(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/admin/dev/captures", nil) + h.clearDevCaptures(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + if out["success"] != true { + t.Fatalf("expected success=true, got %#v", out) + } +} diff --git a/internal/auth/request.go b/internal/auth/request.go index 25980cf..23acb40 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -187,7 +187,12 @@ func extractCallerToken(req *http.Request) string { return token } } - return strings.TrimSpace(req.Header.Get("x-api-key")) + if key := strings.TrimSpace(req.Header.Get("x-api-key")); key != "" { + return key + } + // Gemini AI Studio compatibility: allow query key fallback only when no + // header-based credential is present. + return strings.TrimSpace(req.URL.Query().Get("key")) } func callerTokenID(token string) string { diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index c292856..2eca44b 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -114,6 +114,40 @@ func TestDetermineMissingToken(t *testing.T) { } } +func TestDetermineWithQueryKeyUsesDirectToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=direct-query-key", nil) + + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + if a.UseConfigToken { + t.Fatalf("expected direct token mode") + } + if a.DeepSeekToken != "direct-query-key" { + t.Fatalf("unexpected token: %q", a.DeepSeekToken) + } +} + +func TestDetermineHeaderTokenPrecedenceOverQueryKey(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=query-key", nil) + req.Header.Set("x-api-key", "managed-key") + + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + if !a.UseConfigToken { + t.Fatalf("expected managed key mode from header token") + } + if a.AccountID == "" { + t.Fatalf("expected managed account to be acquired") + } +} + func TestDetermineCallerMissingToken(t *testing.T) { r := newTestResolver(t) req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) diff --git a/internal/deepseek/client.go b/internal/deepseek/client.go index 0523435..2ffe05d 100644 --- a/internal/deepseek/client.go +++ b/internal/deepseek/client.go @@ -16,6 +16,7 @@ import ( "ds2api/internal/auth" "ds2api/internal/config" trans "ds2api/internal/deepseek/transport" + "ds2api/internal/devcapture" "ds2api/internal/util" "github.com/andybalholm/brotli" @@ -27,6 +28,7 @@ var intFrom = util.IntFrom type Client struct { Store *config.Store Auth *auth.Resolver + capture *devcapture.Store regular trans.Doer stream trans.Doer fallback *http.Client @@ -39,6 +41,7 @@ func NewClient(store *config.Store, resolver *auth.Resolver) *Client { return &Client{ Store: store, Auth: resolver, + capture: devcapture.Global(), regular: trans.New(60 * time.Second), stream: trans.New(0), fallback: &http.Client{Timeout: 60 * time.Second}, @@ -179,6 +182,7 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa } headers := c.authHeaders(a.DeepSeekToken) headers["x-ds-pow-response"] = powResp + captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload) attempts := 0 for attempts < maxAttempts { resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload) @@ -188,8 +192,14 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa continue } if resp.StatusCode == http.StatusOK { + if captureSession != nil { + resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) + } return resp, nil } + if captureSession != nil { + resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) + } _ = resp.Body.Close() attempts++ time.Sleep(time.Second) diff --git a/internal/devcapture/store.go b/internal/devcapture/store.go new file mode 100644 index 0000000..6d0d8cd --- /dev/null +++ b/internal/devcapture/store.go @@ -0,0 +1,259 @@ +package devcapture + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +const ( + defaultLimit = 5 + defaultMaxBodyBytes = 2 * 1024 * 1024 + maxLimit = 50 +) + +type Entry struct { + ID string `json:"id"` + CreatedAt int64 `json:"created_at"` + Label string `json:"label"` + URL string `json:"url"` + AccountID string `json:"account_id,omitempty"` + StatusCode int `json:"status_code"` + RequestBody string `json:"request_body"` + ResponseBody string `json:"response_body"` + ResponseTruncated bool `json:"response_truncated"` +} + +type Store struct { + mu sync.Mutex + enabled bool + limit int + maxBodyBytes int + items []Entry +} + +type Session struct { + store *Store + id string + createdAt int64 + label string + url string + accountID string + requestRaw string +} + +type captureBody struct { + rc io.ReadCloser + s *Session + statusCode int + buf strings.Builder + truncated bool + finalized bool +} + +var ( + globalOnce sync.Once + globalInst *Store +) + +func Global() *Store { + globalOnce.Do(func() { + globalInst = NewFromEnv() + }) + return globalInst +} + +func NewFromEnv() *Store { + enabled := !isVercelRuntime() + if raw, ok := os.LookupEnv("DS2API_DEV_PACKET_CAPTURE"); ok { + enabled = parseBool(raw) + } + limit := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_LIMIT"), defaultLimit) + if limit < 1 { + limit = defaultLimit + } + if limit > maxLimit { + limit = maxLimit + } + maxBodyBytes := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES"), defaultMaxBodyBytes) + if maxBodyBytes < 1024 { + maxBodyBytes = defaultMaxBodyBytes + } + return &Store{ + enabled: enabled, + limit: limit, + maxBodyBytes: maxBodyBytes, + items: make([]Entry, 0, limit), + } +} + +func isVercelRuntime() bool { + return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != "" +} + +func (s *Store) Enabled() bool { + if s == nil { + return false + } + return s.enabled +} + +func (s *Store) Limit() int { + if s == nil { + return defaultLimit + } + return s.limit +} + +func (s *Store) MaxBodyBytes() int { + if s == nil { + return defaultMaxBodyBytes + } + return s.maxBodyBytes +} + +func (s *Store) Snapshot() []Entry { + if s == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + out := make([]Entry, len(s.items)) + copy(out, s.items) + return out +} + +func (s *Store) Clear() { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.items = s.items[:0] +} + +func (s *Store) Start(label, url, accountID string, requestPayload any) *Session { + if s == nil || !s.enabled { + return nil + } + return &Session{ + store: s, + id: "cap_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + createdAt: time.Now().Unix(), + label: strings.TrimSpace(label), + url: strings.TrimSpace(url), + accountID: strings.TrimSpace(accountID), + requestRaw: marshalPayload(requestPayload), + } +} + +func (s *Session) WrapBody(rc io.ReadCloser, statusCode int) io.ReadCloser { + if s == nil || rc == nil { + return rc + } + return &captureBody{ + rc: rc, + s: s, + statusCode: statusCode, + } +} + +func (c *captureBody) Read(p []byte) (int, error) { + n, err := c.rc.Read(p) + if n > 0 { + c.append(string(p[:n])) + } + if err == io.EOF { + c.finalize() + } + return n, err +} + +func (c *captureBody) Close() error { + err := c.rc.Close() + c.finalize() + return err +} + +func (c *captureBody) append(chunk string) { + if chunk == "" || c.s == nil || c.s.store == nil { + return + } + maxLen := c.s.store.maxBodyBytes + current := c.buf.Len() + if current >= maxLen { + c.truncated = true + return + } + remain := maxLen - current + if len(chunk) > remain { + c.buf.WriteString(chunk[:remain]) + c.truncated = true + return + } + c.buf.WriteString(chunk) +} + +func (c *captureBody) finalize() { + if c.finalized || c.s == nil || c.s.store == nil { + return + } + c.finalized = true + entry := Entry{ + ID: c.s.id, + CreatedAt: c.s.createdAt, + Label: c.s.label, + URL: c.s.url, + AccountID: c.s.accountID, + StatusCode: c.statusCode, + RequestBody: c.s.requestRaw, + ResponseBody: c.buf.String(), + ResponseTruncated: c.truncated, + } + c.s.store.push(entry) +} + +func (s *Store) push(entry Entry) { + s.mu.Lock() + defer s.mu.Unlock() + s.items = append([]Entry{entry}, s.items...) + if len(s.items) > s.limit { + s.items = s.items[:s.limit] + } +} + +func marshalPayload(v any) string { + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(b) +} + +func parseBool(v string) bool { + switch strings.ToLower(strings.TrimSpace(v)) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +func parseIntWithDefault(raw string, d int) int { + raw = strings.TrimSpace(raw) + if raw == "" { + return d + } + n, err := strconv.Atoi(raw) + if err != nil { + return d + } + return n +} diff --git a/internal/devcapture/store_test.go b/internal/devcapture/store_test.go new file mode 100644 index 0000000..1dd58b4 --- /dev/null +++ b/internal/devcapture/store_test.go @@ -0,0 +1,55 @@ +package devcapture + +import ( + "io" + "strings" + "testing" +) + +func TestStorePushKeepsNewestWithinLimit(t *testing.T) { + s := &Store{enabled: true, limit: 2, maxBodyBytes: 1024} + for i := 0; i < 3; i++ { + session := s.Start("test", "http://x", "", map[string]any{"seq": i}) + if session == nil { + t.Fatal("expected session") + } + rc := session.WrapBody(io.NopCloser(strings.NewReader("ok")), 200) + _, _ = io.ReadAll(rc) + _ = rc.Close() + } + items := s.Snapshot() + if len(items) != 2 { + t.Fatalf("expected 2 items, got %d", len(items)) + } + if !strings.Contains(items[0].RequestBody, `"seq":2`) { + t.Fatalf("expected newest first, got %#v", items[0].RequestBody) + } + if !strings.Contains(items[1].RequestBody, `"seq":1`) { + t.Fatalf("expected second newest, got %#v", items[1].RequestBody) + } +} + +func TestWrapBodyTruncatesByLimit(t *testing.T) { + s := &Store{enabled: true, limit: 5, maxBodyBytes: 4} + session := s.Start("test", "http://x", "acc1", map[string]any{"x": 1}) + if session == nil { + t.Fatal("expected session") + } + rc := session.WrapBody(io.NopCloser(strings.NewReader("abcdef")), 200) + _, _ = io.ReadAll(rc) + _ = rc.Close() + + items := s.Snapshot() + if len(items) != 1 { + t.Fatalf("expected 1 item, got %d", len(items)) + } + if items[0].ResponseBody != "abcd" { + t.Fatalf("expected truncated body, got %q", items[0].ResponseBody) + } + if !items[0].ResponseTruncated { + t.Fatal("expected truncated flag true") + } + if items[0].AccountID != "acc1" { + t.Fatalf("expected account id, got %q", items[0].AccountID) + } +} diff --git a/internal/format/openai/render.go b/internal/format/openai/render.go index fc7473f..2107d4e 100644 --- a/internal/format/openai/render.go +++ b/internal/format/openai/render.go @@ -1,6 +1,7 @@ package openai import ( + "encoding/json" "strings" "time" @@ -43,36 +44,45 @@ func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalT } func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + // Align responses tool-call semantics with chat/completions: + // mixed prose + tool_call payloads should still be interpreted as tool calls. detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { + detected = util.ParseToolCalls(finalThinking, toolNames) + } exposedOutputText := finalText output := make([]any, 0, 2) if len(detected) > 0 { exposedOutputText = "" - toolCalls := make([]any, 0, len(detected)) - for _, tc := range detected { - toolCalls = append(toolCalls, map[string]any{ - "type": "tool_call", - "name": tc.Name, - "arguments": tc.Input, + if strings.TrimSpace(finalThinking) != "" { + output = append(output, map[string]any{ + "type": "reasoning", + "text": finalThinking, }) } + formatted := util.FormatOpenAIToolCalls(detected) + output = append(output, toResponsesFunctionCallItems(formatted)...) output = append(output, map[string]any{ "type": "tool_calls", - "tool_calls": toolCalls, + "tool_calls": formatted, }) } else { - content := []any{ - map[string]any{ - "type": "output_text", - "text": finalText, - }, - } + content := make([]any, 0, 2) if finalThinking != "" { content = append([]any{map[string]any{ "type": "reasoning", "text": finalThinking, }}, content...) } + if strings.TrimSpace(finalText) != "" { + content = append(content, map[string]any{ + "type": "output_text", + "text": finalText, + }) + } + if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" { + exposedOutputText = finalThinking + } output = append(output, map[string]any{ "type": "message", "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), @@ -100,6 +110,54 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex } } +func toResponsesFunctionCallItems(toolCalls []map[string]any) []any { + if len(toolCalls) == 0 { + return nil + } + out := make([]any, 0, len(toolCalls)) + for _, tc := range toolCalls { + callID, _ := tc["id"].(string) + if strings.TrimSpace(callID) == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + } + name := "" + args := "{}" + if fn, ok := tc["function"].(map[string]any); ok { + if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" { + name = n + } + if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" { + args = a + } + } + out = append(out, map[string]any{ + "id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "type": "function_call", + "call_id": callID, + "name": name, + "arguments": normalizeJSONString(args), + "status": "completed", + }) + } + return out +} + +func normalizeJSONString(raw string) string { + s := strings.TrimSpace(raw) + if s == "" { + return "{}" + } + var v any + if err := json.Unmarshal([]byte(s), &v); err != nil { + return raw + } + b, err := json.Marshal(v) + if err != nil { + return raw + } + return string(b) +} + func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { return map[string]any{ "delta": delta, @@ -145,49 +203,105 @@ func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any func BuildResponsesCreatedPayload(responseID, model string) map[string]any { return map[string]any{ - "type": "response.created", - "id": responseID, - "object": "response", - "model": model, - "status": "in_progress", + "type": "response.created", + "id": responseID, + "response_id": responseID, + "object": "response", + "model": model, + "status": "in_progress", } } func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any { return map[string]any{ - "type": "response.output_text.delta", - "id": responseID, - "delta": delta, + "type": "response.output_text.delta", + "id": responseID, + "response_id": responseID, + "delta": delta, } } func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { return map[string]any{ - "type": "response.reasoning.delta", - "id": responseID, - "delta": delta, + "type": "response.reasoning.delta", + "id": responseID, + "response_id": responseID, + "delta": delta, + } +} + +func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any { + return map[string]any{ + "type": "response.reasoning_text.delta", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "delta": delta, + } +} + +func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any { + return map[string]any{ + "type": "response.reasoning_text.done", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "text": text, } } func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { return map[string]any{ - "type": "response.output_tool_call.delta", - "id": responseID, - "tool_calls": toolCalls, + "type": "response.output_tool_call.delta", + "id": responseID, + "response_id": responseID, + "tool_calls": toolCalls, } } func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { return map[string]any{ - "type": "response.output_tool_call.done", - "id": responseID, - "tool_calls": toolCalls, + "type": "response.output_tool_call.done", + "id": responseID, + "response_id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any { + return map[string]any{ + "type": "response.function_call_arguments.delta", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "call_id": callID, + "delta": delta, + } +} + +func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any { + return map[string]any{ + "type": "response.function_call_arguments.done", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "call_id": callID, + "name": name, + "arguments": normalizeJSONString(arguments), } } func BuildResponsesCompletedPayload(response map[string]any) map[string]any { + responseID, _ := response["id"].(string) return map[string]any{ - "type": "response.completed", - "response": response, + "type": "response.completed", + "response_id": responseID, + "response": response, } } diff --git a/internal/format/openai/render_test.go b/internal/format/openai/render_test.go new file mode 100644 index 0000000..e3bf0dd --- /dev/null +++ b/internal/format/openai/render_test.go @@ -0,0 +1,181 @@ +package openai + +import ( + "encoding/json" + "testing" +) + +func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "", + `{"tool_calls":[{"name":"search","input":{"q":"golang"}}]}`, + []string{"search"}, + ) + + outputText, _ := obj["output_text"].(string) + if outputText != "" { + t.Fatalf("expected output_text to be hidden for tool calls, got %q", outputText) + } + + output, _ := obj["output"].([]any) + if len(output) != 2 { + t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"]) + } + + first, _ := output[0].(map[string]any) + if first["type"] != "function_call" { + t.Fatalf("expected first output item type function_call, got %#v", first["type"]) + } + if first["call_id"] == "" { + t.Fatalf("expected function_call item to have call_id, got %#v", first) + } + second, _ := output[1].(map[string]any) + if second["type"] != "tool_calls" { + t.Fatalf("expected second output item type tool_calls, got %#v", second["type"]) + } + var toolCalls []map[string]any + switch v := second["tool_calls"].(type) { + case []map[string]any: + toolCalls = v + case []any: + toolCalls = make([]map[string]any, 0, len(v)) + for _, item := range v { + m, _ := item.(map[string]any) + if m != nil { + toolCalls = append(toolCalls, m) + } + } + } + if len(toolCalls) != 1 { + t.Fatalf("expected one tool call, got %#v", second["tool_calls"]) + } + tc := toolCalls[0] + if tc["type"] != "function" || tc["id"] == "" { + t.Fatalf("unexpected tool call shape: %#v", tc) + } + fn, _ := tc["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("unexpected function name: %#v", fn["name"]) + } + argsRaw, _ := fn["arguments"].(string) + var args map[string]any + if err := json.Unmarshal([]byte(argsRaw), &args); err != nil { + t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err) + } + if args["q"] != "golang" { + t.Fatalf("unexpected arguments: %#v", args) + } +} + +func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "", + `示例格式:{"tool_calls":[{"name":"search","input":{"q":"golang"}}]},但这条是普通回答。`, + []string{"search"}, + ) + + outputText, _ := obj["output_text"].(string) + if outputText != "" { + t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText) + } + + output, _ := obj["output"].([]any) + if len(output) != 2 { + t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "function_call" { + t.Fatalf("expected first output type function_call, got %#v", first["type"]) + } +} + +func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "", + "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```", + []string{"search"}, + ) + + outputText, _ := obj["output_text"].(string) + if outputText == "" { + t.Fatalf("expected output_text preserved for fenced example") + } + output, _ := obj["output"].([]any) + if len(output) != 1 { + t.Fatalf("expected one message output item, got %#v", obj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "message" { + t.Fatalf("expected message output type, got %#v", first["type"]) + } +} + +func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "internal thinking content", + "", + nil, + ) + + outputText, _ := obj["output_text"].(string) + if outputText == "" { + t.Fatalf("expected output_text fallback from reasoning when final text is empty") + } + + output, _ := obj["output"].([]any) + if len(output) != 1 { + t.Fatalf("expected one output item, got %#v", obj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "message" { + t.Fatalf("expected output type message, got %#v", first["type"]) + } + content, _ := first["content"].([]any) + if len(content) == 0 { + t.Fatalf("expected reasoning content, got %#v", first["content"]) + } + block0, _ := content[0].(map[string]any) + if block0["type"] != "reasoning" { + t.Fatalf("expected first content block reasoning, got %#v", block0["type"]) + } +} + +func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + `{"tool_calls":[{"name":"search","input":{"q":"from-thinking"}}]}`, + "", + []string{"search"}, + ) + + output, _ := obj["output"].([]any) + if len(output) != 3 { + t.Fatalf("expected reasoning + function_call + tool_calls outputs, got %#v", obj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "reasoning" { + t.Fatalf("expected first output reasoning, got %#v", first["type"]) + } + second, _ := output[1].(map[string]any) + if second["type"] != "function_call" { + t.Fatalf("expected second output function_call, got %#v", second["type"]) + } + third, _ := output[2].(map[string]any) + if third["type"] != "tool_calls" { + t.Fatalf("expected third output tool_calls, got %#v", third["type"]) + } +} diff --git a/internal/server/router.go b/internal/server/router.go index a81f0cb..ae3108e 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -12,6 +12,7 @@ import ( "ds2api/internal/account" "ds2api/internal/adapter/claude" + "ds2api/internal/adapter/gemini" "ds2api/internal/adapter/openai" "ds2api/internal/admin" "ds2api/internal/auth" @@ -44,6 +45,7 @@ func NewApp() *App { openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient} claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient} + geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient} adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient} webuiHandler := webui.NewHandler() @@ -67,6 +69,7 @@ func NewApp() *App { }) openai.RegisterRoutes(r, openaiHandler) claude.RegisterRoutes(r, claudeHandler) + gemini.RegisterRoutes(r, geminiHandler) r.Route("/admin", func(ar chi.Router) { admin.RegisterRoutes(ar, adminHandler) }) diff --git a/scripts/testsuite/run-live.sh b/tests/scripts/run-live.sh similarity index 100% rename from scripts/testsuite/run-live.sh rename to tests/scripts/run-live.sh diff --git a/tests/scripts/run-unit-all.sh b/tests/scripts/run-unit-all.sh new file mode 100755 index 0000000..59b202c --- /dev/null +++ b/tests/scripts/run-unit-all.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$ROOT_DIR" + +./tests/scripts/run-unit-go.sh +./tests/scripts/run-unit-node.sh diff --git a/tests/scripts/run-unit-go.sh b/tests/scripts/run-unit-go.sh new file mode 100755 index 0000000..38a11b8 --- /dev/null +++ b/tests/scripts/run-unit-go.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$ROOT_DIR" + +go test ./... "$@" diff --git a/tests/scripts/run-unit-node.sh b/tests/scripts/run-unit-node.sh new file mode 100755 index 0000000..95f11e0 --- /dev/null +++ b/tests/scripts/run-unit-node.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$ROOT_DIR" + +node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js "$@" diff --git a/vercel.json b/vercel.json index 2e68a94..bad49e0 100644 --- a/vercel.json +++ b/vercel.json @@ -38,6 +38,18 @@ "source": "/admin/config", "destination": "/api/index" }, + { + "source": "/admin/config/(.*)", + "destination": "/api/index" + }, + { + "source": "/admin/settings", + "destination": "/api/index" + }, + { + "source": "/admin/settings/(.*)", + "destination": "/api/index" + }, { "source": "/admin/keys(.*)", "destination": "/api/index" diff --git a/webui/src/App.jsx b/webui/src/App.jsx index 3f6ad27..2c3f099 100644 --- a/webui/src/App.jsx +++ b/webui/src/App.jsx @@ -1,4 +1,4 @@ -import { useState, useEffect } from 'react' +import { useState, useEffect, useCallback, useMemo } from 'react' import { Routes, Route, @@ -29,12 +29,12 @@ import Login from './components/Login' import LandingPage from './components/LandingPage' import LanguageToggle from './components/LanguageToggle' import { useI18n } from './i18n' +import { detectRuntimeEnv } from './utils/runtimeEnv' -function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout }) { +function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout, isVercel }) { const { t } = useI18n() const [activeTab, setActiveTab] = useState('accounts') const [sidebarOpen, setSidebarOpen] = useState(false) - const [loading, setLoading] = useState(false) const navItems = [ { id: 'accounts', label: t('nav.accounts.label'), icon: Users, description: t('nav.accounts.desc') }, @@ -44,7 +44,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, { id: 'settings', label: t('nav.settings.label'), icon: SettingsIcon, description: t('nav.settings.desc') }, ] - const authFetch = async (url, options = {}) => { + const authFetch = useCallback(async (url, options = {}) => { const headers = { ...options.headers, 'Authorization': `Bearer ${token}` @@ -56,7 +56,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, throw new Error(t('auth.expired')) } return res - } + }, [onLogout, t, token]) const renderTab = () => { switch (activeTab) { @@ -67,9 +67,9 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, case 'import': return case 'vercel': - return + return case 'settings': - return + return default: return null } @@ -213,13 +213,27 @@ export default function App() { const navigate = useNavigate() const location = useLocation() const [config, setConfig] = useState({ keys: [], accounts: [] }) - const [loading, setLoading] = useState(true) const [message, setMessage] = useState(null) const [token, setToken] = useState(null) const [authChecking, setAuthChecking] = useState(true) const isProduction = import.meta.env.MODE === 'production' const isAdminRoute = location.pathname.startsWith('/admin') || isProduction + const runtimeEnv = useMemo(() => detectRuntimeEnv(), []) + const isVercel = runtimeEnv.isVercel + + const showMessage = useCallback((type, text) => { + setMessage({ type, text }) + setTimeout(() => setMessage(null), 5000) + }, []) + + const handleLogout = useCallback(() => { + setToken(null) + localStorage.removeItem('ds2api_token') + localStorage.removeItem('ds2api_token_expires') + sessionStorage.removeItem('ds2api_token') + sessionStorage.removeItem('ds2api_token_expires') + }, []) useEffect(() => { // Only check auth status on admin routes. @@ -249,12 +263,11 @@ export default function App() { setAuthChecking(false) } checkAuth() - }, [isAdminRoute]) + }, [handleLogout, isAdminRoute]) - const fetchConfig = async () => { + const fetchConfig = useCallback(async () => { if (!token) return try { - setLoading(true) const res = await fetch('/admin/config', { headers: { 'Authorization': `Bearer ${token}` } }) @@ -265,34 +278,19 @@ export default function App() { } catch (e) { console.error('Failed to fetch config:', e) showMessage('error', t('errors.fetchConfig', { error: e.message })) - } finally { - setLoading(false) } - } + }, [showMessage, t, token]) useEffect(() => { if (token) { fetchConfig() } - }, [token]) - - const showMessage = (type, text) => { - setMessage({ type, text }) - setTimeout(() => setMessage(null), 5000) - } + }, [fetchConfig, token]) const handleLogin = (newToken) => { setToken(newToken) } - const handleLogout = () => { - setToken(null) - localStorage.removeItem('ds2api_token') - localStorage.removeItem('ds2api_token_expires') - sessionStorage.removeItem('ds2api_token') - sessionStorage.removeItem('ds2api_token_expires') - } - // Wait for auth checks on admin routes. if (isAdminRoute && authChecking) { return ( @@ -320,6 +318,7 @@ export default function App() { showMessage={showMessage} message={message} onForceLogout={handleLogout} + isVercel={isVercel} /> ) : (
diff --git a/webui/src/components/Settings.jsx b/webui/src/components/Settings.jsx index b257ed5..927804e 100644 --- a/webui/src/components/Settings.jsx +++ b/webui/src/components/Settings.jsx @@ -2,7 +2,9 @@ import { useCallback, useEffect, useMemo, useState } from 'react' import { AlertTriangle, Download, Lock, Save, Upload } from 'lucide-react' import { useI18n } from '../i18n' -export default function Settings({ onRefresh, onMessage, authFetch, onForceLogout }) { +const MAX_AUTO_FETCH_FAILURES = 3 + +export default function Settings({ onRefresh, onMessage, authFetch, onForceLogout, isVercel = false }) { const { t } = useI18n() const apiFetch = authFetch || fetch @@ -14,6 +16,9 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou const [importMode, setImportMode] = useState('merge') const [importText, setImportText] = useState('') const [newPassword, setNewPassword] = useState('') + const [consecutiveFailures, setConsecutiveFailures] = useState(0) + const [autoFetchPaused, setAutoFetchPaused] = useState(false) + const [lastError, setLastError] = useState('') const [settingsMeta, setSettingsMeta] = useState({ default_password_warning: false, env_backed: false, needs_vercel_sync: false }) const [form, setForm] = useState({ @@ -43,15 +48,38 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou return parsed } - const loadSettings = useCallback(async () => { + const parseJSONResponse = useCallback(async (res) => { + const contentType = String(res.headers.get('content-type') || '').toLowerCase() + if (!contentType.includes('application/json')) { + throw new Error(t('settings.nonJsonResponse', { status: res.status })) + } + return res.json() + }, [t]) + + const loadSettings = useCallback(async ({ manual = false } = {}) => { + if (isVercel && autoFetchPaused && !manual) { + return + } setLoading(true) try { const res = await apiFetch('/admin/settings') - const data = await res.json() + const data = await parseJSONResponse(res) if (!res.ok) { - onMessage('error', data.detail || t('settings.loadFailed')) + const detail = data.detail || t('settings.loadFailed') + setLastError(detail) + onMessage('error', detail) + setConsecutiveFailures((prev) => { + const next = prev + 1 + if (isVercel && next >= MAX_AUTO_FETCH_FAILURES) { + setAutoFetchPaused(true) + } + return next + }) return } + setConsecutiveFailures(0) + setAutoFetchPaused(false) + setLastError('') setSettingsMeta({ default_password_warning: Boolean(data.admin?.default_password_warning), env_backed: Boolean(data.env_backed), @@ -78,18 +106,32 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou model_aliases_text: JSON.stringify(data.model_aliases || {}, null, 2), }) } catch (e) { - onMessage('error', t('settings.loadFailed')) + const detail = e?.message || t('settings.loadFailed') + setLastError(detail) + onMessage('error', detail) + setConsecutiveFailures((prev) => { + const next = prev + 1 + if (isVercel && next >= MAX_AUTO_FETCH_FAILURES) { + setAutoFetchPaused(true) + } + return next + }) // eslint-disable-next-line no-console console.error(e) } finally { setLoading(false) } - }, [apiFetch, onMessage, t]) + }, [apiFetch, autoFetchPaused, isVercel, onMessage, parseJSONResponse, t]) useEffect(() => { loadSettings() }, [loadSettings]) + const retryLoadSettings = () => { + setAutoFetchPaused(false) + loadSettings({ manual: true }) + } + const saveSettings = async () => { let claudeMapping = {} let modelAliases = {} @@ -228,6 +270,23 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou return (
+ {autoFetchPaused && ( +
+
+ + + {t('settings.autoFetchPaused', { count: consecutiveFailures, error: lastError || t('settings.loadFailed') })} + +
+ +
+ )} {settingsMeta.default_password_warning && (
diff --git a/webui/src/components/VercelSync.jsx b/webui/src/components/VercelSync.jsx index 2f8a548..a50714e 100644 --- a/webui/src/components/VercelSync.jsx +++ b/webui/src/components/VercelSync.jsx @@ -3,7 +3,15 @@ import { Cloud, ArrowRight, ExternalLink, Info, CheckCircle2, XCircle, RefreshCw import clsx from 'clsx' import { useI18n } from '../i18n' -export default function VercelSync({ onMessage, authFetch }) { +const MAX_POLL_FAILURES = 3 + +function pollDelayMs(attempt) { + if (attempt <= 0) return 15000 + if (attempt === 1) return 30000 + return 60000 +} + +export default function VercelSync({ onMessage, authFetch, isVercel = false }) { const { t } = useI18n() const [vercelToken, setVercelToken] = useState('') const [projectId, setProjectId] = useState('') @@ -12,20 +20,42 @@ export default function VercelSync({ onMessage, authFetch }) { const [result, setResult] = useState(null) const [preconfig, setPreconfig] = useState(null) const [syncStatus, setSyncStatus] = useState(null) + const [pollPaused, setPollPaused] = useState(false) + const [pollFailures, setPollFailures] = useState(0) + const [nextRetryAt, setNextRetryAt] = useState(null) const apiFetch = authFetch || fetch - const fetchSyncStatus = useCallback(async () => { + const fetchSyncStatus = useCallback(async ({ manual = false } = {}) => { try { const res = await apiFetch('/admin/vercel/status') - if (res.ok) { - const data = await res.json() - setSyncStatus(data) + if (!res.ok) { + throw new Error(`status ${res.status}`) } + const data = await res.json() + setSyncStatus(data) + setPollFailures(0) + setPollPaused(false) + setNextRetryAt(null) } catch (e) { + setPollFailures((prev) => { + const next = prev + 1 + if (isVercel) { + if (next >= MAX_POLL_FAILURES) { + setPollPaused(true) + setNextRetryAt(null) + } else { + setNextRetryAt(Date.now() + pollDelayMs(next)) + } + } + return next + }) + if (manual) { + onMessage('error', t('vercel.networkError')) + } console.error('Failed to fetch sync status:', e) } - }, [apiFetch]) + }, [apiFetch, isVercel, onMessage, t]) useEffect(() => { const loadPreconfig = async () => { @@ -43,11 +73,32 @@ export default function VercelSync({ onMessage, authFetch }) { } loadPreconfig() fetchSyncStatus() - // Poll every 15s to detect config changes - const interval = setInterval(fetchSyncStatus, 15000) - return () => clearInterval(interval) }, [fetchSyncStatus]) + useEffect(() => { + if (!isVercel) { + const interval = setInterval(() => { + fetchSyncStatus() + }, 15000) + return () => clearInterval(interval) + } + if (pollPaused) { + return undefined + } + const delay = nextRetryAt && nextRetryAt > Date.now() ? nextRetryAt - Date.now() : pollDelayMs(pollFailures) + const timer = setTimeout(() => { + fetchSyncStatus() + }, Math.max(1000, delay)) + return () => clearTimeout(timer) + }, [fetchSyncStatus, isVercel, nextRetryAt, pollFailures, pollPaused]) + + const handleManualRefresh = () => { + setPollPaused(false) + setPollFailures(0) + setNextRetryAt(null) + fetchSyncStatus({ manual: true }) + } + const handleSync = async () => { const tokenToUse = preconfig?.has_token && !vercelToken ? '__USE_PRECONFIG__' : vercelToken @@ -122,6 +173,20 @@ export default function VercelSync({ onMessage, authFetch }) {

{t('vercel.description')}

+ {pollPaused && ( +
+

+ {t('vercel.pollPaused', { count: pollFailures })} +

+ +
+ )} {syncStatus?.last_sync_time && (

diff --git a/webui/src/locales/en.json b/webui/src/locales/en.json index d5be108..c06a86d 100644 --- a/webui/src/locales/en.json +++ b/webui/src/locales/en.json @@ -198,6 +198,7 @@ }, "settings": { "loadFailed": "Failed to load settings.", + "nonJsonResponse": "Unexpected non-JSON response from server (status: {status}).", "save": "Save settings", "saving": "Saving...", "saveSuccess": "Settings saved and hot reloaded.", @@ -239,7 +240,9 @@ "exportJson": "Export JSON", "invalidJsonField": "{field} is not a valid JSON object.", "defaultPasswordWarning": "You are using the default admin password \"admin\". Please change it.", - "vercelSyncHint": "Configuration changed. For Vercel deployments, sync manually in Vercel Sync and redeploy." + "vercelSyncHint": "Configuration changed. For Vercel deployments, sync manually in Vercel Sync and redeploy.", + "autoFetchPaused": "Auto loading paused after {count} failures: {error}", + "retryLoad": "Retry now" }, "login": { "welcome": "Welcome back", @@ -278,6 +281,8 @@ "statusNotSynced": "Not synced", "statusNeverSynced": "Never synced", "lastSyncTime": "Last sync: {time}", + "pollPaused": "Status polling paused after {count} failures.", + "manualRefresh": "Refresh manually", "howItWorks": "How it works", "steps": { "one": "The current configuration (keys and accounts) is exported as JSON.", diff --git a/webui/src/locales/zh.json b/webui/src/locales/zh.json index fb31188..7d587d7 100644 --- a/webui/src/locales/zh.json +++ b/webui/src/locales/zh.json @@ -198,6 +198,7 @@ }, "settings": { "loadFailed": "加载设置失败", + "nonJsonResponse": "服务端返回了非 JSON 响应(状态码:{status})", "save": "保存设置", "saving": "保存中...", "saveSuccess": "设置已保存并热更新生效", @@ -239,7 +240,9 @@ "exportJson": "导出 JSON", "invalidJsonField": "{field} 不是有效 JSON 对象", "defaultPasswordWarning": "当前使用默认密码 admin,请尽快在此修改。", - "vercelSyncHint": "当前配置已更新。Vercel 部署请到 Vercel 同步页面手动同步并重部署。" + "vercelSyncHint": "当前配置已更新。Vercel 部署请到 Vercel 同步页面手动同步并重部署。", + "autoFetchPaused": "自动加载已暂停:连续失败 {count} 次({error})", + "retryLoad": "立即重试" }, "login": { "welcome": "欢迎回来", @@ -278,6 +281,8 @@ "statusNotSynced": "未同步", "statusNeverSynced": "从未同步", "lastSyncTime": "上次同步: {time}", + "pollPaused": "状态轮询已暂停:连续失败 {count} 次。", + "manualRefresh": "手动刷新", "howItWorks": "工作原理", "steps": { "one": "当前配置 (密钥和账号) 被导出为 JSON 字符串。", diff --git a/webui/src/utils/runtimeEnv.js b/webui/src/utils/runtimeEnv.js new file mode 100644 index 0000000..d7b7889 --- /dev/null +++ b/webui/src/utils/runtimeEnv.js @@ -0,0 +1,13 @@ +export function detectRuntimeEnv() { + const deployTarget = String(import.meta.env.VITE_DEPLOY_TARGET || '').trim().toLowerCase() + if (deployTarget === 'vercel') { + return { isVercel: true, source: 'vite_env' } + } + + const host = typeof window !== 'undefined' ? String(window.location.hostname || '').toLowerCase() : '' + if (host.includes('vercel.app')) { + return { isVercel: true, source: 'hostname' } + } + + return { isVercel: false, source: 'default' } +}