mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
feat: Improve OpenAI tool call handling by passing unknown tool calls as content and filtering streamed tool calls by schema.
This commit is contained in:
24
API.en.md
24
API.en.md
@@ -286,8 +286,10 @@ OpenAI Responses-style endpoint, accepting either `input` or `messages`.
|
||||
| `instructions` | string | ❌ | Prepended as a system message |
|
||||
| `stream` | boolean | ❌ | Default `false` |
|
||||
| `tools` | array | ❌ | Same tool detection/translation policy as chat |
|
||||
| `tool_choice` | string/object | ❌ | Supports `auto`/`none`/`required` and forced function selection (`{"type":"function","name":"..."}`) |
|
||||
|
||||
**Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache.
|
||||
If `tool_choice=required` and no valid tool call is produced, DS2API returns HTTP `422` (`error.code=tool_choice_violation`).
|
||||
|
||||
**Stream (SSE)**: minimal event sequence:
|
||||
|
||||
@@ -295,11 +297,26 @@ OpenAI Responses-style endpoint, accepting either `input` or `messages`.
|
||||
event: response.created
|
||||
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
|
||||
|
||||
event: response.output_item.added
|
||||
data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||
|
||||
event: response.content_part.added
|
||||
data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...}
|
||||
|
||||
event: response.output_text.delta
|
||||
data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."}
|
||||
|
||||
event: response.output_tool_call.delta
|
||||
data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]}
|
||||
event: response.function_call_arguments.delta
|
||||
data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."}
|
||||
|
||||
event: response.function_call_arguments.done
|
||||
data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"}
|
||||
|
||||
event: response.content_part.done
|
||||
data: {"type":"response.content_part.done","response_id":"resp_xxx",...}
|
||||
|
||||
event: response.output_item.done
|
||||
data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||
|
||||
event: response.completed
|
||||
data: {"type":"response.completed","response":{...}}
|
||||
@@ -307,6 +324,9 @@ data: {"type":"response.completed","response":{...}}
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
If `tool_choice=required` is violated in stream mode, DS2API emits `response.failed` then `[DONE]` (no `response.completed`).
|
||||
Unknown tool names (outside declared `tools`) are rejected and will not be emitted as valid tool calls.
|
||||
|
||||
### `GET /v1/responses/{response_id}`
|
||||
|
||||
Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read).
|
||||
|
||||
24
API.md
24
API.md
@@ -286,8 +286,10 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
|
||||
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 |
|
||||
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) |
|
||||
|
||||
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
||||
当 `tool_choice=required` 且未产出有效工具调用时,返回 HTTP `422`(`error.code=tool_choice_violation`)。
|
||||
|
||||
**流式响应(SSE)**:最小事件序列如下。
|
||||
|
||||
@@ -295,11 +297,26 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
|
||||
event: response.created
|
||||
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
|
||||
|
||||
event: response.output_item.added
|
||||
data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||
|
||||
event: response.content_part.added
|
||||
data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...}
|
||||
|
||||
event: response.output_text.delta
|
||||
data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."}
|
||||
|
||||
event: response.output_tool_call.delta
|
||||
data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]}
|
||||
event: response.function_call_arguments.delta
|
||||
data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."}
|
||||
|
||||
event: response.function_call_arguments.done
|
||||
data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"}
|
||||
|
||||
event: response.content_part.done
|
||||
data: {"type":"response.content_part.done","response_id":"resp_xxx",...}
|
||||
|
||||
event: response.output_item.done
|
||||
data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||
|
||||
event: response.completed
|
||||
data: {"type":"response.completed","response":{...}}
|
||||
@@ -307,6 +324,9 @@ data: {"type":"response.completed","response":{...}}
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
流式场景下若 `tool_choice=required` 违规,会返回 `response.failed` 后结束(不再发送 `response.completed`)。
|
||||
未在 `tools` 声明中的工具名会被严格拒绝,不会作为有效 tool call 下发。
|
||||
|
||||
### `GET /v1/responses/{response_id}`
|
||||
|
||||
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。
|
||||
|
||||
53
README.MD
53
README.MD
@@ -8,13 +8,13 @@
|
||||
|
||||
语言 / Language: [中文](README.MD) | [English](README.en.md)
|
||||
|
||||
将 DeepSeek Web 对话能力转换为 OpenAI 与 Claude 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。
|
||||
将 DeepSeek Web 对话能力转换为 OpenAI、Claude 与 Gemini 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。
|
||||
|
||||
## 架构概览
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
Client["🖥️ 客户端\n(OpenAI / Claude 兼容)"]
|
||||
Client["🖥️ 客户端\n(OpenAI / Claude / Gemini 兼容)"]
|
||||
|
||||
subgraph DS2API["DS2API 服务"]
|
||||
direction TB
|
||||
@@ -24,6 +24,7 @@ flowchart LR
|
||||
subgraph Adapters["适配器层"]
|
||||
OA["OpenAI 适配器\n/v1/*"]
|
||||
CA["Claude 适配器\n/anthropic/*"]
|
||||
GA["Gemini 适配器\n/v1beta/models/*"]
|
||||
end
|
||||
|
||||
subgraph Support["支撑模块"]
|
||||
@@ -38,11 +39,11 @@ flowchart LR
|
||||
DS["☁️ DeepSeek API"]
|
||||
|
||||
Client -- "请求" --> CORS --> Auth
|
||||
Auth --> OA & CA
|
||||
OA & CA -- "调用" --> DS
|
||||
Auth --> OA & CA & GA
|
||||
OA & CA & GA -- "调用" --> DS
|
||||
Auth --> Admin
|
||||
OA & CA -. "轮询选账号" .-> Pool
|
||||
OA & CA -. "计算 PoW" .-> PoW
|
||||
OA & CA & GA -. "轮询选账号" .-> Pool
|
||||
OA & CA & GA -. "计算 PoW" .-> PoW
|
||||
DS -- "响应" --> Client
|
||||
```
|
||||
|
||||
@@ -55,12 +56,13 @@ flowchart LR
|
||||
| 能力 | 说明 |
|
||||
| --- | --- |
|
||||
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` |
|
||||
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` |
|
||||
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens`(及快捷路径 `/v1/messages`、`/messages`) |
|
||||
| Gemini 兼容 | `POST /v1beta/models/{model}:generateContent`、`POST /v1beta/models/{model}:streamGenerateContent`(及 `/v1/models/{model}:*` 路径) |
|
||||
| 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 |
|
||||
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
|
||||
| DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 |
|
||||
| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 |
|
||||
| Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 |
|
||||
| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、导入导出、Vercel 同步 |
|
||||
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
|
||||
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
|
||||
|
||||
@@ -72,6 +74,7 @@ flowchart LR
|
||||
| P0 | OpenAI SDK(JS/Python,chat + responses) | ✅ |
|
||||
| P0 | Vercel AI SDK(openai-compatible) | ✅ |
|
||||
| P0 | Anthropic SDK(messages) | ✅ |
|
||||
| P0 | Google Gemini SDK(generateContent) | ✅ |
|
||||
| P1 | LangChain / LlamaIndex / OpenWebUI(OpenAI 兼容接入) | ✅ |
|
||||
| P2 | MCP 独立桥接层 | 规划中 |
|
||||
|
||||
@@ -97,6 +100,10 @@ flowchart LR
|
||||
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
||||
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
||||
|
||||
### Gemini 接口
|
||||
|
||||
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 通用第一步(所有部署方式)
|
||||
@@ -249,6 +256,14 @@ cp opencode.json.example opencode.json
|
||||
"claude_model_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
"admin": {
|
||||
"jwt_expire_hours": 24
|
||||
},
|
||||
"runtime": {
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -262,6 +277,8 @@ cp opencode.json.example opencode.json
|
||||
- `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL
|
||||
- `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`)
|
||||
- `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型
|
||||
- `admin`:管理后台设置(JWT 过期时间、密码哈希等),可通过 Admin Settings API 热更新
|
||||
- `runtime`:运行时参数(并发限制、队列大小),可通过 Admin Settings API 热更新
|
||||
|
||||
### 环境变量
|
||||
|
||||
@@ -293,7 +310,7 @@ cp opencode.json.example opencode.json
|
||||
|
||||
## 鉴权模式
|
||||
|
||||
调用业务接口(`/v1/*`、`/anthropic/*`)时支持两种模式:
|
||||
调用业务接口(`/v1/*`、`/anthropic/*`、Gemini 路由)时支持两种模式:
|
||||
|
||||
| 模式 | 说明 |
|
||||
| --- | --- |
|
||||
@@ -320,9 +337,10 @@ cp opencode.json.example opencode.json
|
||||
当请求中带 `tools` 时,DS2API 会做防泄漏处理:
|
||||
|
||||
1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发)
|
||||
2. 一旦命中高置信特征(`tool_calls` + `name` + `arguments/input` 起始)就立即输出 `delta.tool_calls`
|
||||
3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content`
|
||||
4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出
|
||||
2. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`)
|
||||
3. 未在 `tools` 声明中的工具名会被严格拒绝,不会下发为有效 tool call
|
||||
4. `responses` 支持并执行 `tool_choice`(`auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed`
|
||||
5. 仅在通过策略校验后才会发出有效工具调用事件,避免错误工具名进入客户端执行链
|
||||
|
||||
## 本地开发抓包工具
|
||||
|
||||
@@ -362,13 +380,20 @@ ds2api/
|
||||
│ ├── account/ # 账号池与并发队列
|
||||
│ ├── adapter/
|
||||
│ │ ├── openai/ # OpenAI 兼容适配器(含 Tool Call 解析、Vercel 流式 prepare/release)
|
||||
│ │ └── claude/ # Claude 兼容适配器
|
||||
│ ├── admin/ # Admin API handlers
|
||||
│ │ ├── claude/ # Claude 兼容适配器
|
||||
│ │ └── gemini/ # Gemini 兼容适配器(generateContent / streamGenerateContent)
|
||||
│ ├── admin/ # Admin API handlers(含 Settings 热更新)
|
||||
│ ├── auth/ # 鉴权与 JWT
|
||||
│ ├── claudeconv/ # Claude 消息格式转换
|
||||
│ ├── compat/ # 兼容性辅助
|
||||
│ ├── config/ # 配置加载与热更新
|
||||
│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM
|
||||
│ ├── devcapture/ # 开发抓包模块
|
||||
│ ├── format/ # 输出格式化
|
||||
│ ├── prompt/ # Prompt 构建
|
||||
│ ├── server/ # HTTP 路由与中间件(chi router)
|
||||
│ ├── sse/ # SSE 解析工具
|
||||
│ ├── stream/ # 统一流式消费引擎
|
||||
│ ├── util/ # 通用工具函数
|
||||
│ └── webui/ # WebUI 静态文件托管与自动构建
|
||||
├── webui/ # React WebUI 源码(Vite + Tailwind)
|
||||
|
||||
@@ -320,9 +320,10 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||
|
||||
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||
2. Once high-confidence features are matched (`tool_calls` + `name` + `arguments/input` start), `delta.tool_calls` is emitted immediately
|
||||
3. Confirmed toolcall JSON fragments are never leaked into `delta.content`
|
||||
4. Natural language before/after toolcalls keeps original order, with incremental argument output supported
|
||||
2. In `responses` stream mode, tool calls follow official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
||||
3. Unknown tool names (outside declared `tools`) are rejected and are not emitted as valid tool calls
|
||||
4. `tool_choice` is enforced on `responses` (`auto`/`none`/`required`/forced function); required violations return HTTP `422` (non-stream) or `response.failed` (stream)
|
||||
5. Confirmed toolcall JSON fragments are never emitted as valid tool call events unless they pass policy checks
|
||||
|
||||
## Local Dev Packet Capture
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ type chatStreamRuntime struct {
|
||||
|
||||
toolSieve toolStreamSieveState
|
||||
streamToolCallIDs map[int]string
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
}
|
||||
@@ -65,6 +66,7 @@ func newChatStreamRuntime(
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamToolNames: map[int]string{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,7 +213,11 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames)
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(filtered, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -3,11 +3,18 @@ package openai
|
||||
import "net/http"
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeOpenAIErrorWithCode(w, status, message, "")
|
||||
}
|
||||
|
||||
func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) {
|
||||
if code == "" {
|
||||
code = openAIErrorCode(status)
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": openAIErrorType(status),
|
||||
"code": openAIErrorCode(status),
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -10,9 +10,23 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
||||
func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolChoicePolicy) ([]map[string]any, []string) {
|
||||
if policy.IsNone() {
|
||||
return messages, nil
|
||||
}
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
isAllowed := func(name string) bool {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return false
|
||||
}
|
||||
if len(policy.Allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := policy.Allowed[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
@@ -25,8 +39,9 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
schema, _ := fn["parameters"].(map[string]any)
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
name = strings.TrimSpace(name)
|
||||
if !isAllowed(name) {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
if desc == "" {
|
||||
@@ -39,6 +54,13 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
||||
if policy.Mode == util.ToolChoiceRequired {
|
||||
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
|
||||
}
|
||||
if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
|
||||
toolPrompt += "\n5) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
||||
toolPrompt += "\n6) Do not call any other tool."
|
||||
}
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
@@ -85,6 +107,33 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s
|
||||
return out
|
||||
}
|
||||
|
||||
func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
allowed := namesToSet(allowedNames)
|
||||
out := make([]toolCallDelta, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name != "" {
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[d.Name]; !ok {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
continue
|
||||
}
|
||||
}
|
||||
seenNames[d.Index] = d.Name
|
||||
out = append(out, d)
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(seenNames[d.Index])
|
||||
if name == "" || name == "__blocked__" {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -181,7 +181,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
@@ -197,16 +197,16 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil, got %#v", msg["content"])
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"])
|
||||
}
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"])
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected unknown tool json to pass through as text, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,7 +375,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
func TestHandleStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
@@ -390,29 +390,14 @@ func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String())
|
||||
}
|
||||
foundToolIndex := false
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if _, ok := tcm["index"].(float64); ok {
|
||||
foundToolIndex = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("expected raw tool_calls json to remain in content for unknown schema name: %s", rec.Body.String())
|
||||
}
|
||||
if !foundToolIndex {
|
||||
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,18 @@ package openai
|
||||
|
||||
import (
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||
return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy())
|
||||
}
|
||||
|
||||
func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
|
||||
toolNames := []string{}
|
||||
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools)
|
||||
messages, toolNames = injectToolPrompt(messages, tools, toolPolicy)
|
||||
}
|
||||
return deepseek.MessagesPrepare(messages), toolNames
|
||||
}
|
||||
|
||||
@@ -73,6 +73,32 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesBackfillsToolResultNameFromCallID(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": "call_999",
|
||||
"name": "search",
|
||||
"arguments": `{"q":"golang"}`,
|
||||
},
|
||||
map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": "call_999",
|
||||
"output": map[string]any{"ok": true},
|
||||
},
|
||||
})
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected two messages, got %d", len(msgs))
|
||||
}
|
||||
toolMsg, _ := msgs[1].(map[string]any)
|
||||
if toolMsg["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", toolMsg)
|
||||
}
|
||||
if toolMsg["name"] != "search" {
|
||||
t.Fatalf("expected tool name backfilled from call_id, got %#v", toolMsg["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
|
||||
@@ -11,10 +11,12 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -67,7 +69,8 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r))
|
||||
traceID := requestTraceID(r)
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, traceID)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
@@ -96,13 +99,13 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if stdReq.Stream {
|
||||
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -110,12 +113,26 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
textParsed := util.ParseToolCallsDetailed(result.Text, toolNames)
|
||||
thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames)
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking")
|
||||
|
||||
callCount := len(textParsed.Calls)
|
||||
if callCount == 0 {
|
||||
callCount = len(thinkingParsed.Calls)
|
||||
}
|
||||
if toolChoice.IsRequired() && callCount == 0 {
|
||||
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||
return
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -148,6 +165,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
toolNames,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
toolChoice,
|
||||
traceID,
|
||||
func(obj map[string]any) {
|
||||
h.getResponseStore().put(owner, responseID, obj)
|
||||
},
|
||||
@@ -169,3 +188,16 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolicy, parsed util.ToolCallParseResult, channel string) {
|
||||
if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 {
|
||||
return
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[responses] rejected tool calls by policy",
|
||||
"trace_id", strings.TrimSpace(traceID),
|
||||
"channel", channel,
|
||||
"tool_choice_mode", policy.Mode,
|
||||
"rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -4,9 +4,15 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
return normalizeResponsesInputItemWithState(m, nil)
|
||||
}
|
||||
|
||||
func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[string]string) map[string]any {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -69,6 +75,15 @@ func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
out["name"] = name
|
||||
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
|
||||
out["name"] = name
|
||||
} else if callID := strings.TrimSpace(asString(out["tool_call_id"])); callID != "" {
|
||||
if inferred := strings.TrimSpace(callNameByID[callID]); inferred != "" {
|
||||
out["name"] = inferred
|
||||
} else {
|
||||
config.Logger.Warn(
|
||||
"[responses] unable to backfill tool result name from call_id",
|
||||
"call_id", callID,
|
||||
)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case "function_call", "tool_call":
|
||||
@@ -111,6 +126,9 @@ func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(call["id"])); callID != "" && callNameByID != nil {
|
||||
callNameByID[callID] = name
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{call},
|
||||
|
||||
@@ -59,6 +59,7 @@ func normalizeResponsesInputArray(items []any) []any {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(items))
|
||||
callNameByID := map[string]string{}
|
||||
fallbackParts := make([]string, 0, len(items))
|
||||
flushFallback := func() {
|
||||
if len(fallbackParts) == 0 {
|
||||
@@ -71,7 +72,7 @@ func normalizeResponsesInputArray(items []any) []any {
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(x); msg != nil {
|
||||
if msg := normalizeResponsesInputItemWithState(x, callNameByID); msg != nil {
|
||||
flushFallback()
|
||||
out = append(out, msg)
|
||||
continue
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
@@ -19,6 +20,8 @@ type responsesStreamRuntime struct {
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
traceID string
|
||||
toolChoice util.ToolChoicePolicy
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
@@ -32,11 +35,19 @@ type responsesStreamRuntime struct {
|
||||
thinkingSieve toolStreamSieveState
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
visibleText strings.Builder
|
||||
streamToolCallIDs map[int]string
|
||||
streamFunctionIDs map[int]string
|
||||
functionDone map[int]bool
|
||||
functionAdded map[int]bool
|
||||
functionNames map[int]string
|
||||
toolCallsDoneSigs map[string]bool
|
||||
reasoningItemID string
|
||||
messageItemID string
|
||||
messageAdded bool
|
||||
messagePartAdded bool
|
||||
sequence int
|
||||
failed bool
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
@@ -53,6 +64,8 @@ func newResponsesStreamRuntime(
|
||||
toolNames []string,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
toolChoice util.ToolChoicePolicy,
|
||||
traceID string,
|
||||
persistResponse func(obj map[string]any),
|
||||
) *responsesStreamRuntime {
|
||||
return &responsesStreamRuntime{
|
||||
@@ -70,7 +83,11 @@ func newResponsesStreamRuntime(
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamFunctionIDs: map[int]string{},
|
||||
functionDone: map[int]bool{},
|
||||
functionAdded: map[int]bool{},
|
||||
functionNames: map[int]string{},
|
||||
toolCallsDoneSigs: map[string]bool{},
|
||||
toolChoice: toolChoice,
|
||||
traceID: traceID,
|
||||
persistResponse: persistResponse,
|
||||
}
|
||||
}
|
||||
@@ -78,36 +95,59 @@ func newResponsesStreamRuntime(
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
|
||||
}
|
||||
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
||||
}
|
||||
// Compatibility fallback: some streams only emit incremental tool deltas.
|
||||
// Ensure final function_call_arguments.done is emitted at least once.
|
||||
if s.toolCallsEmitted {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) == 0 {
|
||||
detected = util.ParseToolCalls(finalThinking, s.toolNames)
|
||||
|
||||
textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames)
|
||||
thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames)
|
||||
detected := textParsed.Calls
|
||||
if len(detected) == 0 {
|
||||
detected = thinkingParsed.Calls
|
||||
}
|
||||
s.logToolPolicyRejections(textParsed, thinkingParsed)
|
||||
|
||||
if len(detected) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
if !s.toolCallsDoneEmitted {
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
if !s.toolCallsDoneEmitted {
|
||||
s.emitToolCallsDone(detected)
|
||||
} else {
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
}
|
||||
|
||||
s.closeMessageItem()
|
||||
|
||||
if s.toolChoice.IsRequired() && !s.hasFunctionCallDone() {
|
||||
s.failed = true
|
||||
message := "tool_choice requires at least one valid tool call."
|
||||
failedResp := map[string]any{
|
||||
"id": s.responseID,
|
||||
"type": "response",
|
||||
"object": "response",
|
||||
"model": s.model,
|
||||
"status": "failed",
|
||||
"output": []any{},
|
||||
"output_text": "",
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": "invalid_request_error",
|
||||
"code": "tool_choice_violation",
|
||||
"param": nil,
|
||||
},
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(failedResp)
|
||||
}
|
||||
s.sendEvent("response.failed", openaifmt.BuildResponsesFailedPayload(s.responseID, s.model, message, "tool_choice_violation"))
|
||||
s.sendDone()
|
||||
return
|
||||
}
|
||||
|
||||
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
|
||||
if s.toolCallsEmitted {
|
||||
s.alignCompletedOutputCallIDs(obj)
|
||||
}
|
||||
if s.toolCallsEmitted {
|
||||
obj["status"] = "completed"
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
@@ -115,6 +155,32 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) {
|
||||
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
||||
if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 {
|
||||
return
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[responses] rejected tool calls by policy",
|
||||
"trace_id", strings.TrimSpace(s.traceID),
|
||||
"channel", channel,
|
||||
"tool_choice_mode", s.toolChoice.Mode,
|
||||
"rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","),
|
||||
)
|
||||
}
|
||||
logRejected(textParsed, "text")
|
||||
logRejected(thinkingParsed, "thinking")
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
||||
for _, done := range s.functionDone {
|
||||
if done {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
@@ -138,7 +204,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
||||
}
|
||||
@@ -147,7 +212,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if !s.bufferToolContent {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
|
||||
s.emitTextDelta(p.Text)
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
|
||||
|
||||
@@ -6,7 +6,18 @@ import (
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) nextSequence() int {
|
||||
s.sequence++
|
||||
return s.sequence
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
|
||||
if payload == nil {
|
||||
payload = map[string]any{}
|
||||
}
|
||||
if _, ok := payload["sequence_number"]; !ok {
|
||||
payload["sequence_number"] = s.nextSequence()
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("event: " + event + "\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
@@ -31,22 +42,20 @@ func (s *responsesStreamRuntime) sendDone() {
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
s.emitTextDelta(evt.Content)
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames)
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
|
||||
s.emitFunctionCallDeltaEvents(filtered)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitToolCallsDone(evt.ToolCalls)
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,25 +11,101 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
|
||||
if len(calls) == 0 {
|
||||
func (s *responsesStreamRuntime) ensureMessageItemID() string {
|
||||
if strings.TrimSpace(s.messageItemID) != "" {
|
||||
return s.messageItemID
|
||||
}
|
||||
s.messageItemID = "msg_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
return s.messageItemID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) messageOutputIndex() int {
|
||||
if strings.TrimSpace(s.thinking.String()) != "" {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageItemAdded() {
|
||||
if s.messageAdded {
|
||||
return
|
||||
}
|
||||
sig := toolCallListSignature(calls)
|
||||
if sig != "" && s.toolCallsDoneSigs[sig] {
|
||||
itemID := s.ensureMessageItemID()
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "in_progress",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.added",
|
||||
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.messageOutputIndex(), item),
|
||||
)
|
||||
s.messageAdded = true
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageContentPartAdded() {
|
||||
if s.messagePartAdded {
|
||||
return
|
||||
}
|
||||
if sig != "" {
|
||||
s.toolCallsDoneSigs[sig] = true
|
||||
}
|
||||
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
s.ensureMessageItemAdded()
|
||||
s.sendEvent(
|
||||
"response.content_part.added",
|
||||
openaifmt.BuildResponsesContentPartAddedPayload(
|
||||
s.responseID,
|
||||
s.ensureMessageItemID(),
|
||||
s.messageOutputIndex(),
|
||||
0,
|
||||
map[string]any{"type": "output_text", "text": ""},
|
||||
),
|
||||
)
|
||||
s.messagePartAdded = true
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitTextDelta(content string) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDoneEvents(calls)
|
||||
s.ensureMessageContentPartAdded()
|
||||
s.visibleText.WriteString(content)
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, content))
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) closeMessageItem() {
|
||||
if !s.messageAdded {
|
||||
return
|
||||
}
|
||||
itemID := s.ensureMessageItemID()
|
||||
text := s.visibleText.String()
|
||||
if s.messagePartAdded {
|
||||
s.sendEvent(
|
||||
"response.content_part.done",
|
||||
openaifmt.BuildResponsesContentPartDonePayload(
|
||||
s.responseID,
|
||||
itemID,
|
||||
s.messageOutputIndex(),
|
||||
0,
|
||||
map[string]any{"type": "output_text", "text": text},
|
||||
),
|
||||
)
|
||||
s.messagePartAdded = false
|
||||
}
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, s.messageOutputIndex(), item),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
|
||||
@@ -65,12 +141,47 @@ func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) functionOutputIndex(callIndex int) int {
|
||||
return s.functionOutputBaseIndex() + callIndex
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name string) {
|
||||
if strings.TrimSpace(name) != "" {
|
||||
s.functionNames[callIndex] = strings.TrimSpace(name)
|
||||
}
|
||||
if s.functionAdded[callIndex] {
|
||||
return
|
||||
}
|
||||
fnName := strings.TrimSpace(s.functionNames[callIndex])
|
||||
if fnName == "" {
|
||||
return
|
||||
}
|
||||
outputIndex := s.functionOutputIndex(callIndex)
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(callIndex)
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": fnName,
|
||||
"arguments": "{}",
|
||||
"status": "in_progress",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.added",
|
||||
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
s.functionAdded[callIndex] = true
|
||||
s.toolCallsEmitted = true
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
|
||||
for _, d := range deltas {
|
||||
s.ensureFunctionItemAdded(d.Index, d.Name)
|
||||
if strings.TrimSpace(d.Arguments) == "" {
|
||||
continue
|
||||
}
|
||||
outputIndex := s.functionOutputBaseIndex() + d.Index
|
||||
outputIndex := s.functionOutputIndex(d.Index)
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(d.Index)
|
||||
s.sendEvent(
|
||||
@@ -86,6 +197,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
s.ensureFunctionItemAdded(idx, tc.Name)
|
||||
|
||||
outputIndex := base + idx
|
||||
if s.functionDone[outputIndex] {
|
||||
continue
|
||||
@@ -93,11 +206,25 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(idx)
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
args := string(argsBytes)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.done",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, args),
|
||||
)
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": tc.Name,
|
||||
"arguments": args,
|
||||
"status": "completed",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
s.functionDone[outputIndex] = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,41 +259,12 @@ func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
switch typ {
|
||||
case "function_call":
|
||||
if functionIdx < len(ordered) {
|
||||
m["call_id"] = ordered[functionIdx]
|
||||
functionIdx++
|
||||
}
|
||||
case "tool_calls":
|
||||
tcArr, _ := m["tool_calls"].([]any)
|
||||
for i, raw := range tcArr {
|
||||
tc, _ := raw.(map[string]any)
|
||||
if tc == nil {
|
||||
continue
|
||||
}
|
||||
if i < len(ordered) {
|
||||
tc["id"] = ordered[i]
|
||||
}
|
||||
}
|
||||
if m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
if functionIdx < len(ordered) {
|
||||
m["call_id"] = ordered[functionIdx]
|
||||
functionIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallListSignature(calls []util.ParsedToolCall) string {
|
||||
if len(calls) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i, tc := range calls {
|
||||
if i > 0 {
|
||||
b.WriteString("|")
|
||||
}
|
||||
b.WriteString(strings.TrimSpace(tc.Name))
|
||||
b.WriteString(":")
|
||||
args, _ := json.Marshal(tc.Input)
|
||||
b.Write(args)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
|
||||
@@ -30,7 +32,7 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
@@ -45,8 +47,8 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
|
||||
}
|
||||
var firstToolWrapper map[string]any
|
||||
hasFunctionCall := false
|
||||
hasLegacyWrapper := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
@@ -55,96 +57,22 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
|
||||
if m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
if m["type"] == "tool_calls" && firstToolWrapper == nil {
|
||||
firstToolWrapper = m
|
||||
if m["type"] == "tool_calls" {
|
||||
hasLegacyWrapper = true
|
||||
}
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected at least one function_call item for responses compatibility, got %#v", responseObj["output"])
|
||||
t.Fatalf("expected function_call item, got %#v", responseObj["output"])
|
||||
}
|
||||
if firstToolWrapper == nil {
|
||||
t.Fatalf("expected a tool_calls wrapper item, got %#v", responseObj["output"])
|
||||
}
|
||||
toolCalls, _ := firstToolWrapper["tool_calls"].([]any)
|
||||
if len(toolCalls) == 0 {
|
||||
t.Fatalf("expected at least one tool_call in output, got %#v", firstToolWrapper["tool_calls"])
|
||||
}
|
||||
call0, _ := toolCalls[0].(map[string]any)
|
||||
if call0["type"] != "function" {
|
||||
t.Fatalf("unexpected tool call type: %#v", call0["type"])
|
||||
}
|
||||
fn, _ := call0["function"].(map[string]any)
|
||||
if fn["name"] != "read_file" {
|
||||
t.Fatalf("unexpected tool call name: %#v", fn["name"])
|
||||
if hasLegacyWrapper {
|
||||
t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"])
|
||||
}
|
||||
if strings.Contains(outputText, `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
tail := `{"tool_calls":[{"name":"read_file","input":`
|
||||
streamBody := sseLine("Before ") + sseLine(tail) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
outputText, _ := responseObj["output_text"].(string)
|
||||
if strings.Count(outputText, tail) > 1 {
|
||||
t.Fatalf("expected incomplete tail not to be duplicated, got output_text=%q", outputText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsReasoningCompatEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "thought",
|
||||
})
|
||||
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil)
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected response.reasoning_text.delta compatibility event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.reasoning_text.done") {
|
||||
t.Fatalf("expected response.reasoning_text.done compatibility event, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
|
||||
func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -163,24 +91,28 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_item.added") {
|
||||
t.Fatalf("expected response.output_item.added event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||
t.Fatalf("expected response.function_call_arguments.delta compatibility event, body=%s", body)
|
||||
t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done compatibility event, body=%s", body)
|
||||
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body)
|
||||
}
|
||||
|
||||
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
|
||||
if !ok {
|
||||
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
|
||||
}
|
||||
if strings.TrimSpace(asString(donePayload["call_id"])) == "" {
|
||||
t.Fatalf("expected call_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
|
||||
}
|
||||
if strings.TrimSpace(asString(donePayload["response_id"])) == "" {
|
||||
t.Fatalf("expected response_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
|
||||
}
|
||||
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
|
||||
if doneCallID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
|
||||
@@ -191,9 +123,6 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected non-empty output in response.completed, response=%#v", responseObj)
|
||||
}
|
||||
var completedCallID string
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
@@ -213,36 +142,29 @@ func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T) {
|
||||
func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "thought",
|
||||
})
|
||||
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"})
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected response.reasoning_text.delta event, body=%s", body)
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done event from thinking channel, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("expected response.output_tool_call.done event from thinking channel, body=%s", body)
|
||||
if strings.Contains(body, "event: response.reasoning_text.delta") || strings.Contains(body, "event: response.reasoning_text.done") {
|
||||
t.Fatalf("did not expect response.reasoning_text.* compatibility events, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,121 +189,31 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("expected response.output_tool_call.done event, body=%s", body)
|
||||
}
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
|
||||
seenNames := map[string]string{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
callID := strings.TrimSpace(asString(payload["call_id"]))
|
||||
args := strings.TrimSpace(asString(payload["arguments"]))
|
||||
if callID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||
}
|
||||
if strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload)
|
||||
}
|
||||
if name == "search_webeval_javascript" {
|
||||
t.Fatalf("unexpected merged tool name in done payload: %#v", payload)
|
||||
}
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in done payload: %#v", payload)
|
||||
}
|
||||
if callID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||
}
|
||||
seenNames[name] = callID
|
||||
}
|
||||
if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" {
|
||||
t.Fatalf("expected done events for both tools, got %#v", seenNames)
|
||||
}
|
||||
if seenNames["search_web"] == seenNames["eval_javascript"] {
|
||||
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
|
||||
}
|
||||
|
||||
completed, ok := extractSSEEventPayload(body, "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
functionCallIDs := map[string]string{}
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil || m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(asString(m["name"]))
|
||||
callID := strings.TrimSpace(asString(m["call_id"]))
|
||||
if name != "" && callID != "" {
|
||||
functionCallIDs[name] = callID
|
||||
}
|
||||
}
|
||||
if functionCallIDs["search_web"] != seenNames["search_web"] {
|
||||
t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"])
|
||||
}
|
||||
if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] {
|
||||
t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||
sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected reasoning stream events, body=%s", body)
|
||||
}
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
seen := map[string]bool{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
if name == "search_webeval_javascript" {
|
||||
t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
args := strings.TrimSpace(asString(payload["arguments"]))
|
||||
if strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
seen[name] = true
|
||||
}
|
||||
if !seen["search_web"] || !seen["eval_javascript"] {
|
||||
t.Fatalf("expected both tools in thinking channel done events, got %#v", seen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T) {
|
||||
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -394,32 +226,76 @@ func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("我来调用工具\n") +
|
||||
sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
streamBody := sseLine("plain text only") + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
break
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected completed output to include function_call when mixed prose contains tool_calls payload, output=%#v", output)
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for unknown tool, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "tool_choice_violation" {
|
||||
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
|
||||
if responseModel == "" {
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
|
||||
toolPolicy := util.DefaultToolChoicePolicy()
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
@@ -34,6 +35,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
@@ -67,7 +69,17 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
|
||||
toolPolicy, err := parseToolChoicePolicy(req["tool_choice"], req["tools"])
|
||||
if err != nil {
|
||||
return util.StandardRequest{}, err
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
if toolPolicy.IsNone() {
|
||||
toolNames = nil
|
||||
toolPolicy.Allowed = nil
|
||||
} else {
|
||||
toolPolicy.Allowed = namesToSet(toolNames)
|
||||
}
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
@@ -78,6 +90,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
@@ -102,3 +115,212 @@ func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseToolChoicePolicy(toolChoiceRaw any, toolsRaw any) (util.ToolChoicePolicy, error) {
|
||||
policy := util.DefaultToolChoicePolicy()
|
||||
declaredNames := extractDeclaredToolNames(toolsRaw)
|
||||
declaredSet := namesToSet(declaredNames)
|
||||
if len(declaredNames) > 0 {
|
||||
policy.Allowed = declaredSet
|
||||
}
|
||||
|
||||
if toolChoiceRaw == nil {
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
switch v := toolChoiceRaw.(type) {
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "", "auto":
|
||||
policy.Mode = util.ToolChoiceAuto
|
||||
case "none":
|
||||
policy.Mode = util.ToolChoiceNone
|
||||
policy.Allowed = nil
|
||||
case "required":
|
||||
policy.Mode = util.ToolChoiceRequired
|
||||
default:
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice: %q", v)
|
||||
}
|
||||
case map[string]any:
|
||||
allowedOverride, hasAllowedOverride, err := parseAllowedToolNames(v["allowed_tools"])
|
||||
if err != nil {
|
||||
return util.ToolChoicePolicy{}, err
|
||||
}
|
||||
if hasAllowedOverride {
|
||||
filtered := make([]string, 0, len(allowedOverride))
|
||||
for _, name := range allowedOverride {
|
||||
if _, ok := declaredSet[name]; !ok {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice.allowed_tools contains undeclared tool %q", name)
|
||||
}
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
policy.Allowed = namesToSet(filtered)
|
||||
}
|
||||
|
||||
typ := strings.ToLower(strings.TrimSpace(asString(v["type"])))
|
||||
switch typ {
|
||||
case "", "auto":
|
||||
if hasFunctionSelector(v) {
|
||||
name, err := parseForcedToolName(v)
|
||||
if err != nil {
|
||||
return util.ToolChoicePolicy{}, err
|
||||
}
|
||||
policy.Mode = util.ToolChoiceForced
|
||||
policy.ForcedName = name
|
||||
policy.Allowed = namesToSet([]string{name})
|
||||
} else {
|
||||
policy.Mode = util.ToolChoiceAuto
|
||||
}
|
||||
case "none":
|
||||
policy.Mode = util.ToolChoiceNone
|
||||
policy.Allowed = nil
|
||||
case "required":
|
||||
policy.Mode = util.ToolChoiceRequired
|
||||
case "function":
|
||||
name, err := parseForcedToolName(v)
|
||||
if err != nil {
|
||||
return util.ToolChoicePolicy{}, err
|
||||
}
|
||||
policy.Mode = util.ToolChoiceForced
|
||||
policy.ForcedName = name
|
||||
policy.Allowed = namesToSet([]string{name})
|
||||
default:
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice.type: %q", typ)
|
||||
}
|
||||
default:
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice must be a string or object")
|
||||
}
|
||||
|
||||
if policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced {
|
||||
if len(declaredNames) == 0 {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice=%s requires non-empty tools.", policy.Mode)
|
||||
}
|
||||
}
|
||||
if policy.Mode == util.ToolChoiceForced {
|
||||
if _, ok := declaredSet[policy.ForcedName]; !ok {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice forced function %q is not declared in tools", policy.ForcedName)
|
||||
}
|
||||
}
|
||||
if len(policy.Allowed) == 0 && (policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced) {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice policy resolved to empty allowed tool set")
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func parseForcedToolName(v map[string]any) (string, error) {
|
||||
if name := strings.TrimSpace(asString(v["name"])); name != "" {
|
||||
return name, nil
|
||||
}
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("tool_choice function requires name")
|
||||
}
|
||||
|
||||
func parseAllowedToolNames(raw any) ([]string, bool, error) {
|
||||
if raw == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
collectName := func(v any) string {
|
||||
if name := strings.TrimSpace(asString(v)); name != "" {
|
||||
return name
|
||||
}
|
||||
if m, ok := v.(map[string]any); ok {
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
return name
|
||||
}
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
names := []string{}
|
||||
switch x := raw.(type) {
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
name := collectName(item)
|
||||
if name == "" {
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains invalid item")
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
case []string:
|
||||
for _, item := range x {
|
||||
name := strings.TrimSpace(item)
|
||||
if name == "" {
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains empty name")
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
default:
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools must be an array")
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools must not be empty")
|
||||
}
|
||||
return names, true, nil
|
||||
}
|
||||
|
||||
func hasFunctionSelector(v map[string]any) bool {
|
||||
if strings.TrimSpace(asString(v["name"])) != "" {
|
||||
return true
|
||||
}
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
return strings.TrimSpace(asString(fn["name"])) != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractDeclaredToolNames(toolsRaw any) []string {
|
||||
tools, ok := toolsRaw.([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(tools))
|
||||
seen := map[string]struct{}{}
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tool["function"].(map[string]any)
|
||||
if len(fn) == 0 {
|
||||
fn = tool
|
||||
}
|
||||
name := strings.TrimSpace(asString(fn["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
out = append(out, name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func namesToSet(names []string) map[string]struct{} {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]struct{}, len(names))
|
||||
for _, name := range names {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out[trimmed] = struct{}{}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store {
|
||||
@@ -58,3 +59,95 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
|
||||
t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceRequired(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": "required",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ToolChoice.Mode != util.ToolChoiceRequired {
|
||||
t.Fatalf("expected tool choice mode required, got %q", n.ToolChoice.Mode)
|
||||
}
|
||||
if len(n.ToolNames) != 1 || n.ToolNames[0] != "search" {
|
||||
t.Fatalf("unexpected tool names: %#v", n.ToolNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedFunction(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "read_file",
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"name": "read_file",
|
||||
},
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ToolChoice.Mode != util.ToolChoiceForced {
|
||||
t.Fatalf("expected tool choice mode forced, got %q", n.ToolChoice.Mode)
|
||||
}
|
||||
if n.ToolChoice.ForcedName != "read_file" {
|
||||
t.Fatalf("expected forced tool name read_file, got %q", n.ToolChoice.ForcedName)
|
||||
}
|
||||
if len(n.ToolNames) != 1 || n.ToolNames[0] != "read_file" {
|
||||
t.Fatalf("expected filtered tool names [read_file], got %#v", n.ToolNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"name": "read_file",
|
||||
},
|
||||
}
|
||||
if _, err := normalizeOpenAIResponsesRequest(store, req, ""); err == nil {
|
||||
t.Fatalf("expected forced undeclared tool to fail")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,12 +27,7 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
|
||||
"text": finalThinking,
|
||||
})
|
||||
}
|
||||
formatted := util.FormatOpenAIToolCalls(detected)
|
||||
output = append(output, toResponsesFunctionCallItems(formatted)...)
|
||||
output = append(output, map[string]any{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": formatted,
|
||||
})
|
||||
output = append(output, toResponsesFunctionCallItems(detected)...)
|
||||
} else {
|
||||
content := make([]any, 0, 2)
|
||||
if finalThinking != "" {
|
||||
@@ -70,32 +65,23 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
|
||||
}
|
||||
}
|
||||
|
||||
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
|
||||
func toResponsesFunctionCallItems(toolCalls []util.ParsedToolCall) []any {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
callID, _ := tc["id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
}
|
||||
name := ""
|
||||
args := "{}"
|
||||
if fn, ok := tc["function"].(map[string]any); ok {
|
||||
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
|
||||
name = n
|
||||
}
|
||||
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
|
||||
args = a
|
||||
}
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
args := normalizeJSONString(string(argsBytes))
|
||||
out = append(out, map[string]any{
|
||||
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": normalizeJSONString(args),
|
||||
"call_id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"name": tc.Name,
|
||||
"arguments": args,
|
||||
"status": "completed",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.created",
|
||||
@@ -11,6 +13,52 @@ func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesOutputItemAddedPayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_item.added",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"output_index": outputIndex,
|
||||
"item_id": itemID,
|
||||
"item": item,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesOutputItemDonePayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_item.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"output_index": outputIndex,
|
||||
"item_id": itemID,
|
||||
"item": item,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesContentPartAddedPayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.content_part.added",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"part": part,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesContentPartDonePayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.content_part.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"part": part,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
@@ -29,48 +77,6 @@ func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]an
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"text": text,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.function_call_arguments.delta",
|
||||
@@ -96,6 +102,27 @@ func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, o
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFailedPayload(responseID, model, message, code string) map[string]any {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
code = "api_error"
|
||||
}
|
||||
return map[string]any{
|
||||
"type": "response.failed",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"status": "failed",
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": "invalid_request_error",
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
|
||||
responseID, _ := response["id"].(string)
|
||||
return map[string]any{
|
||||
|
||||
@@ -21,8 +21,8 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 2 {
|
||||
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
}
|
||||
|
||||
first, _ := output[0].(map[string]any)
|
||||
@@ -32,35 +32,10 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
if first["call_id"] == "" {
|
||||
t.Fatalf("expected function_call item to have call_id, got %#v", first)
|
||||
}
|
||||
second, _ := output[1].(map[string]any)
|
||||
if second["type"] != "tool_calls" {
|
||||
t.Fatalf("expected second output item type tool_calls, got %#v", second["type"])
|
||||
if first["name"] != "search" {
|
||||
t.Fatalf("unexpected function name: %#v", first["name"])
|
||||
}
|
||||
var toolCalls []map[string]any
|
||||
switch v := second["tool_calls"].(type) {
|
||||
case []map[string]any:
|
||||
toolCalls = v
|
||||
case []any:
|
||||
toolCalls = make([]map[string]any, 0, len(v))
|
||||
for _, item := range v {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil {
|
||||
toolCalls = append(toolCalls, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", second["tool_calls"])
|
||||
}
|
||||
tc := toolCalls[0]
|
||||
if tc["type"] != "function" || tc["id"] == "" {
|
||||
t.Fatalf("unexpected tool call shape: %#v", tc)
|
||||
}
|
||||
fn, _ := tc["function"].(map[string]any)
|
||||
if fn["name"] != "search" {
|
||||
t.Fatalf("unexpected function name: %#v", fn["name"])
|
||||
}
|
||||
argsRaw, _ := fn["arguments"].(string)
|
||||
argsRaw, _ := first["arguments"].(string)
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
|
||||
t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
|
||||
@@ -86,8 +61,8 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T)
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 2 {
|
||||
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
@@ -163,8 +138,8 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
)
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 3 {
|
||||
t.Fatalf("expected reasoning + function_call + tool_calls outputs, got %#v", obj["output"])
|
||||
if len(output) != 2 {
|
||||
t.Fatalf("expected reasoning + function_call outputs, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "reasoning" {
|
||||
@@ -174,8 +149,4 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
if second["type"] != "function_call" {
|
||||
t.Fatalf("expected second output function_call, got %#v", second["type"])
|
||||
}
|
||||
third, _ := output[2].(map[string]any)
|
||||
if third["type"] != "tool_calls" {
|
||||
t.Fatalf("expected third output tool_calls, got %#v", third["type"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,12 +8,48 @@ type StandardRequest struct {
|
||||
Messages []any
|
||||
FinalPrompt string
|
||||
ToolNames []string
|
||||
ToolChoice ToolChoicePolicy
|
||||
Stream bool
|
||||
Thinking bool
|
||||
Search bool
|
||||
PassThrough map[string]any
|
||||
}
|
||||
|
||||
type ToolChoiceMode string
|
||||
|
||||
const (
|
||||
ToolChoiceAuto ToolChoiceMode = "auto"
|
||||
ToolChoiceNone ToolChoiceMode = "none"
|
||||
ToolChoiceRequired ToolChoiceMode = "required"
|
||||
ToolChoiceForced ToolChoiceMode = "forced"
|
||||
)
|
||||
|
||||
type ToolChoicePolicy struct {
|
||||
Mode ToolChoiceMode
|
||||
ForcedName string
|
||||
Allowed map[string]struct{}
|
||||
}
|
||||
|
||||
func DefaultToolChoicePolicy() ToolChoicePolicy {
|
||||
return ToolChoicePolicy{Mode: ToolChoiceAuto}
|
||||
}
|
||||
|
||||
func (p ToolChoicePolicy) IsNone() bool {
|
||||
return p.Mode == ToolChoiceNone
|
||||
}
|
||||
|
||||
func (p ToolChoicePolicy) IsRequired() bool {
|
||||
return p.Mode == ToolChoiceRequired || p.Mode == ToolChoiceForced
|
||||
}
|
||||
|
||||
func (p ToolChoicePolicy) Allows(name string) bool {
|
||||
if len(p.Allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := p.Allowed[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r StandardRequest) CompletionPayload(sessionID string) map[string]any {
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
|
||||
@@ -10,38 +10,62 @@ type ParsedToolCall struct {
|
||||
Input map[string]any `json:"input"`
|
||||
}
|
||||
|
||||
type ToolCallParseResult struct {
|
||||
Calls []ParsedToolCall
|
||||
SawToolCallSyntax bool
|
||||
RejectedByPolicy bool
|
||||
RejectedToolNames []string
|
||||
}
|
||||
|
||||
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return ParseToolCallsDetailed(text, availableToolNames).Calls
|
||||
}
|
||||
|
||||
func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
|
||||
result := ToolCallParseResult{}
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
text = stripFencedCodeBlocks(text)
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls")
|
||||
|
||||
candidates := buildToolCallCandidates(text)
|
||||
var parsed []ParsedToolCall
|
||||
for _, candidate := range candidates {
|
||||
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
|
||||
parsed = tc
|
||||
result.SawToolCallSyntax = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parsed) == 0 {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
result.Calls = calls
|
||||
result.RejectedToolNames = rejectedNames
|
||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||
return result
|
||||
}
|
||||
|
||||
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
|
||||
}
|
||||
|
||||
func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
|
||||
result := ToolCallParseResult{}
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
if looksLikeToolExampleContext(trimmed) {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls")
|
||||
candidates := []string{trimmed}
|
||||
for _, candidate := range candidates {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
@@ -52,24 +76,31 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed
|
||||
continue
|
||||
}
|
||||
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
result.SawToolCallSyntax = true
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
result.Calls = calls
|
||||
result.RejectedToolNames = rejectedNames
|
||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||
return result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
|
||||
func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall {
|
||||
func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) {
|
||||
allowed := map[string]struct{}{}
|
||||
for _, name := range availableToolNames {
|
||||
allowed[name] = struct{}{}
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(parsed))
|
||||
rejectedSet := map[string]struct{}{}
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[tc.Name]; !ok {
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -78,21 +109,11 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par
|
||||
}
|
||||
out = append(out, tc)
|
||||
}
|
||||
// If the model clearly emitted tool_calls JSON but all names are outside the
|
||||
// declared set, keep the parsed calls as a fallback so upper layers can still
|
||||
// intercept structured tool output instead of leaking raw JSON to users.
|
||||
if len(out) == 0 && len(parsed) > 0 {
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
if tc.Input == nil {
|
||||
tc.Input = map[string]any{}
|
||||
}
|
||||
out = append(out, tc)
|
||||
}
|
||||
rejected := make([]string, 0, len(rejectedSet))
|
||||
for name := range rejectedSet {
|
||||
rejected = append(rejected, name)
|
||||
}
|
||||
return out
|
||||
return out, rejected
|
||||
}
|
||||
|
||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
|
||||
@@ -38,14 +38,25 @@ func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsKeepsUnknownAsFallback(t *testing.T) {
|
||||
func TestParseToolCallsRejectsUnknownToolName(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected fallback 1 call, got %d", len(calls))
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected unknown tool to be rejected, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "unknown" {
|
||||
t.Fatalf("unexpected name: %s", calls[0].Name)
|
||||
}
|
||||
|
||||
func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
|
||||
res := ParseToolCallsDetailed(text, []string{"search"})
|
||||
if !res.SawToolCallSyntax {
|
||||
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
|
||||
}
|
||||
if !res.RejectedByPolicy {
|
||||
t.Fatalf("expected RejectedByPolicy=true, got %#v", res)
|
||||
}
|
||||
if len(res.Calls) != 0 {
|
||||
t.Fatalf("expected no calls after policy rejection, got %#v", res.Calls)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
{
|
||||
"calls": [
|
||||
{"name": "unknown_tool", "input": {"x": 1}}
|
||||
]
|
||||
"calls": []
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user