diff --git a/API.en.md b/API.en.md index b910723..0767f49 100644 --- a/API.en.md +++ b/API.en.md @@ -304,7 +304,7 @@ event: response.content_part.added data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...} event: response.output_text.delta -data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."} +data: {"type":"response.output_text.delta","response_id":"resp_xxx","item_id":"msg_xxx","output_index":0,"content_index":0,"delta":"..."} event: response.function_call_arguments.delta data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."} diff --git a/API.md b/API.md index 6ec117e..c102b53 100644 --- a/API.md +++ b/API.md @@ -304,7 +304,7 @@ event: response.content_part.added data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...} event: response.output_text.delta -data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."} +data: {"type":"response.output_text.delta","response_id":"resp_xxx","item_id":"msg_xxx","output_index":0,"content_index":0,"delta":"..."} event: response.function_call_arguments.delta data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."} diff --git a/README.en.md b/README.en.md index 1de40a7..34208e8 100644 --- a/README.en.md +++ b/README.en.md @@ -8,13 +8,13 @@ Language: [δΈ­ζ–‡](README.MD) | [English](README.en.md) -DS2API converts DeepSeek Web chat capability into OpenAI-compatible and Claude-compatible APIs. The backend is a **pure Go implementation**, with a React WebUI admin panel (source in `webui/`, build output auto-generated to `static/admin` during deployment). +DS2API converts DeepSeek Web chat capability into OpenAI-compatible, Claude-compatible, and Gemini-compatible APIs. The backend is a **pure Go implementation**, with a React WebUI admin panel (source in `webui/`, build output auto-generated to `static/admin` during deployment). ## Architecture Overview ```mermaid flowchart LR - Client["πŸ–₯️ Clients\n(OpenAI / Claude compat)"] + Client["πŸ–₯️ Clients\n(OpenAI / Claude / Gemini compat)"] subgraph DS2API["DS2API Service"] direction TB @@ -24,6 +24,7 @@ flowchart LR subgraph Adapters["Adapter Layer"] OA["OpenAI Adapter\n/v1/*"] CA["Claude Adapter\n/anthropic/*"] + GA["Gemini Adapter\n/v1beta/models/*"] end subgraph Support["Support Modules"] @@ -38,11 +39,11 @@ flowchart LR DS["☁️ DeepSeek API"] Client -- "Request" --> CORS --> Auth - Auth --> OA & CA - OA & CA -- "Call" --> DS + Auth --> OA & CA & GA + OA & CA & GA -- "Call" --> DS Auth --> Admin - OA & CA -. "Rotate accounts" .-> Pool - OA & CA -. "Compute PoW" .-> PoW + OA & CA & GA -. "Rotate accounts" .-> Pool + OA & CA & GA -. "Compute PoW" .-> PoW DS -- "Response" --> Client ``` @@ -55,12 +56,13 @@ flowchart LR | Capability | Details | | --- | --- | | OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` | -| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` | +| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` (plus shortcut paths `/v1/messages`, `/messages`) | +| Gemini compatible | `POST /v1beta/models/{model}:generateContent`, `POST /v1beta/models/{model}:streamGenerateContent` (plus `/v1/models/{model}:*` paths) | | Multi-account rotation | Auto token refresh, email/mobile dual login | | Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency | | DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency | | Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output | -| Admin API | Config management, account testing/batch test, import/export, Vercel sync | +| Admin API | Config management, runtime settings hot-reload, account testing/batch test, import/export, Vercel sync | | WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) | | Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) | @@ -72,6 +74,7 @@ flowchart LR | P0 | OpenAI SDK (JS/Python, chat + responses) | βœ… | | P0 | Vercel AI SDK (openai-compatible) | βœ… | | P0 | Anthropic SDK (messages) | βœ… | +| P0 | Google Gemini SDK (generateContent) | βœ… | | P1 | LangChain / LlamaIndex / OpenWebUI (OpenAI-compatible integration) | βœ… | | P2 | MCP standalone bridge | Planned | @@ -97,6 +100,10 @@ flowchart LR Override mapping via `claude_mapping` or `claude_model_mapping` in config. In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility. +### Gemini Endpoint + +The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` β†’ `functionCall` output). + ## Quick Start ### Universal First Step (all deployment modes) @@ -249,6 +256,14 @@ cp opencode.json.example opencode.json "claude_model_mapping": { "fast": "deepseek-chat", "slow": "deepseek-reasoner" + }, + "admin": { + "jwt_expire_hours": 24 + }, + "runtime": { + "account_max_inflight": 2, + "account_max_queue": 0, + "global_max_inflight": 0 } } ``` @@ -262,6 +277,8 @@ cp opencode.json.example opencode.json - `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}` - `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in) - `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models +- `admin`: Admin panel settings (JWT expiry, password hash, etc.), hot-reloadable via Admin Settings API +- `runtime`: Runtime parameters (concurrency limits, queue sizes), hot-reloadable via Admin Settings API ### Environment Variables @@ -293,7 +310,7 @@ cp opencode.json.example opencode.json ## Authentication Modes -For business endpoints (`/v1/*`, `/anthropic/*`), DS2API supports two modes: +For business endpoints (`/v1/*`, `/anthropic/*`, Gemini routes), DS2API supports two modes: | Mode | Description | | --- | --- | @@ -320,10 +337,10 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency) When `tools` is present in the request, DS2API performs anti-leak handling: 1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored) -2. In `responses` stream mode, tool calls follow official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`) -3. Unknown tool names (outside declared `tools`) are rejected and are not emitted as valid tool calls -4. `tool_choice` is enforced on `responses` (`auto`/`none`/`required`/forced function); required violations return HTTP `422` (non-stream) or `response.failed` (stream) -5. Confirmed toolcall JSON fragments are never emitted as valid tool call events unless they pass policy checks +2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`) +3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls +4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream +5. Valid tool call events are only emitted after passing policy validation, preventing invalid tool names from entering the client execution chain ## Local Dev Packet Capture @@ -363,13 +380,20 @@ ds2api/ β”‚ β”œβ”€β”€ account/ # Account pool and concurrency queue β”‚ β”œβ”€β”€ adapter/ β”‚ β”‚ β”œβ”€β”€ openai/ # OpenAI adapter (incl. tool call parsing, Vercel stream prepare/release) -β”‚ β”‚ └── claude/ # Claude adapter -β”‚ β”œβ”€β”€ admin/ # Admin API handlers +β”‚ β”‚ β”œβ”€β”€ claude/ # Claude adapter +β”‚ β”‚ └── gemini/ # Gemini adapter (generateContent / streamGenerateContent) +β”‚ β”œβ”€β”€ admin/ # Admin API handlers (incl. Settings hot-reload) β”‚ β”œβ”€β”€ auth/ # Auth and JWT +β”‚ β”œβ”€β”€ claudeconv/ # Claude message format conversion +β”‚ β”œβ”€β”€ compat/ # Compatibility helpers β”‚ β”œβ”€β”€ config/ # Config loading and hot-reload β”‚ β”œβ”€β”€ deepseek/ # DeepSeek API client, PoW WASM +β”‚ β”œβ”€β”€ devcapture/ # Dev packet capture module +β”‚ β”œβ”€β”€ format/ # Output formatting +β”‚ β”œβ”€β”€ prompt/ # Prompt construction β”‚ β”œβ”€β”€ server/ # HTTP routing and middleware (chi router) β”‚ β”œβ”€β”€ sse/ # SSE parsing utilities +β”‚ β”œβ”€β”€ stream/ # Unified stream consumption engine β”‚ β”œβ”€β”€ util/ # Common utilities β”‚ └── webui/ # WebUI static file serving and auto-build β”œβ”€β”€ webui/ # React WebUI source (Vite + Tailwind) diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go index 1e7e500..37ebaf9 100644 --- a/internal/adapter/openai/handler_toolcall_format.go +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -112,14 +112,20 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNam return nil } allowed := namesToSet(allowedNames) + if len(allowed) == 0 { + for _, d := range deltas { + if d.Name != "" { + seenNames[d.Index] = "__blocked__" + } + } + return nil + } out := make([]toolCallDelta, 0, len(deltas)) for _, d := range deltas { if d.Name != "" { - if len(allowed) > 0 { - if _, ok := allowed[d.Name]; !ok { - seenNames[d.Index] = "__blocked__" - continue - } + if _, ok := allowed[d.Name]; !ok { + seenNames[d.Index] = "__blocked__" + continue } seenNames[d.Index] = d.Name out = append(out, d) diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go index a767960..94b2339 100644 --- a/internal/adapter/openai/message_normalize.go +++ b/internal/adapter/openai/message_normalize.go @@ -3,6 +3,7 @@ package openai import ( "encoding/json" "fmt" + "io" "strings" "ds2api/internal/config" @@ -163,12 +164,43 @@ func normalizeOpenAIContentForPrompt(v any) string { func normalizeOpenAIArgumentsForPrompt(v any) string { switch x := v.(type) { case string: - return strings.TrimSpace(x) + return normalizeToolArgumentString(x) default: return marshalToPromptString(v) } } +func normalizeToolArgumentString(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + if !looksLikeConcatenatedJSON(trimmed) { + return trimmed + } + dec := json.NewDecoder(strings.NewReader(trimmed)) + values := make([]any, 0, 2) + for { + var v any + if err := dec.Decode(&v); err != nil { + if err == io.EOF { + break + } + return trimmed + } + values = append(values, v) + } + if len(values) < 2 { + return trimmed + } + last := values[len(values)-1] + b, err := json.Marshal(last) + if err != nil || len(b) == 0 { + return trimmed + } + return string(b) +} + func marshalToPromptString(v any) string { b, err := json.Marshal(v) if err != nil { diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go index 30403bc..ff36bd9 100644 --- a/internal/adapter/openai/message_normalize_test.go +++ b/internal/adapter/openai/message_normalize_test.go @@ -167,3 +167,32 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara t.Fatalf("unexpected concatenated function arguments detected: %q", content) } } + +func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "function": map[string]any{ + "name": "search_web", + "arguments": `{}{"query":"ζ΅‹θ―•ε·₯具调用"}`, + }, + }, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 1 { + t.Fatalf("expected one normalized message, got %d", len(normalized)) + } + content, _ := normalized[0]["content"].(string) + if !strings.Contains(content, `function.arguments: {"query":"ζ΅‹θ―•ε·₯具调用"}`) { + t.Fatalf("expected repaired arguments in tool history, got %q", content) + } + if strings.Contains(content, `{}{"query":"ζ΅‹θ―•ε·₯具调用"}`) { + t.Fatalf("expected concatenated JSON to be repaired, got %q", content) + } +} diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go index 0f58c70..a586682 100644 --- a/internal/adapter/openai/responses_embeddings_test.go +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -135,6 +135,27 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) { } } +func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call", + "call_id": "call_456", + "name": "search", + "arguments": `{}{"q":"golang"}`, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + toolCalls, _ := m["tool_calls"].([]any) + call, _ := toolCalls[0].(map[string]any) + fn, _ := call["function"].(map[string]any) + if fn["arguments"] != `{"q":"golang"}` { + t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"]) + } +} + func TestExtractEmbeddingInputs(t *testing.T) { got := extractEmbeddingInputs([]any{"a", "b"}) if len(got) != 2 || got[0] != "a" || got[1] != "b" { diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index d8c59bb..81da92d 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -190,7 +190,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, } func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolicy, parsed util.ToolCallParseResult, channel string) { - if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 { + rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames) + if !parsed.RejectedByPolicy || len(rejected) == 0 { return } config.Logger.Warn( @@ -198,6 +199,23 @@ func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolic "trace_id", strings.TrimSpace(traceID), "channel", channel, "tool_choice_mode", policy.Mode, - "rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","), + "rejected_tool_names", strings.Join(rejected, ","), ) } + +func filteredRejectedToolNamesForLog(names []string) []string { + if len(names) == 0 { + return nil + } + out := make([]string, 0, len(names)) + for _, name := range names { + trimmed := strings.TrimSpace(name) + switch strings.ToLower(trimmed) { + case "", "tool_name": + continue + default: + out = append(out, trimmed) + } + } + return out +} diff --git a/internal/adapter/openai/responses_input_items.go b/internal/adapter/openai/responses_input_items.go index 81f29e8..e0eea09 100644 --- a/internal/adapter/openai/responses_input_items.go +++ b/internal/adapter/openai/responses_input_items.go @@ -188,6 +188,10 @@ func stringifyToolCallArguments(v any) string { if s == "" { return "{}" } + s = normalizeToolArgumentString(s) + if s == "" { + return "{}" + } return s default: b, err := json.Marshal(x) diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go index 7ab8600..226c94d 100644 --- a/internal/adapter/openai/responses_stream_runtime_core.go +++ b/internal/adapter/openai/responses_stream_runtime_core.go @@ -37,13 +37,14 @@ type responsesStreamRuntime struct { text strings.Builder visibleText strings.Builder streamToolCallIDs map[int]string - streamFunctionIDs map[int]string + functionItemIDs map[int]string + functionOutputIDs map[int]int functionDone map[int]bool functionAdded map[int]bool functionNames map[int]string - toolCallsDoneSigs map[string]bool - reasoningItemID string messageItemID string + messageOutputID int + nextOutputID int messageAdded bool messagePartAdded bool sequence int @@ -81,11 +82,12 @@ func newResponsesStreamRuntime( bufferToolContent: bufferToolContent, emitEarlyToolDeltas: emitEarlyToolDeltas, streamToolCallIDs: map[int]string{}, - streamFunctionIDs: map[int]string{}, + functionItemIDs: map[int]string{}, + functionOutputIDs: map[int]int{}, functionDone: map[int]bool{}, functionAdded: map[int]bool{}, functionNames: map[int]string{}, - toolCallsDoneSigs: map[string]bool{}, + messageOutputID: -1, toolChoice: toolChoice, traceID: traceID, persistResponse: persistResponse, @@ -144,10 +146,7 @@ func (s *responsesStreamRuntime) finalize() { return } - obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames) - if s.toolCallsEmitted { - s.alignCompletedOutputCallIDs(obj) - } + obj := s.buildCompletedResponseObject(finalThinking, finalText, detected) if s.persistResponse != nil { s.persistResponse(obj) } @@ -157,7 +156,8 @@ func (s *responsesStreamRuntime) finalize() { func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) { logRejected := func(parsed util.ToolCallParseResult, channel string) { - if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 { + rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames) + if !parsed.RejectedByPolicy || len(rejected) == 0 { return } config.Logger.Warn( @@ -165,7 +165,7 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingPar "trace_id", strings.TrimSpace(s.traceID), "channel", channel, "tool_choice_mode", s.toolChoice.Mode, - "rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","), + "rejected_tool_names", strings.Join(rejected, ","), ) } logRejected(textParsed, "text") diff --git a/internal/adapter/openai/responses_stream_runtime_toolcalls.go b/internal/adapter/openai/responses_stream_runtime_toolcalls.go index 94859cd..c0b3057 100644 --- a/internal/adapter/openai/responses_stream_runtime_toolcalls.go +++ b/internal/adapter/openai/responses_stream_runtime_toolcalls.go @@ -11,6 +11,12 @@ import ( "github.com/google/uuid" ) +func (s *responsesStreamRuntime) allocateOutputIndex() int { + idx := s.nextOutputID + s.nextOutputID++ + return idx +} + func (s *responsesStreamRuntime) ensureMessageItemID() string { if strings.TrimSpace(s.messageItemID) != "" { return s.messageItemID @@ -19,11 +25,12 @@ func (s *responsesStreamRuntime) ensureMessageItemID() string { return s.messageItemID } -func (s *responsesStreamRuntime) messageOutputIndex() int { - if strings.TrimSpace(s.thinking.String()) != "" { - return 1 +func (s *responsesStreamRuntime) ensureMessageOutputIndex() int { + if s.messageOutputID >= 0 { + return s.messageOutputID } - return 0 + s.messageOutputID = s.allocateOutputIndex() + return s.messageOutputID } func (s *responsesStreamRuntime) ensureMessageItemAdded() { @@ -39,7 +46,7 @@ func (s *responsesStreamRuntime) ensureMessageItemAdded() { } s.sendEvent( "response.output_item.added", - openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.messageOutputIndex(), item), + openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.ensureMessageOutputIndex(), item), ) s.messageAdded = true } @@ -54,7 +61,7 @@ func (s *responsesStreamRuntime) ensureMessageContentPartAdded() { openaifmt.BuildResponsesContentPartAddedPayload( s.responseID, s.ensureMessageItemID(), - s.messageOutputIndex(), + s.ensureMessageOutputIndex(), 0, map[string]any{"type": "output_text", "text": ""}, ), @@ -68,7 +75,16 @@ func (s *responsesStreamRuntime) emitTextDelta(content string) { } s.ensureMessageContentPartAdded() s.visibleText.WriteString(content) - s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, content)) + s.sendEvent( + "response.output_text.delta", + openaifmt.BuildResponsesTextDeltaPayload( + s.responseID, + s.ensureMessageItemID(), + s.ensureMessageOutputIndex(), + 0, + content, + ), + ) } func (s *responsesStreamRuntime) closeMessageItem() { @@ -76,6 +92,7 @@ func (s *responsesStreamRuntime) closeMessageItem() { return } itemID := s.ensureMessageItemID() + outputIndex := s.ensureMessageOutputIndex() text := s.visibleText.String() if s.messagePartAdded { s.sendEvent( @@ -83,7 +100,7 @@ func (s *responsesStreamRuntime) closeMessageItem() { openaifmt.BuildResponsesContentPartDonePayload( s.responseID, itemID, - s.messageOutputIndex(), + outputIndex, 0, map[string]any{"type": "output_text", "text": text}, ), @@ -104,45 +121,35 @@ func (s *responsesStreamRuntime) closeMessageItem() { } s.sendEvent( "response.output_item.done", - openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, s.messageOutputIndex(), item), + openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item), ) } -func (s *responsesStreamRuntime) ensureReasoningItemID() string { - if strings.TrimSpace(s.reasoningItemID) != "" { - return s.reasoningItemID - } - s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "") - return s.reasoningItemID -} - -func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string { - if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" { +func (s *responsesStreamRuntime) ensureFunctionItemID(callIndex int) string { + if id, ok := s.functionItemIDs[callIndex]; ok && strings.TrimSpace(id) != "" { return id } id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "") - s.streamFunctionIDs[index] = id + s.functionItemIDs[callIndex] = id return id } -func (s *responsesStreamRuntime) ensureToolCallID(index int) string { - if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" { +func (s *responsesStreamRuntime) ensureToolCallID(callIndex int) string { + if id, ok := s.streamToolCallIDs[callIndex]; ok && strings.TrimSpace(id) != "" { return id } id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") - s.streamToolCallIDs[index] = id + s.streamToolCallIDs[callIndex] = id return id } -func (s *responsesStreamRuntime) functionOutputBaseIndex() int { - if strings.TrimSpace(s.thinking.String()) != "" { - return 1 +func (s *responsesStreamRuntime) ensureFunctionOutputIndex(callIndex int) int { + if idx, ok := s.functionOutputIDs[callIndex]; ok { + return idx } - return 0 -} - -func (s *responsesStreamRuntime) functionOutputIndex(callIndex int) int { - return s.functionOutputBaseIndex() + callIndex + idx := s.allocateOutputIndex() + s.functionOutputIDs[callIndex] = idx + return idx } func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name string) { @@ -156,15 +163,15 @@ func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name str if fnName == "" { return } - outputIndex := s.functionOutputIndex(callIndex) - itemID := s.ensureFunctionItemID(outputIndex) + outputIndex := s.ensureFunctionOutputIndex(callIndex) + itemID := s.ensureFunctionItemID(callIndex) callID := s.ensureToolCallID(callIndex) item := map[string]any{ "id": itemID, "type": "function_call", "call_id": callID, "name": fnName, - "arguments": "{}", + "arguments": "", "status": "in_progress", } s.sendEvent( @@ -181,8 +188,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDe if strings.TrimSpace(d.Arguments) == "" { continue } - outputIndex := s.functionOutputIndex(d.Index) - itemID := s.ensureFunctionItemID(outputIndex) + outputIndex := s.ensureFunctionOutputIndex(d.Index) + itemID := s.ensureFunctionItemID(d.Index) callID := s.ensureToolCallID(d.Index) s.sendEvent( "response.function_call_arguments.delta", @@ -192,18 +199,16 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDe } func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) { - base := s.functionOutputBaseIndex() for idx, tc := range calls { if strings.TrimSpace(tc.Name) == "" { continue } s.ensureFunctionItemAdded(idx, tc.Name) - - outputIndex := base + idx - if s.functionDone[outputIndex] { + if s.functionDone[idx] { continue } - itemID := s.ensureFunctionItemID(outputIndex) + outputIndex := s.ensureFunctionOutputIndex(idx) + itemID := s.ensureFunctionItemID(idx) callID := s.ensureToolCallID(idx) argsBytes, _ := json.Marshal(tc.Input) args := string(argsBytes) @@ -223,48 +228,105 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT "response.output_item.done", openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item), ) - s.functionDone[outputIndex] = true + s.functionDone[idx] = true s.toolCallsDoneEmitted = true } } -func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) { - if obj == nil || len(s.streamToolCallIDs) == 0 { - return +func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, finalText string, calls []util.ParsedToolCall) map[string]any { + type indexedItem struct { + index int + item map[string]any } - output, _ := obj["output"].([]any) - if len(output) == 0 { - return - } - indices := make([]int, 0, len(s.streamToolCallIDs)) - for idx := range s.streamToolCallIDs { - indices = append(indices, idx) - } - sort.Ints(indices) - ordered := make([]string, 0, len(indices)) - for _, idx := range indices { - id := strings.TrimSpace(s.streamToolCallIDs[idx]) - if id == "" { - continue + indexed := make([]indexedItem, 0, len(calls)+1) + + if s.messageAdded { + text := s.visibleText.String() + indexed = append(indexed, indexedItem{ + index: s.ensureMessageOutputIndex(), + item: map[string]any{ + "id": s.ensureMessageItemID(), + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]any{ + { + "type": "output_text", + "text": text, + }, + }, + }, + }) + } else if len(calls) == 0 { + content := make([]map[string]any, 0, 2) + if strings.TrimSpace(finalThinking) != "" { + content = append(content, map[string]any{ + "type": "reasoning", + "text": finalThinking, + }) + } + if strings.TrimSpace(finalText) != "" { + content = append(content, map[string]any{ + "type": "output_text", + "text": finalText, + }) + } + if len(content) > 0 { + indexed = append(indexed, indexedItem{ + index: s.ensureMessageOutputIndex(), + item: map[string]any{ + "id": s.ensureMessageItemID(), + "type": "message", + "role": "assistant", + "status": "completed", + "content": content, + }, + }) } - ordered = append(ordered, id) - } - if len(ordered) == 0 { - return } - functionIdx := 0 - for _, item := range output { - m, _ := item.(map[string]any) - if m == nil { + for idx, tc := range calls { + if strings.TrimSpace(tc.Name) == "" { continue } - if m["type"] != "function_call" { - continue - } - if functionIdx < len(ordered) { - m["call_id"] = ordered[functionIdx] - functionIdx++ + argsBytes, _ := json.Marshal(tc.Input) + indexed = append(indexed, indexedItem{ + index: s.ensureFunctionOutputIndex(idx), + item: map[string]any{ + "id": s.ensureFunctionItemID(idx), + "type": "function_call", + "call_id": s.ensureToolCallID(idx), + "name": tc.Name, + "arguments": string(argsBytes), + "status": "completed", + }, + }) + } + + sort.SliceStable(indexed, func(i, j int) bool { + return indexed[i].index < indexed[j].index + }) + output := make([]any, 0, len(indexed)) + for _, it := range indexed { + output = append(output, it.item) + } + + outputText := s.visibleText.String() + if strings.TrimSpace(outputText) == "" && len(calls) == 0 { + if strings.TrimSpace(finalText) != "" { + outputText = finalText + } else if strings.TrimSpace(finalThinking) != "" { + outputText = finalThinking } } + + return openaifmt.BuildResponseObjectFromItems( + s.responseID, + s.model, + s.finalPrompt, + finalThinking, + finalText, + output, + outputText, + ) } diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index fd6a7a0..9dccecb 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -109,6 +109,22 @@ func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) { t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body) } + addedPayloads := extractAllSSEEventPayloads(body, "response.output_item.added") + hasFunctionCallAdded := false + for _, payload := range addedPayloads { + item, _ := payload["item"].(map[string]any) + if item == nil || asString(item["type"]) != "function_call" { + continue + } + hasFunctionCallAdded = true + if asString(item["arguments"]) != "" { + t.Fatalf("expected in-progress function_call.arguments to start empty string, got %#v", item["arguments"]) + } + } + if !hasFunctionCallAdded { + t.Fatalf("expected function_call output_item.added payload, body=%s", body) + } + donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done") if !ok { t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body) @@ -213,6 +229,137 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing. } } +func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("hello") + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "") + body := rec.Body.String() + + deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta") + if !ok { + t.Fatalf("expected response.output_text.delta payload, body=%s", body) + } + if strings.TrimSpace(asString(deltaPayload["item_id"])) == "" { + t.Fatalf("expected non-empty item_id in output_text.delta, payload=%#v", deltaPayload) + } + if _, ok := deltaPayload["output_index"]; !ok { + t.Fatalf("expected output_index in output_text.delta, payload=%#v", deltaPayload) + } + if _, ok := deltaPayload["content_index"]; !ok { + t.Fatalf("expected content_index in output_text.delta, payload=%#v", deltaPayload) + } +} + +func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(path, value string) string { + b, _ := json.Marshal(map[string]any{ + "p": path, + "v": value, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("response/thinking_content", "thinking...") + + sseLine("response/content", "ε…ˆθ―»ε–ζ–‡δ»Άγ€‚") + + sseLine("response/content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") + + addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added") + if len(addedPayloads) < 2 { + t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String()) + } + + indexes := map[int]struct{}{} + typeByIndex := map[int]string{} + addedIDs := map[string]string{} + for _, payload := range addedPayloads { + item, _ := payload["item"].(map[string]any) + itemType := strings.TrimSpace(asString(item["type"])) + outputIndex := int(asFloat(payload["output_index"])) + if _, exists := indexes[outputIndex]; exists { + t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload) + } + indexes[outputIndex] = struct{}{} + typeByIndex[outputIndex] = itemType + addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"])) + } + + completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") + if !ok { + t.Fatalf("expected response.completed payload, body=%s", rec.Body.String()) + } + responseObj, _ := completedPayload["response"].(map[string]any) + output, _ := responseObj["output"].([]any) + found := map[string]bool{} + for _, item := range output { + m, _ := item.(map[string]any) + itemType := strings.TrimSpace(asString(m["type"])) + itemID := strings.TrimSpace(asString(m["id"])) + if itemType == "" || itemID == "" { + continue + } + if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID { + found[itemType] = true + } + } + if !found["message"] || !found["function_call"] { + t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output) + } +} + +func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone} + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "") + body := rec.Body.String() + if strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("did not expect function_call events for tool_choice=none, body=%s", body) + } +} + func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) @@ -299,6 +446,32 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) { } } +func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" + + `data: [DONE]` + "\n", + )), + } + policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone} + + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 for tool_choice=none passthrough text, got %d body=%s", rec.Code, rec.Body.String()) + } + out := decodeJSONBody(t, rec.Body.String()) + output, _ := out["output"].([]any) + for _, item := range output { + m, _ := item.(map[string]any) + if m != nil && m["type"] == "function_call" { + t.Fatalf("did not expect function_call output item for tool_choice=none, got %#v", output) + } + } +} + func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { scanner := bufio.NewScanner(strings.NewReader(body)) matched := false @@ -351,3 +524,18 @@ func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any { } return out } + +func asFloat(v any) float64 { + switch x := v.(type) { + case float64: + return x + case float32: + return float64(x) + case int: + return float64(x) + case int64: + return float64(x) + default: + return 0 + } +} diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go index 60b0922..e8d1225 100644 --- a/internal/adapter/openai/standard_request_test.go +++ b/internal/adapter/openai/standard_request_test.go @@ -151,3 +151,30 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testi t.Fatalf("expected forced undeclared tool to fail") } } + +func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + }, + }, + }, + "tool_choice": "none", + } + n, err := normalizeOpenAIResponsesRequest(store, req, "") + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ToolChoice.Mode != util.ToolChoiceNone { + t.Fatalf("expected tool choice mode none, got %q", n.ToolChoice.Mode) + } + if len(n.ToolNames) != 0 { + t.Fatalf("expected no tool names when tool_choice=none, got %#v", n.ToolNames) + } +} diff --git a/internal/format/openai/render_responses.go b/internal/format/openai/render_responses.go index 1839977..f55ee9f 100644 --- a/internal/format/openai/render_responses.go +++ b/internal/format/openai/render_responses.go @@ -21,12 +21,6 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex output := make([]any, 0, 2) if len(detected) > 0 { exposedOutputText = "" - if strings.TrimSpace(finalThinking) != "" { - output = append(output, map[string]any{ - "type": "reasoning", - "text": finalThinking, - }) - } output = append(output, toResponsesFunctionCallItems(detected)...) } else { content := make([]any, 0, 2) @@ -52,6 +46,21 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex "content": content, }) } + return BuildResponseObjectFromItems( + responseID, + model, + finalPrompt, + finalThinking, + finalText, + output, + exposedOutputText, + ) +} + +func BuildResponseObjectFromItems(responseID, model, finalPrompt, finalThinking, finalText string, output []any, outputText string) map[string]any { + if output == nil { + output = []any{} + } return map[string]any{ "id": responseID, "type": "response", @@ -60,7 +69,7 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex "status": "completed", "model": model, "output": output, - "output_text": exposedOutputText, + "output_text": outputText, "usage": BuildResponsesUsage(finalPrompt, finalThinking, finalText), } } diff --git a/internal/format/openai/render_stream_events.go b/internal/format/openai/render_stream_events.go index 40e8c2c..dc13231 100644 --- a/internal/format/openai/render_stream_events.go +++ b/internal/format/openai/render_stream_events.go @@ -59,12 +59,15 @@ func BuildResponsesContentPartDonePayload(responseID, itemID string, outputIndex } } -func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any { +func BuildResponsesTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any { return map[string]any{ - "type": "response.output_text.delta", - "id": responseID, - "response_id": responseID, - "delta": delta, + "type": "response.output_text.delta", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "delta": delta, } } diff --git a/internal/format/openai/render_test.go b/internal/format/openai/render_test.go index b95e739..df792ed 100644 --- a/internal/format/openai/render_test.go +++ b/internal/format/openai/render_test.go @@ -138,15 +138,11 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) { ) output, _ := obj["output"].([]any) - if len(output) != 2 { - t.Fatalf("expected reasoning + function_call outputs, got %#v", obj["output"]) + if len(output) != 1 { + t.Fatalf("expected function_call output only, got %#v", obj["output"]) } first, _ := output[0].(map[string]any) - if first["type"] != "reasoning" { - t.Fatalf("expected first output reasoning, got %#v", first["type"]) - } - second, _ := output[1].(map[string]any) - if second["type"] != "function_call" { - t.Fatalf("expected second output function_call, got %#v", second["type"]) + if first["type"] != "function_call" { + t.Fatalf("expected output function_call, got %#v", first["type"]) } } diff --git a/internal/util/render_stream.go b/internal/util/render_stream.go deleted file mode 100644 index b5699ba..0000000 --- a/internal/util/render_stream.go +++ /dev/null @@ -1,113 +0,0 @@ -package util - -// BuildOpenAIChatStreamDeltaChoice is kept for backward compatibility. -// Prefer internal/format/openai.BuildChatStreamDeltaChoice for new code. -func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { - return map[string]any{ - "delta": delta, - "index": index, - } -} - -// BuildOpenAIChatStreamFinishChoice is kept for backward compatibility. -// Prefer internal/format/openai.BuildChatStreamFinishChoice for new code. -func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[string]any { - return map[string]any{ - "delta": map[string]any{}, - "index": index, - "finish_reason": finishReason, - } -} - -// BuildOpenAIChatStreamChunk is kept for backward compatibility. -// Prefer internal/format/openai.BuildChatStreamChunk for new code. -func BuildOpenAIChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { - out := map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": choices, - } - if len(usage) > 0 { - out["usage"] = usage - } - return out -} - -// BuildOpenAIChatUsage is kept for backward compatibility. -// Prefer internal/format/openai.BuildChatUsage for new code. -func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { - promptTokens := EstimateTokens(finalPrompt) - reasoningTokens := EstimateTokens(finalThinking) - completionTokens := EstimateTokens(finalText) - return map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - "completion_tokens_details": map[string]any{ - "reasoning_tokens": reasoningTokens, - }, - } -} - -// BuildOpenAIResponsesCreatedPayload is kept for backward compatibility. -// Prefer internal/format/openai.BuildResponsesCreatedPayload for new code. -func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any { - return map[string]any{ - "type": "response.created", - "id": responseID, - "object": "response", - "model": model, - "status": "in_progress", - } -} - -// BuildOpenAIResponsesTextDeltaPayload is kept for backward compatibility. -// Prefer internal/format/openai.BuildResponsesTextDeltaPayload for new code. -func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]any { - return map[string]any{ - "type": "response.output_text.delta", - "id": responseID, - "delta": delta, - } -} - -// BuildOpenAIResponsesReasoningDeltaPayload is kept for backward compatibility. -// Prefer internal/format/openai.BuildResponsesReasoningDeltaPayload for new code. -func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { - return map[string]any{ - "type": "response.reasoning.delta", - "id": responseID, - "delta": delta, - } -} - -// BuildOpenAIResponsesToolCallDeltaPayload is kept for backward compatibility. -// Prefer internal/format/openai.BuildResponsesToolCallDeltaPayload for new code. -func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { - return map[string]any{ - "type": "response.output_tool_call.delta", - "id": responseID, - "tool_calls": toolCalls, - } -} - -// BuildOpenAIResponsesToolCallDonePayload is kept for backward compatibility. -// Prefer internal/format/openai.BuildResponsesToolCallDonePayload for new code. -func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { - return map[string]any{ - "type": "response.output_tool_call.done", - "id": responseID, - "tool_calls": toolCalls, - } -} - -// BuildOpenAIResponsesCompletedPayload is kept for backward compatibility. -// Prefer internal/format/openai.BuildResponsesCompletedPayload for new code. -func BuildOpenAIResponsesCompletedPayload(response map[string]any) map[string]any { - return map[string]any{ - "type": "response.completed", - "response": response, - } -} diff --git a/internal/util/render_stream_test.go b/internal/util/render_stream_test.go deleted file mode 100644 index 420a311..0000000 --- a/internal/util/render_stream_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package util - -import "testing" - -func TestBuildOpenAIChatStreamChunk(t *testing.T) { - chunk := BuildOpenAIChatStreamChunk( - "cid", - 123, - "deepseek-chat", - []map[string]any{BuildOpenAIChatStreamDeltaChoice(0, map[string]any{"role": "assistant"})}, - nil, - ) - if chunk["object"] != "chat.completion.chunk" { - t.Fatalf("unexpected object: %#v", chunk["object"]) - } - choices, _ := chunk["choices"].([]map[string]any) - if len(choices) == 0 { - rawChoices, _ := chunk["choices"].([]any) - if len(rawChoices) == 0 { - t.Fatalf("expected choices") - } - } -} - -func TestBuildOpenAIChatUsage(t *testing.T) { - usage := BuildOpenAIChatUsage("prompt", "think", "answer") - if _, ok := usage["prompt_tokens"]; !ok { - t.Fatalf("expected prompt_tokens") - } - if _, ok := usage["completion_tokens_details"]; !ok { - t.Fatalf("expected completion_tokens_details") - } -} - -func TestBuildOpenAIResponsesEventPayloads(t *testing.T) { - created := BuildOpenAIResponsesCreatedPayload("resp_1", "gpt-4o") - if created["type"] != "response.created" { - t.Fatalf("unexpected type: %#v", created["type"]) - } - done := BuildOpenAIResponsesToolCallDonePayload("resp_1", []map[string]any{{"index": 0}}) - if done["type"] != "response.output_tool_call.done" { - t.Fatalf("unexpected type: %#v", done["type"]) - } - completed := BuildOpenAIResponsesCompletedPayload(map[string]any{"id": "resp_1"}) - if completed["type"] != "response.completed" { - t.Fatalf("unexpected type: %#v", completed["type"]) - } -} diff --git a/internal/util/toolcalls_parse.go b/internal/util/toolcalls_parse.go index 2b8610e..5b386c2 100644 --- a/internal/util/toolcalls_parse.go +++ b/internal/util/toolcalls_parse.go @@ -92,17 +92,29 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin for _, name := range availableToolNames { allowed[name] = struct{}{} } + if len(allowed) == 0 { + rejectedSet := map[string]struct{}{} + for _, tc := range parsed { + if tc.Name == "" { + continue + } + rejectedSet[tc.Name] = struct{}{} + } + rejected := make([]string, 0, len(rejectedSet)) + for name := range rejectedSet { + rejected = append(rejected, name) + } + return nil, rejected + } out := make([]ParsedToolCall, 0, len(parsed)) rejectedSet := map[string]struct{}{} for _, tc := range parsed { if tc.Name == "" { continue } - if len(allowed) > 0 { - if _, ok := allowed[tc.Name]; !ok { - rejectedSet[tc.Name] = struct{}{} - continue - } + if _, ok := allowed[tc.Name]; !ok { + rejectedSet[tc.Name] = struct{}{} + continue } if tc.Input == nil { tc.Input = map[string]any{} diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index b102b41..0e823c0 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -60,6 +60,20 @@ func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) { } } +func TestParseToolCallsDetailedRejectsWhenAllowListEmpty(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + res := ParseToolCallsDetailed(text, nil) + if !res.SawToolCallSyntax { + t.Fatalf("expected SawToolCallSyntax=true, got %#v", res) + } + if !res.RejectedByPolicy { + t.Fatalf("expected RejectedByPolicy=true, got %#v", res) + } + if len(res.Calls) != 0 { + t.Fatalf("expected no calls when allow-list is empty, got %#v", res.Calls) + } +} + func TestFormatOpenAIToolCalls(t *testing.T) { formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}}) if len(formatted) != 1 { diff --git a/internal/util/util_edge_test.go b/internal/util/util_edge_test.go index cba0ceb..876cd04 100644 --- a/internal/util/util_edge_test.go +++ b/internal/util/util_edge_test.go @@ -364,8 +364,8 @@ func TestFormatOpenAIStreamToolCalls(t *testing.T) { func TestParseToolCallsNoToolNames(t *testing.T) { text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` calls := ParseToolCalls(text, nil) - if len(calls) != 1 { - t.Fatalf("expected 1 call with nil tool names, got %d", len(calls)) + if len(calls) != 0 { + t.Fatalf("expected 0 call with nil tool names, got %d", len(calls)) } }