Merge pull request #39 from CJackHwang/dev

同步
This commit is contained in:
CJACK.
2026-02-22 01:45:01 +08:00
committed by GitHub
60 changed files with 3723 additions and 228 deletions

View File

@@ -13,12 +13,8 @@
#### 🔀 变更说明 | Description of Change
<!-- Thank you for your Pull Request. Please provide a description above. -->
#### 📝 补充信息 | Additional Information
<!-- Add any other context about the Pull Request here. -->
---
> 💡 **提示**:如果修改了 `webui/` 目录下的文件PR 合并后 CI 会自动构建并提交产物,无需手动构建。

View File

@@ -12,6 +12,9 @@ permissions:
jobs:
build-and-upload:
runs-on: ubuntu-latest
env:
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -87,10 +90,11 @@ jobs:
password: ${{ secrets.GITHUB_TOKEN }}
- name: Log in to Docker Hub
if: "${{ env.DOCKERHUB_USERNAME != '' }}"
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
username: ${{ env.DOCKERHUB_USERNAME }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Extract Docker metadata
id: meta_release
@@ -98,7 +102,7 @@ jobs:
with:
images: |
ghcr.io/${{ github.repository }}
cjackhwang/ds2api
${{ env.DOCKERHUB_USERNAME || 'cjackhwang' }}/ds2api
tags: |
type=raw,value=${{ github.event.release.tag_name }}
type=raw,value=latest

View File

@@ -86,7 +86,7 @@ Manually build WebUI to `static/admin/`:
go test ./...
# End-to-end live tests (real accounts)
./scripts/testsuite/run-live.sh
./tests/scripts/run-live.sh
```
## Project Structure

View File

@@ -86,7 +86,7 @@ docker-compose -f docker-compose.dev.yml up
go test ./...
# 端到端全链路测试(真实账号)
./scripts/testsuite/run-live.sh
./tests/scripts/run-live.sh
```
## 项目结构

View File

@@ -518,7 +518,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
Run the full live testsuite before release (real account tests):
```bash
./scripts/testsuite/run-live.sh
./tests/scripts/run-live.sh
```
With custom flags:

View File

@@ -518,7 +518,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
建议在发布前执行完整的端到端测试集(使用真实账号):
```bash
./scripts/testsuite/run-live.sh
./tests/scripts/run-live.sh
```
可自定义参数:

View File

@@ -283,6 +283,9 @@ cp opencode.json.example opencode.json
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容旧名) | — |
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel 混合流式内部鉴权密钥 | 回退用 `DS2API_ADMIN_KEY` |
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease 过期秒数 | `900` |
| `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 |
| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | 本地抓包保留条数(超出自动淘汰) | `5` |
| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | 单条响应体最大记录字节数 | `2097152` |
| `VERCEL_TOKEN` | Vercel 同步 token | — |
| `VERCEL_PROJECT_ID` | Vercel 项目 ID | — |
| `VERCEL_TEAM_ID` | Vercel 团队 ID | — |
@@ -321,6 +324,29 @@ cp opencode.json.example opencode.json
3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content`
4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出
## 本地开发抓包工具
用于定位「responses 思考流/工具调用」等问题。开启后会自动记录最近 N 条 DeepSeek 对话上游请求体与响应体(默认 5 条,超出自动淘汰)。
启用示例:
```bash
DS2API_DEV_PACKET_CAPTURE=true \
DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
go run ./cmd/ds2api
```
查询/清空(需 Admin JWT
- `GET /admin/dev/captures`:查看抓包列表(最新在前)
- `DELETE /admin/dev/captures`:清空抓包
返回字段包含:
- `request_body`:发送给 DeepSeek 的完整请求体
- `response_body`:上游返回的原始流式内容拼接文本
- `response_truncated`:是否触发单条大小截断
## 项目结构
```text
@@ -350,8 +376,10 @@ ds2api/
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
│ └── locales/ # 中英文语言包zh.json / en.json
├── scripts/
── build-webui.sh # WebUI 手动构建脚本
│ └── testsuite/ # 测试集运行脚本
── build-webui.sh # WebUI 手动构建脚本
── tests/
│ ├── compat/ # 兼容性测试夹具与期望输出
│ └── scripts/ # 统一测试脚本入口unit/e2e
├── static/admin/ # WebUI 构建产物(不提交到 Git
├── .github/
│ ├── workflows/ # GitHub ActionsRelease 自动构建)
@@ -379,11 +407,11 @@ ds2api/
## 测试
```bash
# 单元测试
go test ./...
# 单元测试Go + Node
./tests/scripts/run-unit-all.sh
# 一键端到端全链路测试(真实账号,生成完整请求/响应日志)
./scripts/testsuite/run-live.sh
./tests/scripts/run-live.sh
# 或自定义参数
go run ./cmd/ds2api-tests \

View File

@@ -283,6 +283,9 @@ cp opencode.json.example opencode.json
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL seconds | `900` |
| `DS2API_DEV_PACKET_CAPTURE` | Local dev packet capture switch (record recent request/response bodies) | Enabled by default on non-Vercel local runtime |
| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | Number of captured sessions to retain (auto-evict overflow) | `5` |
| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | Max recorded bytes per captured response body | `2097152` |
| `VERCEL_TOKEN` | Vercel sync token | — |
| `VERCEL_PROJECT_ID` | Vercel project ID | — |
| `VERCEL_TEAM_ID` | Vercel team ID | — |
@@ -321,6 +324,29 @@ When `tools` is present in the request, DS2API performs anti-leak handling:
3. Confirmed toolcall JSON fragments are never leaked into `delta.content`
4. Natural language before/after toolcalls keeps original order, with incremental argument output supported
## Local Dev Packet Capture
This is for debugging issues such as Responses reasoning streaming and tool-call handoff. When enabled, DS2API stores the latest N DeepSeek conversation payload pairs (request body + upstream response body), defaulting to 5 entries with auto-eviction.
Enable example:
```bash
DS2API_DEV_PACKET_CAPTURE=true \
DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
go run ./cmd/ds2api
```
Inspect/clear (Admin JWT required):
- `GET /admin/dev/captures`: list captured items (newest first)
- `DELETE /admin/dev/captures`: clear captured items
Response fields include:
- `request_body`: full payload sent to DeepSeek
- `response_body`: concatenated raw upstream stream body text
- `response_truncated`: whether body-size truncation happened
## Project Structure
```text
@@ -350,8 +376,10 @@ ds2api/
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
│ └── locales/ # Language packs (zh.json / en.json)
├── scripts/
── build-webui.sh # Manual WebUI build script
│ └── testsuite/ # Testsuite runner scripts
── build-webui.sh # Manual WebUI build script
── tests/
│ ├── compat/ # Compatibility fixtures and expected outputs
│ └── scripts/ # Unified test script entrypoints (unit/e2e)
├── static/admin/ # WebUI build output (not committed to Git)
├── .github/
│ ├── workflows/ # GitHub Actions (Release artifact automation)
@@ -379,11 +407,11 @@ ds2api/
## Testing
```bash
# Unit tests
go test ./...
# Unit tests (Go + Node)
./tests/scripts/run-unit-all.sh
# One-command live end-to-end tests (real accounts, full request/response logs)
./scripts/testsuite/run-live.sh
./tests/scripts/run-live.sh
# Or with custom flags
go run ./cmd/ds2api-tests \

View File

@@ -8,8 +8,10 @@ DS2API 提供两个层级的测试:
| 层级 | 命令 | 说明 |
| --- | --- | --- |
| 单元测试 | `go test ./...` | 不需要真实账号 |
| 端到端测试 | `./scripts/testsuite/run-live.sh` | 使用真实账号执行全链路测试 |
| 单元测试Go | `./tests/scripts/run-unit-go.sh` | 不需要真实账号 |
| 单元测试Node | `./tests/scripts/run-unit-node.sh` | 不需要真实账号 |
| 单元测试(全部) | `./tests/scripts/run-unit-all.sh` | 不需要真实账号 |
| 端到端测试 | `./tests/scripts/run-live.sh` | 使用真实账号执行全链路测试 |
端到端测试集会录制完整的请求/响应日志,用于故障排查。
@@ -20,17 +22,19 @@ DS2API 提供两个层级的测试:
### 单元测试 | Unit Tests
```bash
go test ./...
./tests/scripts/run-unit-all.sh
```
```bash
node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js
# 或按语言拆分执行
./tests/scripts/run-unit-go.sh
./tests/scripts/run-unit-node.sh
```
### 端到端测试 | End-to-End Tests
```bash
./scripts/testsuite/run-live.sh
./tests/scripts/run-live.sh
```
**默认行为**
@@ -179,7 +183,7 @@ go run ./cmd/ds2api-tests \
```bash
# 确保 config.json 存在且包含有效测试账号
./scripts/testsuite/run-live.sh
./tests/scripts/run-live.sh
exit_code=$?
if [ $exit_code -ne 0 ]; then
echo "Tests failed! Check artifacts for details."

View File

@@ -38,6 +38,10 @@ func RegisterRoutes(r chi.Router, h *Handler) {
r.Get("/anthropic/v1/models", h.ListModels)
r.Post("/anthropic/v1/messages", h.Messages)
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
r.Post("/v1/messages", h.Messages)
r.Post("/messages", h.Messages)
r.Post("/v1/messages/count_tokens", h.CountTokens)
r.Post("/messages/count_tokens", h.CountTokens)
}
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
@@ -167,7 +171,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
canFlush := rc.Flush() == nil
_, canFlush := w.(http.Flusher)
if !canFlush {
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
}
@@ -250,7 +254,7 @@ func normalizeClaudeMessages(messages []any) []any {
}
}
if typeStr == "tool_result" {
parts = append(parts, fmt.Sprintf("%v", b["content"]))
parts = append(parts, formatClaudeToolResultForPrompt(b))
}
}
copied["content"] = strings.Join(parts, "\n")
@@ -272,10 +276,36 @@ func buildClaudeToolPrompt(tools []any) string {
schema, _ := json.Marshal(m["input_schema"])
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
}
parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}")
parts = append(parts,
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
)
return strings.Join(parts, "\n\n")
}
func formatClaudeToolResultForPrompt(block map[string]any) string {
if block == nil {
return ""
}
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
if toolCallID == "" {
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
}
if toolCallID == "" {
toolCallID = "unknown"
}
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
if name == "" {
name = "unknown"
}
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
if content == "" {
content = "null"
}
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
}
func hasSystemMessage(messages []any) bool {
for _, m := range messages {
msg, ok := m.(map[string]any)

View File

@@ -1,6 +1,7 @@
package claude
import (
"strings"
"testing"
)
@@ -48,8 +49,9 @@ func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["content"] != "tool output" {
t.Fatalf("expected 'tool output', got %q", m["content"])
content, _ := m["content"].(string)
if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") {
t.Fatalf("expected serialized tool result marker, got %q", content)
}
}

View File

@@ -0,0 +1,44 @@
package claude
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"ds2api/internal/auth"
)
type routeAliasAuthStub struct{}
func (routeAliasAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
return nil, auth.ErrUnauthorized
}
func (routeAliasAuthStub) Release(_ *auth.RequestAuth) {}
func TestClaudeRouteAliasesDoNot404(t *testing.T) {
h := &Handler{
Auth: routeAliasAuthStub{},
}
r := chi.NewRouter()
RegisterRoutes(r, h)
paths := []string{
"/anthropic/v1/messages",
"/v1/messages",
"/messages",
"/anthropic/v1/messages/count_tokens",
"/v1/messages/count_tokens",
"/messages/count_tokens",
}
for _, path := range paths {
req := httptest.NewRequest(http.MethodPost, path, nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code == http.StatusNotFound {
t.Fatalf("expected route %s to be registered, got 404", path)
}
}
}

View File

@@ -27,9 +27,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
payload := cloneMap(req)
payload["messages"] = normalizedMessages
toolsRequested, _ := req["tools"].([]any)
if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) {
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...)
}
payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested)
dsPayload := convertClaudeToDeepSeek(payload, store)
dsModel, _ := dsPayload["model"].(string)
@@ -57,3 +55,59 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
NormalizedMessages: normalizedMessages,
}, nil
}
func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any {
if len(tools) == 0 {
return normalizedMessages
}
toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools))
if toolPrompt == "" {
return normalizedMessages
}
// Prefer top-level Anthropic-style system prompt when available.
if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" {
payload["system"] = mergeSystemPrompt(systemText, toolPrompt)
return normalizedMessages
}
messages := cloneAnySlice(normalizedMessages)
for i := range messages {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
role, _ := msg["role"].(string)
if !strings.EqualFold(strings.TrimSpace(role), "system") {
continue
}
copied := cloneMap(msg)
copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt)
messages[i] = copied
return messages
}
return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...)
}
func mergeSystemPrompt(base, extra string) string {
base = strings.TrimSpace(base)
extra = strings.TrimSpace(extra)
switch {
case base == "":
return extra
case extra == "":
return base
default:
return base + "\n\n" + extra
}
}
func cloneAnySlice(in []any) []any {
if len(in) == 0 {
return nil
}
out := make([]any, len(in))
copy(out, in)
return out
}

View File

@@ -36,3 +36,57 @@ func TestNormalizeClaudeRequest(t *testing.T) {
t.Fatalf("expected non-empty final prompt")
}
}
func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{}`)
store := config.LoadStore()
req := map[string]any{
"model": "claude-sonnet-4-5",
"messages": []any{
map[string]any{"role": "system", "content": "baseline rule"},
map[string]any{"role": "user", "content": "hello"},
},
"tools": []any{
map[string]any{"name": "search", "description": "Search"},
},
}
norm, err := normalizeClaudeRequest(store, req)
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt)
}
if !containsStr(norm.Standard.FinalPrompt, "baseline rule") {
t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt)
}
}
func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{}`)
store := config.LoadStore()
req := map[string]any{
"model": "claude-sonnet-4-5",
"system": "top-level system",
"messages": []any{
map[string]any{"role": "user", "content": "hello"},
},
"tools": []any{
map[string]any{"name": "search", "description": "Search"},
},
}
norm, err := normalizeClaudeRequest(store, req)
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if !containsStr(norm.Standard.FinalPrompt, "top-level system") {
t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt)
}
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt)
}
}

View File

@@ -0,0 +1,100 @@
package claude
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
chimw "github.com/go-chi/chi/v5/middleware"
"ds2api/internal/auth"
)
type streamStatusClaudeAuthStub struct{}
func (streamStatusClaudeAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
return &auth.RequestAuth{
UseConfigToken: false,
DeepSeekToken: "direct-token",
CallerID: "caller:test",
TriedAccounts: map[string]bool{},
}, nil
}
func (streamStatusClaudeAuthStub) Release(_ *auth.RequestAuth) {}
type streamStatusClaudeDSStub struct{}
func (streamStatusClaudeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
return "session-id", nil
}
func (streamStatusClaudeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
return "pow", nil
}
func (streamStatusClaudeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
body := "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n" + "data: [DONE]\n"
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: ioNopCloser{strings.NewReader(body)},
}, nil
}
type ioNopCloser struct {
*strings.Reader
}
func (ioNopCloser) Close() error { return nil }
type streamStatusClaudeStoreStub struct{}
func (streamStatusClaudeStoreStub) ClaudeMapping() map[string]string {
return map[string]string{
"fast": "deepseek-chat",
"slow": "deepseek-reasoner",
}
}
func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
next.ServeHTTP(ww, r)
*statuses = append(*statuses, ww.Status())
})
}
}
func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) {
statuses := make([]int, 0, 1)
h := &Handler{
Store: streamStatusClaudeStoreStub{},
Auth: streamStatusClaudeAuthStub{},
DS: streamStatusClaudeDSStub{},
}
r := chi.NewRouter()
r.Use(captureClaudeStatusMiddleware(&statuses))
RegisterRoutes(r, h)
reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody))
req.Header.Set("Authorization", "Bearer direct-token")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
if len(statuses) != 1 {
t.Fatalf("expected one captured status, got %d", len(statuses))
}
if statuses[0] != http.StatusOK {
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
}
}

View File

@@ -0,0 +1,313 @@
package gemini
import (
"encoding/json"
"fmt"
"strings"
"ds2api/internal/adapter/openai"
"ds2api/internal/config"
"ds2api/internal/util"
)
func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
requestedModel := strings.TrimSpace(routeModel)
if requestedModel == "" {
return util.StandardRequest{}, fmt.Errorf("model is required in request path")
}
resolvedModel, ok := config.ResolveModel(store, requestedModel)
if !ok {
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
messagesRaw := geminiMessagesFromRequest(req)
if len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
}
toolsRaw := convertGeminiTools(req["tools"])
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
passThrough := collectGeminiPassThrough(req)
return util.StandardRequest{
Surface: "google_gemini",
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
ResponseModel: requestedModel,
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: stream,
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
}, nil
}
func geminiMessagesFromRequest(req map[string]any) []any {
out := make([]any, 0, 8)
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
out = append(out, map[string]any{
"role": "system",
"content": sys,
})
}
contents, _ := req["contents"].([]any)
for _, item := range contents {
content, ok := item.(map[string]any)
if !ok {
continue
}
role := mapGeminiRole(content["role"])
if role == "" {
role = "user"
}
parts, _ := content["parts"].([]any)
if len(parts) == 0 {
if text := strings.TrimSpace(asString(content["text"])); text != "" {
out = append(out, map[string]any{
"role": role,
"content": text,
})
}
continue
}
textParts := make([]string, 0, len(parts))
flushText := func() {
if len(textParts) == 0 {
return
}
out = append(out, map[string]any{
"role": role,
"content": strings.Join(textParts, "\n"),
})
textParts = textParts[:0]
}
for _, rawPart := range parts {
part, ok := rawPart.(map[string]any)
if !ok {
continue
}
if text := strings.TrimSpace(asString(part["text"])); text != "" {
textParts = append(textParts, text)
continue
}
if fnCall, ok := part["functionCall"].(map[string]any); ok {
flushText()
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
callID := strings.TrimSpace(asString(fnCall["id"]))
if callID == "" {
callID = "call_gemini"
}
out = append(out, map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": callID,
"type": "function",
"function": map[string]any{
"name": name,
"arguments": stringifyJSON(fnCall["args"]),
},
},
},
})
}
continue
}
if fnResp, ok := part["functionResponse"].(map[string]any); ok {
flushText()
name := strings.TrimSpace(asString(fnResp["name"]))
callID := strings.TrimSpace(asString(fnResp["id"]))
if callID == "" {
callID = strings.TrimSpace(asString(fnResp["callId"]))
}
if callID == "" {
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
}
if callID == "" {
callID = "call_gemini"
}
content := fnResp["response"]
if content == nil {
content = fnResp["output"]
}
if content == nil {
content = ""
}
msg := map[string]any{
"role": "tool",
"tool_call_id": callID,
"content": content,
}
if name != "" {
msg["name"] = name
}
out = append(out, msg)
}
}
flushText()
}
return out
}
func normalizeGeminiSystemInstruction(raw any) string {
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case map[string]any:
if parts, ok := v["parts"].([]any); ok {
texts := make([]string, 0, len(parts))
for _, item := range parts {
part, ok := item.(map[string]any)
if !ok {
continue
}
if text := strings.TrimSpace(asString(part["text"])); text != "" {
texts = append(texts, text)
}
}
return strings.Join(texts, "\n")
}
if text := strings.TrimSpace(asString(v["text"])); text != "" {
return text
}
}
return ""
}
func mapGeminiRole(v any) string {
switch strings.ToLower(strings.TrimSpace(asString(v))) {
case "user":
return "user"
case "model", "assistant":
return "assistant"
case "system":
return "system"
default:
return ""
}
}
func convertGeminiTools(raw any) []any {
tools, _ := raw.([]any)
if len(tools) == 0 {
return nil
}
out := make([]any, 0, len(tools))
for _, item := range tools {
tool, ok := item.(map[string]any)
if !ok {
continue
}
if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
for _, declRaw := range fnDecls {
decl, ok := declRaw.(map[string]any)
if !ok {
continue
}
name := strings.TrimSpace(asString(decl["name"]))
if name == "" {
continue
}
function := map[string]any{
"name": name,
}
if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
function["description"] = desc
}
if params, ok := decl["parameters"].(map[string]any); ok {
function["parameters"] = params
}
out = append(out, map[string]any{
"type": "function",
"function": function,
})
}
continue
}
// OpenAI-style passthrough fallback.
if _, ok := tool["function"].(map[string]any); ok {
out = append(out, tool)
continue
}
// Loose fallback for flattened function schema objects.
name := strings.TrimSpace(asString(tool["name"]))
if name == "" {
continue
}
fn := map[string]any{"name": name}
if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
fn["description"] = desc
}
if params, ok := tool["parameters"].(map[string]any); ok {
fn["parameters"] = params
}
out = append(out, map[string]any{
"type": "function",
"function": fn,
})
}
if len(out) == 0 {
return nil
}
return out
}
func collectGeminiPassThrough(req map[string]any) map[string]any {
cfg, _ := req["generationConfig"].(map[string]any)
if len(cfg) == 0 {
return nil
}
out := map[string]any{}
if v, ok := cfg["temperature"]; ok {
out["temperature"] = v
}
if v, ok := cfg["topP"]; ok {
out["top_p"] = v
}
if v, ok := cfg["maxOutputTokens"]; ok {
out["max_tokens"] = v
}
if v, ok := cfg["stopSequences"]; ok {
out["stop"] = v
}
if len(out) == 0 {
return nil
}
return out
}
func asString(v any) string {
s, _ := v.(string)
return s
}
func stringifyJSON(v any) string {
switch x := v.(type) {
case nil:
return "{}"
case string:
s := strings.TrimSpace(x)
if s == "" {
return "{}"
}
return s
default:
b, err := json.Marshal(x)
if err != nil || len(b) == 0 {
return "{}"
}
return string(b)
}
}

View File

@@ -0,0 +1,29 @@
package gemini
import (
"context"
"net/http"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/deepseek"
)
type AuthResolver interface {
Determine(req *http.Request) (*auth.RequestAuth, error)
Release(a *auth.RequestAuth)
}
type DeepSeekCaller interface {
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
}
type ConfigReader interface {
ModelAliases() map[string]string
}
var _ AuthResolver = (*auth.Resolver)(nil)
var _ DeepSeekCaller = (*deepseek.Client)(nil)
var _ ConfigReader = (*config.Store)(nil)

View File

@@ -0,0 +1,348 @@
package gemini
import (
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"ds2api/internal/auth"
"ds2api/internal/deepseek"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)
var writeJSON = util.WriteJSON
type Handler struct {
Store ConfigReader
Auth AuthResolver
DS DeepSeekCaller
}
func RegisterRoutes(r chi.Router, h *Handler) {
r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
}
func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
h.handleGenerateContent(w, r, false)
}
func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
h.handleGenerateContent(w, r, true)
}
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeGeminiError(w, status, detail)
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeGeminiError(w, http.StatusBadRequest, "invalid json")
return
}
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
if err != nil {
writeGeminiError(w, http.StatusBadRequest, err.Error())
return
}
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
}
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
if stream {
h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
}
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
result := sse.CollectStream(resp, thinkingEnabled, true)
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
}
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
_, canFlush := w.(http.Flusher)
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
Context: r.Context(),
Body: resp.Body,
ThinkingEnabled: thinkingEnabled,
InitialType: initialType,
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
}, streamengine.ConsumeHooks{
OnParsed: runtime.onParsed,
OnFinalize: func(_ streamengine.StopReason, _ error) {
runtime.finalize()
},
})
}
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
return map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": parts,
},
"finishReason": "STOP",
},
},
"modelVersion": model,
"usageMetadata": usage,
}
}
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
return map[string]any{
"promptTokenCount": promptTokens,
"candidatesTokenCount": reasoningTokens + completionTokens,
"totalTokenCount": promptTokens + reasoningTokens + completionTokens,
}
}
func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
detected = util.ParseToolCalls(finalThinking, toolNames)
}
if len(detected) > 0 {
parts := make([]map[string]any, 0, len(detected))
for _, tc := range detected {
parts = append(parts, map[string]any{
"functionCall": map[string]any{
"name": tc.Name,
"args": tc.Input,
},
})
}
return parts
}
text := finalText
if strings.TrimSpace(text) == "" {
text = finalThinking
}
return []map[string]any{{"text": text}}
}
type geminiStreamRuntime struct {
w http.ResponseWriter
rc *http.ResponseController
canFlush bool
model string
finalPrompt string
thinkingEnabled bool
searchEnabled bool
bufferContent bool
toolNames []string
thinking strings.Builder
text strings.Builder
}
func newGeminiStreamRuntime(
w http.ResponseWriter,
rc *http.ResponseController,
canFlush bool,
model string,
finalPrompt string,
thinkingEnabled bool,
searchEnabled bool,
toolNames []string,
) *geminiStreamRuntime {
return &geminiStreamRuntime{
w: w,
rc: rc,
canFlush: canFlush,
model: model,
finalPrompt: finalPrompt,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
bufferContent: len(toolNames) > 0,
toolNames: toolNames,
}
}
func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
b, _ := json.Marshal(payload)
_, _ = s.w.Write([]byte("data: "))
_, _ = s.w.Write(b)
_, _ = s.w.Write([]byte("\n\n"))
if s.canFlush {
_ = s.rc.Flush()
}
}
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
return streamengine.ParsedDecision{Stop: true}
}
contentSeen := false
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
continue
}
contentSeen = true
if p.Type == "thinking" {
if s.thinkingEnabled {
s.thinking.WriteString(p.Text)
}
continue
}
s.text.WriteString(p.Text)
if s.bufferContent {
continue
}
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": []map[string]any{{"text": p.Text}},
},
},
},
"modelVersion": s.model,
})
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}
func (s *geminiStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := s.text.String()
if s.bufferContent {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"content": map[string]any{
"role": "model",
"parts": parts,
},
},
},
"modelVersion": s.model,
})
}
s.sendChunk(map[string]any{
"candidates": []map[string]any{
{
"index": 0,
"finishReason": "STOP",
},
},
"modelVersion": s.model,
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
})
}
func writeGeminiError(w http.ResponseWriter, status int, message string) {
errorStatus := "INVALID_ARGUMENT"
switch status {
case http.StatusUnauthorized:
errorStatus = "UNAUTHENTICATED"
case http.StatusForbidden:
errorStatus = "PERMISSION_DENIED"
case http.StatusTooManyRequests:
errorStatus = "RESOURCE_EXHAUSTED"
case http.StatusNotFound:
errorStatus = "NOT_FOUND"
default:
if status >= 500 {
errorStatus = "INTERNAL"
}
}
writeJSON(w, status, map[string]any{
"error": map[string]any{
"code": status,
"message": message,
"status": errorStatus,
},
})
}

View File

@@ -0,0 +1,174 @@
package gemini
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
"ds2api/internal/auth"
)
type testGeminiConfig struct{}
func (testGeminiConfig) ModelAliases() map[string]string { return nil }
type testGeminiAuth struct {
a *auth.RequestAuth
err error
}
func (m testGeminiAuth) Determine(_ *http.Request) (*auth.RequestAuth, error) {
if m.err != nil {
return nil, m.err
}
if m.a != nil {
return m.a, nil
}
return &auth.RequestAuth{
UseConfigToken: false,
DeepSeekToken: "direct-token",
CallerID: "caller:test",
TriedAccounts: map[string]bool{},
}, nil
}
func (testGeminiAuth) Release(_ *auth.RequestAuth) {}
type testGeminiDS struct {
resp *http.Response
err error
}
func (m testGeminiDS) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
return "session-id", nil
}
func (m testGeminiDS) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
return "pow", nil
}
func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
if m.err != nil {
return nil, m.err
}
return m.resp, nil
}
func makeGeminiUpstreamResponse(lines ...string) *http.Response {
body := strings.Join(lines, "\n")
if !strings.HasSuffix(body, "\n") {
body += "\n"
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func TestGeminiRoutesRegistered(t *testing.T) {
h := &Handler{
Store: testGeminiConfig{},
Auth: testGeminiAuth{err: auth.ErrUnauthorized},
}
r := chi.NewRouter()
RegisterRoutes(r, h)
paths := []string{
"/v1beta/models/gemini-2.5-pro:generateContent",
"/v1beta/models/gemini-2.5-pro:streamGenerateContent",
"/v1/models/gemini-2.5-pro:generateContent",
"/v1/models/gemini-2.5-pro:streamGenerateContent",
}
for _, path := range paths {
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`))
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code == http.StatusNotFound {
t.Fatalf("expected route %s to be registered, got 404", path)
}
}
}
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
upstream := makeGeminiUpstreamResponse(
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
`data: [DONE]`,
)
h := &Handler{
Store: testGeminiConfig{},
Auth: testGeminiAuth{},
DS: testGeminiDS{resp: upstream},
}
r := chi.NewRouter()
RegisterRoutes(r, h)
body := `{
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
}`
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
req.Header.Set("Authorization", "Bearer direct-token")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
candidates, _ := out["candidates"].([]any)
if len(candidates) == 0 {
t.Fatalf("expected non-empty candidates: %#v", out)
}
c0, _ := candidates[0].(map[string]any)
content, _ := c0["content"].(map[string]any)
parts, _ := content["parts"].([]any)
if len(parts) == 0 {
t.Fatalf("expected non-empty parts: %#v", content)
}
part0, _ := parts[0].(map[string]any)
functionCall, _ := part0["functionCall"].(map[string]any)
if functionCall["name"] != "eval_javascript" {
t.Fatalf("expected functionCall name eval_javascript, got %#v", functionCall)
}
}
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
upstream := makeGeminiUpstreamResponse(
`data: {"p":"response/content","v":"hello "}`,
`data: {"p":"response/content","v":"world"}`,
`data: [DONE]`,
)
h := &Handler{
Store: testGeminiConfig{},
Auth: testGeminiAuth{},
DS: testGeminiDS{resp: upstream},
}
r := chi.NewRouter()
RegisterRoutes(r, h)
body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`
req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:streamGenerateContent?alt=sse", strings.NewReader(body))
req.Header.Set("Authorization", "Bearer direct-token")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "data: ") {
t.Fatalf("expected SSE data frames, got body=%s", rec.Body.String())
}
if !strings.Contains(rec.Body.String(), `"finishReason":"STOP"`) {
t.Fatalf("expected stream finish frame, got body=%s", rec.Body.String())
}
}

View File

@@ -25,10 +25,11 @@ type chatStreamRuntime struct {
thinkingEnabled bool
searchEnabled bool
firstChunkSent bool
bufferToolContent bool
emitEarlyToolDeltas bool
toolCallsEmitted bool
firstChunkSent bool
bufferToolContent bool
emitEarlyToolDeltas bool
toolCallsEmitted bool
toolCallsDoneEmitted bool
toolSieve toolStreamSieveState
streamToolCallIDs map[int]string
@@ -96,10 +97,10 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
finalThinking := s.thinking.String()
finalText := s.text.String()
detected := util.ParseToolCalls(finalText, s.toolNames)
if len(detected) > 0 && !s.toolCallsEmitted {
if len(detected) > 0 && !s.toolCallsDoneEmitted {
finishReason = "tool_calls"
delta := map[string]any{
"tool_calls": util.FormatOpenAIStreamToolCalls(detected),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs),
}
if !s.firstChunkSent {
delta["role"] = "assistant"
@@ -112,8 +113,29 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
nil,
))
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
} else if s.bufferToolContent {
for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) {
if len(evt.ToolCalls) > 0 {
finishReason = "tool_calls"
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
tcDelta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
}
if !s.firstChunkSent {
tcDelta["role"] = "assistant"
s.firstChunkSent = true
}
s.sendChunk(openaifmt.BuildChatStreamChunk(
s.completionID,
s.created,
s.model,
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
nil,
))
}
if evt.Content == "" {
continue
}
@@ -189,10 +211,14 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
if !s.emitEarlyToolDeltas {
continue
}
s.toolCallsEmitted = true
tcDelta := map[string]any{
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs),
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
if len(formatted) == 0 {
continue
}
tcDelta := map[string]any{
"tool_calls": formatted,
}
s.toolCallsEmitted = true
if !s.firstChunkSent {
tcDelta["role"] = "assistant"
s.firstChunkSent = true
@@ -202,8 +228,9 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
}
if len(evt.ToolCalls) > 0 {
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
tcDelta := map[string]any{
"tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
}
if !s.firstChunkSent {
tcDelta["role"] = "assistant"

View File

@@ -31,7 +31,7 @@ func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
"model": "my-model",
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
}
out, err := normalizeOpenAIChatRequest(cfg, req)
out, err := normalizeOpenAIChatRequest(cfg, req, "")
if err != nil {
t.Fatalf("normalizeOpenAIChatRequest error: %v", err)
}
@@ -52,7 +52,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
_, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
aliases: map[string]string{},
wideInput: false,
}, req)
}, req, "")
if err == nil {
t.Fatal("expected error when wide input is disabled and only input is provided")
}
@@ -60,7 +60,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
aliases: map[string]string{},
wideInput: true,
}, req)
}, req, "")
if err != nil {
t.Fatalf("unexpected error when wide input is enabled: %v", err)
}

View File

@@ -93,7 +93,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
@@ -154,7 +154,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
canFlush := rc.Flush() == nil
_, canFlush := w.(http.Flusher)
if !canFlush {
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
}
@@ -233,7 +233,7 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
if len(toolSchemas) == 0 {
return messages, names
}
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error."
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
for i := range messages {
if messages[i]["role"] == "system" {
@@ -280,6 +280,36 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s
return out
}
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
if len(calls) == 0 {
return nil
}
out := make([]map[string]any, 0, len(calls))
for i, c := range calls {
callID := ""
if ids != nil {
callID = strings.TrimSpace(ids[i])
}
if callID == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if ids != nil {
ids[i] = callID
}
}
args, _ := json.Marshal(c.Input)
out = append(out, map[string]any{
"index": i,
"id": callID,
"type": "function",
"function": map[string]any{
"name": c.Name,
"arguments": string(args),
},
})
}
return out
}
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]any{
"error": map[string]any{

View File

@@ -735,3 +735,73 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
`data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
foundSearch := false
foundEval := false
foundIndex1 := false
toolCallsDeltaLens := make([]int, 0, 2)
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
rawToolCalls, hasToolCalls := delta["tool_calls"]
if !hasToolCalls {
continue
}
toolCalls, _ := rawToolCalls.([]any)
toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls))
for _, tc := range toolCalls {
tcm, _ := tc.(map[string]any)
if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
foundIndex1 = true
}
fn, _ := tcm["function"].(map[string]any)
name, _ := fn["name"].(string)
switch name {
case "search_web":
foundSearch = true
case "eval_javascript":
foundEval = true
case "search_webeval_javascript":
t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
}
if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
}
}
}
}
if !foundSearch || !foundEval {
t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
}
if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 {
t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String())
}
if !foundIndex1 {
t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}

