feat: Improve OpenAI tool call handling by passing unknown tool calls as content and filtering streamed tool calls by schema.

This commit is contained in:
CJACK
2026-02-22 19:33:52 +08:00
parent 312728c8b6
commit ae7dce0b32
26 changed files with 1109 additions and 501 deletions

View File

@@ -286,8 +286,10 @@ OpenAI Responses-style endpoint, accepting either `input` or `messages`.
| `instructions` | string | ❌ | Prepended as a system message |
| `stream` | boolean | ❌ | Default `false` |
| `tools` | array | ❌ | Same tool detection/translation policy as chat |
| `tool_choice` | string/object | ❌ | Supports `auto`/`none`/`required` and forced function selection (`{"type":"function","name":"..."}`) |
**Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache.
If `tool_choice=required` and no valid tool call is produced, DS2API returns HTTP `422` (`error.code=tool_choice_violation`).
**Stream (SSE)**: minimal event sequence:
@@ -295,11 +297,26 @@ OpenAI Responses-style endpoint, accepting either `input` or `messages`.
event: response.created
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
event: response.output_item.added
data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
event: response.content_part.added
data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...}
event: response.output_text.delta
data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."}
event: response.output_tool_call.delta
data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]}
event: response.function_call_arguments.delta
data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."}
event: response.function_call_arguments.done
data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"}
event: response.content_part.done
data: {"type":"response.content_part.done","response_id":"resp_xxx",...}
event: response.output_item.done
data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
event: response.completed
data: {"type":"response.completed","response":{...}}
@@ -307,6 +324,9 @@ data: {"type":"response.completed","response":{...}}
data: [DONE]
```
If `tool_choice=required` is violated in stream mode, DS2API emits `response.failed` then `[DONE]` (no `response.completed`).
Unknown tool names (outside declared `tools`) are rejected and will not be emitted as valid tool calls.
### `GET /v1/responses/{response_id}`
Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read).

24
API.md
View File

@@ -286,8 +286,10 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
| `instructions` | string | ❌ | 自动前置为 system 消息 |
| `stream` | boolean | ❌ | 默认 `false` |
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 |
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}` |
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
`tool_choice=required` 且未产出有效工具调用时,返回 HTTP `422``error.code=tool_choice_violation`)。
**流式响应SSE**:最小事件序列如下。
@@ -295,11 +297,26 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
event: response.created
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
event: response.output_item.added
data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
event: response.content_part.added
data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...}
event: response.output_text.delta
data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."}
event: response.output_tool_call.delta
data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]}
event: response.function_call_arguments.delta
data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."}
event: response.function_call_arguments.done
data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"}
event: response.content_part.done
data: {"type":"response.content_part.done","response_id":"resp_xxx",...}
event: response.output_item.done
data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
event: response.completed
data: {"type":"response.completed","response":{...}}
@@ -307,6 +324,9 @@ data: {"type":"response.completed","response":{...}}
data: [DONE]
```
流式场景下若 `tool_choice=required` 违规,会返回 `response.failed` 后结束(不再发送 `response.completed`)。
未在 `tools` 声明中的工具名会被严格拒绝,不会作为有效 tool call 下发。
### `GET /v1/responses/{response_id}`
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。

View File

@@ -8,13 +8,13 @@
语言 / Language: [中文](README.MD) | [English](README.en.md)
将 DeepSeek Web 对话能力转换为 OpenAIClaude 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。
将 DeepSeek Web 对话能力转换为 OpenAIClaude 与 Gemini 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。
## 架构概览
```mermaid
flowchart LR
Client["🖥️ 客户端\n(OpenAI / Claude 兼容)"]
Client["🖥️ 客户端\n(OpenAI / Claude / Gemini 兼容)"]
subgraph DS2API["DS2API 服务"]
direction TB
@@ -24,6 +24,7 @@ flowchart LR
subgraph Adapters["适配器层"]
OA["OpenAI 适配器\n/v1/*"]
CA["Claude 适配器\n/anthropic/*"]
GA["Gemini 适配器\n/v1beta/models/*"]
end
subgraph Support["支撑模块"]
@@ -38,11 +39,11 @@ flowchart LR
DS["☁️ DeepSeek API"]
Client -- "请求" --> CORS --> Auth
Auth --> OA & CA
OA & CA -- "调用" --> DS
Auth --> OA & CA & GA
OA & CA & GA -- "调用" --> DS
Auth --> Admin
OA & CA -. "轮询选账号" .-> Pool
OA & CA -. "计算 PoW" .-> PoW
OA & CA & GA -. "轮询选账号" .-> Pool
OA & CA & GA -. "计算 PoW" .-> PoW
DS -- "响应" --> Client
```
@@ -55,12 +56,13 @@ flowchart LR
| 能力 | 说明 |
| --- | --- |
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` |
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` |
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens`(及快捷路径 `/v1/messages`、`/messages` |
| Gemini 兼容 | `POST /v1beta/models/{model}:generateContent`、`POST /v1beta/models/{model}:streamGenerateContent`(及 `/v1/models/{model}:*` 路径) |
| 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 |
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
| DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 |
| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 |
| Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 |
| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、导入导出、Vercel 同步 |
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
@@ -72,6 +74,7 @@ flowchart LR
| P0 | OpenAI SDKJS/Pythonchat + responses | ✅ |
| P0 | Vercel AI SDKopenai-compatible | ✅ |
| P0 | Anthropic SDKmessages | ✅ |
| P0 | Google Gemini SDKgenerateContent | ✅ |
| P1 | LangChain / LlamaIndex / OpenWebUIOpenAI 兼容接入) | ✅ |
| P2 | MCP 独立桥接层 | 规划中 |
@@ -97,6 +100,10 @@ flowchart LR
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
### Gemini 接口
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling`functionDeclarations` → `functionCall` 输出)。
## 快速开始
### 通用第一步(所有部署方式)
@@ -249,6 +256,14 @@ cp opencode.json.example opencode.json
"claude_model_mapping": {
"fast": "deepseek-chat",
"slow": "deepseek-reasoner"
},
"admin": {
"jwt_expire_hours": 24
},
"runtime": {
"account_max_inflight": 2,
"account_max_queue": 0,
"global_max_inflight": 0
}
}
```
@@ -262,6 +277,8 @@ cp opencode.json.example opencode.json
- `responses.store_ttl_seconds``/v1/responses/{id}` 的内存缓存 TTL
- `embeddings.provider`embedding 提供方(当前内置 `deterministic/mock/builtin`
- `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型
- `admin`管理后台设置JWT 过期时间、密码哈希等),可通过 Admin Settings API 热更新
- `runtime`:运行时参数(并发限制、队列大小),可通过 Admin Settings API 热更新
### 环境变量
@@ -293,7 +310,7 @@ cp opencode.json.example opencode.json
## 鉴权模式
调用业务接口(`/v1/*`、`/anthropic/*`)时支持两种模式:
调用业务接口(`/v1/*`、`/anthropic/*`、Gemini 路由)时支持两种模式:
| 模式 | 说明 |
| --- | --- |
@@ -320,9 +337,10 @@ cp opencode.json.example opencode.json
当请求中带 `tools` 时DS2API 会做防泄漏处理:
1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发)
2. 一旦命中高置信特征(`tool_calls` + `name` + `arguments/input` 起始)就立即输出 `delta.tool_calls`
3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content`
4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出
2. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`
3. 未在 `tools` 声明中的工具名会被严格拒绝,不会下发为有效 tool call
4. `responses` 支持并执行 `tool_choice``auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed`
5. 仅在通过策略校验后才会发出有效工具调用事件,避免错误工具名进入客户端执行链
## 本地开发抓包工具
@@ -362,13 +380,20 @@ ds2api/
│ ├── account/ # 账号池与并发队列
│ ├── adapter/
│ │ ├── openai/ # OpenAI 兼容适配器(含 Tool Call 解析、Vercel 流式 prepare/release
│ │ ── claude/ # Claude 兼容适配器
── admin/ # Admin API handlers
│ │ ── claude/ # Claude 兼容适配器
│ └── gemini/ # Gemini 兼容适配器generateContent / streamGenerateContent
│ ├── admin/ # Admin API handlers含 Settings 热更新)
│ ├── auth/ # 鉴权与 JWT
│ ├── claudeconv/ # Claude 消息格式转换
│ ├── compat/ # 兼容性辅助
│ ├── config/ # 配置加载与热更新
│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM
│ ├── devcapture/ # 开发抓包模块
│ ├── format/ # 输出格式化
│ ├── prompt/ # Prompt 构建
│ ├── server/ # HTTP 路由与中间件chi router
│ ├── sse/ # SSE 解析工具
│ ├── stream/ # 统一流式消费引擎
│ ├── util/ # 通用工具函数
│ └── webui/ # WebUI 静态文件托管与自动构建
├── webui/ # React WebUI 源码Vite + Tailwind

