From ae7dce0b3201bedbf814efbc75896157614ab825 Mon Sep 17 00:00:00 2001 From: CJACK Date: Sun, 22 Feb 2026 19:33:52 +0800 Subject: [PATCH] feat: Improve OpenAI tool call handling by passing unknown tool calls as content and filtering streamed tool calls by schema. --- API.en.md | 24 +- API.md | 24 +- README.MD | 53 ++- README.en.md | 7 +- .../adapter/openai/chat_stream_runtime.go | 8 +- internal/adapter/openai/handler_errors.go | 9 +- .../adapter/openai/handler_toolcall_format.go | 55 ++- .../adapter/openai/handler_toolcall_test.go | 45 +-- internal/adapter/openai/prompt_build.go | 7 +- .../openai/responses_embeddings_test.go | 26 ++ internal/adapter/openai/responses_handler.go | 42 ++- .../adapter/openai/responses_input_items.go | 18 + .../openai/responses_input_normalize.go | 3 +- .../openai/responses_stream_runtime_core.go | 105 ++++-- .../openai/responses_stream_runtime_events.go | 23 +- .../responses_stream_runtime_toolcalls.go | 198 ++++++++--- .../adapter/openai/responses_stream_test.go | 326 ++++++------------ internal/adapter/openai/standard_request.go | 226 +++++++++++- .../adapter/openai/standard_request_test.go | 93 +++++ internal/format/openai/render_responses.go | 32 +- .../format/openai/render_stream_events.go | 111 +++--- internal/format/openai/render_test.go | 47 +-- internal/util/standard_request.go | 36 ++ internal/util/toolcalls_parse.go | 67 ++-- internal/util/toolcalls_test.go | 21 +- .../expected/toolcalls_unknown_name.json | 4 +- 26 files changed, 1109 insertions(+), 501 deletions(-) diff --git a/API.en.md b/API.en.md index ef1a6f3..b910723 100644 --- a/API.en.md +++ b/API.en.md @@ -286,8 +286,10 @@ OpenAI Responses-style endpoint, accepting either `input` or `messages`. | `instructions` | string | ❌ | Prepended as a system message | | `stream` | boolean | ❌ | Default `false` | | `tools` | array | ❌ | Same tool detection/translation policy as chat | +| `tool_choice` | string/object | ❌ | Supports `auto`/`none`/`required` and forced function selection (`{"type":"function","name":"..."}`) | **Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache. +If `tool_choice=required` and no valid tool call is produced, DS2API returns HTTP `422` (`error.code=tool_choice_violation`). **Stream (SSE)**: minimal event sequence: @@ -295,11 +297,26 @@ OpenAI Responses-style endpoint, accepting either `input` or `messages`. event: response.created data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...} +event: response.output_item.added +data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...} + +event: response.content_part.added +data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...} + event: response.output_text.delta data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."} -event: response.output_tool_call.delta -data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]} +event: response.function_call_arguments.delta +data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."} + +event: response.function_call_arguments.done +data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"} + +event: response.content_part.done +data: {"type":"response.content_part.done","response_id":"resp_xxx",...} + +event: response.output_item.done +data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...} event: response.completed data: {"type":"response.completed","response":{...}} @@ -307,6 +324,9 @@ data: {"type":"response.completed","response":{...}} data: [DONE] ``` +If `tool_choice=required` is violated in stream mode, DS2API emits `response.failed` then `[DONE]` (no `response.completed`). +Unknown tool names (outside declared `tools`) are rejected and will not be emitted as valid tool calls. + ### `GET /v1/responses/{response_id}` Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read). diff --git a/API.md b/API.md index 3770924..6ec117e 100644 --- a/API.md +++ b/API.md @@ -286,8 +286,10 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。 | `instructions` | string | ❌ | 自动前置为 system 消息 | | `stream` | boolean | ❌ | 默认 `false` | | `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 | +| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) | **非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。 +当 `tool_choice=required` 且未产出有效工具调用时,返回 HTTP `422`(`error.code=tool_choice_violation`)。 **流式响应(SSE)**:最小事件序列如下。 @@ -295,11 +297,26 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。 event: response.created data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...} +event: response.output_item.added +data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...} + +event: response.content_part.added +data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...} + event: response.output_text.delta data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."} -event: response.output_tool_call.delta -data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]} +event: response.function_call_arguments.delta +data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."} + +event: response.function_call_arguments.done +data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"} + +event: response.content_part.done +data: {"type":"response.content_part.done","response_id":"resp_xxx",...} + +event: response.output_item.done +data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...} event: response.completed data: {"type":"response.completed","response":{...}} @@ -307,6 +324,9 @@ data: {"type":"response.completed","response":{...}} data: [DONE] ``` +流式场景下若 `tool_choice=required` 违规,会返回 `response.failed` 后结束(不再发送 `response.completed`)。 +未在 `tools` 声明中的工具名会被严格拒绝,不会作为有效 tool call 下发。 + ### `GET /v1/responses/{response_id}` 需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。 diff --git a/README.MD b/README.MD index a0a1932..69dde0a 100644 --- a/README.MD +++ b/README.MD @@ -8,13 +8,13 @@ 语言 / Language: [中文](README.MD) | [English](README.en.md) -将 DeepSeek Web 对话能力转换为 OpenAI 与 Claude 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。 +将 DeepSeek Web 对话能力转换为 OpenAI、Claude 与 Gemini 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。 ## 架构概览 ```mermaid flowchart LR - Client["🖥️ 客户端\n(OpenAI / Claude 兼容)"] + Client["🖥️ 客户端\n(OpenAI / Claude / Gemini 兼容)"] subgraph DS2API["DS2API 服务"] direction TB @@ -24,6 +24,7 @@ flowchart LR subgraph Adapters["适配器层"] OA["OpenAI 适配器\n/v1/*"] CA["Claude 适配器\n/anthropic/*"] + GA["Gemini 适配器\n/v1beta/models/*"] end subgraph Support["支撑模块"] @@ -38,11 +39,11 @@ flowchart LR DS["☁️ DeepSeek API"] Client -- "请求" --> CORS --> Auth - Auth --> OA & CA - OA & CA -- "调用" --> DS + Auth --> OA & CA & GA + OA & CA & GA -- "调用" --> DS Auth --> Admin - OA & CA -. "轮询选账号" .-> Pool - OA & CA -. "计算 PoW" .-> PoW + OA & CA & GA -. "轮询选账号" .-> Pool + OA & CA & GA -. "计算 PoW" .-> PoW DS -- "响应" --> Client ``` @@ -55,12 +56,13 @@ flowchart LR | 能力 | 说明 | | --- | --- | | OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` | -| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` | +| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens`(及快捷路径 `/v1/messages`、`/messages`) | +| Gemini 兼容 | `POST /v1beta/models/{model}:generateContent`、`POST /v1beta/models/{model}:streamGenerateContent`(及 `/v1/models/{model}:*` 路径) | | 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 | | 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 | | DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 | | Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 | -| Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 | +| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、导入导出、Vercel 同步 | | WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) | | 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) | @@ -72,6 +74,7 @@ flowchart LR | P0 | OpenAI SDK(JS/Python,chat + responses) | ✅ | | P0 | Vercel AI SDK(openai-compatible) | ✅ | | P0 | Anthropic SDK(messages) | ✅ | +| P0 | Google Gemini SDK(generateContent) | ✅ | | P1 | LangChain / LlamaIndex / OpenWebUI(OpenAI 兼容接入) | ✅ | | P2 | MCP 独立桥接层 | 规划中 | @@ -97,6 +100,10 @@ flowchart LR 可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。 另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。 +### Gemini 接口 + +Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。 + ## 快速开始 ### 通用第一步(所有部署方式) @@ -249,6 +256,14 @@ cp opencode.json.example opencode.json "claude_model_mapping": { "fast": "deepseek-chat", "slow": "deepseek-reasoner" + }, + "admin": { + "jwt_expire_hours": 24 + }, + "runtime": { + "account_max_inflight": 2, + "account_max_queue": 0, + "global_max_inflight": 0 } } ``` @@ -262,6 +277,8 @@ cp opencode.json.example opencode.json - `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL - `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`) - `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型 +- `admin`:管理后台设置(JWT 过期时间、密码哈希等),可通过 Admin Settings API 热更新 +- `runtime`:运行时参数(并发限制、队列大小),可通过 Admin Settings API 热更新 ### 环境变量 @@ -293,7 +310,7 @@ cp opencode.json.example opencode.json ## 鉴权模式 -调用业务接口(`/v1/*`、`/anthropic/*`)时支持两种模式: +调用业务接口(`/v1/*`、`/anthropic/*`、Gemini 路由)时支持两种模式: | 模式 | 说明 | | --- | --- | @@ -320,9 +337,10 @@ cp opencode.json.example opencode.json 当请求中带 `tools` 时,DS2API 会做防泄漏处理: 1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发) -2. 一旦命中高置信特征(`tool_calls` + `name` + `arguments/input` 起始)就立即输出 `delta.tool_calls` -3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content` -4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出 +2. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`) +3. 未在 `tools` 声明中的工具名会被严格拒绝,不会下发为有效 tool call +4. `responses` 支持并执行 `tool_choice`(`auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed` +5. 仅在通过策略校验后才会发出有效工具调用事件,避免错误工具名进入客户端执行链 ## 本地开发抓包工具 @@ -362,13 +380,20 @@ ds2api/ │ ├── account/ # 账号池与并发队列 │ ├── adapter/ │ │ ├── openai/ # OpenAI 兼容适配器(含 Tool Call 解析、Vercel 流式 prepare/release) -│ │ └── claude/ # Claude 兼容适配器 -│ ├── admin/ # Admin API handlers +│ │ ├── claude/ # Claude 兼容适配器 +│ │ └── gemini/ # Gemini 兼容适配器(generateContent / streamGenerateContent) +│ ├── admin/ # Admin API handlers(含 Settings 热更新) │ ├── auth/ # 鉴权与 JWT +│ ├── claudeconv/ # Claude 消息格式转换 +│ ├── compat/ # 兼容性辅助 │ ├── config/ # 配置加载与热更新 │ ├── deepseek/ # DeepSeek API 客户端、PoW WASM +│ ├── devcapture/ # 开发抓包模块 +│ ├── format/ # 输出格式化 +│ ├── prompt/ # Prompt 构建 │ ├── server/ # HTTP 路由与中间件(chi router) │ ├── sse/ # SSE 解析工具 +│ ├── stream/ # 统一流式消费引擎 │ ├── util/ # 通用工具函数 │ └── webui/ # WebUI 静态文件托管与自动构建 ├── webui/ # React WebUI 源码(Vite + Tailwind) diff --git a/README.en.md b/README.en.md index 1445db6..1de40a7 100644 --- a/README.en.md +++ b/README.en.md @@ -320,9 +320,10 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency) When `tools` is present in the request, DS2API performs anti-leak handling: 1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored) -2. Once high-confidence features are matched (`tool_calls` + `name` + `arguments/input` start), `delta.tool_calls` is emitted immediately -3. Confirmed toolcall JSON fragments are never leaked into `delta.content` -4. Natural language before/after toolcalls keeps original order, with incremental argument output supported +2. In `responses` stream mode, tool calls follow official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`) +3. Unknown tool names (outside declared `tools`) are rejected and are not emitted as valid tool calls +4. `tool_choice` is enforced on `responses` (`auto`/`none`/`required`/forced function); required violations return HTTP `422` (non-stream) or `response.failed` (stream) +5. Confirmed toolcall JSON fragments are never emitted as valid tool call events unless they pass policy checks ## Local Dev Packet Capture diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index d9a1ba4..a5ecbd6 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -33,6 +33,7 @@ type chatStreamRuntime struct { toolSieve toolStreamSieveState streamToolCallIDs map[int]string + streamToolNames map[int]string thinking strings.Builder text strings.Builder } @@ -65,6 +66,7 @@ func newChatStreamRuntime( bufferToolContent: bufferToolContent, emitEarlyToolDeltas: emitEarlyToolDeltas, streamToolCallIDs: map[int]string{}, + streamToolNames: map[int]string{}, } } @@ -211,7 +213,11 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD if !s.emitEarlyToolDeltas { continue } - formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs) + filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames) + if len(filtered) == 0 { + continue + } + formatted := formatIncrementalStreamToolCallDeltas(filtered, s.streamToolCallIDs) if len(formatted) == 0 { continue } diff --git a/internal/adapter/openai/handler_errors.go b/internal/adapter/openai/handler_errors.go index 62249d2..2e60d73 100644 --- a/internal/adapter/openai/handler_errors.go +++ b/internal/adapter/openai/handler_errors.go @@ -3,11 +3,18 @@ package openai import "net/http" func writeOpenAIError(w http.ResponseWriter, status int, message string) { + writeOpenAIErrorWithCode(w, status, message, "") +} + +func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) { + if code == "" { + code = openAIErrorCode(status) + } writeJSON(w, status, map[string]any{ "error": map[string]any{ "message": message, "type": openAIErrorType(status), - "code": openAIErrorCode(status), + "code": code, "param": nil, }, }) diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go index d939c68..1e7e500 100644 --- a/internal/adapter/openai/handler_toolcall_format.go +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -10,9 +10,23 @@ import ( "ds2api/internal/util" ) -func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { +func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolChoicePolicy) ([]map[string]any, []string) { + if policy.IsNone() { + return messages, nil + } toolSchemas := make([]string, 0, len(tools)) names := make([]string, 0, len(tools)) + isAllowed := func(name string) bool { + if strings.TrimSpace(name) == "" { + return false + } + if len(policy.Allowed) == 0 { + return true + } + _, ok := policy.Allowed[name] + return ok + } + for _, t := range tools { tool, ok := t.(map[string]any) if !ok { @@ -25,8 +39,9 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, name, _ := fn["name"].(string) desc, _ := fn["description"].(string) schema, _ := fn["parameters"].(map[string]any) - if name == "" { - name = "unknown" + name = strings.TrimSpace(name) + if !isAllowed(name) { + continue } names = append(names, name) if desc == "" { @@ -39,6 +54,13 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, return messages, names } toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block." + if policy.Mode == util.ToolChoiceRequired { + toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list." + } + if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" { + toolPrompt += "\n5) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName) + toolPrompt += "\n6) Do not call any other tool." + } for i := range messages { if messages[i]["role"] == "system" { @@ -85,6 +107,33 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s return out } +func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta { + if len(deltas) == 0 { + return nil + } + allowed := namesToSet(allowedNames) + out := make([]toolCallDelta, 0, len(deltas)) + for _, d := range deltas { + if d.Name != "" { + if len(allowed) > 0 { + if _, ok := allowed[d.Name]; !ok { + seenNames[d.Index] = "__blocked__" + continue + } + } + seenNames[d.Index] = d.Name + out = append(out, d) + continue + } + name := strings.TrimSpace(seenNames[d.Index]) + if name == "" || name == "__blocked__" { + continue + } + out = append(out, d) + } + return out +} + func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any { if len(calls) == 0 { return nil diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index 2027729..9236b8b 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -181,7 +181,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) { } } -func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { +func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`, @@ -197,16 +197,16 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { out := decodeJSONBody(t, rec.Body.String()) choices, _ := out["choices"].([]any) choice, _ := choices[0].(map[string]any) - if choice["finish_reason"] != "tool_calls" { - t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"]) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) } msg, _ := choice["message"].(map[string]any) - if msg["content"] != nil { - t.Fatalf("expected content nil, got %#v", msg["content"]) + if _, ok := msg["tool_calls"]; ok { + t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"]) } - toolCalls, _ := msg["tool_calls"].([]any) - if len(toolCalls) != 1 { - t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"]) + content, _ := msg["content"].(string) + if !strings.Contains(content, `"tool_calls"`) { + t.Fatalf("expected unknown tool json to pass through as text, got %#v", content) } } @@ -375,7 +375,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing. } } -func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) { +func TestHandleStreamUnknownToolNotIntercepted(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`, @@ -390,29 +390,14 @@ func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + if streamHasToolCallsDelta(frames) { + t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String()) } - foundToolIndex := false - for _, frame := range frames { - choices, _ := frame["choices"].([]any) - for _, item := range choices { - choice, _ := item.(map[string]any) - delta, _ := choice["delta"].(map[string]any) - toolCalls, _ := delta["tool_calls"].([]any) - for _, tc := range toolCalls { - tcm, _ := tc.(map[string]any) - if _, ok := tcm["index"].(float64); ok { - foundToolIndex = true - } - } - } + if !streamHasRawToolJSONContent(frames) { + t.Fatalf("expected raw tool_calls json to remain in content for unknown schema name: %s", rec.Body.String()) } - if !foundToolIndex { - t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String()) - } - if streamHasRawToolJSONContent(frames) { - t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + if streamFinishReason(frames) != "stop" { + t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String()) } } diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go index 76739ed..d6823b2 100644 --- a/internal/adapter/openai/prompt_build.go +++ b/internal/adapter/openai/prompt_build.go @@ -2,13 +2,18 @@ package openai import ( "ds2api/internal/deepseek" + "ds2api/internal/util" ) func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) { + return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy()) +} + +func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy) (string, []string) { messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID) toolNames := []string{} if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { - messages, toolNames = injectToolPrompt(messages, tools) + messages, toolNames = injectToolPrompt(messages, tools, toolPolicy) } return deepseek.MessagesPrepare(messages), toolNames } diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go index a5e2b72..0f58c70 100644 --- a/internal/adapter/openai/responses_embeddings_test.go +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -73,6 +73,32 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) { } } +func TestNormalizeResponsesInputAsMessagesBackfillsToolResultNameFromCallID(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call", + "call_id": "call_999", + "name": "search", + "arguments": `{"q":"golang"}`, + }, + map[string]any{ + "type": "function_call_output", + "call_id": "call_999", + "output": map[string]any{"ok": true}, + }, + }) + if len(msgs) != 2 { + t.Fatalf("expected two messages, got %d", len(msgs)) + } + toolMsg, _ := msgs[1].(map[string]any) + if toolMsg["role"] != "tool" { + t.Fatalf("expected tool role, got %#v", toolMsg) + } + if toolMsg["name"] != "search" { + t.Fatalf("expected tool name backfilled from call_id, got %#v", toolMsg["name"]) + } +} + func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) { msgs := normalizeResponsesInputAsMessages([]any{ map[string]any{ diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 7b35e0c..d8c59bb 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -11,10 +11,12 @@ import ( "github.com/google/uuid" "ds2api/internal/auth" + "ds2api/internal/config" "ds2api/internal/deepseek" openaifmt "ds2api/internal/format/openai" "ds2api/internal/sse" streamengine "ds2api/internal/stream" + "ds2api/internal/util" ) func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { @@ -67,7 +69,8 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r)) + traceID := requestTraceID(r) + stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, traceID) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error()) return @@ -96,13 +99,13 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") if stdReq.Stream { - h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID) return } - h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) + h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID) } -func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { +func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -110,12 +113,26 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res return } result := sse.CollectStream(resp, thinkingEnabled, true) + textParsed := util.ParseToolCallsDetailed(result.Text, toolNames) + thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames) + logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text") + logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking") + + callCount := len(textParsed.Calls) + if callCount == 0 { + callCount = len(thinkingParsed.Calls) + } + if toolChoice.IsRequired() && callCount == 0 { + writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation") + return + } + responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) h.getResponseStore().put(owner, responseID, responseObj) writeJSON(w, http.StatusOK, responseObj) } -func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { +func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -148,6 +165,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, toolNames, bufferToolContent, emitEarlyToolDeltas, + toolChoice, + traceID, func(obj map[string]any) { h.getResponseStore().put(owner, responseID, obj) }, @@ -169,3 +188,16 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, }, }) } + +func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolicy, parsed util.ToolCallParseResult, channel string) { + if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 { + return + } + config.Logger.Warn( + "[responses] rejected tool calls by policy", + "trace_id", strings.TrimSpace(traceID), + "channel", channel, + "tool_choice_mode", policy.Mode, + "rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","), + ) +} diff --git a/internal/adapter/openai/responses_input_items.go b/internal/adapter/openai/responses_input_items.go index 2a2dfc4..81f29e8 100644 --- a/internal/adapter/openai/responses_input_items.go +++ b/internal/adapter/openai/responses_input_items.go @@ -4,9 +4,15 @@ import ( "encoding/json" "fmt" "strings" + + "ds2api/internal/config" ) func normalizeResponsesInputItem(m map[string]any) map[string]any { + return normalizeResponsesInputItemWithState(m, nil) +} + +func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[string]string) map[string]any { if m == nil { return nil } @@ -69,6 +75,15 @@ func normalizeResponsesInputItem(m map[string]any) map[string]any { out["name"] = name } else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" { out["name"] = name + } else if callID := strings.TrimSpace(asString(out["tool_call_id"])); callID != "" { + if inferred := strings.TrimSpace(callNameByID[callID]); inferred != "" { + out["name"] = inferred + } else { + config.Logger.Warn( + "[responses] unable to backfill tool result name from call_id", + "call_id", callID, + ) + } } return out case "function_call", "tool_call": @@ -111,6 +126,9 @@ func normalizeResponsesInputItem(m map[string]any) map[string]any { } else if callID = strings.TrimSpace(asString(m["id"])); callID != "" { call["id"] = callID } + if callID := strings.TrimSpace(asString(call["id"])); callID != "" && callNameByID != nil { + callNameByID[callID] = name + } return map[string]any{ "role": "assistant", "tool_calls": []any{call}, diff --git a/internal/adapter/openai/responses_input_normalize.go b/internal/adapter/openai/responses_input_normalize.go index 13f9e1a..6514669 100644 --- a/internal/adapter/openai/responses_input_normalize.go +++ b/internal/adapter/openai/responses_input_normalize.go @@ -59,6 +59,7 @@ func normalizeResponsesInputArray(items []any) []any { return nil } out := make([]any, 0, len(items)) + callNameByID := map[string]string{} fallbackParts := make([]string, 0, len(items)) flushFallback := func() { if len(fallbackParts) == 0 { @@ -71,7 +72,7 @@ func normalizeResponsesInputArray(items []any) []any { for _, item := range items { switch x := item.(type) { case map[string]any: - if msg := normalizeResponsesInputItem(x); msg != nil { + if msg := normalizeResponsesInputItemWithState(x, callNameByID); msg != nil { flushFallback() out = append(out, msg) continue diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go index 5aad11e..7ab8600 100644 --- a/internal/adapter/openai/responses_stream_runtime_core.go +++ b/internal/adapter/openai/responses_stream_runtime_core.go @@ -4,6 +4,7 @@ import ( "net/http" "strings" + "ds2api/internal/config" openaifmt "ds2api/internal/format/openai" "ds2api/internal/sse" streamengine "ds2api/internal/stream" @@ -19,6 +20,8 @@ type responsesStreamRuntime struct { model string finalPrompt string toolNames []string + traceID string + toolChoice util.ToolChoicePolicy thinkingEnabled bool searchEnabled bool @@ -32,11 +35,19 @@ type responsesStreamRuntime struct { thinkingSieve toolStreamSieveState thinking strings.Builder text strings.Builder + visibleText strings.Builder streamToolCallIDs map[int]string streamFunctionIDs map[int]string functionDone map[int]bool + functionAdded map[int]bool + functionNames map[int]string toolCallsDoneSigs map[string]bool reasoningItemID string + messageItemID string + messageAdded bool + messagePartAdded bool + sequence int + failed bool persistResponse func(obj map[string]any) } @@ -53,6 +64,8 @@ func newResponsesStreamRuntime( toolNames []string, bufferToolContent bool, emitEarlyToolDeltas bool, + toolChoice util.ToolChoicePolicy, + traceID string, persistResponse func(obj map[string]any), ) *responsesStreamRuntime { return &responsesStreamRuntime{ @@ -70,7 +83,11 @@ func newResponsesStreamRuntime( streamToolCallIDs: map[int]string{}, streamFunctionIDs: map[int]string{}, functionDone: map[int]bool{}, + functionAdded: map[int]bool{}, + functionNames: map[int]string{}, toolCallsDoneSigs: map[string]bool{}, + toolChoice: toolChoice, + traceID: traceID, persistResponse: persistResponse, } } @@ -78,36 +95,59 @@ func newResponsesStreamRuntime( func (s *responsesStreamRuntime) finalize() { finalThinking := s.thinking.String() finalText := s.text.String() - if strings.TrimSpace(finalThinking) != "" { - s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking)) - } + if s.bufferToolContent { s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false) } - // Compatibility fallback: some streams only emit incremental tool deltas. - // Ensure final function_call_arguments.done is emitted at least once. - if s.toolCallsEmitted { - detected := util.ParseToolCalls(finalText, s.toolNames) - if len(detected) == 0 { - detected = util.ParseToolCalls(finalThinking, s.toolNames) + + textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames) + thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames) + detected := textParsed.Calls + if len(detected) == 0 { + detected = thinkingParsed.Calls + } + s.logToolPolicyRejections(textParsed, thinkingParsed) + + if len(detected) > 0 { + s.toolCallsEmitted = true + if !s.toolCallsDoneEmitted { + s.emitFunctionCallDoneEvents(detected) } - if len(detected) > 0 { - if !s.toolCallsDoneEmitted { - s.emitToolCallsDone(detected) - } else { - s.emitFunctionCallDoneEvents(detected) - } + } + + s.closeMessageItem() + + if s.toolChoice.IsRequired() && !s.hasFunctionCallDone() { + s.failed = true + message := "tool_choice requires at least one valid tool call." + failedResp := map[string]any{ + "id": s.responseID, + "type": "response", + "object": "response", + "model": s.model, + "status": "failed", + "output": []any{}, + "output_text": "", + "error": map[string]any{ + "message": message, + "type": "invalid_request_error", + "code": "tool_choice_violation", + "param": nil, + }, } + if s.persistResponse != nil { + s.persistResponse(failedResp) + } + s.sendEvent("response.failed", openaifmt.BuildResponsesFailedPayload(s.responseID, s.model, message, "tool_choice_violation")) + s.sendDone() + return } obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames) if s.toolCallsEmitted { s.alignCompletedOutputCallIDs(obj) } - if s.toolCallsEmitted { - obj["status"] = "completed" - } if s.persistResponse != nil { s.persistResponse(obj) } @@ -115,6 +155,32 @@ func (s *responsesStreamRuntime) finalize() { s.sendDone() } +func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) { + logRejected := func(parsed util.ToolCallParseResult, channel string) { + if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 { + return + } + config.Logger.Warn( + "[responses] rejected tool calls by policy", + "trace_id", strings.TrimSpace(s.traceID), + "channel", channel, + "tool_choice_mode", s.toolChoice.Mode, + "rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","), + ) + } + logRejected(textParsed, "text") + logRejected(thinkingParsed, "thinking") +} + +func (s *responsesStreamRuntime) hasFunctionCallDone() bool { + for _, done := range s.functionDone { + if done { + return true + } + } + return false +} + func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { if !parsed.Parsed { return streamengine.ParsedDecision{} @@ -138,7 +204,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa } s.thinking.WriteString(p.Text) s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) - s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text)) if s.bufferToolContent { s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false) } @@ -147,7 +212,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa s.text.WriteString(p.Text) if !s.bufferToolContent { - s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text)) + s.emitTextDelta(p.Text) continue } s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true) diff --git a/internal/adapter/openai/responses_stream_runtime_events.go b/internal/adapter/openai/responses_stream_runtime_events.go index fd36b6a..792d0ce 100644 --- a/internal/adapter/openai/responses_stream_runtime_events.go +++ b/internal/adapter/openai/responses_stream_runtime_events.go @@ -6,7 +6,18 @@ import ( openaifmt "ds2api/internal/format/openai" ) +func (s *responsesStreamRuntime) nextSequence() int { + s.sequence++ + return s.sequence +} + func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) { + if payload == nil { + payload = map[string]any{} + } + if _, ok := payload["sequence_number"]; !ok { + payload["sequence_number"] = s.nextSequence() + } b, _ := json.Marshal(payload) _, _ = s.w.Write([]byte("event: " + event + "\n")) _, _ = s.w.Write([]byte("data: ")) @@ -31,22 +42,20 @@ func (s *responsesStreamRuntime) sendDone() { func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) { for _, evt := range events { if emitContent && evt.Content != "" { - s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) + s.emitTextDelta(evt.Content) } if len(evt.ToolCallDeltas) > 0 { if !s.emitEarlyToolDeltas { continue } - formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs) - if len(formatted) == 0 { + filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames) + if len(filtered) == 0 { continue } - s.toolCallsEmitted = true - s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted)) - s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas) + s.emitFunctionCallDeltaEvents(filtered) } if len(evt.ToolCalls) > 0 { - s.emitToolCallsDone(evt.ToolCalls) + s.emitFunctionCallDoneEvents(evt.ToolCalls) } } } diff --git a/internal/adapter/openai/responses_stream_runtime_toolcalls.go b/internal/adapter/openai/responses_stream_runtime_toolcalls.go index 7891425..94859cd 100644 --- a/internal/adapter/openai/responses_stream_runtime_toolcalls.go +++ b/internal/adapter/openai/responses_stream_runtime_toolcalls.go @@ -11,25 +11,101 @@ import ( "github.com/google/uuid" ) -func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) { - if len(calls) == 0 { +func (s *responsesStreamRuntime) ensureMessageItemID() string { + if strings.TrimSpace(s.messageItemID) != "" { + return s.messageItemID + } + s.messageItemID = "msg_" + strings.ReplaceAll(uuid.NewString(), "-", "") + return s.messageItemID +} + +func (s *responsesStreamRuntime) messageOutputIndex() int { + if strings.TrimSpace(s.thinking.String()) != "" { + return 1 + } + return 0 +} + +func (s *responsesStreamRuntime) ensureMessageItemAdded() { + if s.messageAdded { return } - sig := toolCallListSignature(calls) - if sig != "" && s.toolCallsDoneSigs[sig] { + itemID := s.ensureMessageItemID() + item := map[string]any{ + "id": itemID, + "type": "message", + "role": "assistant", + "status": "in_progress", + } + s.sendEvent( + "response.output_item.added", + openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.messageOutputIndex(), item), + ) + s.messageAdded = true +} + +func (s *responsesStreamRuntime) ensureMessageContentPartAdded() { + if s.messagePartAdded { return } - if sig != "" { - s.toolCallsDoneSigs[sig] = true - } - formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs) - if len(formatted) == 0 { + s.ensureMessageItemAdded() + s.sendEvent( + "response.content_part.added", + openaifmt.BuildResponsesContentPartAddedPayload( + s.responseID, + s.ensureMessageItemID(), + s.messageOutputIndex(), + 0, + map[string]any{"type": "output_text", "text": ""}, + ), + ) + s.messagePartAdded = true +} + +func (s *responsesStreamRuntime) emitTextDelta(content string) { + if strings.TrimSpace(content) == "" { return } - s.toolCallsEmitted = true - s.toolCallsDoneEmitted = true - s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted)) - s.emitFunctionCallDoneEvents(calls) + s.ensureMessageContentPartAdded() + s.visibleText.WriteString(content) + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, content)) +} + +func (s *responsesStreamRuntime) closeMessageItem() { + if !s.messageAdded { + return + } + itemID := s.ensureMessageItemID() + text := s.visibleText.String() + if s.messagePartAdded { + s.sendEvent( + "response.content_part.done", + openaifmt.BuildResponsesContentPartDonePayload( + s.responseID, + itemID, + s.messageOutputIndex(), + 0, + map[string]any{"type": "output_text", "text": text}, + ), + ) + s.messagePartAdded = false + } + item := map[string]any{ + "id": itemID, + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]any{ + { + "type": "output_text", + "text": text, + }, + }, + } + s.sendEvent( + "response.output_item.done", + openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, s.messageOutputIndex(), item), + ) } func (s *responsesStreamRuntime) ensureReasoningItemID() string { @@ -65,12 +141,47 @@ func (s *responsesStreamRuntime) functionOutputBaseIndex() int { return 0 } +func (s *responsesStreamRuntime) functionOutputIndex(callIndex int) int { + return s.functionOutputBaseIndex() + callIndex +} + +func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name string) { + if strings.TrimSpace(name) != "" { + s.functionNames[callIndex] = strings.TrimSpace(name) + } + if s.functionAdded[callIndex] { + return + } + fnName := strings.TrimSpace(s.functionNames[callIndex]) + if fnName == "" { + return + } + outputIndex := s.functionOutputIndex(callIndex) + itemID := s.ensureFunctionItemID(outputIndex) + callID := s.ensureToolCallID(callIndex) + item := map[string]any{ + "id": itemID, + "type": "function_call", + "call_id": callID, + "name": fnName, + "arguments": "{}", + "status": "in_progress", + } + s.sendEvent( + "response.output_item.added", + openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, outputIndex, item), + ) + s.functionAdded[callIndex] = true + s.toolCallsEmitted = true +} + func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) { for _, d := range deltas { + s.ensureFunctionItemAdded(d.Index, d.Name) if strings.TrimSpace(d.Arguments) == "" { continue } - outputIndex := s.functionOutputBaseIndex() + d.Index + outputIndex := s.functionOutputIndex(d.Index) itemID := s.ensureFunctionItemID(outputIndex) callID := s.ensureToolCallID(d.Index) s.sendEvent( @@ -86,6 +197,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT if strings.TrimSpace(tc.Name) == "" { continue } + s.ensureFunctionItemAdded(idx, tc.Name) + outputIndex := base + idx if s.functionDone[outputIndex] { continue @@ -93,11 +206,25 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT itemID := s.ensureFunctionItemID(outputIndex) callID := s.ensureToolCallID(idx) argsBytes, _ := json.Marshal(tc.Input) + args := string(argsBytes) s.sendEvent( "response.function_call_arguments.done", - openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)), + openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, args), + ) + item := map[string]any{ + "id": itemID, + "type": "function_call", + "call_id": callID, + "name": tc.Name, + "arguments": args, + "status": "completed", + } + s.sendEvent( + "response.output_item.done", + openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item), ) s.functionDone[outputIndex] = true + s.toolCallsDoneEmitted = true } } @@ -132,41 +259,12 @@ func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) if m == nil { continue } - typ, _ := m["type"].(string) - switch typ { - case "function_call": - if functionIdx < len(ordered) { - m["call_id"] = ordered[functionIdx] - functionIdx++ - } - case "tool_calls": - tcArr, _ := m["tool_calls"].([]any) - for i, raw := range tcArr { - tc, _ := raw.(map[string]any) - if tc == nil { - continue - } - if i < len(ordered) { - tc["id"] = ordered[i] - } - } + if m["type"] != "function_call" { + continue + } + if functionIdx < len(ordered) { + m["call_id"] = ordered[functionIdx] + functionIdx++ } } } - -func toolCallListSignature(calls []util.ParsedToolCall) string { - if len(calls) == 0 { - return "" - } - var b strings.Builder - for i, tc := range calls { - if i > 0 { - b.WriteString("|") - } - b.WriteString(strings.TrimSpace(tc.Name)) - b.WriteString(":") - args, _ := json.Marshal(tc.Input) - b.Write(args) - } - return b.String() -} diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index a513e6f..fd6a7a0 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -8,6 +8,8 @@ import ( "net/http/httptest" "strings" "testing" + + "ds2api/internal/util" ) func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) { @@ -30,7 +32,7 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T Body: io.NopCloser(strings.NewReader(streamBody)), } - h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") if !ok { @@ -45,8 +47,8 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T if len(output) == 0 { t.Fatalf("expected structured output entries, got %#v", responseObj["output"]) } - var firstToolWrapper map[string]any hasFunctionCall := false + hasLegacyWrapper := false for _, item := range output { m, _ := item.(map[string]any) if m == nil { @@ -55,96 +57,22 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T if m["type"] == "function_call" { hasFunctionCall = true } - if m["type"] == "tool_calls" && firstToolWrapper == nil { - firstToolWrapper = m + if m["type"] == "tool_calls" { + hasLegacyWrapper = true } } if !hasFunctionCall { - t.Fatalf("expected at least one function_call item for responses compatibility, got %#v", responseObj["output"]) + t.Fatalf("expected function_call item, got %#v", responseObj["output"]) } - if firstToolWrapper == nil { - t.Fatalf("expected a tool_calls wrapper item, got %#v", responseObj["output"]) - } - toolCalls, _ := firstToolWrapper["tool_calls"].([]any) - if len(toolCalls) == 0 { - t.Fatalf("expected at least one tool_call in output, got %#v", firstToolWrapper["tool_calls"]) - } - call0, _ := toolCalls[0].(map[string]any) - if call0["type"] != "function" { - t.Fatalf("unexpected tool call type: %#v", call0["type"]) - } - fn, _ := call0["function"].(map[string]any) - if fn["name"] != "read_file" { - t.Fatalf("unexpected tool call name: %#v", fn["name"]) + if hasLegacyWrapper { + t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"]) } if strings.Contains(outputText, `"tool_calls"`) { t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText) } } -func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t *testing.T) { - h := &Handler{} - req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) - rec := httptest.NewRecorder() - - sseLine := func(v string) string { - b, _ := json.Marshal(map[string]any{ - "p": "response/content", - "v": v, - }) - return "data: " + string(b) + "\n" - } - - tail := `{"tool_calls":[{"name":"read_file","input":` - streamBody := sseLine("Before ") + sseLine(tail) + "data: [DONE]\n" - resp := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(streamBody)), - } - - h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) - - completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") - if !ok { - t.Fatalf("expected response.completed event, body=%s", rec.Body.String()) - } - responseObj, _ := completed["response"].(map[string]any) - outputText, _ := responseObj["output_text"].(string) - if strings.Count(outputText, tail) > 1 { - t.Fatalf("expected incomplete tail not to be duplicated, got output_text=%q", outputText) - } -} - -func TestHandleResponsesStreamEmitsReasoningCompatEvents(t *testing.T) { - h := &Handler{} - req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) - rec := httptest.NewRecorder() - - b, _ := json.Marshal(map[string]any{ - "p": "response/thinking_content", - "v": "thought", - }) - streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n" - resp := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(streamBody)), - } - - h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil) - - body := rec.Body.String() - if !strings.Contains(body, "event: response.reasoning.delta") { - t.Fatalf("expected response.reasoning.delta event, body=%s", body) - } - if !strings.Contains(body, "event: response.reasoning_text.delta") { - t.Fatalf("expected response.reasoning_text.delta compatibility event, body=%s", body) - } - if !strings.Contains(body, "event: response.reasoning_text.done") { - t.Fatalf("expected response.reasoning_text.done compatibility event, body=%s", body) - } -} - -func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) { +func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) rec := httptest.NewRecorder() @@ -163,24 +91,28 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) { Body: io.NopCloser(strings.NewReader(streamBody)), } - h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") body := rec.Body.String() + if !strings.Contains(body, "event: response.output_item.added") { + t.Fatalf("expected response.output_item.added event, body=%s", body) + } + if !strings.Contains(body, "event: response.output_item.done") { + t.Fatalf("expected response.output_item.done event, body=%s", body) + } if !strings.Contains(body, "event: response.function_call_arguments.delta") { - t.Fatalf("expected response.function_call_arguments.delta compatibility event, body=%s", body) + t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body) } if !strings.Contains(body, "event: response.function_call_arguments.done") { - t.Fatalf("expected response.function_call_arguments.done compatibility event, body=%s", body) + t.Fatalf("expected response.function_call_arguments.done event, body=%s", body) } + if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") { + t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body) + } + donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done") if !ok { t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body) } - if strings.TrimSpace(asString(donePayload["call_id"])) == "" { - t.Fatalf("expected call_id in response.function_call_arguments.done payload, payload=%#v", donePayload) - } - if strings.TrimSpace(asString(donePayload["response_id"])) == "" { - t.Fatalf("expected response_id in response.function_call_arguments.done payload, payload=%#v", donePayload) - } doneCallID := strings.TrimSpace(asString(donePayload["call_id"])) if doneCallID == "" { t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload) @@ -191,9 +123,6 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) { } responseObj, _ := completed["response"].(map[string]any) output, _ := responseObj["output"].([]any) - if len(output) == 0 { - t.Fatalf("expected non-empty output in response.completed, response=%#v", responseObj) - } var completedCallID string for _, item := range output { m, _ := item.(map[string]any) @@ -213,36 +142,29 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) { } } -func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T) { +func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) rec := httptest.NewRecorder() - sseLine := func(path, v string) string { - b, _ := json.Marshal(map[string]any{ - "p": path, - "v": v, - }) - return "data: " + string(b) + "\n" - } - - streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n" + b, _ := json.Marshal(map[string]any{ + "p": "response/thinking_content", + "v": "thought", + }) + streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n" resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(streamBody)), } - h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}) + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "") body := rec.Body.String() - if !strings.Contains(body, "event: response.reasoning_text.delta") { - t.Fatalf("expected response.reasoning_text.delta event, body=%s", body) + if !strings.Contains(body, "event: response.reasoning.delta") { + t.Fatalf("expected response.reasoning.delta event, body=%s", body) } - if !strings.Contains(body, "event: response.function_call_arguments.done") { - t.Fatalf("expected response.function_call_arguments.done event from thinking channel, body=%s", body) - } - if !strings.Contains(body, "event: response.output_tool_call.done") { - t.Fatalf("expected response.output_tool_call.done event from thinking channel, body=%s", body) + if strings.Contains(body, "event: response.reasoning_text.delta") || strings.Contains(body, "event: response.reasoning_text.done") { + t.Fatalf("did not expect response.reasoning_text.* compatibility events, body=%s", body) } } @@ -267,121 +189,31 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing. Body: io.NopCloser(strings.NewReader(streamBody)), } - h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}) + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "") body := rec.Body.String() - if !strings.Contains(body, "event: response.output_tool_call.done") { - t.Fatalf("expected response.output_tool_call.done event, body=%s", body) - } donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done") if len(donePayloads) != 2 { t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body) } - seenNames := map[string]string{} for _, payload := range donePayloads { name := strings.TrimSpace(asString(payload["name"])) callID := strings.TrimSpace(asString(payload["call_id"])) - args := strings.TrimSpace(asString(payload["arguments"])) - if callID == "" { - t.Fatalf("expected non-empty call_id in done payload: %#v", payload) - } - if strings.Contains(args, `}{"`) { - t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload) - } - if name == "search_webeval_javascript" { - t.Fatalf("unexpected merged tool name in done payload: %#v", payload) - } if name != "search_web" && name != "eval_javascript" { t.Fatalf("unexpected tool name in done payload: %#v", payload) } + if callID == "" { + t.Fatalf("expected non-empty call_id in done payload: %#v", payload) + } seenNames[name] = callID } - if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" { - t.Fatalf("expected done events for both tools, got %#v", seenNames) - } if seenNames["search_web"] == seenNames["eval_javascript"] { t.Fatalf("expected distinct call_id per tool, got %#v", seenNames) } - - completed, ok := extractSSEEventPayload(body, "response.completed") - if !ok { - t.Fatalf("expected response.completed event, body=%s", body) - } - responseObj, _ := completed["response"].(map[string]any) - output, _ := responseObj["output"].([]any) - functionCallIDs := map[string]string{} - for _, item := range output { - m, _ := item.(map[string]any) - if m == nil || m["type"] != "function_call" { - continue - } - name := strings.TrimSpace(asString(m["name"])) - callID := strings.TrimSpace(asString(m["call_id"])) - if name != "" && callID != "" { - functionCallIDs[name] = callID - } - } - if functionCallIDs["search_web"] != seenNames["search_web"] { - t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"]) - } - if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] { - t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"]) - } } -func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) { - h := &Handler{} - req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) - rec := httptest.NewRecorder() - - sseLine := func(path, v string) string { - b, _ := json.Marshal(map[string]any{ - "p": path, - "v": v, - }) - return "data: " + string(b) + "\n" - } - - streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) + - sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) + - "data: [DONE]\n" - resp := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(streamBody)), - } - - h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"}) - - body := rec.Body.String() - if !strings.Contains(body, "event: response.reasoning_text.delta") { - t.Fatalf("expected reasoning stream events, body=%s", body) - } - donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done") - if len(donePayloads) != 2 { - t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body) - } - seen := map[string]bool{} - for _, payload := range donePayloads { - name := strings.TrimSpace(asString(payload["name"])) - if name == "search_webeval_javascript" { - t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload) - } - if name != "search_web" && name != "eval_javascript" { - t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload) - } - args := strings.TrimSpace(asString(payload["arguments"])) - if strings.Contains(args, `}{"`) { - t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload) - } - seen[name] = true - } - if !seen["search_web"] || !seen["eval_javascript"] { - t.Fatalf("expected both tools in thinking channel done events, got %#v", seen) - } -} - -func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T) { +func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) rec := httptest.NewRecorder() @@ -394,32 +226,76 @@ func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T return "data: " + string(b) + "\n" } - streamBody := sseLine("我来调用工具\n") + - sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + - "data: [DONE]\n" + streamBody := sseLine("plain text only") + "data: [DONE]\n" resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(streamBody)), } - h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + policy := util.ToolChoicePolicy{ + Mode: util.ToolChoiceRequired, + Allowed: map[string]struct{}{"read_file": {}}, + } + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "") - completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") - if !ok { - t.Fatalf("expected response.completed event, body=%s", rec.Body.String()) + body := rec.Body.String() + if !strings.Contains(body, "event: response.failed") { + t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body) } - responseObj, _ := completed["response"].(map[string]any) - output, _ := responseObj["output"].([]any) - hasFunctionCall := false - for _, item := range output { - m, _ := item.(map[string]any) - if m != nil && m["type"] == "function_call" { - hasFunctionCall = true - break - } + if strings.Contains(body, "event: response.completed") { + t.Fatalf("did not expect response.completed after failure, body=%s", body) } - if !hasFunctionCall { - t.Fatalf("expected completed output to include function_call when mixed prose contains tool_calls payload, output=%#v", output) +} + +func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") + body := rec.Body.String() + if strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("did not expect function_call events for unknown tool, body=%s", body) + } +} + +func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `data: {"p":"response/content","v":"plain text only"}` + "\n" + + `data: [DONE]` + "\n", + )), + } + policy := util.ToolChoicePolicy{ + Mode: util.ToolChoiceRequired, + Allowed: map[string]struct{}{"read_file": {}}, + } + + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "") + if rec.Code != http.StatusUnprocessableEntity { + t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String()) + } + out := decodeJSONBody(t, rec.Body.String()) + errObj, _ := out["error"].(map[string]any) + if asString(errObj["code"]) != "tool_choice_violation" { + t.Fatalf("expected code=tool_choice_violation, got %#v", out) } } diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go index 7683ee7..1ba957c 100644 --- a/internal/adapter/openai/standard_request.go +++ b/internal/adapter/openai/standard_request.go @@ -23,7 +23,8 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID if responseModel == "" { responseModel = resolvedModel } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID) + toolPolicy := util.DefaultToolChoicePolicy() + finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy) passThrough := collectOpenAIChatPassThrough(req) return util.StandardRequest{ @@ -34,6 +35,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID Messages: messagesRaw, FinalPrompt: finalPrompt, ToolNames: toolNames, + ToolChoice: toolPolicy, Stream: util.ToBool(req["stream"]), Thinking: thinkingEnabled, Search: searchEnabled, @@ -67,7 +69,17 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra if len(messagesRaw) == 0 { return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.") } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID) + toolPolicy, err := parseToolChoicePolicy(req["tool_choice"], req["tools"]) + if err != nil { + return util.StandardRequest{}, err + } + finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy) + if toolPolicy.IsNone() { + toolNames = nil + toolPolicy.Allowed = nil + } else { + toolPolicy.Allowed = namesToSet(toolNames) + } passThrough := collectOpenAIChatPassThrough(req) return util.StandardRequest{ @@ -78,6 +90,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra Messages: messagesRaw, FinalPrompt: finalPrompt, ToolNames: toolNames, + ToolChoice: toolPolicy, Stream: util.ToBool(req["stream"]), Thinking: thinkingEnabled, Search: searchEnabled, @@ -102,3 +115,212 @@ func collectOpenAIChatPassThrough(req map[string]any) map[string]any { } return out } + +func parseToolChoicePolicy(toolChoiceRaw any, toolsRaw any) (util.ToolChoicePolicy, error) { + policy := util.DefaultToolChoicePolicy() + declaredNames := extractDeclaredToolNames(toolsRaw) + declaredSet := namesToSet(declaredNames) + if len(declaredNames) > 0 { + policy.Allowed = declaredSet + } + + if toolChoiceRaw == nil { + return policy, nil + } + + switch v := toolChoiceRaw.(type) { + case string: + switch strings.ToLower(strings.TrimSpace(v)) { + case "", "auto": + policy.Mode = util.ToolChoiceAuto + case "none": + policy.Mode = util.ToolChoiceNone + policy.Allowed = nil + case "required": + policy.Mode = util.ToolChoiceRequired + default: + return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice: %q", v) + } + case map[string]any: + allowedOverride, hasAllowedOverride, err := parseAllowedToolNames(v["allowed_tools"]) + if err != nil { + return util.ToolChoicePolicy{}, err + } + if hasAllowedOverride { + filtered := make([]string, 0, len(allowedOverride)) + for _, name := range allowedOverride { + if _, ok := declaredSet[name]; !ok { + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice.allowed_tools contains undeclared tool %q", name) + } + filtered = append(filtered, name) + } + policy.Allowed = namesToSet(filtered) + } + + typ := strings.ToLower(strings.TrimSpace(asString(v["type"]))) + switch typ { + case "", "auto": + if hasFunctionSelector(v) { + name, err := parseForcedToolName(v) + if err != nil { + return util.ToolChoicePolicy{}, err + } + policy.Mode = util.ToolChoiceForced + policy.ForcedName = name + policy.Allowed = namesToSet([]string{name}) + } else { + policy.Mode = util.ToolChoiceAuto + } + case "none": + policy.Mode = util.ToolChoiceNone + policy.Allowed = nil + case "required": + policy.Mode = util.ToolChoiceRequired + case "function": + name, err := parseForcedToolName(v) + if err != nil { + return util.ToolChoicePolicy{}, err + } + policy.Mode = util.ToolChoiceForced + policy.ForcedName = name + policy.Allowed = namesToSet([]string{name}) + default: + return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice.type: %q", typ) + } + default: + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice must be a string or object") + } + + if policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced { + if len(declaredNames) == 0 { + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice=%s requires non-empty tools.", policy.Mode) + } + } + if policy.Mode == util.ToolChoiceForced { + if _, ok := declaredSet[policy.ForcedName]; !ok { + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice forced function %q is not declared in tools", policy.ForcedName) + } + } + if len(policy.Allowed) == 0 && (policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced) { + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice policy resolved to empty allowed tool set") + } + return policy, nil +} + +func parseForcedToolName(v map[string]any) (string, error) { + if name := strings.TrimSpace(asString(v["name"])); name != "" { + return name, nil + } + if fn, ok := v["function"].(map[string]any); ok { + if name := strings.TrimSpace(asString(fn["name"])); name != "" { + return name, nil + } + } + return "", fmt.Errorf("tool_choice function requires name") +} + +func parseAllowedToolNames(raw any) ([]string, bool, error) { + if raw == nil { + return nil, false, nil + } + collectName := func(v any) string { + if name := strings.TrimSpace(asString(v)); name != "" { + return name + } + if m, ok := v.(map[string]any); ok { + if name := strings.TrimSpace(asString(m["name"])); name != "" { + return name + } + if fn, ok := m["function"].(map[string]any); ok { + if name := strings.TrimSpace(asString(fn["name"])); name != "" { + return name + } + } + } + return "" + } + + names := []string{} + switch x := raw.(type) { + case []any: + for _, item := range x { + name := collectName(item) + if name == "" { + return nil, true, fmt.Errorf("tool_choice.allowed_tools contains invalid item") + } + names = append(names, name) + } + case []string: + for _, item := range x { + name := strings.TrimSpace(item) + if name == "" { + return nil, true, fmt.Errorf("tool_choice.allowed_tools contains empty name") + } + names = append(names, name) + } + default: + return nil, true, fmt.Errorf("tool_choice.allowed_tools must be an array") + } + + if len(names) == 0 { + return nil, true, fmt.Errorf("tool_choice.allowed_tools must not be empty") + } + return names, true, nil +} + +func hasFunctionSelector(v map[string]any) bool { + if strings.TrimSpace(asString(v["name"])) != "" { + return true + } + if fn, ok := v["function"].(map[string]any); ok { + return strings.TrimSpace(asString(fn["name"])) != "" + } + return false +} + +func extractDeclaredToolNames(toolsRaw any) []string { + tools, ok := toolsRaw.([]any) + if !ok || len(tools) == 0 { + return nil + } + out := make([]string, 0, len(tools)) + seen := map[string]struct{}{} + for _, t := range tools { + tool, ok := t.(map[string]any) + if !ok { + continue + } + fn, _ := tool["function"].(map[string]any) + if len(fn) == 0 { + fn = tool + } + name := strings.TrimSpace(asString(fn["name"])) + if name == "" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + out = append(out, name) + } + return out +} + +func namesToSet(names []string) map[string]struct{} { + if len(names) == 0 { + return nil + } + out := make(map[string]struct{}, len(names)) + for _, name := range names { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + continue + } + out[trimmed] = struct{}{} + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go index a876364..60b0922 100644 --- a/internal/adapter/openai/standard_request_test.go +++ b/internal/adapter/openai/standard_request_test.go @@ -4,6 +4,7 @@ import ( "testing" "ds2api/internal/config" + "ds2api/internal/util" ) func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store { @@ -58,3 +59,95 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) { t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages)) } } + +func TestNormalizeOpenAIResponsesRequestToolChoiceRequired(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + }, + "tool_choice": "required", + } + n, err := normalizeOpenAIResponsesRequest(store, req, "") + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ToolChoice.Mode != util.ToolChoiceRequired { + t.Fatalf("expected tool choice mode required, got %q", n.ToolChoice.Mode) + } + if len(n.ToolNames) != 1 || n.ToolNames[0] != "search" { + t.Fatalf("unexpected tool names: %#v", n.ToolNames) + } +} + +func TestNormalizeOpenAIResponsesRequestToolChoiceForcedFunction(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + }, + }, + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "read_file", + }, + }, + }, + "tool_choice": map[string]any{ + "type": "function", + "name": "read_file", + }, + } + n, err := normalizeOpenAIResponsesRequest(store, req, "") + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ToolChoice.Mode != util.ToolChoiceForced { + t.Fatalf("expected tool choice mode forced, got %q", n.ToolChoice.Mode) + } + if n.ToolChoice.ForcedName != "read_file" { + t.Fatalf("expected forced tool name read_file, got %q", n.ToolChoice.ForcedName) + } + if len(n.ToolNames) != 1 || n.ToolNames[0] != "read_file" { + t.Fatalf("expected filtered tool names [read_file], got %#v", n.ToolNames) + } +} + +func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + }, + }, + }, + "tool_choice": map[string]any{ + "type": "function", + "name": "read_file", + }, + } + if _, err := normalizeOpenAIResponsesRequest(store, req, ""); err == nil { + t.Fatalf("expected forced undeclared tool to fail") + } +} diff --git a/internal/format/openai/render_responses.go b/internal/format/openai/render_responses.go index 4fd17c3..1839977 100644 --- a/internal/format/openai/render_responses.go +++ b/internal/format/openai/render_responses.go @@ -27,12 +27,7 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex "text": finalThinking, }) } - formatted := util.FormatOpenAIToolCalls(detected) - output = append(output, toResponsesFunctionCallItems(formatted)...) - output = append(output, map[string]any{ - "type": "tool_calls", - "tool_calls": formatted, - }) + output = append(output, toResponsesFunctionCallItems(detected)...) } else { content := make([]any, 0, 2) if finalThinking != "" { @@ -70,32 +65,23 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex } } -func toResponsesFunctionCallItems(toolCalls []map[string]any) []any { +func toResponsesFunctionCallItems(toolCalls []util.ParsedToolCall) []any { if len(toolCalls) == 0 { return nil } out := make([]any, 0, len(toolCalls)) for _, tc := range toolCalls { - callID, _ := tc["id"].(string) - if strings.TrimSpace(callID) == "" { - callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") - } - name := "" - args := "{}" - if fn, ok := tc["function"].(map[string]any); ok { - if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" { - name = n - } - if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" { - args = a - } + if strings.TrimSpace(tc.Name) == "" { + continue } + argsBytes, _ := json.Marshal(tc.Input) + args := normalizeJSONString(string(argsBytes)) out = append(out, map[string]any{ "id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""), "type": "function_call", - "call_id": callID, - "name": name, - "arguments": normalizeJSONString(args), + "call_id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "name": tc.Name, + "arguments": args, "status": "completed", }) } diff --git a/internal/format/openai/render_stream_events.go b/internal/format/openai/render_stream_events.go index cc62604..40e8c2c 100644 --- a/internal/format/openai/render_stream_events.go +++ b/internal/format/openai/render_stream_events.go @@ -1,5 +1,7 @@ package openai +import "strings" + func BuildResponsesCreatedPayload(responseID, model string) map[string]any { return map[string]any{ "type": "response.created", @@ -11,6 +13,52 @@ func BuildResponsesCreatedPayload(responseID, model string) map[string]any { } } +func BuildResponsesOutputItemAddedPayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_item.added", + "id": responseID, + "response_id": responseID, + "output_index": outputIndex, + "item_id": itemID, + "item": item, + } +} + +func BuildResponsesOutputItemDonePayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_item.done", + "id": responseID, + "response_id": responseID, + "output_index": outputIndex, + "item_id": itemID, + "item": item, + } +} + +func BuildResponsesContentPartAddedPayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any { + return map[string]any{ + "type": "response.content_part.added", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "part": part, + } +} + +func BuildResponsesContentPartDonePayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any { + return map[string]any{ + "type": "response.content_part.done", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "part": part, + } +} + func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any { return map[string]any{ "type": "response.output_text.delta", @@ -29,48 +77,6 @@ func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]an } } -func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any { - return map[string]any{ - "type": "response.reasoning_text.delta", - "id": responseID, - "response_id": responseID, - "item_id": itemID, - "output_index": outputIndex, - "content_index": contentIndex, - "delta": delta, - } -} - -func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any { - return map[string]any{ - "type": "response.reasoning_text.done", - "id": responseID, - "response_id": responseID, - "item_id": itemID, - "output_index": outputIndex, - "content_index": contentIndex, - "text": text, - } -} - -func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { - return map[string]any{ - "type": "response.output_tool_call.delta", - "id": responseID, - "response_id": responseID, - "tool_calls": toolCalls, - } -} - -func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { - return map[string]any{ - "type": "response.output_tool_call.done", - "id": responseID, - "response_id": responseID, - "tool_calls": toolCalls, - } -} - func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any { return map[string]any{ "type": "response.function_call_arguments.delta", @@ -96,6 +102,27 @@ func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, o } } +func BuildResponsesFailedPayload(responseID, model, message, code string) map[string]any { + code = strings.TrimSpace(code) + if code == "" { + code = "api_error" + } + return map[string]any{ + "type": "response.failed", + "id": responseID, + "response_id": responseID, + "object": "response", + "model": model, + "status": "failed", + "error": map[string]any{ + "message": message, + "type": "invalid_request_error", + "code": code, + "param": nil, + }, + } +} + func BuildResponsesCompletedPayload(response map[string]any) map[string]any { responseID, _ := response["id"].(string) return map[string]any{ diff --git a/internal/format/openai/render_test.go b/internal/format/openai/render_test.go index e3bf0dd..b95e739 100644 --- a/internal/format/openai/render_test.go +++ b/internal/format/openai/render_test.go @@ -21,8 +21,8 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) { } output, _ := obj["output"].([]any) - if len(output) != 2 { - t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"]) + if len(output) != 1 { + t.Fatalf("expected function_call output only, got %#v", obj["output"]) } first, _ := output[0].(map[string]any) @@ -32,35 +32,10 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) { if first["call_id"] == "" { t.Fatalf("expected function_call item to have call_id, got %#v", first) } - second, _ := output[1].(map[string]any) - if second["type"] != "tool_calls" { - t.Fatalf("expected second output item type tool_calls, got %#v", second["type"]) + if first["name"] != "search" { + t.Fatalf("unexpected function name: %#v", first["name"]) } - var toolCalls []map[string]any - switch v := second["tool_calls"].(type) { - case []map[string]any: - toolCalls = v - case []any: - toolCalls = make([]map[string]any, 0, len(v)) - for _, item := range v { - m, _ := item.(map[string]any) - if m != nil { - toolCalls = append(toolCalls, m) - } - } - } - if len(toolCalls) != 1 { - t.Fatalf("expected one tool call, got %#v", second["tool_calls"]) - } - tc := toolCalls[0] - if tc["type"] != "function" || tc["id"] == "" { - t.Fatalf("unexpected tool call shape: %#v", tc) - } - fn, _ := tc["function"].(map[string]any) - if fn["name"] != "search" { - t.Fatalf("unexpected function name: %#v", fn["name"]) - } - argsRaw, _ := fn["arguments"].(string) + argsRaw, _ := first["arguments"].(string) var args map[string]any if err := json.Unmarshal([]byte(argsRaw), &args); err != nil { t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err) @@ -86,8 +61,8 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) } output, _ := obj["output"].([]any) - if len(output) != 2 { - t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"]) + if len(output) != 1 { + t.Fatalf("expected function_call output only, got %#v", obj["output"]) } first, _ := output[0].(map[string]any) if first["type"] != "function_call" { @@ -163,8 +138,8 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) { ) output, _ := obj["output"].([]any) - if len(output) != 3 { - t.Fatalf("expected reasoning + function_call + tool_calls outputs, got %#v", obj["output"]) + if len(output) != 2 { + t.Fatalf("expected reasoning + function_call outputs, got %#v", obj["output"]) } first, _ := output[0].(map[string]any) if first["type"] != "reasoning" { @@ -174,8 +149,4 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) { if second["type"] != "function_call" { t.Fatalf("expected second output function_call, got %#v", second["type"]) } - third, _ := output[2].(map[string]any) - if third["type"] != "tool_calls" { - t.Fatalf("expected third output tool_calls, got %#v", third["type"]) - } } diff --git a/internal/util/standard_request.go b/internal/util/standard_request.go index af73acf..84a4c98 100644 --- a/internal/util/standard_request.go +++ b/internal/util/standard_request.go @@ -8,12 +8,48 @@ type StandardRequest struct { Messages []any FinalPrompt string ToolNames []string + ToolChoice ToolChoicePolicy Stream bool Thinking bool Search bool PassThrough map[string]any } +type ToolChoiceMode string + +const ( + ToolChoiceAuto ToolChoiceMode = "auto" + ToolChoiceNone ToolChoiceMode = "none" + ToolChoiceRequired ToolChoiceMode = "required" + ToolChoiceForced ToolChoiceMode = "forced" +) + +type ToolChoicePolicy struct { + Mode ToolChoiceMode + ForcedName string + Allowed map[string]struct{} +} + +func DefaultToolChoicePolicy() ToolChoicePolicy { + return ToolChoicePolicy{Mode: ToolChoiceAuto} +} + +func (p ToolChoicePolicy) IsNone() bool { + return p.Mode == ToolChoiceNone +} + +func (p ToolChoicePolicy) IsRequired() bool { + return p.Mode == ToolChoiceRequired || p.Mode == ToolChoiceForced +} + +func (p ToolChoicePolicy) Allows(name string) bool { + if len(p.Allowed) == 0 { + return true + } + _, ok := p.Allowed[name] + return ok +} + func (r StandardRequest) CompletionPayload(sessionID string) map[string]any { payload := map[string]any{ "chat_session_id": sessionID, diff --git a/internal/util/toolcalls_parse.go b/internal/util/toolcalls_parse.go index ab9fe84..2b8610e 100644 --- a/internal/util/toolcalls_parse.go +++ b/internal/util/toolcalls_parse.go @@ -10,38 +10,62 @@ type ParsedToolCall struct { Input map[string]any `json:"input"` } +type ToolCallParseResult struct { + Calls []ParsedToolCall + SawToolCallSyntax bool + RejectedByPolicy bool + RejectedToolNames []string +} + func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { + return ParseToolCallsDetailed(text, availableToolNames).Calls +} + +func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult { + result := ToolCallParseResult{} if strings.TrimSpace(text) == "" { - return nil + return result } text = stripFencedCodeBlocks(text) if strings.TrimSpace(text) == "" { - return nil + return result } + result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls") candidates := buildToolCallCandidates(text) var parsed []ParsedToolCall for _, candidate := range candidates { if tc := parseToolCallsPayload(candidate); len(tc) > 0 { parsed = tc + result.SawToolCallSyntax = true break } } if len(parsed) == 0 { - return nil + return result } - return filterToolCalls(parsed, availableToolNames) + calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) + result.Calls = calls + result.RejectedToolNames = rejectedNames + result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 + return result } func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall { + return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls +} + +func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult { + result := ToolCallParseResult{} trimmed := strings.TrimSpace(text) if trimmed == "" { - return nil + return result } if looksLikeToolExampleContext(trimmed) { - return nil + return result } + result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls") candidates := []string{trimmed} for _, candidate := range candidates { candidate = strings.TrimSpace(candidate) @@ -52,24 +76,31 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed continue } if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 { - return filterToolCalls(parsed, availableToolNames) + result.SawToolCallSyntax = true + calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) + result.Calls = calls + result.RejectedToolNames = rejectedNames + result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 + return result } } - return nil + return result } -func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall { +func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) { allowed := map[string]struct{}{} for _, name := range availableToolNames { allowed[name] = struct{}{} } out := make([]ParsedToolCall, 0, len(parsed)) + rejectedSet := map[string]struct{}{} for _, tc := range parsed { if tc.Name == "" { continue } if len(allowed) > 0 { if _, ok := allowed[tc.Name]; !ok { + rejectedSet[tc.Name] = struct{}{} continue } } @@ -78,21 +109,11 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par } out = append(out, tc) } - // If the model clearly emitted tool_calls JSON but all names are outside the - // declared set, keep the parsed calls as a fallback so upper layers can still - // intercept structured tool output instead of leaking raw JSON to users. - if len(out) == 0 && len(parsed) > 0 { - for _, tc := range parsed { - if tc.Name == "" { - continue - } - if tc.Input == nil { - tc.Input = map[string]any{} - } - out = append(out, tc) - } + rejected := make([]string, 0, len(rejectedSet)) + for name := range rejectedSet { + rejected = append(rejected, name) } - return out + return out, rejected } func parseToolCallsPayload(payload string) []ParsedToolCall { diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index f7c82d2..b102b41 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -38,14 +38,25 @@ func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) { } } -func TestParseToolCallsKeepsUnknownAsFallback(t *testing.T) { +func TestParseToolCallsRejectsUnknownToolName(t *testing.T) { text := `{"tool_calls":[{"name":"unknown","input":{}}]}` calls := ParseToolCalls(text, []string{"search"}) - if len(calls) != 1 { - t.Fatalf("expected fallback 1 call, got %d", len(calls)) + if len(calls) != 0 { + t.Fatalf("expected unknown tool to be rejected, got %#v", calls) } - if calls[0].Name != "unknown" { - t.Fatalf("unexpected name: %s", calls[0].Name) +} + +func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) { + text := `{"tool_calls":[{"name":"unknown","input":{}}]}` + res := ParseToolCallsDetailed(text, []string{"search"}) + if !res.SawToolCallSyntax { + t.Fatalf("expected SawToolCallSyntax=true, got %#v", res) + } + if !res.RejectedByPolicy { + t.Fatalf("expected RejectedByPolicy=true, got %#v", res) + } + if len(res.Calls) != 0 { + t.Fatalf("expected no calls after policy rejection, got %#v", res.Calls) } } diff --git a/tests/compat/expected/toolcalls_unknown_name.json b/tests/compat/expected/toolcalls_unknown_name.json index 8f79875..97646bf 100644 --- a/tests/compat/expected/toolcalls_unknown_name.json +++ b/tests/compat/expected/toolcalls_unknown_name.json @@ -1,5 +1,3 @@ { - "calls": [ - {"name": "unknown_tool", "input": {"x": 1}} - ] + "calls": [] }