View File

@@ -4,9 +4,11 @@ import (
"encoding/json"
"fmt"
"strings"
"ds2api/internal/config"
)
func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
out := make([]map[string]any, 0, len(raw))
for _, item := range raw {
msg, ok := item.(map[string]any)
@@ -17,7 +19,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
switch role {
case "assistant":
content := normalizeOpenAIContentForPrompt(msg["content"])
toolCalls := formatAssistantToolCallsForPrompt(msg)
toolCalls := formatAssistantToolCallsForPrompt(msg, traceID)
combined := joinNonEmpty(content, toolCalls)
if combined == "" {
continue
@@ -53,7 +55,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
return out
}
func formatAssistantToolCallsForPrompt(msg map[string]any) string {
func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string {
entries := make([]string, 0)
if calls, ok := msg["tool_calls"].([]any); ok {
for i, item := range calls {
@@ -86,7 +88,8 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
if args == "" {
args = "{}"
}
entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: %s\n- function.name: %s\n- function.arguments: %s", id, name, args))
maybeWarnSuspiciousToolHistory(traceID, id, name, args)
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args))
}
}
@@ -99,7 +102,8 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
if args == "" {
args = "{}"
}
entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: call_legacy\n- function.name: %s\n- function.arguments: %s", name, args))
maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args)
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args))
}
return strings.Join(entries, "\n\n")
@@ -124,7 +128,7 @@ func formatToolResultForPrompt(msg map[string]any) string {
content = "null"
}
return fmt.Sprintf("Tool result:\n- tool_call_id: %s\n- name: %s\n- content: %s", toolCallID, name, content)
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
}
func normalizeOpenAIContentForPrompt(v any) string {
@@ -190,3 +194,45 @@ func joinNonEmpty(parts ...string) string {
}
return strings.Join(nonEmpty, "\n\n")
}
func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) {
if !looksLikeConcatenatedJSON(args) {
return
}
traceID = strings.TrimSpace(traceID)
if traceID == "" {
traceID = "unknown"
}
config.Logger.Warn(
"[openai] suspicious tool call history payload detected",
"trace_id", traceID,
"tool_call_id", strings.TrimSpace(callID),
"name", strings.TrimSpace(name),
"arguments_preview", previewToolArgs(args, 160),
)
}
func looksLikeConcatenatedJSON(raw string) bool {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return false
}
if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
return true
}
dec := json.NewDecoder(strings.NewReader(trimmed))
var first any
if err := dec.Decode(&first); err != nil {
return false
}
var second any
return dec.Decode(&second) == nil
}
func previewToolArgs(raw string, max int) string {
trimmed := strings.TrimSpace(raw)
if max <= 0 || len(trimmed) <= max {
return trimmed
}
return trimmed[:max]
}