View File

@@ -320,9 +320,10 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
When `tools` is present in the request, DS2API performs anti-leak handling:
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
2. Once high-confidence features are matched (`tool_calls` + `name` + `arguments/input` start), `delta.tool_calls` is emitted immediately
3. Confirmed toolcall JSON fragments are never leaked into `delta.content`
4. Natural language before/after toolcalls keeps original order, with incremental argument output supported
2. In `responses` stream mode, tool calls follow official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
3. Unknown tool names (outside declared `tools`) are rejected and are not emitted as valid tool calls
4. `tool_choice` is enforced on `responses` (`auto`/`none`/`required`/forced function); required violations return HTTP `422` (non-stream) or `response.failed` (stream)
5. Confirmed toolcall JSON fragments are never emitted as valid tool call events unless they pass policy checks
## Local Dev Packet Capture

View File

@@ -33,6 +33,7 @@ type chatStreamRuntime struct {
toolSieve toolStreamSieveState
streamToolCallIDs map[int]string
streamToolNames map[int]string
thinking strings.Builder
text strings.Builder
}
@@ -65,6 +66,7 @@ func newChatStreamRuntime(
bufferToolContent: bufferToolContent,
emitEarlyToolDeltas: emitEarlyToolDeltas,
streamToolCallIDs: map[int]string{},
streamToolNames: map[int]string{},
}
}
@@ -211,7 +213,11 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
if !s.emitEarlyToolDeltas {
continue
}
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames)
if len(filtered) == 0 {
continue
}
formatted := formatIncrementalStreamToolCallDeltas(filtered, s.streamToolCallIDs)
if len(formatted) == 0 {
continue
}

View File

@@ -3,11 +3,18 @@ package openai
import "net/http"
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
writeOpenAIErrorWithCode(w, status, message, "")
}
func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) {
if code == "" {
code = openAIErrorCode(status)
}
writeJSON(w, status, map[string]any{
"error": map[string]any{
"message": message,
"type": openAIErrorType(status),
"code": openAIErrorCode(status),
"code": code,
"param": nil,
},
})

View File

@@ -10,9 +10,23 @@ import (
"ds2api/internal/util"
)
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolChoicePolicy) ([]map[string]any, []string) {
if policy.IsNone() {
return messages, nil
}
toolSchemas := make([]string, 0, len(tools))
names := make([]string, 0, len(tools))
isAllowed := func(name string) bool {
if strings.TrimSpace(name) == "" {
return false
}
if len(policy.Allowed) == 0 {
return true
}
_, ok := policy.Allowed[name]
return ok
}
for _, t := range tools {
tool, ok := t.(map[string]any)
if !ok {
@@ -25,8 +39,9 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
schema, _ := fn["parameters"].(map[string]any)
if name == "" {
name = "unknown"
name = strings.TrimSpace(name)
if !isAllowed(name) {
continue
}
names = append(names, name)
if desc == "" {
@@ -39,6 +54,13 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
return messages, names
}
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
if policy.Mode == util.ToolChoiceRequired {
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
}
if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
toolPrompt += "\n5) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
toolPrompt += "\n6) Do not call any other tool."
}
for i := range messages {
if messages[i]["role"] == "system" {
@@ -85,6 +107,33 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s
return out
}
func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta {
if len(deltas) == 0 {
return nil
}
allowed := namesToSet(allowedNames)
out := make([]toolCallDelta, 0, len(deltas))
for _, d := range deltas {
if d.Name != "" {
if len(allowed) > 0 {
if _, ok := allowed[d.Name]; !ok {
seenNames[d.Index] = "__blocked__"
continue
}
}
seenNames[d.Index] = d.Name
out = append(out, d)
continue
}
name := strings.TrimSpace(seenNames[d.Index])
if name == "" || name == "__blocked__" {
continue
}
out = append(out, d)
}
return out
}
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
if len(calls) == 0 {
return nil

View File

@@ -181,7 +181,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
}
}
func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
@@ -197,16 +197,16 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
if choice["finish_reason"] != "stop" {
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
if msg["content"] != nil {
t.Fatalf("expected content nil, got %#v", msg["content"])
if _, ok := msg["tool_calls"]; ok {
t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"])
}
toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"])
content, _ := msg["content"].(string)
if !strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected unknown tool json to pass through as text, got %#v", content)
}
}
@@ -375,7 +375,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.
}
}
func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
func TestHandleStreamUnknownToolNotIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
@@ -390,29 +390,14 @@ func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String())
}
foundToolIndex := false
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
toolCalls, _ := delta["tool_calls"].([]any)
for _, tc := range toolCalls {
tcm, _ := tc.(map[string]any)
if _, ok := tcm["index"].(float64); ok {
foundToolIndex = true
}
}
}
if !streamHasRawToolJSONContent(frames) {
t.Fatalf("expected raw tool_calls json to remain in content for unknown schema name: %s", rec.Body.String())
}
if !foundToolIndex {
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
if streamFinishReason(frames) != "stop" {
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
}
}

View File

@@ -2,13 +2,18 @@ package openai
import (
"ds2api/internal/deepseek"
"ds2api/internal/util"
)
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy())
}
func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy) (string, []string) {
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
toolNames := []string{}
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
messages, toolNames = injectToolPrompt(messages, tools)
messages, toolNames = injectToolPrompt(messages, tools, toolPolicy)
}
return deepseek.MessagesPrepare(messages), toolNames
}

View File

