mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-09 18:57:43 +08:00
6
.github/PULL_REQUEST_TEMPLATE.md
vendored
6
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -13,12 +13,8 @@
|
||||
|
||||
#### 🔀 变更说明 | Description of Change
|
||||
|
||||
<!-- Thank you for your Pull Request. Please provide a description above. -->
|
||||
|
||||
|
||||
#### 📝 补充信息 | Additional Information
|
||||
|
||||
<!-- Add any other context about the Pull Request here. -->
|
||||
|
||||
---
|
||||
|
||||
> 💡 **提示**:如果修改了 `webui/` 目录下的文件,PR 合并后 CI 会自动构建并提交产物,无需手动构建。
|
||||
10
.github/workflows/release-artifacts.yml
vendored
10
.github/workflows/release-artifacts.yml
vendored
@@ -12,6 +12,9 @@ permissions:
|
||||
jobs:
|
||||
build-and-upload:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -87,10 +90,11 @@ jobs:
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
if: "${{ env.DOCKERHUB_USERNAME != '' }}"
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
username: ${{ env.DOCKERHUB_USERNAME }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Extract Docker metadata
|
||||
id: meta_release
|
||||
@@ -98,7 +102,7 @@ jobs:
|
||||
with:
|
||||
images: |
|
||||
ghcr.io/${{ github.repository }}
|
||||
cjackhwang/ds2api
|
||||
${{ env.DOCKERHUB_USERNAME || 'cjackhwang' }}/ds2api
|
||||
tags: |
|
||||
type=raw,value=${{ github.event.release.tag_name }}
|
||||
type=raw,value=latest
|
||||
|
||||
@@ -86,7 +86,7 @@ Manually build WebUI to `static/admin/`:
|
||||
go test ./...
|
||||
|
||||
# End-to-end live tests (real accounts)
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
@@ -86,7 +86,7 @@ docker-compose -f docker-compose.dev.yml up
|
||||
go test ./...
|
||||
|
||||
# 端到端全链路测试(真实账号)
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
@@ -518,7 +518,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
|
||||
Run the full live testsuite before release (real account tests):
|
||||
|
||||
```bash
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
With custom flags:
|
||||
|
||||
@@ -518,7 +518,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
|
||||
建议在发布前执行完整的端到端测试集(使用真实账号):
|
||||
|
||||
```bash
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
可自定义参数:
|
||||
|
||||
38
README.MD
38
README.MD
@@ -283,6 +283,9 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容旧名) | — |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel 混合流式内部鉴权密钥 | 回退用 `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease 过期秒数 | `900` |
|
||||
| `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | 本地抓包保留条数(超出自动淘汰) | `5` |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | 单条响应体最大记录字节数 | `2097152` |
|
||||
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
||||
| `VERCEL_PROJECT_ID` | Vercel 项目 ID | — |
|
||||
| `VERCEL_TEAM_ID` | Vercel 团队 ID | — |
|
||||
@@ -321,6 +324,29 @@ cp opencode.json.example opencode.json
|
||||
3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content`
|
||||
4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出
|
||||
|
||||
## 本地开发抓包工具
|
||||
|
||||
用于定位「responses 思考流/工具调用」等问题。开启后会自动记录最近 N 条 DeepSeek 对话上游请求体与响应体(默认 5 条,超出自动淘汰)。
|
||||
|
||||
启用示例:
|
||||
|
||||
```bash
|
||||
DS2API_DEV_PACKET_CAPTURE=true \
|
||||
DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
查询/清空(需 Admin JWT):
|
||||
|
||||
- `GET /admin/dev/captures`:查看抓包列表(最新在前)
|
||||
- `DELETE /admin/dev/captures`:清空抓包
|
||||
|
||||
返回字段包含:
|
||||
|
||||
- `request_body`:发送给 DeepSeek 的完整请求体
|
||||
- `response_body`:上游返回的原始流式内容拼接文本
|
||||
- `response_truncated`:是否触发单条大小截断
|
||||
|
||||
## 项目结构
|
||||
|
||||
```text
|
||||
@@ -350,8 +376,10 @@ ds2api/
|
||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||
│ └── locales/ # 中英文语言包(zh.json / en.json)
|
||||
├── scripts/
|
||||
│ ├── build-webui.sh # WebUI 手动构建脚本
|
||||
│ └── testsuite/ # 测试集运行脚本
|
||||
│ └── build-webui.sh # WebUI 手动构建脚本
|
||||
├── tests/
|
||||
│ ├── compat/ # 兼容性测试夹具与期望输出
|
||||
│ └── scripts/ # 统一测试脚本入口(unit/e2e)
|
||||
├── static/admin/ # WebUI 构建产物(不提交到 Git)
|
||||
├── .github/
|
||||
│ ├── workflows/ # GitHub Actions(Release 自动构建)
|
||||
@@ -379,11 +407,11 @@ ds2api/
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
# 单元测试
|
||||
go test ./...
|
||||
# 单元测试(Go + Node)
|
||||
./tests/scripts/run-unit-all.sh
|
||||
|
||||
# 一键端到端全链路测试(真实账号,生成完整请求/响应日志)
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
|
||||
# 或自定义参数
|
||||
go run ./cmd/ds2api-tests \
|
||||
|
||||
38
README.en.md
38
README.en.md
@@ -283,6 +283,9 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL seconds | `900` |
|
||||
| `DS2API_DEV_PACKET_CAPTURE` | Local dev packet capture switch (record recent request/response bodies) | Enabled by default on non-Vercel local runtime |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | Number of captured sessions to retain (auto-evict overflow) | `5` |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | Max recorded bytes per captured response body | `2097152` |
|
||||
| `VERCEL_TOKEN` | Vercel sync token | — |
|
||||
| `VERCEL_PROJECT_ID` | Vercel project ID | — |
|
||||
| `VERCEL_TEAM_ID` | Vercel team ID | — |
|
||||
@@ -321,6 +324,29 @@ When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||
3. Confirmed toolcall JSON fragments are never leaked into `delta.content`
|
||||
4. Natural language before/after toolcalls keeps original order, with incremental argument output supported
|
||||
|
||||
## Local Dev Packet Capture
|
||||
|
||||
This is for debugging issues such as Responses reasoning streaming and tool-call handoff. When enabled, DS2API stores the latest N DeepSeek conversation payload pairs (request body + upstream response body), defaulting to 5 entries with auto-eviction.
|
||||
|
||||
Enable example:
|
||||
|
||||
```bash
|
||||
DS2API_DEV_PACKET_CAPTURE=true \
|
||||
DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
Inspect/clear (Admin JWT required):
|
||||
|
||||
- `GET /admin/dev/captures`: list captured items (newest first)
|
||||
- `DELETE /admin/dev/captures`: clear captured items
|
||||
|
||||
Response fields include:
|
||||
|
||||
- `request_body`: full payload sent to DeepSeek
|
||||
- `response_body`: concatenated raw upstream stream body text
|
||||
- `response_truncated`: whether body-size truncation happened
|
||||
|
||||
## Project Structure
|
||||
|
||||
```text
|
||||
@@ -350,8 +376,10 @@ ds2api/
|
||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||
│ └── locales/ # Language packs (zh.json / en.json)
|
||||
├── scripts/
|
||||
│ ├── build-webui.sh # Manual WebUI build script
|
||||
│ └── testsuite/ # Testsuite runner scripts
|
||||
│ └── build-webui.sh # Manual WebUI build script
|
||||
├── tests/
|
||||
│ ├── compat/ # Compatibility fixtures and expected outputs
|
||||
│ └── scripts/ # Unified test script entrypoints (unit/e2e)
|
||||
├── static/admin/ # WebUI build output (not committed to Git)
|
||||
├── .github/
|
||||
│ ├── workflows/ # GitHub Actions (Release artifact automation)
|
||||
@@ -379,11 +407,11 @@ ds2api/
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Unit tests
|
||||
go test ./...
|
||||
# Unit tests (Go + Node)
|
||||
./tests/scripts/run-unit-all.sh
|
||||
|
||||
# One-command live end-to-end tests (real accounts, full request/response logs)
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
|
||||
# Or with custom flags
|
||||
go run ./cmd/ds2api-tests \
|
||||
|
||||
16
TESTING.md
16
TESTING.md
@@ -8,8 +8,10 @@ DS2API 提供两个层级的测试:
|
||||
|
||||
| 层级 | 命令 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| 单元测试 | `go test ./...` | 不需要真实账号 |
|
||||
| 端到端测试 | `./scripts/testsuite/run-live.sh` | 使用真实账号执行全链路测试 |
|
||||
| 单元测试(Go) | `./tests/scripts/run-unit-go.sh` | 不需要真实账号 |
|
||||
| 单元测试(Node) | `./tests/scripts/run-unit-node.sh` | 不需要真实账号 |
|
||||
| 单元测试(全部) | `./tests/scripts/run-unit-all.sh` | 不需要真实账号 |
|
||||
| 端到端测试 | `./tests/scripts/run-live.sh` | 使用真实账号执行全链路测试 |
|
||||
|
||||
端到端测试集会录制完整的请求/响应日志,用于故障排查。
|
||||
|
||||
@@ -20,17 +22,19 @@ DS2API 提供两个层级的测试:
|
||||
### 单元测试 | Unit Tests
|
||||
|
||||
```bash
|
||||
go test ./...
|
||||
./tests/scripts/run-unit-all.sh
|
||||
```
|
||||
|
||||
```bash
|
||||
node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js
|
||||
# 或按语言拆分执行
|
||||
./tests/scripts/run-unit-go.sh
|
||||
./tests/scripts/run-unit-node.sh
|
||||
```
|
||||
|
||||
### 端到端测试 | End-to-End Tests
|
||||
|
||||
```bash
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
**默认行为**:
|
||||
@@ -179,7 +183,7 @@ go run ./cmd/ds2api-tests \
|
||||
|
||||
```bash
|
||||
# 确保 config.json 存在且包含有效测试账号
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
exit_code=$?
|
||||
if [ $exit_code -ne 0 ]; then
|
||||
echo "Tests failed! Check artifacts for details."
|
||||
|
||||
@@ -38,6 +38,10 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/anthropic/v1/models", h.ListModels)
|
||||
r.Post("/anthropic/v1/messages", h.Messages)
|
||||
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/v1/messages", h.Messages)
|
||||
r.Post("/messages", h.Messages)
|
||||
r.Post("/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/messages/count_tokens", h.CountTokens)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
@@ -167,7 +171,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
canFlush := rc.Flush() == nil
|
||||
_, canFlush := w.(http.Flusher)
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
@@ -250,7 +254,7 @@ func normalizeClaudeMessages(messages []any) []any {
|
||||
}
|
||||
}
|
||||
if typeStr == "tool_result" {
|
||||
parts = append(parts, fmt.Sprintf("%v", b["content"]))
|
||||
parts = append(parts, formatClaudeToolResultForPrompt(b))
|
||||
}
|
||||
}
|
||||
copied["content"] = strings.Join(parts, "\n")
|
||||
@@ -272,10 +276,36 @@ func buildClaudeToolPrompt(tools []any) string {
|
||||
schema, _ := json.Marshal(m["input_schema"])
|
||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
}
|
||||
parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}")
|
||||
parts = append(parts,
|
||||
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
|
||||
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
||||
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||
)
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
func formatClaudeToolResultForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "unknown"
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
|
||||
if content == "" {
|
||||
content = "null"
|
||||
}
|
||||
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||
}
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -48,8 +49,9 @@ func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "tool output" {
|
||||
t.Fatalf("expected 'tool output', got %q", m["content"])
|
||||
content, _ := m["content"].(string)
|
||||
if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") {
|
||||
t.Fatalf("expected serialized tool result marker, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
44
internal/adapter/claude/route_alias_test.go
Normal file
44
internal/adapter/claude/route_alias_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type routeAliasAuthStub struct{}
|
||||
|
||||
func (routeAliasAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return nil, auth.ErrUnauthorized
|
||||
}
|
||||
|
||||
func (routeAliasAuthStub) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
func TestClaudeRouteAliasesDoNot404(t *testing.T) {
|
||||
h := &Handler{
|
||||
Auth: routeAliasAuthStub{},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
paths := []string{
|
||||
"/anthropic/v1/messages",
|
||||
"/v1/messages",
|
||||
"/messages",
|
||||
"/anthropic/v1/messages/count_tokens",
|
||||
"/v1/messages/count_tokens",
|
||||
"/messages/count_tokens",
|
||||
}
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest(http.MethodPost, path, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusNotFound {
|
||||
t.Fatalf("expected route %s to be registered, got 404", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -27,9 +27,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalizedMessages
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) {
|
||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...)
|
||||
}
|
||||
payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested)
|
||||
|
||||
dsPayload := convertClaudeToDeepSeek(payload, store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
@@ -57,3 +55,59 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
|
||||
NormalizedMessages: normalizedMessages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any {
|
||||
if len(tools) == 0 {
|
||||
return normalizedMessages
|
||||
}
|
||||
toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools))
|
||||
if toolPrompt == "" {
|
||||
return normalizedMessages
|
||||
}
|
||||
|
||||
// Prefer top-level Anthropic-style system prompt when available.
|
||||
if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" {
|
||||
payload["system"] = mergeSystemPrompt(systemText, toolPrompt)
|
||||
return normalizedMessages
|
||||
}
|
||||
|
||||
messages := cloneAnySlice(normalizedMessages)
|
||||
for i := range messages {
|
||||
msg, ok := messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msg["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "system") {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt)
|
||||
messages[i] = copied
|
||||
return messages
|
||||
}
|
||||
|
||||
return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...)
|
||||
}
|
||||
|
||||
func mergeSystemPrompt(base, extra string) string {
|
||||
base = strings.TrimSpace(base)
|
||||
extra = strings.TrimSpace(extra)
|
||||
switch {
|
||||
case base == "":
|
||||
return extra
|
||||
case extra == "":
|
||||
return base
|
||||
default:
|
||||
return base + "\n\n" + extra
|
||||
}
|
||||
}
|
||||
|
||||
func cloneAnySlice(in []any) []any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, len(in))
|
||||
copy(out, in)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -36,3 +36,57 @@ func TestNormalizeClaudeRequest(t *testing.T) {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "system", "content": "baseline rule"},
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||
t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
if !containsStr(norm.Standard.FinalPrompt, "baseline rule") {
|
||||
t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"system": "top-level system",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
if !containsStr(norm.Standard.FinalPrompt, "top-level system") {
|
||||
t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||
t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
100
internal/adapter/claude/stream_status_test.go
Normal file
100
internal/adapter/claude/stream_status_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimw "github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type streamStatusClaudeAuthStub struct{}
|
||||
|
||||
func (streamStatusClaudeAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: "direct-token",
|
||||
CallerID: "caller:test",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (streamStatusClaudeAuthStub) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
type streamStatusClaudeDSStub struct{}
|
||||
|
||||
func (streamStatusClaudeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "session-id", nil
|
||||
}
|
||||
|
||||
func (streamStatusClaudeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (streamStatusClaudeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
body := "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n" + "data: [DONE]\n"
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: ioNopCloser{strings.NewReader(body)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ioNopCloser struct {
|
||||
*strings.Reader
|
||||
}
|
||||
|
||||
func (ioNopCloser) Close() error { return nil }
|
||||
|
||||
type streamStatusClaudeStoreStub struct{}
|
||||
|
||||
func (streamStatusClaudeStoreStub) ClaudeMapping() map[string]string {
|
||||
return map[string]string{
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner",
|
||||
}
|
||||
}
|
||||
|
||||
func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
next.ServeHTTP(ww, r)
|
||||
*statuses = append(*statuses, ww.Status())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: streamStatusClaudeStoreStub{},
|
||||
Auth: streamStatusClaudeAuthStub{},
|
||||
DS: streamStatusClaudeDSStub{},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureClaudeStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||
}
|
||||
if statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||
}
|
||||
}
|
||||
313
internal/adapter/gemini/convert.go
Normal file
313
internal/adapter/gemini/convert.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/adapter/openai"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
|
||||
requestedModel := strings.TrimSpace(routeModel)
|
||||
if requestedModel == "" {
|
||||
return util.StandardRequest{}, fmt.Errorf("model is required in request path")
|
||||
}
|
||||
|
||||
resolvedModel, ok := config.ResolveModel(store, requestedModel)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
|
||||
messagesRaw := geminiMessagesFromRequest(req)
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
|
||||
}
|
||||
|
||||
toolsRaw := convertGeminiTools(req["tools"])
|
||||
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
|
||||
passThrough := collectGeminiPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "google_gemini",
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: requestedModel,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: stream,
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
out := make([]any, 0, 8)
|
||||
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": "system",
|
||||
"content": sys,
|
||||
})
|
||||
}
|
||||
|
||||
contents, _ := req["contents"].([]any)
|
||||
for _, item := range contents {
|
||||
content, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := mapGeminiRole(content["role"])
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
parts, _ := content["parts"].([]any)
|
||||
if len(parts) == 0 {
|
||||
if text := strings.TrimSpace(asString(content["text"])); text != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
textParts := make([]string, 0, len(parts))
|
||||
flushText := func() {
|
||||
if len(textParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": strings.Join(textParts, "\n"),
|
||||
})
|
||||
textParts = textParts[:0]
|
||||
}
|
||||
|
||||
for _, rawPart := range parts {
|
||||
part, ok := rawPart.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||
textParts = append(textParts, text)
|
||||
continue
|
||||
}
|
||||
|
||||
if fnCall, ok := part["functionCall"].(map[string]any); ok {
|
||||
flushText()
|
||||
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
|
||||
callID := strings.TrimSpace(asString(fnCall["id"]))
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyJSON(fnCall["args"]),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if fnResp, ok := part["functionResponse"].(map[string]any); ok {
|
||||
flushText()
|
||||
name := strings.TrimSpace(asString(fnResp["name"]))
|
||||
callID := strings.TrimSpace(asString(fnResp["id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(asString(fnResp["callId"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
}
|
||||
content := fnResp["response"]
|
||||
if content == nil {
|
||||
content = fnResp["output"]
|
||||
}
|
||||
if content == nil {
|
||||
content = ""
|
||||
}
|
||||
msg := map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": callID,
|
||||
"content": content,
|
||||
}
|
||||
if name != "" {
|
||||
msg["name"] = name
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
}
|
||||
flushText()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeGeminiSystemInstruction(raw any) string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if parts, ok := v["parts"].([]any); ok {
|
||||
texts := make([]string, 0, len(parts))
|
||||
for _, item := range parts {
|
||||
part, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
}
|
||||
if text := strings.TrimSpace(asString(v["text"])); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func mapGeminiRole(v any) string {
|
||||
switch strings.ToLower(strings.TrimSpace(asString(v))) {
|
||||
case "user":
|
||||
return "user"
|
||||
case "model", "assistant":
|
||||
return "assistant"
|
||||
case "system":
|
||||
return "system"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func convertGeminiTools(raw any) []any {
|
||||
tools, _ := raw.([]any)
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(tools))
|
||||
for _, item := range tools {
|
||||
tool, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
|
||||
for _, declRaw := range fnDecls {
|
||||
decl, ok := declRaw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(asString(decl["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
function := map[string]any{
|
||||
"name": name,
|
||||
}
|
||||
if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
|
||||
function["description"] = desc
|
||||
}
|
||||
if params, ok := decl["parameters"].(map[string]any); ok {
|
||||
function["parameters"] = params
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"type": "function",
|
||||
"function": function,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// OpenAI-style passthrough fallback.
|
||||
if _, ok := tool["function"].(map[string]any); ok {
|
||||
out = append(out, tool)
|
||||
continue
|
||||
}
|
||||
|
||||
// Loose fallback for flattened function schema objects.
|
||||
name := strings.TrimSpace(asString(tool["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
fn := map[string]any{"name": name}
|
||||
if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
|
||||
fn["description"] = desc
|
||||
}
|
||||
if params, ok := tool["parameters"].(map[string]any); ok {
|
||||
fn["parameters"] = params
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"type": "function",
|
||||
"function": fn,
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func collectGeminiPassThrough(req map[string]any) map[string]any {
|
||||
cfg, _ := req["generationConfig"].(map[string]any)
|
||||
if len(cfg) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := map[string]any{}
|
||||
if v, ok := cfg["temperature"]; ok {
|
||||
out["temperature"] = v
|
||||
}
|
||||
if v, ok := cfg["topP"]; ok {
|
||||
out["top_p"] = v
|
||||
}
|
||||
if v, ok := cfg["maxOutputTokens"]; ok {
|
||||
out["max_tokens"] = v
|
||||
}
|
||||
if v, ok := cfg["stopSequences"]; ok {
|
||||
out["stop"] = v
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func stringifyJSON(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
29
internal/adapter/gemini/deps.go
Normal file
29
internal/adapter/gemini/deps.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type AuthResolver interface {
|
||||
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||
Release(a *auth.RequestAuth)
|
||||
}
|
||||
|
||||
type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
}
|
||||
|
||||
type ConfigReader interface {
|
||||
ModelAliases() map[string]string
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
348
internal/adapter/gemini/handler.go
Normal file
348
internal/adapter/gemini/handler.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
|
||||
r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||
r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
|
||||
r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||
}
|
||||
|
||||
func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleGenerateContent(w, r, false)
|
||||
}
|
||||
|
||||
func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleGenerateContent(w, r, true)
|
||||
}
|
||||
|
||||
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeGeminiError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
|
||||
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
|
||||
}
|
||||
|
||||
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnParsed: runtime.onParsed,
|
||||
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||
runtime.finalize()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": model,
|
||||
"usageMetadata": usage,
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"promptTokenCount": promptTokens,
|
||||
"candidatesTokenCount": reasoningTokens + completionTokens,
|
||||
"totalTokenCount": promptTokens + reasoningTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
parts := make([]map[string]any, 0, len(detected))
|
||||
for _, tc := range detected {
|
||||
parts = append(parts, map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": tc.Name,
|
||||
"args": tc.Input,
|
||||
},
|
||||
})
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
text := finalText
|
||||
if strings.TrimSpace(text) == "" {
|
||||
text = finalThinking
|
||||
}
|
||||
return []map[string]any{{"text": text}}
|
||||
}
|
||||
|
||||
type geminiStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
model string
|
||||
finalPrompt string
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
bufferContent bool
|
||||
toolNames []string
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
func newGeminiStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
) *geminiStreamRuntime {
|
||||
return &geminiStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferContent: len(toolNames) > 0,
|
||||
toolNames: toolNames,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
if p.Type == "thinking" {
|
||||
if s.thinkingEnabled {
|
||||
s.thinking.WriteString(p.Text)
|
||||
}
|
||||
continue
|
||||
}
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferContent {
|
||||
continue
|
||||
}
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": []map[string]any{{"text": p.Text}},
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
|
||||
if s.bufferContent {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
|
||||
})
|
||||
}
|
||||
|
||||
func writeGeminiError(w http.ResponseWriter, status int, message string) {
|
||||
errorStatus := "INVALID_ARGUMENT"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errorStatus = "UNAUTHENTICATED"
|
||||
case http.StatusForbidden:
|
||||
errorStatus = "PERMISSION_DENIED"
|
||||
case http.StatusTooManyRequests:
|
||||
errorStatus = "RESOURCE_EXHAUSTED"
|
||||
case http.StatusNotFound:
|
||||
errorStatus = "NOT_FOUND"
|
||||
default:
|
||||
if status >= 500 {
|
||||
errorStatus = "INTERNAL"
|
||||
}
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": errorStatus,
|
||||
},
|
||||
})
|
||||
}
|
||||
174
internal/adapter/gemini/handler_test.go
Normal file
174
internal/adapter/gemini/handler_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type testGeminiConfig struct{}
|
||||
|
||||
func (testGeminiConfig) ModelAliases() map[string]string { return nil }
|
||||
|
||||
type testGeminiAuth struct {
|
||||
a *auth.RequestAuth
|
||||
err error
|
||||
}
|
||||
|
||||
func (m testGeminiAuth) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
if m.a != nil {
|
||||
return m.a, nil
|
||||
}
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: "direct-token",
|
||||
CallerID: "caller:test",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (testGeminiAuth) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
type testGeminiDS struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (m testGeminiDS) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "session-id", nil
|
||||
}
|
||||
|
||||
func (m testGeminiDS) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.resp, nil
|
||||
}
|
||||
|
||||
func makeGeminiUpstreamResponse(lines ...string) *http.Response {
|
||||
body := strings.Join(lines, "\n")
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
body += "\n"
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiRoutesRegistered(t *testing.T) {
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
Auth: testGeminiAuth{err: auth.ErrUnauthorized},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
paths := []string{
|
||||
"/v1beta/models/gemini-2.5-pro:generateContent",
|
||||
"/v1beta/models/gemini-2.5-pro:streamGenerateContent",
|
||||
"/v1/models/gemini-2.5-pro:generateContent",
|
||||
"/v1/models/gemini-2.5-pro:streamGenerateContent",
|
||||
}
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusNotFound {
|
||||
t.Fatalf("expected route %s to be registered, got 404", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
Auth: testGeminiAuth{},
|
||||
DS: testGeminiDS{resp: upstream},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{
|
||||
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
|
||||
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
candidates, _ := out["candidates"].([]any)
|
||||
if len(candidates) == 0 {
|
||||
t.Fatalf("expected non-empty candidates: %#v", out)
|
||||
}
|
||||
c0, _ := candidates[0].(map[string]any)
|
||||
content, _ := c0["content"].(map[string]any)
|
||||
parts, _ := content["parts"].([]any)
|
||||
if len(parts) == 0 {
|
||||
t.Fatalf("expected non-empty parts: %#v", content)
|
||||
}
|
||||
part0, _ := parts[0].(map[string]any)
|
||||
functionCall, _ := part0["functionCall"].(map[string]any)
|
||||
if functionCall["name"] != "eval_javascript" {
|
||||
t.Fatalf("expected functionCall name eval_javascript, got %#v", functionCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"hello "}`,
|
||||
`data: {"p":"response/content","v":"world"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
Auth: testGeminiAuth{},
|
||||
DS: testGeminiDS{resp: upstream},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:streamGenerateContent?alt=sse", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "data: ") {
|
||||
t.Fatalf("expected SSE data frames, got body=%s", rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), `"finishReason":"STOP"`) {
|
||||
t.Fatalf("expected stream finish frame, got body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -25,10 +25,11 @@ type chatStreamRuntime struct {
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
|
||||
firstChunkSent bool
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
firstChunkSent bool
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
toolCallsDoneEmitted bool
|
||||
|
||||
toolSieve toolStreamSieveState
|
||||
streamToolCallIDs map[int]string
|
||||
@@ -96,10 +97,10 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 && !s.toolCallsEmitted {
|
||||
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": util.FormatOpenAIStreamToolCalls(detected),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
@@ -112,8 +113,29 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
} else if s.bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) {
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
|
||||
nil,
|
||||
))
|
||||
}
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
}
|
||||
@@ -189,10 +211,14 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs),
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
continue
|
||||
}
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatted,
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
@@ -202,8 +228,9 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||
"model": "my-model",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
|
||||
}
|
||||
out, err := normalizeOpenAIChatRequest(cfg, req)
|
||||
out, err := normalizeOpenAIChatRequest(cfg, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeOpenAIChatRequest error: %v", err)
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
|
||||
_, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||
aliases: map[string]string{},
|
||||
wideInput: false,
|
||||
}, req)
|
||||
}, req, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when wide input is disabled and only input is provided")
|
||||
}
|
||||
@@ -60,7 +60,7 @@ func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.
|
||||
out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||
aliases: map[string]string{},
|
||||
wideInput: true,
|
||||
}, req)
|
||||
}, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error when wide input is enabled: %v", err)
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
@@ -154,7 +154,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
canFlush := rc.Flush() == nil
|
||||
_, canFlush := w.(http.Flusher)
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
@@ -233,7 +233,7 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
if len(toolSchemas) == 0 {
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error."
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
@@ -280,6 +280,36 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s
|
||||
return out
|
||||
}
|
||||
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
callID := ""
|
||||
if ids != nil {
|
||||
callID = strings.TrimSpace(ids[i])
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if ids != nil {
|
||||
ids[i] = callID
|
||||
}
|
||||
}
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"index": i,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": c.Name,
|
||||
"arguments": string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
|
||||
@@ -735,3 +735,73 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
|
||||
`data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundSearch := false
|
||||
foundEval := false
|
||||
foundIndex1 := false
|
||||
toolCallsDeltaLens := make([]int, 0, 2)
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
rawToolCalls, hasToolCalls := delta["tool_calls"]
|
||||
if !hasToolCalls {
|
||||
continue
|
||||
}
|
||||
toolCalls, _ := rawToolCalls.([]any)
|
||||
toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
|
||||
foundIndex1 = true
|
||||
}
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
name, _ := fn["name"].(string)
|
||||
switch name {
|
||||
case "search_web":
|
||||
foundSearch = true
|
||||
case "eval_javascript":
|
||||
foundEval = true
|
||||
case "search_webeval_javascript":
|
||||
t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
|
||||
}
|
||||
if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundSearch || !foundEval {
|
||||
t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
|
||||
}
|
||||
if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 {
|
||||
t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String())
|
||||
}
|
||||
if !foundIndex1 {
|
||||
t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
msg, ok := item.(map[string]any)
|
||||
@@ -17,7 +19,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
|
||||
switch role {
|
||||
case "assistant":
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
toolCalls := formatAssistantToolCallsForPrompt(msg)
|
||||
toolCalls := formatAssistantToolCallsForPrompt(msg, traceID)
|
||||
combined := joinNonEmpty(content, toolCalls)
|
||||
if combined == "" {
|
||||
continue
|
||||
@@ -53,7 +55,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func formatAssistantToolCallsForPrompt(msg map[string]any) string {
|
||||
func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string {
|
||||
entries := make([]string, 0)
|
||||
if calls, ok := msg["tool_calls"].([]any); ok {
|
||||
for i, item := range calls {
|
||||
@@ -86,7 +88,8 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: %s\n- function.name: %s\n- function.arguments: %s", id, name, args))
|
||||
maybeWarnSuspiciousToolHistory(traceID, id, name, args)
|
||||
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,7 +102,8 @@ func formatAssistantToolCallsForPrompt(msg map[string]any) string {
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: call_legacy\n- function.name: %s\n- function.arguments: %s", name, args))
|
||||
maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args)
|
||||
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args))
|
||||
}
|
||||
|
||||
return strings.Join(entries, "\n\n")
|
||||
@@ -124,7 +128,7 @@ func formatToolResultForPrompt(msg map[string]any) string {
|
||||
content = "null"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Tool result:\n- tool_call_id: %s\n- name: %s\n- content: %s", toolCallID, name, content)
|
||||
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||
}
|
||||
|
||||
func normalizeOpenAIContentForPrompt(v any) string {
|
||||
@@ -190,3 +194,45 @@ func joinNonEmpty(parts ...string) string {
|
||||
}
|
||||
return strings.Join(nonEmpty, "\n\n")
|
||||
}
|
||||
|
||||
func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) {
|
||||
if !looksLikeConcatenatedJSON(args) {
|
||||
return
|
||||
}
|
||||
traceID = strings.TrimSpace(traceID)
|
||||
if traceID == "" {
|
||||
traceID = "unknown"
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[openai] suspicious tool call history payload detected",
|
||||
"trace_id", traceID,
|
||||
"tool_call_id", strings.TrimSpace(callID),
|
||||
"name", strings.TrimSpace(name),
|
||||
"arguments_preview", previewToolArgs(args, 160),
|
||||
)
|
||||
}
|
||||
|
||||
func looksLikeConcatenatedJSON(raw string) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
|
||||
return true
|
||||
}
|
||||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||||
var first any
|
||||
if err := dec.Decode(&first); err != nil {
|
||||
return false
|
||||
}
|
||||
var second any
|
||||
return dec.Decode(&second) == nil
|
||||
}
|
||||
|
||||
func previewToolArgs(raw string, max int) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if max <= 0 || len(trimmed) <= max {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:max]
|
||||
}
|
||||
|
||||
@@ -33,23 +33,24 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 4 {
|
||||
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
|
||||
}
|
||||
assistantContent, _ := normalized[2]["content"].(string)
|
||||
if !strings.Contains(assistantContent, "tool_call_id: call_1") ||
|
||||
if !strings.Contains(assistantContent, "[TOOL_CALL_HISTORY]") ||
|
||||
!strings.Contains(assistantContent, "tool_call_id: call_1") ||
|
||||
!strings.Contains(assistantContent, "function.name: get_weather") ||
|
||||
!strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") {
|
||||
t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent)
|
||||
}
|
||||
toolContent, _ := normalized[3]["content"].(string)
|
||||
if !strings.Contains(toolContent, "Tool result:") || !strings.Contains(toolContent, "name: get_weather") {
|
||||
if !strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") || !strings.Contains(toolContent, "name: get_weather") {
|
||||
t.Fatalf("tool result not serialized correctly: %q", toolContent)
|
||||
}
|
||||
|
||||
prompt := util.MessagesPrepare(normalized)
|
||||
if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "Tool result:") {
|
||||
if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt)
|
||||
}
|
||||
}
|
||||
@@ -67,7 +68,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) {
|
||||
t.Fatalf("expected serialized object in tool content, got %q", got)
|
||||
@@ -88,7 +89,7 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "line-1\nline-2") {
|
||||
t.Fatalf("expected joined text blocks, got %q", got)
|
||||
@@ -107,7 +108,7 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
@@ -119,3 +120,50 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
|
||||
t.Fatalf("unexpected normalized function-role content: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_search",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search_web",
|
||||
"arguments": `{"query":"latest ai news"}`,
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"id": "call_eval",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "eval_javascript",
|
||||
"arguments": `{"code":"1+1"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized assistant message, got %d", len(normalized))
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 {
|
||||
t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") {
|
||||
t.Fatalf("missing first tool call block, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") {
|
||||
t.Fatalf("missing second tool call block, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, "search_webeval_javascript") {
|
||||
t.Fatalf("unexpected merged function name detected: %q", content)
|
||||
}
|
||||
if strings.Contains(content, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated function arguments detected: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,18 @@ import (
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw)
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
|
||||
toolNames := []string{}
|
||||
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools)
|
||||
}
|
||||
return deepseek.MessagesPrepare(messages), toolNames
|
||||
}
|
||||
|
||||
// BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so
|
||||
// other protocol adapters (for example Gemini) can reuse the same tool/history
|
||||
// normalization logic and remain behavior-compatible with chat/completions.
|
||||
func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||
return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID)
|
||||
}
|
||||
|
||||
@@ -40,13 +40,13 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools)
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "")
|
||||
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
|
||||
t.Fatalf("unexpected tool names: %#v", toolNames)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "tool_call_id: call_1") ||
|
||||
!strings.Contains(finalPrompt, "function.name: get_weather") ||
|
||||
!strings.Contains(finalPrompt, "Tool result:") ||
|
||||
!strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") ||
|
||||
!strings.Contains(finalPrompt, `"condition":"sunny"`) {
|
||||
t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt)
|
||||
}
|
||||
@@ -70,11 +70,14 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools)
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "")
|
||||
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing history marker instruction: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -32,6 +33,82 @@ func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesObjectRoleContentBlocks(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages(map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_text", "text": "line-1"},
|
||||
map[string]any{"type": "input_text", "text": "line-2"},
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "user" {
|
||||
t.Fatalf("unexpected role: %#v", m)
|
||||
}
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(m["content"])) != "line-1\nline-2" {
|
||||
t.Fatalf("unexpected content: %#v", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": "call_123",
|
||||
"output": map[string]any{"ok": true},
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", m)
|
||||
}
|
||||
if m["tool_call_id"] != "call_123" {
|
||||
t.Fatalf("expected tool_call_id propagated, got %#v", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": "call_456",
|
||||
"name": "search",
|
||||
"arguments": `{"q":"golang"}`,
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant role, got %#v", m["role"])
|
||||
}
|
||||
toolCalls, _ := m["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool_call, got %#v", m["tool_calls"])
|
||||
}
|
||||
call, _ := toolCalls[0].(map[string]any)
|
||||
if call["id"] != "call_456" {
|
||||
t.Fatalf("expected call id preserved, got %#v", call)
|
||||
}
|
||||
if call["type"] != "function" {
|
||||
t.Fatalf("expected function type, got %#v", call)
|
||||
}
|
||||
fn, _ := call["function"].(map[string]any)
|
||||
if fn["name"] != "search" {
|
||||
t.Fatalf("expected call name preserved, got %#v", call)
|
||||
}
|
||||
if fn["arguments"] != `{"q":"golang"}` {
|
||||
t.Fatalf("expected call arguments preserved, got %#v", call)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmbeddingInputs(t *testing.T) {
|
||||
got := extractEmbeddingInputs([]any{"a", "b"})
|
||||
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
|
||||
|
||||
@@ -68,7 +68,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req)
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
@@ -128,7 +128,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
canFlush := rc.Flush() == nil
|
||||
_, canFlush := w.(http.Flusher)
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
@@ -203,40 +203,231 @@ func normalizeResponsesInputAsMessages(input any) []any {
|
||||
}
|
||||
return []any{map[string]any{"role": "user", "content": v}}
|
||||
case []any:
|
||||
if len(v) == 0 {
|
||||
return nil
|
||||
}
|
||||
// If caller already provides role-shaped items, keep as-is.
|
||||
if first, ok := v[0].(map[string]any); ok {
|
||||
if _, hasRole := first["role"]; hasRole {
|
||||
return v
|
||||
}
|
||||
}
|
||||
parts := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
parts = append(parts, txt)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}}
|
||||
return normalizeResponsesInputArray(v)
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(v); msg != nil {
|
||||
return []any{msg}
|
||||
}
|
||||
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": txt}}
|
||||
}
|
||||
if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": content}}
|
||||
if content, ok := v["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": content}}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesInputArray(items []any) []any {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(items))
|
||||
fallbackParts := make([]string, 0, len(items))
|
||||
flushFallback := func() {
|
||||
if len(fallbackParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
|
||||
fallbackParts = fallbackParts[:0]
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(x); msg != nil {
|
||||
flushFallback()
|
||||
out = append(out, msg)
|
||||
continue
|
||||
}
|
||||
if s := normalizeResponsesFallbackPart(x); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
default:
|
||||
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
flushFallback()
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role != "" {
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
|
||||
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
switch itemType {
|
||||
case "message", "input_message":
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
case "function_call_output", "tool_result":
|
||||
content := m["output"]
|
||||
if content == nil {
|
||||
content = m["content"]
|
||||
}
|
||||
if content == nil {
|
||||
content = ""
|
||||
}
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
out["name"] = name
|
||||
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
return out
|
||||
case "function_call", "tool_call":
|
||||
name := strings.TrimSpace(asString(m["name"]))
|
||||
var fn map[string]any
|
||||
if rawFn, ok := m["function"].(map[string]any); ok {
|
||||
fn = rawFn
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asString(fn["name"]))
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var argsRaw any
|
||||
if v, ok := m["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := m["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
if argsRaw == nil && fn != nil {
|
||||
if v, ok := fn["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := fn["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
}
|
||||
|
||||
functionPayload := map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyToolCallArguments(argsRaw),
|
||||
}
|
||||
call := map[string]any{
|
||||
"type": "function",
|
||||
"function": functionPayload,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{call},
|
||||
}
|
||||
case "input_text":
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesFallbackPart(m map[string]any) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
}
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", m))
|
||||
}
|
||||
|
||||
func stringifyToolCallArguments(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,15 @@ package openai
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type responsesStreamRuntime struct {
|
||||
@@ -24,14 +27,20 @@ type responsesStreamRuntime struct {
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
toolCallsDoneEmitted bool
|
||||
|
||||
sieve toolStreamSieveState
|
||||
thinkingSieve toolStreamSieveState
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
streamToolCallIDs map[int]string
|
||||
streamFunctionIDs map[int]string
|
||||
functionDone map[int]bool
|
||||
toolCallsDoneSigs map[string]bool
|
||||
reasoningItemID string
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
@@ -63,6 +72,9 @@ func newResponsesStreamRuntime(
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamFunctionIDs: map[int]string{},
|
||||
functionDone: map[int]bool{},
|
||||
toolCallsDoneSigs: map[string]bool{},
|
||||
persistResponse: persistResponse,
|
||||
}
|
||||
}
|
||||
@@ -92,19 +104,33 @@ func (s *responsesStreamRuntime) sendDone() {
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
|
||||
}
|
||||
if s.bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&s.sieve, s.toolNames) {
|
||||
if evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
||||
}
|
||||
// Compatibility fallback: some streams only emit incremental tool deltas.
|
||||
// Ensure final function_call_arguments.done is emitted at least once.
|
||||
if s.toolCallsEmitted {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) == 0 {
|
||||
detected = util.ParseToolCalls(finalThinking, s.toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
if !s.toolCallsDoneEmitted {
|
||||
s.emitToolCallsDone(detected)
|
||||
} else {
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
|
||||
if s.toolCallsEmitted {
|
||||
s.alignCompletedOutputCallIDs(obj)
|
||||
}
|
||||
if s.toolCallsEmitted {
|
||||
obj["status"] = "completed"
|
||||
}
|
||||
@@ -138,6 +164,10 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||
s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -146,23 +176,191 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) {
|
||||
if evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)))
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
|
||||
}
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitToolCallsDone(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
|
||||
if len(calls) == 0 {
|
||||
return
|
||||
}
|
||||
sig := toolCallListSignature(calls)
|
||||
if sig != "" && s.toolCallsDoneSigs[sig] {
|
||||
return
|
||||
}
|
||||
if sig != "" {
|
||||
s.toolCallsDoneSigs[sig] = true
|
||||
}
|
||||
formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
return
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
|
||||
s.emitFunctionCallDoneEvents(calls)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
|
||||
if strings.TrimSpace(s.reasoningItemID) != "" {
|
||||
return s.reasoningItemID
|
||||
}
|
||||
s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
return s.reasoningItemID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
|
||||
if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.streamFunctionIDs[index] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
|
||||
if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.streamToolCallIDs[index] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
|
||||
if strings.TrimSpace(s.thinking.String()) != "" {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
|
||||
for _, d := range deltas {
|
||||
if strings.TrimSpace(d.Arguments) == "" {
|
||||
continue
|
||||
}
|
||||
outputIndex := s.functionOutputBaseIndex() + d.Index
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(d.Index)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.delta",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
|
||||
base := s.functionOutputBaseIndex()
|
||||
for idx, tc := range calls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
outputIndex := base + idx
|
||||
if s.functionDone[outputIndex] {
|
||||
continue
|
||||
}
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
callID := s.ensureToolCallID(idx)
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.done",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
|
||||
)
|
||||
s.functionDone[outputIndex] = true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
|
||||
if obj == nil || len(s.streamToolCallIDs) == 0 {
|
||||
return
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
return
|
||||
}
|
||||
indices := make([]int, 0, len(s.streamToolCallIDs))
|
||||
for idx := range s.streamToolCallIDs {
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
sort.Ints(indices)
|
||||
ordered := make([]string, 0, len(indices))
|
||||
for _, idx := range indices {
|
||||
id := strings.TrimSpace(s.streamToolCallIDs[idx])
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
ordered = append(ordered, id)
|
||||
}
|
||||
if len(ordered) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
functionIdx := 0
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
switch typ {
|
||||
case "function_call":
|
||||
if functionIdx < len(ordered) {
|
||||
m["call_id"] = ordered[functionIdx]
|
||||
functionIdx++
|
||||
}
|
||||
case "tool_calls":
|
||||
tcArr, _ := m["tool_calls"].([]any)
|
||||
for i, raw := range tcArr {
|
||||
tc, _ := raw.(map[string]any)
|
||||
if tc == nil {
|
||||
continue
|
||||
}
|
||||
if i < len(ordered) {
|
||||
tc["id"] = ordered[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallListSignature(calls []util.ParsedToolCall) string {
|
||||
if len(calls) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i, tc := range calls {
|
||||
if i > 0 {
|
||||
b.WriteString("|")
|
||||
}
|
||||
b.WriteString(strings.TrimSpace(tc.Name))
|
||||
b.WriteString(":")
|
||||
args, _ := json.Marshal(tc.Input)
|
||||
b.Write(args)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
@@ -45,17 +45,37 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "tool_calls" {
|
||||
t.Fatalf("expected first output type tool_calls, got %#v", first["type"])
|
||||
var firstToolWrapper map[string]any
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
if m["type"] == "tool_calls" && firstToolWrapper == nil {
|
||||
firstToolWrapper = m
|
||||
}
|
||||
}
|
||||
toolCalls, _ := first["tool_calls"].([]any)
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected at least one function_call item for responses compatibility, got %#v", responseObj["output"])
|
||||
}
|
||||
if firstToolWrapper == nil {
|
||||
t.Fatalf("expected a tool_calls wrapper item, got %#v", responseObj["output"])
|
||||
}
|
||||
toolCalls, _ := firstToolWrapper["tool_calls"].([]any)
|
||||
if len(toolCalls) == 0 {
|
||||
t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"])
|
||||
t.Fatalf("expected at least one tool_call in output, got %#v", firstToolWrapper["tool_calls"])
|
||||
}
|
||||
call0, _ := toolCalls[0].(map[string]any)
|
||||
if call0["name"] != "read_file" {
|
||||
t.Fatalf("unexpected tool call name: %#v", call0["name"])
|
||||
if call0["type"] != "function" {
|
||||
t.Fatalf("unexpected tool call type: %#v", call0["type"])
|
||||
}
|
||||
fn, _ := call0["function"].(map[string]any)
|
||||
if fn["name"] != "read_file" {
|
||||
t.Fatalf("unexpected tool call name: %#v", fn["name"])
|
||||
}
|
||||
if strings.Contains(outputText, `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
|
||||
@@ -95,6 +115,314 @@ func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsReasoningCompatEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "thought",
|
||||
})
|
||||
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil)
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected response.reasoning_text.delta compatibility event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.reasoning_text.done") {
|
||||
t.Fatalf("expected response.reasoning_text.done compatibility event, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsFunctionCallCompatEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||
t.Fatalf("expected response.function_call_arguments.delta compatibility event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done compatibility event, body=%s", body)
|
||||
}
|
||||
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
|
||||
if !ok {
|
||||
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
|
||||
}
|
||||
if strings.TrimSpace(asString(donePayload["call_id"])) == "" {
|
||||
t.Fatalf("expected call_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
|
||||
}
|
||||
if strings.TrimSpace(asString(donePayload["response_id"])) == "" {
|
||||
t.Fatalf("expected response_id in response.function_call_arguments.done payload, payload=%#v", donePayload)
|
||||
}
|
||||
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
|
||||
if doneCallID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
|
||||
}
|
||||
completed, ok := extractSSEEventPayload(body, "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed payload, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected non-empty output in response.completed, response=%#v", responseObj)
|
||||
}
|
||||
var completedCallID string
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil || m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
completedCallID = strings.TrimSpace(asString(m["call_id"]))
|
||||
if completedCallID != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if completedCallID == "" {
|
||||
t.Fatalf("expected function_call.call_id in completed output, output=%#v", output)
|
||||
}
|
||||
if completedCallID != doneCallID {
|
||||
t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamDetectsToolCallsFromThinkingChannel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"})
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected response.reasoning_text.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done event from thinking channel, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("expected response.output_tool_call.done event from thinking channel, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||
sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("expected response.output_tool_call.done event, body=%s", body)
|
||||
}
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
|
||||
seenNames := map[string]string{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
callID := strings.TrimSpace(asString(payload["call_id"]))
|
||||
args := strings.TrimSpace(asString(payload["arguments"]))
|
||||
if callID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||
}
|
||||
if strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated arguments in done payload: %#v", payload)
|
||||
}
|
||||
if name == "search_webeval_javascript" {
|
||||
t.Fatalf("unexpected merged tool name in done payload: %#v", payload)
|
||||
}
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in done payload: %#v", payload)
|
||||
}
|
||||
seenNames[name] = callID
|
||||
}
|
||||
if seenNames["search_web"] == "" || seenNames["eval_javascript"] == "" {
|
||||
t.Fatalf("expected done events for both tools, got %#v", seenNames)
|
||||
}
|
||||
if seenNames["search_web"] == seenNames["eval_javascript"] {
|
||||
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
|
||||
}
|
||||
|
||||
completed, ok := extractSSEEventPayload(body, "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
functionCallIDs := map[string]string{}
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil || m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(asString(m["name"]))
|
||||
callID := strings.TrimSpace(asString(m["call_id"]))
|
||||
if name != "" && callID != "" {
|
||||
functionCallIDs[name] = callID
|
||||
}
|
||||
}
|
||||
if functionCallIDs["search_web"] != seenNames["search_web"] {
|
||||
t.Fatalf("search_web call_id mismatch between done and completed: done=%q completed=%q", seenNames["search_web"], functionCallIDs["search_web"])
|
||||
}
|
||||
if functionCallIDs["eval_javascript"] != seenNames["eval_javascript"] {
|
||||
t.Fatalf("eval_javascript call_id mismatch between done and completed: done=%q completed=%q", seenNames["eval_javascript"], functionCallIDs["eval_javascript"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||
sseLine("response/thinking_content", `{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning_text.delta") {
|
||||
t.Fatalf("expected reasoning stream events, body=%s", body)
|
||||
}
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
seen := map[string]bool{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
if name == "search_webeval_javascript" {
|
||||
t.Fatalf("unexpected merged tool name in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
args := strings.TrimSpace(asString(payload["arguments"]))
|
||||
if strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated arguments in thinking channel done payload: %#v", payload)
|
||||
}
|
||||
seen[name] = true
|
||||
}
|
||||
if !seen["search_web"] || !seen["eval_javascript"] {
|
||||
t.Fatalf("expected both tools in thinking channel done events, got %#v", seen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("我来调用工具\n") +
|
||||
sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected completed output to include function_call when mixed prose contains tool_calls payload, output=%#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
@@ -120,3 +448,30 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
out := make([]map[string]any, 0, 2)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||
matched = evt == targetEvent
|
||||
continue
|
||||
}
|
||||
if !matched || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if raw == "" || raw == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, payload)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
|
||||
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
@@ -23,7 +23,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
|
||||
if responseModel == "" {
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
@@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.St
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
|
||||
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
@@ -67,7 +67,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (ut
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"], traceID)
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestNormalizeOpenAIChatRequest(t *testing.T) {
|
||||
"temperature": 0.3,
|
||||
"stream": true,
|
||||
}
|
||||
n, err := normalizeOpenAIChatRequest(store, req)
|
||||
n, err := normalizeOpenAIChatRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
|
||||
"input": "ping",
|
||||
"instructions": "system",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req)
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
185
internal/adapter/openai/stream_status_test.go
Normal file
185
internal/adapter/openai/stream_status_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimw "github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type streamStatusAuthStub struct{}
|
||||
|
||||
func (streamStatusAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: "direct-token",
|
||||
CallerID: "caller:test",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (streamStatusAuthStub) DetermineCaller(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: "direct-token",
|
||||
CallerID: "caller:test",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (streamStatusAuthStub) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
type streamStatusDSStub struct {
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "session-id", nil
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
return m.resp, nil
|
||||
}
|
||||
|
||||
func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
|
||||
body := strings.Join(lines, "\n")
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
body += "\n"
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func captureStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
next.ServeHTTP(ww, r)
|
||||
*statuses = append(*statuses, ww.Status())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsStreamStatusCapturedAs200(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||
}
|
||||
if statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesStreamStatusCapturedAs200(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","input":"hi","stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||
}
|
||||
if statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
content, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||
})
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 || statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200, got %#v", statuses)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
outputText, _ := out["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText)
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ type toolStreamSieveState struct {
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
@@ -35,6 +36,7 @@ const toolSieveCaptureLimit = 8 * 1024
|
||||
const toolSieveContextTailLimit = 256
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.disableDeltas = false
|
||||
s.toolNameSent = false
|
||||
s.toolName = ""
|
||||
s.toolArgsStart = -1
|
||||
@@ -239,17 +241,8 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
if state.toolNameSent {
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return captured, nil, "", true
|
||||
}
|
||||
if state.toolNameSent {
|
||||
if len(parsed) > 1 {
|
||||
return prefixPart, parsed[1:], suffixPart, true
|
||||
}
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
|
||||
@@ -296,6 +289,9 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
}
|
||||
|
||||
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
if state.disableDeltas {
|
||||
return nil
|
||||
}
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return nil
|
||||
@@ -312,6 +308,16 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
||||
return nil
|
||||
}
|
||||
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
||||
if hasMultiple {
|
||||
state.disableDeltas = true
|
||||
return nil
|
||||
}
|
||||
if !certainSingle {
|
||||
// In uncertain phases (e.g. first call arrived but array not closed yet),
|
||||
// avoid speculative deltas and wait for final parsed tool_calls payload.
|
||||
return nil
|
||||
}
|
||||
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -363,6 +369,68 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
return deltas
|
||||
}
|
||||
|
||||
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return false, false
|
||||
}
|
||||
count := 0
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for ; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
if depth == 0 {
|
||||
count++
|
||||
if count > 1 {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
if depth > 0 {
|
||||
depth--
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == ',' && depth == 0 {
|
||||
// top-level separator means at least one more tool call exists
|
||||
// (or is expected). Treat as multi-call and stop incremental deltas.
|
||||
return false, true
|
||||
}
|
||||
if ch == ']' && depth == 0 {
|
||||
return count == 1, false
|
||||
}
|
||||
}
|
||||
// array not closed yet: still uncertain whether more calls will appear
|
||||
return false, false
|
||||
}
|
||||
|
||||
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
|
||||
21
internal/adapter/openai/trace.go
Normal file
21
internal/adapter/openai/trace.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
func requestTraceID(r *http.Request) string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
if q := strings.TrimSpace(r.URL.Query().Get("__trace_id")); q != "" {
|
||||
return q
|
||||
}
|
||||
if h := strings.TrimSpace(r.Header.Get("X-Ds2-Test-Trace")); h != "" {
|
||||
return h
|
||||
}
|
||||
return strings.TrimSpace(middleware.GetReqID(r.Context()))
|
||||
}
|
||||
47
internal/adapter/openai/trace_test.go
Normal file
47
internal/adapter/openai/trace_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
func traceIDViaMiddleware(req *http.Request) string {
|
||||
if req == nil {
|
||||
return requestTraceID(nil)
|
||||
}
|
||||
var got string
|
||||
h := middleware.RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
got = requestTraceID(r)
|
||||
}))
|
||||
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||
return got
|
||||
}
|
||||
|
||||
func TestRequestTraceIDPriority(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions?__trace_id=query-trace", nil)
|
||||
req.Header.Set("X-Ds2-Test-Trace", "header-trace")
|
||||
got := traceIDViaMiddleware(req)
|
||||
if got != "query-trace" {
|
||||
t.Fatalf("expected query trace id to win, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestTraceIDHeaderFallback(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
|
||||
req.Header.Set("X-Ds2-Test-Trace", "header-trace")
|
||||
got := traceIDViaMiddleware(req)
|
||||
if got != "header-trace" {
|
||||
t.Fatalf("expected header trace id to win when query missing, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestTraceIDReqIDFallback(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
|
||||
got := traceIDViaMiddleware(req)
|
||||
if got == "" {
|
||||
t.Fatal("expected middleware request id fallback to be non-empty")
|
||||
}
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
|
||||
@@ -36,5 +36,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Post("/vercel/sync", h.syncVercel)
|
||||
pr.Get("/vercel/status", h.vercelStatus)
|
||||
pr.Get("/export", h.exportConfig)
|
||||
pr.Get("/dev/captures", h.getDevCaptures)
|
||||
pr.Delete("/dev/captures", h.clearDevCaptures)
|
||||
})
|
||||
}
|
||||
|
||||
26
internal/admin/handler_dev_capture.go
Normal file
26
internal/admin/handler_dev_capture.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/devcapture"
|
||||
)
|
||||
|
||||
func (h *Handler) getDevCaptures(w http.ResponseWriter, _ *http.Request) {
|
||||
store := devcapture.Global()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"enabled": store.Enabled(),
|
||||
"limit": store.Limit(),
|
||||
"max_body_bytes": store.MaxBodyBytes(),
|
||||
"items": store.Snapshot(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) clearDevCaptures(w http.ResponseWriter, _ *http.Request) {
|
||||
store := devcapture.Global()
|
||||
store.Clear()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"detail": "capture logs cleared",
|
||||
})
|
||||
}
|
||||
45
internal/admin/handler_dev_capture_test.go
Normal file
45
internal/admin/handler_dev_capture_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDevCapturesShape(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dev/captures", nil)
|
||||
h.getDevCaptures(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if _, ok := out["enabled"]; !ok {
|
||||
t.Fatalf("expected enabled field, got %#v", out)
|
||||
}
|
||||
if _, ok := out["items"]; !ok {
|
||||
t.Fatalf("expected items field, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearDevCapturesShape(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodDelete, "/admin/dev/captures", nil)
|
||||
h.clearDevCaptures(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if out["success"] != true {
|
||||
t.Fatalf("expected success=true, got %#v", out)
|
||||
}
|
||||
}
|
||||
@@ -187,7 +187,12 @@ func extractCallerToken(req *http.Request) string {
|
||||
return token
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(req.Header.Get("x-api-key"))
|
||||
if key := strings.TrimSpace(req.Header.Get("x-api-key")); key != "" {
|
||||
return key
|
||||
}
|
||||
// Gemini AI Studio compatibility: allow query key fallback only when no
|
||||
// header-based credential is present.
|
||||
return strings.TrimSpace(req.URL.Query().Get("key"))
|
||||
}
|
||||
|
||||
func callerTokenID(token string) string {
|
||||
|
||||
@@ -114,6 +114,40 @@ func TestDetermineMissingToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineWithQueryKeyUsesDirectToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=direct-query-key", nil)
|
||||
|
||||
a, err := r.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
if a.UseConfigToken {
|
||||
t.Fatalf("expected direct token mode")
|
||||
}
|
||||
if a.DeepSeekToken != "direct-query-key" {
|
||||
t.Fatalf("unexpected token: %q", a.DeepSeekToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineHeaderTokenPrecedenceOverQueryKey(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=query-key", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
|
||||
a, err := r.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
defer r.Release(a)
|
||||
if !a.UseConfigToken {
|
||||
t.Fatalf("expected managed key mode from header token")
|
||||
}
|
||||
if a.AccountID == "" {
|
||||
t.Fatalf("expected managed account to be acquired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineCallerMissingToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
trans "ds2api/internal/deepseek/transport"
|
||||
"ds2api/internal/devcapture"
|
||||
"ds2api/internal/util"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
@@ -27,6 +28,7 @@ var intFrom = util.IntFrom
|
||||
type Client struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
capture *devcapture.Store
|
||||
regular trans.Doer
|
||||
stream trans.Doer
|
||||
fallback *http.Client
|
||||
@@ -39,6 +41,7 @@ func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
|
||||
return &Client{
|
||||
Store: store,
|
||||
Auth: resolver,
|
||||
capture: devcapture.Global(),
|
||||
regular: trans.New(60 * time.Second),
|
||||
stream: trans.New(0),
|
||||
fallback: &http.Client{Timeout: 60 * time.Second},
|
||||
@@ -179,6 +182,7 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
|
||||
}
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
headers["x-ds-pow-response"] = powResp
|
||||
captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
|
||||
attempts := 0
|
||||
for attempts < maxAttempts {
|
||||
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
|
||||
@@ -188,8 +192,14 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
if captureSession != nil {
|
||||
resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
attempts++
|
||||
time.Sleep(time.Second)
|
||||
|
||||
259
internal/devcapture/store.go
Normal file
259
internal/devcapture/store.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package devcapture
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultLimit = 5
|
||||
defaultMaxBodyBytes = 2 * 1024 * 1024
|
||||
maxLimit = 50
|
||||
)
|
||||
|
||||
type Entry struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Label string `json:"label"`
|
||||
URL string `json:"url"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
StatusCode int `json:"status_code"`
|
||||
RequestBody string `json:"request_body"`
|
||||
ResponseBody string `json:"response_body"`
|
||||
ResponseTruncated bool `json:"response_truncated"`
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.Mutex
|
||||
enabled bool
|
||||
limit int
|
||||
maxBodyBytes int
|
||||
items []Entry
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
store *Store
|
||||
id string
|
||||
createdAt int64
|
||||
label string
|
||||
url string
|
||||
accountID string
|
||||
requestRaw string
|
||||
}
|
||||
|
||||
type captureBody struct {
|
||||
rc io.ReadCloser
|
||||
s *Session
|
||||
statusCode int
|
||||
buf strings.Builder
|
||||
truncated bool
|
||||
finalized bool
|
||||
}
|
||||
|
||||
var (
|
||||
globalOnce sync.Once
|
||||
globalInst *Store
|
||||
)
|
||||
|
||||
func Global() *Store {
|
||||
globalOnce.Do(func() {
|
||||
globalInst = NewFromEnv()
|
||||
})
|
||||
return globalInst
|
||||
}
|
||||
|
||||
func NewFromEnv() *Store {
|
||||
enabled := !isVercelRuntime()
|
||||
if raw, ok := os.LookupEnv("DS2API_DEV_PACKET_CAPTURE"); ok {
|
||||
enabled = parseBool(raw)
|
||||
}
|
||||
limit := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_LIMIT"), defaultLimit)
|
||||
if limit < 1 {
|
||||
limit = defaultLimit
|
||||
}
|
||||
if limit > maxLimit {
|
||||
limit = maxLimit
|
||||
}
|
||||
maxBodyBytes := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES"), defaultMaxBodyBytes)
|
||||
if maxBodyBytes < 1024 {
|
||||
maxBodyBytes = defaultMaxBodyBytes
|
||||
}
|
||||
return &Store{
|
||||
enabled: enabled,
|
||||
limit: limit,
|
||||
maxBodyBytes: maxBodyBytes,
|
||||
items: make([]Entry, 0, limit),
|
||||
}
|
||||
}
|
||||
|
||||
func isVercelRuntime() bool {
|
||||
return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
|
||||
}
|
||||
|
||||
func (s *Store) Enabled() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
return s.enabled
|
||||
}
|
||||
|
||||
func (s *Store) Limit() int {
|
||||
if s == nil {
|
||||
return defaultLimit
|
||||
}
|
||||
return s.limit
|
||||
}
|
||||
|
||||
func (s *Store) MaxBodyBytes() int {
|
||||
if s == nil {
|
||||
return defaultMaxBodyBytes
|
||||
}
|
||||
return s.maxBodyBytes
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() []Entry {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]Entry, len(s.items))
|
||||
copy(out, s.items)
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Store) Clear() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.items = s.items[:0]
|
||||
}
|
||||
|
||||
func (s *Store) Start(label, url, accountID string, requestPayload any) *Session {
|
||||
if s == nil || !s.enabled {
|
||||
return nil
|
||||
}
|
||||
return &Session{
|
||||
store: s,
|
||||
id: "cap_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
createdAt: time.Now().Unix(),
|
||||
label: strings.TrimSpace(label),
|
||||
url: strings.TrimSpace(url),
|
||||
accountID: strings.TrimSpace(accountID),
|
||||
requestRaw: marshalPayload(requestPayload),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) WrapBody(rc io.ReadCloser, statusCode int) io.ReadCloser {
|
||||
if s == nil || rc == nil {
|
||||
return rc
|
||||
}
|
||||
return &captureBody{
|
||||
rc: rc,
|
||||
s: s,
|
||||
statusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *captureBody) Read(p []byte) (int, error) {
|
||||
n, err := c.rc.Read(p)
|
||||
if n > 0 {
|
||||
c.append(string(p[:n]))
|
||||
}
|
||||
if err == io.EOF {
|
||||
c.finalize()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *captureBody) Close() error {
|
||||
err := c.rc.Close()
|
||||
c.finalize()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *captureBody) append(chunk string) {
|
||||
if chunk == "" || c.s == nil || c.s.store == nil {
|
||||
return
|
||||
}
|
||||
maxLen := c.s.store.maxBodyBytes
|
||||
current := c.buf.Len()
|
||||
if current >= maxLen {
|
||||
c.truncated = true
|
||||
return
|
||||
}
|
||||
remain := maxLen - current
|
||||
if len(chunk) > remain {
|
||||
c.buf.WriteString(chunk[:remain])
|
||||
c.truncated = true
|
||||
return
|
||||
}
|
||||
c.buf.WriteString(chunk)
|
||||
}
|
||||
|
||||
func (c *captureBody) finalize() {
|
||||
if c.finalized || c.s == nil || c.s.store == nil {
|
||||
return
|
||||
}
|
||||
c.finalized = true
|
||||
entry := Entry{
|
||||
ID: c.s.id,
|
||||
CreatedAt: c.s.createdAt,
|
||||
Label: c.s.label,
|
||||
URL: c.s.url,
|
||||
AccountID: c.s.accountID,
|
||||
StatusCode: c.statusCode,
|
||||
RequestBody: c.s.requestRaw,
|
||||
ResponseBody: c.buf.String(),
|
||||
ResponseTruncated: c.truncated,
|
||||
}
|
||||
c.s.store.push(entry)
|
||||
}
|
||||
|
||||
func (s *Store) push(entry Entry) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.items = append([]Entry{entry}, s.items...)
|
||||
if len(s.items) > s.limit {
|
||||
s.items = s.items[:s.limit]
|
||||
}
|
||||
}
|
||||
|
||||
func marshalPayload(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func parseBool(v string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func parseIntWithDefault(raw string, d int) int {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return d
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return d
|
||||
}
|
||||
return n
|
||||
}
|
||||
55
internal/devcapture/store_test.go
Normal file
55
internal/devcapture/store_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package devcapture
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStorePushKeepsNewestWithinLimit(t *testing.T) {
|
||||
s := &Store{enabled: true, limit: 2, maxBodyBytes: 1024}
|
||||
for i := 0; i < 3; i++ {
|
||||
session := s.Start("test", "http://x", "", map[string]any{"seq": i})
|
||||
if session == nil {
|
||||
t.Fatal("expected session")
|
||||
}
|
||||
rc := session.WrapBody(io.NopCloser(strings.NewReader("ok")), 200)
|
||||
_, _ = io.ReadAll(rc)
|
||||
_ = rc.Close()
|
||||
}
|
||||
items := s.Snapshot()
|
||||
if len(items) != 2 {
|
||||
t.Fatalf("expected 2 items, got %d", len(items))
|
||||
}
|
||||
if !strings.Contains(items[0].RequestBody, `"seq":2`) {
|
||||
t.Fatalf("expected newest first, got %#v", items[0].RequestBody)
|
||||
}
|
||||
if !strings.Contains(items[1].RequestBody, `"seq":1`) {
|
||||
t.Fatalf("expected second newest, got %#v", items[1].RequestBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapBodyTruncatesByLimit(t *testing.T) {
|
||||
s := &Store{enabled: true, limit: 5, maxBodyBytes: 4}
|
||||
session := s.Start("test", "http://x", "acc1", map[string]any{"x": 1})
|
||||
if session == nil {
|
||||
t.Fatal("expected session")
|
||||
}
|
||||
rc := session.WrapBody(io.NopCloser(strings.NewReader("abcdef")), 200)
|
||||
_, _ = io.ReadAll(rc)
|
||||
_ = rc.Close()
|
||||
|
||||
items := s.Snapshot()
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected 1 item, got %d", len(items))
|
||||
}
|
||||
if items[0].ResponseBody != "abcd" {
|
||||
t.Fatalf("expected truncated body, got %q", items[0].ResponseBody)
|
||||
}
|
||||
if !items[0].ResponseTruncated {
|
||||
t.Fatal("expected truncated flag true")
|
||||
}
|
||||
if items[0].AccountID != "acc1" {
|
||||
t.Fatalf("expected account id, got %q", items[0].AccountID)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -43,36 +44,45 @@ func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalT
|
||||
}
|
||||
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
// Align responses tool-call semantics with chat/completions:
|
||||
// mixed prose + tool_call payloads should still be interpreted as tool calls.
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||
}
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
exposedOutputText = ""
|
||||
toolCalls := make([]any, 0, len(detected))
|
||||
for _, tc := range detected {
|
||||
toolCalls = append(toolCalls, map[string]any{
|
||||
"type": "tool_call",
|
||||
"name": tc.Name,
|
||||
"arguments": tc.Input,
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
output = append(output, map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
})
|
||||
}
|
||||
formatted := util.FormatOpenAIToolCalls(detected)
|
||||
output = append(output, toResponsesFunctionCallItems(formatted)...)
|
||||
output = append(output, map[string]any{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": toolCalls,
|
||||
"tool_calls": formatted,
|
||||
})
|
||||
} else {
|
||||
content := []any{
|
||||
map[string]any{
|
||||
"type": "output_text",
|
||||
"text": finalText,
|
||||
},
|
||||
}
|
||||
content := make([]any, 0, 2)
|
||||
if finalThinking != "" {
|
||||
content = append([]any{map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
}}, content...)
|
||||
}
|
||||
if strings.TrimSpace(finalText) != "" {
|
||||
content = append(content, map[string]any{
|
||||
"type": "output_text",
|
||||
"text": finalText,
|
||||
})
|
||||
}
|
||||
if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" {
|
||||
exposedOutputText = finalThinking
|
||||
}
|
||||
output = append(output, map[string]any{
|
||||
"type": "message",
|
||||
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
@@ -100,6 +110,54 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
|
||||
}
|
||||
}
|
||||
|
||||
func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
callID, _ := tc["id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
}
|
||||
name := ""
|
||||
args := "{}"
|
||||
if fn, ok := tc["function"].(map[string]any); ok {
|
||||
if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
|
||||
name = n
|
||||
}
|
||||
if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
|
||||
args = a
|
||||
}
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": normalizeJSONString(args),
|
||||
"status": "completed",
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeJSONString(raw string) string {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
var v any
|
||||
if err := json.Unmarshal([]byte(s), &v); err != nil {
|
||||
return raw
|
||||
}
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": delta,
|
||||
@@ -145,49 +203,105 @@ func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any
|
||||
|
||||
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.created",
|
||||
"id": responseID,
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"status": "in_progress",
|
||||
"type": "response.created",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"status": "in_progress",
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
"id": responseID,
|
||||
"delta": delta,
|
||||
"type": "response.output_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning.delta",
|
||||
"id": responseID,
|
||||
"delta": delta,
|
||||
"type": "response.reasoning.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning_text.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"text": text,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
"id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
"type": "response.output_tool_call.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.done",
|
||||
"id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
"type": "response.output_tool_call.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"call_id": callID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.function_call_arguments.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": normalizeJSONString(arguments),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
|
||||
responseID, _ := response["id"].(string)
|
||||
return map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": response,
|
||||
"type": "response.completed",
|
||||
"response_id": responseID,
|
||||
"response": response,
|
||||
}
|
||||
}
|
||||
|
||||
181
internal/format/openai/render_test.go
Normal file
181
internal/format/openai/render_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
`{"tool_calls":[{"name":"search","input":{"q":"golang"}}]}`,
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text to be hidden for tool calls, got %q", outputText)
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 2 {
|
||||
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
|
||||
}
|
||||
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected first output item type function_call, got %#v", first["type"])
|
||||
}
|
||||
if first["call_id"] == "" {
|
||||
t.Fatalf("expected function_call item to have call_id, got %#v", first)
|
||||
}
|
||||
second, _ := output[1].(map[string]any)
|
||||
if second["type"] != "tool_calls" {
|
||||
t.Fatalf("expected second output item type tool_calls, got %#v", second["type"])
|
||||
}
|
||||
var toolCalls []map[string]any
|
||||
switch v := second["tool_calls"].(type) {
|
||||
case []map[string]any:
|
||||
toolCalls = v
|
||||
case []any:
|
||||
toolCalls = make([]map[string]any, 0, len(v))
|
||||
for _, item := range v {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil {
|
||||
toolCalls = append(toolCalls, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", second["tool_calls"])
|
||||
}
|
||||
tc := toolCalls[0]
|
||||
if tc["type"] != "function" || tc["id"] == "" {
|
||||
t.Fatalf("unexpected tool call shape: %#v", tc)
|
||||
}
|
||||
fn, _ := tc["function"].(map[string]any)
|
||||
if fn["name"] != "search" {
|
||||
t.Fatalf("unexpected function name: %#v", fn["name"])
|
||||
}
|
||||
argsRaw, _ := fn["arguments"].(string)
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
|
||||
t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
|
||||
}
|
||||
if args["q"] != "golang" {
|
||||
t.Fatalf("unexpected arguments: %#v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
`示例格式:{"tool_calls":[{"name":"search","input":{"q":"golang"}}]},但这条是普通回答。`,
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText)
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 2 {
|
||||
t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected first output type function_call, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for fenced example")
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected message output type, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"internal thinking content",
|
||||
"",
|
||||
nil,
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text fallback from reasoning when final text is empty")
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected output type message, got %#v", first["type"])
|
||||
}
|
||||
content, _ := first["content"].([]any)
|
||||
if len(content) == 0 {
|
||||
t.Fatalf("expected reasoning content, got %#v", first["content"])
|
||||
}
|
||||
block0, _ := content[0].(map[string]any)
|
||||
if block0["type"] != "reasoning" {
|
||||
t.Fatalf("expected first content block reasoning, got %#v", block0["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
`{"tool_calls":[{"name":"search","input":{"q":"from-thinking"}}]}`,
|
||||
"",
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 3 {
|
||||
t.Fatalf("expected reasoning + function_call + tool_calls outputs, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "reasoning" {
|
||||
t.Fatalf("expected first output reasoning, got %#v", first["type"])
|
||||
}
|
||||
second, _ := output[1].(map[string]any)
|
||||
if second["type"] != "function_call" {
|
||||
t.Fatalf("expected second output function_call, got %#v", second["type"])
|
||||
}
|
||||
third, _ := output[2].(map[string]any)
|
||||
if third["type"] != "tool_calls" {
|
||||
t.Fatalf("expected third output tool_calls, got %#v", third["type"])
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/adapter/claude"
|
||||
"ds2api/internal/adapter/gemini"
|
||||
"ds2api/internal/adapter/openai"
|
||||
"ds2api/internal/admin"
|
||||
"ds2api/internal/auth"
|
||||
@@ -44,6 +45,7 @@ func NewApp() *App {
|
||||
|
||||
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient}
|
||||
webuiHandler := webui.NewHandler()
|
||||
|
||||
@@ -67,6 +69,7 @@ func NewApp() *App {
|
||||
})
|
||||
openai.RegisterRoutes(r, openaiHandler)
|
||||
claude.RegisterRoutes(r, claudeHandler)
|
||||
gemini.RegisterRoutes(r, geminiHandler)
|
||||
r.Route("/admin", func(ar chi.Router) {
|
||||
admin.RegisterRoutes(ar, adminHandler)
|
||||
})
|
||||
|
||||
8
tests/scripts/run-unit-all.sh
Executable file
8
tests/scripts/run-unit-all.sh
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
./tests/scripts/run-unit-go.sh
|
||||
./tests/scripts/run-unit-node.sh
|
||||
7
tests/scripts/run-unit-go.sh
Executable file
7
tests/scripts/run-unit-go.sh
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
go test ./... "$@"
|
||||
7
tests/scripts/run-unit-node.sh
Executable file
7
tests/scripts/run-unit-node.sh
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js "$@"
|
||||
12
vercel.json
12
vercel.json
@@ -38,6 +38,18 @@
|
||||
"source": "/admin/config",
|
||||
"destination": "/api/index"
|
||||
},
|
||||
{
|
||||
"source": "/admin/config/(.*)",
|
||||
"destination": "/api/index"
|
||||
},
|
||||
{
|
||||
"source": "/admin/settings",
|
||||
"destination": "/api/index"
|
||||
},
|
||||
{
|
||||
"source": "/admin/settings/(.*)",
|
||||
"destination": "/api/index"
|
||||
},
|
||||
{
|
||||
"source": "/admin/keys(.*)",
|
||||
"destination": "/api/index"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useEffect } from 'react'
|
||||
import { useState, useEffect, useCallback, useMemo } from 'react'
|
||||
import {
|
||||
Routes,
|
||||
Route,
|
||||
@@ -29,12 +29,12 @@ import Login from './components/Login'
|
||||
import LandingPage from './components/LandingPage'
|
||||
import LanguageToggle from './components/LanguageToggle'
|
||||
import { useI18n } from './i18n'
|
||||
import { detectRuntimeEnv } from './utils/runtimeEnv'
|
||||
|
||||
function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout }) {
|
||||
function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout, isVercel }) {
|
||||
const { t } = useI18n()
|
||||
const [activeTab, setActiveTab] = useState('accounts')
|
||||
const [sidebarOpen, setSidebarOpen] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const navItems = [
|
||||
{ id: 'accounts', label: t('nav.accounts.label'), icon: Users, description: t('nav.accounts.desc') },
|
||||
@@ -44,7 +44,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message,
|
||||
{ id: 'settings', label: t('nav.settings.label'), icon: SettingsIcon, description: t('nav.settings.desc') },
|
||||
]
|
||||
|
||||
const authFetch = async (url, options = {}) => {
|
||||
const authFetch = useCallback(async (url, options = {}) => {
|
||||
const headers = {
|
||||
...options.headers,
|
||||
'Authorization': `Bearer ${token}`
|
||||
@@ -56,7 +56,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message,
|
||||
throw new Error(t('auth.expired'))
|
||||
}
|
||||
return res
|
||||
}
|
||||
}, [onLogout, t, token])
|
||||
|
||||
const renderTab = () => {
|
||||
switch (activeTab) {
|
||||
@@ -67,9 +67,9 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message,
|
||||
case 'import':
|
||||
return <BatchImport onRefresh={fetchConfig} onMessage={showMessage} authFetch={authFetch} />
|
||||
case 'vercel':
|
||||
return <VercelSync onMessage={showMessage} authFetch={authFetch} />
|
||||
return <VercelSync onMessage={showMessage} authFetch={authFetch} isVercel={isVercel} />
|
||||
case 'settings':
|
||||
return <Settings onRefresh={fetchConfig} onMessage={showMessage} authFetch={authFetch} onForceLogout={onForceLogout} />
|
||||
return <Settings onRefresh={fetchConfig} onMessage={showMessage} authFetch={authFetch} onForceLogout={onForceLogout} isVercel={isVercel} />
|
||||
default:
|
||||
return null
|
||||
}
|
||||
@@ -213,13 +213,27 @@ export default function App() {
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
const [config, setConfig] = useState({ keys: [], accounts: [] })
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [message, setMessage] = useState(null)
|
||||
const [token, setToken] = useState(null)
|
||||
const [authChecking, setAuthChecking] = useState(true)
|
||||
|
||||
const isProduction = import.meta.env.MODE === 'production'
|
||||
const isAdminRoute = location.pathname.startsWith('/admin') || isProduction
|
||||
const runtimeEnv = useMemo(() => detectRuntimeEnv(), [])
|
||||
const isVercel = runtimeEnv.isVercel
|
||||
|
||||
const showMessage = useCallback((type, text) => {
|
||||
setMessage({ type, text })
|
||||
setTimeout(() => setMessage(null), 5000)
|
||||
}, [])
|
||||
|
||||
const handleLogout = useCallback(() => {
|
||||
setToken(null)
|
||||
localStorage.removeItem('ds2api_token')
|
||||
localStorage.removeItem('ds2api_token_expires')
|
||||
sessionStorage.removeItem('ds2api_token')
|
||||
sessionStorage.removeItem('ds2api_token_expires')
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
// Only check auth status on admin routes.
|
||||
@@ -249,12 +263,11 @@ export default function App() {
|
||||
setAuthChecking(false)
|
||||
}
|
||||
checkAuth()
|
||||
}, [isAdminRoute])
|
||||
}, [handleLogout, isAdminRoute])
|
||||
|
||||
const fetchConfig = async () => {
|
||||
const fetchConfig = useCallback(async () => {
|
||||
if (!token) return
|
||||
try {
|
||||
setLoading(true)
|
||||
const res = await fetch('/admin/config', {
|
||||
headers: { 'Authorization': `Bearer ${token}` }
|
||||
})
|
||||
@@ -265,34 +278,19 @@ export default function App() {
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch config:', e)
|
||||
showMessage('error', t('errors.fetchConfig', { error: e.message }))
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
}, [showMessage, t, token])
|
||||
|
||||
useEffect(() => {
|
||||
if (token) {
|
||||
fetchConfig()
|
||||
}
|
||||
}, [token])
|
||||
|
||||
const showMessage = (type, text) => {
|
||||
setMessage({ type, text })
|
||||
setTimeout(() => setMessage(null), 5000)
|
||||
}
|
||||
}, [fetchConfig, token])
|
||||
|
||||
const handleLogin = (newToken) => {
|
||||
setToken(newToken)
|
||||
}
|
||||
|
||||
const handleLogout = () => {
|
||||
setToken(null)
|
||||
localStorage.removeItem('ds2api_token')
|
||||
localStorage.removeItem('ds2api_token_expires')
|
||||
sessionStorage.removeItem('ds2api_token')
|
||||
sessionStorage.removeItem('ds2api_token_expires')
|
||||
}
|
||||
|
||||
// Wait for auth checks on admin routes.
|
||||
if (isAdminRoute && authChecking) {
|
||||
return (
|
||||
@@ -320,6 +318,7 @@ export default function App() {
|
||||
showMessage={showMessage}
|
||||
message={message}
|
||||
onForceLogout={handleLogout}
|
||||
isVercel={isVercel}
|
||||
/>
|
||||
) : (
|
||||
<div className="min-h-screen flex flex-col bg-background relative overflow-hidden">
|
||||
|
||||
@@ -2,7 +2,9 @@ import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { AlertTriangle, Download, Lock, Save, Upload } from 'lucide-react'
|
||||
import { useI18n } from '../i18n'
|
||||
|
||||
export default function Settings({ onRefresh, onMessage, authFetch, onForceLogout }) {
|
||||
const MAX_AUTO_FETCH_FAILURES = 3
|
||||
|
||||
export default function Settings({ onRefresh, onMessage, authFetch, onForceLogout, isVercel = false }) {
|
||||
const { t } = useI18n()
|
||||
const apiFetch = authFetch || fetch
|
||||
|
||||
@@ -14,6 +16,9 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou
|
||||
const [importMode, setImportMode] = useState('merge')
|
||||
const [importText, setImportText] = useState('')
|
||||
const [newPassword, setNewPassword] = useState('')
|
||||
const [consecutiveFailures, setConsecutiveFailures] = useState(0)
|
||||
const [autoFetchPaused, setAutoFetchPaused] = useState(false)
|
||||
const [lastError, setLastError] = useState('')
|
||||
const [settingsMeta, setSettingsMeta] = useState({ default_password_warning: false, env_backed: false, needs_vercel_sync: false })
|
||||
|
||||
const [form, setForm] = useState({
|
||||
@@ -43,15 +48,38 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou
|
||||
return parsed
|
||||
}
|
||||
|
||||
const loadSettings = useCallback(async () => {
|
||||
const parseJSONResponse = useCallback(async (res) => {
|
||||
const contentType = String(res.headers.get('content-type') || '').toLowerCase()
|
||||
if (!contentType.includes('application/json')) {
|
||||
throw new Error(t('settings.nonJsonResponse', { status: res.status }))
|
||||
}
|
||||
return res.json()
|
||||
}, [t])
|
||||
|
||||
const loadSettings = useCallback(async ({ manual = false } = {}) => {
|
||||
if (isVercel && autoFetchPaused && !manual) {
|
||||
return
|
||||
}
|
||||
setLoading(true)
|
||||
try {
|
||||
const res = await apiFetch('/admin/settings')
|
||||
const data = await res.json()
|
||||
const data = await parseJSONResponse(res)
|
||||
if (!res.ok) {
|
||||
onMessage('error', data.detail || t('settings.loadFailed'))
|
||||
const detail = data.detail || t('settings.loadFailed')
|
||||
setLastError(detail)
|
||||
onMessage('error', detail)
|
||||
setConsecutiveFailures((prev) => {
|
||||
const next = prev + 1
|
||||
if (isVercel && next >= MAX_AUTO_FETCH_FAILURES) {
|
||||
setAutoFetchPaused(true)
|
||||
}
|
||||
return next
|
||||
})
|
||||
return
|
||||
}
|
||||
setConsecutiveFailures(0)
|
||||
setAutoFetchPaused(false)
|
||||
setLastError('')
|
||||
setSettingsMeta({
|
||||
default_password_warning: Boolean(data.admin?.default_password_warning),
|
||||
env_backed: Boolean(data.env_backed),
|
||||
@@ -78,18 +106,32 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou
|
||||
model_aliases_text: JSON.stringify(data.model_aliases || {}, null, 2),
|
||||
})
|
||||
} catch (e) {
|
||||
onMessage('error', t('settings.loadFailed'))
|
||||
const detail = e?.message || t('settings.loadFailed')
|
||||
setLastError(detail)
|
||||
onMessage('error', detail)
|
||||
setConsecutiveFailures((prev) => {
|
||||
const next = prev + 1
|
||||
if (isVercel && next >= MAX_AUTO_FETCH_FAILURES) {
|
||||
setAutoFetchPaused(true)
|
||||
}
|
||||
return next
|
||||
})
|
||||
// eslint-disable-next-line no-console
|
||||
console.error(e)
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [apiFetch, onMessage, t])
|
||||
}, [apiFetch, autoFetchPaused, isVercel, onMessage, parseJSONResponse, t])
|
||||
|
||||
useEffect(() => {
|
||||
loadSettings()
|
||||
}, [loadSettings])
|
||||
|
||||
const retryLoadSettings = () => {
|
||||
setAutoFetchPaused(false)
|
||||
loadSettings({ manual: true })
|
||||
}
|
||||
|
||||
const saveSettings = async () => {
|
||||
let claudeMapping = {}
|
||||
let modelAliases = {}
|
||||
@@ -228,6 +270,23 @@ export default function Settings({ onRefresh, onMessage, authFetch, onForceLogou
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{autoFetchPaused && (
|
||||
<div className="p-4 rounded-lg border border-destructive/30 bg-destructive/10 text-destructive flex items-center justify-between gap-4">
|
||||
<div className="flex items-center gap-2">
|
||||
<AlertTriangle className="w-4 h-4" />
|
||||
<span className="text-sm">
|
||||
{t('settings.autoFetchPaused', { count: consecutiveFailures, error: lastError || t('settings.loadFailed') })}
|
||||
</span>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={retryLoadSettings}
|
||||
className="px-3 py-1.5 text-xs rounded-md border border-destructive/40 hover:bg-destructive/10"
|
||||
>
|
||||
{t('settings.retryLoad')}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{settingsMeta.default_password_warning && (
|
||||
<div className="p-4 rounded-lg border border-amber-300/30 bg-amber-500/10 text-amber-700 flex items-center gap-2">
|
||||
<AlertTriangle className="w-4 h-4" />
|
||||
|
||||
@@ -3,7 +3,15 @@ import { Cloud, ArrowRight, ExternalLink, Info, CheckCircle2, XCircle, RefreshCw
|
||||
import clsx from 'clsx'
|
||||
import { useI18n } from '../i18n'
|
||||
|
||||
export default function VercelSync({ onMessage, authFetch }) {
|
||||
const MAX_POLL_FAILURES = 3
|
||||
|
||||
function pollDelayMs(attempt) {
|
||||
if (attempt <= 0) return 15000
|
||||
if (attempt === 1) return 30000
|
||||
return 60000
|
||||
}
|
||||
|
||||
export default function VercelSync({ onMessage, authFetch, isVercel = false }) {
|
||||
const { t } = useI18n()
|
||||
const [vercelToken, setVercelToken] = useState('')
|
||||
const [projectId, setProjectId] = useState('')
|
||||
@@ -12,20 +20,42 @@ export default function VercelSync({ onMessage, authFetch }) {
|
||||
const [result, setResult] = useState(null)
|
||||
const [preconfig, setPreconfig] = useState(null)
|
||||
const [syncStatus, setSyncStatus] = useState(null)
|
||||
const [pollPaused, setPollPaused] = useState(false)
|
||||
const [pollFailures, setPollFailures] = useState(0)
|
||||
const [nextRetryAt, setNextRetryAt] = useState(null)
|
||||
|
||||
const apiFetch = authFetch || fetch
|
||||
|
||||
const fetchSyncStatus = useCallback(async () => {
|
||||
const fetchSyncStatus = useCallback(async ({ manual = false } = {}) => {
|
||||
try {
|
||||
const res = await apiFetch('/admin/vercel/status')
|
||||
if (res.ok) {
|
||||
const data = await res.json()
|
||||
setSyncStatus(data)
|
||||
if (!res.ok) {
|
||||
throw new Error(`status ${res.status}`)
|
||||
}
|
||||
const data = await res.json()
|
||||
setSyncStatus(data)
|
||||
setPollFailures(0)
|
||||
setPollPaused(false)
|
||||
setNextRetryAt(null)
|
||||
} catch (e) {
|
||||
setPollFailures((prev) => {
|
||||
const next = prev + 1
|
||||
if (isVercel) {
|
||||
if (next >= MAX_POLL_FAILURES) {
|
||||
setPollPaused(true)
|
||||
setNextRetryAt(null)
|
||||
} else {
|
||||
setNextRetryAt(Date.now() + pollDelayMs(next))
|
||||
}
|
||||
}
|
||||
return next
|
||||
})
|
||||
if (manual) {
|
||||
onMessage('error', t('vercel.networkError'))
|
||||
}
|
||||
console.error('Failed to fetch sync status:', e)
|
||||
}
|
||||
}, [apiFetch])
|
||||
}, [apiFetch, isVercel, onMessage, t])
|
||||
|
||||
useEffect(() => {
|
||||
const loadPreconfig = async () => {
|
||||
@@ -43,11 +73,32 @@ export default function VercelSync({ onMessage, authFetch }) {
|
||||
}
|
||||
loadPreconfig()
|
||||
fetchSyncStatus()
|
||||
// Poll every 15s to detect config changes
|
||||
const interval = setInterval(fetchSyncStatus, 15000)
|
||||
return () => clearInterval(interval)
|
||||
}, [fetchSyncStatus])
|
||||
|
||||
useEffect(() => {
|
||||
if (!isVercel) {
|
||||
const interval = setInterval(() => {
|
||||
fetchSyncStatus()
|
||||
}, 15000)
|
||||
return () => clearInterval(interval)
|
||||
}
|
||||
if (pollPaused) {
|
||||
return undefined
|
||||
}
|
||||
const delay = nextRetryAt && nextRetryAt > Date.now() ? nextRetryAt - Date.now() : pollDelayMs(pollFailures)
|
||||
const timer = setTimeout(() => {
|
||||
fetchSyncStatus()
|
||||
}, Math.max(1000, delay))
|
||||
return () => clearTimeout(timer)
|
||||
}, [fetchSyncStatus, isVercel, nextRetryAt, pollFailures, pollPaused])
|
||||
|
||||
const handleManualRefresh = () => {
|
||||
setPollPaused(false)
|
||||
setPollFailures(0)
|
||||
setNextRetryAt(null)
|
||||
fetchSyncStatus({ manual: true })
|
||||
}
|
||||
|
||||
const handleSync = async () => {
|
||||
const tokenToUse = preconfig?.has_token && !vercelToken ? '__USE_PRECONFIG__' : vercelToken
|
||||
|
||||
@@ -122,6 +173,20 @@ export default function VercelSync({ onMessage, authFetch }) {
|
||||
<p className="text-muted-foreground text-sm mt-1">
|
||||
{t('vercel.description')}
|
||||
</p>
|
||||
{pollPaused && (
|
||||
<div className="mt-2 flex flex-wrap items-center gap-2">
|
||||
<p className="text-xs text-destructive">
|
||||
{t('vercel.pollPaused', { count: pollFailures })}
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleManualRefresh}
|
||||
className="px-2 py-1 text-xs rounded border border-border hover:bg-secondary/50"
|
||||
>
|
||||
{t('vercel.manualRefresh')}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{syncStatus?.last_sync_time && (
|
||||
<p className="text-xs text-muted-foreground/60 mt-1.5 flex items-center gap-1">
|
||||
<RefreshCw className="w-3 h-3" />
|
||||
|
||||
@@ -198,6 +198,7 @@
|
||||
},
|
||||
"settings": {
|
||||
"loadFailed": "Failed to load settings.",
|
||||
"nonJsonResponse": "Unexpected non-JSON response from server (status: {status}).",
|
||||
"save": "Save settings",
|
||||
"saving": "Saving...",
|
||||
"saveSuccess": "Settings saved and hot reloaded.",
|
||||
@@ -239,7 +240,9 @@
|
||||
"exportJson": "Export JSON",
|
||||
"invalidJsonField": "{field} is not a valid JSON object.",
|
||||
"defaultPasswordWarning": "You are using the default admin password \"admin\". Please change it.",
|
||||
"vercelSyncHint": "Configuration changed. For Vercel deployments, sync manually in Vercel Sync and redeploy."
|
||||
"vercelSyncHint": "Configuration changed. For Vercel deployments, sync manually in Vercel Sync and redeploy.",
|
||||
"autoFetchPaused": "Auto loading paused after {count} failures: {error}",
|
||||
"retryLoad": "Retry now"
|
||||
},
|
||||
"login": {
|
||||
"welcome": "Welcome back",
|
||||
@@ -278,6 +281,8 @@
|
||||
"statusNotSynced": "Not synced",
|
||||
"statusNeverSynced": "Never synced",
|
||||
"lastSyncTime": "Last sync: {time}",
|
||||
"pollPaused": "Status polling paused after {count} failures.",
|
||||
"manualRefresh": "Refresh manually",
|
||||
"howItWorks": "How it works",
|
||||
"steps": {
|
||||
"one": "The current configuration (keys and accounts) is exported as JSON.",
|
||||
|
||||
@@ -198,6 +198,7 @@
|
||||
},
|
||||
"settings": {
|
||||
"loadFailed": "加载设置失败",
|
||||
"nonJsonResponse": "服务端返回了非 JSON 响应(状态码:{status})",
|
||||
"save": "保存设置",
|
||||
"saving": "保存中...",
|
||||
"saveSuccess": "设置已保存并热更新生效",
|
||||
@@ -239,7 +240,9 @@
|
||||
"exportJson": "导出 JSON",
|
||||
"invalidJsonField": "{field} 不是有效 JSON 对象",
|
||||
"defaultPasswordWarning": "当前使用默认密码 admin,请尽快在此修改。",
|
||||
"vercelSyncHint": "当前配置已更新。Vercel 部署请到 Vercel 同步页面手动同步并重部署。"
|
||||
"vercelSyncHint": "当前配置已更新。Vercel 部署请到 Vercel 同步页面手动同步并重部署。",
|
||||
"autoFetchPaused": "自动加载已暂停:连续失败 {count} 次({error})",
|
||||
"retryLoad": "立即重试"
|
||||
},
|
||||
"login": {
|
||||
"welcome": "欢迎回来",
|
||||
@@ -278,6 +281,8 @@
|
||||
"statusNotSynced": "未同步",
|
||||
"statusNeverSynced": "从未同步",
|
||||
"lastSyncTime": "上次同步: {time}",
|
||||
"pollPaused": "状态轮询已暂停:连续失败 {count} 次。",
|
||||
"manualRefresh": "手动刷新",
|
||||
"howItWorks": "工作原理",
|
||||
"steps": {
|
||||
"one": "当前配置 (密钥和账号) 被导出为 JSON 字符串。",
|
||||
|
||||
13
webui/src/utils/runtimeEnv.js
Normal file
13
webui/src/utils/runtimeEnv.js
Normal file
@@ -0,0 +1,13 @@
|
||||
export function detectRuntimeEnv() {
|
||||
const deployTarget = String(import.meta.env.VITE_DEPLOY_TARGET || '').trim().toLowerCase()
|
||||
if (deployTarget === 'vercel') {
|
||||
return { isVercel: true, source: 'vite_env' }
|
||||
}
|
||||
|
||||
const host = typeof window !== 'undefined' ? String(window.location.hostname || '').toLowerCase() : ''
|
||||
if (host.includes('vercel.app')) {
|
||||
return { isVercel: true, source: 'hostname' }
|
||||
}
|
||||
|
||||
return { isVercel: false, source: 'default' }
|
||||
}
|
||||
Reference in New Issue
Block a user