View File

@@ -33,23 +33,24 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 4 {
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
}
assistantContent, _ := normalized[2]["content"].(string)
if !strings.Contains(assistantContent, "tool_call_id: call_1") ||
if !strings.Contains(assistantContent, "[TOOL_CALL_HISTORY]") ||
!strings.Contains(assistantContent, "tool_call_id: call_1") ||
!strings.Contains(assistantContent, "function.name: get_weather") ||
!strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") {
t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent)
}
toolContent, _ := normalized[3]["content"].(string)
if !strings.Contains(toolContent, "Tool result:") || !strings.Contains(toolContent, "name: get_weather") {
if !strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") || !strings.Contains(toolContent, "name: get_weather") {
t.Fatalf("tool result not serialized correctly: %q", toolContent)
}
prompt := util.MessagesPrepare(normalized)
if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "Tool result:") {
if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") {
t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt)
}
}
@@ -67,7 +68,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
got, _ := normalized[0]["content"].(string)
if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) {
t.Fatalf("expected serialized object in tool content, got %q", got)
@@ -88,7 +89,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
got, _ := normalized[0]["content"].(string)
if !strings.Contains(got, "line-1\nline-2") {
t.Fatalf("expected joined text blocks, got %q", got)
@@ -107,7 +108,7 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw)
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 1 {
t.Fatalf("expected one normalized message, got %d", len(normalized))
}
@@ -119,3 +120,50 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
t.Fatalf("unexpected normalized function-role content: %q", got)
}
}
func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) {
raw := []any{
map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": "call_search",
"type": "function",
"function": map[string]any{
"name": "search_web",
"arguments": `{"query":"latest ai news"}`,
},
},
map[string]any{
"id": "call_eval",
"type": "function",
"function": map[string]any{
"name": "eval_javascript",
"arguments": `{"code":"1+1"}`,
},
},
},
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 1 {
t.Fatalf("expected one normalized assistant message, got %d", len(normalized))
}
content, _ := normalized[0]["content"].(string)
if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 {
t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content)
}
if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") {
t.Fatalf("missing first tool call block, got %q", content)
}
if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") {
t.Fatalf("missing second tool call block, got %q", content)
}
if strings.Contains(content, "search_webeval_javascript") {
t.Fatalf("unexpected merged function name detected: %q", content)
}
if strings.Contains(content, `}{"`) {
t.Fatalf("unexpected concatenated function arguments detected: %q", content)
}
}