@@ -73,6 +73,32 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) {
}
}
func TestNormalizeResponsesInputAsMessagesBackfillsToolResultNameFromCallID(t *testing.T) {
msgs := normalizeResponsesInputAsMessages([]any{
map[string]any{
"type": "function_call",
"call_id": "call_999",
"name": "search",
"arguments": `{"q":"golang"}`,
},
map[string]any{
"type": "function_call_output",
"call_id": "call_999",
"output": map[string]any{"ok": true},
},
})
if len(msgs) != 2 {
t.Fatalf("expected two messages, got %d", len(msgs))
}
toolMsg, _ := msgs[1].(map[string]any)
if toolMsg["role"] != "tool" {
t.Fatalf("expected tool role, got %#v", toolMsg)
}
if toolMsg["name"] != "search" {
t.Fatalf("expected tool name backfilled from call_id, got %#v", toolMsg["name"])
}
}
func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
msgs := normalizeResponsesInputAsMessages([]any{
map[string]any{

View File

@@ -11,10 +11,12 @@ import (
"github.com/google/uuid"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/deepseek"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
@@ -67,7 +69,8 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r))
traceID := requestTraceID(r)
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, traceID)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
@@ -96,13 +99,13 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if stdReq.Stream {
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
return
}
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID)
}
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -110,12 +113,26 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
return
}
result := sse.CollectStream(resp, thinkingEnabled, true)
textParsed := util.ParseToolCallsDetailed(result.Text, toolNames)
thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames)
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking")
callCount := len(textParsed.Calls)
if callCount == 0 {
callCount = len(thinkingParsed.Calls)
}
if toolChoice.IsRequired() && callCount == 0 {
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
return
}
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
h.getResponseStore().put(owner, responseID, responseObj)
writeJSON(w, http.StatusOK, responseObj)
}
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -148,6 +165,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
toolNames,
bufferToolContent,
emitEarlyToolDeltas,
toolChoice,
traceID,
func(obj map[string]any) {
h.getResponseStore().put(owner, responseID, obj)
},
@@ -169,3 +188,16 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
},
})
}
func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolicy, parsed util.ToolCallParseResult, channel string) {
if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 {
return
}
config.Logger.Warn(
"[responses] rejected tool calls by policy",
"trace_id", strings.TrimSpace(traceID),
"channel", channel,
"tool_choice_mode", policy.Mode,
"rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","),
)
}

View File

