mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-03 16:05:26 +08:00
Compare commits
89 Commits
v2.5.1_bet
...
v3.1.0_bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
068f4b0df6 | ||
|
|
5a51045ba4 | ||
|
|
3497d5d019 | ||
|
|
95a9d16843 | ||
|
|
0847091864 | ||
|
|
c6340354ec | ||
|
|
6bf08e00cd | ||
|
|
35221002d5 | ||
|
|
4b1f1ea550 | ||
|
|
0258f83d10 | ||
|
|
da912f87bf | ||
|
|
6b32d84222 | ||
|
|
e1df5c8636 | ||
|
|
f23382ff5f | ||
|
|
fabdba48c3 | ||
|
|
a28e833f33 | ||
|
|
ec1be468ca | ||
|
|
fe43f1e6ee | ||
|
|
440d759584 | ||
|
|
a6a9863fc3 | ||
|
|
f787e25641 | ||
|
|
5722f21cdd | ||
|
|
ca3c16c424 | ||
|
|
8b86f1c903 | ||
|
|
b758ce9234 | ||
|
|
1cf28101d6 | ||
|
|
c1bdb6776b | ||
|
|
47544fb385 | ||
|
|
2a05c96f5f | ||
|
|
cbc68f7e92 | ||
|
|
5576043106 | ||
|
|
287e8d5a60 | ||
|
|
8a2c500806 | ||
|
|
e958bf7e40 | ||
|
|
443fa4ad8e | ||
|
|
2d62c658f8 | ||
|
|
6a632ad9ef | ||
|
|
cd2f5ad3b0 | ||
|
|
1457b63a76 | ||
|
|
24655342a7 | ||
|
|
39f6e066d6 | ||
|
|
02d64c192e | ||
|
|
283aa304df | ||
|
|
02fe3e4bfc | ||
|
|
15bf77e044 | ||
|
|
add0d0cc06 | ||
|
|
a87ec3fd68 | ||
|
|
50ce88ca3f | ||
|
|
48a5f1c39e | ||
|
|
07578f9c56 | ||
|
|
5ebc33c347 | ||
|
|
cc74397edc | ||
|
|
1289e8afd8 | ||
|
|
e60738b084 | ||
|
|
f6cd541c6f | ||
|
|
1eb47147c2 | ||
|
|
da3fafb79a | ||
|
|
3900aaec47 | ||
|
|
8a74dbff9c | ||
|
|
bfca84c2c7 | ||
|
|
1cdfa9c05d | ||
|
|
fe8232bfc1 | ||
|
|
063599678a | ||
|
|
f55aa7564a | ||
|
|
3b60e3c8f9 | ||
|
|
efebe9ebad | ||
|
|
b54b418f96 | ||
|
|
1c5f022b06 | ||
|
|
836eaf5290 | ||
|
|
958e7a0d04 | ||
|
|
f3555ae9b0 | ||
|
|
d50d39e2e5 | ||
|
|
01393837be | ||
|
|
1fe1240240 | ||
|
|
c07736fbea | ||
|
|
775bf3b578 | ||
|
|
ab3943ebeb | ||
|
|
6efba7b2e4 | ||
|
|
765d0231cd | ||
|
|
aebf3e9119 | ||
|
|
535d9298a7 | ||
|
|
b790545d82 | ||
|
|
034c00f10e | ||
|
|
390f7580e5 | ||
|
|
586d31e556 | ||
|
|
c4a73e871a | ||
|
|
25b3292497 | ||
|
|
11f66db87d | ||
|
|
7131b06e26 |
24
API.en.md
24
API.en.md
@@ -31,6 +31,13 @@ This document describes the actual behavior of the current Go codebase.
|
||||
| Health probes | `GET /healthz`, `GET /readyz` |
|
||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
### 3.0 Adapter-Layer Notes
|
||||
|
||||
- OpenAI / Claude / Gemini protocols are now mounted on one shared `chi` router tree assembled in `internal/server/router.go`.
|
||||
- Adapter responsibilities are streamlined to: **request normalization → DeepSeek invocation → protocol-shaped rendering**, reducing legacy split-logic paths.
|
||||
- Tool-calling semantics are aligned between Go and Node runtime: structured parsing first (JSON/XML/invoke/markup), plus stream-time anti-leak filtering.
|
||||
- `Admin API` separates static config from runtime policy: `/admin/config*` for configuration state, `/admin/settings*` for runtime behavior.
|
||||
|
||||
---
|
||||
|
||||
## Configuration Best Practice
|
||||
@@ -91,7 +98,9 @@ Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=
|
||||
| Method | Path | Auth | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| GET | `/healthz` | None | Liveness probe |
|
||||
| HEAD | `/healthz` | None | Liveness probe (no body) |
|
||||
| GET | `/readyz` | None | Readiness probe |
|
||||
| HEAD | `/readyz` | None | Readiness probe (no body) |
|
||||
| GET | `/v1/models` | None | OpenAI model list |
|
||||
| GET | `/v1/models/{id}` | None | OpenAI single-model query (alias accepted) |
|
||||
| POST | `/v1/chat/completions` | Business | OpenAI chat completions |
|
||||
@@ -587,6 +596,9 @@ Returns sanitized config.
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"env_backed": false,
|
||||
"env_source_present": true,
|
||||
"env_writeback_enabled": true,
|
||||
"config_path": "/data/config.json",
|
||||
"accounts": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
@@ -928,15 +940,15 @@ Checks the current build version and the latest GitHub Release:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"current_version": "2.3.5",
|
||||
"current_tag": "v2.3.5",
|
||||
"current_version": "3.0.0",
|
||||
"current_tag": "v3.0.0",
|
||||
"source": "file:VERSION",
|
||||
"checked_at": "2026-03-29T00:00:00Z",
|
||||
"latest_tag": "v2.3.6",
|
||||
"latest_version": "2.3.6",
|
||||
"release_url": "https://github.com/CJackHwang/ds2api/releases/tag/v2.3.6",
|
||||
"latest_tag": "v3.0.0",
|
||||
"latest_version": "3.0.0",
|
||||
"release_url": "https://github.com/CJackHwang/ds2api/releases/tag/v3.0.0",
|
||||
"published_at": "2026-03-28T12:00:00Z",
|
||||
"has_update": true
|
||||
"has_update": false
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
24
API.md
24
API.md
@@ -31,6 +31,13 @@
|
||||
| 健康检查 | `GET /healthz`、`GET /readyz` |
|
||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
### 3.0 接口适配层说明
|
||||
|
||||
- OpenAI / Claude / Gemini 三套协议已统一挂在同一 `chi` 路由树上,由 `internal/server/router.go` 负责装配。
|
||||
- 适配器层职责收敛为:**请求归一化 → DeepSeek 调用 → 协议形态渲染**,减少历史版本中“同能力多处实现”的分叉。
|
||||
- Tool Calling 的解析策略在 Go 与 Node Runtime 间保持一致:优先结构化解析(JSON/XML/invoke/markup),并在流式场景执行防泄漏筛分。
|
||||
- `Admin API` 将配置与运行时策略分开:`/admin/config*` 管静态配置,`/admin/settings*` 管运行时行为。
|
||||
|
||||
---
|
||||
|
||||
## 配置最佳实践
|
||||
@@ -91,7 +98,9 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| GET | `/healthz` | 无 | 存活探针 |
|
||||
| HEAD | `/healthz` | 无 | 存活探针(无响应体) |
|
||||
| GET | `/readyz` | 无 | 就绪探针 |
|
||||
| HEAD | `/readyz` | 无 | 就绪探针(无响应体) |
|
||||
| GET | `/v1/models` | 无 | OpenAI 模型列表 |
|
||||
| GET | `/v1/models/{id}` | 无 | OpenAI 单模型查询(支持 alias 入参) |
|
||||
| POST | `/v1/chat/completions` | 业务 | OpenAI 对话补全 |
|
||||
@@ -596,6 +605,9 @@ data: {"type":"message_stop"}
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"env_backed": false,
|
||||
"env_source_present": true,
|
||||
"env_writeback_enabled": true,
|
||||
"config_path": "/data/config.json",
|
||||
"accounts": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
@@ -934,15 +946,15 @@ data: {"type":"message_stop"}
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"current_version": "2.3.5",
|
||||
"current_tag": "v2.3.5",
|
||||
"current_version": "3.0.0",
|
||||
"current_tag": "v3.0.0",
|
||||
"source": "file:VERSION",
|
||||
"checked_at": "2026-03-29T00:00:00Z",
|
||||
"latest_tag": "v2.3.6",
|
||||
"latest_version": "2.3.6",
|
||||
"release_url": "https://github.com/CJackHwang/ds2api/releases/tag/v2.3.6",
|
||||
"latest_tag": "v3.0.0",
|
||||
"latest_version": "3.0.0",
|
||||
"release_url": "https://github.com/CJackHwang/ds2api/releases/tag/v3.0.0",
|
||||
"published_at": "2026-03-28T12:00:00Z",
|
||||
"has_update": true
|
||||
"has_update": false
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM node:20 AS webui-builder
|
||||
FROM node:24 AS webui-builder
|
||||
|
||||
WORKDIR /app/webui
|
||||
COPY webui/package.json webui/package-lock.json ./
|
||||
@@ -6,7 +6,7 @@ RUN npm ci
|
||||
COPY webui ./
|
||||
RUN npm run build
|
||||
|
||||
FROM golang:1.24 AS go-builder
|
||||
FROM golang:1.26 AS go-builder
|
||||
WORKDIR /app
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
||||
84
README.MD
84
README.MD
@@ -28,43 +28,64 @@
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
Client["🖥️ 客户端\n(OpenAI / Claude / Gemini 兼容)"]
|
||||
Client["🖥️ 客户端 / SDK\n(OpenAI / Claude / Gemini)"]
|
||||
Upstream["☁️ DeepSeek API"]
|
||||
|
||||
subgraph DS2API["DS2API 服务"]
|
||||
direction TB
|
||||
CORS["CORS 中间件"]
|
||||
Auth["🔐 鉴权中间件"]
|
||||
subgraph DS2API["DS2API 3.x(统一 OpenAI 内核)"]
|
||||
Router["chi Router + 中间件\n(RequestID / RealIP / Logger / Recoverer / CORS)"]
|
||||
|
||||
subgraph Adapters["适配器层"]
|
||||
OA["OpenAI 适配器\n/v1/*"]
|
||||
CA["Claude 适配器\n/anthropic/*"]
|
||||
GA["Gemini 适配器\n/v1beta/models/*"]
|
||||
subgraph Adapters["协议适配层"]
|
||||
OA["OpenAI\n/v1/*"]
|
||||
CA["Claude\n/anthropic/* + /v1/messages"]
|
||||
GA["Gemini\n/v1beta/models/* + /v1/models/*"]
|
||||
Admin["Admin API\n/admin/*"]
|
||||
WebUI["WebUI\n/admin(静态托管)"]
|
||||
end
|
||||
|
||||
subgraph Support["支撑模块"]
|
||||
Pool["📦 账号池 / 并发队列"]
|
||||
PoW["⚙️ PoW WASM\n(wazero)"]
|
||||
subgraph Runtime["运行时核心能力"]
|
||||
Bridge["CLIProxy 转换桥\n(多协议 <-> OpenAI)"]
|
||||
OAEngine["OpenAI ChatCompletions\n(统一工具调用与流式语义)"]
|
||||
Auth["Auth Resolver\n(API key / bearer / x-goog-api-key)"]
|
||||
Pool["Account Pool + Queue\n(并发槽位 + 等待队列)"]
|
||||
DSClient["DeepSeek Client\n(Session / Auth / HTTP)"]
|
||||
Pow["PoW WASM\n(wazero 预加载)"]
|
||||
Tool["Tool Sieve\n(Go/Node 语义对齐)"]
|
||||
end
|
||||
|
||||
Admin["🛠️ Admin API\n/admin/*"]
|
||||
WebUI["🌐 WebUI\n(/admin)"]
|
||||
end
|
||||
|
||||
DS["☁️ DeepSeek API"]
|
||||
Client --> Router
|
||||
Router --> OA & CA & GA
|
||||
Router --> Admin
|
||||
Router --> WebUI
|
||||
|
||||
Client -- "请求" --> CORS --> Auth
|
||||
Auth --> OA & CA & GA
|
||||
OA & CA & GA -- "调用" --> DS
|
||||
Auth --> Admin
|
||||
OA & CA & GA -. "轮询选账号" .-> Pool
|
||||
OA & CA & GA -. "计算 PoW" .-> PoW
|
||||
DS -- "响应" --> Client
|
||||
OA --> OAEngine
|
||||
CA & GA --> Bridge
|
||||
Bridge --> OAEngine
|
||||
OAEngine --> Auth
|
||||
OAEngine -.账号轮询.-> Pool
|
||||
OAEngine -.工具调用解析.-> Tool
|
||||
OAEngine -.PoW 计算.-> Pow
|
||||
Auth --> DSClient
|
||||
DSClient --> Upstream
|
||||
Upstream --> DSClient
|
||||
OAEngine --> Bridge
|
||||
Bridge --> Client
|
||||
```
|
||||
|
||||
- **后端**:Go(`cmd/ds2api/`、`api/`、`internal/`),不依赖 Python 运行时
|
||||
- **前端**:React 管理台(`webui/`),运行时托管静态构建产物
|
||||
- **部署**:本地运行、Docker、Vercel Serverless、Linux systemd
|
||||
|
||||
### 3.0 底层架构调整(相较旧版本)
|
||||
|
||||
- **统一路由内核**:所有协议入口统一汇聚到 `internal/server/router.go`,并在同一路由树中注册 OpenAI / Claude / Gemini / Admin / WebUI 路由,避免多入口行为漂移。
|
||||
- **统一执行链路**:Claude / Gemini 入口先经 `internal/translatorcliproxy` 做协议转换,再进入 `openai.ChatCompletions` 统一处理工具调用与流式语义,最后再转换回原协议响应。
|
||||
- **适配器分层更清晰**:`internal/adapter/{claude,gemini}` 负责入口/出口协议封装,`internal/adapter/openai` 负责核心执行,DeepSeek 侧调用只保留在 OpenAI 内核中。
|
||||
- **Tool Calling 双运行时对齐**:Go 侧(`internal/util`)与 Vercel Node 侧(`internal/js/helpers/stream-tool-sieve`)保持一致的解析/防泄漏语义,覆盖 JSON / XML / invoke / text-kv 多风格输入。
|
||||
- **配置与运行时设置解耦**:静态配置(`config`)与运行时策略(`settings`)通过 Admin API 分离管理,支持热更新和密码轮换失效旧 JWT。
|
||||
- **流式能力升级**:`/v1/responses` 与 `/v1/chat/completions` 共享更一致的工具调用增量输出策略,降低不同 SDK 下的行为差异。
|
||||
- **可观测与可运维增强**:`/healthz`、`/readyz`、`/admin/version`、`/admin/dev/captures` 形成排障闭环,便于发布后验证。
|
||||
|
||||
## 核心能力
|
||||
|
||||
| 能力 | 说明 |
|
||||
@@ -144,7 +165,7 @@ cp config.example.json config.json
|
||||
|
||||
### 方式一:本地运行
|
||||
|
||||
**前置要求**:Go 1.24+,Node.js 20+(仅在需要构建 WebUI 时)
|
||||
**前置要求**:Go 1.26+,Node.js 20+(仅在需要构建 WebUI 时)
|
||||
|
||||
```bash
|
||||
# 1. 克隆仓库
|
||||
@@ -166,8 +187,9 @@ go run ./cmd/ds2api
|
||||
### 方式二:Docker 运行
|
||||
|
||||
```bash
|
||||
# 1. 准备环境变量文件
|
||||
# 1. 准备环境变量和配置文件
|
||||
cp .env.example .env
|
||||
cp config.example.json config.json
|
||||
|
||||
# 2. 编辑 .env(至少设置 DS2API_ADMIN_KEY)
|
||||
# DS2API_ADMIN_KEY=请替换为强密码
|
||||
@@ -320,6 +342,7 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_CONFIG_PATH` | 配置文件路径 | `config.json` |
|
||||
| `DS2API_CONFIG_JSON` | 直接注入配置(JSON 或 Base64) | — |
|
||||
| `CONFIG_JSON` | 旧版兼容配置注入 | — |
|
||||
| `DS2API_ENV_WRITEBACK` | 环境变量模式下自动写回配置文件并切换文件模式(`1/true/yes/on`) | 关闭 |
|
||||
| `DS2API_WASM_PATH` | PoW WASM 文件路径 | 自动查找 |
|
||||
| `DS2API_STATIC_ADMIN_DIR` | 管理台静态文件目录 | `static/admin` |
|
||||
| `DS2API_AUTO_BUILD_WEBUI` | 启动时自动构建 WebUI | 本地开启,Vercel 关闭 |
|
||||
@@ -342,6 +365,8 @@ cp opencode.json.example opencode.json
|
||||
| `VERCEL_TEAM_ID` | Vercel 团队 ID | — |
|
||||
| `DS2API_VERCEL_PROTECTION_BYPASS` | Vercel 部署保护绕过密钥(内部 Node→Go 调用) | — |
|
||||
|
||||
> 提示:当检测到 `DS2API_CONFIG_JSON/CONFIG_JSON` 时,管理台会显示当前模式风险与自动持久化状态(含 `DS2API_CONFIG_PATH` 路径与模式切换说明)。
|
||||
|
||||
## 鉴权模式
|
||||
|
||||
调用业务接口(`/v1/*`、`/anthropic/*`、Gemini 路由)时支持两种模式:
|
||||
@@ -408,6 +433,7 @@ go run ./cmd/ds2api
|
||||
|
||||
```text
|
||||
ds2api/
|
||||
├── app/ # 统一 HTTP Handler 组装层(供本地与 Serverless 复用)
|
||||
├── cmd/
|
||||
│ ├── ds2api/ # 本地 / 容器启动入口
|
||||
│ └── ds2api-tests/ # 端到端测试集入口
|
||||
@@ -424,8 +450,8 @@ ds2api/
|
||||
│ ├── admin/ # Admin API handlers(含 Settings 热更新)
|
||||
│ ├── auth/ # 鉴权与 JWT
|
||||
│ ├── claudeconv/ # Claude 消息格式转换
|
||||
│ ├── compat/ # 兼容性辅助
|
||||
│ ├── config/ # 配置加载与热更新
|
||||
│ ├── compat/ # Go 版本兼容与回归测试辅助
|
||||
│ ├── config/ # 配置加载、校验与热更新
|
||||
│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM
|
||||
│ ├── js/ # Node 运行时流式处理与兼容逻辑
|
||||
│ ├── devcapture/ # 开发抓包模块
|
||||
@@ -434,7 +460,10 @@ ds2api/
|
||||
│ ├── server/ # HTTP 路由与中间件(chi router)
|
||||
│ ├── sse/ # SSE 解析工具
|
||||
│ ├── stream/ # 统一流式消费引擎
|
||||
│ ├── testsuite/ # 端到端测试框架与用例编排
|
||||
│ ├── translatorcliproxy/ # CLIProxy 桥接与流写入组件
|
||||
│ ├── util/ # 通用工具函数
|
||||
│ ├── version/ # 版本解析 / 比较与 tag 规范化
|
||||
│ └── webui/ # WebUI 静态文件托管与自动构建
|
||||
├── webui/ # React WebUI 源码(Vite + Tailwind)
|
||||
│ └── src/
|
||||
@@ -446,6 +475,7 @@ ds2api/
|
||||
│ └── build-webui.sh # WebUI 手动构建脚本
|
||||
├── tests/
|
||||
│ ├── compat/ # 兼容性测试夹具与期望输出
|
||||
│ ├── node/ # Node 侧单元测试(chat-stream / tool-sieve)
|
||||
│ └── scripts/ # 统一测试脚本入口(unit/e2e)
|
||||
├── docs/ # 部署 / 贡献 / 测试等辅助文档
|
||||
├── static/admin/ # WebUI 构建产物(不提交到 Git)
|
||||
|
||||
84
README.en.md
84
README.en.md
@@ -28,43 +28,64 @@ DS2API converts DeepSeek Web chat capability into OpenAI-compatible, Claude-comp
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
Client["🖥️ Clients\n(OpenAI / Claude / Gemini compat)"]
|
||||
Client["🖥️ Clients / SDKs\n(OpenAI / Claude / Gemini)"]
|
||||
Upstream["☁️ DeepSeek API"]
|
||||
|
||||
subgraph DS2API["DS2API Service"]
|
||||
direction TB
|
||||
CORS["CORS Middleware"]
|
||||
Auth["🔐 Auth Middleware"]
|
||||
subgraph DS2API["DS2API 3.x (Unified OpenAI Core)"]
|
||||
Router["chi Router + Middleware\n(RequestID / RealIP / Logger / Recoverer / CORS)"]
|
||||
|
||||
subgraph Adapters["Adapter Layer"]
|
||||
OA["OpenAI Adapter\n/v1/*"]
|
||||
CA["Claude Adapter\n/anthropic/*"]
|
||||
GA["Gemini Adapter\n/v1beta/models/*"]
|
||||
subgraph Adapters["Protocol Adapters"]
|
||||
OA["OpenAI\n/v1/*"]
|
||||
CA["Claude\n/anthropic/* + /v1/messages"]
|
||||
GA["Gemini\n/v1beta/models/* + /v1/models/*"]
|
||||
Admin["Admin API\n/admin/*"]
|
||||
WebUI["WebUI\n/admin (static hosting)"]
|
||||
end
|
||||
|
||||
subgraph Support["Support Modules"]
|
||||
Pool["📦 Account Pool / Queue"]
|
||||
PoW["⚙️ PoW WASM\n(wazero)"]
|
||||
subgraph Runtime["Runtime + Core Capabilities"]
|
||||
Bridge["CLIProxy Bridge\n(multi-protocol <-> OpenAI)"]
|
||||
OAEngine["OpenAI ChatCompletions\n(unified tools + stream semantics)"]
|
||||
Auth["Auth Resolver\n(API key / bearer / x-goog-api-key)"]
|
||||
Pool["Account Pool + Queue\n(in-flight slots + wait queue)"]
|
||||
DSClient["DeepSeek Client\n(session / auth / HTTP)"]
|
||||
Pow["PoW WASM\n(wazero preload)"]
|
||||
Tool["Tool Sieve\n(Go/Node semantic parity)"]
|
||||
end
|
||||
|
||||
Admin["🛠️ Admin API\n/admin/*"]
|
||||
WebUI["🌐 WebUI\n(/admin)"]
|
||||
end
|
||||
|
||||
DS["☁️ DeepSeek API"]
|
||||
Client --> Router
|
||||
Router --> OA & CA & GA
|
||||
Router --> Admin
|
||||
Router --> WebUI
|
||||
|
||||
Client -- "Request" --> CORS --> Auth
|
||||
Auth --> OA & CA & GA
|
||||
OA & CA & GA -- "Call" --> DS
|
||||
Auth --> Admin
|
||||
OA & CA & GA -. "Rotate accounts" .-> Pool
|
||||
OA & CA & GA -. "Compute PoW" .-> PoW
|
||||
DS -- "Response" --> Client
|
||||
OA --> OAEngine
|
||||
CA & GA --> Bridge
|
||||
Bridge --> OAEngine
|
||||
OAEngine --> Auth
|
||||
OAEngine -.account rotation.-> Pool
|
||||
OAEngine -.tool-call parsing.-> Tool
|
||||
OAEngine -.PoW solving.-> Pow
|
||||
Auth --> DSClient
|
||||
DSClient --> Upstream
|
||||
Upstream --> DSClient
|
||||
OAEngine --> Bridge
|
||||
Bridge --> Client
|
||||
```
|
||||
|
||||
- **Backend**: Go (`cmd/ds2api/`, `api/`, `internal/`), no Python runtime
|
||||
- **Frontend**: React admin panel (`webui/`), served as static build at runtime
|
||||
- **Deployment**: local run, Docker, Vercel serverless, Linux systemd
|
||||
|
||||
### 3.0 Architecture Changes (vs older releases)
|
||||
|
||||
- **Unified routing core**: all protocol entries are now centralized through `internal/server/router.go`, with OpenAI / Claude / Gemini / Admin / WebUI routes registered in one tree to avoid multi-entry drift.
|
||||
- **Unified execution chain**: Claude/Gemini entries are translated by `internal/translatorcliproxy`, then executed through `openai.ChatCompletions` for shared tool-calling and stream semantics, then translated back to the client protocol.
|
||||
- **Cleaner adapter boundaries**: `internal/adapter/{claude,gemini}` handles protocol wrappers, while `internal/adapter/openai` remains the execution core; upstream DeepSeek calls are retained only in the OpenAI core.
|
||||
- **Tool-calling parity across runtimes**: Go (`internal/util`) and Vercel Node (`internal/js/helpers/stream-tool-sieve`) follow aligned parsing/anti-leak semantics across JSON / XML / invoke / text-kv inputs.
|
||||
- **Config/runtime separation**: static config (`config`) and runtime policy (`settings`) are managed independently via Admin APIs, enabling hot updates and password rotation with JWT invalidation.
|
||||
- **Streaming behavior upgrade**: `/v1/responses` and `/v1/chat/completions` now share a more consistent incremental tool-call emission strategy across SDK ecosystems.
|
||||
- **Improved operability**: `/healthz`, `/readyz`, `/admin/version`, and `/admin/dev/captures` form a tighter post-deploy diagnostics loop.
|
||||
|
||||
## Key Capabilities
|
||||
|
||||
| Capability | Details |
|
||||
@@ -144,7 +165,7 @@ Recommended per deployment mode:
|
||||
|
||||
### Option 1: Local Run
|
||||
|
||||
**Prerequisites**: Go 1.24+, Node.js 20+ (only if building WebUI locally)
|
||||
**Prerequisites**: Go 1.26+, Node.js 20+ (only if building WebUI locally)
|
||||
|
||||
```bash
|
||||
# 1. Clone
|
||||
@@ -166,8 +187,9 @@ Default URL: `http://localhost:5001`
|
||||
### Option 2: Docker
|
||||
|
||||
```bash
|
||||
# 1. Prepare env file
|
||||
# 1. Prepare env file and config file
|
||||
cp .env.example .env
|
||||
cp config.example.json config.json
|
||||
|
||||
# 2. Edit .env (at least set DS2API_ADMIN_KEY)
|
||||
# DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||
@@ -320,6 +342,7 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_CONFIG_PATH` | Config file path | `config.json` |
|
||||
| `DS2API_CONFIG_JSON` | Inline config (JSON or Base64) | — |
|
||||
| `CONFIG_JSON` | Legacy compatibility config input | — |
|
||||
| `DS2API_ENV_WRITEBACK` | Auto-write env-backed config to file and transition to file mode (`1/true/yes/on`) | Disabled |
|
||||
| `DS2API_WASM_PATH` | PoW WASM file path | Auto-detect |
|
||||
| `DS2API_STATIC_ADMIN_DIR` | Admin static assets dir | `static/admin` |
|
||||
| `DS2API_AUTO_BUILD_WEBUI` | Auto-build WebUI on startup | Enabled locally, disabled on Vercel |
|
||||
@@ -339,6 +362,8 @@ cp opencode.json.example opencode.json
|
||||
| `VERCEL_TEAM_ID` | Vercel team ID | — |
|
||||
| `DS2API_VERCEL_PROTECTION_BYPASS` | Vercel deployment protection bypass for internal Node→Go calls | — |
|
||||
|
||||
> Note: when `DS2API_CONFIG_JSON/CONFIG_JSON` is detected, the Admin UI shows mode risk and auto-persistence status (including `DS2API_CONFIG_PATH` and mode-transition hints).
|
||||
|
||||
## Authentication Modes
|
||||
|
||||
For business endpoints (`/v1/*`, `/anthropic/*`, Gemini routes), DS2API supports two modes:
|
||||
@@ -402,6 +427,7 @@ Response fields include:
|
||||
|
||||
```text
|
||||
ds2api/
|
||||
├── app/ # Unified HTTP handler assembly (shared by local + serverless)
|
||||
├── cmd/
|
||||
│ ├── ds2api/ # Local / container entrypoint
|
||||
│ └── ds2api-tests/ # End-to-end testsuite entrypoint
|
||||
@@ -418,8 +444,8 @@ ds2api/
|
||||
│ ├── admin/ # Admin API handlers (incl. Settings hot-reload)
|
||||
│ ├── auth/ # Auth and JWT
|
||||
│ ├── claudeconv/ # Claude message format conversion
|
||||
│ ├── compat/ # Compatibility helpers
|
||||
│ ├── config/ # Config loading and hot-reload
|
||||
│ ├── compat/ # Go-version compatibility and regression helpers
|
||||
│ ├── config/ # Config loading, validation, and hot-reload
|
||||
│ ├── deepseek/ # DeepSeek API client, PoW WASM
|
||||
│ ├── js/ # Node runtime stream/compat logic
|
||||
│ ├── devcapture/ # Dev packet capture module
|
||||
@@ -428,7 +454,10 @@ ds2api/
|
||||
│ ├── server/ # HTTP routing and middleware (chi router)
|
||||
│ ├── sse/ # SSE parsing utilities
|
||||
│ ├── stream/ # Unified stream consumption engine
|
||||
│ ├── testsuite/ # End-to-end testsuite framework and case orchestration
|
||||
│ ├── translatorcliproxy/ # CLIProxy bridge and stream writer components
|
||||
│ ├── util/ # Common utilities
|
||||
│ ├── version/ # Version parsing/comparison and tag normalization
|
||||
│ └── webui/ # WebUI static file serving and auto-build
|
||||
├── webui/ # React WebUI source (Vite + Tailwind)
|
||||
│ └── src/
|
||||
@@ -440,6 +469,7 @@ ds2api/
|
||||
│ └── build-webui.sh # Manual WebUI build script
|
||||
├── tests/
|
||||
│ ├── compat/ # Compatibility fixtures and expected outputs
|
||||
│ ├── node/ # Node-side unit tests (chat-stream / tool-sieve)
|
||||
│ └── scripts/ # Unified test script entrypoints (unit/e2e)
|
||||
├── docs/ # Deployment / contributing / testing docs
|
||||
├── static/admin/ # WebUI build output (not committed to Git)
|
||||
|
||||
@@ -8,7 +8,7 @@ Thanks for your interest in contributing to DS2API!
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Go 1.24+
|
||||
- Go 1.26+
|
||||
- Node.js 20+ (for WebUI development)
|
||||
- npm (bundled with Node.js)
|
||||
|
||||
@@ -94,6 +94,7 @@ Manually build WebUI to `static/admin/`:
|
||||
|
||||
```text
|
||||
ds2api/
|
||||
├── app/ # Shared HTTP handler assembly (local + serverless)
|
||||
├── cmd/
|
||||
│ ├── ds2api/ # Local/container entrypoint
|
||||
│ └── ds2api-tests/ # End-to-end testsuite entrypoint
|
||||
@@ -110,8 +111,8 @@ ds2api/
|
||||
│ ├── admin/ # Admin API handlers
|
||||
│ ├── auth/ # Auth and JWT
|
||||
│ ├── claudeconv/ # Claude message conversion
|
||||
│ ├── compat/ # Compatibility helpers
|
||||
│ ├── config/ # Config loading and hot-reload
|
||||
│ ├── compat/ # Go-version compatibility and regression helpers
|
||||
│ ├── config/ # Config loading, validation, and hot-reload
|
||||
│ ├── deepseek/ # DeepSeek client, PoW WASM
|
||||
│ ├── js/ # Node runtime stream/compat logic
|
||||
│ ├── devcapture/ # Dev packet capture
|
||||
@@ -120,8 +121,10 @@ ds2api/
|
||||
│ ├── server/ # HTTP routing (chi router)
|
||||
│ ├── sse/ # SSE parsing utilities
|
||||
│ ├── stream/ # Unified stream consumption engine
|
||||
│ ├── testsuite/ # Testsuite core logic
|
||||
│ ├── testsuite/ # Testsuite framework and scenario orchestration
|
||||
│ ├── translatorcliproxy/ # CLIProxy bridge and stream writer
|
||||
│ ├── util/ # Common utilities
|
||||
│ ├── version/ # Version parsing and comparison
|
||||
│ └── webui/ # WebUI static hosting
|
||||
├── webui/ # React WebUI source
|
||||
│ └── src/
|
||||
@@ -130,7 +133,10 @@ ds2api/
|
||||
│ ├── components/ # Shared components
|
||||
│ └── locales/ # Language packs
|
||||
├── scripts/ # Build and test scripts
|
||||
├── tests/ # Unit tests, Node tests, and end-to-end tests
|
||||
├── tests/
|
||||
│ ├── compat/ # Compatibility fixtures and expected outputs
|
||||
│ ├── node/ # Node-side unit tests
|
||||
│ └── scripts/ # Test script entrypoints (unit/e2e)
|
||||
├── plans/ # Plans, gates, and manual smoke-test records
|
||||
├── static/admin/ # WebUI build output (not committed)
|
||||
├── Dockerfile # Multi-stage build
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
### 前置要求
|
||||
|
||||
- Go 1.24+
|
||||
- Go 1.26+
|
||||
- Node.js 20+(WebUI 开发时)
|
||||
- npm(随 Node.js 提供)
|
||||
|
||||
@@ -94,6 +94,7 @@ docker-compose -f docker-compose.dev.yml up
|
||||
|
||||
```text
|
||||
ds2api/
|
||||
├── app/ # 统一 HTTP Handler 装配(本地 + Serverless)
|
||||
├── cmd/
|
||||
│ ├── ds2api/ # 本地/容器启动入口
|
||||
│ └── ds2api-tests/ # 端到端测试集入口
|
||||
@@ -110,8 +111,8 @@ ds2api/
|
||||
│ ├── admin/ # Admin API handlers
|
||||
│ ├── auth/ # 鉴权与 JWT
|
||||
│ ├── claudeconv/ # Claude 消息格式转换
|
||||
│ ├── compat/ # 兼容性辅助
|
||||
│ ├── config/ # 配置加载与热更新
|
||||
│ ├── compat/ # Go 版本兼容与回归测试辅助
|
||||
│ ├── config/ # 配置加载、校验与热更新
|
||||
│ ├── deepseek/ # DeepSeek 客户端、PoW WASM
|
||||
│ ├── js/ # Node 运行时流式/兼容逻辑
|
||||
│ ├── devcapture/ # 开发抓包
|
||||
@@ -120,8 +121,10 @@ ds2api/
|
||||
│ ├── server/ # HTTP 路由(chi router)
|
||||
│ ├── sse/ # SSE 解析工具
|
||||
│ ├── stream/ # 统一流式消费引擎
|
||||
│ ├── testsuite/ # 测试集核心逻辑
|
||||
│ ├── testsuite/ # 测试集框架与场景编排
|
||||
│ ├── translatorcliproxy/ # CLIProxy 桥接与流式写入
|
||||
│ ├── util/ # 通用工具
|
||||
│ ├── version/ # 版本解析与比较
|
||||
│ └── webui/ # WebUI 静态托管
|
||||
├── webui/ # React WebUI 源码
|
||||
│ └── src/
|
||||
@@ -130,7 +133,10 @@ ds2api/
|
||||
│ ├── components/ # 通用组件
|
||||
│ └── locales/ # 语言包
|
||||
├── scripts/ # 构建与测试脚本
|
||||
├── tests/ # 单元测试、Node 测试与端到端测试
|
||||
├── tests/
|
||||
│ ├── compat/ # 兼容夹具与期望输出
|
||||
│ ├── node/ # Node 侧单元测试
|
||||
│ └── scripts/ # 测试脚本入口(unit/e2e)
|
||||
├── plans/ # 计划、门禁和手工烟测记录
|
||||
├── static/admin/ # WebUI 构建产物(不提交)
|
||||
├── Dockerfile # 多阶段构建
|
||||
|
||||
@@ -24,7 +24,7 @@ This guide covers all deployment methods for the current Go-based codebase.
|
||||
|
||||
| Dependency | Minimum Version | Notes |
|
||||
| --- | --- | --- |
|
||||
| Go | 1.24+ | Build backend |
|
||||
| Go | 1.26+ | Build backend |
|
||||
| Node.js | 20+ | Only needed to build WebUI locally |
|
||||
| npm | Bundled with Node.js | Install WebUI dependencies |
|
||||
|
||||
@@ -111,8 +111,9 @@ go build -o ds2api ./cmd/ds2api
|
||||
### 2.1 Basic Steps
|
||||
|
||||
```bash
|
||||
# Copy env template
|
||||
# Copy env template and config file
|
||||
cp .env.example .env
|
||||
cp config.example.json config.json
|
||||
|
||||
# Edit .env and set at least:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
@@ -248,6 +249,7 @@ VERCEL_TEAM_ID=team_xxxxxxxxxxxx # optional for personal accounts
|
||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
|
||||
| `DS2API_GLOBAL_MAX_INFLIGHT` | Global inflight limit | `recommended_concurrency` |
|
||||
| `DS2API_MAX_INFLIGHT` | Alias (legacy compat) | — |
|
||||
| `DS2API_ENV_WRITEBACK` | When `DS2API_CONFIG_JSON` is present, auto-write to `DS2API_CONFIG_PATH` and switch to file-backed mode after success (`1/true/yes/on`) | Disabled |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | Hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL | `900` |
|
||||
| `VERCEL_TOKEN` | Vercel sync token | — |
|
||||
@@ -399,7 +401,7 @@ cp config.example.json config.json
|
||||
docker pull ghcr.io/cjackhwang/ds2api:latest
|
||||
|
||||
# specific version (example)
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v2.1.2
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v3.0.0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
|
||||
| 依赖 | 最低版本 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| Go | 1.24+ | 编译后端 |
|
||||
| Go | 1.26+ | 编译后端 |
|
||||
| Node.js | 20+ | 仅在需要本地构建 WebUI 时 |
|
||||
| npm | 随 Node.js 提供 | 安装 WebUI 依赖 |
|
||||
|
||||
@@ -111,8 +111,9 @@ go build -o ds2api ./cmd/ds2api
|
||||
### 2.1 基本步骤
|
||||
|
||||
```bash
|
||||
# 复制环境变量模板
|
||||
# 复制环境变量模板和配置文件
|
||||
cp .env.example .env
|
||||
cp config.example.json config.json
|
||||
|
||||
# 编辑 .env(请改成你的强密码),至少设置:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
@@ -248,6 +249,7 @@ VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空
|
||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容别名) | — |
|
||||
| `DS2API_GLOBAL_MAX_INFLIGHT` | 全局并发上限 | `recommended_concurrency` |
|
||||
| `DS2API_MAX_INFLIGHT` | 同上(兼容别名) | — |
|
||||
| `DS2API_ENV_WRITEBACK` | 检测到 `DS2API_CONFIG_JSON` 时自动写入 `DS2API_CONFIG_PATH`,并在成功后转为文件模式(`1/true/yes/on`) | 关闭 |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | 混合流式内部鉴权 | 回退用 `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease TTL | `900` |
|
||||
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
||||
@@ -399,7 +401,7 @@ cp config.example.json config.json
|
||||
docker pull ghcr.io/cjackhwang/ds2api:latest
|
||||
|
||||
# 指定版本(示例)
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v2.1.2
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v3.0.0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
82
docs/DeepSeekSSE流格式字段分析-2026-04-03.md
Normal file
82
docs/DeepSeekSSE流格式字段分析-2026-04-03.md
Normal file
@@ -0,0 +1,82 @@
|
||||
# DeepSeek SSE 流格式字段分析(2026-04-03)
|
||||
|
||||
> 日期:2026-04-03(UTC)
|
||||
>
|
||||
> 样本:`tests/raw_stream_samples/guangzhou-weather-reasoner-search-20260403/upstream.stream.sse`
|
||||
>
|
||||
> 模型:`deepseek-reasoner-search`(搜索 + 思考)
|
||||
|
||||
## 1. SSE 事件层结构
|
||||
|
||||
原始流由标准 SSE 帧组成,常见形态:
|
||||
|
||||
```text
|
||||
event: <type>
|
||||
data: <json or text>
|
||||
|
||||
```
|
||||
|
||||
样本中主要 `event` 类型:
|
||||
|
||||
- `ready`:流建立后返回请求/响应消息 ID。
|
||||
- `update_session`:会话时间戳更新。
|
||||
- `finish`:流式阶段结束。
|
||||
- (无 `event` 时)默认为 message 事件,`data:` 中承载主要增量数据。
|
||||
|
||||
## 2. `data` JSON 常见字段
|
||||
|
||||
上游增量主体多为 JSON Patch 风格对象:
|
||||
|
||||
- `p`(path):字段路径,如 `response/fragments/-1/content`。
|
||||
- `o`(op,可选):操作类型,常见 `SET` / `APPEND` / `BATCH`。
|
||||
- `v`(value):值(字符串、布尔、对象、数组都可能)。
|
||||
|
||||
示例(语义):
|
||||
|
||||
- `{"p":"response/fragments/-1/content","o":"APPEND","v":"..."}`
|
||||
- `{"p":"response/fragments/-16/status","v":"FINISHED"}`
|
||||
- `{"p":"response/status","o":"SET","v":"FINISHED"}`
|
||||
|
||||
## 3. 搜索+思考场景关键路径
|
||||
|
||||
### 3.1 文本内容
|
||||
|
||||
- `response/fragments/<idx>/content`
|
||||
- `response/content`
|
||||
- `response/thinking_content`
|
||||
- `response/fragments`(`APPEND` + fragment 数组)
|
||||
|
||||
### 3.2 搜索相关
|
||||
|
||||
- `response/fragments/<idx>/results`(检索结果数组)
|
||||
- `response/search_status`(检索状态,建议跳过展示)
|
||||
|
||||
### 3.3 状态相关(重点)
|
||||
|
||||
- `response/status = FINISHED`:**最终结束信号**(需要保留用于结束判定)
|
||||
- `response/fragments/<idx>/status = FINISHED`:**分片级状态**(高频,建议跳过输出)
|
||||
- `response/quasi_status`:过程状态(建议跳过输出)
|
||||
|
||||
## 4. 泄露问题根因(FINISHED 重复)
|
||||
|
||||
在搜索 + 思考模型中,`response/fragments/<idx>/status` 会出现大量不同 `<idx>`(例如 `-1/-2/-3/-16...`)的 `FINISHED`。
|
||||
|
||||
若只过滤固定少量索引(例如仅 `-1/-2/-3`),其他索引的状态会当普通文本透传,导致前端出现:
|
||||
|
||||
- `FINISHEDFINISHEDFINISHED...`
|
||||
|
||||
## 5. 适配建议(已落地)
|
||||
|
||||
1. 跳过所有 `response/fragments/-?\d+/status`。
|
||||
2. 继续保留 `response/status=FINISHED` 作为真正结束判定。
|
||||
3. 通过独立仿真工具持续回放全部样本,作为回归门禁:
|
||||
|
||||
```bash
|
||||
./tests/scripts/run-raw-stream-sim.sh
|
||||
```
|
||||
|
||||
## 6. 后续扩展建议
|
||||
|
||||
- 增加不同模型(`deepseek-chat-search` / 非 search / 非 thinking)样本。
|
||||
- 增加异常样本(限流、中断、content_filter、空结果)。
|
||||
- 为仿真报告加入字段覆盖率统计(路径频次、事件频次、终止路径命中率)。
|
||||
@@ -226,6 +226,17 @@ node --test tests/node/stream-tool-sieve.test.js
|
||||
go run ./cmd/ds2api-tests --no-preflight
|
||||
```
|
||||
|
||||
### 运行原始流仿真(独立工具)
|
||||
|
||||
```bash
|
||||
./tests/scripts/run-raw-stream-sim.sh
|
||||
```
|
||||
|
||||
说明:
|
||||
- 该工具会重放 `tests/raw_stream_samples` 下全部样本,按上游 SSE 顺序做 1:1 仿真解析。
|
||||
- 默认校验不出现 `FINISHED` 文本泄露,并要求存在结束信号。
|
||||
- 结果会写入 `artifacts/raw-stream-sim/*.json`,可供其他测试脚本或排障流程复用。
|
||||
|
||||
### 指定输出目录和超时
|
||||
|
||||
```bash
|
||||
|
||||
16
go.mod
16
go.mod
@@ -1,17 +1,25 @@
|
||||
module ds2api
|
||||
|
||||
go 1.24
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
github.com/go-chi/chi/v5 v5.2.3
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/refraction-networking/utls v1.8.1
|
||||
github.com/refraction-networking/utls v1.8.2
|
||||
github.com/tetratelabs/wazero v1.9.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/klauspost/compress v1.17.4 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
github.com/router-for-me/CLIProxyAPI/v6 v6.9.8 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
31
go.sum
31
go.sum
@@ -1,16 +1,47 @@
|
||||
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
||||
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
|
||||
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/router-for-me/CLIProxyAPI/v6 v6.9.8 h1:O65R38THenp8E1IK0paQlOfop3Y6UYlfqSdLlepidSY=
|
||||
github.com/router-for-me/CLIProxyAPI/v6 v6.9.8/go.mod h1:P1jsIPFXorYGuS2N/3BlZYkpRKi/z7+oR3+1tdG0u4k=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I=
|
||||
github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -60,16 +60,10 @@ func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Acc
|
||||
return acc, true
|
||||
}
|
||||
|
||||
if acc, ok := p.tryAcquire(exclude, true); ok {
|
||||
return acc, true
|
||||
}
|
||||
if acc, ok := p.tryAcquire(exclude, false); ok {
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
return p.tryAcquire(exclude)
|
||||
}
|
||||
|
||||
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
|
||||
func (p *Pool) tryAcquire(exclude map[string]bool) (config.Account, bool) {
|
||||
for i := 0; i < len(p.queue); i++ {
|
||||
id := p.queue[i]
|
||||
if exclude[id] || !p.canAcquireIDLocked(id) {
|
||||
@@ -79,9 +73,6 @@ func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Ac
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if requireToken && acc.Token == "" {
|
||||
continue
|
||||
}
|
||||
p.inUse[id]++
|
||||
p.bumpQueue(id)
|
||||
return acc, true
|
||||
|
||||
@@ -215,6 +215,33 @@ func TestPoolDropsLegacyTokenOnlyAccountOnLoad(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolAcquireRotatesIntoTokenlessAccounts(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
|
||||
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"accounts":[
|
||||
{"email":"acc1@example.com","token":"token1"},
|
||||
{"email":"acc2@example.com","token":""},
|
||||
{"email":"acc3@example.com","token":""}
|
||||
]
|
||||
}`)
|
||||
|
||||
pool := NewPool(config.LoadStore())
|
||||
for i, want := range []string{"acc1@example.com", "acc2@example.com", "acc3@example.com"} {
|
||||
acc, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatalf("expected acquire success at step %d", i+1)
|
||||
}
|
||||
if got := acc.Identifier(); got != want {
|
||||
t.Fatalf("unexpected account at step %d: got %q want %q", i+1, got, want)
|
||||
}
|
||||
pool.Release(acc.Identifier())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolAcquireWaitQueuesAndSucceedsAfterRelease(t *testing.T) {
|
||||
pool := newSingleAccountPoolForTest(t, "1")
|
||||
first, ok := pool.Acquire("", nil)
|
||||
|
||||
@@ -24,6 +24,10 @@ type ConfigReader interface {
|
||||
ClaudeMapping() map[string]string
|
||||
}
|
||||
|
||||
type OpenAIChatRunner interface {
|
||||
ChatCompletions(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
|
||||
97
internal/adapter/claude/handler_helpers_misc.go
Normal file
97
internal/adapter/claude/handler_helpers_misc.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if ok && msg["role"] == "system" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractClaudeToolNames(tools []any) []string {
|
||||
out := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _, _ := extractClaudeToolMeta(m)
|
||||
if name != "" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractClaudeToolMeta(m map[string]any) (string, string, any) {
|
||||
name, _ := m["name"].(string)
|
||||
desc, _ := m["description"].(string)
|
||||
schemaObj := m["input_schema"]
|
||||
if schemaObj == nil {
|
||||
schemaObj = m["parameters"]
|
||||
}
|
||||
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name, _ = fn["name"].(string)
|
||||
}
|
||||
if strings.TrimSpace(desc) == "" {
|
||||
desc, _ = fn["description"].(string)
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["input_schema"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["parameters"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj
|
||||
}
|
||||
|
||||
func toMessageMaps(v any) []map[string]any {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractMessageContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, it := range x {
|
||||
parts = append(parts, fmt.Sprintf("%v", it))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return fmt.Sprintf("%v", x)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,85 +1,126 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
claudefmt "ds2api/internal/format/claude"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeClaudeError(w, status, detail)
|
||||
if h.OpenAI == nil {
|
||||
writeClaudeError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
if h.proxyViaOpenAI(w, r, h.Store) {
|
||||
return
|
||||
}
|
||||
writeClaudeError(w, http.StatusBadGateway, "Failed to proxy Claude request.")
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid body")
|
||||
return true
|
||||
}
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if err := json.Unmarshal(raw, &req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
return true
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
stdReq := norm.Standard
|
||||
model, _ := req["model"].(string)
|
||||
stream := util.ToBool(req["stream"])
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
|
||||
return
|
||||
// Preserve claude_mapping (fast/slow/opus routing) while proxying via OpenAI.
|
||||
translateModel := model
|
||||
if store != nil {
|
||||
if norm, normErr := normalizeClaudeRequest(store, cloneMap(req)); normErr == nil && strings.TrimSpace(norm.Standard.ResolvedModel) != "" {
|
||||
translateModel = strings.TrimSpace(norm.Standard.ResolvedModel)
|
||||
}
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
|
||||
return
|
||||
}
|
||||
requestPayload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
translatedReq := translatorcliproxy.ToOpenAI(sdktranslator.FormatClaude, translateModel, raw, stream)
|
||||
|
||||
isVercelPrepare := strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1"
|
||||
isVercelRelease := strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
|
||||
if isVercelRelease {
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(raw))
|
||||
proxyReq.ContentLength = int64(len(raw))
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
|
||||
if stdReq.Stream {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(translatedReq))
|
||||
proxyReq.ContentLength = int64(len(translatedReq))
|
||||
|
||||
if stream && !isVercelPrepare {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
streamWriter := translatorcliproxy.NewOpenAIStreamTranslatorWriter(w, sdktranslator.FormatClaude, model, raw, translatedReq)
|
||||
h.OpenAI.ChatCompletions(streamWriter, proxyReq)
|
||||
return true
|
||||
}
|
||||
result := sse.CollectStream(resp, stdReq.Thinking, true)
|
||||
respBody := claudefmt.BuildMessageResponse(
|
||||
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
stdReq.ResponseModel,
|
||||
norm.NormalizedMessages,
|
||||
result.Thinking,
|
||||
result.Text,
|
||||
stdReq.ToolNames,
|
||||
)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
if isVercelPrepare {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
converted := translatorcliproxy.FromOpenAINonStream(sdktranslator.FormatClaude, model, raw, translatedReq, body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(converted)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
|
||||
@@ -15,9 +15,10 @@ import (
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatRunner
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -225,6 +225,47 @@ func TestNormalizeClaudeMessagesToolResultNonTextPayloadStringified(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesBackfillsToolResultCallIDByName(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "search_web",
|
||||
"input": map[string]any{"query": "latest"},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"name": "search_web",
|
||||
"content": "ok",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %#v", got)
|
||||
}
|
||||
assistant, _ := got[0].(map[string]any)
|
||||
tc, _ := assistant["tool_calls"].([]any)
|
||||
call, _ := tc[0].(map[string]any)
|
||||
callID, _ := call["id"].(string)
|
||||
if !strings.HasPrefix(callID, "call_claude_") {
|
||||
t.Fatalf("expected generated call id, got %#v", call)
|
||||
}
|
||||
toolMsg, _ := got[1].(map[string]any)
|
||||
if toolMsg["tool_call_id"] != callID {
|
||||
t.Fatalf("expected tool_result to reuse generated id, got %#v", toolMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── buildClaudeToolPrompt ───────────────────────────────────────────
|
||||
|
||||
func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
||||
|
||||
@@ -11,6 +11,11 @@ import (
|
||||
|
||||
func normalizeClaudeMessages(messages []any) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
state := &claudeToolCallState{
|
||||
nameByID: map[string]string{},
|
||||
lastIDByName: map[string]string{},
|
||||
callIDSequence: 0,
|
||||
}
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
@@ -44,7 +49,7 @@ func normalizeClaudeMessages(messages []any) []any {
|
||||
case "tool_use":
|
||||
if role == "assistant" {
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolUseToAssistant(b); toolMsg != nil {
|
||||
if toolMsg := normalizeClaudeToolUseToAssistant(b, state); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
continue
|
||||
@@ -54,7 +59,7 @@ func normalizeClaudeMessages(messages []any) []any {
|
||||
}
|
||||
case "tool_result":
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolResultToToolMessage(b); toolMsg != nil {
|
||||
if toolMsg := normalizeClaudeToolResultToToolMessage(b, state); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
default:
|
||||
@@ -119,7 +124,7 @@ func formatClaudeToolResultForPrompt(block map[string]any) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any {
|
||||
func normalizeClaudeToolUseToAssistant(block map[string]any, state *claudeToolCallState) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -127,13 +132,15 @@ func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any {
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
callID := strings.TrimSpace(fmt.Sprintf("%v", block["id"]))
|
||||
callID := safeStringValue(block["id"])
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
callID = safeStringValue(block["tool_use_id"])
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_claude"
|
||||
callID = state.nextID()
|
||||
}
|
||||
state.nameByID[callID] = name
|
||||
state.lastIDByName[strings.ToLower(name)] = callID
|
||||
arguments := block["input"]
|
||||
if arguments == nil {
|
||||
arguments = map[string]any{}
|
||||
@@ -159,24 +166,34 @@ func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any {
|
||||
func normalizeClaudeToolResultToToolMessage(block map[string]any, state *claudeToolCallState) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
name := safeStringValue(block["name"])
|
||||
toolCallID := safeStringValue(block["tool_use_id"])
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
|
||||
toolCallID = safeStringValue(block["tool_call_id"])
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "call_claude"
|
||||
if name != "" {
|
||||
toolCallID = strings.TrimSpace(state.lastIDByName[strings.ToLower(name)])
|
||||
}
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = state.nextID()
|
||||
}
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": toolCallID,
|
||||
"content": normalizeClaudeToolResultContent(block["content"]),
|
||||
}
|
||||
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
|
||||
if name != "" {
|
||||
out["name"] = name
|
||||
state.nameByID[toolCallID] = name
|
||||
state.lastIDByName[strings.ToLower(name)] = toolCallID
|
||||
} else if inferred := strings.TrimSpace(state.nameByID[toolCallID]); inferred != "" {
|
||||
out["name"] = inferred
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -206,94 +223,3 @@ func formatClaudeBlockRaw(block map[string]any) string {
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if ok && msg["role"] == "system" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractClaudeToolNames(tools []any) []string {
|
||||
out := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _, _ := extractClaudeToolMeta(m)
|
||||
if name != "" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractClaudeToolMeta(m map[string]any) (string, string, any) {
|
||||
name, _ := m["name"].(string)
|
||||
desc, _ := m["description"].(string)
|
||||
schemaObj := m["input_schema"]
|
||||
if schemaObj == nil {
|
||||
schemaObj = m["parameters"]
|
||||
}
|
||||
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name, _ = fn["name"].(string)
|
||||
}
|
||||
if strings.TrimSpace(desc) == "" {
|
||||
desc, _ = fn["description"].(string)
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["input_schema"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["parameters"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj
|
||||
}
|
||||
|
||||
func toMessageMaps(v any) []map[string]any {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractMessageContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, it := range x {
|
||||
parts = append(parts, fmt.Sprintf("%v", it))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return fmt.Sprintf("%v", x)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
84
internal/adapter/claude/proxy_vercel_test.go
Normal file
84
internal/adapter/claude/proxy_vercel_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type claudeProxyStoreStub struct {
|
||||
mapping map[string]string
|
||||
}
|
||||
|
||||
func (s claudeProxyStoreStub) ClaudeMapping() map[string]string {
|
||||
return s.mapping
|
||||
}
|
||||
|
||||
type openAIProxyStub struct {
|
||||
status int
|
||||
body string
|
||||
}
|
||||
|
||||
func (s openAIProxyStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
if s.status == 0 {
|
||||
s.status = http.StatusOK
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(s.status)
|
||||
_, _ = w.Write([]byte(s.body))
|
||||
}
|
||||
|
||||
type openAIProxyCaptureStub struct {
|
||||
seenModel string
|
||||
}
|
||||
|
||||
func (s *openAIProxyCaptureStub) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
if m, ok := req["model"].(string); ok {
|
||||
s.seenModel = m
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"id":"ok","choices":[{"message":{"role":"assistant","content":"ok"}}]}`))
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIVercelPreparePassthrough(t *testing.T) {
|
||||
h := &Handler{OpenAI: openAIProxyStub{status: 200, body: `{"lease_id":"lease_123","payload":{"a":1}}`}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages?__stream_prepare=1", strings.NewReader(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("expected json response, got err=%v body=%s", err, rec.Body.String())
|
||||
}
|
||||
if _, ok := out["lease_id"]; !ok {
|
||||
t.Fatalf("expected lease_id in prepare passthrough, got=%v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIPreservesClaudeMapping(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{
|
||||
Store: claudeProxyStoreStub{mapping: map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}},
|
||||
OpenAI: openAI,
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model":"claude-3-opus","messages":[{"role":"user","content":"hi"}],"stream":false}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := strings.TrimSpace(openAI.seenModel); got != "deepseek-reasoner" {
|
||||
t.Fatalf("expected mapped proxy model deepseek-reasoner, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ type claudeStreamRuntime struct {
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
|
||||
nextBlockIndex int
|
||||
thinkingBlockOpen bool
|
||||
@@ -66,6 +67,9 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
s.upstreamErr = parsed.ErrorMessage
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||
|
||||
@@ -108,6 +108,9 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
if s.outputTokens > 0 {
|
||||
outputTokens = s.outputTokens
|
||||
}
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -9,48 +8,17 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimw "github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type streamStatusClaudeAuthStub struct{}
|
||||
type streamStatusClaudeOpenAIStub struct{}
|
||||
|
||||
func (streamStatusClaudeAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: "direct-token",
|
||||
CallerID: "caller:test",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
func (streamStatusClaudeOpenAIStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hello\"},\"finish_reason\":null}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
}
|
||||
|
||||
func (streamStatusClaudeAuthStub) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
type streamStatusClaudeDSStub struct{}
|
||||
|
||||
func (streamStatusClaudeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "session-id", nil
|
||||
}
|
||||
|
||||
func (streamStatusClaudeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (streamStatusClaudeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
body := "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n" + "data: [DONE]\n"
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: ioNopCloser{strings.NewReader(body)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ioNopCloser struct {
|
||||
*strings.Reader
|
||||
}
|
||||
|
||||
func (ioNopCloser) Close() error { return nil }
|
||||
|
||||
type streamStatusClaudeStoreStub struct{}
|
||||
|
||||
func (streamStatusClaudeStoreStub) ClaudeMapping() map[string]string {
|
||||
@@ -73,9 +41,8 @@ func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Hand
|
||||
func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: streamStatusClaudeStoreStub{},
|
||||
Auth: streamStatusClaudeAuthStub{},
|
||||
DS: streamStatusClaudeDSStub{},
|
||||
Store: streamStatusClaudeStoreStub{},
|
||||
OpenAI: streamStatusClaudeOpenAIStub{},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureClaudeStatusMiddleware(&statuses))
|
||||
@@ -83,7 +50,6 @@ func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) {
|
||||
|
||||
reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
25
internal/adapter/claude/tool_call_state.go
Normal file
25
internal/adapter/claude/tool_call_state.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type claudeToolCallState struct {
|
||||
nameByID map[string]string
|
||||
lastIDByName map[string]string
|
||||
callIDSequence int
|
||||
}
|
||||
|
||||
func (s *claudeToolCallState) nextID() string {
|
||||
s.callIDSequence++
|
||||
return fmt.Sprintf("call_claude_%d", s.callIDSequence)
|
||||
}
|
||||
|
||||
func safeStringValue(v any) string {
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
@@ -1,11 +1,20 @@
|
||||
package gemini
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxGeminiRawPromptChars = 1024
|
||||
|
||||
func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
out := make([]any, 0, 8)
|
||||
toolCallCounter := 0
|
||||
nextToolCallID := func() string {
|
||||
toolCallCounter++
|
||||
return fmt.Sprintf("call_gemini_%d", toolCallCounter)
|
||||
}
|
||||
lastToolCallIDByName := map[string]string{}
|
||||
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": "system",
|
||||
@@ -61,8 +70,11 @@ func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
|
||||
callID := strings.TrimSpace(asString(fnCall["id"]))
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
if callID = strings.TrimSpace(asString(fnCall["call_id"])); callID == "" {
|
||||
callID = nextToolCallID()
|
||||
}
|
||||
}
|
||||
lastToolCallIDByName[strings.ToLower(name)] = callID
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
@@ -91,7 +103,10 @@ func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
callID = strings.TrimSpace(lastToolCallIDByName[strings.ToLower(name)])
|
||||
}
|
||||
if callID == "" {
|
||||
callID = nextToolCallID()
|
||||
}
|
||||
content := fnResp["response"]
|
||||
if content == nil {
|
||||
|
||||
@@ -82,3 +82,48 @@ func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T
|
||||
t.Fatalf("expected raw base64 payload not to be embedded, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiMessagesFromRequestBackfillsFunctionResponseCallIDByName(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "model",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": "search_web",
|
||||
"args": map[string]any{"query": "docs"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionResponse": map[string]any{
|
||||
"name": "search_web",
|
||||
"response": map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := geminiMessagesFromRequest(req)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected two normalized messages, got %#v", got)
|
||||
}
|
||||
assistant, _ := got[0].(map[string]any)
|
||||
tc, _ := assistant["tool_calls"].([]any)
|
||||
call, _ := tc[0].(map[string]any)
|
||||
callID, _ := call["id"].(string)
|
||||
if !strings.HasPrefix(callID, "call_gemini_") {
|
||||
t.Fatalf("expected generated call id prefix, got %#v", call)
|
||||
}
|
||||
toolMsg, _ := got[1].(map[string]any)
|
||||
if toolMsg["tool_call_id"] != callID {
|
||||
t.Fatalf("expected tool response to inherit generated call id, tool=%#v call=%#v", toolMsg, call)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ type ConfigReader interface {
|
||||
ModelAliases() map[string]string
|
||||
}
|
||||
|
||||
type OpenAIChatRunner interface {
|
||||
ChatCompletions(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
|
||||
@@ -1,70 +1,134 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if h.OpenAI == nil {
|
||||
writeGeminiError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
|
||||
return
|
||||
}
|
||||
if h.proxyViaOpenAI(w, r, stream) {
|
||||
return
|
||||
}
|
||||
writeGeminiError(w, http.StatusBadGateway, "Failed to proxy Gemini request.")
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeGeminiError(w, status, detail)
|
||||
return
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid body")
|
||||
return true
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
|
||||
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
|
||||
translatedReq := translatorcliproxy.ToOpenAI(sdktranslator.FormatGemini, routeModel, raw, stream)
|
||||
if !strings.Contains(string(translatedReq), `"stream"`) {
|
||||
var reqMap map[string]any
|
||||
if json.Unmarshal(translatedReq, &reqMap) == nil {
|
||||
reqMap["stream"] = stream
|
||||
if b, e := json.Marshal(reqMap); e == nil {
|
||||
translatedReq = b
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
isVercelPrepare := strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1"
|
||||
isVercelRelease := strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
|
||||
if isVercelRelease {
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(raw))
|
||||
proxyReq.ContentLength = int64(len(raw))
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(translatedReq))
|
||||
proxyReq.ContentLength = int64(len(translatedReq))
|
||||
|
||||
if stream && !isVercelPrepare {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
streamWriter := translatorcliproxy.NewOpenAIStreamTranslatorWriter(w, sdktranslator.FormatGemini, routeModel, raw, translatedReq)
|
||||
h.OpenAI.ChatCompletions(streamWriter, proxyReq)
|
||||
return true
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
writeGeminiErrorFromOpenAI(w, res.StatusCode, body)
|
||||
return true
|
||||
}
|
||||
if isVercelPrepare {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
converted := translatorcliproxy.FromOpenAINonStream(sdktranslator.FormatGemini, routeModel, raw, translatedReq, body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(converted)
|
||||
return true
|
||||
}
|
||||
|
||||
func writeGeminiErrorFromOpenAI(w http.ResponseWriter, status int, raw []byte) {
|
||||
message := strings.TrimSpace(string(raw))
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(raw, &parsed); err == nil {
|
||||
if errObj, ok := parsed["error"].(map[string]any); ok {
|
||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
message = strings.TrimSpace(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
if message == "" {
|
||||
message = http.StatusText(status)
|
||||
}
|
||||
writeGeminiError(w, status, message)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
@@ -76,12 +140,12 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht
|
||||
}
|
||||
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames, result.OutputTokens))
|
||||
}
|
||||
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, outputTokens int) map[string]any {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, outputTokens)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
@@ -98,10 +162,14 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string, outputTokens int) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
if outputTokens > 0 {
|
||||
completionTokens = outputTokens
|
||||
reasoningTokens = 0
|
||||
}
|
||||
return map[string]any{
|
||||
"promptTokenCount": promptTokens,
|
||||
"candidatesTokenCount": reasoningTokens + completionTokens,
|
||||
|
||||
@@ -11,9 +11,10 @@ import (
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatRunner
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
|
||||
@@ -64,6 +64,7 @@ type geminiStreamRuntime struct {
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
}
|
||||
|
||||
func newGeminiStreamRuntime(
|
||||
@@ -103,6 +104,9 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
@@ -176,6 +180,6 @@ func (s *geminiStreamRuntime) finalize() {
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.outputTokens),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,6 +61,44 @@ func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ m
|
||||
return m.resp, nil
|
||||
}
|
||||
|
||||
type geminiOpenAIErrorStub struct {
|
||||
status int
|
||||
body string
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (s geminiOpenAIErrorStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
for k, v := range s.headers {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(s.status)
|
||||
_, _ = w.Write([]byte(s.body))
|
||||
}
|
||||
|
||||
type geminiOpenAISuccessStub struct {
|
||||
stream bool
|
||||
body string
|
||||
}
|
||||
|
||||
func (s geminiOpenAISuccessStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
if s.stream {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hello \"},\"finish_reason\":null}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"world\"},\"finish_reason\":\"stop\"}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
return
|
||||
}
|
||||
out := s.body
|
||||
if strings.TrimSpace(out) == "" {
|
||||
out = `{"id":"chatcmpl-1","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"eval_javascript","arguments":"{\"code\":\"1+1\"}"}}]},"finish_reason":"tool_calls"}]}`
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(out))
|
||||
}
|
||||
|
||||
func makeGeminiUpstreamResponse(lines ...string) *http.Response {
|
||||
body := strings.Join(lines, "\n")
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
@@ -98,14 +136,11 @@ func TestGeminiRoutesRegistered(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
Auth: testGeminiAuth{},
|
||||
DS: testGeminiDS{resp: upstream},
|
||||
OpenAI: geminiOpenAISuccessStub{
|
||||
body: `{"id":"chatcmpl-1","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"eval_javascript","arguments":"{\"code\":\"1+1\"}"}}]},"finish_reason":"tool_calls"}]}`,
|
||||
},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
@@ -115,7 +150,6 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
@@ -144,11 +178,7 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{Store: testGeminiConfig{}, Auth: testGeminiAuth{}, DS: testGeminiDS{resp: upstream}}
|
||||
h := &Handler{Store: testGeminiConfig{}, OpenAI: geminiOpenAISuccessStub{}}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
@@ -157,7 +187,6 @@ func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
@@ -180,38 +209,25 @@ func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"hello "}`,
|
||||
`data: {"p":"response/content","v":"world"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
Auth: testGeminiAuth{},
|
||||
DS: testGeminiDS{resp: upstream},
|
||||
Store: testGeminiConfig{},
|
||||
OpenAI: geminiOpenAISuccessStub{stream: true},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:streamGenerateContent?alt=sse", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "data: ") {
|
||||
t.Fatalf("expected SSE data frames, got body=%s", rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), `"finishReason":"STOP"`) {
|
||||
t.Fatalf("expected stream finish frame, got body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
frames := extractGeminiSSEFrames(t, rec.Body.String())
|
||||
if len(frames) == 0 {
|
||||
t.Fatalf("expected non-empty sse frames, body=%s", rec.Body.String())
|
||||
t.Fatalf("expected non-empty stream frames, body=%s", rec.Body.String())
|
||||
}
|
||||
last := frames[len(frames)-1]
|
||||
candidates, _ := last["candidates"].([]any)
|
||||
@@ -229,16 +245,61 @@ func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateContentOpenAIProxyErrorUsesGeminiEnvelope(t *testing.T) {
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
OpenAI: geminiOpenAIErrorStub{
|
||||
status: http.StatusUnauthorized,
|
||||
body: `{"error":{"message":"invalid api key"}}`,
|
||||
headers: map[string]string{
|
||||
"WWW-Authenticate": `Bearer realm="example"`,
|
||||
"Retry-After": "30",
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:generateContent", strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("expected json body: %v", err)
|
||||
}
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if errObj["status"] != "UNAUTHENTICATED" {
|
||||
t.Fatalf("expected Gemini status UNAUTHENTICATED, got=%v", errObj["status"])
|
||||
}
|
||||
if errObj["message"] != "invalid api key" {
|
||||
t.Fatalf("expected parsed error message, got=%v", errObj["message"])
|
||||
}
|
||||
if got := rec.Header().Get("WWW-Authenticate"); got == "" {
|
||||
t.Fatalf("expected WWW-Authenticate header to be preserved")
|
||||
}
|
||||
if got := rec.Header().Get("Retry-After"); got != "30" {
|
||||
t.Fatalf("expected Retry-After header 30, got=%q", got)
|
||||
}
|
||||
if got := rec.Header().Get("X-RateLimit-Remaining"); got != "0" {
|
||||
t.Fatalf("expected X-RateLimit-Remaining header 0, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func extractGeminiSSEFrames(t *testing.T, body string) []map[string]any {
|
||||
t.Helper()
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
out := make([]map[string]any, 0, 4)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
raw := line
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
raw = strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
42
internal/adapter/gemini/proxy_vercel_test.go
Normal file
42
internal/adapter/gemini/proxy_vercel_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type openAIProxyStub struct {
|
||||
status int
|
||||
body string
|
||||
}
|
||||
|
||||
func (s openAIProxyStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
if s.status == 0 {
|
||||
s.status = http.StatusOK
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(s.status)
|
||||
_, _ = w.Write([]byte(s.body))
|
||||
}
|
||||
|
||||
func TestGeminiProxyViaOpenAIVercelReleasePassthrough(t *testing.T) {
|
||||
h := &Handler{OpenAI: openAIProxyStub{status: 200, body: `{"success":true}`}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:streamGenerateContent?__stream_release=1", strings.NewReader(`{"lease_id":"lease_123"}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.StreamGenerateContent(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("expected json response, got err=%v body=%s", err, rec.Body.String())
|
||||
}
|
||||
if v, ok := out["success"].(bool); !ok || !v {
|
||||
t.Fatalf("expected success=true passthrough, got=%v", out)
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,7 @@ type chatStreamRuntime struct {
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
}
|
||||
|
||||
func newChatStreamRuntime(
|
||||
@@ -165,12 +166,19 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
if len(detected.Calls) > 0 || s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
|
||||
if s.outputTokens > 0 {
|
||||
usage["completion_tokens"] = s.outputTokens
|
||||
if prompt, ok := usage["prompt_tokens"].(int); ok {
|
||||
usage["total_tokens"] = prompt + s.outputTokens
|
||||
}
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)},
|
||||
openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText),
|
||||
usage,
|
||||
))
|
||||
s.sendDone()
|
||||
}
|
||||
@@ -179,7 +187,13 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested}
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")}
|
||||
}
|
||||
if parsed.Stop {
|
||||
|
||||
@@ -106,7 +106,18 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := sanitizeLeakedOutput(result.Text)
|
||||
if writeUpstreamEmptyOutputError(w, result) {
|
||||
return
|
||||
}
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
if result.OutputTokens > 0 {
|
||||
if usage, ok := respBody["usage"].(map[string]any); ok {
|
||||
usage["completion_tokens"] = result.OutputTokens
|
||||
if prompt, ok := usage["prompt_tokens"].(int); ok {
|
||||
usage["total_tokens"] = prompt + result.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
|
||||
@@ -275,6 +275,44 @@ func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
|
||||
}
|
||||
|
||||
func TestHandleNonStreamReturns502WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":""}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty", "deepseek-chat", "prompt", false, nil)
|
||||
if rec.Code != http.StatusBadGateway {
|
||||
t.Fatalf("expected status 502 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||
t.Fatalf("expected code=upstream_empty_output, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutput(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"code":"content_filter"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, nil)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "content_filter" {
|
||||
t.Fatalf("expected code=content_filter, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
|
||||
@@ -23,6 +23,9 @@ var leakedAgentXMLBlockPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?is)<new_task\b[^>]*>(.*?)</new_task>`),
|
||||
}
|
||||
|
||||
var leakedAgentWrapperTagPattern = regexp.MustCompile(`(?is)</?(?:attempt_completion|ask_followup_question|new_task)\b[^>]*>`)
|
||||
var leakedAgentWrapperPlusResultOpenPattern = regexp.MustCompile(`(?is)<(?:attempt_completion|ask_followup_question|new_task)\b[^>]*>\s*<result>`)
|
||||
var leakedAgentResultPlusWrapperClosePattern = regexp.MustCompile(`(?is)</result>\s*</(?:attempt_completion|ask_followup_question|new_task)\b[^>]*>`)
|
||||
var leakedAgentResultTagPattern = regexp.MustCompile(`(?is)</?result>`)
|
||||
|
||||
func sanitizeLeakedOutput(text string) string {
|
||||
@@ -50,5 +53,18 @@ func sanitizeLeakedAgentXMLBlocks(text string) string {
|
||||
return leakedAgentResultTagPattern.ReplaceAllString(submatches[1], "")
|
||||
})
|
||||
}
|
||||
// Fallback for truncated output streams: strip any dangling wrapper tags
|
||||
// that were not part of a complete block replacement. If we detect leaked
|
||||
// wrapper tags, strip only adjacent <result> tags to avoid exposing agent
|
||||
// markup without altering unrelated user-visible <result> examples.
|
||||
if leakedAgentWrapperTagPattern.MatchString(out) {
|
||||
out = leakedAgentWrapperPlusResultOpenPattern.ReplaceAllStringFunc(out, func(match string) string {
|
||||
return leakedAgentResultTagPattern.ReplaceAllString(match, "")
|
||||
})
|
||||
out = leakedAgentResultPlusWrapperClosePattern.ReplaceAllStringFunc(out, func(match string) string {
|
||||
return leakedAgentResultTagPattern.ReplaceAllString(match, "")
|
||||
})
|
||||
out = leakedAgentWrapperTagPattern.ReplaceAllString(out, "")
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -41,3 +41,28 @@ func TestSanitizeLeakedOutputPreservesStandaloneResultTags(t *testing.T) {
|
||||
t.Fatalf("unexpected sanitize result for standalone result tag: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputRemovesDanglingAgentXMLOpeningTags(t *testing.T) {
|
||||
raw := "Done.<attempt_completion><result>Some final answer"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
if got != "Done.Some final answer" {
|
||||
t.Fatalf("unexpected sanitize result for dangling opening tags: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputRemovesDanglingAgentXMLClosingTags(t *testing.T) {
|
||||
raw := "Done.Some final answer</result></attempt_completion>"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
if got != "Done.Some final answer" {
|
||||
t.Fatalf("unexpected sanitize result for dangling closing tags: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputPreservesUnrelatedResultTagsWhenWrapperLeaks(t *testing.T) {
|
||||
raw := "Done.<attempt_completion><result>Some final answer\nExample XML: <result>value</result>"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
want := "Done.Some final answer\nExample XML: <result>value</result>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected sanitize result for mixed leaked wrapper + xml example: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,6 +114,9 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
sanitizedText := sanitizeLeakedOutput(result.Text)
|
||||
if writeUpstreamEmptyOutputError(w, result) {
|
||||
return
|
||||
}
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(sanitizedText, toolNames)
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||
|
||||
@@ -124,6 +127,14 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, sanitizedText, toolNames)
|
||||
if result.OutputTokens > 0 {
|
||||
if usage, ok := responseObj["usage"].(map[string]any); ok {
|
||||
usage["output_tokens"] = result.OutputTokens
|
||||
if input, ok := usage["input_tokens"].(int); ok {
|
||||
usage["total_tokens"] = input + result.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ type responsesStreamRuntime struct {
|
||||
messagePartAdded bool
|
||||
sequence int
|
||||
failed bool
|
||||
outputTokens int
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
@@ -144,6 +145,14 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.closeIncompleteFunctionItems()
|
||||
|
||||
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
|
||||
if s.outputTokens > 0 {
|
||||
if usage, ok := obj["usage"].(map[string]any); ok {
|
||||
usage["output_tokens"] = s.outputTokens
|
||||
if input, ok := usage["input_tokens"].(int); ok {
|
||||
usage["total_tokens"] = input + s.outputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
@@ -172,6 +181,9 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
@@ -627,6 +627,50 @@ func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamReturns502WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/content","v":""}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusBadGateway {
|
||||
t.Fatalf("expected 502 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||
t.Fatalf("expected code=upstream_empty_output, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutput(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"code":"content_filter"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "content_filter" {
|
||||
t.Fatalf("expected code=content_filter, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
|
||||
@@ -183,3 +183,53 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsStreamContentFilterStopsNormallyWithoutLeak(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"合法前缀"}`,
|
||||
`data: {"p":"response/status","v":"CONTENT_FILTER","accumulated_token_usage":77}`,
|
||||
`data: {"p":"response/content","v":"CONTENT_FILTER你好,这个问题我暂时无法回答,让我们换个话题再聊聊吧。"}`,
|
||||
)},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 || statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200, got %#v", statuses)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "这个问题我暂时无法回答") {
|
||||
t.Fatalf("expected leaked content-filter suffix to be hidden, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if len(frames) == 0 {
|
||||
t.Fatalf("expected at least one json frame, body=%s", rec.Body.String())
|
||||
}
|
||||
last := frames[len(frames)-1]
|
||||
choices, _ := last["choices"].([]any)
|
||||
if len(choices) != 1 {
|
||||
t.Fatalf("expected one choice in final frame, got %#v", last)
|
||||
}
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop for content-filter upstream stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +183,7 @@ func findToolSegmentStart(s string) int {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:"}
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:", "\"tool_use\""}
|
||||
bestKeyIdx := -1
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
@@ -191,6 +191,9 @@ func findToolSegmentStart(s string) int {
|
||||
bestKeyIdx = idx
|
||||
}
|
||||
}
|
||||
if fnKeyIdx := findQuotedFunctionCallKeyStart(s); fnKeyIdx >= 0 && (bestKeyIdx < 0 || fnKeyIdx < bestKeyIdx) {
|
||||
bestKeyIdx = fnKeyIdx
|
||||
}
|
||||
// Also detect XML tool call tags.
|
||||
for _, tag := range xmlToolTagsToDetect {
|
||||
idx := strings.Index(lower, tag)
|
||||
@@ -240,13 +243,16 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := -1
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:"}
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:", "\"tool_use\""}
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {
|
||||
keyIdx = idx
|
||||
}
|
||||
}
|
||||
if fnKeyIdx := findQuotedFunctionCallKeyStart(captured); fnKeyIdx >= 0 && (keyIdx < 0 || fnKeyIdx < keyIdx) {
|
||||
keyIdx = fnKeyIdx
|
||||
}
|
||||
|
||||
if keyIdx < 0 {
|
||||
return "", nil, "", false
|
||||
|
||||
100
internal/adapter/openai/tool_sieve_functioncall.go
Normal file
100
internal/adapter/openai/tool_sieve_functioncall.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func findQuotedFunctionCallKeyStart(s string) int {
|
||||
lower := strings.ToLower(s)
|
||||
quotedIdx := findFunctionCallKeyStart(lower, `"functioncall"`)
|
||||
bareIdx := findFunctionCallKeyStart(lower, "functioncall")
|
||||
|
||||
// Prefer the quoted JSON key whenever we have a structural match.
|
||||
// Bare-key detection is only for loose payloads where the quoted form
|
||||
// is absent.
|
||||
if quotedIdx >= 0 {
|
||||
return quotedIdx
|
||||
}
|
||||
return bareIdx
|
||||
}
|
||||
|
||||
func findFunctionCallKeyStart(lower, key string) int {
|
||||
for from := 0; from < len(lower); {
|
||||
rel := strings.Index(lower[from:], key)
|
||||
if rel < 0 {
|
||||
return -1
|
||||
}
|
||||
idx := from + rel
|
||||
if isInsideJSONString(lower, idx) {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
if !hasJSONObjectContextPrefix(lower[:idx]) {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
if !hasJSONKeyBoundary(lower, idx, len(key)) {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
j := idx + len(key)
|
||||
for j < len(lower) && (lower[j] == ' ' || lower[j] == '\t' || lower[j] == '\r' || lower[j] == '\n') {
|
||||
j++
|
||||
}
|
||||
if j < len(lower) && lower[j] == ':' {
|
||||
k := j + 1
|
||||
for k < len(lower) && (lower[k] == ' ' || lower[k] == '\t' || lower[k] == '\r' || lower[k] == '\n') {
|
||||
k++
|
||||
}
|
||||
if k < len(lower) && lower[k] != '{' {
|
||||
from = idx + 1
|
||||
continue
|
||||
}
|
||||
return idx
|
||||
}
|
||||
from = idx + 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func isInsideJSONString(s string, idx int) bool {
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := 0; i < idx; i++ {
|
||||
c := s[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if c == '\\' && inString {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
}
|
||||
}
|
||||
return inString
|
||||
}
|
||||
|
||||
func hasJSONObjectContextPrefix(prefix string) bool {
|
||||
return strings.LastIndex(prefix, "{") >= 0
|
||||
}
|
||||
|
||||
func hasJSONKeyBoundary(s string, idx, keyLen int) bool {
|
||||
if idx > 0 {
|
||||
prev := s[idx-1]
|
||||
if isLowerAlphaNumeric(prev) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if end := idx + keyLen; end < len(s) {
|
||||
next := s[end]
|
||||
if isLowerAlphaNumeric(next) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isLowerAlphaNumeric(b byte) bool {
|
||||
return (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') || b == '_'
|
||||
}
|
||||
23
internal/adapter/openai/tool_sieve_functioncall_test.go
Normal file
23
internal/adapter/openai/tool_sieve_functioncall_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFindQuotedFunctionCallKeyStart_PrefersEarlierBareKey(t *testing.T) {
|
||||
input := `{functionCall:{"name":"a","arguments":"{}"},"message":"literal text: \"functionCall\": not a key"}`
|
||||
|
||||
got := findQuotedFunctionCallKeyStart(input)
|
||||
want := 1
|
||||
if got != want {
|
||||
t.Fatalf("findQuotedFunctionCallKeyStart() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindQuotedFunctionCallKeyStart_PrefersEarlierQuotedKey(t *testing.T) {
|
||||
input := `{"functionCall":{"name":"a","arguments":"{}"},"note":"functionCall appears in prose"}`
|
||||
|
||||
got := findQuotedFunctionCallKeyStart(input)
|
||||
want := 1
|
||||
if got != want {
|
||||
t.Fatalf("findQuotedFunctionCallKeyStart() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
@@ -34,7 +34,8 @@ type toolCallDelta struct {
|
||||
Arguments string
|
||||
}
|
||||
|
||||
const toolSieveContextTailLimit = 256
|
||||
// Keep in sync with JS TOOL_SIEVE_CONTEXT_TAIL_LIMIT.
|
||||
const toolSieveContextTailLimit = 2048
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.disableDeltas = false
|
||||
|
||||
@@ -71,12 +71,31 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
// If this block does not look like an executable tool-call payload,
|
||||
// pass it through as normal content (e.g. user-requested XML snippets).
|
||||
if !looksLikeExecutableXMLToolCallBlock(xmlBlock, pair.open) {
|
||||
return prefixPart + xmlBlock, nil, suffixPart, true
|
||||
}
|
||||
// Looks like XML tool syntax but failed to parse — consume it to avoid leak.
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
func looksLikeExecutableXMLToolCallBlock(xmlBlock, openTag string) bool {
|
||||
lower := strings.ToLower(xmlBlock)
|
||||
// Agent wrapper tags are always treated as internal tool-call wrappers.
|
||||
switch openTag {
|
||||
case "<attempt_completion", "<ask_followup_question", "<new_task":
|
||||
return true
|
||||
}
|
||||
return strings.Contains(lower, "<tool_name") ||
|
||||
strings.Contains(lower, "<parameters") ||
|
||||
strings.Contains(lower, `"tool"`) ||
|
||||
strings.Contains(lower, `"tool_name"`) ||
|
||||
strings.Contains(lower, `"name"`)
|
||||
}
|
||||
|
||||
// hasOpenXMLToolTag returns true if captured text contains an XML tool opening tag
|
||||
// whose SPECIFIC closing tag has not appeared yet.
|
||||
func hasOpenXMLToolTag(captured string) bool {
|
||||
|
||||
@@ -78,6 +78,49 @@ func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePassesThroughNonToolXMLBlock(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := `<tool_call><title>示例 XML</title><body>plain text xml payload</body></tool_call>`
|
||||
events := processToolSieveChunk(&state, chunk, []string{"read_file"})
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
if toolCalls != 0 {
|
||||
t.Fatalf("expected no tool calls for plain XML payload, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.String() != chunk {
|
||||
t.Fatalf("expected XML payload to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveNonToolXMLKeepsSuffixForToolParsing(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := `<tool_call><title>plain xml</title></tool_call><invoke name="read_file"><parameters>{"path":"README.MD"}</parameters></invoke>`
|
||||
events := processToolSieveChunk(&state, chunk, []string{"read_file"})
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
if !strings.Contains(textContent.String(), `<tool_call><title>plain xml</title></tool_call>`) {
|
||||
t.Fatalf("expected leading non-tool XML to be preserved, got %q", textContent.String())
|
||||
}
|
||||
if strings.Contains(textContent.String(), `<invoke name="read_file">`) {
|
||||
t.Fatalf("expected invoke tool XML to be intercepted, got %q", textContent.String())
|
||||
}
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected exactly one parsed tool call from suffix, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePartialXMLTagHeldBack(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Chunk ends with a partial XML tool tag.
|
||||
@@ -104,6 +147,7 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
want int
|
||||
}{
|
||||
{"tool_calls_tag", "some text <tool_calls>\n", 10},
|
||||
{"gemini_function_call_json", `some text {"functionCall":{"name":"search","args":{"q":"latest"}}}`, 10},
|
||||
{"tool_call_tag", "prefix <tool_call>\n", 7},
|
||||
{"invoke_tag", "text <invoke name=\"foo\">body</invoke>", 5},
|
||||
{"function_call_tag", "<function_call name=\"foo\">body</function_call>", 0},
|
||||
@@ -119,6 +163,81 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartIgnoresFunctionCallProse(t *testing.T) {
|
||||
input := "Please explain the functionCall API field and how clients should parse it."
|
||||
if got := findToolSegmentStart(input); got != -1 {
|
||||
t.Fatalf("expected no tool segment start for prose, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartDetectsQuotedFunctionCallKey(t *testing.T) {
|
||||
input := `prefix {"functionCall": {"name":"search_web","args":{"query":"x"}}}`
|
||||
want := strings.Index(input, "{")
|
||||
if got := findToolSegmentStart(input); got != want {
|
||||
t.Fatalf("expected JSON object start %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartDetectsLooseFunctionCallKey(t *testing.T) {
|
||||
input := `prefix {functionCall: {"name":"search_web","args":{"query":"x"}}}`
|
||||
want := strings.Index(input, "{")
|
||||
if got := findToolSegmentStart(input); got != want {
|
||||
t.Fatalf("expected JSON object start %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartPrefersQuotedFunctionCallOverEarlierBareProse(t *testing.T) {
|
||||
input := `prefix {note} functionCall: docs hint {"functionCall":{"name":"search_web","args":{"query":"x"}}}`
|
||||
want := strings.Index(input, `{"functionCall"`)
|
||||
if got := findToolSegmentStart(input); got != want {
|
||||
t.Fatalf("expected quoted functionCall JSON start %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartIgnoresLooseFunctionCallProse(t *testing.T) {
|
||||
input := "Please explain why functionCall: is used in documentation examples."
|
||||
if got := findToolSegmentStart(input); got != -1 {
|
||||
t.Fatalf("expected no tool segment start for prose, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveDoesNotBufferFunctionCallProse(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := "Please explain the functionCall API field and keep streaming this sentence."
|
||||
events := processToolSieveChunk(&state, chunk, []string{"search_web"})
|
||||
var text string
|
||||
for _, evt := range events {
|
||||
text += evt.Content
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
t.Fatalf("expected no tool calls for prose, got %#v", evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
if text != chunk {
|
||||
t.Fatalf("expected prose to pass through immediately, got %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveDetectsGeminiFunctionCallPayload(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
events := processToolSieveChunk(&state, `{"functionCall":{"name":"search_web","args":{"query":"latest"}}}`, []string{"search_web"})
|
||||
events = append(events, flushToolSieve(&state, []string{"search_web"})...)
|
||||
|
||||
var textContent string
|
||||
var toolCalls int
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent += evt.Content
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected one tool call from functionCall payload, got events=%#v", events)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(textContent), "functioncall") {
|
||||
t.Fatalf("functionCall json leaked into text content: %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPartialXMLToolTagStart(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
@@ -288,7 +407,7 @@ func TestOpeningXMLTagNotLeakedAsContent(t *testing.T) {
|
||||
|
||||
func TestProcessToolSieveInterceptsAttemptCompletionLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Simulate an agent outputting attempt_completion XML tag
|
||||
// Simulate an agent outputting attempt_completion XML tag
|
||||
// which shouldn't leak to text output, even if it fails to parse as a valid tool.
|
||||
chunks := []string{
|
||||
"Done with task.\n",
|
||||
|
||||
20
internal/adapter/openai/upstream_empty.go
Normal file
20
internal/adapter/openai/upstream_empty.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
func writeUpstreamEmptyOutputError(w http.ResponseWriter, result sse.CollectResult) bool {
|
||||
if strings.TrimSpace(result.Thinking) != "" || strings.TrimSpace(sanitizeLeakedOutput(result.Text)) != "" {
|
||||
return false
|
||||
}
|
||||
if result.ContentFilter {
|
||||
writeOpenAIErrorWithCode(w, http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter")
|
||||
return true
|
||||
}
|
||||
writeOpenAIErrorWithCode(w, http.StatusBadGateway, "Upstream model returned empty output.", "upstream_empty_output")
|
||||
return true
|
||||
}
|
||||
@@ -21,6 +21,9 @@ type ConfigStore interface {
|
||||
Update(mutator func(*config.Config) error) error
|
||||
ExportJSONAndBase64() (string, string, error)
|
||||
IsEnvBacked() bool
|
||||
IsEnvWritebackEnabled() bool
|
||||
HasEnvConfigSource() bool
|
||||
ConfigPath() string
|
||||
SetVercelSync(hash string, ts int64) error
|
||||
AdminPasswordHash() string
|
||||
AdminJWTExpireHours() int
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
safe := map[string]any{
|
||||
"keys": snap.Keys,
|
||||
"accounts": []map[string]any{},
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"keys": snap.Keys,
|
||||
"accounts": []map[string]any{},
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"env_source_present": h.Store.HasEnvConfigSource(),
|
||||
"env_writeback_enabled": h.Store.IsEnvWritebackEnabled(),
|
||||
"config_path": h.Store.ConfigPath(),
|
||||
"claude_mapping": func() map[string]string {
|
||||
if len(snap.ClaudeMapping) > 0 {
|
||||
return snap.ClaudeMapping
|
||||
|
||||
@@ -204,6 +204,45 @@ func TestSwitchAccountNilTriedAccounts(t *testing.T) {
|
||||
r.Release(a)
|
||||
}
|
||||
|
||||
func TestSwitchAccountSkipsLoginFailureAndContinues(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[
|
||||
{"email":"acc1@test.com","password":"pwd","token":"t1"},
|
||||
{"email":"acc2@test.com","password":"pwd"},
|
||||
{"email":"acc3@test.com","password":"pwd","token":"t3"}
|
||||
]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
r := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) {
|
||||
if acc.Email == "acc2@test.com" {
|
||||
return "", errors.New("login failed")
|
||||
}
|
||||
return "new-token", nil
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
a, err := r.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
defer r.Release(a)
|
||||
if a.AccountID != "acc1@test.com" {
|
||||
t.Fatalf("expected first account, got %q", a.AccountID)
|
||||
}
|
||||
if !r.SwitchAccount(context.Background(), a) {
|
||||
t.Fatal("expected switch to succeed after skipping failed account")
|
||||
}
|
||||
if a.AccountID != "acc3@test.com" {
|
||||
t.Fatalf("expected fallback to third account, got %q", a.AccountID)
|
||||
}
|
||||
if !a.TriedAccounts["acc2@test.com"] {
|
||||
t.Fatalf("expected failed account to be marked as tried")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Release edge cases ─────────────────────────────────────────────
|
||||
|
||||
func TestReleaseNilAuth(t *testing.T) {
|
||||
|
||||
@@ -70,25 +70,53 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
}, nil
|
||||
}
|
||||
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
|
||||
acc, ok := r.Pool.AcquireWait(ctx, target, nil)
|
||||
if !ok {
|
||||
return nil, ErrNoAccount
|
||||
}
|
||||
a := &RequestAuth{
|
||||
UseConfigToken: true,
|
||||
CallerID: callerID,
|
||||
AccountID: acc.Identifier(),
|
||||
Account: acc,
|
||||
TriedAccounts: map[string]bool{},
|
||||
resolver: r,
|
||||
}
|
||||
if err := r.ensureManagedToken(ctx, a); err != nil {
|
||||
r.Pool.Release(a.AccountID)
|
||||
a, err := r.acquireManagedRequestAuth(ctx, callerID, target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) acquireManagedRequestAuth(ctx context.Context, callerID, target string) (*RequestAuth, error) {
|
||||
tried := map[string]bool{}
|
||||
var lastEnsureErr error
|
||||
for {
|
||||
if target == "" && len(tried) >= len(r.Store.Accounts()) {
|
||||
if lastEnsureErr != nil {
|
||||
return nil, lastEnsureErr
|
||||
}
|
||||
return nil, ErrNoAccount
|
||||
}
|
||||
acc, ok := r.Pool.AcquireWait(ctx, target, tried)
|
||||
if !ok {
|
||||
if lastEnsureErr != nil {
|
||||
return nil, lastEnsureErr
|
||||
}
|
||||
return nil, ErrNoAccount
|
||||
}
|
||||
|
||||
a := &RequestAuth{
|
||||
UseConfigToken: true,
|
||||
CallerID: callerID,
|
||||
AccountID: acc.Identifier(),
|
||||
Account: acc,
|
||||
TriedAccounts: tried,
|
||||
resolver: r,
|
||||
}
|
||||
|
||||
if err := r.ensureManagedToken(ctx, a); err != nil {
|
||||
lastEnsureErr = err
|
||||
tried[a.AccountID] = true
|
||||
r.Pool.Release(a.AccountID)
|
||||
if target != "" {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
}
|
||||
|
||||
// DetermineCaller resolves caller identity without acquiring any pooled account.
|
||||
// Use this for local-cache lookup routes that only need tenant isolation.
|
||||
func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) {
|
||||
@@ -164,16 +192,20 @@ func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool {
|
||||
a.TriedAccounts[a.AccountID] = true
|
||||
r.Pool.Release(a.AccountID)
|
||||
}
|
||||
acc, ok := r.Pool.Acquire("", a.TriedAccounts)
|
||||
if !ok {
|
||||
return false
|
||||
for {
|
||||
acc, ok := r.Pool.Acquire("", a.TriedAccounts)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
a.Account = acc
|
||||
a.AccountID = acc.Identifier()
|
||||
if err := r.ensureManagedToken(ctx, a); err != nil {
|
||||
a.TriedAccounts[a.AccountID] = true
|
||||
r.Pool.Release(a.AccountID)
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
a.Account = acc
|
||||
a.AccountID = acc.Identifier()
|
||||
if err := r.ensureManagedToken(ctx, a); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *Resolver) Release(a *RequestAuth) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -301,3 +302,96 @@ func TestDetermineManagedAccountUsesUpdatedRefreshInterval(t *testing.T) {
|
||||
t.Fatalf("expected exactly one login after runtime update, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineManagedAccountRetriesOtherAccountOnLoginFailure(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[
|
||||
{"email":"bad@example.com","password":"pwd"},
|
||||
{"email":"good@example.com","password":"pwd","token":"good-token"}
|
||||
]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) {
|
||||
if acc.Email == "bad@example.com" {
|
||||
return "", errors.New("stale account")
|
||||
}
|
||||
return "fresh-good-token", nil
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
|
||||
a, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
defer resolver.Release(a)
|
||||
if a.AccountID != "good@example.com" {
|
||||
t.Fatalf("expected fallback to good account, got %q", a.AccountID)
|
||||
}
|
||||
if a.DeepSeekToken == "" {
|
||||
t.Fatal("expected non-empty token from fallback account")
|
||||
}
|
||||
if !a.TriedAccounts["bad@example.com"] {
|
||||
t.Fatalf("expected bad account to be tracked as tried")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineTargetAccountDoesNotFallbackOnLoginFailure(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[
|
||||
{"email":"bad@example.com","password":"pwd"},
|
||||
{"email":"good@example.com","password":"pwd","token":"good-token"}
|
||||
]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := NewResolver(store, pool, func(_ context.Context, acc config.Account) (string, error) {
|
||||
if acc.Email == "bad@example.com" {
|
||||
return "", errors.New("stale account")
|
||||
}
|
||||
return "fresh-good-token", nil
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
req.Header.Set("X-Ds2-Target-Account", "bad@example.com")
|
||||
|
||||
_, err := resolver.Determine(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected determine to fail for broken target account")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineManagedAccountReturnsLastEnsureErrorWhenAllFail(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[
|
||||
{"email":"bad1@example.com","password":"pwd"},
|
||||
{"email":"bad2@example.com","password":"pwd"}
|
||||
]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
ensureErr := errors.New("all credentials stale")
|
||||
resolver := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "", ensureErr
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
|
||||
_, err := resolver.Determine(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected determine to fail")
|
||||
}
|
||||
if !errors.Is(err, ensureErr) {
|
||||
t.Fatalf("expected ensure error, got %v", err)
|
||||
}
|
||||
if errors.Is(err, ErrNoAccount) {
|
||||
t.Fatalf("expected auth-style ensure error, got ErrNoAccount")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -79,6 +80,111 @@ func TestLoadStorePreservesFileBackedTokensForRuntime(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvBackedStoreWritebackBootstrapsMissingConfigFile(t *testing.T) {
|
||||
tmp, err := os.CreateTemp(t.TempDir(), "config-*.json")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp config: %v", err)
|
||||
}
|
||||
path := tmp.Name()
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(path)
|
||||
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"seed@example.com","password":"p"}]}`)
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
t.Setenv("DS2API_CONFIG_PATH", path)
|
||||
t.Setenv("DS2API_ENV_WRITEBACK", "1")
|
||||
|
||||
store := LoadStore()
|
||||
if store.IsEnvBacked() {
|
||||
t.Fatalf("expected writeback bootstrap to become file-backed immediately")
|
||||
}
|
||||
if err := store.Update(func(c *Config) error {
|
||||
c.Accounts = append(c.Accounts, Account{Email: "new@example.com", Password: "p2"})
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("update failed: %v", err)
|
||||
}
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read written config: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(content), "seed@example.com") {
|
||||
t.Fatalf("expected bootstrapped config to contain seed account, got: %s", content)
|
||||
}
|
||||
if !strings.Contains(string(content), "new@example.com") {
|
||||
t.Fatalf("expected persisted config to contain added account, got: %s", content)
|
||||
}
|
||||
|
||||
reloaded := LoadStore()
|
||||
if reloaded.IsEnvBacked() {
|
||||
t.Fatalf("expected reloaded store to prefer persisted config file")
|
||||
}
|
||||
accounts := reloaded.Accounts()
|
||||
if len(accounts) != 2 {
|
||||
t.Fatalf("expected 2 accounts after reload, got %d", len(accounts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvBackedStoreWritebackDoesNotBootstrapOnInvalidEnvJSON(t *testing.T) {
|
||||
tmp, err := os.CreateTemp(t.TempDir(), "config-*.json")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp config: %v", err)
|
||||
}
|
||||
path := tmp.Name()
|
||||
_ = tmp.Close()
|
||||
_ = os.Remove(path)
|
||||
|
||||
t.Setenv("DS2API_CONFIG_JSON", "{invalid-json")
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
t.Setenv("DS2API_CONFIG_PATH", path)
|
||||
t.Setenv("DS2API_ENV_WRITEBACK", "1")
|
||||
|
||||
cfg, fromEnv, loadErr := loadConfig()
|
||||
if loadErr == nil {
|
||||
t.Fatalf("expected loadConfig error for invalid env json")
|
||||
}
|
||||
if !fromEnv {
|
||||
t.Fatalf("expected fromEnv=true when parsing env config fails")
|
||||
}
|
||||
if len(cfg.Keys) != 0 || len(cfg.Accounts) != 0 {
|
||||
t.Fatalf("expected empty config on parse failure, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts))
|
||||
}
|
||||
if _, statErr := os.Stat(path); !errors.Is(statErr, os.ErrNotExist) {
|
||||
t.Fatalf("expected no bootstrapped config file, stat err=%v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvBackedStoreWritebackFallsBackToPersistedFileOnInvalidEnvJSON(t *testing.T) {
|
||||
tmp, err := os.CreateTemp(t.TempDir(), "config-*.json")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp config: %v", err)
|
||||
}
|
||||
path := tmp.Name()
|
||||
if _, err := tmp.WriteString(`{"keys":["file-key"],"accounts":[{"email":"persisted@example.com","password":"p"}]}`); err != nil {
|
||||
t.Fatalf("write temp config: %v", err)
|
||||
}
|
||||
_ = tmp.Close()
|
||||
|
||||
t.Setenv("DS2API_CONFIG_JSON", "{invalid-json")
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
t.Setenv("DS2API_CONFIG_PATH", path)
|
||||
t.Setenv("DS2API_ENV_WRITEBACK", "1")
|
||||
|
||||
cfg, fromEnv, loadErr := loadConfig()
|
||||
if loadErr != nil {
|
||||
t.Fatalf("expected fallback to persisted file, got error: %v", loadErr)
|
||||
}
|
||||
if fromEnv {
|
||||
t.Fatalf("expected fallback to file-backed mode")
|
||||
}
|
||||
if len(cfg.Keys) != 1 || cfg.Keys[0] != "file-key" {
|
||||
t.Fatalf("unexpected keys after fallback: %#v", cfg.Keys)
|
||||
}
|
||||
if len(cfg.Accounts) != 1 || cfg.Accounts[0].Email != "persisted@example.com" {
|
||||
t.Fatalf("unexpected accounts after fallback: %#v", cfg.Accounts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeTokenRefreshIntervalHoursDefaultsToSix(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
|
||||
@@ -40,12 +40,38 @@ func loadConfig() (Config, bool, error) {
|
||||
}
|
||||
if rawCfg != "" {
|
||||
cfg, err := parseConfigString(rawCfg)
|
||||
if err != nil {
|
||||
if !IsVercel() && envWritebackEnabled() {
|
||||
if fileCfg, fileErr := loadConfigFromFile(ConfigPath()); fileErr == nil {
|
||||
return fileCfg, false, nil
|
||||
}
|
||||
}
|
||||
return cfg, true, err
|
||||
}
|
||||
cfg.ClearAccountTokens()
|
||||
cfg.DropInvalidAccounts()
|
||||
if IsVercel() || !envWritebackEnabled() {
|
||||
return cfg, true, err
|
||||
}
|
||||
content, fileErr := os.ReadFile(ConfigPath())
|
||||
if fileErr == nil {
|
||||
var fileCfg Config
|
||||
if unmarshalErr := json.Unmarshal(content, &fileCfg); unmarshalErr == nil {
|
||||
fileCfg.DropInvalidAccounts()
|
||||
return fileCfg, false, err
|
||||
}
|
||||
}
|
||||
if errors.Is(fileErr, os.ErrNotExist) {
|
||||
if writeErr := writeConfigFile(ConfigPath(), cfg.Clone()); writeErr == nil {
|
||||
return cfg, false, err
|
||||
} else {
|
||||
Logger.Warn("[config] env writeback bootstrap failed", "error", writeErr)
|
||||
}
|
||||
}
|
||||
return cfg, true, err
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(ConfigPath())
|
||||
cfg, err := loadConfigFromFile(ConfigPath())
|
||||
if err != nil {
|
||||
if IsVercel() {
|
||||
// Vercel one-click deploy may start without a writable/present config file.
|
||||
@@ -54,16 +80,6 @@ func loadConfig() (Config, bool, error) {
|
||||
}
|
||||
return Config{}, false, err
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, false, err
|
||||
}
|
||||
cfg.DropInvalidAccounts()
|
||||
if strings.Contains(string(content), `"test_status"`) && !IsVercel() {
|
||||
if b, err := json.MarshalIndent(cfg, "", " "); err == nil {
|
||||
_ = os.WriteFile(ConfigPath(), b, 0o644)
|
||||
}
|
||||
}
|
||||
if IsVercel() {
|
||||
// Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
|
||||
return cfg, true, nil
|
||||
@@ -71,6 +87,24 @@ func loadConfig() (Config, bool, error) {
|
||||
return cfg, false, nil
|
||||
}
|
||||
|
||||
func loadConfigFromFile(path string) (Config, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.DropInvalidAccounts()
|
||||
if strings.Contains(string(content), `"test_status"`) && !IsVercel() {
|
||||
if b, err := json.MarshalIndent(cfg, "", " "); err == nil {
|
||||
_ = os.WriteFile(path, b, 0o644)
|
||||
}
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() Config {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -177,7 +211,7 @@ func (s *Store) Update(mutator func(*Config) error) error {
|
||||
func (s *Store) Save() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.fromEnv {
|
||||
if s.fromEnv && (IsVercel() || !envWritebackEnabled()) {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
@@ -187,11 +221,15 @@ func (s *Store) Save() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, b, 0o644)
|
||||
if err := writeConfigBytes(s.path, b); err != nil {
|
||||
return err
|
||||
}
|
||||
s.fromEnv = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) saveLocked() error {
|
||||
if s.fromEnv {
|
||||
if s.fromEnv && (IsVercel() || !envWritebackEnabled()) {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
@@ -201,7 +239,11 @@ func (s *Store) saveLocked() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(s.path, b, 0o644)
|
||||
if err := writeConfigBytes(s.path, b); err != nil {
|
||||
return err
|
||||
}
|
||||
s.fromEnv = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) IsEnvBacked() bool {
|
||||
|
||||
51
internal/config/store_env_writeback.go
Normal file
51
internal/config/store_env_writeback.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func envWritebackEnabled() bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv("DS2API_ENV_WRITEBACK")))
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
func (s *Store) IsEnvWritebackEnabled() bool {
|
||||
return envWritebackEnabled()
|
||||
}
|
||||
|
||||
func (s *Store) HasEnvConfigSource() bool {
|
||||
rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON"))
|
||||
if rawCfg == "" {
|
||||
rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON"))
|
||||
}
|
||||
return rawCfg != ""
|
||||
}
|
||||
|
||||
func (s *Store) ConfigPath() string {
|
||||
return s.path
|
||||
}
|
||||
|
||||
func writeConfigFile(path string, cfg Config) error {
|
||||
persistCfg := cfg.Clone()
|
||||
persistCfg.ClearAccountTokens()
|
||||
b, err := json.MarshalIndent(persistCfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeConfigBytes(path, b)
|
||||
}
|
||||
|
||||
func writeConfigBytes(path string, b []byte) error {
|
||||
dir := filepath.Dir(path)
|
||||
if dir == "." || dir == "" {
|
||||
return os.WriteFile(path, b, 0o644)
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("mkdir config dir: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, b, 0o644)
|
||||
}
|
||||
@@ -59,8 +59,9 @@ async function handler(req, res) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Keep all non-stream behavior on Go side to avoid compatibility regressions.
|
||||
if (!toBool(payload.stream)) {
|
||||
// Keep all non-stream behavior and non-OpenAI-chat paths on Go side to avoid
|
||||
// protocol-shape regressions (e.g. Gemini/Claude clients expecting their own formats).
|
||||
if (!toBool(payload.stream) || !isNodeStreamSupportedPath(req.url || '')) {
|
||||
await proxyToGo(req, res, rawBody);
|
||||
return;
|
||||
}
|
||||
@@ -76,6 +77,23 @@ function isVercelRuntime() {
|
||||
return asString(process.env.VERCEL) !== '' || asString(process.env.NOW_REGION) !== '';
|
||||
}
|
||||
|
||||
function isNodeStreamSupportedPath(rawURL) {
|
||||
const path = extractPathname(rawURL);
|
||||
return path === '/v1/chat/completions';
|
||||
}
|
||||
|
||||
function extractPathname(rawURL) {
|
||||
const text = asString(rawURL);
|
||||
if (!text) {
|
||||
return '';
|
||||
}
|
||||
const q = text.indexOf('?');
|
||||
if (q >= 0) {
|
||||
return text.slice(0, q);
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
module.exports = handler;
|
||||
|
||||
module.exports.__test = {
|
||||
@@ -89,4 +107,6 @@ module.exports.__test = {
|
||||
boolDefaultTrue,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
estimateTokens,
|
||||
isNodeStreamSupportedPath,
|
||||
extractPathname,
|
||||
};
|
||||
|
||||
@@ -193,6 +193,9 @@ function extractContentRecursive(items, defaultType) {
|
||||
}
|
||||
|
||||
function shouldSkipPath(pathValue) {
|
||||
if (isFragmentStatusPath(pathValue)) {
|
||||
return true;
|
||||
}
|
||||
if (SKIP_EXACT_PATHS.has(pathValue)) {
|
||||
return true;
|
||||
}
|
||||
@@ -204,6 +207,13 @@ function shouldSkipPath(pathValue) {
|
||||
return false;
|
||||
}
|
||||
|
||||
function isFragmentStatusPath(pathValue) {
|
||||
if (!pathValue || pathValue === 'response/status') {
|
||||
return false;
|
||||
}
|
||||
return /^response\/fragments\/-?\d+\/status$/i.test(pathValue);
|
||||
}
|
||||
|
||||
function isCitation(text) {
|
||||
return asString(text).trim().startsWith('[citation:');
|
||||
}
|
||||
@@ -225,5 +235,6 @@ module.exports = {
|
||||
parseChunkForContent,
|
||||
extractContentRecursive,
|
||||
shouldSkipPath,
|
||||
isFragmentStatusPath,
|
||||
isCitation,
|
||||
};
|
||||
|
||||
@@ -237,7 +237,10 @@ function isLikelyJSONToolPayloadCandidate(text) {
|
||||
return false;
|
||||
}
|
||||
const lower = trimmed.toLowerCase();
|
||||
return lower.includes('tool_calls') || lower.includes('"function"');
|
||||
return lower.includes('tool_calls')
|
||||
|| lower.includes('"function"')
|
||||
|| lower.includes('functioncall')
|
||||
|| lower.includes('"tool_use"');
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
|
||||
@@ -85,6 +85,8 @@ function extractToolCallObjects(text) {
|
||||
while (true) {
|
||||
const idxToolCalls = lower.indexOf('tool_calls', offset);
|
||||
const idxFunction = lower.indexOf('"function"', offset);
|
||||
const idxFunctionCall = lower.indexOf('functioncall', offset);
|
||||
const idxToolUse = lower.indexOf('"tool_use"', offset);
|
||||
let idx = -1;
|
||||
let matched = '';
|
||||
if (idxToolCalls >= 0 && (idxFunction < 0 || idxToolCalls <= idxFunction)) {
|
||||
@@ -94,6 +96,14 @@ function extractToolCallObjects(text) {
|
||||
idx = idxFunction;
|
||||
matched = '"function"';
|
||||
}
|
||||
if (idxFunctionCall >= 0 && (idx < 0 || idxFunctionCall < idx)) {
|
||||
idx = idxFunctionCall;
|
||||
matched = 'functioncall';
|
||||
}
|
||||
if (idxToolUse >= 0 && (idx < 0 || idxToolUse < idx)) {
|
||||
idx = idxToolUse;
|
||||
matched = '"tool_use"';
|
||||
}
|
||||
if (idx < 0) {
|
||||
break;
|
||||
}
|
||||
@@ -102,7 +112,10 @@ function extractToolCallObjects(text) {
|
||||
const obj = extractJSONObjectFrom(raw, start);
|
||||
if (obj.ok) {
|
||||
out.push(raw.slice(start, obj.end).trim());
|
||||
offset = obj.end;
|
||||
// Ensure forward progress even when the matched keyword is outside
|
||||
// the extracted JSON object (e.g. closing XML wrapper tags containing
|
||||
// "tool_calls" after an earlier JSON arguments object).
|
||||
offset = Math.max(obj.end, idx + matched.length);
|
||||
idx = -1;
|
||||
break;
|
||||
}
|
||||
@@ -324,6 +337,20 @@ function parseToolCallItem(m) {
|
||||
let name = toStringSafe(m.name);
|
||||
let inputRaw = m.input;
|
||||
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
||||
const fnCall = m.functionCall && typeof m.functionCall === 'object' ? m.functionCall : null;
|
||||
if (fnCall) {
|
||||
if (!name) {
|
||||
name = toStringSafe(fnCall.name);
|
||||
}
|
||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fnCall, 'args')) {
|
||||
inputRaw = fnCall.args;
|
||||
hasInput = true;
|
||||
}
|
||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fnCall, 'arguments')) {
|
||||
inputRaw = fnCall.arguments;
|
||||
hasInput = true;
|
||||
}
|
||||
}
|
||||
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
||||
|
||||
if (fn) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
'use strict';
|
||||
|
||||
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 4096;
|
||||
// Keep in sync with Go toolSieveContextTailLimit.
|
||||
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 2048;
|
||||
|
||||
function createToolSieveState() {
|
||||
return {
|
||||
|
||||
@@ -4,6 +4,8 @@ const TOOL_SEGMENT_KEYWORDS = [
|
||||
'tool_calls',
|
||||
'"function"',
|
||||
'function.name:',
|
||||
'functioncall',
|
||||
'"tool_use"',
|
||||
];
|
||||
|
||||
const XML_TOOL_SEGMENT_TAGS = [
|
||||
|
||||
@@ -5,6 +5,12 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var promptXMLTextEscaper = strings.NewReplacer(
|
||||
"&", "&",
|
||||
"<", "<",
|
||||
">", ">",
|
||||
)
|
||||
|
||||
// FormatToolCallsForPrompt renders a tool_calls slice into the canonical
|
||||
// prompt-visible history block used across adapters.
|
||||
func FormatToolCallsForPrompt(raw any) string {
|
||||
@@ -82,8 +88,8 @@ func formatToolCallForPrompt(call map[string]any) string {
|
||||
}
|
||||
|
||||
return " <tool_call>\n" +
|
||||
" <tool_name>" + name + "</tool_name>\n" +
|
||||
" <parameters>" + StringifyToolCallArguments(argsRaw) + "</parameters>\n" +
|
||||
" <tool_name>" + escapeXMLText(name) + "</tool_name>\n" +
|
||||
" <parameters>" + escapeXMLText(StringifyToolCallArguments(argsRaw)) + "</parameters>\n" +
|
||||
" </tool_call>"
|
||||
}
|
||||
|
||||
@@ -122,3 +128,10 @@ func asString(v any) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func escapeXMLText(v string) string {
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
return promptXMLTextEscaper.Replace(v)
|
||||
}
|
||||
|
||||
@@ -26,3 +26,16 @@ func TestFormatToolCallsForPromptXML(t *testing.T) {
|
||||
t.Fatalf("unexpected formatted tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolCallsForPromptEscapesXMLEntities(t *testing.T) {
|
||||
got := FormatToolCallsForPrompt([]any{
|
||||
map[string]any{
|
||||
"name": "search<&>",
|
||||
"arguments": `{"q":"a < b && c > d"}`,
|
||||
},
|
||||
})
|
||||
want := "<tool_calls>\n <tool_call>\n <tool_name>search<&></tool_name>\n <parameters>{\"q\":\"a < b && c > d\"}</parameters>\n </tool_call>\n</tool_calls>"
|
||||
if got != want {
|
||||
t.Fatalf("unexpected escaped tool call XML: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,8 +44,8 @@ func NewApp() *App {
|
||||
}
|
||||
|
||||
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: openaiHandler}
|
||||
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: openaiHandler}
|
||||
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient}
|
||||
webuiHandler := webui.NewHandler()
|
||||
|
||||
|
||||
@@ -10,8 +10,10 @@ import (
|
||||
// CollectResult holds the aggregated text and thinking content from a
|
||||
// DeepSeek SSE stream, consumed to completion (non-streaming use case).
|
||||
type CollectResult struct {
|
||||
Text string
|
||||
Thinking string
|
||||
Text string
|
||||
Thinking string
|
||||
OutputTokens int
|
||||
ContentFilter bool
|
||||
}
|
||||
|
||||
// CollectStream fully consumes a DeepSeek SSE response and separates
|
||||
@@ -26,6 +28,8 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
}
|
||||
text := strings.Builder{}
|
||||
thinking := strings.Builder{}
|
||||
outputTokens := 0
|
||||
contentFilter := false
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
@@ -37,8 +41,17 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
return true
|
||||
}
|
||||
if result.Stop {
|
||||
if result.ContentFilter {
|
||||
contentFilter = true
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
return false
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
for _, p := range result.Parts {
|
||||
if p.Type == "thinking" {
|
||||
thinking.WriteString(p.Text)
|
||||
@@ -48,5 +61,10 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
}
|
||||
return true
|
||||
})
|
||||
return CollectResult{Text: text.String(), Thinking: thinking.String()}
|
||||
return CollectResult{
|
||||
Text: text.String(),
|
||||
Thinking: thinking.String(),
|
||||
OutputTokens: outputTokens,
|
||||
ContentFilter: contentFilter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,3 +138,15 @@ func TestCollectStreamStatusFinished(t *testing.T) {
|
||||
t.Fatalf("expected 'Hello', got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamStopsOnContentFilterStatus(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"safe\"}\n" +
|
||||
"data: {\"p\":\"response/status\",\"v\":\"CONTENT_FILTER\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"blocked\"}\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
if result.Text != "safe" {
|
||||
t.Fatalf("expected stream to stop before blocked tail, got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
45
internal/sse/content_filter_leak.go
Normal file
45
internal/sse/content_filter_leak.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package sse
|
||||
|
||||
import "strings"
|
||||
|
||||
func filterLeakedContentFilterParts(parts []ContentPart) []ContentPart {
|
||||
if len(parts) == 0 {
|
||||
return parts
|
||||
}
|
||||
out := make([]ContentPart, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
cleaned := stripLeakedContentFilterSuffix(p.Text)
|
||||
if shouldDropCleanedLeakedChunk(cleaned) {
|
||||
continue
|
||||
}
|
||||
p.Text = cleaned
|
||||
out = append(out, p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stripLeakedContentFilterSuffix(text string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
upperText := strings.ToUpper(text)
|
||||
idx := strings.Index(upperText, "CONTENT_FILTER")
|
||||
if idx < 0 {
|
||||
return text
|
||||
}
|
||||
// Keep "\n" so we don't collapse line structure when the upstream model
|
||||
// appends leaked CONTENT_FILTER markers after a line break.
|
||||
return strings.TrimRight(text[:idx], " \t\r")
|
||||
}
|
||||
|
||||
func shouldDropCleanedLeakedChunk(cleaned string) bool {
|
||||
if cleaned == "" {
|
||||
return true
|
||||
}
|
||||
// Preserve newline-only chunks to avoid dropping legitimate line breaks
|
||||
// before a leaked CONTENT_FILTER suffix.
|
||||
if strings.Contains(cleaned, "\n") {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(cleaned) == ""
|
||||
}
|
||||
@@ -10,6 +10,7 @@ type LineResult struct {
|
||||
ErrorMessage string
|
||||
Parts []ContentPart
|
||||
NextType string
|
||||
OutputTokens int
|
||||
}
|
||||
|
||||
// ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both
|
||||
@@ -35,15 +36,26 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Parsed: true,
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
ErrorMessage: "content filtered by upstream",
|
||||
NextType: currentType,
|
||||
OutputTokens: extractAccumulatedTokenUsage(chunk),
|
||||
}
|
||||
}
|
||||
if hasContentFilterStatus(chunk) {
|
||||
return LineResult{
|
||||
Parsed: true,
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
NextType: currentType,
|
||||
OutputTokens: extractAccumulatedTokenUsage(chunk),
|
||||
}
|
||||
}
|
||||
parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
parts = filterLeakedContentFilterParts(parts)
|
||||
return LineResult{
|
||||
Parsed: true,
|
||||
Stop: finished,
|
||||
Parts: parts,
|
||||
NextType: nextType,
|
||||
OutputTokens: extractAccumulatedTokenUsage(chunk),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,8 +40,8 @@ func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) {
|
||||
if !res.ContentFilter {
|
||||
t.Fatal("expected content filter flag")
|
||||
}
|
||||
if res.ErrorMessage == "" {
|
||||
t.Fatal("expected error message on content filter")
|
||||
if res.ErrorMessage != "" {
|
||||
t.Fatalf("expected empty error message on content filter, got %q", res.ErrorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,33 @@ func TestParseDeepSeekContentLineContentFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentFilterCodeIncludesOutputTokens(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine(
|
||||
[]byte(`data: {"code":"content_filter","accumulated_token_usage":99}`),
|
||||
false, "text",
|
||||
)
|
||||
if !res.Parsed || !res.Stop || !res.ContentFilter {
|
||||
t.Fatalf("expected content-filter stop result: %#v", res)
|
||||
}
|
||||
if res.OutputTokens != 99 {
|
||||
t.Fatalf("expected output token usage 99, got %d", res.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentFilterStatus(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/status","v":"CONTENT_FILTER"}`), false, "text")
|
||||
if !res.Parsed || !res.Stop || !res.ContentFilter {
|
||||
t.Fatalf("expected status-based content-filter stop result: %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineCapturesAccumulatedTokenUsage(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response","o":"BATCH","v":[{"p":"accumulated_token_usage","v":1383},{"p":"quasi_status","v":"FINISHED"}]}`), false, "text")
|
||||
if res.OutputTokens != 1383 {
|
||||
t.Fatalf("expected output token usage 1383, got %d", res.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContent(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"hi"}`), false, "text")
|
||||
if !res.Parsed || res.Stop {
|
||||
@@ -35,3 +62,63 @@ func TestParseDeepSeekContentLineContent(t *testing.T) {
|
||||
t.Fatalf("unexpected parts: %#v", res.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineStripsLeakedContentFilterSuffix(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"正常输出CONTENT_FILTER你好,这个问题我暂时无法回答"}`), false, "text")
|
||||
if !res.Parsed || res.Stop {
|
||||
t.Fatalf("expected parsed non-stop result: %#v", res)
|
||||
}
|
||||
if len(res.Parts) != 1 || res.Parts[0].Text != "正常输出" {
|
||||
t.Fatalf("unexpected parts after filter: %#v", res.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineDropsPureLeakedContentFilterChunk(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"CONTENT_FILTER你好,这个问题我暂时无法回答"}`), false, "text")
|
||||
if !res.Parsed || res.Stop {
|
||||
t.Fatalf("expected parsed non-stop result: %#v", res)
|
||||
}
|
||||
if len(res.Parts) != 0 {
|
||||
t.Fatalf("expected empty parts, got %#v", res.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineTrimsFromContentFilterKeyword(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"模型会在命中 CONTENT_FILTER 时返回拒绝原因。"}`), false, "text")
|
||||
if !res.Parsed || res.Stop {
|
||||
t.Fatalf("expected parsed non-stop result: %#v", res)
|
||||
}
|
||||
if len(res.Parts) != 1 || res.Parts[0].Text != "模型会在命中" {
|
||||
t.Fatalf("unexpected parts after filter: %#v", res.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentTextEqualContentFilterDoesNotStop(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"content_filter"}`), false, "text")
|
||||
if !res.Parsed {
|
||||
t.Fatalf("expected parsed result: %#v", res)
|
||||
}
|
||||
if res.Stop || res.ContentFilter {
|
||||
t.Fatalf("did not expect content-filter stop for content text: %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLinePreservesTrailingNewlineBeforeLeakedContentFilter(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte("data: {\"p\":\"response/content\",\"v\":\"line1\\nCONTENT_FILTERblocked\"}"), false, "text")
|
||||
if !res.Parsed || res.Stop {
|
||||
t.Fatalf("expected parsed non-stop result: %#v", res)
|
||||
}
|
||||
if len(res.Parts) != 1 || res.Parts[0].Text != "line1\n" {
|
||||
t.Fatalf("expected trailing newline preserved, got %#v", res.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineKeepsNewlineOnlyChunkBeforeLeakedContentFilter(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte("data: {\"p\":\"response/content\",\"v\":\"\\nCONTENT_FILTERblocked\"}"), false, "text")
|
||||
if !res.Parsed || res.Stop {
|
||||
t.Fatalf("expected parsed non-stop result: %#v", res)
|
||||
}
|
||||
if len(res.Parts) != 1 || res.Parts[0].Text != "\n" {
|
||||
t.Fatalf("expected newline-only chunk preserved, got %#v", res.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package sse
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/deepseek"
|
||||
@@ -30,6 +31,9 @@ func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) {
|
||||
}
|
||||
|
||||
func shouldSkipPath(path string) bool {
|
||||
if isFragmentStatusPath(path) {
|
||||
return true
|
||||
}
|
||||
if _, ok := deepseek.SkipExactPathSet[path]; ok {
|
||||
return true
|
||||
}
|
||||
@@ -41,6 +45,31 @@ func shouldSkipPath(path string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isFragmentStatusPath(path string) bool {
|
||||
if path == "" || path == "response/status" {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(path, "response/fragments/") || !strings.HasSuffix(path, "/status") {
|
||||
return false
|
||||
}
|
||||
mid := strings.TrimSuffix(strings.TrimPrefix(path, "response/fragments/"), "/status")
|
||||
if mid == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(mid, "-") {
|
||||
mid = mid[1:]
|
||||
}
|
||||
if mid == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range mid {
|
||||
if r < '0' || r > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ParseSSEChunkForContent(chunk map[string]any, thinkingEnabled bool, currentFragmentType string) ([]ContentPart, bool, string) {
|
||||
v, ok := chunk["v"]
|
||||
if !ok {
|
||||
@@ -287,3 +316,90 @@ func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bo
|
||||
func IsCitation(text string) bool {
|
||||
return bytes.HasPrefix([]byte(strings.TrimSpace(text)), []byte("[citation:"))
|
||||
}
|
||||
|
||||
func hasContentFilterStatus(chunk map[string]any) bool {
|
||||
if code, _ := chunk["code"].(string); strings.EqualFold(strings.TrimSpace(code), "content_filter") {
|
||||
return true
|
||||
}
|
||||
return hasContentFilterStatusValue(chunk)
|
||||
}
|
||||
|
||||
func hasContentFilterStatusValue(v any) bool {
|
||||
switch x := v.(type) {
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
if hasContentFilterStatusValue(item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "status") {
|
||||
if s, _ := x["v"].(string); strings.EqualFold(strings.TrimSpace(s), "content_filter") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if code, _ := x["code"].(string); strings.EqualFold(strings.TrimSpace(code), "content_filter") {
|
||||
return true
|
||||
}
|
||||
for _, vv := range x {
|
||||
if hasContentFilterStatusValue(vv) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAccumulatedTokenUsage(chunk map[string]any) int {
|
||||
return findAccumulatedTokenUsage(chunk)
|
||||
}
|
||||
|
||||
func findAccumulatedTokenUsage(v any) int {
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "accumulated_token_usage") {
|
||||
if n, ok := toInt(x["v"]); ok && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if n, ok := toInt(x["accumulated_token_usage"]); ok && n > 0 {
|
||||
return n
|
||||
}
|
||||
for _, vv := range x {
|
||||
if n := findAccumulatedTokenUsage(vv); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
if n := findAccumulatedTokenUsage(item); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func toInt(v any) (int, bool) {
|
||||
switch x := v.(type) {
|
||||
case int:
|
||||
return x, true
|
||||
case int32:
|
||||
return int(x), true
|
||||
case int64:
|
||||
return int(x), true
|
||||
case float64:
|
||||
if math.IsNaN(x) || math.IsInf(x, 0) {
|
||||
return 0, false
|
||||
}
|
||||
return int(x), true
|
||||
case json.Number:
|
||||
i, err := x.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int(i), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +90,15 @@ func TestShouldSkipPathFragmentStatus(t *testing.T) {
|
||||
if !shouldSkipPath("response/fragments/-3/status") {
|
||||
t.Fatal("expected skip for fragment -3 status")
|
||||
}
|
||||
if !shouldSkipPath("response/fragments/-16/status") {
|
||||
t.Fatal("expected skip for fragment -16 status")
|
||||
}
|
||||
if !shouldSkipPath("response/fragments/7/status") {
|
||||
t.Fatal("expected skip for fragment 7 status")
|
||||
}
|
||||
if shouldSkipPath("response/status") {
|
||||
t.Fatal("expected response/status to be handled by finish logic, not skipped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathRegularContent(t *testing.T) {
|
||||
|
||||
67
internal/translatorcliproxy/bridge.go
Normal file
67
internal/translatorcliproxy/bridge.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package translatorcliproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator/builtin"
|
||||
)
|
||||
|
||||
func ToOpenAI(from sdktranslator.Format, model string, raw []byte, stream bool) []byte {
|
||||
return sdktranslator.TranslateRequest(from, sdktranslator.FormatOpenAI, model, raw, stream)
|
||||
}
|
||||
|
||||
func FromOpenAINonStream(to sdktranslator.Format, model string, originalReq, translatedReq, raw []byte) []byte {
|
||||
var param any
|
||||
return sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, ¶m)
|
||||
}
|
||||
|
||||
func FromOpenAIStream(to sdktranslator.Format, model string, originalReq, translatedReq, streamBody []byte) []byte {
|
||||
var out bytes.Buffer
|
||||
var param any
|
||||
for _, line := range bytes.Split(streamBody, []byte("\n")) {
|
||||
trimmed := strings.TrimSpace(string(line))
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
payload := append([]byte(nil), line...)
|
||||
if !bytes.HasPrefix(payload, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, payload, ¶m)
|
||||
for i := range chunks {
|
||||
out.Write(chunks[i])
|
||||
if !bytes.HasSuffix(chunks[i], []byte("\n")) {
|
||||
out.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
return out.Bytes()
|
||||
}
|
||||
|
||||
func ParseFormat(name string) sdktranslator.Format {
|
||||
switch strings.ToLower(strings.TrimSpace(name)) {
|
||||
case "openai", "openai-chat", "chat", "chat-completions":
|
||||
return sdktranslator.FormatOpenAI
|
||||
case "openai-response", "responses", "openai-responses":
|
||||
return sdktranslator.FormatOpenAIResponse
|
||||
case "claude", "anthropic":
|
||||
return sdktranslator.FormatClaude
|
||||
case "gemini", "google":
|
||||
return sdktranslator.FormatGemini
|
||||
case "gemini-cli", "geminicli":
|
||||
return sdktranslator.FormatGeminiCLI
|
||||
case "codex", "openai-codex":
|
||||
return sdktranslator.FormatCodex
|
||||
case "antigravity":
|
||||
return sdktranslator.FormatAntigravity
|
||||
default:
|
||||
return sdktranslator.FromString(name)
|
||||
}
|
||||
}
|
||||
|
||||
func ToOpenAIByName(formatName, model string, raw []byte, stream bool) []byte {
|
||||
return ToOpenAI(ParseFormat(formatName), model, raw, stream)
|
||||
}
|
||||
72
internal/translatorcliproxy/bridge_test.go
Normal file
72
internal/translatorcliproxy/bridge_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package translatorcliproxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func TestToOpenAIClaude(t *testing.T) {
|
||||
raw := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`)
|
||||
got := ToOpenAI(sdktranslator.FormatClaude, "claude-sonnet-4-5", raw, false)
|
||||
s := string(got)
|
||||
if !strings.Contains(s, `"messages"`) || !strings.Contains(s, `"model"`) {
|
||||
t.Fatalf("unexpected translated request: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromOpenAINonStreamClaude(t *testing.T) {
|
||||
original := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`)
|
||||
translatedReq := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`)
|
||||
openaibody := []byte(`{"id":"chatcmpl_1","object":"chat.completion","created":1,"model":"claude-sonnet-4-5","choices":[{"index":0,"message":{"role":"assistant","content":"hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`)
|
||||
got := FromOpenAINonStream(sdktranslator.FormatClaude, "claude-sonnet-4-5", original, translatedReq, openaibody)
|
||||
if !strings.Contains(string(got), `"type":"message"`) {
|
||||
t.Fatalf("expected claude response format, got: %s", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFormatAliases(t *testing.T) {
|
||||
cases := map[string]sdktranslator.Format{
|
||||
"responses": sdktranslator.FormatOpenAIResponse,
|
||||
"anthropic": sdktranslator.FormatClaude,
|
||||
"geminicli": sdktranslator.FormatGeminiCLI,
|
||||
"openai-codex": sdktranslator.FormatCodex,
|
||||
"antigravity": sdktranslator.FormatAntigravity,
|
||||
"chat-completions": sdktranslator.FormatOpenAI,
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := ParseFormat(in); got != want {
|
||||
t.Fatalf("ParseFormat(%q)=%q want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToOpenAIByNameAllSupportedFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
format string
|
||||
model string
|
||||
body string
|
||||
}{
|
||||
{name: "openai", format: "openai", model: "gpt-4.1", body: `{"model":"gpt-4.1","messages":[{"role":"user","content":"hi"}],"stream":false}`},
|
||||
{name: "responses", format: "responses", model: "gpt-4.1", body: `{"model":"gpt-4.1","input":"hello","stream":false}`},
|
||||
{name: "claude", format: "claude", model: "claude-sonnet-4-5", body: `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hello"}],"stream":false}`},
|
||||
{name: "gemini", format: "gemini", model: "gemini-2.5-pro", body: `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`},
|
||||
{name: "gemini-cli", format: "gemini-cli", model: "gemini-2.5-pro", body: `{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hello"}],"stream":false}`},
|
||||
{name: "codex", format: "codex", model: "gpt-5-codex", body: `{"model":"gpt-5-codex","messages":[{"role":"user","content":"hello"}],"stream":false}`},
|
||||
{name: "antigravity", format: "antigravity", model: "gpt-4.1", body: `{"model":"gpt-4.1","messages":[{"role":"user","content":"hello"}],"stream":false}`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ToOpenAIByName(tc.format, tc.model, []byte(tc.body), false)
|
||||
if len(got) == 0 {
|
||||
t.Fatalf("expected non-empty conversion result for format=%s", tc.format)
|
||||
}
|
||||
if !strings.Contains(string(got), `"model"`) {
|
||||
t.Fatalf("expected model field in converted payload, got=%s", string(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
120
internal/translatorcliproxy/stream_writer.go
Normal file
120
internal/translatorcliproxy/stream_writer.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package translatorcliproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
// OpenAIStreamTranslatorWriter translates OpenAI SSE output to another client format in real-time.
|
||||
type OpenAIStreamTranslatorWriter struct {
|
||||
dst http.ResponseWriter
|
||||
target sdktranslator.Format
|
||||
model string
|
||||
originalReq []byte
|
||||
translatedReq []byte
|
||||
param any
|
||||
statusCode int
|
||||
headersSent bool
|
||||
lineBuf bytes.Buffer
|
||||
}
|
||||
|
||||
func NewOpenAIStreamTranslatorWriter(dst http.ResponseWriter, target sdktranslator.Format, model string, originalReq, translatedReq []byte) *OpenAIStreamTranslatorWriter {
|
||||
return &OpenAIStreamTranslatorWriter{
|
||||
dst: dst,
|
||||
target: target,
|
||||
model: model,
|
||||
originalReq: originalReq,
|
||||
translatedReq: translatedReq,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) Header() http.Header {
|
||||
return w.dst.Header()
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) WriteHeader(statusCode int) {
|
||||
if w.headersSent {
|
||||
return
|
||||
}
|
||||
w.statusCode = statusCode
|
||||
w.headersSent = true
|
||||
w.dst.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) Write(p []byte) (int, error) {
|
||||
if !w.headersSent {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if w.statusCode < 200 || w.statusCode >= 300 {
|
||||
return w.dst.Write(p)
|
||||
}
|
||||
w.lineBuf.Write(p)
|
||||
for {
|
||||
line, ok := w.readOneLine()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(trimmed, []byte(":")) {
|
||||
if _, err := w.dst.Write(trimmed); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if _, err := w.dst.Write([]byte("\n\n")); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if f, ok := w.dst.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(context.Background(), sdktranslator.FormatOpenAI, w.target, w.model, w.originalReq, w.translatedReq, trimmed, &w.param)
|
||||
for i := range chunks {
|
||||
if len(chunks[i]) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := w.dst.Write(chunks[i]); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if !bytes.HasSuffix(chunks[i], []byte("\n")) {
|
||||
if _, err := w.dst.Write([]byte("\n")); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
}
|
||||
}
|
||||
if f, ok := w.dst.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) Flush() {
|
||||
if f, ok := w.dst.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) Unwrap() http.ResponseWriter {
|
||||
return w.dst
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) readOneLine() ([]byte, bool) {
|
||||
b := w.lineBuf.Bytes()
|
||||
idx := bytes.IndexByte(b, '\n')
|
||||
if idx < 0 {
|
||||
return nil, false
|
||||
}
|
||||
line := append([]byte(nil), b[:idx]...)
|
||||
w.lineBuf.Next(idx + 1)
|
||||
return line, true
|
||||
}
|
||||
57
internal/translatorcliproxy/stream_writer_test.go
Normal file
57
internal/translatorcliproxy/stream_writer_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package translatorcliproxy
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func TestOpenAIStreamTranslatorWriterClaude(t *testing.T) {
|
||||
original := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`)
|
||||
translated := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := NewOpenAIStreamTranslatorWriter(rec, sdktranslator.FormatClaude, "claude-sonnet-4-5", original, translated)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: message_start") {
|
||||
t.Fatalf("expected claude message_start event, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamTranslatorWriterGemini(t *testing.T) {
|
||||
original := []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`)
|
||||
translated := []byte(`{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hi"}],"stream":true}`)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := NewOpenAIStreamTranslatorWriter(rec, sdktranslator.FormatGemini, "gemini-2.5-pro", original, translated)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gemini-2.5-pro\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "candidates") {
|
||||
t.Fatalf("expected gemini stream payload, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamTranslatorWriterPreservesKeepAliveComment(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
w := NewOpenAIStreamTranslatorWriter(rec, sdktranslator.FormatGemini, "gemini-2.5-pro", []byte(`{}`), []byte(`{}`))
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte(": keep-alive\n\n"))
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, ": keep-alive\n\n") {
|
||||
t.Fatalf("expected keep-alive comment passthrough, got %q", body)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package util
|
||||
|
||||
import "strings"
|
||||
|
||||
// BuildToolCallInstructions generates the unified tool-calling instruction block
|
||||
// used by all adapters (OpenAI, Claude, Gemini). It uses attention-optimized
|
||||
// structure: rules → negative examples → positive examples → anchor.
|
||||
@@ -19,7 +21,7 @@ func BuildToolCallInstructions(toolNames []string) string {
|
||||
ex1 = n
|
||||
used["ex1"] = true
|
||||
// Write/execute-type tools
|
||||
case !used["ex2"] && matchAny(n, "write_to_file", "apply_diff", "execute_command", "Write", "Edit", "MultiEdit", "Bash"):
|
||||
case !used["ex2"] && matchAny(n, "write_to_file", "apply_diff", "execute_command", "exec_command", "Write", "Edit", "MultiEdit", "Bash"):
|
||||
ex2 = n
|
||||
used["ex2"] = true
|
||||
// Interactive/meta tools
|
||||
@@ -28,10 +30,13 @@ func BuildToolCallInstructions(toolNames []string) string {
|
||||
used["ex3"] = true
|
||||
}
|
||||
}
|
||||
ex1Params := exampleReadParams(ex1)
|
||||
ex2Params := exampleWriteOrExecParams(ex2)
|
||||
ex3Params := exampleInteractiveParams(ex3)
|
||||
|
||||
return `TOOL CALL FORMAT — FOLLOW EXACTLY:
|
||||
|
||||
When calling tools, emit ONLY raw XML. No text before, no text after, no markdown fences.
|
||||
When calling tools, emit ONLY raw XML at the very end of your response. No text before, no text after, no markdown fences.
|
||||
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
@@ -47,6 +52,7 @@ RULES:
|
||||
4) Do NOT wrap the XML in markdown code fences (no triple backticks).
|
||||
5) After receiving a tool result, use it directly. Only call another tool if the result is insufficient.
|
||||
6) If you want to say something AND call a tool, output text first, then the XML block on its own.
|
||||
7) Parameters MUST use the exact field names from the selected tool schema.
|
||||
|
||||
❌ WRONG — Do NOT do these:
|
||||
Wrong 1 — mixed text and XML:
|
||||
@@ -62,7 +68,7 @@ Example A — Single tool:
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>` + ex1 + `</tool_name>
|
||||
<parameters>{"path":"src/main.go"}</parameters>
|
||||
<parameters>` + ex1Params + `</parameters>
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
@@ -70,11 +76,11 @@ Example B — Two tools in parallel:
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>` + ex1 + `</tool_name>
|
||||
<parameters>{"path":"config.json"}</parameters>
|
||||
<parameters>` + ex1Params + `</parameters>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<tool_name>` + ex2 + `</tool_name>
|
||||
<parameters>{"path":"output.txt","content":"Hello world"}</parameters>
|
||||
<parameters>` + ex2Params + `</parameters>
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
@@ -82,7 +88,7 @@ Example C — Tool with complex nested JSON parameters:
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>` + ex3 + `</tool_name>
|
||||
<parameters>{"question":"Which approach do you prefer?","follow_up":[{"text":"Option A"},{"text":"Option B"}]}</parameters>
|
||||
<parameters>` + ex3Params + `</parameters>
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
@@ -97,3 +103,38 @@ func matchAny(name string, candidates ...string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func exampleReadParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Read":
|
||||
return `{"file_path":"README.md"}`
|
||||
case "Glob":
|
||||
return `{"pattern":"**/*.go","path":"."}`
|
||||
default:
|
||||
return `{"path":"src/main.go"}`
|
||||
}
|
||||
}
|
||||
|
||||
func exampleWriteOrExecParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Bash", "execute_command":
|
||||
return `{"command":"pwd"}`
|
||||
case "exec_command":
|
||||
return `{"cmd":"pwd"}`
|
||||
case "Edit":
|
||||
return `{"file_path":"README.md","old_string":"foo","new_string":"bar"}`
|
||||
case "MultiEdit":
|
||||
return `{"file_path":"README.md","edits":[{"old_string":"foo","new_string":"bar"}]}`
|
||||
default:
|
||||
return `{"path":"output.txt","content":"Hello world"}`
|
||||
}
|
||||
}
|
||||
|
||||
func exampleInteractiveParams(name string) string {
|
||||
switch strings.TrimSpace(name) {
|
||||
case "Task":
|
||||
return `{"description":"Investigate flaky tests","prompt":"Run targeted tests and summarize failures"}`
|
||||
default:
|
||||
return `{"question":"Which approach do you prefer?","follow_up":[{"text":"Option A"},{"text":"Option B"}]}`
|
||||
}
|
||||
}
|
||||
|
||||
26
internal/util/tool_prompt_test.go
Normal file
26
internal/util/tool_prompt_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildToolCallInstructions_ExecCommandUsesCmdExample(t *testing.T) {
|
||||
out := BuildToolCallInstructions([]string{"exec_command"})
|
||||
if !strings.Contains(out, `<tool_name>exec_command</tool_name>`) {
|
||||
t.Fatalf("expected exec_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<parameters>{"cmd":"pwd"}</parameters>`) {
|
||||
t.Fatalf("expected cmd parameter example for exec_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolCallInstructions_ExecuteCommandUsesCommandExample(t *testing.T) {
|
||||
out := BuildToolCallInstructions([]string{"execute_command"})
|
||||
if !strings.Contains(out, `<tool_name>execute_command</tool_name>`) {
|
||||
t.Fatalf("expected execute_command in examples, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, `<parameters>{"command":"pwd"}</parameters>`) {
|
||||
t.Fatalf("expected command parameter example for execute_command, got: %s", out)
|
||||
}
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func extractToolCallObjects(text string) []string {
|
||||
lower := strings.ToLower(text)
|
||||
out := []string{}
|
||||
offset := 0
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:"}
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:", "functioncall", "\"tool_use\""}
|
||||
for {
|
||||
bestIdx := -1
|
||||
matchedKeyword := ""
|
||||
|
||||
@@ -196,18 +196,6 @@ func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isLikelyJSONToolPayloadCandidate(candidate string) bool {
|
||||
trimmed := strings.TrimSpace(candidate)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if !(strings.HasPrefix(trimmed, "{") || strings.HasPrefix(trimmed, "[")) {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
return strings.Contains(lower, "tool_calls") || strings.Contains(lower, "\"function\"")
|
||||
}
|
||||
|
||||
func isLikelyChatMessageEnvelope(v map[string]any) bool {
|
||||
if v == nil {
|
||||
return false
|
||||
@@ -234,62 +222,11 @@ func looksLikeToolCallSyntax(text string) bool {
|
||||
lower := strings.ToLower(text)
|
||||
return strings.Contains(lower, "tool_calls") ||
|
||||
strings.Contains(lower, "\"function\"") ||
|
||||
strings.Contains(lower, "functioncall") ||
|
||||
strings.Contains(lower, "\"tool_use\"") ||
|
||||
strings.Contains(lower, "<tool_call") ||
|
||||
strings.Contains(lower, "<function_call") ||
|
||||
strings.Contains(lower, "<function_name") ||
|
||||
strings.Contains(lower, "<invoke") ||
|
||||
strings.Contains(lower, "function.name:")
|
||||
}
|
||||
|
||||
func parseToolCallList(v any) []ParsedToolCall {
|
||||
items, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(items))
|
||||
for _, item := range items {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if tc, ok := parseToolCallItem(m); ok {
|
||||
out = append(out, tc)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) {
|
||||
name, _ := m["name"].(string)
|
||||
inputRaw, hasInput := m["input"]
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name, _ = fn["name"].(string)
|
||||
}
|
||||
if !hasInput {
|
||||
if v, ok := fn["arguments"]; ok {
|
||||
inputRaw = v
|
||||
hasInput = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasInput {
|
||||
for _, key := range []string{"arguments", "args", "parameters", "params"} {
|
||||
if v, ok := m[key]; ok {
|
||||
inputRaw = v
|
||||
hasInput = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
return ParsedToolCall{
|
||||
Name: strings.TrimSpace(name),
|
||||
Input: parseToolCallInput(inputRaw),
|
||||
}, true
|
||||
}
|
||||
|
||||
88
internal/util/toolcalls_parse_item.go
Normal file
88
internal/util/toolcalls_parse_item.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package util
|
||||
|
||||
import "strings"
|
||||
|
||||
func isLikelyJSONToolPayloadCandidate(candidate string) bool {
|
||||
trimmed := strings.TrimSpace(candidate)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if !(strings.HasPrefix(trimmed, "{") || strings.HasPrefix(trimmed, "[")) {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
return strings.Contains(lower, "tool_calls") ||
|
||||
strings.Contains(lower, "\"function\"") ||
|
||||
strings.Contains(lower, "functioncall") ||
|
||||
strings.Contains(lower, "\"tool_use\"")
|
||||
}
|
||||
|
||||
func parseToolCallList(v any) []ParsedToolCall {
|
||||
items, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(items))
|
||||
for _, item := range items {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if tc, ok := parseToolCallItem(m); ok {
|
||||
out = append(out, tc)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) {
|
||||
name, _ := m["name"].(string)
|
||||
inputRaw, hasInput := m["input"]
|
||||
if fnCall, ok := m["functionCall"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name, _ = fnCall["name"].(string)
|
||||
}
|
||||
if !hasInput {
|
||||
if v, ok := fnCall["args"]; ok {
|
||||
inputRaw = v
|
||||
hasInput = true
|
||||
}
|
||||
}
|
||||
if !hasInput {
|
||||
if v, ok := fnCall["arguments"]; ok {
|
||||
inputRaw = v
|
||||
hasInput = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name, _ = fn["name"].(string)
|
||||
}
|
||||
if !hasInput {
|
||||
if v, ok := fn["arguments"]; ok {
|
||||
inputRaw = v
|
||||
hasInput = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasInput {
|
||||
for _, key := range []string{"arguments", "args", "parameters", "params"} {
|
||||
if v, ok := m[key]; ok {
|
||||
inputRaw = v
|
||||
hasInput = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
return ParsedToolCall{
|
||||
Name: strings.TrimSpace(name),
|
||||
Input: parseToolCallInput(inputRaw),
|
||||
}, true
|
||||
}
|
||||
@@ -19,6 +19,10 @@ var toolUseFunctionPattern = regexp.MustCompile(`(?is)<tool_use>\s*<function\s+n
|
||||
var toolUseNameParametersPattern = regexp.MustCompile(`(?is)<tool_use>\s*<tool_name>\s*([^<]+?)\s*</tool_name>\s*<parameters>\s*(.*?)\s*</parameters>\s*</tool_use>`)
|
||||
var toolUseFunctionNameParametersPattern = regexp.MustCompile(`(?is)<tool_use>\s*<function_name>\s*([^<]+?)\s*</function_name>\s*<parameters>\s*(.*?)\s*</parameters>\s*</tool_use>`)
|
||||
var toolUseToolNameBodyPattern = regexp.MustCompile(`(?is)<tool_use>\s*<tool_name>\s*([^<]+?)\s*</tool_name>\s*(.*?)\s*</tool_use>`)
|
||||
var xmlToolNamePatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?tool_name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?tool_name>`),
|
||||
regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function_name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?function_name>`),
|
||||
}
|
||||
|
||||
func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
matches := xmlToolCallPattern.FindAllString(text, -1)
|
||||
@@ -81,9 +85,9 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
dec := xml.NewDecoder(strings.NewReader(block))
|
||||
name := ""
|
||||
params := map[string]any{}
|
||||
params := extractXMLToolParamsByRegex(inner)
|
||||
dec := xml.NewDecoder(strings.NewReader(block))
|
||||
inParams := false
|
||||
inTool := false
|
||||
for {
|
||||
@@ -132,9 +136,13 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
}
|
||||
inParams = false
|
||||
case "tool_name", "name":
|
||||
case "tool_name", "function_name", "name":
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||
if inParams {
|
||||
params[t.Name.Local] = strings.TrimSpace(v)
|
||||
break
|
||||
}
|
||||
name = strings.TrimSpace(v)
|
||||
}
|
||||
case "input", "arguments", "argument", "args", "params":
|
||||
@@ -164,12 +172,60 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = strings.TrimSpace(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner)))
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
|
||||
}
|
||||
|
||||
func stripTopLevelXMLParameters(inner string) string {
|
||||
out := strings.TrimSpace(inner)
|
||||
for {
|
||||
idx := strings.Index(strings.ToLower(out), "<parameters")
|
||||
if idx < 0 {
|
||||
return out
|
||||
}
|
||||
segment := out[idx:]
|
||||
segmentLower := strings.ToLower(segment)
|
||||
openEnd := strings.Index(segmentLower, ">")
|
||||
if openEnd < 0 {
|
||||
return out
|
||||
}
|
||||
closeIdx := strings.Index(segmentLower, "</parameters>")
|
||||
if closeIdx < 0 {
|
||||
return out[:idx]
|
||||
}
|
||||
end := idx + closeIdx + len("</parameters>")
|
||||
out = out[:idx] + out[end:]
|
||||
}
|
||||
}
|
||||
|
||||
func extractXMLToolNameByRegex(inner string) string {
|
||||
for _, pattern := range xmlToolNamePatterns {
|
||||
if m := pattern.FindStringSubmatch(inner); len(m) >= 2 {
|
||||
if v := strings.TrimSpace(stripTagText(m[1])); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractXMLToolParamsByRegex(inner string) map[string]any {
|
||||
raw := findMarkupTagValue(inner, toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag)
|
||||
if raw == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
parsed := parseMarkupInput(raw)
|
||||
if parsed == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
||||
m := functionCallPattern.FindStringSubmatch(text)
|
||||
if len(m) < 2 {
|
||||
|
||||
@@ -176,6 +176,35 @@ func TestParseToolCallsSupportsCanonicalXMLParametersJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsXMLParametersJSONWithAmpersandCommand(t *testing.T) {
|
||||
text := `<tool_calls><tool_call><tool_name>execute_command</tool_name><parameters>{"command":"sshpass -p 'xxx' ssh -o StrictHostKeyChecking=no -p 1111 root@111.111.111.111 'cd /root && git clone https://github.com/ericc-ch/copilot-api.git'","cwd":null,"timeout":null}</parameters></tool_call></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"execute_command"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "execute_command" {
|
||||
t.Fatalf("expected tool name execute_command, got %q", calls[0].Name)
|
||||
}
|
||||
cmd, _ := calls[0].Input["command"].(string)
|
||||
if !strings.Contains(cmd, "&& git clone") {
|
||||
t.Fatalf("expected command to keep && segment, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsDoesNotTreatParameterNameTagAsToolName(t *testing.T) {
|
||||
text := `<tool_call><tool name="execute_command"><parameters><name>file.txt</name><command>pwd</command></parameters></tool></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"execute_command"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "execute_command" {
|
||||
t.Fatalf("expected tool name execute_command, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["name"] != "file.txt" {
|
||||
t.Fatalf("expected parameter name preserved, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsPrefersJSONPayloadOverIncidentalXMLInString(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"search","input":{"q":"latest <tool_call><tool_name>wrong</tool_name><parameters>{\"x\":1}</parameters></tool_call>"}}]}`
|
||||
calls := ParseToolCallsDetailed(text, []string{"search"}).Calls
|
||||
@@ -271,6 +300,34 @@ func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsGeminiFunctionCallJSON(t *testing.T) {
|
||||
text := `{"functionCall":{"name":"search_web","args":{"query":"latest"}}}`
|
||||
calls := ParseToolCalls(text, []string{"search_web"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "search_web" {
|
||||
t.Fatalf("expected search_web, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["query"] != "latest" {
|
||||
t.Fatalf("expected query argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsClaudeToolUseJSON(t *testing.T) {
|
||||
text := `{"type":"tool_use","name":"read_file","input":{"path":"README.md"}}`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "read_file" {
|
||||
t.Fatalf("expected read_file, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["path"] != "README.md" {
|
||||
t.Fatalf("expected path argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsToolUseFunctionParameterStyle(t *testing.T) {
|
||||
text := `<tool_use><function name="search_web"><parameter name="query">test</parameter></function></tool_use>`
|
||||
calls := ParseToolCalls(text, []string{"search_web"})
|
||||
@@ -374,6 +431,14 @@ func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsDoesNotTreatParametersFunctionNameAsToolName(t *testing.T) {
|
||||
text := `<tool_call><parameters><function_name>data_only</function_name><path>README.md</path></parameters></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool call when function_name appears only under parameters, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairInvalidJSONBackslashes(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
|
||||
@@ -17,6 +17,9 @@ const {
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
shouldSkipPath,
|
||||
isNodeStreamSupportedPath,
|
||||
extractPathname,
|
||||
} = handler.__test;
|
||||
|
||||
test('chat-stream exposes parser test hooks', () => {
|
||||
@@ -218,3 +221,21 @@ test('parseChunkForContent supports wrapped response.fragments object shape', ()
|
||||
assert.equal(parsed.finished, false);
|
||||
assert.equal(parsed.parts.map((p) => p.text).join(''), 'AB');
|
||||
});
|
||||
|
||||
test('shouldSkipPath skips dynamic response/fragments/*/status paths only', () => {
|
||||
assert.equal(shouldSkipPath('response/fragments/-16/status'), true);
|
||||
assert.equal(shouldSkipPath('response/fragments/8/status'), true);
|
||||
assert.equal(shouldSkipPath('response/status'), false);
|
||||
});
|
||||
|
||||
test('node stream path guard only allows /v1/chat/completions', () => {
|
||||
assert.equal(isNodeStreamSupportedPath('/v1/chat/completions'), true);
|
||||
assert.equal(isNodeStreamSupportedPath('/v1/chat/completions?x=1'), true);
|
||||
assert.equal(isNodeStreamSupportedPath('/v1beta/models/gemini-2.5-flash:streamGenerateContent'), false);
|
||||
assert.equal(isNodeStreamSupportedPath('/anthropic/v1/messages'), false);
|
||||
});
|
||||
|
||||
test('extractPathname strips query only', () => {
|
||||
assert.equal(extractPathname('/v1/chat/completions?stream=true'), '/v1/chat/completions');
|
||||
assert.equal(extractPathname('/v1beta/models/gemini-2.5-flash:streamGenerateContent?key=1'), '/v1beta/models/gemini-2.5-flash:streamGenerateContent');
|
||||
});
|
||||
|
||||
@@ -108,6 +108,24 @@ test('parseToolCalls parses text-kv fallback payload', () => {
|
||||
assert.equal(calls[0].input.command, 'cd scripts && python check_syntax.py example.py');
|
||||
});
|
||||
|
||||
test('parseToolCalls supports Gemini functionCall JSON payload', () => {
|
||||
const payload = JSON.stringify({
|
||||
functionCall: { name: 'search_web', args: { query: 'latest' } },
|
||||
});
|
||||
const calls = parseToolCalls(payload, ['search_web']);
|
||||
assert.deepEqual(calls, [{ name: 'search_web', input: { query: 'latest' } }]);
|
||||
});
|
||||
|
||||
test('parseToolCalls supports Claude tool_use JSON payload', () => {
|
||||
const payload = JSON.stringify({
|
||||
type: 'tool_use',
|
||||
name: 'read_file',
|
||||
input: { path: 'README.md' },
|
||||
});
|
||||
const calls = parseToolCalls(payload, ['read_file']);
|
||||
assert.deepEqual(calls, [{ name: 'read_file', input: { path: 'README.md' } }]);
|
||||
});
|
||||
|
||||
test('parseToolCalls parses multiple text-kv fallback payloads', () => {
|
||||
const text = [
|
||||
'function.name: read_file',
|
||||
@@ -227,6 +245,24 @@ test('sieve flushes incomplete captured XML tool blocks without leaking raw tags
|
||||
assert.equal(leakedText.includes('<tool_call'), false);
|
||||
});
|
||||
|
||||
test('sieve captures XML wrapper tags with attributes without leaking wrapper text', () => {
|
||||
const events = runSieve(
|
||||
[
|
||||
'前置正文H。',
|
||||
'<tool_calls id="x"><tool_call><tool_name>read_file</tool_name><parameters>{"path":"README.MD"}</parameters></tool_call></tool_calls>',
|
||||
'后置正文I。',
|
||||
],
|
||||
['read_file'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && evt.calls?.length > 0);
|
||||
assert.equal(hasToolCall, true);
|
||||
assert.equal(leakedText.includes('前置正文H。'), true);
|
||||
assert.equal(leakedText.includes('后置正文I。'), true);
|
||||
assert.equal(leakedText.includes('<tool_calls id=\"x\">'), false);
|
||||
assert.equal(leakedText.includes('</tool_calls>'), false);
|
||||
});
|
||||
|
||||
test('sieve still intercepts large tool json payloads over previous capture limit', () => {
|
||||
const large = 'a'.repeat(9000);
|
||||
const payload = `{"tool_calls":[{"name":"read_file","input":{"path":"${large}"}}]}`;
|
||||
@@ -252,6 +288,46 @@ test('sieve keeps plain text intact in tool mode when no tool call appears', ()
|
||||
assert.equal(leakedText, '你好,这是普通文本回复。请继续。');
|
||||
});
|
||||
|
||||
test('sieve keeps plain "tool_calls" prose as text when no valid payload follows', () => {
|
||||
const events = runSieve(
|
||||
['前置。', '这里提到 tool_calls 只是解释,不是调用。', '后置。'],
|
||||
['read_file'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && evt.calls?.length > 0);
|
||||
assert.equal(hasToolCall, false);
|
||||
assert.equal(leakedText.includes('tool_calls'), true);
|
||||
assert.equal(leakedText, '前置。这里提到 tool_calls 只是解释,不是调用。后置。');
|
||||
});
|
||||
|
||||
test('sieve keeps numbered planning prose before a real tool payload (mobile-chat style)', () => {
|
||||
const events = runSieve(
|
||||
[
|
||||
'好的,我会依次测试每个工具,先把所有工具都调用一遍,然后汇总结果给你看。\n\n1. 获取当前时间\n',
|
||||
'{"tool_calls":[{"name":"get_current_time","input":{}}]}',
|
||||
],
|
||||
['get_current_time'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||
assert.equal(finalCalls.length, 1);
|
||||
assert.equal(finalCalls[0].name, 'get_current_time');
|
||||
assert.equal(leakedText.includes('先把所有工具都调用一遍'), true);
|
||||
assert.equal(leakedText.includes('1. 获取当前时间'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
});
|
||||
|
||||
test('sieve keeps numbered planning prose when no tool payload follows', () => {
|
||||
const events = runSieve(
|
||||
['好的,我会依次测试每个工具。\n\n1. 获取当前时间'],
|
||||
['get_current_time'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && evt.calls?.length > 0);
|
||||
assert.equal(hasToolCall, false);
|
||||
assert.equal(leakedText, '好的,我会依次测试每个工具。\n\n1. 获取当前时间');
|
||||
});
|
||||
|
||||
test('sieve emits unknown tool payload (no args) as executable tool call', () => {
|
||||
const events = runSieve(
|
||||
['{"tool_calls":[{"name":"not_in_schema"}]}', '后置正文G。'],
|
||||
|
||||
28
tests/raw_stream_samples/README.md
Normal file
28
tests/raw_stream_samples/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# 原始流数据样本目录
|
||||
|
||||
该目录用于存放**上游真实 SSE 原始流**样本,供本地仿真测试和解析适配使用。
|
||||
|
||||
## 目录规范
|
||||
|
||||
每个样本一个子目录:
|
||||
|
||||
- `meta.json`:样本元信息(问题、模型、采集时间、备注)
|
||||
- `upstream.stream.sse`:完整原始 SSE 文本(`event:` / `data:` 行)
|
||||
|
||||
## 扩展方式
|
||||
|
||||
1. 抓取一次真实请求(建议开启 `DS2API_DEV_PACKET_CAPTURE=1`)。
|
||||
2. 新建 `<sample-id>/` 目录并放入 `meta.json` + `upstream.stream.sse`。
|
||||
3. 运行独立仿真工具(可被其他测试脚本调用):
|
||||
|
||||
```bash
|
||||
./tests/scripts/run-raw-stream-sim.sh
|
||||
```
|
||||
|
||||
该工具会自动遍历本目录全部样本,按真实流顺序重放并验证:
|
||||
|
||||
- 不会把上游 `status=FINISHED` 片段当正文输出(防泄露)。
|
||||
- 能正确检测 `response/status=FINISHED` 流结束信号。
|
||||
- 生成可归档 JSON 报告(`artifacts/raw-stream-sim/`)。
|
||||
|
||||
> 注意:样本可能包含搜索结果正文与引用信息,请勿放入敏感账号/密钥。
|
||||
@@ -0,0 +1,55 @@
|
||||
# 样本分析(广州天气 / deepseek-reasoner-search)
|
||||
|
||||
- 样本来源:`/admin/dev/captures` 上游原始 SSE 抓包
|
||||
- 采集时间(UTC):2026-04-03 01:28:50
|
||||
- 原始字节数:41043
|
||||
- `FINISHED` 字符串出现次数:24
|
||||
- JSON `data:` chunk 数:420
|
||||
|
||||
## 事件分布
|
||||
|
||||
- `ready`: 1
|
||||
- `update_session`: 2
|
||||
- `finish`: 1
|
||||
|
||||
## 高频路径(Top 12)
|
||||
|
||||
- `response/fragments/-1/content`: 13
|
||||
- `response/fragments/-1`: 9
|
||||
- `response`: 5
|
||||
- `response/has_pending_fragment`: 4
|
||||
- `response/fragments/-1/elapsed_secs`: 3
|
||||
- `response/fragments/-5/status`: 2
|
||||
- `response/fragments/-6/status`: 2
|
||||
- `response/fragments/-3/status`: 2
|
||||
- `response/fragments/-1/status`: 2
|
||||
- `response/fragments/-4/status`: 2
|
||||
- `response/fragments/-2/status`: 2
|
||||
- `response/fragments/-5/results`: 1
|
||||
|
||||
## 关键泄露来源
|
||||
|
||||
以下状态路径会高频出现 `v=FINISHED`,如果解析器按普通文本透传,就会出现 `FINISHEDFINISHED...` 泄露:
|
||||
|
||||
- `response/fragments/-5/status`: 2
|
||||
- `response/fragments/-6/status`: 2
|
||||
- `response/fragments/-3/status`: 2
|
||||
- `response/fragments/-1/status`: 2
|
||||
- `response/fragments/-4/status`: 2
|
||||
- `response/fragments/-2/status`: 2
|
||||
- `response/fragments/-14/status`: 1
|
||||
- `response/fragments/-12/status`: 1
|
||||
- `response/fragments/-10/status`: 1
|
||||
- `response/fragments/-9/status`: 1
|
||||
- `response/fragments/-8/status`: 1
|
||||
- `response/fragments/-7/status`: 1
|
||||
- `response/fragments/-11/status`: 1
|
||||
- `response/fragments/-16/status`: 1
|
||||
- `response/fragments/-13/status`: 1
|
||||
- `response/fragments/-15/status`: 1
|
||||
|
||||
## 适配建议
|
||||
|
||||
1. 跳过 `response/fragments/<index>/status`(所有 index,而非仅 `-1/-2/-3`)。
|
||||
2. 保留 `response/status=FINISHED` 用于结束流判定,不应当输出正文。
|
||||
3. 在样本仿真测试中对全部样本执行“不得输出 `FINISHED`”断言。
|
||||
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"sample_id": "guangzhou-weather-reasoner-search-20260403",
|
||||
"captured_at_utc": "2026-04-03T01:28:50Z",
|
||||
"request": {
|
||||
"model": "deepseek-reasoner-search",
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "广州天气"
|
||||
}
|
||||
],
|
||||
"thinking_enabled": true,
|
||||
"search_enabled": true
|
||||
},
|
||||
"capture": {
|
||||
"label": "deepseek_completion",
|
||||
"url": "https://chat.deepseek.com/api/v0/chat/completion",
|
||||
"status_code": 200,
|
||||
"response_bytes": 41043,
|
||||
"contains_finished_token": true,
|
||||
"finished_token_count": 24
|
||||
},
|
||||
"notes": "Captured from upstream DeepSeek SSE via /admin/dev/captures with packet capture enabled. Account ID removed."
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
98
tests/scripts/capture-raw-stream-sample.sh
Executable file
98
tests/scripts/capture-raw-stream-sample.sh
Executable file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
CONFIG_PATH="${1:-config.json}"
|
||||
SAMPLE_ID="${2:-sample-$(date -u +%Y%m%dT%H%M%SZ)}"
|
||||
QUESTION="${3:-广州天气}"
|
||||
MODEL="${4:-deepseek-reasoner-search}"
|
||||
API_KEY="${5:-}"
|
||||
ADMIN_KEY="${DS2API_ADMIN_KEY:-admin}"
|
||||
|
||||
if [[ -z "$API_KEY" ]]; then
|
||||
API_KEY="$(python3 - <<'PY' "$CONFIG_PATH"
|
||||
import json,sys
|
||||
cfg=json.load(open(sys.argv[1]))
|
||||
keys=cfg.get('keys') or []
|
||||
print(keys[0] if keys else '')
|
||||
PY
|
||||
)"
|
||||
fi
|
||||
|
||||
if [[ -z "$API_KEY" ]]; then
|
||||
echo "[capture] missing API key (pass as arg5 or set config.keys[0])" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
OUT_DIR="tests/raw_stream_samples/${SAMPLE_ID}"
|
||||
mkdir -p "$OUT_DIR"
|
||||
|
||||
cleanup() {
|
||||
pkill -f "cmd/ds2api" >/dev/null 2>&1 || true
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
DS2API_CONFIG_PATH="$CONFIG_PATH" \
|
||||
DS2API_ADMIN_KEY="$ADMIN_KEY" \
|
||||
DS2API_DEV_PACKET_CAPTURE=1 \
|
||||
DS2API_DEV_PACKET_CAPTURE_LIMIT=20 \
|
||||
go run ./cmd/ds2api >/tmp/ds2api_capture_server.log 2>&1 &
|
||||
|
||||
for _ in $(seq 1 120); do
|
||||
if curl -sSf http://127.0.0.1:5001/healthz >/dev/null 2>&1; then
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
REQUEST_BODY="$(python3 - <<'PY' "$MODEL" "$QUESTION"
|
||||
import json,sys
|
||||
model,question=sys.argv[1:3]
|
||||
payload={
|
||||
'model':model,
|
||||
'stream':True,
|
||||
'messages':[{'role':'user','content':question}],
|
||||
}
|
||||
print(json.dumps(payload, ensure_ascii=False))
|
||||
PY
|
||||
)"
|
||||
|
||||
curl -sS http://127.0.0.1:5001/v1/chat/completions \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H "Authorization: Bearer ${API_KEY}" \
|
||||
--data-binary "${REQUEST_BODY}" \
|
||||
>"${OUT_DIR}/openai.stream.sse"
|
||||
|
||||
curl -sS http://127.0.0.1:5001/admin/dev/captures \
|
||||
-H "Authorization: Bearer ${ADMIN_KEY}" \
|
||||
>"${OUT_DIR}/captures.json"
|
||||
|
||||
python3 - <<'PY' "$OUT_DIR" "$SAMPLE_ID" "$QUESTION" "$MODEL"
|
||||
import json,sys,pathlib,datetime
|
||||
out=pathlib.Path(sys.argv[1])
|
||||
sample_id,question,model=sys.argv[2:5]
|
||||
captures=json.loads((out/'captures.json').read_text())
|
||||
items=captures.get('items') or []
|
||||
if not items:
|
||||
raise SystemExit('no captured upstream stream found')
|
||||
best=max(items,key=lambda x:len((x.get('response_body') or '')))
|
||||
raw=best.get('response_body') or ''
|
||||
(out/'upstream.stream.sse').write_text(raw)
|
||||
meta={
|
||||
'sample_id':sample_id,
|
||||
'captured_at_utc':datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ'),
|
||||
'request':{'model':model,'stream':True,'messages':[{'role':'user','content':question}]},
|
||||
'capture':{
|
||||
'label':best.get('label'),'url':best.get('url'),'status_code':best.get('status_code'),
|
||||
'response_bytes':len(raw),'contains_finished_token':('FINISHED' in raw),'finished_token_count':raw.count('FINISHED')
|
||||
}
|
||||
}
|
||||
(out/'meta.json').write_text(json.dumps(meta,ensure_ascii=False,indent=2))
|
||||
print(f'[capture] wrote sample to {out}')
|
||||
print(f'[capture] upstream bytes={len(raw)} finished_count={raw.count("FINISHED")}')
|
||||
PY
|
||||
|
||||
rm -f "${OUT_DIR}/captures.json"
|
||||
echo "[capture] done: ${OUT_DIR}"
|
||||
16
tests/scripts/run-raw-stream-sim.sh
Executable file
16
tests/scripts/run-raw-stream-sim.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
|
||||
cd "$ROOT_DIR"
|
||||
|
||||
REPORT_DIR="artifacts/raw-stream-sim"
|
||||
mkdir -p "$REPORT_DIR"
|
||||
REPORT_PATH="$REPORT_DIR/report-$(date -u +%Y%m%dT%H%M%SZ).json"
|
||||
|
||||
node tests/tools/deepseek-sse-simulator.mjs \
|
||||
--samples-root tests/raw_stream_samples \
|
||||
--report "$REPORT_PATH" \
|
||||
"$@"
|
||||
|
||||
echo "[run-raw-stream-sim] report: $REPORT_PATH"
|
||||
158
tests/tools/deepseek-sse-simulator.mjs
Executable file
158
tests/tools/deepseek-sse-simulator.mjs
Executable file
@@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env node
|
||||
import fs from 'node:fs';
|
||||
import path from 'node:path';
|
||||
import process from 'node:process';
|
||||
import { createRequire } from 'node:module';
|
||||
|
||||
const require = createRequire(import.meta.url);
|
||||
const chatStream = require('../../api/chat-stream.js');
|
||||
const { parseChunkForContent } = chatStream.__test;
|
||||
|
||||
function parseArgs(argv) {
|
||||
const out = {
|
||||
samplesRoot: 'tests/raw_stream_samples',
|
||||
reportPath: '',
|
||||
failOnLeak: true,
|
||||
failOnMissingFinish: true,
|
||||
};
|
||||
for (let i = 2; i < argv.length; i += 1) {
|
||||
const a = argv[i];
|
||||
if (a === '--samples-root' && argv[i + 1]) {
|
||||
out.samplesRoot = argv[++i];
|
||||
} else if (a === '--report' && argv[i + 1]) {
|
||||
out.reportPath = argv[++i];
|
||||
} else if (a === '--no-fail-on-leak') {
|
||||
out.failOnLeak = false;
|
||||
} else if (a === '--no-fail-on-missing-finish') {
|
||||
out.failOnMissingFinish = false;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function findSampleDirs(root) {
|
||||
if (!fs.existsSync(root)) {
|
||||
return [];
|
||||
}
|
||||
return fs.readdirSync(root)
|
||||
.map((name) => path.join(root, name))
|
||||
.filter((p) => fs.statSync(p).isDirectory())
|
||||
.filter((p) => fs.existsSync(path.join(p, 'upstream.stream.sse')))
|
||||
.sort();
|
||||
}
|
||||
|
||||
function parseSSE(raw) {
|
||||
const events = [];
|
||||
for (const block of raw.split(/\r?\n\r?\n/)) {
|
||||
if (!block.trim()) {
|
||||
continue;
|
||||
}
|
||||
let eventType = 'message';
|
||||
const dataLines = [];
|
||||
for (const line of block.split(/\r?\n/)) {
|
||||
if (line.startsWith('event:')) {
|
||||
eventType = line.slice(6).trim() || 'message';
|
||||
} else if (line.startsWith('data:')) {
|
||||
dataLines.push(line.slice(5).trimStart());
|
||||
}
|
||||
}
|
||||
if (dataLines.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const payload = dataLines.join('\n').trim();
|
||||
events.push({ event: eventType, payload });
|
||||
}
|
||||
return events;
|
||||
}
|
||||
|
||||
function replaySample(raw) {
|
||||
const events = parseSSE(raw);
|
||||
let currentType = 'thinking';
|
||||
let sawFinish = false;
|
||||
let outputText = '';
|
||||
let parsedChunks = 0;
|
||||
|
||||
for (const evt of events) {
|
||||
if (evt.event === 'finish') {
|
||||
sawFinish = true;
|
||||
}
|
||||
if (!evt.payload || evt.payload === '[DONE]' || evt.payload[0] !== '{') {
|
||||
continue;
|
||||
}
|
||||
let obj;
|
||||
try {
|
||||
obj = JSON.parse(evt.payload);
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
parsedChunks += 1;
|
||||
const parsed = parseChunkForContent(obj, true, currentType);
|
||||
currentType = parsed.newType;
|
||||
if (parsed.finished) {
|
||||
sawFinish = true;
|
||||
}
|
||||
for (const part of parsed.parts) {
|
||||
outputText += part.text;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
events: events.length,
|
||||
parsedChunks,
|
||||
sawFinish,
|
||||
leakedFinishedText: outputText.includes('FINISHED'),
|
||||
outputChars: outputText.length,
|
||||
};
|
||||
}
|
||||
|
||||
function main() {
|
||||
const opts = parseArgs(process.argv);
|
||||
const dirs = findSampleDirs(opts.samplesRoot);
|
||||
if (dirs.length === 0) {
|
||||
console.error(`[sim] no samples found: ${opts.samplesRoot}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const report = {
|
||||
generated_at: new Date().toISOString(),
|
||||
samples_root: opts.samplesRoot,
|
||||
total: dirs.length,
|
||||
failed: 0,
|
||||
samples: [],
|
||||
};
|
||||
|
||||
for (const dir of dirs) {
|
||||
const sampleID = path.basename(dir);
|
||||
const raw = fs.readFileSync(path.join(dir, 'upstream.stream.sse'), 'utf8');
|
||||
const r = replaySample(raw);
|
||||
const errors = [];
|
||||
if (opts.failOnMissingFinish && !r.sawFinish) {
|
||||
errors.push('missing finish signal');
|
||||
}
|
||||
if (opts.failOnLeak && r.leakedFinishedText) {
|
||||
errors.push('FINISHED leaked into output text');
|
||||
}
|
||||
if (errors.length > 0) {
|
||||
report.failed += 1;
|
||||
}
|
||||
report.samples.push({ sample_id: sampleID, ...r, ok: errors.length === 0, errors });
|
||||
}
|
||||
|
||||
if (opts.reportPath) {
|
||||
fs.writeFileSync(opts.reportPath, JSON.stringify(report, null, 2));
|
||||
}
|
||||
|
||||
for (const s of report.samples) {
|
||||
const status = s.ok ? 'OK' : 'FAIL';
|
||||
const note = s.errors.length > 0 ? ` errors=${s.errors.join(';')}` : '';
|
||||
console.log(`[sim] ${status} ${s.sample_id} events=${s.events} parsed=${s.parsedChunks} chars=${s.outputChars}${note}`);
|
||||
}
|
||||
|
||||
if (report.failed > 0) {
|
||||
console.error(`[sim] ${report.failed}/${report.total} samples failed`);
|
||||
process.exit(2);
|
||||
}
|
||||
console.log(`[sim] all ${report.total} samples passed`);
|
||||
}
|
||||
|
||||
main();
|
||||
@@ -64,6 +64,27 @@ export default function AccountManagerContainer({ config, onRefresh, onMessage,
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{Boolean(config?.env_source_present) && (
|
||||
<div className={`rounded-xl border px-4 py-3 text-sm ${
|
||||
config?.env_writeback_enabled
|
||||
? (config?.env_backed ? 'border-amber-500/30 bg-amber-500/10 text-amber-600' : 'border-emerald-500/30 bg-emerald-500/10 text-emerald-600')
|
||||
: 'border-amber-500/30 bg-amber-500/10 text-amber-600'
|
||||
}`}>
|
||||
<p className="font-medium">
|
||||
{config?.env_writeback_enabled
|
||||
? (config?.env_backed
|
||||
? t('accountManager.envModeWritebackPendingTitle')
|
||||
: t('accountManager.envModeWritebackActiveTitle'))
|
||||
: t('accountManager.envModeRiskTitle')}
|
||||
</p>
|
||||
<p className="mt-1 text-xs opacity-90">
|
||||
{config?.env_writeback_enabled
|
||||
? t('accountManager.envModeWritebackDesc', { path: config?.config_path || 'config.json' })
|
||||
: t('accountManager.envModeRiskDesc')}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<QueueCards queueStatus={queueStatus} t={t} />
|
||||
|
||||
<ApiKeysPanel
|
||||
|
||||
@@ -1,6 +1,31 @@
|
||||
import { useState } from 'react'
|
||||
import { Check, ChevronDown, Copy, Plus, Trash2 } from 'lucide-react'
|
||||
import clsx from 'clsx'
|
||||
|
||||
function fallbackCopyText(text) {
|
||||
const textArea = document.createElement('textarea')
|
||||
textArea.value = text
|
||||
textArea.setAttribute('readonly', '')
|
||||
textArea.style.position = 'fixed'
|
||||
textArea.style.top = '-9999px'
|
||||
textArea.style.left = '-9999px'
|
||||
|
||||
document.body.appendChild(textArea)
|
||||
textArea.focus()
|
||||
textArea.select()
|
||||
|
||||
let copied = false
|
||||
try {
|
||||
copied = document.execCommand('copy')
|
||||
} finally {
|
||||
document.body.removeChild(textArea)
|
||||
}
|
||||
|
||||
if (!copied) {
|
||||
throw new Error('copy failed')
|
||||
}
|
||||
}
|
||||
|
||||
export default function ApiKeysPanel({
|
||||
t,
|
||||
config,
|
||||
@@ -11,6 +36,31 @@ export default function ApiKeysPanel({
|
||||
setCopiedKey,
|
||||
onDeleteKey,
|
||||
}) {
|
||||
const [failedKey, setFailedKey] = useState(null)
|
||||
|
||||
const handleCopyKey = async (key) => {
|
||||
try {
|
||||
if (navigator.clipboard?.writeText) {
|
||||
await navigator.clipboard.writeText(key)
|
||||
} else {
|
||||
fallbackCopyText(key)
|
||||
}
|
||||
setCopiedKey(key)
|
||||
setFailedKey(null)
|
||||
setTimeout(() => setCopiedKey(null), 2000)
|
||||
} catch {
|
||||
try {
|
||||
fallbackCopyText(key)
|
||||
setCopiedKey(key)
|
||||
setFailedKey(null)
|
||||
setTimeout(() => setCopiedKey(null), 2000)
|
||||
} catch {
|
||||
setFailedKey(key)
|
||||
setTimeout(() => setFailedKey(null), 2500)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="bg-card border border-border rounded-xl overflow-hidden shadow-sm">
|
||||
<div
|
||||
@@ -42,28 +92,31 @@ export default function ApiKeysPanel({
|
||||
config.keys.map((key, i) => (
|
||||
<div key={i} className="p-4 flex items-center justify-between hover:bg-muted/50 transition-colors group">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="font-mono text-sm bg-muted/50 px-3 py-1 rounded inline-block">
|
||||
<button
|
||||
onClick={() => handleCopyKey(key)}
|
||||
className="font-mono text-sm bg-muted/50 px-3 py-1 rounded inline-block hover:bg-muted transition-colors"
|
||||
title={t('accountManager.copyKeyTitle')}
|
||||
>
|
||||
{key.slice(0, 16)}****
|
||||
</div>
|
||||
</button>
|
||||
{copiedKey === key && (
|
||||
<span className="text-xs text-green-500 animate-pulse">{t('accountManager.copied')}</span>
|
||||
)}
|
||||
{failedKey === key && (
|
||||
<span className="text-xs text-destructive">{t('accountManager.copyFailed')}</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
onClick={() => {
|
||||
navigator.clipboard.writeText(key)
|
||||
setCopiedKey(key)
|
||||
setTimeout(() => setCopiedKey(null), 2000)
|
||||
}}
|
||||
className="p-2 text-muted-foreground hover:text-primary hover:bg-primary/10 rounded-md transition-colors opacity-0 group-hover:opacity-100"
|
||||
onClick={() => handleCopyKey(key)}
|
||||
className="p-2 text-muted-foreground hover:text-primary hover:bg-primary/10 rounded-md transition-colors"
|
||||
title={t('accountManager.copyKeyTitle')}
|
||||
>
|
||||
{copiedKey === key ? <Check className="w-4 h-4 text-green-500" /> : <Copy className="w-4 h-4" />}
|
||||
</button>
|
||||
<button
|
||||
onClick={() => onDeleteKey(key)}
|
||||
className="p-2 text-muted-foreground hover:text-destructive hover:bg-destructive/10 rounded-md transition-colors opacity-0 group-hover:opacity-100"
|
||||
className="p-2 text-muted-foreground hover:text-destructive hover:bg-destructive/10 rounded-md transition-colors"
|
||||
title={t('accountManager.deleteKeyTitle')}
|
||||
>
|
||||
<Trash2 className="w-4 h-4" />
|
||||
|
||||
@@ -105,6 +105,7 @@
|
||||
"apiKeysDesc": "Manage the API access key pool",
|
||||
"addKey": "Add key",
|
||||
"copied": "Copied",
|
||||
"copyFailed": "Copy failed",
|
||||
"copyKeyTitle": "Copy key",
|
||||
"deleteKeyTitle": "Delete key",
|
||||
"noApiKeys": "No API keys found.",
|
||||
@@ -138,7 +139,12 @@
|
||||
"sessionCount": "Sessions: {count}",
|
||||
"deleteAllSessions": "Delete all sessions",
|
||||
"deleteAllSessionsConfirm": "Are you sure you want to delete all sessions for this account? This action cannot be undone.",
|
||||
"deleteAllSessionsSuccess": "Successfully deleted all sessions"
|
||||
"deleteAllSessionsSuccess": "Successfully deleted all sessions",
|
||||
"envModeRiskTitle": "Environment-variable config mode detected (persistence risk)",
|
||||
"envModeRiskDesc": "Detected DS2API_CONFIG_JSON/CONFIG_JSON. If DS2API_ENV_WRITEBACK is not enabled, Admin UI edits are in-memory only and may be lost after restart.",
|
||||
"envModeWritebackPendingTitle": "Env mode + auto-persistence enabled (pending file handoff)",
|
||||
"envModeWritebackActiveTitle": "Env mode + auto-persistence active",
|
||||
"envModeWritebackDesc": "The app will auto-create/write the config file and transition to file-backed mode. Current persistence path: {path}"
|
||||
},
|
||||
"apiTester": {
|
||||
"defaultMessage": "Hello, please introduce yourself in one sentence.",
|
||||
|
||||
@@ -105,6 +105,7 @@
|
||||
"apiKeysDesc": "管理 API 访问密钥池",
|
||||
"addKey": "添加密钥",
|
||||
"copied": "已复制",
|
||||
"copyFailed": "复制失败",
|
||||
"copyKeyTitle": "复制密钥",
|
||||
"deleteKeyTitle": "删除密钥",
|
||||
"noApiKeys": "未找到 API 密钥",
|
||||
@@ -138,7 +139,12 @@
|
||||
"sessionCount": "会话: {count}",
|
||||
"deleteAllSessions": "删除所有会话",
|
||||
"deleteAllSessionsConfirm": "确定要删除该账号的所有会话吗?此操作不可恢复。",
|
||||
"deleteAllSessionsSuccess": "删除成功"
|
||||
"deleteAllSessionsSuccess": "删除成功",
|
||||
"envModeRiskTitle": "当前为环境变量配置模式(有持久化风险)",
|
||||
"envModeRiskDesc": "检测到 DS2API_CONFIG_JSON/CONFIG_JSON。若未开启 DS2API_ENV_WRITEBACK,管理台改动仅在内存生效,重启可能丢失。",
|
||||
"envModeWritebackPendingTitle": "环境变量模式 + 自动持久化已开启(等待落盘)",
|
||||
"envModeWritebackActiveTitle": "环境变量模式 + 自动持久化已生效",
|
||||
"envModeWritebackDesc": "程序会自动创建/写入配置文件并在后续切换为文件模式。当前持久化路径:{path}"
|
||||
},
|
||||
"apiTester": {
|
||||
"defaultMessage": "你好,请用一句话介绍你自己。",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user