View File

@@ -4,11 +4,18 @@ import (
"ds2api/internal/deepseek"
)
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) {
messages := normalizeOpenAIMessagesForPrompt(messagesRaw)
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
toolNames := []string{}
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
messages, toolNames = injectToolPrompt(messages, tools)
}
return deepseek.MessagesPrepare(messages), toolNames
}
// BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so
// other protocol adapters (for example Gemini) can reuse the same tool/history
// normalization logic and remain behavior-compatible with chat/completions.
func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID)
}

View File

@@ -40,13 +40,13 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
},
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools)
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "")
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
t.Fatalf("unexpected tool names: %#v", toolNames)
}
if !strings.Contains(finalPrompt, "tool_call_id: call_1") ||
!strings.Contains(finalPrompt, "function.name: get_weather") ||
!strings.Contains(finalPrompt, "Tool result:") ||
!strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") ||
!strings.Contains(finalPrompt, `"condition":"sunny"`) {
t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt)
}
@@ -70,11 +70,14 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
},
}
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools)
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "")
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
}
if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") {
t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt)
}
if !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") {
t.Fatalf("vercel prepare finalPrompt missing history marker instruction: %q", finalPrompt)
}
}

View File

@@ -1,6 +1,7 @@
package openai
import (
"strings"
"testing"
"time"
)
@@ -32,6 +33,82 @@ func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) {
}
}
func TestNormalizeResponsesInputAsMessagesObjectRoleContentBlocks(t *testing.T) {
msgs := normalizeResponsesInputAsMessages(map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "input_text", "text": "line-1"},
map[string]any{"type": "input_text", "text": "line-2"},
},
})
if len(msgs) != 1 {
t.Fatalf("expected one message, got %d", len(msgs))
}
m, _ := msgs[0].(map[string]any)
if m["role"] != "user" {
t.Fatalf("unexpected role: %#v", m)
}
if strings.TrimSpace(normalizeOpenAIContentForPrompt(m["content"])) != "line-1\nline-2" {
t.Fatalf("unexpected content: %#v", m["content"])
}
}
func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) {
msgs := normalizeResponsesInputAsMessages([]any{
map[string]any{
"type": "function_call_output",
"call_id": "call_123",
"output": map[string]any{"ok": true},
},
})
if len(msgs) != 1 {
t.Fatalf("expected one message, got %d", len(msgs))
}
m, _ := msgs[0].(map[string]any)
if m["role"] != "tool" {
t.Fatalf("expected tool role, got %#v", m)
}
if m["tool_call_id"] != "call_123" {
t.Fatalf("expected tool_call_id propagated, got %#v", m)
}
}
func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
msgs := normalizeResponsesInputAsMessages([]any{
map[string]any{
"type": "function_call",
"call_id": "call_456",
"name": "search",
"arguments": `{"q":"golang"}`,
},
})
if len(msgs) != 1 {
t.Fatalf("expected one message, got %d", len(msgs))
}
m, _ := msgs[0].(map[string]any)
if m["role"] != "assistant" {
t.Fatalf("expected assistant role, got %#v", m["role"])
}
toolCalls, _ := m["tool_calls"].([]any)
if len(toolCalls) != 1 {
t.Fatalf("expected one tool_call, got %#v", m["tool_calls"])
}
call, _ := toolCalls[0].(map[string]any)
if call["id"] != "call_456" {
t.Fatalf("expected call id preserved, got %#v", call)
}
if call["type"] != "function" {
t.Fatalf("expected function type, got %#v", call)
}
fn, _ := call["function"].(map[string]any)
if fn["name"] != "search" {
t.Fatalf("expected call name preserved, got %#v", call)
}
if fn["arguments"] != `{"q":"golang"}` {
t.Fatalf("expected call arguments preserved, got %#v", call)
}
}
func TestExtractEmbeddingInputs(t *testing.T) {
got := extractEmbeddingInputs([]any{"a", "b"})
if len(got) != 2 || got[0] != "a" || got[1] != "b" {

View File

@@ -68,7 +68,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req)
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
@@ -128,7 +128,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
canFlush := rc.Flush() == nil
_, canFlush := w.(http.Flusher)
initialType := "text"
if thinkingEnabled {
@@ -203,40 +203,231 @@ func normalizeResponsesInputAsMessages(input any) []any {
}
return []any{map[string]any{"role": "user", "content": v}}
case []any:
if len(v) == 0 {
return nil
}
// If caller already provides role-shaped items, keep as-is.
if first, ok := v[0].(map[string]any); ok {
if _, hasRole := first["role"]; hasRole {
return v
}
}
parts := make([]string, 0, len(v))
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
parts = append(parts, txt)
continue
}
}
}
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
parts = append(parts, s)
}
}
if len(parts) == 0 {
return nil
}
return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}}
return normalizeResponsesInputArray(v)
case map[string]any:
if msg := normalizeResponsesInputItem(v); msg != nil {
return []any{msg}
}
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
return []any{map[string]any{"role": "user", "content": txt}}
}
if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" {
return []any{map[string]any{"role": "user", "content": content}}
if content, ok := v["content"]; ok {
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
return []any{map[string]any{"role": "user", "content": content}}
}
}
}
return nil
}
func normalizeResponsesInputArray(items []any) []any {
if len(items) == 0 {
return nil
}
out := make([]any, 0, len(items))
fallbackParts := make([]string, 0, len(items))
flushFallback := func() {
if len(fallbackParts) == 0 {
return
}
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
fallbackParts = fallbackParts[:0]
}
for _, item := range items {
switch x := item.(type) {
case map[string]any:
if msg := normalizeResponsesInputItem(x); msg != nil {
flushFallback()
out = append(out, msg)
continue
}
if s := normalizeResponsesFallbackPart(x); s != "" {
fallbackParts = append(fallbackParts, s)
}
default:
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
fallbackParts = append(fallbackParts, s)
}
}
}
flushFallback()
if len(out) == 0 {
return nil
}
return out
}
func normalizeResponsesInputItem(m map[string]any) map[string]any {
if m == nil {
return nil
}
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
if role != "" {
content := m["content"]
if content == nil {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
content = txt
}
}
if content == nil {
return nil
}
return map[string]any{
"role": role,
"content": content,
}
}
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
switch itemType {
case "message", "input_message":
content := m["content"]
if content == nil {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
content = txt
}
}
if content == nil {
return nil
}
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
if role == "" {
role = "user"
}
return map[string]any{
"role": role,
"content": content,
}
case "function_call_output", "tool_result":
content := m["output"]
if content == nil {
content = m["content"]
}
if content == nil {
content = ""
}
out := map[string]any{
"role": "tool",
"content": content,
}
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
out["tool_call_id"] = callID
} else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
out["tool_call_id"] = callID
}
if name := strings.TrimSpace(asString(m["name"])); name != "" {
out["name"] = name
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
out["name"] = name
}
return out
case "function_call", "tool_call":
name := strings.TrimSpace(asString(m["name"]))
var fn map[string]any
if rawFn, ok := m["function"].(map[string]any); ok {
fn = rawFn
if name == "" {
name = strings.TrimSpace(asString(fn["name"]))
}
}
if name == "" {
return nil
}
var argsRaw any
if v, ok := m["arguments"]; ok {
argsRaw = v
} else if v, ok := m["input"]; ok {
argsRaw = v
}
if argsRaw == nil && fn != nil {
if v, ok := fn["arguments"]; ok {
argsRaw = v
} else if v, ok := fn["input"]; ok {
argsRaw = v
}
}
functionPayload := map[string]any{
"name": name,
"arguments": stringifyToolCallArguments(argsRaw),
}
call := map[string]any{
"type": "function",
"function": functionPayload,
}
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
call["id"] = callID
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
call["id"] = callID
}
return map[string]any{
"role": "assistant",
"tool_calls": []any{call},
}
case "input_text":
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return map[string]any{
"role": "user",
"content": txt,
}
}
}
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return map[string]any{
"role": "user",
"content": txt,
}
}
if content, ok := m["content"]; ok {
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
return map[string]any{
"role": "user",
"content": content,
}
}
}
return nil
}
func normalizeResponsesFallbackPart(m map[string]any) string {
if m == nil {
return ""
}
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return txt
}
}
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
return txt
}
if content, ok := m["content"]; ok {
if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
return normalized
}
}
return strings.TrimSpace(fmt.Sprintf("%v", m))
}
func stringifyToolCallArguments(v any) string {
switch x := v.(type) {
case nil:
return "{}"
case string:
s := strings.TrimSpace(x)
if s == "" {
return "{}"
}
return s
default:
b, err := json.Marshal(x)
if err != nil || len(b) == 0 {
return "{}"
}
return string(b)
}
}

View File

@@ -3,12 +3,15 @@ package openai
import (
"encoding/json"
"net/http"
"sort"
"strings"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
"github.com/google/uuid"
)
type responsesStreamRuntime struct {
@@ -24,14 +27,20 @@ type responsesStreamRuntime struct {
thinkingEnabled bool
searchEnabled bool
bufferToolContent bool
emitEarlyToolDeltas bool
toolCallsEmitted bool
bufferToolContent bool
emitEarlyToolDeltas bool
toolCallsEmitted bool
toolCallsDoneEmitted bool
sieve toolStreamSieveState
thinkingSieve toolStreamSieveState
thinking strings.Builder
text strings.Builder
streamToolCallIDs map[int]string
streamFunctionIDs map[int]string
functionDone map[int]bool
toolCallsDoneSigs map[string]bool
reasoningItemID string
persistResponse func(obj map[string]any)
}
@@ -63,6 +72,9 @@ func newResponsesStreamRuntime(
bufferToolContent: bufferToolContent,
emitEarlyToolDeltas: emitEarlyToolDeltas,
streamToolCallIDs: map[int]string{},
streamFunctionIDs: map[int]string{},
functionDone: map[int]bool{},
toolCallsDoneSigs: map[string]bool{},
persistResponse: persistResponse,
}
}
@@ -92,19 +104,33 @@ func (s *responsesStreamRuntime) sendDone() {
func (s *responsesStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := s.text.String()
if strings.TrimSpace(finalThinking) != "" {
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
}
if s.bufferToolContent {
for _, evt := range flushToolSieve(&s.sieve, s.toolNames) {
if evt.Content != "" {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
}
if len(evt.ToolCalls) > 0 {
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
}
// Compatibility fallback: some streams only emit incremental tool deltas.
// Ensure final function_call_arguments.done is emitted at least once.
if s.toolCallsEmitted {
detected := util.ParseToolCalls(finalText, s.toolNames)
if len(detected) == 0 {
detected = util.ParseToolCalls(finalThinking, s.toolNames)
}
if len(detected) > 0 {
if !s.toolCallsDoneEmitted {
s.emitToolCallsDone(detected)
} else {
s.emitFunctionCallDoneEvents(detected)
}
}
}
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
if s.toolCallsEmitted {
s.alignCompletedOutputCallIDs(obj)
}
if s.toolCallsEmitted {
obj["status"] = "completed"
}
@@ -138,6 +164,10 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
}
s.thinking.WriteString(p.Text)
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
if s.bufferToolContent {
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
}
continue
}
@@ -146,23 +176,191 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
continue
}
for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) {
if evt.Content != "" {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
}
if len(evt.ToolCallDeltas) > 0 {
if !s.emitEarlyToolDeltas {
continue
}
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)))
}
if len(evt.ToolCalls) > 0 {
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
}
}
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}
}
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
for _, evt := range events {
if emitContent && evt.Content != "" {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
}
if len(evt.ToolCallDeltas) > 0 {
if !s.emitEarlyToolDeltas {
continue
}
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
if len(formatted) == 0 {
continue
}
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
}
if len(evt.ToolCalls) > 0 {
s.emitToolCallsDone(evt.ToolCalls)
}
}
}
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
if len(calls) == 0 {
return
}
sig := toolCallListSignature(calls)
if sig != "" && s.toolCallsDoneSigs[sig] {
return
}
if sig != "" {
s.toolCallsDoneSigs[sig] = true
}
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
if len(formatted) == 0 {
return
}
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
s.emitFunctionCallDoneEvents(calls)
}
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
if strings.TrimSpace(s.reasoningItemID) != "" {
return s.reasoningItemID
}
s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
return s.reasoningItemID
}
func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
return id
}
id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
s.streamFunctionIDs[index] = id
return id
}
func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
return id
}
id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
s.streamToolCallIDs[index] = id
return id
}
func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
if strings.TrimSpace(s.thinking.String()) != "" {
return 1
}
return 0
}
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
for _, d := range deltas {
if strings.TrimSpace(d.Arguments) == "" {
continue
}
outputIndex := s.functionOutputBaseIndex() + d.Index
itemID := s.ensureFunctionItemID(outputIndex)
callID := s.ensureToolCallID(d.Index)
s.sendEvent(
"response.function_call_arguments.delta",
openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
)
}
}
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
base := s.functionOutputBaseIndex()
for idx, tc := range calls {
if strings.TrimSpace(tc.Name) == "" {
continue
}
outputIndex := base + idx
if s.functionDone[outputIndex] {
continue
}
itemID := s.ensureFunctionItemID(outputIndex)
callID := s.ensureToolCallID(idx)
argsBytes, _ := json.Marshal(tc.Input)
s.sendEvent(
"response.function_call_arguments.done",
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
)
s.functionDone[outputIndex] = true
}
}
func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
if obj == nil || len(s.streamToolCallIDs) == 0 {
return
}
output, _ := obj["output"].([]any)
if len(output) == 0 {
return
}
indices := make([]int, 0, len(s.streamToolCallIDs))
for idx := range s.streamToolCallIDs {
indices = append(indices, idx)
}
sort.Ints(indices)
ordered := make([]string, 0, len(indices))
for _, idx := range indices {
id := strings.TrimSpace(s.streamToolCallIDs[idx])
if id == "" {
continue
}
ordered = append(ordered, id)
}
if len(ordered) == 0 {
return
}
functionIdx := 0
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil {
continue
}
typ, _ := m["type"].(string)
switch typ {
case "function_call":
if functionIdx < len(ordered) {
m["call_id"] = ordered[functionIdx]
functionIdx++
}
case "tool_calls":
tcArr, _ := m["tool_calls"].([]any)
for i, raw := range tcArr {
tc, _ := raw.(map[string]any)
if tc == nil {
continue
}
if i < len(ordered) {
tc["id"] = ordered[i]
}
}
}
}
}
func toolCallListSignature(calls []util.ParsedToolCall) string {
if len(calls) == 0 {
return ""
}
var b strings.Builder
for i, tc := range calls {
if i > 0 {
b.WriteString("|")
}
b.WriteString(strings.TrimSpace(tc.Name))
b.WriteString(":")
args, _ := json.Marshal(tc.Input)
b.Write(args)
}
return b.String()
}