@@ -4,9 +4,15 @@ import (
"encoding/json"
"fmt"
"strings"
"ds2api/internal/config"
)
func normalizeResponsesInputItem(m map[string]any) map[string]any {
return normalizeResponsesInputItemWithState(m, nil)
}
func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[string]string) map[string]any {
if m == nil {
return nil
}
@@ -69,6 +75,15 @@ func normalizeResponsesInputItem(m map[string]any) map[string]any {
out["name"] = name
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
out["name"] = name
} else if callID := strings.TrimSpace(asString(out["tool_call_id"])); callID != "" {
if inferred := strings.TrimSpace(callNameByID[callID]); inferred != "" {
out["name"] = inferred
} else {
config.Logger.Warn(
"[responses] unable to backfill tool result name from call_id",
"call_id", callID,
)
}
}
return out
case "function_call", "tool_call":
@@ -111,6 +126,9 @@ func normalizeResponsesInputItem(m map[string]any) map[string]any {
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
call["id"] = callID
}
if callID := strings.TrimSpace(asString(call["id"])); callID != "" && callNameByID != nil {
callNameByID[callID] = name
}
return map[string]any{
"role": "assistant",
"tool_calls": []any{call},

View File

@@ -59,6 +59,7 @@ func normalizeResponsesInputArray(items []any) []any {
return nil
}
out := make([]any, 0, len(items))
callNameByID := map[string]string{}
fallbackParts := make([]string, 0, len(items))
flushFallback := func() {
if len(fallbackParts) == 0 {
@@ -71,7 +72,7 @@ func normalizeResponsesInputArray(items []any) []any {
for _, item := range items {
switch x := item.(type) {
case map[string]any:
if msg := normalizeResponsesInputItem(x); msg != nil {
if msg := normalizeResponsesInputItemWithState(x, callNameByID); msg != nil {
flushFallback()
out = append(out, msg)
continue

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"strings"
"ds2api/internal/config"
openaifmt "ds2api/internal/format/openai"
"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
@@ -19,6 +20,8 @@ type responsesStreamRuntime struct {
model string
finalPrompt string
toolNames []string
traceID string
toolChoice util.ToolChoicePolicy
thinkingEnabled bool
searchEnabled bool
@@ -32,11 +35,19 @@ type responsesStreamRuntime struct {
thinkingSieve toolStreamSieveState
thinking strings.Builder
text strings.Builder
visibleText strings.Builder
streamToolCallIDs map[int]string
streamFunctionIDs map[int]string
functionDone map[int]bool
functionAdded map[int]bool
functionNames map[int]string
toolCallsDoneSigs map[string]bool
reasoningItemID string
messageItemID string
messageAdded bool
messagePartAdded bool
sequence int
failed bool
persistResponse func(obj map[string]any)
}
@@ -53,6 +64,8 @@ func newResponsesStreamRuntime(
toolNames []string,
bufferToolContent bool,
emitEarlyToolDeltas bool,
toolChoice util.ToolChoicePolicy,
traceID string,
persistResponse func(obj map[string]any),
) *responsesStreamRuntime {
return &responsesStreamRuntime{
@@ -70,7 +83,11 @@ func newResponsesStreamRuntime(
streamToolCallIDs: map[int]string{},
streamFunctionIDs: map[int]string{},
functionDone: map[int]bool{},
functionAdded: map[int]bool{},
functionNames: map[int]string{},
toolCallsDoneSigs: map[string]bool{},
toolChoice: toolChoice,
traceID: traceID,
persistResponse: persistResponse,
}
}
@@ -78,36 +95,59 @@ func newResponsesStreamRuntime(
func (s *responsesStreamRuntime) finalize() {
finalThinking := s.thinking.String()
finalText := s.text.String()
if strings.TrimSpace(finalThinking) != "" {
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
}
if s.bufferToolContent {
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
}
// Compatibility fallback: some streams only emit incremental tool deltas.
// Ensure final function_call_arguments.done is emitted at least once.
if s.toolCallsEmitted {
detected := util.ParseToolCalls(finalText, s.toolNames)
if len(detected) == 0 {
detected = util.ParseToolCalls(finalThinking, s.toolNames)
textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames)
thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames)
detected := textParsed.Calls
if len(detected) == 0 {
detected = thinkingParsed.Calls
}
s.logToolPolicyRejections(textParsed, thinkingParsed)
if len(detected) > 0 {
s.toolCallsEmitted = true
if !s.toolCallsDoneEmitted {
s.emitFunctionCallDoneEvents(detected)
}
if len(detected) > 0 {
if !s.toolCallsDoneEmitted {
s.emitToolCallsDone(detected)
} else {
s.emitFunctionCallDoneEvents(detected)
}
}
s.closeMessageItem()
if s.toolChoice.IsRequired() && !s.hasFunctionCallDone() {
s.failed = true
message := "tool_choice requires at least one valid tool call."
failedResp := map[string]any{
"id": s.responseID,
"type": "response",
"object": "response",
"model": s.model,
"status": "failed",
"output": []any{},
"output_text": "",
"error": map[string]any{
"message": message,
"type": "invalid_request_error",
"code": "tool_choice_violation",
"param": nil,
},
}
if s.persistResponse != nil {
s.persistResponse(failedResp)
}
s.sendEvent("response.failed", openaifmt.BuildResponsesFailedPayload(s.responseID, s.model, message, "tool_choice_violation"))
s.sendDone()
return
}
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
if s.toolCallsEmitted {
s.alignCompletedOutputCallIDs(obj)
}
if s.toolCallsEmitted {
obj["status"] = "completed"
}
if s.persistResponse != nil {
s.persistResponse(obj)
}
@@ -115,6 +155,32 @@ func (s *responsesStreamRuntime) finalize() {
s.sendDone()
}
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) {
logRejected := func(parsed util.ToolCallParseResult, channel string) {
if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 {
return
}
config.Logger.Warn(
"[responses] rejected tool calls by policy",
"trace_id", strings.TrimSpace(s.traceID),
"channel", channel,
"tool_choice_mode", s.toolChoice.Mode,
"rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","),
)
}
logRejected(textParsed, "text")
logRejected(thinkingParsed, "thinking")
}
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
for _, done := range s.functionDone {
if done {
return true
}
}
return false
}
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
if !parsed.Parsed {
return streamengine.ParsedDecision{}
@@ -138,7 +204,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
}
s.thinking.WriteString(p.Text)
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
if s.bufferToolContent {
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
}
@@ -147,7 +212,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
s.text.WriteString(p.Text)
if !s.bufferToolContent {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
s.emitTextDelta(p.Text)
continue
}
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)

View File

@@ -6,7 +6,18 @@ import (
openaifmt "ds2api/internal/format/openai"
)
func (s *responsesStreamRuntime) nextSequence() int {
s.sequence++
return s.sequence
}
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
if payload == nil {
payload = map[string]any{}
}
if _, ok := payload["sequence_number"]; !ok {
payload["sequence_number"] = s.nextSequence()
}
b, _ := json.Marshal(payload)
_, _ = s.w.Write([]byte("event: " + event + "\n"))
_, _ = s.w.Write([]byte("data: "))
@@ -31,22 +42,20 @@ func (s *responsesStreamRuntime) sendDone() {
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
for _, evt := range events {
if emitContent && evt.Content != "" {
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
s.emitTextDelta(evt.Content)
}
if len(evt.ToolCallDeltas) > 0 {
if !s.emitEarlyToolDeltas {
continue
}
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
if len(formatted) == 0 {
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames)
if len(filtered) == 0 {
continue
}
s.toolCallsEmitted = true
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
s.emitFunctionCallDeltaEvents(filtered)
}
if len(evt.ToolCalls) > 0 {
s.emitToolCallsDone(evt.ToolCalls)
s.emitFunctionCallDoneEvents(evt.ToolCalls)
}
}
}

View File

@@ -11,25 +11,101 @@ import (
"github.com/google/uuid"
)
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
if len(calls) == 0 {
func (s *responsesStreamRuntime) ensureMessageItemID() string {
if strings.TrimSpace(s.messageItemID) != "" {
return s.messageItemID
}
s.messageItemID = "msg_" + strings.ReplaceAll(uuid.NewString(), "-", "")
return s.messageItemID
}
func (s *responsesStreamRuntime) messageOutputIndex() int {
if strings.TrimSpace(s.thinking.String()) != "" {
return 1
}
return 0
}
func (s *responsesStreamRuntime) ensureMessageItemAdded() {
if s.messageAdded {
return
}
sig := toolCallListSignature(calls)
if sig != "" && s.toolCallsDoneSigs[sig] {
itemID := s.ensureMessageItemID()
item := map[string]any{
"id": itemID,
"type": "message",
"role": "assistant",
"status": "in_progress",
}
s.sendEvent(
"response.output_item.added",
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.messageOutputIndex(), item),
)
s.messageAdded = true
}
func (s *responsesStreamRuntime) ensureMessageContentPartAdded() {
if s.messagePartAdded {
return
}
if sig != "" {
s.toolCallsDoneSigs[sig] = true
}
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
if len(formatted) == 0 {
s.ensureMessageItemAdded()
s.sendEvent(
"response.content_part.added",
openaifmt.BuildResponsesContentPartAddedPayload(
s.responseID,
s.ensureMessageItemID(),
s.messageOutputIndex(),
0,
map[string]any{"type": "output_text", "text": ""},
),
)
s.messagePartAdded = true
}
func (s *responsesStreamRuntime) emitTextDelta(content string) {
if strings.TrimSpace(content) == "" {
return
}
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
s.emitFunctionCallDoneEvents(calls)
s.ensureMessageContentPartAdded()
s.visibleText.WriteString(content)
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, content))
}
func (s *responsesStreamRuntime) closeMessageItem() {
if !s.messageAdded {
return
}
itemID := s.ensureMessageItemID()
text := s.visibleText.String()
if s.messagePartAdded {
s.sendEvent(
"response.content_part.done",
openaifmt.BuildResponsesContentPartDonePayload(
s.responseID,
itemID,
s.messageOutputIndex(),
0,
map[string]any{"type": "output_text", "text": text},
),
)
s.messagePartAdded = false
}
item := map[string]any{
"id": itemID,
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]any{
{
"type": "output_text",
"text": text,
},
},
}
s.sendEvent(
"response.output_item.done",
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, s.messageOutputIndex(), item),
)
}
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
@@ -65,12 +141,47 @@ func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
return 0
}
func (s *responsesStreamRuntime) functionOutputIndex(callIndex int) int {
return s.functionOutputBaseIndex() + callIndex
}
func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name string) {
if strings.TrimSpace(name) != "" {
s.functionNames[callIndex] = strings.TrimSpace(name)
}
if s.functionAdded[callIndex] {
return
}
fnName := strings.TrimSpace(s.functionNames[callIndex])
if fnName == "" {
return
}
outputIndex := s.functionOutputIndex(callIndex)
itemID := s.ensureFunctionItemID(outputIndex)
callID := s.ensureToolCallID(callIndex)
item := map[string]any{
"id": itemID,
"type": "function_call",
"call_id": callID,
"name": fnName,
"arguments": "{}",
"status": "in_progress",
}
s.sendEvent(
"response.output_item.added",
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, outputIndex, item),
)
s.functionAdded[callIndex] = true
s.toolCallsEmitted = true
}
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
for _, d := range deltas {
s.ensureFunctionItemAdded(d.Index, d.Name)
if strings.TrimSpace(d.Arguments) == "" {
continue
}
outputIndex := s.functionOutputBaseIndex() + d.Index
outputIndex := s.functionOutputIndex(d.Index)
itemID := s.ensureFunctionItemID(outputIndex)
callID := s.ensureToolCallID(d.Index)
s.sendEvent(
@@ -86,6 +197,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
if strings.TrimSpace(tc.Name) == "" {
continue
}
s.ensureFunctionItemAdded(idx, tc.Name)
outputIndex := base + idx
if s.functionDone[outputIndex] {
continue
@@ -93,11 +206,25 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
itemID := s.ensureFunctionItemID(outputIndex)
callID := s.ensureToolCallID(idx)
argsBytes, _ := json.Marshal(tc.Input)
args := string(argsBytes)
s.sendEvent(
"response.function_call_arguments.done",
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, args),
)
item := map[string]any{
"id": itemID,
"type": "function_call",
"call_id": callID,
"name": tc.Name,
"arguments": args,
"status": "completed",
}
s.sendEvent(
"response.output_item.done",
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
)
s.functionDone[outputIndex] = true
s.toolCallsDoneEmitted = true
}
}
@@ -132,41 +259,12 @@ func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any)
if m == nil {
continue
}
typ, _ := m["type"].(string)
switch typ {
case "function_call":
if functionIdx < len(ordered) {
m["call_id"] = ordered[functionIdx]
functionIdx++
}
case "tool_calls":
tcArr, _ := m["tool_calls"].([]any)
for i, raw := range tcArr {
tc, _ := raw.(map[string]any)
if tc == nil {
continue
}
if i < len(ordered) {
tc["id"] = ordered[i]
}
}
if m["type"] != "function_call" {
continue
}
if functionIdx < len(ordered) {
m["call_id"] = ordered[functionIdx]
functionIdx++
}
}
}
func toolCallListSignature(calls []util.ParsedToolCall) string {
if len(calls) == 0 {
return ""
}
var b strings.Builder
for i, tc := range calls {
if i > 0 {
b.WriteString("|")
}
b.WriteString(strings.TrimSpace(tc.Name))
b.WriteString(":")
args, _ := json.Marshal(tc.Input)
b.Write(args)
}
return b.String()
}

View File

@@ -8,6 +8,8 @@ import (
"net/http/httptest"
"strings"
"testing"
"ds2api/internal/util"
)
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
@@ -30,7 +32,7 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
if !ok {
@@ -45,8 +47,8 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
if len(output) == 0 {
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
}
var firstToolWrapper map[string]any
hasFunctionCall := false
hasLegacyWrapper := false
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil {
@@ -55,96 +57,22 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
if m["type"] == "function_call" {
hasFunctionCall = true
}
if m["type"] == "tool_calls" && firstToolWrapper == nil {
firstToolWrapper = m
if m["type"] == "tool_calls" {
hasLegacyWrapper = true
}
}
if !hasFunctionCall {
t.Fatalf("expected at least one function_call item for responses compatibility, got %#v", responseObj["output"])
t.Fatalf("expected function_call item, got %#v", responseObj["output"])
}
if firstToolWrapper == nil {
t.Fatalf("expected a tool_calls wrapper item, got %#v", responseObj["output"])
}
toolCalls, _ := firstToolWrapper["tool_calls"].([]any)
if len(toolCalls) == 0 {
t.Fatalf("expected at least one tool_call in output, got %#v", firstToolWrapper["tool_calls"])
}
call0, _ := toolCalls[0].(map[string]any)
if call0["type"] != "function" {
t.Fatalf("unexpected tool call type: %#v", call0["type"])
}
fn, _ := call0["function"].(map[string]any)
if fn["name"] != "read_file" {
t.Fatalf("unexpected tool call name: %#v", fn["name"])
if hasLegacyWrapper {
t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"])
}
if strings.Contains(outputText, `"tool_calls"`) {
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
}
}
func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
tail := `{"tool_calls":[{"name":"read_file","input":`
streamBody := sseLine("Before ") + sseLine(tail) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
if !ok {
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
}
responseObj, _ := completed["response"].(map[string]any)
outputText, _ := responseObj["output_text"].(string)
if strings.Count(outputText, tail) > 1 {
t.Fatalf("expected incomplete tail not to be duplicated, got output_text=%q", outputText)
}
}
func TestHandleResponsesStreamEmitsReasoningCompatEvents(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
b, _ := json.Marshal(map[string]any{
"p": "response/thinking_content",
"v": "thought",
})
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil)
body := rec.Body.String()
if !strings.Contains(body, "event: response.reasoning.delta") {
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
}
if !strings.Contains(body, "event: response.reasoning_text.delta") {
t.Fatalf("expected response.reasoning_text.delta compatibility event, body=%s", body)
}
if !strings.Contains(body, "event: response.reasoning_text.done") {
t.Fatalf("expected response.reasoning_text.done compatibility event, body=%s", body)
}
}
func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
@@ -163,24 +91,28 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.output_item.added") {
t.Fatalf("expected response.output_item.added event, body=%s", body)
}
if !strings.Contains(body, "event: response.output_item.done") {
t.Fatalf("expected response.output_item.done event, body=%s", body)
}
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
t.Fatalf("expected response.function_call_arguments.delta compatibility event, body=%s", body)
t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body)
}
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected response.function_call_arguments.done compatibility event, body=%s", body)
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
}
if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") {
t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body)
}
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
if !ok {
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
}
if strings.TrimSpace(asString(donePayload["call_id"])) == "" {
t.Fatalf("expected call_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
}
if strings.TrimSpace(asString(donePayload["response_id"])) == "" {
t.Fatalf("expected response_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
}
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
if doneCallID == "" {
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
@@ -191,9 +123,6 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
}
responseObj, _ := completed["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
if len(output) == 0 {
t.Fatalf("expected non-empty output in response.completed, response=%#v", responseObj)
}
var completedCallID string
for _, item := range output {
m, _ := item.(map[string]any)
@@ -213,36 +142,29 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
}
}
func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T) {
func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(path, v string) string {
b, _ := json.Marshal(map[string]any{
"p": path,
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
b, _ := json.Marshal(map[string]any{
"p": "response/thinking_content",
"v": "thought",
})
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"})
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.reasoning_text.delta") {
t.Fatalf("expected response.reasoning_text.delta event, body=%s", body)
if !strings.Contains(body, "event: response.reasoning.delta") {
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
}
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected response.function_call_arguments.done event from thinking channel, body=%s", body)
}
if !strings.Contains(body, "event: response.output_tool_call.done") {
t.Fatalf("expected response.output_tool_call.done event from thinking channel, body=%s", body)
if strings.Contains(body, "event: response.reasoning_text.delta") || strings.Contains(body, "event: response.reasoning_text.done") {
t.Fatalf("did not expect response.reasoning_text.* compatibility events, body=%s", body)
}
}
@@ -267,121 +189,31 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.output_tool_call.done") {
t.Fatalf("expected response.output_tool_call.done event, body=%s", body)
}
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
if len(donePayloads) != 2 {
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
}
seenNames := map[string]string{}
for _, payload := range donePayloads {
name := strings.TrimSpace(asString(payload["name"]))
callID := strings.TrimSpace(asString(payload["call_id"]))
args := strings.TrimSpace(asString(payload["arguments"]))
if callID == "" {
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
}
if strings.Contains(args, `}{"`) {
t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload)
}
if name == "search_webeval_javascript" {
t.Fatalf("unexpected merged tool name in done payload: %#v", payload)
}
if name != "search_web" && name != "eval_javascript" {
t.Fatalf("unexpected tool name in done payload: %#v", payload)
}
if callID == "" {
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
}
seenNames[name] = callID
}
if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" {
t.Fatalf("expected done events for both tools, got %#v", seenNames)
}
if seenNames["search_web"] == seenNames["eval_javascript"] {
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
}
completed, ok := extractSSEEventPayload(body, "response.completed")
if !ok {
t.Fatalf("expected response.completed event, body=%s", body)
}
responseObj, _ := completed["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
functionCallIDs := map[string]string{}
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil || m["type"] != "function_call" {
continue
}
name := strings.TrimSpace(asString(m["name"]))
callID := strings.TrimSpace(asString(m["call_id"]))
if name != "" && callID != "" {
functionCallIDs[name] = callID
}
}
if functionCallIDs["search_web"] != seenNames["search_web"] {
t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"])
}
if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] {
t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"])
}
}
func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(path, v string) string {
b, _ := json.Marshal(map[string]any{
"p": path,
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"})
body := rec.Body.String()
if !strings.Contains(body, "event: response.reasoning_text.delta") {
t.Fatalf("expected reasoning stream events, body=%s", body)
}
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
if len(donePayloads) != 2 {
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
}
seen := map[string]bool{}
for _, payload := range donePayloads {
name := strings.TrimSpace(asString(payload["name"]))
if name == "search_webeval_javascript" {
t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload)
}
if name != "search_web" && name != "eval_javascript" {
t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload)
}
args := strings.TrimSpace(asString(payload["arguments"]))
if strings.Contains(args, `}{"`) {
t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload)
}
seen[name] = true
}
if !seen["search_web"] || !seen["eval_javascript"] {
t.Fatalf("expected both tools in thinking channel done events, got %#v", seen)
}
}
func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T) {
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
@@ -394,32 +226,76 @@ func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T
return "data: " + string(b) + "\n"
}
streamBody := sseLine("我来调用工具\n") +
sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
"data: [DONE]\n"
streamBody := sseLine("plain text only") + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
policy := util.ToolChoicePolicy{
Mode: util.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
if !ok {
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
body := rec.Body.String()
if !strings.Contains(body, "event: response.failed") {
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
}
responseObj, _ := completed["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
hasFunctionCall := false
for _, item := range output {
m, _ := item.(map[string]any)
if m != nil && m["type"] == "function_call" {
hasFunctionCall = true
break
}
if strings.Contains(body, "event: response.completed") {
t.Fatalf("did not expect response.completed after failure, body=%s", body)
}
if !hasFunctionCall {
t.Fatalf("expected completed output to include function_call when mixed prose contains tool_calls payload, output=%#v", output)
}
func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("did not expect function_call events for unknown tool, body=%s", body)
}
}
func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
h := &Handler{}
rec := httptest.NewRecorder()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
`data: [DONE]` + "\n",
)),
}
policy := util.ToolChoicePolicy{
Mode: util.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "")
if rec.Code != http.StatusUnprocessableEntity {
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
}
out := decodeJSONBody(t, rec.Body.String())
errObj, _ := out["error"].(map[string]any)
if asString(errObj["code"]) != "tool_choice_violation" {
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
}
}