View File

@@ -45,17 +45,37 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
if len(output) == 0 {
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "tool_calls" {
t.Fatalf("expected first output type tool_calls, got %#v", first["type"])
var firstToolWrapper map[string]any
hasFunctionCall := false
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil {
continue
}
if m["type"] == "function_call" {
hasFunctionCall = true
}
if m["type"] == "tool_calls" && firstToolWrapper == nil {
firstToolWrapper = m
}
}
toolCalls, _ := first["tool_calls"].([]any)
if !hasFunctionCall {
t.Fatalf("expected at least one function_call item for responses compatibility, got %#v", responseObj["output"])
}
if firstToolWrapper == nil {
t.Fatalf("expected a tool_calls wrapper item, got %#v", responseObj["output"])
}
toolCalls, _ := firstToolWrapper["tool_calls"].([]any)
if len(toolCalls) == 0 {
t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"])
t.Fatalf("expected at least one tool_call in output, got %#v", firstToolWrapper["tool_calls"])
}
call0, _ := toolCalls[0].(map[string]any)
if call0["name"] != "read_file" {
t.Fatalf("unexpected tool call name: %#v", call0["name"])
if call0["type"] != "function" {
t.Fatalf("unexpected tool call type: %#v", call0["type"])
}
fn, _ := call0["function"].(map[string]any)
if fn["name"] != "read_file" {
t.Fatalf("unexpected tool call name: %#v", fn["name"])
}
if strings.Contains(outputText, `"tool_calls"`) {
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
@@ -95,6 +115,314 @@ func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t
}
}
func TestHandleResponsesStreamEmitsReasoningCompatEvents(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
b, _ := json.Marshal(map[string]any{
"p": "response/thinking_content",
"v": "thought",
})
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil)
body := rec.Body.String()
if !strings.Contains(body, "event: response.reasoning.delta") {
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
}
if !strings.Contains(body, "event: response.reasoning_text.delta") {
t.Fatalf("expected response.reasoning_text.delta compatibility event, body=%s", body)
}
if !strings.Contains(body, "event: response.reasoning_text.done") {
t.Fatalf("expected response.reasoning_text.done compatibility event, body=%s", body)
}
}
func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
body := rec.Body.String()
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
t.Fatalf("expected response.function_call_arguments.delta compatibility event, body=%s", body)
}
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected response.function_call_arguments.done compatibility event, body=%s", body)
}
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
if !ok {
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
}
if strings.TrimSpace(asString(donePayload["call_id"])) == "" {
t.Fatalf("expected call_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
}
if strings.TrimSpace(asString(donePayload["response_id"])) == "" {
t.Fatalf("expected response_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
}
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
if doneCallID == "" {
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
}
completed, ok := extractSSEEventPayload(body, "response.completed")
if !ok {
t.Fatalf("expected response.completed payload, body=%s", body)
}
responseObj, _ := completed["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
if len(output) == 0 {
t.Fatalf("expected non-empty output in response.completed, response=%#v", responseObj)
}
var completedCallID string
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil || m["type"] != "function_call" {
continue
}
completedCallID = strings.TrimSpace(asString(m["call_id"]))
if completedCallID != "" {
break
}
}
if completedCallID == "" {
t.Fatalf("expected function_call.call_id in completed output, output=%#v", output)
}
if completedCallID != doneCallID {
t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID)
}
}
func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(path, v string) string {
b, _ := json.Marshal(map[string]any{
"p": path,
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"})
body := rec.Body.String()
if !strings.Contains(body, "event: response.reasoning_text.delta") {
t.Fatalf("expected response.reasoning_text.delta event, body=%s", body)
}
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected response.function_call_arguments.done event from thinking channel, body=%s", body)
}
if !strings.Contains(body, "event: response.output_tool_call.done") {
t.Fatalf("expected response.output_tool_call.done event from thinking channel, body=%s", body)
}
}
func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
body := rec.Body.String()
if !strings.Contains(body, "event: response.output_tool_call.done") {
t.Fatalf("expected response.output_tool_call.done event, body=%s", body)
}
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
if len(donePayloads) != 2 {
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
}
seenNames := map[string]string{}
for _, payload := range donePayloads {
name := strings.TrimSpace(asString(payload["name"]))
callID := strings.TrimSpace(asString(payload["call_id"]))
args := strings.TrimSpace(asString(payload["arguments"]))
if callID == "" {
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
}
if strings.Contains(args, `}{"`) {
t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload)
}
if name == "search_webeval_javascript" {
t.Fatalf("unexpected merged tool name in done payload: %#v", payload)
}
if name != "search_web" && name != "eval_javascript" {
t.Fatalf("unexpected tool name in done payload: %#v", payload)
}
seenNames[name] = callID
}
if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" {
t.Fatalf("expected done events for both tools, got %#v", seenNames)
}
if seenNames["search_web"] == seenNames["eval_javascript"] {
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
}
completed, ok := extractSSEEventPayload(body, "response.completed")
if !ok {
t.Fatalf("expected response.completed event, body=%s", body)
}
responseObj, _ := completed["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
functionCallIDs := map[string]string{}
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil || m["type"] != "function_call" {
continue
}
name := strings.TrimSpace(asString(m["name"]))
callID := strings.TrimSpace(asString(m["call_id"]))
if name != "" && callID != "" {
functionCallIDs[name] = callID
}
}
if functionCallIDs["search_web"] != seenNames["search_web"] {
t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"])
}
if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] {
t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"])
}
}
func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(path, v string) string {
b, _ := json.Marshal(map[string]any{
"p": path,
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"})
body := rec.Body.String()
if !strings.Contains(body, "event: response.reasoning_text.delta") {
t.Fatalf("expected reasoning stream events, body=%s", body)
}
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
if len(donePayloads) != 2 {
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
}
seen := map[string]bool{}
for _, payload := range donePayloads {
name := strings.TrimSpace(asString(payload["name"]))
if name == "search_webeval_javascript" {
t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload)
}
if name != "search_web" && name != "eval_javascript" {
t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload)
}
args := strings.TrimSpace(asString(payload["arguments"]))
if strings.Contains(args, `}{"`) {
t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload)
}
seen[name] = true
}
if !seen["search_web"] || !seen["eval_javascript"] {
t.Fatalf("expected both tools in thinking channel done events, got %#v", seen)
}
}
func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("我来调用工具\n") +
sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
if !ok {
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
}
responseObj, _ := completed["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
hasFunctionCall := false
for _, item := range output {
m, _ := item.(map[string]any)
if m != nil && m["type"] == "function_call" {
hasFunctionCall = true
break
}
}
if !hasFunctionCall {
t.Fatalf("expected completed output to include function_call when mixed prose contains tool_calls payload, output=%#v", output)
}
}
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
scanner := bufio.NewScanner(strings.NewReader(body))
matched := false
@@ -120,3 +448,30 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
}
return nil, false
}
func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
scanner := bufio.NewScanner(strings.NewReader(body))
matched := false
out := make([]map[string]any, 0, 2)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "event: ") {
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
matched = evt == targetEvent
continue
}
if !matched || !strings.HasPrefix(line, "data: ") {
continue
}
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if raw == "" || raw == "[DONE]" {
continue
}
var payload map[string]any
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
continue
}
out = append(out, payload)
}
return out
}

View File

@@ -8,7 +8,7 @@ import (
"ds2api/internal/util"
)
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
@@ -23,7 +23,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
if responseModel == "" {
responseModel = resolvedModel
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
@@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
}, nil
}
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
model, _ := req["model"].(string)
model = strings.TrimSpace(model)
if model == "" {
@@ -67,7 +67,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (ut
if len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{

View File

@@ -22,7 +22,7 @@ func TestNormalizeOpenAIChatRequest(t *testing.T) {
"temperature": 0.3,
"stream": true,
}
n, err := normalizeOpenAIChatRequest(store, req)
n, err := normalizeOpenAIChatRequest(store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
@@ -47,7 +47,7 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
"input": "ping",
"instructions": "system",
}
n, err := normalizeOpenAIResponsesRequest(store, req)
n, err := normalizeOpenAIResponsesRequest(store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}

View File

@@ -0,0 +1,185 @@
package openai
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
chimw "github.com/go-chi/chi/v5/middleware"
"ds2api/internal/auth"
)
type streamStatusAuthStub struct{}
func (streamStatusAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
return &auth.RequestAuth{
UseConfigToken: false,
DeepSeekToken: "direct-token",
CallerID: "caller:test",
TriedAccounts: map[string]bool{},
}, nil
}
func (streamStatusAuthStub) DetermineCaller(_ *http.Request) (*auth.RequestAuth, error) {
return &auth.RequestAuth{
UseConfigToken: false,
DeepSeekToken: "direct-token",
CallerID: "caller:test",
TriedAccounts: map[string]bool{},
}, nil
}
func (streamStatusAuthStub) Release(_ *auth.RequestAuth) {}
type streamStatusDSStub struct {
resp *http.Response
}
func (m streamStatusDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
return "session-id", nil
}
func (m streamStatusDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
return "pow", nil
}
func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
return m.resp, nil
}
func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
body := strings.Join(lines, "\n")
if !strings.HasSuffix(body, "\n") {
body += "\n"
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func captureStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
next.ServeHTTP(ww, r)
*statuses = append(*statuses, ww.Status())
})
}
}
func TestChatCompletionsStreamStatusCapturedAs200(t *testing.T) {
statuses := make([]int, 0, 1)
h := &Handler{
Store: mockOpenAIConfig{wideInput: true},
Auth: streamStatusAuthStub{},
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
}
r := chi.NewRouter()
r.Use(captureStatusMiddleware(&statuses))
RegisterRoutes(r, h)
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}`
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
req.Header.Set("Authorization", "Bearer direct-token")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
if len(statuses) != 1 {
t.Fatalf("expected one captured status, got %d", len(statuses))
}
if statuses[0] != http.StatusOK {
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
}
}
func TestResponsesStreamStatusCapturedAs200(t *testing.T) {
statuses := make([]int, 0, 1)
h := &Handler{
Store: mockOpenAIConfig{wideInput: true},
Auth: streamStatusAuthStub{},
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
}
r := chi.NewRouter()
r.Use(captureStatusMiddleware(&statuses))
RegisterRoutes(r, h)
reqBody := `{"model":"deepseek-chat","input":"hi","stream":true}`
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
req.Header.Set("Authorization", "Bearer direct-token")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
if len(statuses) != 1 {
t.Fatalf("expected one captured status, got %d", len(statuses))
}
if statuses[0] != http.StatusOK {
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
}
}
func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
statuses := make([]int, 0, 1)
content, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
})
h := &Handler{
Store: mockOpenAIConfig{wideInput: true},
Auth: streamStatusAuthStub{},
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")},
}
r := chi.NewRouter()
r.Use(captureStatusMiddleware(&statuses))
RegisterRoutes(r, h)
reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}`
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
req.Header.Set("Authorization", "Bearer direct-token")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
if len(statuses) != 1 || statuses[0] != http.StatusOK {
t.Fatalf("expected captured status 200, got %#v", statuses)
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
}
outputText, _ := out["output_text"].(string)
if outputText != "" {
t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText)
}
output, _ := out["output"].([]any)
hasFunctionCall := false
for _, item := range output {
m, _ := item.(map[string]any)
if m != nil && m["type"] == "function_call" {
hasFunctionCall = true
break
}
}
if !hasFunctionCall {
t.Fatalf("expected function_call output item, got %#v", output)
}
}

View File

@@ -11,6 +11,7 @@ type toolStreamSieveState struct {
capture strings.Builder
capturing bool
recentTextTail string
disableDeltas bool
toolNameSent bool
toolName string
toolArgsStart int
@@ -35,6 +36,7 @@ const toolSieveCaptureLimit = 8 * 1024
const toolSieveContextTailLimit = 256
func (s *toolStreamSieveState) resetIncrementalToolState() {
s.disableDeltas = false
s.toolNameSent = false
s.toolName = ""
s.toolArgsStart = -1
@@ -239,17 +241,8 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
}
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
if len(parsed) == 0 {
if state.toolNameSent {
return prefixPart, nil, suffixPart, true
}
return captured, nil, "", true
}
if state.toolNameSent {
if len(parsed) > 1 {
return prefixPart, parsed[1:], suffixPart, true
}
return prefixPart, nil, suffixPart, true
}
return prefixPart, parsed, suffixPart, true
}
@@ -296,6 +289,9 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
}
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if state.disableDeltas {
return nil
}
captured := state.capture.String()
if captured == "" {
return nil
@@ -312,6 +308,16 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if insideCodeFence(state.recentTextTail + captured[:start]) {
return nil
}
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
if hasMultiple {
state.disableDeltas = true
return nil
}
if !certainSingle {
// In uncertain phases (e.g. first call arrived but array not closed yet),
// avoid speculative deltas and wait for final parsed tool_calls payload.
return nil
}
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
if !ok {
return nil
@@ -363,6 +369,68 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
return deltas
}
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
return false, false
}
i := skipSpaces(text, arrStart+1)
if i >= len(text) || text[i] != '{' {
return false, false
}
count := 0
depth := 0
quote := byte(0)
escaped := false
for ; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
if depth == 0 {
count++
if count > 1 {
return false, true
}
}
depth++
continue
}
if ch == '}' {
if depth > 0 {
depth--
}
continue
}
if ch == ',' && depth == 0 {
// top-level separator means at least one more tool call exists
// (or is expected). Treat as multi-call and stop incremental deltas.
return false, true
}
if ch == ']' && depth == 0 {
return count == 1, false
}
}
// array not closed yet: still uncertain whether more calls will appear
return false, false
}
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {

View File

@@ -0,0 +1,21 @@
package openai
import (
"net/http"
"strings"
"github.com/go-chi/chi/v5/middleware"
)
func requestTraceID(r *http.Request) string {
if r == nil {
return ""
}
if q := strings.TrimSpace(r.URL.Query().Get("__trace_id")); q != "" {
return q
}
if h := strings.TrimSpace(r.Header.Get("X-Ds2-Test-Trace")); h != "" {
return h
}
return strings.TrimSpace(middleware.GetReqID(r.Context()))
}

View File

@@ -0,0 +1,47 @@
package openai
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5/middleware"
)
func traceIDViaMiddleware(req *http.Request) string {
if req == nil {
return requestTraceID(nil)
}
var got string
h := middleware.RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
got = requestTraceID(r)
}))
h.ServeHTTP(httptest.NewRecorder(), req)
return got
}
func TestRequestTraceIDPriority(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions?__trace_id=query-trace", nil)
req.Header.Set("X-Ds2-Test-Trace", "header-trace")
got := traceIDViaMiddleware(req)
if got != "query-trace" {
t.Fatalf("expected query trace id to win, got %q", got)
}
}
func TestRequestTraceIDHeaderFallback(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
req.Header.Set("X-Ds2-Test-Trace", "header-trace")
got := traceIDViaMiddleware(req)
if got != "header-trace" {
t.Fatalf("expected header trace id to win when query missing, got %q", got)
}
}
func TestRequestTraceIDReqIDFallback(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
got := traceIDViaMiddleware(req)
if got == "" {
t.Fatal("expected middleware request id fallback to be non-empty")
}
}

View File

@@ -56,7 +56,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
return
}
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return

View File

@@ -36,5 +36,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
pr.Post("/vercel/sync", h.syncVercel)
pr.Get("/vercel/status", h.vercelStatus)
pr.Get("/export", h.exportConfig)
pr.Get("/dev/captures", h.getDevCaptures)
pr.Delete("/dev/captures", h.clearDevCaptures)
})
}

View File

@@ -0,0 +1,26 @@
package admin
import (
"net/http"
"ds2api/internal/devcapture"
)
func (h *Handler) getDevCaptures(w http.ResponseWriter, _ *http.Request) {
store := devcapture.Global()
writeJSON(w, http.StatusOK, map[string]any{
"enabled": store.Enabled(),
"limit": store.Limit(),
"max_body_bytes": store.MaxBodyBytes(),
"items": store.Snapshot(),
})
}
func (h *Handler) clearDevCaptures(w http.ResponseWriter, _ *http.Request) {
store := devcapture.Global()
store.Clear()
writeJSON(w, http.StatusOK, map[string]any{
"success": true,
"detail": "capture logs cleared",
})
}

View File

@@ -0,0 +1,45 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestGetDevCapturesShape(t *testing.T) {
h := &Handler{}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/admin/dev/captures", nil)
h.getDevCaptures(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode failed: %v", err)
}
if _, ok := out["enabled"]; !ok {
t.Fatalf("expected enabled field, got %#v", out)
}
if _, ok := out["items"]; !ok {
t.Fatalf("expected items field, got %#v", out)
}
}
func TestClearDevCapturesShape(t *testing.T) {
h := &Handler{}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodDelete, "/admin/dev/captures", nil)
h.clearDevCaptures(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode failed: %v", err)
}
if out["success"] != true {
t.Fatalf("expected success=true, got %#v", out)
}
}

View File

@@ -187,7 +187,12 @@ func extractCallerToken(req *http.Request) string {
return token
}
}
return strings.TrimSpace(req.Header.Get("x-api-key"))
if key := strings.TrimSpace(req.Header.Get("x-api-key")); key != "" {
return key
}
// Gemini AI Studio compatibility: allow query key fallback only when no
// header-based credential is present.
return strings.TrimSpace(req.URL.Query().Get("key"))
}
func callerTokenID(token string) string {

View File

@@ -114,6 +114,40 @@ func TestDetermineMissingToken(t *testing.T) {
}
}
func TestDetermineWithQueryKeyUsesDirectToken(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=direct-query-key", nil)
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
if a.UseConfigToken {
t.Fatalf("expected direct token mode")
}
if a.DeepSeekToken != "direct-query-key" {
t.Fatalf("unexpected token: %q", a.DeepSeekToken)
}
}
func TestDetermineHeaderTokenPrecedenceOverQueryKey(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=query-key", nil)
req.Header.Set("x-api-key", "managed-key")
a, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer r.Release(a)
if !a.UseConfigToken {
t.Fatalf("expected managed key mode from header token")
}
if a.AccountID == "" {
t.Fatalf("expected managed account to be acquired")
}
}
func TestDetermineCallerMissingToken(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)

View File

@@ -16,6 +16,7 @@ import (
"ds2api/internal/auth"
"ds2api/internal/config"
trans "ds2api/internal/deepseek/transport"
"ds2api/internal/devcapture"
"ds2api/internal/util"
"github.com/andybalholm/brotli"
@@ -27,6 +28,7 @@ var intFrom = util.IntFrom
type Client struct {
Store *config.Store
Auth *auth.Resolver
capture *devcapture.Store
regular trans.Doer
stream trans.Doer
fallback *http.Client
@@ -39,6 +41,7 @@ func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
return &Client{
Store: store,
Auth: resolver,
capture: devcapture.Global(),
regular: trans.New(60 * time.Second),
stream: trans.New(0),
fallback: &http.Client{Timeout: 60 * time.Second},
@@ -179,6 +182,7 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
}
headers := c.authHeaders(a.DeepSeekToken)
headers["x-ds-pow-response"] = powResp
captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
attempts := 0
for attempts < maxAttempts {
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
@@ -188,8 +192,14 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
continue
}
if resp.StatusCode == http.StatusOK {
if captureSession != nil {
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
}
return resp, nil
}
if captureSession != nil {
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
}
_ = resp.Body.Close()
attempts++
time.Sleep(time.Second)

View File

@@ -0,0 +1,259 @@
package devcapture
import (
"encoding/json"
"fmt"
"io"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
const (
defaultLimit = 5
defaultMaxBodyBytes = 2 * 1024 * 1024
maxLimit = 50
)
type Entry struct {
ID string `json:"id"`
CreatedAt int64 `json:"created_at"`
Label string `json:"label"`
URL string `json:"url"`
AccountID string `json:"account_id,omitempty"`
StatusCode int `json:"status_code"`
RequestBody string `json:"request_body"`
ResponseBody string `json:"response_body"`
ResponseTruncated bool `json:"response_truncated"`
}
type Store struct {
mu sync.Mutex
enabled bool
limit int
maxBodyBytes int
items []Entry
}
type Session struct {
store *Store
id string
createdAt int64
label string
url string
accountID string
requestRaw string
}
type captureBody struct {
rc io.ReadCloser
s *Session
statusCode int
buf strings.Builder
truncated bool
finalized bool
}
var (
globalOnce sync.Once
globalInst *Store
)
func Global() *Store {
globalOnce.Do(func() {
globalInst = NewFromEnv()
})
return globalInst
}
func NewFromEnv() *Store {
enabled := !isVercelRuntime()
if raw, ok := os.LookupEnv("DS2API_DEV_PACKET_CAPTURE"); ok {
enabled = parseBool(raw)
}
limit := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_LIMIT"), defaultLimit)
if limit < 1 {
limit = defaultLimit
}
if limit > maxLimit {
limit = maxLimit
}
maxBodyBytes := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES"), defaultMaxBodyBytes)
if maxBodyBytes < 1024 {
maxBodyBytes = defaultMaxBodyBytes
}
return &Store{
enabled: enabled,
limit: limit,
maxBodyBytes: maxBodyBytes,
items: make([]Entry, 0, limit),
}
}
func isVercelRuntime() bool {
return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
}
func (s *Store) Enabled() bool {
if s == nil {
return false
}
return s.enabled
}
func (s *Store) Limit() int {
if s == nil {
return defaultLimit
}
return s.limit
}
func (s *Store) MaxBodyBytes() int {
if s == nil {
return defaultMaxBodyBytes
}
return s.maxBodyBytes
}
func (s *Store) Snapshot() []Entry {
if s == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
out := make([]Entry, len(s.items))
copy(out, s.items)
return out
}
func (s *Store) Clear() {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.items = s.items[:0]
}
func (s *Store) Start(label, url, accountID string, requestPayload any) *Session {
if s == nil || !s.enabled {
return nil
}
return &Session{
store: s,
id: "cap_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
createdAt: time.Now().Unix(),
label: strings.TrimSpace(label),
url: strings.TrimSpace(url),
accountID: strings.TrimSpace(accountID),
requestRaw: marshalPayload(requestPayload),
}
}
func (s *Session) WrapBody(rc io.ReadCloser, statusCode int) io.ReadCloser {
if s == nil || rc == nil {
return rc
}
return &captureBody{
rc: rc,
s: s,
statusCode: statusCode,
}
}
func (c *captureBody) Read(p []byte) (int, error) {
n, err := c.rc.Read(p)
if n > 0 {
c.append(string(p[:n]))
}
if err == io.EOF {
c.finalize()
}
return n, err
}
func (c *captureBody) Close() error {
err := c.rc.Close()
c.finalize()
return err
}
func (c *captureBody) append(chunk string) {
if chunk == "" || c.s == nil || c.s.store == nil {
return
}
maxLen := c.s.store.maxBodyBytes
current := c.buf.Len()
if current >= maxLen {
c.truncated = true
return
}
remain := maxLen - current
if len(chunk) > remain {
c.buf.WriteString(chunk[:remain])
c.truncated = true
return
}
c.buf.WriteString(chunk)
}
func (c *captureBody) finalize() {
if c.finalized || c.s == nil || c.s.store == nil {
return
}
c.finalized = true
entry := Entry{
ID: c.s.id,
CreatedAt: c.s.createdAt,
Label: c.s.label,
URL: c.s.url,
AccountID: c.s.accountID,
StatusCode: c.statusCode,
RequestBody: c.s.requestRaw,
ResponseBody: c.buf.String(),
ResponseTruncated: c.truncated,
}
c.s.store.push(entry)
}
func (s *Store) push(entry Entry) {
s.mu.Lock()
defer s.mu.Unlock()
s.items = append([]Entry{entry}, s.items...)
if len(s.items) > s.limit {
s.items = s.items[:s.limit]
}
}
func marshalPayload(v any) string {
b, err := json.Marshal(v)
if err != nil {
return fmt.Sprintf("%v", v)
}
return string(b)
}
func parseBool(v string) bool {
switch strings.ToLower(strings.TrimSpace(v)) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func parseIntWithDefault(raw string, d int) int {
raw = strings.TrimSpace(raw)
if raw == "" {
return d
}
n, err := strconv.Atoi(raw)
if err != nil {
return d
}
return n
}

View File

@@ -0,0 +1,55 @@
package devcapture
import (
"io"
"strings"
"testing"
)
func TestStorePushKeepsNewestWithinLimit(t *testing.T) {
s := &Store{enabled: true, limit: 2, maxBodyBytes: 1024}
for i := 0; i < 3; i++ {
session := s.Start("test", "http://x", "", map[string]any{"seq": i})
if session == nil {
t.Fatal("expected session")
}
rc := session.WrapBody(io.NopCloser(strings.NewReader("ok")), 200)
_, _ = io.ReadAll(rc)
_ = rc.Close()
}
items := s.Snapshot()
if len(items) != 2 {
t.Fatalf("expected 2 items, got %d", len(items))
}
if !strings.Contains(items[0].RequestBody, `"seq":2`) {
t.Fatalf("expected newest first, got %#v", items[0].RequestBody)
}
if !strings.Contains(items[1].RequestBody, `"seq":1`) {
t.Fatalf("expected second newest, got %#v", items[1].RequestBody)
}
}
func TestWrapBodyTruncatesByLimit(t *testing.T) {
s := &Store{enabled: true, limit: 5, maxBodyBytes: 4}
session := s.Start("test", "http://x", "acc1", map[string]any{"x": 1})
if session == nil {
t.Fatal("expected session")
}
rc := session.WrapBody(io.NopCloser(strings.NewReader("abcdef")), 200)
_, _ = io.ReadAll(rc)
_ = rc.Close()
items := s.Snapshot()
if len(items) != 1 {
t.Fatalf("expected 1 item, got %d", len(items))
}
if items[0].ResponseBody != "abcd" {
t.Fatalf("expected truncated body, got %q", items[0].ResponseBody)
}
if !items[0].ResponseTruncated {
t.Fatal("expected truncated flag true")
}
if items[0].AccountID != "acc1" {
t.Fatalf("expected account id, got %q", items[0].AccountID)
}
}

View File

@@ -1,6 +1,7 @@
package openai
import (
"encoding/json"
"strings"
"time"
@@ -43,36 +44,45 @@ func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalT
}
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
// Align responses tool-call semantics with chat/completions:
// mixed prose + tool_call payloads should still be interpreted as tool calls.
detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
detected = util.ParseToolCalls(finalThinking, toolNames)
}
exposedOutputText := finalText
output := make([]any, 0, 2)
if len(detected) > 0 {
exposedOutputText = ""
toolCalls := make([]any, 0, len(detected))
for _, tc := range detected {
toolCalls = append(toolCalls, map[string]any{
"type": "tool_call",
"name": tc.Name,
"arguments": tc.Input,
if strings.TrimSpace(finalThinking) != "" {
output = append(output, map[string]any{
"type": "reasoning",
"text": finalThinking,
})
}
formatted := util.FormatOpenAIToolCalls(detected)
output = append(output, toResponsesFunctionCallItems(formatted)...)
output = append(output, map[string]any{
"type": "tool_calls",
"tool_calls": toolCalls,
"tool_calls": formatted,
})
} else {
content := []any{
map[string]any{
"type": "output_text",
"text": finalText,
},
}
content := make([]any, 0, 2)
if finalThinking != "" {
content = append([]any{map[string]any{
"type": "reasoning",
"text": finalThinking,
}}, content...)
}
if strings.TrimSpace(finalText) != "" {
content = append(content, map[string]any{
"type": "output_text",
"text": finalText,
})
}
if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" {
exposedOutputText = finalThinking
}
output = append(output, map[string]any{
"type": "message",
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
@@ -100,6 +110,54 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
}
}
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
if len(toolCalls) == 0 {
return nil
}
out := make([]any, 0, len(toolCalls))
for _, tc := range toolCalls {
callID, _ := tc["id"].(string)
if strings.TrimSpace(callID) == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
}
name := ""
args := "{}"
if fn, ok := tc["function"].(map[string]any); ok {
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
name = n
}
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
args = a
}
}
out = append(out, map[string]any{
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"type": "function_call",
"call_id": callID,
"name": name,
"arguments": normalizeJSONString(args),
"status": "completed",
})
}
return out
}
func normalizeJSONString(raw string) string {
s := strings.TrimSpace(raw)
if s == "" {
return "{}"
}
var v any
if err := json.Unmarshal([]byte(s), &v); err != nil {
return raw
}
b, err := json.Marshal(v)
if err != nil {
return raw
}
return string(b)
}
func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
return map[string]any{
"delta": delta,
@@ -145,49 +203,105 @@ func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
return map[string]any{
"type": "response.created",
"id": responseID,
"object": "response",
"model": model,
"status": "in_progress",
"type": "response.created",
"id": responseID,
"response_id": responseID,
"object": "response",
"model": model,
"status": "in_progress",
}
}
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.output_text.delta",
"id": responseID,
"delta": delta,
"type": "response.output_text.delta",
"id": responseID,
"response_id": responseID,
"delta": delta,
}
}
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.reasoning.delta",
"id": responseID,
"delta": delta,
"type": "response.reasoning.delta",
"id": responseID,
"response_id": responseID,
"delta": delta,
}
}
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.delta",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"delta": delta,
}
}
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"text": text,
}
}
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.delta",
"id": responseID,
"tool_calls": toolCalls,
"type": "response.output_tool_call.delta",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.done",
"id": responseID,
"tool_calls": toolCalls,
"type": "response.output_tool_call.done",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
return map[string]any{
"type": "response.function_call_arguments.delta",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"call_id": callID,
"delta": delta,
}
}
func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any {
return map[string]any{
"type": "response.function_call_arguments.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"call_id": callID,
"name": name,
"arguments": normalizeJSONString(arguments),
}
}
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
responseID, _ := response["id"].(string)
return map[string]any{
"type": "response.completed",
"response": response,
"type": "response.completed",
"response_id": responseID,
"response": response,
}
}