View File

@@ -23,7 +23,8 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
if responseModel == "" {
responseModel = resolvedModel
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
toolPolicy := util.DefaultToolChoicePolicy()
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
@@ -34,6 +35,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
ToolChoice: toolPolicy,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
@@ -67,7 +69,17 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
if len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
toolPolicy, err := parseToolChoicePolicy(req["tool_choice"], req["tools"])
if err != nil {
return util.StandardRequest{}, err
}
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
if toolPolicy.IsNone() {
toolNames = nil
toolPolicy.Allowed = nil
} else {
toolPolicy.Allowed = namesToSet(toolNames)
}
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
@@ -78,6 +90,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
ToolChoice: toolPolicy,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
@@ -102,3 +115,212 @@ func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
}
return out
}
func parseToolChoicePolicy(toolChoiceRaw any, toolsRaw any) (util.ToolChoicePolicy, error) {
policy := util.DefaultToolChoicePolicy()
declaredNames := extractDeclaredToolNames(toolsRaw)
declaredSet := namesToSet(declaredNames)
if len(declaredNames) > 0 {
policy.Allowed = declaredSet
}
if toolChoiceRaw == nil {
return policy, nil
}
switch v := toolChoiceRaw.(type) {
case string:
switch strings.ToLower(strings.TrimSpace(v)) {
case "", "auto":
policy.Mode = util.ToolChoiceAuto
case "none":
policy.Mode = util.ToolChoiceNone
policy.Allowed = nil
case "required":
policy.Mode = util.ToolChoiceRequired
default:
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice: %q", v)
}
case map[string]any:
allowedOverride, hasAllowedOverride, err := parseAllowedToolNames(v["allowed_tools"])
if err != nil {
return util.ToolChoicePolicy{}, err
}
if hasAllowedOverride {
filtered := make([]string, 0, len(allowedOverride))
for _, name := range allowedOverride {
if _, ok := declaredSet[name]; !ok {
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice.allowed_tools contains undeclared tool %q", name)
}
filtered = append(filtered, name)
}
policy.Allowed = namesToSet(filtered)
}
typ := strings.ToLower(strings.TrimSpace(asString(v["type"])))
switch typ {
case "", "auto":
if hasFunctionSelector(v) {
name, err := parseForcedToolName(v)
if err != nil {
return util.ToolChoicePolicy{}, err
}
policy.Mode = util.ToolChoiceForced
policy.ForcedName = name
policy.Allowed = namesToSet([]string{name})
} else {
policy.Mode = util.ToolChoiceAuto
}
case "none":
policy.Mode = util.ToolChoiceNone
policy.Allowed = nil
case "required":
policy.Mode = util.ToolChoiceRequired
case "function":
name, err := parseForcedToolName(v)
if err != nil {
return util.ToolChoicePolicy{}, err
}
policy.Mode = util.ToolChoiceForced
policy.ForcedName = name
policy.Allowed = namesToSet([]string{name})
default:
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice.type: %q", typ)
}
default:
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice must be a string or object")
}
if policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced {
if len(declaredNames) == 0 {
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice=%s requires non-empty tools.", policy.Mode)
}
}
if policy.Mode == util.ToolChoiceForced {
if _, ok := declaredSet[policy.ForcedName]; !ok {
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice forced function %q is not declared in tools", policy.ForcedName)
}
}
if len(policy.Allowed) == 0 && (policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced) {
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice policy resolved to empty allowed tool set")
}
return policy, nil
}
func parseForcedToolName(v map[string]any) (string, error) {
if name := strings.TrimSpace(asString(v["name"])); name != "" {
return name, nil
}
if fn, ok := v["function"].(map[string]any); ok {
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
return name, nil
}
}
return "", fmt.Errorf("tool_choice function requires name")
}
func parseAllowedToolNames(raw any) ([]string, bool, error) {
if raw == nil {
return nil, false, nil
}
collectName := func(v any) string {
if name := strings.TrimSpace(asString(v)); name != "" {
return name
}
if m, ok := v.(map[string]any); ok {
if name := strings.TrimSpace(asString(m["name"])); name != "" {
return name
}
if fn, ok := m["function"].(map[string]any); ok {
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
return name
}
}
}
return ""
}
names := []string{}
switch x := raw.(type) {
case []any:
for _, item := range x {
name := collectName(item)
if name == "" {
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains invalid item")
}
names = append(names, name)
}
case []string:
for _, item := range x {
name := strings.TrimSpace(item)
if name == "" {
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains empty name")
}
names = append(names, name)
}
default:
return nil, true, fmt.Errorf("tool_choice.allowed_tools must be an array")
}
if len(names) == 0 {
return nil, true, fmt.Errorf("tool_choice.allowed_tools must not be empty")
}
return names, true, nil
}
func hasFunctionSelector(v map[string]any) bool {
if strings.TrimSpace(asString(v["name"])) != "" {
return true
}
if fn, ok := v["function"].(map[string]any); ok {
return strings.TrimSpace(asString(fn["name"])) != ""
}
return false
}
func extractDeclaredToolNames(toolsRaw any) []string {
tools, ok := toolsRaw.([]any)
if !ok || len(tools) == 0 {
return nil
}
out := make([]string, 0, len(tools))
seen := map[string]struct{}{}
for _, t := range tools {
tool, ok := t.(map[string]any)
if !ok {
continue
}
fn, _ := tool["function"].(map[string]any)
if len(fn) == 0 {
fn = tool
}
name := strings.TrimSpace(asString(fn["name"]))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
out = append(out, name)
}
return out
}
func namesToSet(names []string) map[string]struct{} {
if len(names) == 0 {
return nil
}
out := make(map[string]struct{}, len(names))
for _, name := range names {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
continue
}
out[trimmed] = struct{}{}
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -4,6 +4,7 @@ import (
"testing"
"ds2api/internal/config"
"ds2api/internal/util"
)
func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store {
@@ -58,3 +59,95 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages))
}
}
func TestNormalizeOpenAIResponsesRequestToolChoiceRequired(t *testing.T) {
store := newEmptyStoreForNormalizeTest(t)
req := map[string]any{
"model": "gpt-4o",
"input": "ping",
"tools": []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "search",
"parameters": map[string]any{
"type": "object",
},
},
},
},
"tool_choice": "required",
}
n, err := normalizeOpenAIResponsesRequest(store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if n.ToolChoice.Mode != util.ToolChoiceRequired {
t.Fatalf("expected tool choice mode required, got %q", n.ToolChoice.Mode)
}
if len(n.ToolNames) != 1 || n.ToolNames[0] != "search" {
t.Fatalf("unexpected tool names: %#v", n.ToolNames)
}
}
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedFunction(t *testing.T) {
store := newEmptyStoreForNormalizeTest(t)
req := map[string]any{
"model": "gpt-4o",
"input": "ping",
"tools": []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "search",
},
},
map[string]any{
"type": "function",
"function": map[string]any{
"name": "read_file",
},
},
},
"tool_choice": map[string]any{
"type": "function",
"name": "read_file",
},
}
n, err := normalizeOpenAIResponsesRequest(store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if n.ToolChoice.Mode != util.ToolChoiceForced {
t.Fatalf("expected tool choice mode forced, got %q", n.ToolChoice.Mode)
}
if n.ToolChoice.ForcedName != "read_file" {
t.Fatalf("expected forced tool name read_file, got %q", n.ToolChoice.ForcedName)
}
if len(n.ToolNames) != 1 || n.ToolNames[0] != "read_file" {
t.Fatalf("expected filtered tool names [read_file], got %#v", n.ToolNames)
}
}
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testing.T) {
store := newEmptyStoreForNormalizeTest(t)
req := map[string]any{
"model": "gpt-4o",
"input": "ping",
"tools": []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "search",
},
},
},
"tool_choice": map[string]any{
"type": "function",
"name": "read_file",
},
}
if _, err := normalizeOpenAIResponsesRequest(store, req, ""); err == nil {
t.Fatalf("expected forced undeclared tool to fail")
}
}