View File

@@ -0,0 +1,181 @@
package openai
import (
"encoding/json"
"testing"
)
func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
obj := BuildResponseObject(
"resp_test",
"gpt-4o",
"prompt",
"",
`{"tool_calls":[{"name":"search","input":{"q":"golang"}}]}`,
[]string{"search"},
)
outputText, _ := obj["output_text"].(string)
if outputText != "" {
t.Fatalf("expected output_text to be hidden for tool calls, got %q", outputText)
}
output, _ := obj["output"].([]any)
if len(output) != 2 {
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "function_call" {
t.Fatalf("expected first output item type function_call, got %#v", first["type"])
}
if first["call_id"] == "" {
t.Fatalf("expected function_call item to have call_id, got %#v", first)
}
second, _ := output[1].(map[string]any)
if second["type"] != "tool_calls" {
t.Fatalf("expected second output item type tool_calls, got %#v", second["type"])
}
var toolCalls []map[string]any
switch v := second["tool_calls"].(type) {
case []map[string]any:
toolCalls = v
case []any:
toolCalls = make([]map[string]any, 0, len(v))
for _, item := range v {
m, _ := item.(map[string]any)
if m != nil {
toolCalls = append(toolCalls, m)
}
}
}
if len(toolCalls) != 1 {
t.Fatalf("expected one tool call, got %#v", second["tool_calls"])
}
tc := toolCalls[0]
if tc["type"] != "function" || tc["id"] == "" {
t.Fatalf("unexpected tool call shape: %#v", tc)
}
fn, _ := tc["function"].(map[string]any)
if fn["name"] != "search" {
t.Fatalf("unexpected function name: %#v", fn["name"])
}
argsRaw, _ := fn["arguments"].(string)
var args map[string]any
if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
}
if args["q"] != "golang" {
t.Fatalf("unexpected arguments: %#v", args)
}
}
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) {
obj := BuildResponseObject(
"resp_test",
"gpt-4o",
"prompt",
"",
`示例格式:{"tool_calls":[{"name":"search","input":{"q":"golang"}}]},但这条是普通回答。`,
[]string{"search"},
)
outputText, _ := obj["output_text"].(string)
if outputText != "" {
t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText)
}
output, _ := obj["output"].([]any)
if len(output) != 2 {
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "function_call" {
t.Fatalf("expected first output type function_call, got %#v", first["type"])
}
}
func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) {
obj := BuildResponseObject(
"resp_test",
"gpt-4o",
"prompt",
"",
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
[]string{"search"},
)
outputText, _ := obj["output_text"].(string)
if outputText == "" {
t.Fatalf("expected output_text preserved for fenced example")
}
output, _ := obj["output"].([]any)
if len(output) != 1 {
t.Fatalf("expected one message output item, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "message" {
t.Fatalf("expected message output type, got %#v", first["type"])
}
}
func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
obj := BuildResponseObject(
"resp_test",
"gpt-4o",
"prompt",
"internal thinking content",
"",
nil,
)
outputText, _ := obj["output_text"].(string)
if outputText == "" {
t.Fatalf("expected output_text fallback from reasoning when final text is empty")
}
output, _ := obj["output"].([]any)
if len(output) != 1 {
t.Fatalf("expected one output item, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "message" {
t.Fatalf("expected output type message, got %#v", first["type"])
}
content, _ := first["content"].([]any)
if len(content) == 0 {
t.Fatalf("expected reasoning content, got %#v", first["content"])
}
block0, _ := content[0].(map[string]any)
if block0["type"] != "reasoning" {
t.Fatalf("expected first content block reasoning, got %#v", block0["type"])
}
}
func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
obj := BuildResponseObject(
"resp_test",
"gpt-4o",
"prompt",
`{"tool_calls":[{"name":"search","input":{"q":"from-thinking"}}]}`,
"",
[]string{"search"},
)
output, _ := obj["output"].([]any)
if len(output) != 3 {
t.Fatalf("expected reasoning + function_call + tool_calls outputs, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "reasoning" {
t.Fatalf("expected first output reasoning, got %#v", first["type"])
}
second, _ := output[1].(map[string]any)
if second["type"] != "function_call" {
t.Fatalf("expected second output function_call, got %#v", second["type"])
}
third, _ := output[2].(map[string]any)
if third["type"] != "tool_calls" {
t.Fatalf("expected third output tool_calls, got %#v", third["type"])
}
}

View File

@@ -12,6 +12,7 @@ import (
"ds2api/internal/account"
"ds2api/internal/adapter/claude"
"ds2api/internal/adapter/gemini"
"ds2api/internal/adapter/openai"
"ds2api/internal/admin"
"ds2api/internal/auth"
@@ -44,6 +45,7 @@ func NewApp() *App {
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient}
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient}
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient}
webuiHandler := webui.NewHandler()
@@ -67,6 +69,7 @@ func NewApp() *App {
})
openai.RegisterRoutes(r, openaiHandler)
claude.RegisterRoutes(r, claudeHandler)
gemini.RegisterRoutes(r, geminiHandler)
r.Route("/admin", func(ar chi.Router) {
admin.RegisterRoutes(ar, adminHandler)
})

8
tests/scripts/run-unit-all.sh Executable file
View File

@@ -0,0 +1,8 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
cd "$ROOT_DIR"
./tests/scripts/run-unit-go.sh
./tests/scripts/run-unit-node.sh

7
tests/scripts/run-unit-go.sh Executable file
View File

@@ -0,0 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
cd "$ROOT_DIR"
go test ./... "$@"

7
tests/scripts/run-unit-node.sh Executable file
View File

@@ -0,0 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
cd "$ROOT_DIR"
node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js "$@"

View File