View File

@@ -27,12 +27,7 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
"text": finalThinking,
})
}
formatted := util.FormatOpenAIToolCalls(detected)
output = append(output, toResponsesFunctionCallItems(formatted)...)
output = append(output, map[string]any{
"type": "tool_calls",
"tool_calls": formatted,
})
output = append(output, toResponsesFunctionCallItems(detected)...)
} else {
content := make([]any, 0, 2)
if finalThinking != "" {
@@ -70,32 +65,23 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
}
}
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
func toResponsesFunctionCallItems(toolCalls []util.ParsedToolCall) []any {
if len(toolCalls) == 0 {
return nil
}
out := make([]any, 0, len(toolCalls))
for _, tc := range toolCalls {
callID, _ := tc["id"].(string)
if strings.TrimSpace(callID) == "" {
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
}
name := ""
args := "{}"
if fn, ok := tc["function"].(map[string]any); ok {
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
name = n
}
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
args = a
}
if strings.TrimSpace(tc.Name) == "" {
continue
}
argsBytes, _ := json.Marshal(tc.Input)
args := normalizeJSONString(string(argsBytes))
out = append(out, map[string]any{
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"type": "function_call",
"call_id": callID,
"name": name,
"arguments": normalizeJSONString(args),
"call_id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"name": tc.Name,
"arguments": args,
"status": "completed",
})
}

View File

@@ -1,5 +1,7 @@
package openai
import "strings"
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
return map[string]any{
"type": "response.created",
@@ -11,6 +13,52 @@ func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
}
}
func BuildResponsesOutputItemAddedPayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any {
return map[string]any{
"type": "response.output_item.added",
"id": responseID,
"response_id": responseID,
"output_index": outputIndex,
"item_id": itemID,
"item": item,
}
}
func BuildResponsesOutputItemDonePayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any {
return map[string]any{
"type": "response.output_item.done",
"id": responseID,
"response_id": responseID,
"output_index": outputIndex,
"item_id": itemID,
"item": item,
}
}
func BuildResponsesContentPartAddedPayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any {
return map[string]any{
"type": "response.content_part.added",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"part": part,
}
}
func BuildResponsesContentPartDonePayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any {
return map[string]any{
"type": "response.content_part.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"part": part,
}
}
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
return map[string]any{
"type": "response.output_text.delta",
@@ -29,48 +77,6 @@ func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]an
}
}
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.delta",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"delta": delta,
}
}
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
return map[string]any{
"type": "response.reasoning_text.done",
"id": responseID,
"response_id": responseID,
"item_id": itemID,
"output_index": outputIndex,
"content_index": contentIndex,
"text": text,
}
}
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.delta",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
return map[string]any{
"type": "response.output_tool_call.done",
"id": responseID,
"response_id": responseID,
"tool_calls": toolCalls,
}
}
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
return map[string]any{
"type": "response.function_call_arguments.delta",
@@ -96,6 +102,27 @@ func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, o
}
}
func BuildResponsesFailedPayload(responseID, model, message, code string) map[string]any {
code = strings.TrimSpace(code)
if code == "" {
code = "api_error"
}
return map[string]any{
"type": "response.failed",
"id": responseID,
"response_id": responseID,
"object": "response",
"model": model,
"status": "failed",
"error": map[string]any{
"message": message,
"type": "invalid_request_error",
"code": code,
"param": nil,
},
}
}
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
responseID, _ := response["id"].(string)
return map[string]any{

View File

@@ -21,8 +21,8 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
}
output, _ := obj["output"].([]any)
if len(output) != 2 {
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
if len(output) != 1 {
t.Fatalf("expected function_call output only, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
@@ -32,35 +32,10 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
if first["call_id"] == "" {
t.Fatalf("expected function_call item to have call_id, got %#v", first)
}
second, _ := output[1].(map[string]any)
if second["type"] != "tool_calls" {
t.Fatalf("expected second output item type tool_calls, got %#v", second["type"])
if first["name"] != "search" {
t.Fatalf("unexpected function name: %#v", first["name"])
}
var toolCalls []map[string]any
switch v := second["tool_calls"].(type) {
case []map[string]any:
toolCalls = v
case []any:
toolCalls = make([]map[string]any, 0, len(v))
for _, item := range v {
m, _ := item.(map[string]any)
if m != nil {
toolCalls = append(toolCalls, m)
}
}
}
if len(toolCalls) != 1 {
t.Fatalf("expected one tool call, got %#v", second["tool_calls"])
}
tc := toolCalls[0]
if tc["type"] != "function" || tc["id"] == "" {
t.Fatalf("unexpected tool call shape: %#v", tc)
}
fn, _ := tc["function"].(map[string]any)
if fn["name"] != "search" {
t.Fatalf("unexpected function name: %#v", fn["name"])
}
argsRaw, _ := fn["arguments"].(string)
argsRaw, _ := first["arguments"].(string)
var args map[string]any
if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
@@ -86,8 +61,8 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T)
}
output, _ := obj["output"].([]any)
if len(output) != 2 {
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
if len(output) != 1 {
t.Fatalf("expected function_call output only, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "function_call" {
@@ -163,8 +138,8 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
)
output, _ := obj["output"].([]any)
if len(output) != 3 {
t.Fatalf("expected reasoning + function_call + tool_calls outputs, got %#v", obj["output"])
if len(output) != 2 {
t.Fatalf("expected reasoning + function_call outputs, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "reasoning" {
@@ -174,8 +149,4 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
if second["type"] != "function_call" {
t.Fatalf("expected second output function_call, got %#v", second["type"])
}
third, _ := output[2].(map[string]any)
if third["type"] != "tool_calls" {
t.Fatalf("expected third output tool_calls, got %#v", third["type"])
}
}

View File

@@ -8,12 +8,48 @@ type StandardRequest struct {
Messages []any
FinalPrompt string
ToolNames []string
ToolChoice ToolChoicePolicy
Stream bool
Thinking bool
Search bool
PassThrough map[string]any
}
type ToolChoiceMode string
const (
ToolChoiceAuto ToolChoiceMode = "auto"
ToolChoiceNone ToolChoiceMode = "none"
ToolChoiceRequired ToolChoiceMode = "required"
ToolChoiceForced ToolChoiceMode = "forced"
)
type ToolChoicePolicy struct {
Mode ToolChoiceMode
ForcedName string
Allowed map[string]struct{}
}
func DefaultToolChoicePolicy() ToolChoicePolicy {
return ToolChoicePolicy{Mode: ToolChoiceAuto}
}
func (p ToolChoicePolicy) IsNone() bool {
return p.Mode == ToolChoiceNone
}
func (p ToolChoicePolicy) IsRequired() bool {
return p.Mode == ToolChoiceRequired || p.Mode == ToolChoiceForced
}
func (p ToolChoicePolicy) Allows(name string) bool {
if len(p.Allowed) == 0 {
return true
}
_, ok := p.Allowed[name]
return ok
}
func (r StandardRequest) CompletionPayload(sessionID string) map[string]any {
payload := map[string]any{
"chat_session_id": sessionID,

View File

@@ -10,38 +10,62 @@ type ParsedToolCall struct {
Input map[string]any `json:"input"`
}
type ToolCallParseResult struct {
Calls []ParsedToolCall
SawToolCallSyntax bool
RejectedByPolicy bool
RejectedToolNames []string
}
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseToolCallsDetailed(text, availableToolNames).Calls
}
func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
result := ToolCallParseResult{}
if strings.TrimSpace(text) == "" {
return nil
return result
}
text = stripFencedCodeBlocks(text)
if strings.TrimSpace(text) == "" {
return nil
return result
}
result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls")
candidates := buildToolCallCandidates(text)
var parsed []ParsedToolCall
for _, candidate := range candidates {
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
parsed = tc
result.SawToolCallSyntax = true
break
}
}
if len(parsed) == 0 {
return nil
return result
}
return filterToolCalls(parsed, availableToolNames)
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result
}
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
}
func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
result := ToolCallParseResult{}
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return nil
return result
}
if looksLikeToolExampleContext(trimmed) {
return nil
return result
}
result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls")
candidates := []string{trimmed}
for _, candidate := range candidates {
candidate = strings.TrimSpace(candidate)
@@ -52,24 +76,31 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed
continue
}
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
return filterToolCalls(parsed, availableToolNames)
result.SawToolCallSyntax = true
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result
}
}
return nil
return result
}
func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall {
func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) {
allowed := map[string]struct{}{}
for _, name := range availableToolNames {
allowed[name] = struct{}{}
}
out := make([]ParsedToolCall, 0, len(parsed))
rejectedSet := map[string]struct{}{}
for _, tc := range parsed {
if tc.Name == "" {
continue
}
if len(allowed) > 0 {
if _, ok := allowed[tc.Name]; !ok {
rejectedSet[tc.Name] = struct{}{}
continue
}
}
@@ -78,21 +109,11 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par
}
out = append(out, tc)
}
// If the model clearly emitted tool_calls JSON but all names are outside the
// declared set, keep the parsed calls as a fallback so upper layers can still
// intercept structured tool output instead of leaking raw JSON to users.
if len(out) == 0 && len(parsed) > 0 {
for _, tc := range parsed {
if tc.Name == "" {
continue
}
if tc.Input == nil {
tc.Input = map[string]any{}
}
out = append(out, tc)
}
rejected := make([]string, 0, len(rejectedSet))
for name := range rejectedSet {
rejected = append(rejected, name)
}
return out
return out, rejected
}
func parseToolCallsPayload(payload string) []ParsedToolCall {

View File

@@ -38,14 +38,25 @@ func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) {
}
}
func TestParseToolCallsKeepsUnknownAsFallback(t *testing.T) {
func TestParseToolCallsRejectsUnknownToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected fallback 1 call, got %d", len(calls))
if len(calls) != 0 {
t.Fatalf("expected unknown tool to be rejected, got %#v", calls)
}
if calls[0].Name != "unknown" {
t.Fatalf("unexpected name: %s", calls[0].Name)
}
func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) {
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
res := ParseToolCallsDetailed(text, []string{"search"})
if !res.SawToolCallSyntax {
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
}
if !res.RejectedByPolicy {
t.Fatalf("expected RejectedByPolicy=true, got %#v", res)
}
if len(res.Calls) != 0 {
t.Fatalf("expected no calls after policy rejection, got %#v", res.Calls)
}
}

View File

@@ -1,5 +1,3 @@
{
"calls": [
{"name": "unknown_tool", "input": {"x": 1}}
]
"calls": []
}