@@ -38,6 +38,18 @@
"source": "/admin/config",
"destination": "/api/index"
},
{
"source": "/admin/config/(.*)",
"destination": "/api/index"
},
{
"source": "/admin/settings",
"destination": "/api/index"
},
{
"source": "/admin/settings/(.*)",
"destination": "/api/index"
},
{
"source": "/admin/keys(.*)",
"destination": "/api/index"

View File

@@ -1,4 +1,4 @@
import { useState, useEffect } from 'react'
import { useState, useEffect, useCallback, useMemo } from 'react'
import {
Routes,
Route,
@@ -29,12 +29,12 @@ import Login from './components/Login'
import LandingPage from './components/LandingPage'
import LanguageToggle from './components/LanguageToggle'
import { useI18n } from './i18n'
import { detectRuntimeEnv } from './utils/runtimeEnv'
function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout }) {
function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout, isVercel }) {
const { t } = useI18n()
const [activeTab, setActiveTab] = useState('accounts')
const [sidebarOpen, setSidebarOpen] = useState(false)
const [loading, setLoading] = useState(false)
const navItems = [
{ id: 'accounts', label: t('nav.accounts.label'), icon: Users, description: t('nav.accounts.desc') },
@@ -44,7 +44,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message,
{ id: 'settings', label: t('nav.settings.label'), icon: SettingsIcon, description: t('nav.settings.desc') },
]
const authFetch = async (url, options = {}) => {
const authFetch = useCallback(async (url, options = {}) => {
const headers = {
...options.headers,
'Authorization': `Bearer ${token}`
@@ -56,7 +56,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message,
throw new Error(t('auth.expired'))
}
return res
}
}, [onLogout, t, token])
const renderTab = () => {
switch (activeTab) {
@@ -67,9 +67,9 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message,
case 'import':
return <BatchImport onRefresh={fetchConfig} onMessage={showMessage} authFetch={authFetch} />
case 'vercel':
return <VercelSync onMessage={showMessage} authFetch={authFetch} />
return <VercelSync onMessage={showMessage} authFetch={authFetch} isVercel={isVercel} />
case 'settings':
return <Settings onRefresh={fetchConfig} onMessage={showMessage} authFetch={authFetch} onForceLogout={onForceLogout} />
return <Settings onRefresh={fetchConfig} onMessage={showMessage} authFetch={authFetch} onForceLogout={onForceLogout} isVercel={isVercel} />
default:
return null
}
@@ -213,13 +213,27 @@ export default function App() {
const navigate = useNavigate()
const location = useLocation()
const [config, setConfig] = useState({ keys: [], accounts: [] })
const [loading, setLoading] = useState(true)
const [message, setMessage] = useState(null)
const [token, setToken] = useState(null)
const [authChecking, setAuthChecking] = useState(true)
const isProduction = import.meta.env.MODE === 'production'
const isAdminRoute = location.pathname.startsWith('/admin') || isProduction
const runtimeEnv = useMemo(() => detectRuntimeEnv(), [])
const isVercel = runtimeEnv.isVercel
const showMessage = useCallback((type, text) => {
setMessage({ type, text })
setTimeout(() => setMessage(null), 5000)
}, [])
const handleLogout = useCallback(() => {
setToken(null)
localStorage.removeItem('ds2api_token')
localStorage.removeItem('ds2api_token_expires')
sessionStorage.removeItem('ds2api_token')
sessionStorage.removeItem('ds2api_token_expires')
}, [])
useEffect(() => {
// Only check auth status on admin routes.
@@ -249,12 +263,11 @@ export default function App() {
setAuthChecking(false)
}
checkAuth()
}, [isAdminRoute])
}, [handleLogout, isAdminRoute])
const fetchConfig = async () => {
const fetchConfig = useCallback(async () => {
if (!token) return
try {
setLoading(true)
const res = await fetch('/admin/config', {
headers: { 'Authorization': `Bearer ${token}` }
})
@@ -265,34 +278,19 @@ export default function App() {
} catch (e) {
console.error('Failed to fetch config:', e)
showMessage('error', t('errors.fetchConfig', { error: e.message }))
} finally {
setLoading(false)
}
}
}, [showMessage, t, token])
useEffect(() => {
if (token) {
fetchConfig()
}
}, [token])
const showMessage = (type, text) => {
setMessage({ type, text })
setTimeout(() => setMessage(null), 5000)
}
}, [fetchConfig, token])
const handleLogin = (newToken) => {
setToken(newToken)
}
const handleLogout = () => {
setToken(null)
localStorage.removeItem('ds2api_token')
localStorage.removeItem('ds2api_token_expires')
sessionStorage.removeItem('ds2api_token')
sessionStorage.removeItem('ds2api_token_expires')
}
// Wait for auth checks on admin routes.
if (isAdminRoute && authChecking) {
return (
@@ -320,6 +318,7 @@ export default function App() {
showMessage={showMessage}
message={message}
onForceLogout={handleLogout}
isVercel={isVercel}
/>
) : (
<div className="min-h-screen flex flex-col bg-background relative overflow-hidden">

View File

@@ -2,7 +2,9 @@ import { useCallback, useEffect, useMemo, useState } from 'react'
import { AlertTriangle, Download, Lock, Save, Upload } from 'lucide-react'
import { useI18n } from '../i18n'
export default function Settings({ onRefresh, onMessage, authFetch, onForceLogout }) {
const MAX_AUTO_FETCH_FAILURES = 3
export default function Settings({ onRefresh, onMessage, authFetch, onForceLogout, isVercel = false }) {
const { t } = useI18n()
const apiFetch = authFetch || fetch
@@ -14,6 +16,9 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou
const [importMode, setImportMode] = useState('merge')
const [importText, setImportText] = useState('')
const [newPassword, setNewPassword] = useState('')
const [consecutiveFailures, setConsecutiveFailures] = useState(0)
const [autoFetchPaused, setAutoFetchPaused] = useState(false)
const [lastError, setLastError] = useState('')
const [settingsMeta, setSettingsMeta] = useState({ default_password_warning: false, env_backed: false, needs_vercel_sync: false })
const [form, setForm] = useState({
@@ -43,15 +48,38 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou
return parsed
}
const loadSettings = useCallback(async () => {
const parseJSONResponse = useCallback(async (res) => {
const contentType = String(res.headers.get('content-type') || '').toLowerCase()
if (!contentType.includes('application/json')) {
throw new Error(t('settings.nonJsonResponse', { status: res.status }))
}
return res.json()
}, [t])
const loadSettings = useCallback(async ({ manual = false } = {}) => {
if (isVercel && autoFetchPaused && !manual) {
return
}
setLoading(true)
try {
const res = await apiFetch('/admin/settings')
const data = await res.json()
const data = await parseJSONResponse(res)
if (!res.ok) {
onMessage('error', data.detail || t('settings.loadFailed'))
const detail = data.detail || t('settings.loadFailed')
setLastError(detail)
onMessage('error', detail)
setConsecutiveFailures((prev) => {
const next = prev + 1
if (isVercel && next >= MAX_AUTO_FETCH_FAILURES) {
setAutoFetchPaused(true)
}
return next
})
return
}
setConsecutiveFailures(0)
setAutoFetchPaused(false)
setLastError('')
setSettingsMeta({
default_password_warning: Boolean(data.admin?.default_password_warning),
env_backed: Boolean(data.env_backed),
@@ -78,18 +106,32 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou
model_aliases_text: JSON.stringify(data.model_aliases || {}, null, 2),
})
} catch (e) {
onMessage('error', t('settings.loadFailed'))
const detail = e?.message || t('settings.loadFailed')
setLastError(detail)
onMessage('error', detail)
setConsecutiveFailures((prev) => {
const next = prev + 1
if (isVercel && next >= MAX_AUTO_FETCH_FAILURES) {
setAutoFetchPaused(true)
}
return next
})
// eslint-disable-next-line no-console
console.error(e)
} finally {
setLoading(false)
}
}, [apiFetch, onMessage, t])
}, [apiFetch, autoFetchPaused, isVercel, onMessage, parseJSONResponse, t])
useEffect(() => {
loadSettings()
}, [loadSettings])
const retryLoadSettings = () => {
setAutoFetchPaused(false)
loadSettings({ manual: true })
}
const saveSettings = async () => {
let claudeMapping = {}
let modelAliases = {}
@@ -228,6 +270,23 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou
return (
<div className="space-y-6">
{autoFetchPaused && (
<div className="p-4 rounded-lg border border-destructive/30 bg-destructive/10 text-destructive flex items-center justify-between gap-4">
<div className="flex items-center gap-2">
<AlertTriangle className="w-4 h-4" />
<span className="text-sm">
{t('settings.autoFetchPaused', { count: consecutiveFailures, error: lastError || t('settings.loadFailed') })}
</span>
</div>
<button
type="button"
onClick={retryLoadSettings}
className="px-3 py-1.5 text-xs rounded-md border border-destructive/40 hover:bg-destructive/10"
>
{t('settings.retryLoad')}
</button>
</div>
)}
{settingsMeta.default_password_warning && (
<div className="p-4 rounded-lg border border-amber-300/30 bg-amber-500/10 text-amber-700 flex items-center gap-2">
<AlertTriangle className="w-4 h-4" />

View File

@@ -3,7 +3,15 @@ import { Cloud, ArrowRight, ExternalLink, Info, CheckCircle2, XCircle, RefreshCw
import clsx from 'clsx'
import { useI18n } from '../i18n'
export default function VercelSync({ onMessage, authFetch }) {
const MAX_POLL_FAILURES = 3
function pollDelayMs(attempt) {
if (attempt <= 0) return 15000
if (attempt === 1) return 30000
return 60000
}
export default function VercelSync({ onMessage, authFetch, isVercel = false }) {
const { t } = useI18n()
const [vercelToken, setVercelToken] = useState('')
const [projectId, setProjectId] = useState('')
@@ -12,20 +20,42 @@ export default function VercelSync({ onMessage, authFetch }) {
const [result, setResult] = useState(null)
const [preconfig, setPreconfig] = useState(null)
const [syncStatus, setSyncStatus] = useState(null)
const [pollPaused, setPollPaused] = useState(false)
const [pollFailures, setPollFailures] = useState(0)
const [nextRetryAt, setNextRetryAt] = useState(null)
const apiFetch = authFetch || fetch
const fetchSyncStatus = useCallback(async () => {
const fetchSyncStatus = useCallback(async ({ manual = false } = {}) => {
try {
const res = await apiFetch('/admin/vercel/status')
if (res.ok) {
const data = await res.json()
setSyncStatus(data)
if (!res.ok) {
throw new Error(`status ${res.status}`)
}
const data = await res.json()
setSyncStatus(data)
setPollFailures(0)
setPollPaused(false)
setNextRetryAt(null)
} catch (e) {
setPollFailures((prev) => {
const next = prev + 1
if (isVercel) {
if (next >= MAX_POLL_FAILURES) {
setPollPaused(true)
setNextRetryAt(null)
} else {
setNextRetryAt(Date.now() + pollDelayMs(next))
}
}
return next
})
if (manual) {
onMessage('error', t('vercel.networkError'))
}
console.error('Failed to fetch sync status:', e)
}
}, [apiFetch])
}, [apiFetch, isVercel, onMessage, t])
useEffect(() => {
const loadPreconfig = async () => {
@@ -43,11 +73,32 @@ export default function VercelSync({ onMessage, authFetch }) {
}
loadPreconfig()
fetchSyncStatus()
// Poll every 15s to detect config changes
const interval = setInterval(fetchSyncStatus, 15000)
return () => clearInterval(interval)
}, [fetchSyncStatus])
useEffect(() => {
if (!isVercel) {
const interval = setInterval(() => {
fetchSyncStatus()
}, 15000)
return () => clearInterval(interval)
}
if (pollPaused) {
return undefined
}
const delay = nextRetryAt && nextRetryAt > Date.now() ? nextRetryAt - Date.now() : pollDelayMs(pollFailures)
const timer = setTimeout(() => {
fetchSyncStatus()
}, Math.max(1000, delay))
return () => clearTimeout(timer)
}, [fetchSyncStatus, isVercel, nextRetryAt, pollFailures, pollPaused])
const handleManualRefresh = () => {
setPollPaused(false)
setPollFailures(0)
setNextRetryAt(null)
fetchSyncStatus({ manual: true })
}
const handleSync = async () => {
const tokenToUse = preconfig?.has_token && !vercelToken ? '__USE_PRECONFIG__' : vercelToken
@@ -122,6 +173,20 @@ export default function VercelSync({ onMessage, authFetch }) {
<p className="text-muted-foreground text-sm mt-1">
{t('vercel.description')}
</p>
{pollPaused && (
<div className="mt-2 flex flex-wrap items-center gap-2">
<p className="text-xs text-destructive">
{t('vercel.pollPaused', { count: pollFailures })}
</p>
<button
type="button"
onClick={handleManualRefresh}
className="px-2 py-1 text-xs rounded border border-border hover:bg-secondary/50"
>
{t('vercel.manualRefresh')}
</button>
</div>
)}
{syncStatus?.last_sync_time && (
<p className="text-xs text-muted-foreground/60 mt-1.5 flex items-center gap-1">
<RefreshCw className="w-3 h-3" />

View File

@@ -198,6 +198,7 @@
},
"settings": {
"loadFailed": "Failed to load settings.",
"nonJsonResponse": "Unexpected non-JSON response from server (status: {status}).",
"save": "Save settings",
"saving": "Saving...",
"saveSuccess": "Settings saved and hot reloaded.",
@@ -239,7 +240,9 @@
"exportJson": "Export JSON",
"invalidJsonField": "{field} is not a valid JSON object.",
"defaultPasswordWarning": "You are using the default admin password \"admin\". Please change it.",
"vercelSyncHint": "Configuration changed. For Vercel deployments, sync manually in Vercel Sync and redeploy."
"vercelSyncHint": "Configuration changed. For Vercel deployments, sync manually in Vercel Sync and redeploy.",
"autoFetchPaused": "Auto loading paused after {count} failures: {error}",
"retryLoad": "Retry now"
},
"login": {
"welcome": "Welcome back",
@@ -278,6 +281,8 @@
"statusNotSynced": "Not synced",
"statusNeverSynced": "Never synced",
"lastSyncTime": "Last sync: {time}",
"pollPaused": "Status polling paused after {count} failures.",
"manualRefresh": "Refresh manually",
"howItWorks": "How it works",
"steps": {
"one": "The current configuration (keys and accounts) is exported as JSON.",

View File

@@ -198,6 +198,7 @@
},
"settings": {
"loadFailed": "加载设置失败",
"nonJsonResponse": "服务端返回了非 JSON 响应(状态码:{status}",
"save": "保存设置",
"saving": "保存中...",
"saveSuccess": "设置已保存并热更新生效",
@@ -239,7 +240,9 @@
"exportJson": "导出 JSON",
"invalidJsonField": "{field} 不是有效 JSON 对象",
"defaultPasswordWarning": "当前使用默认密码 admin请尽快在此修改。",
"vercelSyncHint": "当前配置已更新。Vercel 部署请到 Vercel 同步页面手动同步并重部署。"
"vercelSyncHint": "当前配置已更新。Vercel 部署请到 Vercel 同步页面手动同步并重部署。",
"autoFetchPaused": "自动加载已暂停:连续失败 {count} 次({error}",
"retryLoad": "立即重试"
},
"login": {
"welcome": "欢迎回来",
@@ -278,6 +281,8 @@
"statusNotSynced": "未同步",
"statusNeverSynced": "从未同步",
"lastSyncTime": "上次同步: {time}",
"pollPaused": "状态轮询已暂停:连续失败 {count} 次。",
"manualRefresh": "手动刷新",
"howItWorks": "工作原理",
"steps": {
"one": "当前配置 (密钥和账号) 被导出为 JSON 字符串。",

View File

@@ -0,0 +1,13 @@
export function detectRuntimeEnv() {
const deployTarget = String(import.meta.env.VITE_DEPLOY_TARGET || '').trim().toLowerCase()
if (deployTarget === 'vercel') {
return { isVercel: true, source: 'vite_env' }
}
const host = typeof window !== 'undefined' ? String(window.location.hostname || '').toLowerCase() : ''
if (host.includes('vercel.app')) {
return { isVercel: true, source: 'hostname' }
}
return { isVercel: false, source: 'default' }
}