mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-07 18:05:30 +08:00
@@ -52,6 +52,9 @@ DS2API_ADMIN_KEY=admin
|
||||
|
||||
# Option C: Base64 encoded JSON (recommended for Vercel env var)
|
||||
# DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0=
|
||||
#
|
||||
# Generate from local config.json:
|
||||
# DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Paths (optional)
|
||||
|
||||
57
.github/workflows/release-artifacts.yml
vendored
57
.github/workflows/release-artifacts.yml
vendored
@@ -73,16 +73,6 @@ jobs:
|
||||
rm -rf "${STAGE}"
|
||||
done
|
||||
|
||||
(cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt)
|
||||
|
||||
- name: Upload Release Assets
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: |
|
||||
dist/*.tar.gz
|
||||
dist/*.zip
|
||||
dist/sha256sums.txt
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
@@ -96,11 +86,19 @@ jobs:
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Extract Docker metadata
|
||||
id: meta
|
||||
id: meta_release
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/${{ github.repository }}
|
||||
images: |
|
||||
ghcr.io/${{ github.repository }}
|
||||
cjackhwang/ds2api
|
||||
tags: |
|
||||
type=raw,value=${{ github.event.release.tag_name }}
|
||||
type=raw,value=latest
|
||||
@@ -112,5 +110,36 @@ jobs:
|
||||
file: ./Dockerfile
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
tags: ${{ steps.meta_release.outputs.tags }}
|
||||
labels: ${{ steps.meta_release.outputs.labels }}
|
||||
|
||||
- name: Export Docker image archives for release assets
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG="${{ github.event.release.tag_name }}"
|
||||
|
||||
docker buildx build \
|
||||
--platform linux/amd64 \
|
||||
--output type=docker,dest="dist/ds2api_${TAG}_docker_linux_amd64.tar" \
|
||||
.
|
||||
|
||||
docker buildx build \
|
||||
--platform linux/arm64 \
|
||||
--output type=docker,dest="dist/ds2api_${TAG}_docker_linux_arm64.tar" \
|
||||
.
|
||||
|
||||
gzip -f "dist/ds2api_${TAG}_docker_linux_amd64.tar"
|
||||
gzip -f "dist/ds2api_${TAG}_docker_linux_arm64.tar"
|
||||
|
||||
- name: Generate checksums
|
||||
run: |
|
||||
set -euo pipefail
|
||||
(cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt)
|
||||
|
||||
- name: Upload Release Assets
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: |
|
||||
dist/*.tar.gz
|
||||
dist/*.zip
|
||||
dist/sha256sums.txt
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -81,6 +81,9 @@ ds2api-tests
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.tox/
|
||||
*.coverprofile
|
||||
coverage*.out
|
||||
cover/
|
||||
|
||||
# Misc
|
||||
*.pyc
|
||||
|
||||
157
API.en.md
157
API.en.md
@@ -9,6 +9,7 @@ This document describes the actual behavior of the current Go codebase.
|
||||
## Table of Contents
|
||||
|
||||
- [Basics](#basics)
|
||||
- [Configuration Best Practice](#configuration-best-practice)
|
||||
- [Authentication](#authentication)
|
||||
- [Route Index](#route-index)
|
||||
- [Health Endpoints](#health-endpoints)
|
||||
@@ -27,7 +28,29 @@ This document describes the actual behavior of the current Go codebase.
|
||||
| Base URL | `http://localhost:5001` or your deployment domain |
|
||||
| Default Content-Type | `application/json` |
|
||||
| Health probes | `GET /healthz`, `GET /readyz` |
|
||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`) |
|
||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
---
|
||||
|
||||
## Configuration Best Practice
|
||||
|
||||
Use `config.json` as the single source of truth:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json (keys/accounts)
|
||||
```
|
||||
|
||||
Use it per deployment mode:
|
||||
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate Base64 from `config.json`, then set `DS2API_CONFIG_JSON`
|
||||
|
||||
```bash
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
```
|
||||
|
||||
For Vercel one-click bootstrap, you can set only `DS2API_ADMIN_KEY` first, then import config at `/admin` and sync env vars from the "Vercel Sync" page.
|
||||
|
||||
---
|
||||
|
||||
@@ -66,7 +89,11 @@ Two header formats accepted:
|
||||
| GET | `/healthz` | None | Liveness probe |
|
||||
| GET | `/readyz` | None | Readiness probe |
|
||||
| GET | `/v1/models` | None | OpenAI model list |
|
||||
| GET | `/v1/models/{id}` | None | OpenAI single-model query (alias accepted) |
|
||||
| POST | `/v1/chat/completions` | Business | OpenAI chat completions |
|
||||
| POST | `/v1/responses` | Business | OpenAI Responses API (stream/non-stream) |
|
||||
| GET | `/v1/responses/{response_id}` | Business | Query stored response (in-memory TTL) |
|
||||
| POST | `/v1/embeddings` | Business | OpenAI Embeddings API |
|
||||
| GET | `/anthropic/v1/models` | None | Claude model list |
|
||||
| POST | `/anthropic/v1/messages` | Business | Claude messages |
|
||||
| POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting |
|
||||
@@ -127,6 +154,15 @@ No auth required. Returns supported models.
|
||||
}
|
||||
```
|
||||
|
||||
### Model Alias Resolution
|
||||
|
||||
For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-output policy:
|
||||
|
||||
1. Match DeepSeek native model IDs first.
|
||||
2. Then match exact keys in `model_aliases`.
|
||||
3. If still unmatched, fall back by known family heuristics (`o*`, `gpt-*`, `claude-*`, etc.).
|
||||
4. If still unmatched, return `invalid_request_error`.
|
||||
|
||||
### `POST /v1/chat/completions`
|
||||
|
||||
**Headers**:
|
||||
@@ -140,7 +176,7 @@ Content-Type: application/json
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` |
|
||||
| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, etc.) |
|
||||
| `messages` | array | ✅ | OpenAI-style messages |
|
||||
| `stream` | boolean | ❌ | Default `false` |
|
||||
| `tools` | array | ❌ | Function calling schema |
|
||||
@@ -230,7 +266,63 @@ When `tools` is present, DS2API performs anti-leak handling:
|
||||
}
|
||||
```
|
||||
|
||||
**Stream**: DS2API buffers text first. If tool call detected → only structured `delta.tool_calls` (each with `index`); otherwise emits buffered text at once.
|
||||
**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full JSON closure), then keeps sending argument deltas; confirmed raw tool JSON is never forwarded as `delta.content`.
|
||||
|
||||
---
|
||||
|
||||
### `GET /v1/models/{id}`
|
||||
|
||||
No auth required. Alias values are accepted as path params (for example `gpt-4o`), and the returned object is the mapped DeepSeek model.
|
||||
|
||||
### `POST /v1/responses`
|
||||
|
||||
OpenAI Responses-style endpoint, accepting either `input` or `messages`.
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | Supports native models + alias mapping |
|
||||
| `input` | string/array/object | ❌ | One of `input` or `messages` is required |
|
||||
| `messages` | array | ❌ | One of `input` or `messages` is required |
|
||||
| `instructions` | string | ❌ | Prepended as a system message |
|
||||
| `stream` | boolean | ❌ | Default `false` |
|
||||
| `tools` | array | ❌ | Same tool detection/translation policy as chat |
|
||||
|
||||
**Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache.
|
||||
|
||||
**Stream (SSE)**: minimal event sequence:
|
||||
|
||||
```text
|
||||
event: response.created
|
||||
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
|
||||
|
||||
event: response.output_text.delta
|
||||
data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."}
|
||||
|
||||
event: response.output_tool_call.delta
|
||||
data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]}
|
||||
|
||||
event: response.completed
|
||||
data: {"type":"response.completed","response":{...}}
|
||||
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
### `GET /v1/responses/{response_id}`
|
||||
|
||||
Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read).
|
||||
|
||||
> Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`).
|
||||
|
||||
### `POST /v1/embeddings`
|
||||
|
||||
Business auth required. Returns OpenAI-compatible embeddings shape.
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | Supports native models + alias mapping |
|
||||
| `input` | string/array | ✅ | Supports string, string array, token array |
|
||||
|
||||
> Requires `embeddings.provider`. Current supported values: `mock` / `deterministic` / `builtin`. If missing/unsupported, returns standard error shape with HTTP 501.
|
||||
|
||||
---
|
||||
|
||||
@@ -249,7 +341,10 @@ No auth required.
|
||||
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
|
||||
]
|
||||
],
|
||||
"first_id": "claude-opus-4-6",
|
||||
"last_id": "claude-instant-1.0",
|
||||
"has_more": false
|
||||
}
|
||||
```
|
||||
|
||||
@@ -265,13 +360,15 @@ Content-Type: application/json
|
||||
anthropic-version: 2023-06-01
|
||||
```
|
||||
|
||||
> `anthropic-version` is optional; DS2API auto-fills `2023-06-01` when absent.
|
||||
|
||||
**Request body**:
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | For example `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5` (compatible with `claude-3-5-haiku-latest`), plus historical Claude model IDs |
|
||||
| `messages` | array | ✅ | Claude-style messages |
|
||||
| `max_tokens` | number | ❌ | Not strictly enforced by upstream bridge |
|
||||
| `max_tokens` | number | ❌ | Auto-filled to `8192` when omitted; not strictly enforced by upstream bridge |
|
||||
| `stream` | boolean | ❌ | Default `false` |
|
||||
| `system` | string | ❌ | Optional system prompt |
|
||||
| `tools` | array | ❌ | Claude tool schema |
|
||||
@@ -416,6 +513,7 @@ Returns sanitized config.
|
||||
"keys": ["k1", "k2"],
|
||||
"accounts": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
"email": "user@example.com",
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
@@ -476,6 +574,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
"email": "user@example.com",
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
@@ -500,7 +599,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
||||
|
||||
### `DELETE /admin/accounts/{identifier}`
|
||||
|
||||
`identifier` is email or mobile.
|
||||
`identifier` can be email, mobile, or the synthetic id for token-only accounts (`token:<hash>`).
|
||||
|
||||
**Response**: `{"success": true, "total_accounts": 5}`
|
||||
|
||||
@@ -530,7 +629,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
||||
|
||||
| Field | Required | Notes |
|
||||
| --- | --- | --- |
|
||||
| `identifier` | ✅ | email or mobile |
|
||||
| `identifier` | ✅ | email / mobile / token-only synthetic id |
|
||||
| `model` | ❌ | default `deepseek-chat` |
|
||||
| `message` | ❌ | if empty, only session creation is tested |
|
||||
|
||||
@@ -659,13 +758,20 @@ Or manual deploy required:
|
||||
|
||||
## Error Payloads
|
||||
|
||||
Error formats vary by module:
|
||||
Compatible routes (`/v1/*`, `/anthropic/*`) use the same error envelope:
|
||||
|
||||
| Module | Format |
|
||||
| --- | --- |
|
||||
| OpenAI routes | `{"error": {"message": "...", "type": "..."}}` |
|
||||
| Claude routes | `{"error": {"type": "...", "message": "..."}}` |
|
||||
| Admin routes | `{"detail": "..."}` |
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "...",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_request",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Admin routes keep `{"detail":"..."}`.
|
||||
|
||||
Clients should handle HTTP status code plus `error` / `detail` fields.
|
||||
|
||||
@@ -707,6 +813,31 @@ curl http://localhost:5001/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI Responses (Stream)
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/v1/responses \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5-codex",
|
||||
"input": "Write a hello world in golang",
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI Embeddings
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/v1/embeddings \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"input": ["first text", "second text"]
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI with Search
|
||||
|
||||
```bash
|
||||
|
||||
157
API.md
157
API.md
@@ -9,6 +9,7 @@
|
||||
## 目录
|
||||
|
||||
- [基础信息](#基础信息)
|
||||
- [配置最佳实践](#配置最佳实践)
|
||||
- [鉴权规则](#鉴权规则)
|
||||
- [路由总览](#路由总览)
|
||||
- [健康检查](#健康检查)
|
||||
@@ -27,7 +28,29 @@
|
||||
| Base URL | `http://localhost:5001` 或你的部署域名 |
|
||||
| 默认 Content-Type | `application/json` |
|
||||
| 健康检查 | `GET /healthz`、`GET /readyz` |
|
||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`) |
|
||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
---
|
||||
|
||||
## 配置最佳实践
|
||||
|
||||
推荐把 `config.json` 作为唯一配置源:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json(keys/accounts)
|
||||
```
|
||||
|
||||
按部署方式使用:
|
||||
|
||||
- 本地运行:直接读取 `config.json`
|
||||
- Docker / Vercel:从 `config.json` 生成 Base64,填入 `DS2API_CONFIG_JSON`
|
||||
|
||||
```bash
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
```
|
||||
|
||||
Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导入配置,再通过 “Vercel 同步” 写回环境变量。
|
||||
|
||||
---
|
||||
|
||||
@@ -66,7 +89,11 @@
|
||||
| GET | `/healthz` | 无 | 存活探针 |
|
||||
| GET | `/readyz` | 无 | 就绪探针 |
|
||||
| GET | `/v1/models` | 无 | OpenAI 模型列表 |
|
||||
| GET | `/v1/models/{id}` | 无 | OpenAI 单模型查询(支持 alias 入参) |
|
||||
| POST | `/v1/chat/completions` | 业务 | OpenAI 对话补全 |
|
||||
| POST | `/v1/responses` | 业务 | OpenAI Responses 接口(流式/非流式) |
|
||||
| GET | `/v1/responses/{response_id}` | 业务 | 查询已生成 response(内存 TTL) |
|
||||
| POST | `/v1/embeddings` | 业务 | OpenAI Embeddings 接口 |
|
||||
| GET | `/anthropic/v1/models` | 无 | Claude 模型列表 |
|
||||
| POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 |
|
||||
| POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 |
|
||||
@@ -127,6 +154,15 @@
|
||||
}
|
||||
```
|
||||
|
||||
### 模型 alias 解析策略
|
||||
|
||||
对 `chat` / `responses` / `embeddings` 的 `model` 字段采用“宽进严出”:
|
||||
|
||||
1. 先匹配 DeepSeek 原生模型。
|
||||
2. 再匹配 `model_aliases` 精确映射。
|
||||
3. 未命中时按模型家族规则回退(如 `o*`、`gpt-*`、`claude-*`)。
|
||||
4. 仍未命中则返回 `invalid_request_error`。
|
||||
|
||||
### `POST /v1/chat/completions`
|
||||
|
||||
**请求头**:
|
||||
@@ -140,7 +176,7 @@ Content-Type: application/json
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` |
|
||||
| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-4o`、`gpt-5-codex`、`o3`、`claude-sonnet-4-5`) |
|
||||
| `messages` | array | ✅ | OpenAI 风格消息数组 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `tools` | array | ❌ | Function Calling 定义 |
|
||||
@@ -230,7 +266,63 @@ data: [DONE]
|
||||
}
|
||||
```
|
||||
|
||||
**流式**:先缓冲正文片段。识别到工具调用 → 仅输出结构化 `delta.tool_calls`(每个 tool call 带 `index`);否则一次性输出普通文本。
|
||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
||||
|
||||
---
|
||||
|
||||
### `GET /v1/models/{id}`
|
||||
|
||||
无需鉴权。入参支持 alias(例如 `gpt-4o`),返回的是映射后的 DeepSeek 模型对象。
|
||||
|
||||
### `POST /v1/responses`
|
||||
|
||||
OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | 支持原生模型 + alias 自动映射 |
|
||||
| `input` | string/array/object | ❌ | 与 `messages` 二选一 |
|
||||
| `messages` | array | ❌ | 与 `input` 二选一 |
|
||||
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 |
|
||||
|
||||
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
||||
|
||||
**流式响应(SSE)**:最小事件序列如下。
|
||||
|
||||
```text
|
||||
event: response.created
|
||||
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
|
||||
|
||||
event: response.output_text.delta
|
||||
data: {"type":"response.output_text.delta","id":"resp_xxx","delta":"..."}
|
||||
|
||||
event: response.output_tool_call.delta
|
||||
data: {"type":"response.output_tool_call.delta","id":"resp_xxx","tool_calls":[...]}
|
||||
|
||||
event: response.completed
|
||||
data: {"type":"response.completed","response":{...}}
|
||||
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
### `GET /v1/responses/{response_id}`
|
||||
|
||||
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。
|
||||
|
||||
> 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。
|
||||
|
||||
### `POST /v1/embeddings`
|
||||
|
||||
需要业务鉴权。返回 OpenAI Embeddings 兼容结构。
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | 支持原生模型 + alias 自动映射 |
|
||||
| `input` | string/array | ✅ | 支持字符串、字符串数组、token 数组 |
|
||||
|
||||
> 需配置 `embeddings.provider`。当前支持:`mock` / `deterministic` / `builtin`。未配置或不支持时返回标准错误结构(HTTP 501)。
|
||||
|
||||
---
|
||||
|
||||
@@ -249,7 +341,10 @@ data: [DONE]
|
||||
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
|
||||
]
|
||||
],
|
||||
"first_id": "claude-opus-4-6",
|
||||
"last_id": "claude-instant-1.0",
|
||||
"has_more": false
|
||||
}
|
||||
```
|
||||
|
||||
@@ -265,13 +360,15 @@ Content-Type: application/json
|
||||
anthropic-version: 2023-06-01
|
||||
```
|
||||
|
||||
> `anthropic-version` 可省略,服务端会自动补为 `2023-06-01`。
|
||||
|
||||
**请求体**:
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | 例如 `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5`(兼容 `claude-3-5-haiku-latest`),并支持历史 Claude 模型 ID |
|
||||
| `messages` | array | ✅ | Claude 风格消息数组 |
|
||||
| `max_tokens` | number | ❌ | 当前实现不会硬性截断上游输出 |
|
||||
| `max_tokens` | number | ❌ | 缺省自动补 `8192`;当前实现不会硬性截断上游输出 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `system` | string | ❌ | 可选系统提示 |
|
||||
| `tools` | array | ❌ | Claude tool 定义 |
|
||||
@@ -416,6 +513,7 @@ data: {"type":"message_stop"}
|
||||
"keys": ["k1", "k2"],
|
||||
"accounts": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
"email": "user@example.com",
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
@@ -476,6 +574,7 @@ data: {"type":"message_stop"}
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
"email": "user@example.com",
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
@@ -500,7 +599,7 @@ data: {"type":"message_stop"}
|
||||
|
||||
### `DELETE /admin/accounts/{identifier}`
|
||||
|
||||
`identifier` 为 email 或 mobile。
|
||||
`identifier` 可为 email、mobile,或 token-only 账号的合成标识(`token:<hash>`)。
|
||||
|
||||
**响应**:`{"success": true, "total_accounts": 5}`
|
||||
|
||||
@@ -530,7 +629,7 @@ data: {"type":"message_stop"}
|
||||
|
||||
| 字段 | 必填 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| `identifier` | ✅ | email 或 mobile |
|
||||
| `identifier` | ✅ | email / mobile / token-only 合成标识 |
|
||||
| `model` | ❌ | 默认 `deepseek-chat` |
|
||||
| `message` | ❌ | 空字符串时仅测试会话创建 |
|
||||
|
||||
@@ -659,13 +758,20 @@ data: {"type":"message_stop"}
|
||||
|
||||
## 错误响应格式
|
||||
|
||||
不同模块的错误格式略有差异:
|
||||
兼容路由(`/v1/*`、`/anthropic/*`)统一使用以下结构:
|
||||
|
||||
| 模块 | 格式 |
|
||||
| --- | --- |
|
||||
| OpenAI 接口 | `{"error": {"message": "...", "type": "..."}}` |
|
||||
| Claude 接口 | `{"error": {"type": "...", "message": "..."}}` |
|
||||
| Admin 接口 | `{"detail": "..."}` |
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "...",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_request",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Admin 接口保持 `{"detail":"..."}`。
|
||||
|
||||
建议客户端处理逻辑:检查 HTTP 状态码 + 解析 `error` 或 `detail` 字段。
|
||||
|
||||
@@ -707,6 +813,31 @@ curl http://localhost:5001/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI Responses(流式)
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/v1/responses \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5-codex",
|
||||
"input": "写一个 golang 的 hello world",
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI Embeddings
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/v1/embeddings \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"input": ["第一段文本", "第二段文本"]
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI 带搜索
|
||||
|
||||
```bash
|
||||
|
||||
59
DEPLOY.en.md
59
DEPLOY.en.md
@@ -33,6 +33,17 @@ Config source (choose one):
|
||||
- **File**: `config.json` (recommended for local/Docker)
|
||||
- **Environment variable**: `DS2API_CONFIG_JSON` (recommended for Vercel; supports raw JSON or Base64)
|
||||
|
||||
Unified recommendation (best practice):
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
```
|
||||
|
||||
Use `config.json` as the single source of truth:
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate `DS2API_CONFIG_JSON` (Base64) from `config.json` and inject it
|
||||
|
||||
---
|
||||
|
||||
## 1. Local Run
|
||||
@@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api
|
||||
### 2.1 Basic Steps
|
||||
|
||||
```bash
|
||||
# Copy and edit environment
|
||||
# Copy env template
|
||||
cp .env.example .env
|
||||
# Edit .env, at minimum set:
|
||||
|
||||
# Generate single-line Base64 from config.json
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# Edit .env and set:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]}
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# Start
|
||||
docker-compose up -d
|
||||
@@ -167,15 +182,49 @@ If container logs look normal but the admin panel is unreachable, check these fi
|
||||
|
||||
1. **Fork** the repo to your GitHub account
|
||||
2. **Import** the project on Vercel
|
||||
3. **Set environment variables** (at minimum):
|
||||
3. **Set environment variables** (minimum required: one variable):
|
||||
|
||||
| Variable | Description |
|
||||
| --- | --- |
|
||||
| `DS2API_ADMIN_KEY` | Admin key (required) |
|
||||
| `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (required) |
|
||||
| `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (optional, recommended) |
|
||||
|
||||
4. **Deploy**
|
||||
|
||||
### 3.1.1 Recommended Input (avoid `DS2API_CONFIG_JSON` mistakes)
|
||||
|
||||
If you prefer faster one-click bootstrap, you can leave `DS2API_CONFIG_JSON` empty first, then open `/admin` after deployment, import config, and sync it back to Vercel env vars from the "Vercel Sync" page.
|
||||
|
||||
Recommended: in repo root, copy the template first and fill your real accounts:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
```
|
||||
|
||||
Do not hand-edit large JSON directly in Vercel. Generate Base64 locally and paste it:
|
||||
|
||||
```bash
|
||||
# Run in repo root
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
echo "$DS2API_CONFIG_JSON"
|
||||
```
|
||||
|
||||
If you choose to preconfigure before first deploy, set these vars in Vercel Project Settings -> Environment Variables:
|
||||
|
||||
```text
|
||||
DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||
DS2API_CONFIG_JSON=<the single-line Base64 output above>
|
||||
```
|
||||
|
||||
Optional but recommended (for WebUI one-click Vercel sync):
|
||||
|
||||
```text
|
||||
VERCEL_TOKEN=your-vercel-token
|
||||
VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
|
||||
VERCEL_TEAM_ID=team_xxxxxxxxxxxx # optional for personal accounts
|
||||
```
|
||||
|
||||
### 3.2 Optional Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|
||||
59
DEPLOY.md
59
DEPLOY.md
@@ -33,6 +33,17 @@
|
||||
- **文件方式**:`config.json`(推荐本地/Docker 使用)
|
||||
- **环境变量方式**:`DS2API_CONFIG_JSON`(推荐 Vercel 使用,支持 JSON 字符串或 Base64 编码)
|
||||
|
||||
统一建议(最优实践):
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
```
|
||||
|
||||
建议把 `config.json` 作为唯一配置源:
|
||||
- 本地运行:直接读 `config.json`
|
||||
- Docker / Vercel:从 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量
|
||||
|
||||
---
|
||||
|
||||
## 一、本地运行
|
||||
@@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api
|
||||
### 2.1 基本步骤
|
||||
|
||||
```bash
|
||||
# 复制并编辑环境变量
|
||||
# 复制环境变量模板
|
||||
cp .env.example .env
|
||||
# 编辑 .env,至少设置:
|
||||
|
||||
# 从 config.json 生成单行 Base64
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# 编辑 .env(请改成你的强密码),设置:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]}
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# 启动
|
||||
docker-compose up -d
|
||||
@@ -167,15 +182,49 @@ healthcheck:
|
||||
|
||||
1. **Fork 仓库**到你的 GitHub 账号
|
||||
2. **在 Vercel 上导入项目**
|
||||
3. **配置环境变量**(至少设置以下两项):
|
||||
3. **配置环境变量**(最少只需设置以下一项):
|
||||
|
||||
| 变量 | 说明 |
|
||||
| --- | --- |
|
||||
| `DS2API_ADMIN_KEY` | 管理密钥(必填) |
|
||||
| `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(必填) |
|
||||
| `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(可选,建议) |
|
||||
|
||||
4. **部署**
|
||||
|
||||
### 3.1.1 推荐填写方式(避免 `DS2API_CONFIG_JSON` 填错)
|
||||
|
||||
如果你想先完成一键部署,也可以先不填 `DS2API_CONFIG_JSON`,部署后进入 `/admin` 导入配置,再在「Vercel 同步」里写回环境变量。
|
||||
|
||||
建议先在仓库目录复制示例配置,再按实际账号填写:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
```
|
||||
|
||||
不要在 Vercel 面板里手写复杂 JSON,建议本地生成 Base64 后粘贴:
|
||||
|
||||
```bash
|
||||
# 在仓库根目录执行
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
echo "$DS2API_CONFIG_JSON"
|
||||
```
|
||||
|
||||
如果你选择在部署前就预置配置,请在 Vercel Project Settings -> Environment Variables 配置:
|
||||
|
||||
```text
|
||||
DS2API_ADMIN_KEY=请替换为强密码
|
||||
DS2API_CONFIG_JSON=上一步生成的一整行 Base64
|
||||
```
|
||||
|
||||
可选但推荐(用于 WebUI 一键同步 Vercel 配置):
|
||||
|
||||
```text
|
||||
VERCEL_TOKEN=你的 Vercel Token
|
||||
VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
|
||||
VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空
|
||||
```
|
||||
|
||||
### 3.2 可选环境变量
|
||||
|
||||
| 变量 | 说明 | 默认值 |
|
||||
|
||||
89
README.MD
89
README.MD
@@ -54,16 +54,27 @@ flowchart LR
|
||||
|
||||
| 能力 | 说明 |
|
||||
| --- | --- |
|
||||
| OpenAI 兼容 | `GET /v1/models`、`POST /v1/chat/completions`(流式/非流式) |
|
||||
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` |
|
||||
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` |
|
||||
| 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 |
|
||||
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
|
||||
| DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 |
|
||||
| Tool Calling | 防泄漏处理:自动缓冲、识别、结构化输出 |
|
||||
| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 |
|
||||
| Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 |
|
||||
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
|
||||
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
|
||||
|
||||
## 平台兼容矩阵
|
||||
|
||||
| 级别 | 平台 | 当前状态 |
|
||||
| --- | --- | --- |
|
||||
| P0 | Codex CLI/SDK(`wire_api=chat` / `wire_api=responses`) | ✅ |
|
||||
| P0 | OpenAI SDK(JS/Python,chat + responses) | ✅ |
|
||||
| P0 | Vercel AI SDK(openai-compatible) | ✅ |
|
||||
| P0 | Anthropic SDK(messages) | ✅ |
|
||||
| P1 | LangChain / LlamaIndex / OpenWebUI(OpenAI 兼容接入) | ✅ |
|
||||
| P2 | MCP 独立桥接层 | 规划中 |
|
||||
|
||||
## 模型支持
|
||||
|
||||
### OpenAI 接口
|
||||
@@ -88,6 +99,19 @@ flowchart LR
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 通用第一步(所有部署方式)
|
||||
|
||||
把 `config.json` 作为唯一配置源(推荐做法):
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
```
|
||||
|
||||
后续部署建议:
|
||||
- 本地运行:直接读取 `config.json`
|
||||
- Docker / Vercel:由 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量
|
||||
|
||||
### 方式一:本地运行
|
||||
|
||||
**前置要求**:Go 1.24+,Node.js 20+(仅在需要构建 WebUI 时)
|
||||
@@ -112,14 +136,20 @@ go run ./cmd/ds2api
|
||||
### 方式二:Docker 运行
|
||||
|
||||
```bash
|
||||
# 1. 配置环境变量
|
||||
# 1. 准备环境变量文件
|
||||
cp .env.example .env
|
||||
# 编辑 .env
|
||||
|
||||
# 2. 启动
|
||||
# 2. 从 config.json 生成 DS2API_CONFIG_JSON(单行 Base64)
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# 3. 编辑 .env,设置:
|
||||
# DS2API_ADMIN_KEY=请替换为强密码
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# 4. 启动
|
||||
docker-compose up -d
|
||||
|
||||
# 3. 查看日志
|
||||
# 5. 查看日志
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
@@ -129,9 +159,22 @@ docker-compose logs -f
|
||||
|
||||
1. Fork 仓库到自己的 GitHub
|
||||
2. 在 Vercel 上导入项目
|
||||
3. 配置环境变量(至少设置 `DS2API_ADMIN_KEY` 和 `DS2API_CONFIG_JSON`)
|
||||
3. 配置环境变量(最少设置 `DS2API_ADMIN_KEY`;推荐同时设置 `DS2API_CONFIG_JSON`)
|
||||
4. 部署
|
||||
|
||||
建议先在仓库目录复制模板并填写:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
```
|
||||
|
||||
推荐:先本地把 `config.json` 转成 Base64,再粘贴到 `DS2API_CONFIG_JSON`,避免 JSON 格式错误:
|
||||
|
||||
```bash
|
||||
base64 < config.json | tr -d '\n'
|
||||
```
|
||||
|
||||
> **流式说明**:`/v1/chat/completions` 在 Vercel 上默认走 `api/chat-stream.js`(Node Runtime)以保证实时 SSE。鉴权、账号选择、会话/PoW 准备仍由 Go 内部 prepare 接口完成;流式响应(含 `tools`)在 Node 侧执行与 Go 对齐的输出组装与防泄漏处理。
|
||||
|
||||
详细部署说明请参阅 [部署指南](DEPLOY.md)。
|
||||
@@ -164,6 +207,7 @@ cp opencode.json.example opencode.json
|
||||
3. 在项目目录启动 OpenCode CLI(按你的安装方式运行 `opencode`)。
|
||||
|
||||
> 建议优先使用 OpenAI 兼容路径(`/v1/*`),即示例里的 `@ai-sdk/openai-compatible` provider。
|
||||
> 若客户端支持 `wire_api`,可分别测试 `responses` 与 `chat`,DS2API 两条链路都兼容。
|
||||
|
||||
## 配置说明
|
||||
|
||||
@@ -184,6 +228,24 @@ cp opencode.json.example opencode.json
|
||||
"token": ""
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true
|
||||
},
|
||||
"toolcall": {
|
||||
"mode": "feature_match",
|
||||
"early_emit_confidence": "high"
|
||||
},
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_model_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
@@ -194,6 +256,11 @@ cp opencode.json.example opencode.json
|
||||
- `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer <key>` 鉴权
|
||||
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录
|
||||
- `token`:留空则首次请求时自动登录获取;也可预填已有 token
|
||||
- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射
|
||||
- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出)
|
||||
- `toolcall`:固定采用特征匹配 + 高置信早发策略
|
||||
- `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL
|
||||
- `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`)
|
||||
- `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型
|
||||
|
||||
### 环境变量
|
||||
@@ -249,10 +316,10 @@ cp opencode.json.example opencode.json
|
||||
|
||||
当请求中带 `tools` 时,DS2API 会做防泄漏处理:
|
||||
|
||||
1. `stream=true` 时先**缓冲**正文片段
|
||||
2. 若识别到工具调用 → 仅输出结构化 `tool_calls`,不透传原始 JSON 文本
|
||||
3. 若最终不是工具调用 → 一次性输出普通文本
|
||||
4. 解析器支持混合文本、fenced JSON、`function.arguments` 字符串等格式
|
||||
1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发)
|
||||
2. 一旦命中高置信特征(`tool_calls` + `name` + `arguments/input` 起始)就立即输出 `delta.tool_calls`
|
||||
3. 已确认的 toolcall JSON 片段不会泄漏到 `delta.content`
|
||||
4. 前文/后文自然语言保持顺序透传,支持混合文本与增量参数输出
|
||||
|
||||
## 项目结构
|
||||
|
||||
|
||||
89
README.en.md
89
README.en.md
@@ -54,16 +54,27 @@ flowchart LR
|
||||
|
||||
| Capability | Details |
|
||||
| --- | --- |
|
||||
| OpenAI compatible | `GET /v1/models`, `POST /v1/chat/completions` (stream/non-stream) |
|
||||
| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` |
|
||||
| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` |
|
||||
| Multi-account rotation | Auto token refresh, email/mobile dual login |
|
||||
| Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency |
|
||||
| DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency |
|
||||
| Tool Calling | Anti-leak handling: auto buffer, detect, structured output |
|
||||
| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output |
|
||||
| Admin API | Config management, account testing/batch test, import/export, Vercel sync |
|
||||
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) |
|
||||
| Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) |
|
||||
|
||||
## Platform Compatibility Matrix
|
||||
|
||||
| Tier | Platform | Status |
|
||||
| --- | --- | --- |
|
||||
| P0 | Codex CLI/SDK (`wire_api=chat` / `wire_api=responses`) | ✅ |
|
||||
| P0 | OpenAI SDK (JS/Python, chat + responses) | ✅ |
|
||||
| P0 | Vercel AI SDK (openai-compatible) | ✅ |
|
||||
| P0 | Anthropic SDK (messages) | ✅ |
|
||||
| P1 | LangChain / LlamaIndex / OpenWebUI (OpenAI-compatible integration) | ✅ |
|
||||
| P2 | MCP standalone bridge | Planned |
|
||||
|
||||
## Model Support
|
||||
|
||||
### OpenAI Endpoint
|
||||
@@ -88,6 +99,19 @@ In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Universal First Step (all deployment modes)
|
||||
|
||||
Use `config.json` as the single source of truth (recommended):
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
```
|
||||
|
||||
Recommended per deployment mode:
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON`
|
||||
|
||||
### Option 1: Local Run
|
||||
|
||||
**Prerequisites**: Go 1.24+, Node.js 20+ (only if building WebUI locally)
|
||||
@@ -112,14 +136,20 @@ Default URL: `http://localhost:5001`
|
||||
### Option 2: Docker
|
||||
|
||||
```bash
|
||||
# 1. Configure environment
|
||||
# 1. Prepare env file
|
||||
cp .env.example .env
|
||||
# Edit .env
|
||||
|
||||
# 2. Start
|
||||
# 2. Generate DS2API_CONFIG_JSON from config.json (single-line Base64)
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# 3. Edit .env and set:
|
||||
# DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# 4. Start
|
||||
docker-compose up -d
|
||||
|
||||
# 3. View logs
|
||||
# 5. View logs
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
@@ -129,9 +159,22 @@ Rebuild after updates: `docker-compose up -d --build`
|
||||
|
||||
1. Fork this repo to your GitHub account
|
||||
2. Import the project on Vercel
|
||||
3. Set environment variables (minimum: `DS2API_ADMIN_KEY` and `DS2API_CONFIG_JSON`)
|
||||
3. Set environment variables (minimum: `DS2API_ADMIN_KEY`; recommended to also set `DS2API_CONFIG_JSON`)
|
||||
4. Deploy
|
||||
|
||||
Recommended first step in repo root:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
```
|
||||
|
||||
Recommended: convert `config.json` to Base64 locally, then paste into `DS2API_CONFIG_JSON` to avoid JSON formatting mistakes:
|
||||
|
||||
```bash
|
||||
base64 < config.json | tr -d '\n'
|
||||
```
|
||||
|
||||
> **Streaming note**: `/v1/chat/completions` on Vercel is routed to `api/chat-stream.js` (Node Runtime) for real-time SSE. Auth, account selection, and session/PoW preparation are still handled by the Go internal prepare endpoint; streaming output (including `tools`) is assembled on Node with Go-aligned anti-leak handling.
|
||||
|
||||
For detailed deployment instructions, see the [Deployment Guide](DEPLOY.en.md).
|
||||
@@ -164,6 +207,7 @@ cp opencode.json.example opencode.json
|
||||
3. Start OpenCode CLI in the project directory (run `opencode` using your installed method).
|
||||
|
||||
> Recommended: use the OpenAI-compatible path (`/v1/*`) via `@ai-sdk/openai-compatible` as shown in the example.
|
||||
> If your client supports `wire_api`, test both `responses` and `chat`; DS2API supports both paths.
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -184,6 +228,24 @@ cp opencode.json.example opencode.json
|
||||
"token": ""
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true
|
||||
},
|
||||
"toolcall": {
|
||||
"mode": "feature_match",
|
||||
"early_emit_confidence": "high"
|
||||
},
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_model_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
@@ -194,6 +256,11 @@ cp opencode.json.example opencode.json
|
||||
- `keys`: API access keys; clients authenticate via `Authorization: Bearer <key>`
|
||||
- `accounts`: DeepSeek account list, supports `email` or `mobile` login
|
||||
- `token`: Leave empty for auto-login on first request; or pre-fill an existing token
|
||||
- `model_aliases`: Map common model names (GPT/Codex/Claude) to DeepSeek models
|
||||
- `compat.wide_input_strict_output`: Keep `true` (current default policy)
|
||||
- `toolcall`: Fixed to feature matching + high-confidence early emit
|
||||
- `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}`
|
||||
- `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in)
|
||||
- `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models
|
||||
|
||||
### Environment Variables
|
||||
@@ -249,10 +316,10 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
||||
|
||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||
|
||||
1. With `stream=true`, DS2API **buffers** text deltas first
|
||||
2. If a tool call is detected → only structured `tool_calls` are emitted, raw JSON is not leaked
|
||||
3. If no tool call → buffered text is emitted at once
|
||||
4. Parser supports mixed text, fenced JSON, and `function.arguments` payloads
|
||||
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||
2. Once high-confidence features are matched (`tool_calls` + `name` + `arguments/input` start), `delta.tool_calls` is emitted immediately
|
||||
3. Confirmed toolcall JSON fragments are never leaked into `delta.content`
|
||||
4. Natural language before/after toolcalls keeps original order, with incremental argument output supported
|
||||
|
||||
## Project Structure
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
'use strict';
|
||||
|
||||
const crypto = require('crypto');
|
||||
|
||||
const {
|
||||
extractToolNames,
|
||||
createToolSieveState,
|
||||
@@ -83,23 +85,57 @@ module.exports = async function handler(req, res) {
|
||||
const finalPrompt = asString(prep.body.final_prompt);
|
||||
const thinkingEnabled = toBool(prep.body.thinking_enabled);
|
||||
const searchEnabled = toBool(prep.body.search_enabled);
|
||||
const toolNames = extractToolNames(payload.tools);
|
||||
const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools);
|
||||
const toolNames = toolPolicy.toolNames;
|
||||
|
||||
if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
|
||||
writeOpenAIError(res, 500, 'invalid vercel prepare response');
|
||||
return;
|
||||
}
|
||||
const releaseLease = createLeaseReleaser(req, leaseID);
|
||||
const upstreamController = new AbortController();
|
||||
let clientClosed = false;
|
||||
let reader = null;
|
||||
const markClientClosed = () => {
|
||||
if (clientClosed) {
|
||||
return;
|
||||
}
|
||||
clientClosed = true;
|
||||
upstreamController.abort();
|
||||
if (reader && typeof reader.cancel === 'function') {
|
||||
Promise.resolve(reader.cancel()).catch(() => {});
|
||||
}
|
||||
};
|
||||
const onReqAborted = () => markClientClosed();
|
||||
const onResClose = () => {
|
||||
if (!res.writableEnded) {
|
||||
markClientClosed();
|
||||
}
|
||||
};
|
||||
req.on('aborted', onReqAborted);
|
||||
res.on('close', onResClose);
|
||||
try {
|
||||
const completionRes = await fetch(DEEPSEEK_COMPLETION_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
...BASE_HEADERS,
|
||||
authorization: `Bearer ${deepseekToken}`,
|
||||
'x-ds-pow-response': powHeader,
|
||||
},
|
||||
body: JSON.stringify(completionPayload),
|
||||
});
|
||||
let completionRes;
|
||||
try {
|
||||
completionRes = await fetch(DEEPSEEK_COMPLETION_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
...BASE_HEADERS,
|
||||
authorization: `Bearer ${deepseekToken}`,
|
||||
'x-ds-pow-response': powHeader,
|
||||
},
|
||||
body: JSON.stringify(completionPayload),
|
||||
signal: upstreamController.signal,
|
||||
});
|
||||
} catch (err) {
|
||||
if (clientClosed || isAbortError(err)) {
|
||||
return;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
if (clientClosed) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!completionRes.ok || !completionRes.body) {
|
||||
const detail = await safeReadText(completionRes);
|
||||
@@ -121,15 +157,20 @@ module.exports = async function handler(req, res) {
|
||||
let currentType = thinkingEnabled ? 'thinking' : 'text';
|
||||
let thinkingText = '';
|
||||
let outputText = '';
|
||||
const toolSieveEnabled = toolNames.length > 0;
|
||||
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
|
||||
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
|
||||
const toolSieveState = createToolSieveState();
|
||||
let toolCallsEmitted = false;
|
||||
const streamToolCallIDs = new Map();
|
||||
const decoder = new TextDecoder();
|
||||
const reader = completionRes.body.getReader();
|
||||
reader = completionRes.body.getReader();
|
||||
let buffered = '';
|
||||
let ended = false;
|
||||
|
||||
const sendFrame = (obj) => {
|
||||
if (clientClosed || res.writableEnded || res.destroyed) {
|
||||
return;
|
||||
}
|
||||
res.write(`data: ${JSON.stringify(obj)}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
@@ -156,6 +197,10 @@ module.exports = async function handler(req, res) {
|
||||
return;
|
||||
}
|
||||
ended = true;
|
||||
if (clientClosed || res.writableEnded || res.destroyed) {
|
||||
await releaseLease();
|
||||
return;
|
||||
}
|
||||
const detected = parseToolCalls(outputText, toolNames);
|
||||
if (detected.length > 0 && !toolCallsEmitted) {
|
||||
toolCallsEmitted = true;
|
||||
@@ -179,14 +224,22 @@ module.exports = async function handler(req, res) {
|
||||
choices: [{ delta: {}, index: 0, finish_reason: reason }],
|
||||
usage: buildUsage(finalPrompt, thinkingText, outputText),
|
||||
});
|
||||
res.write('data: [DONE]\n\n');
|
||||
if (!res.writableEnded && !res.destroyed) {
|
||||
res.write('data: [DONE]\n\n');
|
||||
}
|
||||
await releaseLease();
|
||||
res.end();
|
||||
if (!res.writableEnded && !res.destroyed) {
|
||||
res.end();
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
if (clientClosed) {
|
||||
await finish('stop');
|
||||
return;
|
||||
}
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
@@ -245,6 +298,14 @@ module.exports = async function handler(req, res) {
|
||||
}
|
||||
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
||||
for (const evt of events) {
|
||||
if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) {
|
||||
if (!emitEarlyToolDeltas) {
|
||||
continue;
|
||||
}
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) });
|
||||
continue;
|
||||
}
|
||||
if (evt.type === 'tool_calls') {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) });
|
||||
@@ -259,10 +320,16 @@ module.exports = async function handler(req, res) {
|
||||
}
|
||||
}
|
||||
await finish('stop');
|
||||
} catch (_err) {
|
||||
} catch (err) {
|
||||
if (clientClosed || isAbortError(err)) {
|
||||
await finish('stop');
|
||||
return;
|
||||
}
|
||||
await finish('stop');
|
||||
}
|
||||
} finally {
|
||||
req.removeListener('aborted', onReqAborted);
|
||||
res.removeListener('close', onResClose);
|
||||
await releaseLease();
|
||||
}
|
||||
};
|
||||
@@ -345,6 +412,37 @@ function relayPreparedFailure(res, prep) {
|
||||
writeOpenAIError(res, prep.status || 500, 'vercel prepare failed');
|
||||
}
|
||||
|
||||
function resolveToolcallPolicy(prepBody, payloadTools) {
|
||||
const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names);
|
||||
const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
|
||||
const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match);
|
||||
const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
|
||||
return {
|
||||
toolNames,
|
||||
toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled,
|
||||
emitEarlyToolDeltas,
|
||||
};
|
||||
}
|
||||
|
||||
function normalizePreparedToolNames(v) {
|
||||
if (!Array.isArray(v) || v.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const item of v) {
|
||||
const name = asString(item);
|
||||
if (!name) {
|
||||
continue;
|
||||
}
|
||||
out.push(name);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function boolDefaultTrue(v) {
|
||||
return v !== false;
|
||||
}
|
||||
|
||||
async function safeReadText(resp) {
|
||||
if (!resp) {
|
||||
return '';
|
||||
@@ -656,6 +754,55 @@ function buildUsage(prompt, thinking, output) {
|
||||
};
|
||||
}
|
||||
|
||||
function formatIncrementalToolCallDeltas(deltas, idStore) {
|
||||
if (!Array.isArray(deltas) || deltas.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const d of deltas) {
|
||||
if (!d || typeof d !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const index = Number.isInteger(d.index) ? d.index : 0;
|
||||
const id = ensureStreamToolCallID(idStore, index);
|
||||
const item = {
|
||||
index,
|
||||
id,
|
||||
type: 'function',
|
||||
};
|
||||
const fn = {};
|
||||
if (asString(d.name)) {
|
||||
fn.name = asString(d.name);
|
||||
}
|
||||
if (typeof d.arguments === 'string' && d.arguments !== '') {
|
||||
fn.arguments = d.arguments;
|
||||
}
|
||||
if (Object.keys(fn).length > 0) {
|
||||
item.function = fn;
|
||||
}
|
||||
out.push(item);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function ensureStreamToolCallID(idStore, index) {
|
||||
const key = Number.isInteger(index) ? index : 0;
|
||||
const existing = idStore.get(key);
|
||||
if (existing) {
|
||||
return existing;
|
||||
}
|
||||
const next = `call_${newCallID()}`;
|
||||
idStore.set(key, next);
|
||||
return next;
|
||||
}
|
||||
|
||||
function newCallID() {
|
||||
if (typeof crypto.randomUUID === 'function') {
|
||||
return crypto.randomUUID().replace(/-/g, '');
|
||||
}
|
||||
return `${Date.now()}${Math.floor(Math.random() * 1e9)}`;
|
||||
}
|
||||
|
||||
function estimateTokens(text) {
|
||||
const t = asString(text);
|
||||
if (!t) {
|
||||
@@ -667,44 +814,92 @@ function estimateTokens(text) {
|
||||
|
||||
async function proxyToGo(req, res, rawBody) {
|
||||
const url = buildInternalGoURL(req);
|
||||
|
||||
const upstream = await fetch(url.toString(), {
|
||||
method: 'POST',
|
||||
headers: buildInternalGoHeaders(req, { withContentType: true }),
|
||||
body: rawBody,
|
||||
});
|
||||
|
||||
res.statusCode = upstream.status;
|
||||
upstream.headers.forEach((value, key) => {
|
||||
if (key.toLowerCase() === 'content-length') {
|
||||
const controller = new AbortController();
|
||||
let clientClosed = false;
|
||||
const markClientClosed = () => {
|
||||
if (clientClosed) {
|
||||
return;
|
||||
}
|
||||
res.setHeader(key, value);
|
||||
});
|
||||
clientClosed = true;
|
||||
controller.abort();
|
||||
};
|
||||
const onReqAborted = () => markClientClosed();
|
||||
const onResClose = () => {
|
||||
if (!res.writableEnded) {
|
||||
markClientClosed();
|
||||
}
|
||||
};
|
||||
req.on('aborted', onReqAborted);
|
||||
res.on('close', onResClose);
|
||||
|
||||
if (!upstream.body || typeof upstream.body.getReader !== 'function') {
|
||||
const bytes = Buffer.from(await upstream.arrayBuffer());
|
||||
res.end(bytes);
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = upstream.body.getReader();
|
||||
try {
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
let upstream;
|
||||
try {
|
||||
upstream = await fetch(url.toString(), {
|
||||
method: 'POST',
|
||||
headers: buildInternalGoHeaders(req, { withContentType: true }),
|
||||
body: rawBody,
|
||||
signal: controller.signal,
|
||||
});
|
||||
} catch (err) {
|
||||
if (clientClosed || isAbortError(err)) {
|
||||
if (!res.writableEnded) {
|
||||
res.end();
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (value && value.length > 0) {
|
||||
res.write(Buffer.from(value));
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
throw err;
|
||||
}
|
||||
if (clientClosed) {
|
||||
if (!res.writableEnded) {
|
||||
res.end();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
res.statusCode = upstream.status;
|
||||
upstream.headers.forEach((value, key) => {
|
||||
if (key.toLowerCase() === 'content-length') {
|
||||
return;
|
||||
}
|
||||
res.setHeader(key, value);
|
||||
});
|
||||
|
||||
if (!upstream.body || typeof upstream.body.getReader !== 'function') {
|
||||
const bytes = Buffer.from(await upstream.arrayBuffer());
|
||||
res.end(bytes);
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = upstream.body.getReader();
|
||||
try {
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
if (clientClosed) {
|
||||
break;
|
||||
}
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
if (value && value.length > 0) {
|
||||
res.write(Buffer.from(value));
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!res.writableEnded) {
|
||||
res.end();
|
||||
}
|
||||
} catch (err) {
|
||||
if (!isAbortError(err) && !res.writableEnded) {
|
||||
res.end();
|
||||
}
|
||||
}
|
||||
res.end();
|
||||
} catch (_err) {
|
||||
} finally {
|
||||
req.removeListener('aborted', onReqAborted);
|
||||
res.removeListener('close', onResClose);
|
||||
if (!res.writableEnded) {
|
||||
res.end();
|
||||
}
|
||||
@@ -762,9 +957,19 @@ function asString(v) {
|
||||
return String(v).trim();
|
||||
}
|
||||
|
||||
function isAbortError(err) {
|
||||
if (!err || typeof err !== 'object') {
|
||||
return false;
|
||||
}
|
||||
return err.name === 'AbortError' || err.code === 'ABORT_ERR';
|
||||
}
|
||||
|
||||
module.exports.__test = {
|
||||
parseChunkForContent,
|
||||
extractContentRecursive,
|
||||
shouldSkipPath,
|
||||
asString,
|
||||
resolveToolcallPolicy,
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
};
|
||||
|
||||
@@ -10,10 +10,50 @@ const {
|
||||
flushToolSieve,
|
||||
} = require('./helpers/stream-tool-sieve');
|
||||
|
||||
const { parseChunkForContent } = handler.__test;
|
||||
const {
|
||||
parseChunkForContent,
|
||||
resolveToolcallPolicy,
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
} = handler.__test;
|
||||
|
||||
test('chat-stream exposes parser test hooks', () => {
|
||||
assert.equal(typeof parseChunkForContent, 'function');
|
||||
assert.equal(typeof resolveToolcallPolicy, 'function');
|
||||
});
|
||||
|
||||
test('resolveToolcallPolicy defaults to feature-match + early emit when prepare flags missing', () => {
|
||||
const policy = resolveToolcallPolicy(
|
||||
{},
|
||||
[{ type: 'function', function: { name: 'read_file', parameters: { type: 'object' } } }],
|
||||
);
|
||||
assert.deepEqual(policy.toolNames, ['read_file']);
|
||||
assert.equal(policy.toolSieveEnabled, true);
|
||||
assert.equal(policy.emitEarlyToolDeltas, true);
|
||||
});
|
||||
|
||||
test('resolveToolcallPolicy respects prepare flags and prepared tool names', () => {
|
||||
const policy = resolveToolcallPolicy(
|
||||
{
|
||||
tool_names: [' prepped_tool ', '', null],
|
||||
toolcall_feature_match: false,
|
||||
toolcall_early_emit_high: false,
|
||||
},
|
||||
[{ type: 'function', function: { name: 'fallback_tool', parameters: { type: 'object' } } }],
|
||||
);
|
||||
assert.deepEqual(policy.toolNames, ['prepped_tool']);
|
||||
assert.equal(policy.toolSieveEnabled, false);
|
||||
assert.equal(policy.emitEarlyToolDeltas, false);
|
||||
});
|
||||
|
||||
test('normalizePreparedToolNames filters empty values', () => {
|
||||
assert.deepEqual(normalizePreparedToolNames([' a ', '', null, 'b']), ['a', 'b']);
|
||||
});
|
||||
|
||||
test('boolDefaultTrue keeps false only when explicitly false', () => {
|
||||
assert.equal(boolDefaultTrue(false), false);
|
||||
assert.equal(boolDefaultTrue(true), true);
|
||||
assert.equal(boolDefaultTrue(undefined), true);
|
||||
});
|
||||
|
||||
test('parseChunkForContent keeps split response/content fragments inside response array', () => {
|
||||
@@ -49,12 +89,13 @@ test('parseChunkForContent + sieve does not leak suspicious prefix in split tool
|
||||
events.push(...flushToolSieve(state, ['read_file']));
|
||||
|
||||
const hasToolCalls = events.some((evt) => evt.type === 'tool_calls' && evt.calls && evt.calls.length > 0);
|
||||
const hasToolDeltas = events.some((evt) => evt.type === 'tool_call_deltas' && evt.deltas && evt.deltas.length > 0);
|
||||
const leakedText = events
|
||||
.filter((evt) => evt.type === 'text' && evt.text)
|
||||
.map((evt) => evt.text)
|
||||
.join('');
|
||||
|
||||
assert.equal(hasToolCalls, true);
|
||||
assert.equal(hasToolCalls || hasToolDeltas, true);
|
||||
assert.equal(leakedText.includes('{'), false);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
});
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
const crypto = require('crypto');
|
||||
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
||||
const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024;
|
||||
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
|
||||
|
||||
function extractToolNames(tools) {
|
||||
if (!Array.isArray(tools) || tools.length === 0) {
|
||||
@@ -26,9 +28,25 @@ function createToolSieveState() {
|
||||
pending: '',
|
||||
capture: '',
|
||||
capturing: false,
|
||||
recentTextTail: '',
|
||||
toolNameSent: false,
|
||||
toolName: '',
|
||||
toolArgsStart: -1,
|
||||
toolArgsSent: -1,
|
||||
toolArgsString: false,
|
||||
toolArgsDone: false,
|
||||
};
|
||||
}
|
||||
|
||||
function resetIncrementalToolState(state) {
|
||||
state.toolNameSent = false;
|
||||
state.toolName = '';
|
||||
state.toolArgsStart = -1;
|
||||
state.toolArgsSent = -1;
|
||||
state.toolArgsString = false;
|
||||
state.toolArgsDone = false;
|
||||
}
|
||||
|
||||
function processToolSieveChunk(state, chunk, toolNames) {
|
||||
if (!state) {
|
||||
return [];
|
||||
@@ -44,13 +62,27 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
state.capture += state.pending;
|
||||
state.pending = '';
|
||||
}
|
||||
const consumed = consumeToolCapture(state.capture, toolNames);
|
||||
const deltas = buildIncrementalToolDeltas(state);
|
||||
if (deltas.length > 0) {
|
||||
events.push({ type: 'tool_call_deltas', deltas });
|
||||
}
|
||||
const consumed = consumeToolCapture(state, toolNames);
|
||||
if (!consumed.ready) {
|
||||
if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) {
|
||||
noteText(state, state.capture);
|
||||
events.push({ type: 'text', text: state.capture });
|
||||
state.capture = '';
|
||||
state.capturing = false;
|
||||
resetIncrementalToolState(state);
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
state.capture = '';
|
||||
state.capturing = false;
|
||||
resetIncrementalToolState(state);
|
||||
if (consumed.prefix) {
|
||||
noteText(state, consumed.prefix);
|
||||
events.push({ type: 'text', text: consumed.prefix });
|
||||
}
|
||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
||||
@@ -70,11 +102,13 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
if (start >= 0) {
|
||||
const prefix = state.pending.slice(0, start);
|
||||
if (prefix) {
|
||||
noteText(state, prefix);
|
||||
events.push({ type: 'text', text: prefix });
|
||||
}
|
||||
state.capture = state.pending.slice(start);
|
||||
state.pending = '';
|
||||
state.capturing = true;
|
||||
resetIncrementalToolState(state);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -83,6 +117,7 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
break;
|
||||
}
|
||||
state.pending = hold;
|
||||
noteText(state, safe);
|
||||
events.push({ type: 'text', text: safe });
|
||||
}
|
||||
return events;
|
||||
@@ -94,24 +129,29 @@ function flushToolSieve(state, toolNames) {
|
||||
}
|
||||
const events = processToolSieveChunk(state, '', toolNames);
|
||||
if (state.capturing) {
|
||||
const consumed = consumeToolCapture(state.capture, toolNames);
|
||||
const consumed = consumeToolCapture(state, toolNames);
|
||||
if (consumed.ready) {
|
||||
if (consumed.prefix) {
|
||||
noteText(state, consumed.prefix);
|
||||
events.push({ type: 'text', text: consumed.prefix });
|
||||
}
|
||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
||||
events.push({ type: 'tool_calls', calls: consumed.calls });
|
||||
}
|
||||
if (consumed.suffix) {
|
||||
noteText(state, consumed.suffix);
|
||||
events.push({ type: 'text', text: consumed.suffix });
|
||||
}
|
||||
} else if (state.capture) {
|
||||
// Incomplete captured tool JSON at stream end: suppress raw capture.
|
||||
noteText(state, state.capture);
|
||||
events.push({ type: 'text', text: state.capture });
|
||||
}
|
||||
state.capture = '';
|
||||
state.capturing = false;
|
||||
resetIncrementalToolState(state);
|
||||
}
|
||||
if (state.pending) {
|
||||
noteText(state, state.pending);
|
||||
events.push({ type: 'text', text: state.pending });
|
||||
state.pending = '';
|
||||
}
|
||||
@@ -151,15 +191,25 @@ function findToolSegmentStart(s) {
|
||||
return -1;
|
||||
}
|
||||
const lower = s.toLowerCase();
|
||||
const keyIdx = lower.indexOf('tool_calls');
|
||||
if (keyIdx < 0) {
|
||||
return -1;
|
||||
let offset = 0;
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
const keyRel = lower.indexOf('tool_calls', offset);
|
||||
if (keyRel < 0) {
|
||||
return -1;
|
||||
}
|
||||
const keyIdx = keyRel;
|
||||
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
||||
const candidateStart = start >= 0 ? start : keyIdx;
|
||||
if (!insideCodeFence(s.slice(0, candidateStart))) {
|
||||
return candidateStart;
|
||||
}
|
||||
offset = keyIdx + 'tool_calls'.length;
|
||||
}
|
||||
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
||||
return start >= 0 ? start : keyIdx;
|
||||
}
|
||||
|
||||
function consumeToolCapture(captured, toolNames) {
|
||||
function consumeToolCapture(state, toolNames) {
|
||||
const captured = state.capture;
|
||||
if (!captured) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
@@ -176,25 +226,367 @@ function consumeToolCapture(captured, toolNames) {
|
||||
if (!obj.ok) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames);
|
||||
if (parsed.length === 0) {
|
||||
// `tool_calls` key exists but strict JSON parse failed.
|
||||
// Drop the captured object body to avoid leaking raw tool JSON.
|
||||
const prefixPart = captured.slice(0, start);
|
||||
const suffixPart = captured.slice(obj.end);
|
||||
if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured.slice(0, start),
|
||||
prefix: captured,
|
||||
calls: [],
|
||||
suffix: captured.slice(obj.end),
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames);
|
||||
if (parsed.length === 0) {
|
||||
if (state.toolNameSent) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
calls: [],
|
||||
suffix: suffixPart,
|
||||
};
|
||||
}
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured,
|
||||
calls: [],
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
if (state.toolNameSent) {
|
||||
if (parsed.length > 1) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
calls: parsed.slice(1),
|
||||
suffix: suffixPart,
|
||||
};
|
||||
}
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
calls: [],
|
||||
suffix: suffixPart,
|
||||
};
|
||||
}
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured.slice(0, start),
|
||||
prefix: prefixPart,
|
||||
calls: parsed,
|
||||
suffix: captured.slice(obj.end),
|
||||
suffix: suffixPart,
|
||||
};
|
||||
}
|
||||
|
||||
function buildIncrementalToolDeltas(state) {
|
||||
const captured = state.capture || '';
|
||||
if (!captured) {
|
||||
return [];
|
||||
}
|
||||
if (looksLikeToolExampleContext(state.recentTextTail)) {
|
||||
return [];
|
||||
}
|
||||
const lower = captured.toLowerCase();
|
||||
const keyIdx = lower.indexOf('tool_calls');
|
||||
if (keyIdx < 0) {
|
||||
return [];
|
||||
}
|
||||
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
||||
if (start < 0) {
|
||||
return [];
|
||||
}
|
||||
if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) {
|
||||
return [];
|
||||
}
|
||||
const callStart = findFirstToolCallObjectStart(captured, keyIdx);
|
||||
if (callStart < 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const deltas = [];
|
||||
if (!state.toolName) {
|
||||
const name = extractToolCallName(captured, callStart);
|
||||
if (!name) {
|
||||
return [];
|
||||
}
|
||||
state.toolName = name;
|
||||
}
|
||||
|
||||
if (state.toolArgsStart < 0) {
|
||||
const args = findToolCallArgsStart(captured, callStart);
|
||||
if (args) {
|
||||
state.toolArgsString = Boolean(args.stringMode);
|
||||
state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start;
|
||||
state.toolArgsSent = state.toolArgsStart;
|
||||
}
|
||||
}
|
||||
if (!state.toolNameSent) {
|
||||
if (state.toolArgsStart < 0) {
|
||||
return [];
|
||||
}
|
||||
state.toolNameSent = true;
|
||||
deltas.push({ index: 0, name: state.toolName });
|
||||
}
|
||||
if (state.toolArgsStart < 0 || state.toolArgsDone) {
|
||||
return deltas;
|
||||
}
|
||||
const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString);
|
||||
if (!progress) {
|
||||
return deltas;
|
||||
}
|
||||
if (progress.end > state.toolArgsSent) {
|
||||
deltas.push({
|
||||
index: 0,
|
||||
arguments: captured.slice(state.toolArgsSent, progress.end),
|
||||
});
|
||||
state.toolArgsSent = progress.end;
|
||||
}
|
||||
if (progress.complete) {
|
||||
state.toolArgsDone = true;
|
||||
}
|
||||
return deltas;
|
||||
}
|
||||
|
||||
function findFirstToolCallObjectStart(text, keyIdx) {
|
||||
const arrStart = findToolCallsArrayStart(text, keyIdx);
|
||||
if (arrStart < 0) {
|
||||
return -1;
|
||||
}
|
||||
const i = skipSpaces(text, arrStart + 1);
|
||||
if (i >= text.length || text[i] !== '{') {
|
||||
return -1;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
function findToolCallsArrayStart(text, keyIdx) {
|
||||
let i = keyIdx + 'tool_calls'.length;
|
||||
while (i < text.length && text[i] !== ':') {
|
||||
i += 1;
|
||||
}
|
||||
if (i >= text.length) {
|
||||
return -1;
|
||||
}
|
||||
i = skipSpaces(text, i + 1);
|
||||
if (i >= text.length || text[i] !== '[') {
|
||||
return -1;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
function extractToolCallName(text, callStart) {
|
||||
let valueStart = findObjectFieldValueStart(text, callStart, ['name']);
|
||||
if (valueStart < 0 || text[valueStart] !== '"') {
|
||||
const fnStart = findFunctionObjectStart(text, callStart);
|
||||
if (fnStart < 0) {
|
||||
return '';
|
||||
}
|
||||
valueStart = findObjectFieldValueStart(text, fnStart, ['name']);
|
||||
if (valueStart < 0 || text[valueStart] !== '"') {
|
||||
return '';
|
||||
}
|
||||
}
|
||||
const parsed = parseJSONStringLiteral(text, valueStart);
|
||||
if (!parsed) {
|
||||
return '';
|
||||
}
|
||||
return parsed.value;
|
||||
}
|
||||
|
||||
function findToolCallArgsStart(text, callStart) {
|
||||
const keys = ['input', 'arguments', 'args', 'parameters', 'params'];
|
||||
let valueStart = findObjectFieldValueStart(text, callStart, keys);
|
||||
if (valueStart < 0) {
|
||||
const fnStart = findFunctionObjectStart(text, callStart);
|
||||
if (fnStart < 0) {
|
||||
return null;
|
||||
}
|
||||
valueStart = findObjectFieldValueStart(text, fnStart, keys);
|
||||
if (valueStart < 0) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
if (valueStart >= text.length) {
|
||||
return null;
|
||||
}
|
||||
const ch = text[valueStart];
|
||||
if (ch === '{' || ch === '[') {
|
||||
return { start: valueStart, stringMode: false };
|
||||
}
|
||||
if (ch === '"') {
|
||||
return { start: valueStart, stringMode: true };
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function scanToolCallArgsProgress(text, start, stringMode) {
|
||||
if (start < 0 || start > text.length) {
|
||||
return null;
|
||||
}
|
||||
if (stringMode) {
|
||||
let escaped = false;
|
||||
for (let i = start; i < text.length; i += 1) {
|
||||
const ch = text[i];
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if (ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '"') {
|
||||
return { end: i, complete: true };
|
||||
}
|
||||
}
|
||||
return { end: text.length, complete: false };
|
||||
}
|
||||
if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) {
|
||||
return null;
|
||||
}
|
||||
let depth = 0;
|
||||
let quote = '';
|
||||
let escaped = false;
|
||||
for (let i = start; i < text.length; i += 1) {
|
||||
const ch = text[i];
|
||||
if (quote) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if (ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === quote) {
|
||||
quote = '';
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '"' || ch === "'") {
|
||||
quote = ch;
|
||||
continue;
|
||||
}
|
||||
if (ch === '{' || ch === '[') {
|
||||
depth += 1;
|
||||
continue;
|
||||
}
|
||||
if (ch === '}' || ch === ']') {
|
||||
depth -= 1;
|
||||
if (depth === 0) {
|
||||
return { end: i + 1, complete: true };
|
||||
}
|
||||
}
|
||||
}
|
||||
return { end: text.length, complete: false };
|
||||
}
|
||||
|
||||
function findObjectFieldValueStart(text, objStart, keys) {
|
||||
if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') {
|
||||
return -1;
|
||||
}
|
||||
let depth = 0;
|
||||
let quote = '';
|
||||
let escaped = false;
|
||||
for (let i = objStart; i < text.length; i += 1) {
|
||||
const ch = text[i];
|
||||
if (quote) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if (ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === quote) {
|
||||
quote = '';
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '"' || ch === "'") {
|
||||
if (depth === 1) {
|
||||
const parsed = parseJSONStringLiteral(text, i);
|
||||
if (!parsed) {
|
||||
return -1;
|
||||
}
|
||||
let j = skipSpaces(text, parsed.end);
|
||||
if (j >= text.length || text[j] !== ':') {
|
||||
i = parsed.end - 1;
|
||||
continue;
|
||||
}
|
||||
j = skipSpaces(text, j + 1);
|
||||
if (j >= text.length) {
|
||||
return -1;
|
||||
}
|
||||
if (keys.includes(parsed.value)) {
|
||||
return j;
|
||||
}
|
||||
i = j - 1;
|
||||
continue;
|
||||
}
|
||||
quote = ch;
|
||||
continue;
|
||||
}
|
||||
if (ch === '{') {
|
||||
depth += 1;
|
||||
continue;
|
||||
}
|
||||
if (ch === '}') {
|
||||
depth -= 1;
|
||||
if (depth === 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
function findFunctionObjectStart(text, callStart) {
|
||||
const valueStart = findObjectFieldValueStart(text, callStart, ['function']);
|
||||
if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') {
|
||||
return -1;
|
||||
}
|
||||
return valueStart;
|
||||
}
|
||||
|
||||
function parseJSONStringLiteral(text, start) {
|
||||
if (!text || start < 0 || start >= text.length || text[start] !== '"') {
|
||||
return null;
|
||||
}
|
||||
let out = '';
|
||||
let escaped = false;
|
||||
for (let i = start + 1; i < text.length; i += 1) {
|
||||
const ch = text[i];
|
||||
if (escaped) {
|
||||
out += ch;
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if (ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '"') {
|
||||
return { value: out, end: i + 1 };
|
||||
}
|
||||
out += ch;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function skipSpaces(text, i) {
|
||||
let idx = i;
|
||||
while (idx < text.length) {
|
||||
const ch = text[idx];
|
||||
if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') {
|
||||
idx += 1;
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
function extractJSONObjectFrom(text, start) {
|
||||
if (!text || start < 0 || start >= text.length || text[start] !== '{') {
|
||||
return { ok: false, end: 0 };
|
||||
@@ -240,7 +632,11 @@ function parseToolCalls(text, toolNames) {
|
||||
if (!toStringSafe(text)) {
|
||||
return [];
|
||||
}
|
||||
const candidates = buildToolCallCandidates(text);
|
||||
const sanitized = stripFencedCodeBlocks(text);
|
||||
if (!toStringSafe(sanitized)) {
|
||||
return [];
|
||||
}
|
||||
const candidates = buildToolCallCandidates(sanitized);
|
||||
let parsed = [];
|
||||
for (const c of candidates) {
|
||||
parsed = parseToolCallsPayload(c);
|
||||
@@ -251,26 +647,49 @@ function parseToolCalls(text, toolNames) {
|
||||
if (parsed.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const allowed = new Set((toolNames || []).filter(Boolean));
|
||||
const out = [];
|
||||
for (const tc of parsed) {
|
||||
if (!tc || !tc.name) {
|
||||
continue;
|
||||
}
|
||||
if (allowed.size > 0 && !allowed.has(tc.name)) {
|
||||
continue;
|
||||
}
|
||||
out.push({ name: tc.name, input: tc.input || {} });
|
||||
return filterToolCalls(parsed, toolNames);
|
||||
}
|
||||
|
||||
function stripFencedCodeBlocks(text) {
|
||||
const t = typeof text === 'string' ? text : '';
|
||||
if (!t) {
|
||||
return '';
|
||||
}
|
||||
if (out.length === 0 && parsed.length > 0) {
|
||||
for (const tc of parsed) {
|
||||
if (!tc || !tc.name) {
|
||||
continue;
|
||||
}
|
||||
out.push({ name: tc.name, input: tc.input || {} });
|
||||
return t.replace(/```[\s\S]*?```/g, ' ');
|
||||
}
|
||||
|
||||
function parseStandaloneToolCalls(text, toolNames) {
|
||||
const trimmed = toStringSafe(text);
|
||||
if (!trimmed) {
|
||||
return [];
|
||||
}
|
||||
if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) {
|
||||
return [];
|
||||
}
|
||||
if (looksLikeToolExampleContext(trimmed)) {
|
||||
return [];
|
||||
}
|
||||
const candidates = [trimmed];
|
||||
if (trimmed.startsWith('```') && trimmed.endsWith('```')) {
|
||||
const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
||||
if (m && m[1]) {
|
||||
candidates.push(toStringSafe(m[1]));
|
||||
}
|
||||
}
|
||||
return out;
|
||||
for (const candidate of candidates) {
|
||||
const c = toStringSafe(candidate);
|
||||
if (!c) {
|
||||
continue;
|
||||
}
|
||||
if (!c.startsWith('{') && !c.startsWith('[')) {
|
||||
continue;
|
||||
}
|
||||
const parsed = parseToolCallsPayload(c);
|
||||
if (parsed.length > 0) {
|
||||
return filterToolCalls(parsed, toolNames);
|
||||
}
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
function buildToolCallCandidates(text) {
|
||||
@@ -432,6 +851,66 @@ function parseToolCallInput(v) {
|
||||
return {};
|
||||
}
|
||||
|
||||
function filterToolCalls(parsed, toolNames) {
|
||||
const allowed = new Set((toolNames || []).filter(Boolean));
|
||||
const out = [];
|
||||
for (const tc of parsed) {
|
||||
if (!tc || !tc.name) {
|
||||
continue;
|
||||
}
|
||||
if (allowed.size > 0 && !allowed.has(tc.name)) {
|
||||
continue;
|
||||
}
|
||||
out.push({ name: tc.name, input: tc.input || {} });
|
||||
}
|
||||
if (out.length === 0 && parsed.length > 0) {
|
||||
for (const tc of parsed) {
|
||||
if (!tc || !tc.name) {
|
||||
continue;
|
||||
}
|
||||
out.push({ name: tc.name, input: tc.input || {} });
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function noteText(state, text) {
|
||||
if (!state || !hasMeaningfulText(text)) {
|
||||
return;
|
||||
}
|
||||
state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT);
|
||||
}
|
||||
|
||||
function appendTail(prev, next, max) {
|
||||
const left = typeof prev === 'string' ? prev : '';
|
||||
const right = typeof next === 'string' ? next : '';
|
||||
if (!Number.isFinite(max) || max <= 0) {
|
||||
return '';
|
||||
}
|
||||
const combined = left + right;
|
||||
if (combined.length <= max) {
|
||||
return combined;
|
||||
}
|
||||
return combined.slice(combined.length - max);
|
||||
}
|
||||
|
||||
function looksLikeToolExampleContext(text) {
|
||||
return insideCodeFence(text);
|
||||
}
|
||||
|
||||
function insideCodeFence(text) {
|
||||
const t = typeof text === 'string' ? text : '';
|
||||
if (!t) {
|
||||
return false;
|
||||
}
|
||||
const ticks = (t.match(/```/g) || []).length;
|
||||
return ticks % 2 === 1;
|
||||
}
|
||||
|
||||
function hasMeaningfulText(text) {
|
||||
return toStringSafe(text) !== '';
|
||||
}
|
||||
|
||||
function formatOpenAIStreamToolCalls(calls) {
|
||||
if (!Array.isArray(calls) || calls.length === 0) {
|
||||
return [];
|
||||
@@ -473,5 +952,6 @@ module.exports = {
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseToolCalls,
|
||||
parseStandaloneToolCalls,
|
||||
formatOpenAIStreamToolCalls,
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@ const {
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseToolCalls,
|
||||
parseStandaloneToolCalls,
|
||||
} = require('./stream-tool-sieve');
|
||||
|
||||
function runSieve(chunks, toolNames) {
|
||||
@@ -68,9 +69,22 @@ test('parseToolCalls supports fenced json and function.arguments string payload'
|
||||
'```',
|
||||
].join('\n');
|
||||
const calls = parseToolCalls(text, ['read_file']);
|
||||
assert.equal(calls.length, 1);
|
||||
assert.equal(calls[0].name, 'read_file');
|
||||
assert.deepEqual(calls[0].input, { path: 'README.md' });
|
||||
assert.equal(calls.length, 0);
|
||||
});
|
||||
|
||||
test('parseStandaloneToolCalls only matches standalone payload and ignores mixed prose', () => {
|
||||
const mixed = '这里是示例:{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]},请勿执行。';
|
||||
const standalone = '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}';
|
||||
const mixedCalls = parseStandaloneToolCalls(mixed, ['read_file']);
|
||||
const standaloneCalls = parseStandaloneToolCalls(standalone, ['read_file']);
|
||||
assert.equal(mixedCalls.length, 0);
|
||||
assert.equal(standaloneCalls.length, 1);
|
||||
});
|
||||
|
||||
test('parseStandaloneToolCalls ignores fenced code block tool_call examples', () => {
|
||||
const fenced = ['```json', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}', '```'].join('\n');
|
||||
const calls = parseStandaloneToolCalls(fenced, ['read_file']);
|
||||
assert.equal(calls.length, 0);
|
||||
});
|
||||
|
||||
test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => {
|
||||
@@ -84,13 +98,14 @@ test('sieve emits tool_calls and does not leak suspicious prefix on late key con
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0);
|
||||
assert.equal(hasToolCall, true);
|
||||
const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0);
|
||||
assert.equal(hasToolCall || hasToolDelta, true);
|
||||
assert.equal(leakedText.includes('{'), false);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
assert.equal(leakedText.includes('后置正文C。'), true);
|
||||
});
|
||||
|
||||
test('sieve drops invalid tool json body while preserving surrounding text', () => {
|
||||
test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => {
|
||||
const events = runSieve(
|
||||
[
|
||||
'前置正文D。',
|
||||
@@ -104,18 +119,18 @@ test('sieve drops invalid tool json body while preserving surrounding text', ()
|
||||
assert.equal(hasToolCall, false);
|
||||
assert.equal(leakedText.includes('前置正文D。'), true);
|
||||
assert.equal(leakedText.includes('后置正文E。'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
|
||||
});
|
||||
|
||||
test('sieve suppresses incomplete captured tool json on stream finalize', () => {
|
||||
test('sieve flushes incomplete captured tool json as text on stream finalize', () => {
|
||||
const events = runSieve(
|
||||
['前置正文F。', '{"tool_calls":[{"name":"read_file"'],
|
||||
['read_file'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
assert.equal(leakedText.includes('前置正文F。'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
assert.equal(leakedText.includes('{'), false);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
|
||||
assert.equal(leakedText.includes('{'), true);
|
||||
});
|
||||
|
||||
test('sieve keeps plain text intact in tool mode when no tool call appears', () => {
|
||||
@@ -128,3 +143,53 @@ test('sieve keeps plain text intact in tool mode when no tool call appears', ()
|
||||
assert.equal(hasToolCall, false);
|
||||
assert.equal(leakedText, '你好,这是普通文本回复。请继续。');
|
||||
});
|
||||
|
||||
test('sieve emits incremental tool_call_deltas for split arguments payload', () => {
|
||||
const state = createToolSieveState();
|
||||
const first = processToolSieveChunk(
|
||||
state,
|
||||
'{"tool_calls":[{"name":"read_file","input":{"path":"READ',
|
||||
['read_file'],
|
||||
);
|
||||
const second = processToolSieveChunk(
|
||||
state,
|
||||
'ME.MD","mode":"head"}}]}',
|
||||
['read_file'],
|
||||
);
|
||||
const tail = flushToolSieve(state, ['read_file']);
|
||||
const events = [...first, ...second, ...tail];
|
||||
const deltaEvents = events.filter((evt) => evt.type === 'tool_call_deltas');
|
||||
assert.equal(deltaEvents.length > 0, true);
|
||||
const merged = deltaEvents.flatMap((evt) => evt.deltas || []);
|
||||
const hasName = merged.some((d) => d.name === 'read_file');
|
||||
const argsJoined = merged
|
||||
.map((d) => d.arguments || '')
|
||||
.join('');
|
||||
assert.equal(hasName, true);
|
||||
assert.equal(argsJoined.includes('"path":"README.MD"'), true);
|
||||
assert.equal(argsJoined.includes('"mode":"head"'), true);
|
||||
});
|
||||
|
||||
test('sieve still intercepts tool call after leading plain text without suffix', () => {
|
||||
const events = runSieve(
|
||||
['我将调用工具。', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],
|
||||
['read_file'],
|
||||
);
|
||||
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
|
||||
const leakedText = collectText(events);
|
||||
assert.equal(hasTool, true);
|
||||
assert.equal(leakedText.includes('我将调用工具。'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
});
|
||||
|
||||
test('sieve intercepts tool call and preserves trailing same-chunk text', () => {
|
||||
const events = runSieve(
|
||||
['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}然后继续解释。'],
|
||||
['read_file'],
|
||||
);
|
||||
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
|
||||
const leakedText = collectText(events);
|
||||
assert.equal(hasTool, true);
|
||||
assert.equal(leakedText.includes('然后继续解释。'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
});
|
||||
|
||||
@@ -24,5 +24,27 @@
|
||||
"password": "your-password-3",
|
||||
"token": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true
|
||||
},
|
||||
"toolcall": {
|
||||
"mode": "feature_match",
|
||||
"early_emit_confidence": "high"
|
||||
},
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_model_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
}
|
||||
}
|
||||
|
||||
249
internal/account/pool_edge_test.go
Normal file
249
internal/account/pool_edge_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── Pool edge cases ─────────────────────────────────────────────────
|
||||
|
||||
func TestPoolEmptyNoAccounts(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "2")
|
||||
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
|
||||
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
|
||||
pool := NewPool(config.LoadStore())
|
||||
if _, ok := pool.Acquire("", nil); ok {
|
||||
t.Fatal("expected acquire to fail with no accounts")
|
||||
}
|
||||
status := pool.Status()
|
||||
if total, ok := status["total"].(int); !ok || total != 0 {
|
||||
t.Fatalf("unexpected total: %#v", status["total"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolReleaseNonExistentAccount(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
pool.Release("nonexistent@example.com") // should not panic
|
||||
}
|
||||
|
||||
func TestPoolReleaseAlreadyReleased(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
acc, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected acquire success")
|
||||
}
|
||||
pool.Release(acc.Identifier())
|
||||
pool.Release(acc.Identifier()) // double release should not panic
|
||||
}
|
||||
|
||||
func TestPoolAcquireTargetNotFound(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
if _, ok := pool.Acquire("nonexistent@example.com", nil); ok {
|
||||
t.Fatal("expected acquire to fail for non-existent target")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolAcquireWithExclusionList(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
acc, ok := pool.Acquire("", map[string]bool{"acc1@example.com": true})
|
||||
if !ok {
|
||||
t.Fatal("expected acquire success with exclusion")
|
||||
}
|
||||
if acc.Identifier() != "acc2@example.com" {
|
||||
t.Fatalf("expected acc2 when acc1 excluded, got %q", acc.Identifier())
|
||||
}
|
||||
pool.Release(acc.Identifier())
|
||||
}
|
||||
|
||||
func TestPoolAcquireAllExcluded(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
if _, ok := pool.Acquire("", map[string]bool{
|
||||
"acc1@example.com": true,
|
||||
"acc2@example.com": true,
|
||||
}); ok {
|
||||
t.Fatal("expected acquire to fail when all accounts excluded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolStatusFields(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
status := pool.Status()
|
||||
|
||||
// Check all expected fields are present
|
||||
for _, key := range []string{"total", "available", "max_inflight_per_account", "recommended_concurrency", "available_accounts", "in_use_accounts", "waiting", "max_queue_size"} {
|
||||
if _, ok := status[key]; !ok {
|
||||
t.Fatalf("missing status field: %s", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolStatusAccountDetails(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
acc, _ := pool.Acquire("acc1@example.com", nil)
|
||||
|
||||
status := pool.Status()
|
||||
inUseAccounts, ok := status["in_use_accounts"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected in_use_accounts type: %T", status["in_use_accounts"])
|
||||
}
|
||||
found := false
|
||||
for _, id := range inUseAccounts {
|
||||
if id == "acc1@example.com" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected acc1 in in_use_accounts, got %v", inUseAccounts)
|
||||
}
|
||||
if status["in_use"] != 1 {
|
||||
t.Fatalf("expected 1 in_use, got %v", status["in_use"])
|
||||
}
|
||||
|
||||
pool.Release(acc.Identifier())
|
||||
}
|
||||
|
||||
func TestPoolAcquireWaitContextCancelled(t *testing.T) {
|
||||
pool := newSingleAccountPoolForTest(t, "1")
|
||||
// Exhaust the pool
|
||||
first, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected first acquire to succeed")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
var waitOK bool
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, waitOK = pool.AcquireWait(ctx, "", nil)
|
||||
}()
|
||||
|
||||
// Wait until queued
|
||||
waitForWaitingCount(t, pool, 1)
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
|
||||
wg.Wait()
|
||||
if waitOK {
|
||||
t.Fatal("expected acquire to fail after context cancellation")
|
||||
}
|
||||
|
||||
pool.Release(first.Identifier())
|
||||
}
|
||||
|
||||
func TestPoolAcquireWaitTargetAccount(t *testing.T) {
|
||||
pool := newPoolForTest(t, "1")
|
||||
// Exhaust acc1
|
||||
acc1, ok := pool.Acquire("acc1@example.com", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected acquire acc1 success")
|
||||
}
|
||||
|
||||
// Acquire acc2 directly (should succeed since acc2 is free)
|
||||
ctx := context.Background()
|
||||
acc2, ok := pool.AcquireWait(ctx, "acc2@example.com", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected acquire acc2 success via AcquireWait")
|
||||
}
|
||||
if acc2.Identifier() != "acc2@example.com" {
|
||||
t.Fatalf("expected acc2, got %q", acc2.Identifier())
|
||||
}
|
||||
|
||||
pool.Release(acc1.Identifier())
|
||||
pool.Release(acc2.Identifier())
|
||||
}
|
||||
|
||||
func TestPoolMaxQueueSizeOverride(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "5")
|
||||
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
|
||||
pool := NewPool(config.LoadStore())
|
||||
status := pool.Status()
|
||||
if got, ok := status["max_queue_size"].(int); !ok || got != 5 {
|
||||
t.Fatalf("expected max_queue_size=5, got %#v", status["max_queue_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolQueueSizeAliasEnv(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
|
||||
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "7")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
|
||||
pool := NewPool(config.LoadStore())
|
||||
status := pool.Status()
|
||||
if got, ok := status["max_queue_size"].(int); !ok || got != 7 {
|
||||
t.Fatalf("expected max_queue_size=7, got %#v", status["max_queue_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolMultipleAcquireReleaseCycles(t *testing.T) {
|
||||
pool := newSingleAccountPoolForTest(t, "1")
|
||||
for i := 0; i < 10; i++ {
|
||||
acc, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatalf("acquire failed at cycle %d", i)
|
||||
}
|
||||
pool.Release(acc.Identifier())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolConcurrentAcquireWait(t *testing.T) {
|
||||
pool := newSingleAccountPoolForTest(t, "1")
|
||||
first, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected first acquire success")
|
||||
}
|
||||
|
||||
const waiters = 3
|
||||
results := make(chan bool, waiters)
|
||||
|
||||
for i := 0; i < waiters; i++ {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, ok := pool.AcquireWait(ctx, "", nil)
|
||||
results <- ok
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all to be queued (only 1 can queue)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Release and allow queued requests to proceed
|
||||
pool.Release(first.Identifier())
|
||||
|
||||
successCount := 0
|
||||
timeoutCount := 0
|
||||
for i := 0; i < waiters; i++ {
|
||||
select {
|
||||
case ok := <-results:
|
||||
if ok {
|
||||
successCount++
|
||||
// Release for next waiter
|
||||
pool.Release("acc1@example.com")
|
||||
} else {
|
||||
timeoutCount++
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timed out waiting for results")
|
||||
}
|
||||
}
|
||||
|
||||
// At least 1 should succeed; 2 may fail due to queue limit
|
||||
if successCount < 1 {
|
||||
t.Fatalf("expected at least 1 success, got success=%d timeout=%d", successCount, timeoutCount)
|
||||
}
|
||||
}
|
||||
35
internal/adapter/claude/error_shape_test.go
Normal file
35
internal/adapter/claude/error_shape_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
writeClaudeError(rec, http.StatusUnauthorized, "bad token")
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode body: %v", err)
|
||||
}
|
||||
errObj, _ := body["error"].(map[string]any)
|
||||
if errObj["message"] != "bad token" {
|
||||
t.Fatalf("unexpected message: %v", errObj["message"])
|
||||
}
|
||||
if errObj["type"] != "invalid_request_error" {
|
||||
t.Fatalf("unexpected type: %v", errObj["type"])
|
||||
}
|
||||
if errObj["code"] != "authentication_failed" {
|
||||
t.Fatalf("unexpected code: %v", errObj["code"])
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatal("expected param field")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,9 @@ func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
@@ -50,132 +53,79 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}})
|
||||
writeClaudeError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}})
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}})
|
||||
norm, err := normalizeClaudeRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
normalized := normalizeClaudeMessages(messagesRaw)
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalized
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
if len(toolsRequested) > 0 && !hasSystemMessage(normalized) {
|
||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...)
|
||||
}
|
||||
|
||||
dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||
if !ok {
|
||||
thinkingEnabled = false
|
||||
searchEnabled = false
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
stdReq := norm.Standard
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}})
|
||||
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}})
|
||||
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
|
||||
return
|
||||
}
|
||||
requestPayload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
requestPayload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}})
|
||||
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
if util.ToBool(req["stream"]) {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
|
||||
if stdReq.Stream {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
fullText := result.Text
|
||||
fullThinking := result.Thinking
|
||||
detected := util.ParseToolCalls(fullText, toolNames)
|
||||
content := make([]map[string]any, 0, 4)
|
||||
if fullThinking != "" {
|
||||
content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking})
|
||||
}
|
||||
stopReason := "end_turn"
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
content = append(content, map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if fullText == "" {
|
||||
fullText = "抱歉,没有生成有效的响应内容。"
|
||||
}
|
||||
content = append(content, map[string]any{"type": "text", "text": fullText})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"id": fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content,
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)),
|
||||
"output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText),
|
||||
},
|
||||
})
|
||||
result := sse.CollectStream(resp, stdReq.Thinking, true)
|
||||
respBody := util.BuildClaudeMessageResponse(
|
||||
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
stdReq.ResponseModel,
|
||||
norm.NormalizedMessages,
|
||||
result.Thinking,
|
||||
result.Text,
|
||||
stdReq.ToolNames,
|
||||
)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()})
|
||||
writeClaudeError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"})
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messages, _ := req["messages"].([]any)
|
||||
if model == "" || len(messages) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."})
|
||||
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
return
|
||||
}
|
||||
inputTokens := 0
|
||||
@@ -206,7 +156,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -241,6 +191,8 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
"error": map[string]any{
|
||||
"type": "api_error",
|
||||
"message": msg,
|
||||
"code": "internal_error",
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -492,6 +444,28 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
}
|
||||
|
||||
func writeClaudeError(w http.ResponseWriter, status int, message string) {
|
||||
code := "invalid_request"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
code = "authentication_failed"
|
||||
case http.StatusTooManyRequests:
|
||||
code = "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
code = "not_found"
|
||||
case http.StatusInternalServerError:
|
||||
code = "internal_error"
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "invalid_request_error",
|
||||
"message": message,
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeClaudeMessages(messages []any) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
|
||||
348
internal/adapter/claude/handler_util_test.go
Normal file
348
internal/adapter/claude/handler_util_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ─── normalizeClaudeMessages ─────────────────────────────────────────
|
||||
|
||||
func TestNormalizeClaudeMessagesSimpleString(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "user", "content": "Hello"},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %v", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesArrayContent(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "line1"},
|
||||
map[string]any{"type": "text", "text": "line2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "line1\nline2" {
|
||||
t.Fatalf("expected joined text, got %q", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "tool_result", "content": "tool output"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "tool output" {
|
||||
t.Fatalf("expected 'tool output', got %q", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) {
|
||||
msgs := []any{"not a map", 42}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected 0 messages for non-map items, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesEmpty(t *testing.T) {
|
||||
got := normalizeClaudeMessages(nil)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected 0, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "assistant", "content": "response"},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected 'assistant', got %q", m["role"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "Hello"},
|
||||
map[string]any{"type": "image", "source": "data:..."},
|
||||
map[string]any{"type": "text", "text": "World"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "Hello\nWorld" {
|
||||
t.Fatalf("expected only text parts joined, got %q", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
// ─── buildClaudeToolPrompt ───────────────────────────────────────────
|
||||
|
||||
func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if prompt == "" {
|
||||
t.Fatal("expected non-empty prompt")
|
||||
}
|
||||
// Should contain tool name and description
|
||||
if !containsStr(prompt, "search") {
|
||||
t.Fatalf("expected 'search' in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "Search the web") {
|
||||
t.Fatalf("expected description in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "tool_calls") {
|
||||
t.Fatalf("expected tool_calls instruction in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeToolPromptMultipleTools(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": "tool1", "description": "desc1"},
|
||||
map[string]any{"name": "tool2", "description": "desc2"},
|
||||
}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") {
|
||||
t.Fatalf("expected both tools in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) {
|
||||
tools := []any{"not a map"}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if prompt == "" {
|
||||
t.Fatal("expected non-empty prompt even with invalid tools")
|
||||
}
|
||||
// Should still contain the intro and instruction
|
||||
if !containsStr(prompt, "You are Claude") {
|
||||
t.Fatalf("expected intro in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── hasSystemMessage ────────────────────────────────────────────────
|
||||
|
||||
func TestHasSystemMessageTrue(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "system", "content": "You are a helper"},
|
||||
map[string]any{"role": "user", "content": "Hi"},
|
||||
}
|
||||
if !hasSystemMessage(msgs) {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSystemMessageFalse(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "user", "content": "Hi"},
|
||||
map[string]any{"role": "assistant", "content": "Hello"},
|
||||
}
|
||||
if hasSystemMessage(msgs) {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSystemMessageEmpty(t *testing.T) {
|
||||
if hasSystemMessage(nil) {
|
||||
t.Fatal("expected false for nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSystemMessageNonMap(t *testing.T) {
|
||||
msgs := []any{"not a map"}
|
||||
if hasSystemMessage(msgs) {
|
||||
t.Fatal("expected false for non-map")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── extractClaudeToolNames ──────────────────────────────────────────
|
||||
|
||||
func TestExtractClaudeToolNamesSingle(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": "search"},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 1 || names[0] != "search" {
|
||||
t.Fatalf("expected [search], got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesMultiple(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": "search"},
|
||||
map[string]any{"name": "calculate"},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 2 {
|
||||
t.Fatalf("expected 2 names, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": ""},
|
||||
map[string]any{"name": "valid"},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 1 || names[0] != "valid" {
|
||||
t.Fatalf("expected [valid], got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) {
|
||||
tools := []any{"not a map", 42}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 0 {
|
||||
t.Fatalf("expected 0, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesNil(t *testing.T) {
|
||||
names := extractClaudeToolNames(nil)
|
||||
if len(names) != 0 {
|
||||
t.Fatalf("expected 0, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── toMessageMaps ───────────────────────────────────────────────────
|
||||
|
||||
func TestToMessageMapsNormal(t *testing.T) {
|
||||
input := []any{
|
||||
map[string]any{"role": "user", "content": "Hello"},
|
||||
}
|
||||
got := toMessageMaps(input)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessageMapsNonSlice(t *testing.T) {
|
||||
got := toMessageMaps("not a slice")
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessageMapsSkipsNonMap(t *testing.T) {
|
||||
input := []any{"string", map[string]any{"role": "user"}, 42}
|
||||
got := toMessageMaps(input)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 map, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessageMapsNil(t *testing.T) {
|
||||
got := toMessageMaps(nil)
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── extractMessageContent ──────────────────────────────────────────
|
||||
|
||||
func TestExtractMessageContentString(t *testing.T) {
|
||||
if got := extractMessageContent("hello"); got != "hello" {
|
||||
t.Fatalf("expected 'hello', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMessageContentArray(t *testing.T) {
|
||||
input := []any{"part1", "part2"}
|
||||
got := extractMessageContent(input)
|
||||
if got != "part1\npart2" {
|
||||
t.Fatalf("expected joined, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMessageContentOther(t *testing.T) {
|
||||
got := extractMessageContent(42)
|
||||
if got != "42" {
|
||||
t.Fatalf("expected '42', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMessageContentNil(t *testing.T) {
|
||||
got := extractMessageContent(nil)
|
||||
if got != "<nil>" {
|
||||
t.Fatalf("expected '<nil>', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── cloneMap ────────────────────────────────────────────────────────
|
||||
|
||||
func TestCloneMapBasic(t *testing.T) {
|
||||
original := map[string]any{"a": 1, "b": "hello"}
|
||||
clone := cloneMap(original)
|
||||
original["a"] = 999
|
||||
if clone["a"] != 1 {
|
||||
t.Fatalf("expected 1, got %v", clone["a"])
|
||||
}
|
||||
if clone["b"] != "hello" {
|
||||
t.Fatalf("expected 'hello', got %v", clone["b"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneMapEmpty(t *testing.T) {
|
||||
clone := cloneMap(map[string]any{})
|
||||
if len(clone) != 0 {
|
||||
t.Fatalf("expected empty, got %v", clone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneMapNested(t *testing.T) {
|
||||
// cloneMap is shallow, so nested maps share references
|
||||
inner := map[string]any{"key": "value"}
|
||||
original := map[string]any{"nested": inner}
|
||||
clone := cloneMap(original)
|
||||
// Shallow clone means inner is shared
|
||||
inner["key"] = "modified"
|
||||
cloneNested := clone["nested"].(map[string]any)
|
||||
if cloneNested["key"] != "modified" {
|
||||
t.Fatal("expected shallow clone to share nested references")
|
||||
}
|
||||
}
|
||||
|
||||
// helper
|
||||
func containsStr(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub))
|
||||
}
|
||||
|
||||
func findSubstring(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
58
internal/adapter/claude/standard_request.go
Normal file
58
internal/adapter/claude/standard_request.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type claudeNormalizedRequest struct {
|
||||
Standard util.StandardRequest
|
||||
NormalizedMessages []any
|
||||
}
|
||||
|
||||
func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
|
||||
}
|
||||
if _, ok := req["max_tokens"]; !ok {
|
||||
req["max_tokens"] = 8192
|
||||
}
|
||||
normalizedMessages := normalizeClaudeMessages(messagesRaw)
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalizedMessages
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) {
|
||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...)
|
||||
}
|
||||
|
||||
dsPayload := util.ConvertClaudeToDeepSeek(payload, store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||
if !ok {
|
||||
thinkingEnabled = false
|
||||
searchEnabled = false
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
|
||||
return claudeNormalizedRequest{
|
||||
Standard: util.StandardRequest{
|
||||
Surface: "anthropic_messages",
|
||||
RequestedModel: strings.TrimSpace(model),
|
||||
ResolvedModel: dsModel,
|
||||
ResponseModel: strings.TrimSpace(model),
|
||||
Messages: payload["messages"].([]any),
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
},
|
||||
NormalizedMessages: normalizedMessages,
|
||||
}, nil
|
||||
}
|
||||
38
internal/adapter/claude/standard_request_test.go
Normal file
38
internal/adapter/claude/standard_request_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func TestNormalizeClaudeRequest(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"stream": true,
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if norm.Standard.ResolvedModel == "" {
|
||||
t.Fatalf("expected resolved model")
|
||||
}
|
||||
if !norm.Standard.Stream {
|
||||
t.Fatalf("expected stream=true")
|
||||
}
|
||||
if len(norm.Standard.ToolNames) == 0 {
|
||||
t.Fatalf("expected tool names")
|
||||
}
|
||||
if norm.Standard.FinalPrompt == "" {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
138
internal/adapter/openai/embeddings_handler.go
Normal file
138
internal/adapter/openai/embeddings_handler.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.")
|
||||
return
|
||||
}
|
||||
if _, ok := config.ResolveModel(h.Store, model); !ok {
|
||||
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
return
|
||||
}
|
||||
|
||||
inputs := extractEmbeddingInputs(req["input"])
|
||||
if len(inputs) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include non-empty 'input'.")
|
||||
return
|
||||
}
|
||||
|
||||
provider := ""
|
||||
if h.Store != nil {
|
||||
provider = strings.ToLower(strings.TrimSpace(h.Store.EmbeddingsProvider()))
|
||||
}
|
||||
if provider == "" {
|
||||
writeOpenAIError(w, http.StatusNotImplemented, "Embeddings provider is not configured. Set embeddings.provider in config.")
|
||||
return
|
||||
}
|
||||
switch provider {
|
||||
case "mock", "deterministic", "builtin":
|
||||
// supported local deterministic provider
|
||||
default:
|
||||
writeOpenAIError(w, http.StatusNotImplemented, fmt.Sprintf("Embeddings provider '%s' is not supported.", provider))
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]map[string]any, 0, len(inputs))
|
||||
totalTokens := 0
|
||||
for i, input := range inputs {
|
||||
totalTokens += util.EstimateTokens(input)
|
||||
data = append(data, map[string]any{
|
||||
"object": "embedding",
|
||||
"index": i,
|
||||
"embedding": deterministicEmbedding(input),
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": model,
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": totalTokens,
|
||||
"total_tokens": totalTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func extractEmbeddingInputs(raw any) []string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(v)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{s}
|
||||
case []any:
|
||||
out := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
switch iv := item.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(iv)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
case []any:
|
||||
// Token array input support: convert to stable string form.
|
||||
out = append(out, fmt.Sprintf("%v", iv))
|
||||
default:
|
||||
s := strings.TrimSpace(fmt.Sprintf("%v", iv))
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func deterministicEmbedding(input string) []float64 {
|
||||
// Keep response shape stable without external dependencies.
|
||||
const dims = 64
|
||||
out := make([]float64, dims)
|
||||
seed := sha256.Sum256([]byte(input))
|
||||
buf := seed[:]
|
||||
for i := 0; i < dims; i++ {
|
||||
if len(buf) < 4 {
|
||||
next := sha256.Sum256(buf)
|
||||
buf = next[:]
|
||||
}
|
||||
v := binary.BigEndian.Uint32(buf[:4])
|
||||
buf = buf[4:]
|
||||
// map [0, 2^32) -> [-1, 1]
|
||||
out[i] = (float64(v)/2147483647.5 - 1.0)
|
||||
}
|
||||
return out
|
||||
}
|
||||
96
internal/adapter/openai/embeddings_route_test.go
Normal file
96
internal/adapter/openai/embeddings_route_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", cfgJSON)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "unused", nil
|
||||
})
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func TestEmbeddingsRouteContract(t *testing.T) {
|
||||
store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
t.Run("unauthorized", func(t *testing.T) {
|
||||
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if out["object"] != "list" {
|
||||
t.Fatalf("unexpected object: %#v", out["object"])
|
||||
}
|
||||
data, _ := out["data"].([]any)
|
||||
if len(data) != 2 {
|
||||
t.Fatalf("expected 2 embeddings, got %d", len(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddingsRouteProviderMissing(t *testing.T) {
|
||||
store, resolver := newResolverWithConfigJSON(t, `{}`)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if _, ok := errObj["code"]; !ok {
|
||||
t.Fatalf("expected error.code in response: %#v", out)
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatalf("expected error.param in response: %#v", out)
|
||||
}
|
||||
}
|
||||
35
internal/adapter/openai/error_shape_test.go
Normal file
35
internal/adapter/openai/error_shape_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
writeOpenAIError(rec, http.StatusBadRequest, "invalid input")
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode body: %v", err)
|
||||
}
|
||||
errObj, _ := body["error"].(map[string]any)
|
||||
if errObj["message"] != "invalid input" {
|
||||
t.Fatalf("unexpected message: %v", errObj["message"])
|
||||
}
|
||||
if errObj["type"] != "invalid_request_error" {
|
||||
t.Fatalf("unexpected type: %v", errObj["type"])
|
||||
}
|
||||
if errObj["code"] != "invalid_request" {
|
||||
t.Fatalf("unexpected code: %v", errObj["code"])
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatal("expected param field")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
@@ -30,6 +31,8 @@ type Handler struct {
|
||||
|
||||
leaseMu sync.Mutex
|
||||
streamLeases map[string]streamLease
|
||||
responsesMu sync.Mutex
|
||||
responses *responseStore
|
||||
}
|
||||
|
||||
type streamLease struct {
|
||||
@@ -39,13 +42,27 @@ type streamLease struct {
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/v1/models", h.ListModels)
|
||||
r.Get("/v1/models/{model_id}", h.GetModel)
|
||||
r.Post("/v1/chat/completions", h.ChatCompletions)
|
||||
r.Post("/v1/responses", h.Responses)
|
||||
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
|
||||
r.Post("/v1/embeddings", h.Embeddings)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
|
||||
}
|
||||
|
||||
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
|
||||
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
|
||||
model, ok := config.OpenAIModelByID(h.Store, modelID)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, model)
|
||||
}
|
||||
|
||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
if isVercelStreamReleaseRequest(r) {
|
||||
h.handleVercelStreamRelease(w, r)
|
||||
@@ -74,24 +91,11 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
return
|
||||
}
|
||||
|
||||
messages := normalizeMessages(messagesRaw)
|
||||
toolNames := []string{}
|
||||
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools)
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(messages)
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
@@ -107,27 +111,20 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if util.ToBool(req["stream"]) {
|
||||
h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
if stdReq.Stream {
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -139,36 +136,8 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := result.Text
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if thinkingEnabled && finalThinking != "" {
|
||||
messageObj["reasoning_content"] = finalThinking
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
respBody := util.BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
@@ -190,9 +159,11 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
|
||||
created := time.Now().Unix()
|
||||
firstChunkSent := false
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
var toolSieve toolStreamSieveState
|
||||
toolCallsEmitted := false
|
||||
streamToolCallIDs := map[int]string{}
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
@@ -235,13 +206,13 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": delta, "index": 0}},
|
||||
})
|
||||
sendChunk(util.BuildOpenAIChatStreamChunk(
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
[]map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
} else if bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&toolSieve, toolNames) {
|
||||
if evt.Content == "" {
|
||||
@@ -254,36 +225,25 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": delta, "index": 0}},
|
||||
})
|
||||
sendChunk(util.BuildOpenAIChatStreamChunk(
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
[]map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
}
|
||||
}
|
||||
if len(detected) > 0 || toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": map[string]any{}, "index": 0, "finish_reason": finishReason}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
sendChunk(util.BuildOpenAIChatStreamChunk(
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
[]map[string]any{util.BuildOpenAIChatStreamFinishChoice(0, finishReason)},
|
||||
util.BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText),
|
||||
))
|
||||
sendDone()
|
||||
}
|
||||
|
||||
@@ -357,6 +317,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
// Keep thinking delta only frame.
|
||||
}
|
||||
for _, evt := range events {
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs),
|
||||
}
|
||||
if !firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta))
|
||||
continue
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
@@ -366,10 +341,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
tcDelta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, map[string]any{
|
||||
"delta": tcDelta,
|
||||
"index": 0,
|
||||
})
|
||||
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta))
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
@@ -380,42 +352,22 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
contentDelta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, map[string]any{
|
||||
"delta": contentDelta,
|
||||
"index": 0,
|
||||
})
|
||||
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, contentDelta))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(delta) > 0 {
|
||||
newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0})
|
||||
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, delta))
|
||||
}
|
||||
}
|
||||
if len(newChoices) > 0 {
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": newChoices,
|
||||
})
|
||||
sendChunk(util.BuildOpenAIChatStreamChunk(completionID, created, model, newChoices, nil))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMessages(raw []any) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
m, ok := item.(map[string]any)
|
||||
if ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
@@ -444,7 +396,7 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
if len(toolSchemas) == 0 {
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }"
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error."
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
@@ -457,11 +409,47 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
return messages, names
|
||||
}
|
||||
|
||||
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name == "" && d.Arguments == "" {
|
||||
continue
|
||||
}
|
||||
callID, ok := ids[d.Index]
|
||||
if !ok || callID == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
ids[d.Index] = callID
|
||||
}
|
||||
item := map[string]any{
|
||||
"index": d.Index,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
}
|
||||
fn := map[string]any{}
|
||||
if d.Name != "" {
|
||||
fn["name"] = d.Name
|
||||
}
|
||||
if d.Arguments != "" {
|
||||
fn["arguments"] = d.Arguments
|
||||
}
|
||||
if len(fn) > 0 {
|
||||
item["function"] = fn
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": openAIErrorType(status),
|
||||
"code": openAIErrorCode(status),
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -485,3 +473,47 @@ func openAIErrorType(status int) string {
|
||||
return "invalid_request_error"
|
||||
}
|
||||
}
|
||||
|
||||
func openAIErrorCode(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request"
|
||||
case http.StatusUnauthorized:
|
||||
return "authentication_failed"
|
||||
case http.StatusForbidden:
|
||||
return "forbidden"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
return "not_found"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "service_unavailable"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "internal_error"
|
||||
}
|
||||
return "invalid_request"
|
||||
}
|
||||
}
|
||||
|
||||
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
|
||||
for k, v := range collectOpenAIChatPassThrough(req) {
|
||||
payload[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallFeatureMatchEnabled() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
|
||||
return mode == "" || mode == "feature_match"
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
|
||||
return level == "" || level == "high"
|
||||
}
|
||||
|
||||
@@ -100,6 +100,26 @@ func streamFinishReason(frames []map[string]any) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func streamToolCallArgumentChunks(frames []map[string]any) []string {
|
||||
out := make([]string, 0, 4)
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
if args, ok := fn["arguments"].(string); ok && args != "" {
|
||||
out = append(out, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -108,7 +128,7 @@ func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
@@ -141,7 +161,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"})
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
@@ -169,7 +189,7 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
@@ -190,6 +210,66 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) == 0 {
|
||||
t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||
}
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected fenced tool example to pass through as text, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -377,9 +457,9 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"前置正文A。"}`,
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"后置正文B。"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -392,10 +472,7 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String())
|
||||
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -409,9 +486,95 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") {
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
||||
t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
@@ -495,16 +658,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
got := strings.ToLower(content.String())
|
||||
if strings.Contains(got, "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String())
|
||||
}
|
||||
if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") {
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") {
|
||||
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) {
|
||||
func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
@@ -533,7 +696,42 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") {
|
||||
t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String())
|
||||
if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") {
|
||||
t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||
`data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
argChunks := streamToolCallArgumentChunks(frames)
|
||||
if len(argChunks) < 2 {
|
||||
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String())
|
||||
}
|
||||
joined := strings.Join(argChunks, "")
|
||||
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||
t.Fatalf("unexpected merged arguments stream: %q", joined)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
192
internal/adapter/openai/message_normalize.go
Normal file
192
internal/adapter/openai/message_normalize.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
msg, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
switch role {
|
||||
case "assistant":
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
toolCalls := formatAssistantToolCallsForPrompt(msg)
|
||||
combined := joinNonEmpty(content, toolCalls)
|
||||
if combined == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"content": combined,
|
||||
})
|
||||
case "tool", "function":
|
||||
out = append(out, map[string]any{
|
||||
"role": "user",
|
||||
"content": formatToolResultForPrompt(msg),
|
||||
})
|
||||
case "user", "system":
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
||||
})
|
||||
default:
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatAssistantToolCallsForPrompt(msg map[string]any) string {
|
||||
entries := make([]string, 0)
|
||||
if calls, ok := msg["tool_calls"].([]any); ok {
|
||||
for i, item := range calls {
|
||||
call, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id := strings.TrimSpace(asString(call["id"]))
|
||||
if id == "" {
|
||||
id = fmt.Sprintf("call_%d", i+1)
|
||||
}
|
||||
name := strings.TrimSpace(asString(call["name"]))
|
||||
args := ""
|
||||
|
||||
if fn, ok := call["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asString(fn["name"]))
|
||||
}
|
||||
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
|
||||
}
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
if args == "" {
|
||||
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
|
||||
}
|
||||
if args == "" {
|
||||
args = normalizeOpenAIArgumentsForPrompt(call["input"])
|
||||
}
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: %s\n- function.name: %s\n- function.arguments: %s", id, name, args))
|
||||
}
|
||||
}
|
||||
|
||||
if legacy, ok := msg["function_call"].(map[string]any); ok {
|
||||
name := strings.TrimSpace(asString(legacy["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"])
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: call_legacy\n- function.name: %s\n- function.arguments: %s", name, args))
|
||||
}
|
||||
|
||||
return strings.Join(entries, "\n\n")
|
||||
}
|
||||
|
||||
func formatToolResultForPrompt(msg map[string]any) string {
|
||||
toolCallID := strings.TrimSpace(asString(msg["tool_call_id"]))
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(asString(msg["id"]))
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "unknown"
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(asString(msg["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
if content == "" {
|
||||
content = "null"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Tool result:\n- tool_call_id: %s\n- name: %s\n- content: %s", toolCallID, name, content)
|
||||
}
|
||||
|
||||
func normalizeOpenAIContentForPrompt(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, item := range x {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
t := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
if t != "text" && t != "output_text" && t != "input_text" {
|
||||
continue
|
||||
}
|
||||
if text := asString(m["text"]); text != "" {
|
||||
parts = append(parts, text)
|
||||
continue
|
||||
}
|
||||
if text := asString(m["content"]); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return marshalToPromptString(v)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(x)
|
||||
default:
|
||||
return marshalToPromptString(v)
|
||||
}
|
||||
}
|
||||
|
||||
func marshalToPromptString(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func joinNonEmpty(parts ...string) string {
|
||||
nonEmpty := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if strings.TrimSpace(p) == "" {
|
||||
continue
|
||||
}
|
||||
nonEmpty = append(nonEmpty, p)
|
||||
}
|
||||
return strings.Join(nonEmpty, "\n\n")
|
||||
}
|
||||
121
internal/adapter/openai/message_normalize_test.go
Normal file
121
internal/adapter/openai/message_normalize_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{"role": "system", "content": "You are helpful"},
|
||||
map[string]any{"role": "user", "content": "查北京天气"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": nil,
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"beijing\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"name": "get_weather",
|
||||
"content": "{\"temp\":18}",
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
if len(normalized) != 4 {
|
||||
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
|
||||
}
|
||||
assistantContent, _ := normalized[2]["content"].(string)
|
||||
if !strings.Contains(assistantContent, "tool_call_id: call_1") ||
|
||||
!strings.Contains(assistantContent, "function.name: get_weather") ||
|
||||
!strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") {
|
||||
t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent)
|
||||
}
|
||||
toolContent, _ := normalized[3]["content"].(string)
|
||||
if !strings.Contains(toolContent, "Tool result:") || !strings.Contains(toolContent, "name: get_weather") {
|
||||
t.Fatalf("tool result not serialized correctly: %q", toolContent)
|
||||
}
|
||||
|
||||
prompt := util.MessagesPrepare(normalized)
|
||||
if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "Tool result:") {
|
||||
t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_2",
|
||||
"name": "get_weather",
|
||||
"content": map[string]any{
|
||||
"temp": 18,
|
||||
"condition": "sunny",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) {
|
||||
t.Fatalf("expected serialized object in tool content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_3",
|
||||
"name": "read_file",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_text", "text": "line-1"},
|
||||
map[string]any{"type": "output_text", "text": "line-2"},
|
||||
map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "line-1\nline-2") {
|
||||
t.Fatalf("expected joined text blocks, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "function",
|
||||
"tool_call_id": "call_4",
|
||||
"name": "legacy_tool",
|
||||
"content": map[string]any{
|
||||
"ok": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw)
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
if normalized[0]["role"] != "user" {
|
||||
t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"])
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) {
|
||||
t.Fatalf("unexpected normalized function-role content: %q", got)
|
||||
}
|
||||
}
|
||||
46
internal/adapter/openai/models_route_test.go
Normal file
46
internal/adapter/openai/models_route_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestGetModelRouteDirectAndAlias(t *testing.T) {
|
||||
h := &Handler{}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
t.Run("direct", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models/deepseek-chat", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models/gpt-4.1", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for alias, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetModelRouteNotFound(t *testing.T) {
|
||||
h := &Handler{}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models/not-exists", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
12
internal/adapter/openai/prompt_build.go
Normal file
12
internal/adapter/openai/prompt_build.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package openai
|
||||
|
||||
import "ds2api/internal/util"
|
||||
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw)
|
||||
toolNames := []string{}
|
||||
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools)
|
||||
}
|
||||
return util.MessagesPrepare(messages), toolNames
|
||||
}
|
||||
80
internal/adapter/openai/prompt_build_test.go
Normal file
80
internal/adapter/openai/prompt_build_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *testing.T) {
|
||||
messages := []any{
|
||||
map[string]any{"role": "user", "content": "查北京天气"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_1",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"beijing\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"name": "get_weather",
|
||||
"content": map[string]any{"temp": 18, "condition": "sunny"},
|
||||
},
|
||||
}
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools)
|
||||
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
|
||||
t.Fatalf("unexpected tool names: %#v", toolNames)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "tool_call_id: call_1") ||
|
||||
!strings.Contains(finalPrompt, "function.name: get_weather") ||
|
||||
!strings.Contains(finalPrompt, "Tool result:") ||
|
||||
!strings.Contains(finalPrompt, `"condition":"sunny"`) {
|
||||
t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) {
|
||||
messages := []any{
|
||||
map[string]any{"role": "system", "content": "You are helpful"},
|
||||
map[string]any{"role": "user", "content": "请调用工具"},
|
||||
}
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
"description": "search docs",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools)
|
||||
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
109
internal/adapter/openai/response_store.go
Normal file
109
internal/adapter/openai/response_store.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type storedResponse struct {
|
||||
Owner string
|
||||
Value map[string]any
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type responseStore struct {
|
||||
mu sync.Mutex
|
||||
ttl time.Duration
|
||||
items map[string]storedResponse
|
||||
}
|
||||
|
||||
func newResponseStore(ttl time.Duration) *responseStore {
|
||||
if ttl <= 0 {
|
||||
ttl = 15 * time.Minute
|
||||
}
|
||||
return &responseStore{
|
||||
ttl: ttl,
|
||||
items: make(map[string]storedResponse),
|
||||
}
|
||||
}
|
||||
|
||||
func responseStoreKey(owner, id string) string {
|
||||
return owner + "\x00" + id
|
||||
}
|
||||
|
||||
func responseStoreOwner(a *auth.RequestAuth) string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
return a.CallerID
|
||||
}
|
||||
|
||||
func (s *responseStore) put(owner, id string, value map[string]any) {
|
||||
if s == nil || owner == "" || id == "" || value == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sweepLocked(now)
|
||||
s.items[responseStoreKey(owner, id)] = storedResponse{
|
||||
Owner: owner,
|
||||
Value: cloneAnyMap(value),
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responseStore) get(owner, id string) (map[string]any, bool) {
|
||||
if s == nil || owner == "" || id == "" {
|
||||
return nil, false
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sweepLocked(now)
|
||||
item, ok := s.items[responseStoreKey(owner, id)]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if item.Owner != owner {
|
||||
return nil, false
|
||||
}
|
||||
return cloneAnyMap(item.Value), true
|
||||
}
|
||||
|
||||
func (s *responseStore) sweepLocked(now time.Time) {
|
||||
for k, v := range s.items {
|
||||
if now.After(v.ExpiresAt) {
|
||||
delete(s.items, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cloneAnyMap(in map[string]any) map[string]any {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) getResponseStore() *responseStore {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
h.responsesMu.Lock()
|
||||
defer h.responsesMu.Unlock()
|
||||
if h.responses == nil {
|
||||
ttl := 15 * time.Minute
|
||||
if h.Store != nil {
|
||||
ttl = time.Duration(h.Store.ResponsesStoreTTLSeconds()) * time.Second
|
||||
}
|
||||
h.responses = newResponseStore(ttl)
|
||||
}
|
||||
return h.responses
|
||||
}
|
||||
73
internal/adapter/openai/responses_embeddings_test.go
Normal file
73
internal/adapter/openai/responses_embeddings_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesString(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages("hello")
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "user" || m["content"] != "hello" {
|
||||
t.Fatalf("unexpected message: %#v", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "gpt-4.1",
|
||||
"input": "ping",
|
||||
"instructions": "system text",
|
||||
}
|
||||
msgs := responsesMessagesFromRequest(req)
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected two messages, got %d", len(msgs))
|
||||
}
|
||||
sys, _ := msgs[0].(map[string]any)
|
||||
if sys["role"] != "system" {
|
||||
t.Fatalf("unexpected first message: %#v", sys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmbeddingInputs(t *testing.T) {
|
||||
got := extractEmbeddingInputs([]any{"a", "b"})
|
||||
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
|
||||
t.Fatalf("unexpected inputs: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeterministicEmbeddingStable(t *testing.T) {
|
||||
a := deterministicEmbedding("hello")
|
||||
b := deterministicEmbedding("hello")
|
||||
if len(a) != 64 || len(b) != 64 {
|
||||
t.Fatalf("expected 64 dims, got %d and %d", len(a), len(b))
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
t.Fatalf("expected stable embedding at %d: %v != %v", i, a[i], b[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseStorePutGet(t *testing.T) {
|
||||
st := newResponseStore(100 * time.Millisecond)
|
||||
st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"})
|
||||
got, ok := st.get("owner_1", "resp_1")
|
||||
if !ok {
|
||||
t.Fatal("expected stored response")
|
||||
}
|
||||
if got["id"] != "resp_1" {
|
||||
t.Fatalf("unexpected response payload: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseStoreTenantIsolation(t *testing.T) {
|
||||
st := newResponseStore(100 * time.Millisecond)
|
||||
st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"})
|
||||
if _, ok := st.get("owner_b", "resp_1"); ok {
|
||||
t.Fatal("expected owner_b to be isolated from owner_a response")
|
||||
}
|
||||
}
|
||||
308
internal/adapter/openai/responses_handler.go
Normal file
308
internal/adapter/openai/responses_handler.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.DetermineCaller(r)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
id := strings.TrimSpace(chi.URLParam(r, "response_id"))
|
||||
if id == "" {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "response_id is required.")
|
||||
return
|
||||
}
|
||||
owner := responseStoreOwner(a)
|
||||
if owner == "" {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
st := h.getResponseStore()
|
||||
item, ok := st.get(owner, id)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusNotFound, "Response not found.")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, item)
|
||||
}
|
||||
|
||||
func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
owner := responseStoreOwner(a)
|
||||
if owner == "" {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if stdReq.Stream {
|
||||
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
responseObj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
canFlush := rc.Flush() == nil
|
||||
|
||||
sendEvent := func(event string, payload map[string]any) {
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = w.Write([]byte("event: " + event + "\n"))
|
||||
_, _ = w.Write([]byte("data: "))
|
||||
_, _ = w.Write(b)
|
||||
_, _ = w.Write([]byte("\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent("response.created", util.BuildOpenAIResponsesCreatedPayload(responseID, model))
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
var sieve toolStreamSieveState
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
toolCallsEmitted := false
|
||||
streamToolCallIDs := map[int]string{}
|
||||
|
||||
finalize := func() {
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
if bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&sieve, toolNames) {
|
||||
if evt.Content != "" {
|
||||
sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
toolCallsEmitted = true
|
||||
sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
|
||||
}
|
||||
}
|
||||
}
|
||||
obj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
if toolCallsEmitted {
|
||||
obj["status"] = "completed"
|
||||
}
|
||||
h.getResponseStore().put(owner, responseID, obj)
|
||||
sendEvent("response.completed", util.BuildOpenAIResponsesCompletedPayload(obj))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case parsed, ok := <-parsedLines:
|
||||
if !ok {
|
||||
_ = <-done
|
||||
finalize()
|
||||
return
|
||||
}
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
finalize()
|
||||
return
|
||||
}
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
if p.Type == "thinking" {
|
||||
if !thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
thinking.WriteString(p.Text)
|
||||
sendEvent("response.reasoning.delta", util.BuildOpenAIResponsesReasoningDeltaPayload(responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
text.WriteString(p.Text)
|
||||
if !bufferToolContent {
|
||||
sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
for _, evt := range processToolSieveChunk(&sieve, p.Text, toolNames) {
|
||||
if evt.Content != "" {
|
||||
sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
toolCallsEmitted = true
|
||||
sendEvent("response.output_tool_call.delta", util.BuildOpenAIResponsesToolCallDeltaPayload(responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs)))
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
toolCallsEmitted = true
|
||||
sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func responsesMessagesFromRequest(req map[string]any) []any {
|
||||
if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
|
||||
return prependInstructionMessage(msgs, req["instructions"])
|
||||
}
|
||||
if rawInput, ok := req["input"]; ok {
|
||||
if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
|
||||
return prependInstructionMessage(msgs, req["instructions"])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prependInstructionMessage(messages []any, instructions any) []any {
|
||||
sys, _ := instructions.(string)
|
||||
sys = strings.TrimSpace(sys)
|
||||
if sys == "" {
|
||||
return messages
|
||||
}
|
||||
out := make([]any, 0, len(messages)+1)
|
||||
out = append(out, map[string]any{"role": "system", "content": sys})
|
||||
out = append(out, messages...)
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeResponsesInputAsMessages(input any) []any {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return nil
|
||||
}
|
||||
return []any{map[string]any{"role": "user", "content": v}}
|
||||
case []any:
|
||||
if len(v) == 0 {
|
||||
return nil
|
||||
}
|
||||
// If caller already provides role-shaped items, keep as-is.
|
||||
if first, ok := v[0].(map[string]any); ok {
|
||||
if _, hasRole := first["role"]; hasRole {
|
||||
return v
|
||||
}
|
||||
}
|
||||
parts := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
parts = append(parts, txt)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}}
|
||||
case map[string]any:
|
||||
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": txt}}
|
||||
}
|
||||
if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": content}}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
176
internal/adapter/openai/responses_route_test.go
Normal file
176
internal/adapter/openai/responses_route_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "unused", nil
|
||||
})
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}]
|
||||
}`)
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0")
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "unused", nil
|
||||
})
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth {
|
||||
t.Helper()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
a, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine auth failed: %v", err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) {
|
||||
store, resolver := newDirectTokenResolver(t)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
ownerA := responseStoreOwner(authForToken(t, resolver, "token-a"))
|
||||
h.getResponseStore().put(ownerA, "resp_test", map[string]any{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
})
|
||||
|
||||
t.Run("unauthorized", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cross-tenant-not-found", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-b")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("same-tenant-ok", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-a")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode body failed: %v", err)
|
||||
}
|
||||
if body["id"] != "resp_test" {
|
||||
t.Fatalf("unexpected body: %#v", body)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponsesRouteValidationContract(t *testing.T) {
|
||||
store, resolver := newDirectTokenResolver(t)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{name: "missing_model", body: `{"input":"hello"}`},
|
||||
{name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body))
|
||||
req.Header.Set("Authorization", "Bearer token-a")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if _, ok := errObj["code"]; !ok {
|
||||
t.Fatalf("expected error.code: %#v", out)
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatalf("expected error.param: %#v", out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) {
|
||||
store, resolver := newManagedKeyResolver(t)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
ownerReq.Header.Set("Authorization", "Bearer managed-key")
|
||||
ownerAuth, err := resolver.DetermineCaller(ownerReq)
|
||||
if err != nil {
|
||||
t.Fatalf("determine caller failed: %v", err)
|
||||
}
|
||||
owner := responseStoreOwner(ownerAuth)
|
||||
h.getResponseStore().put(owner, "resp_test", map[string]any{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
})
|
||||
|
||||
occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
occupyReq.Header.Set("Authorization", "Bearer managed-key")
|
||||
occupied, err := resolver.Determine(occupyReq)
|
||||
if err != nil {
|
||||
t.Fatalf("expected first acquire to succeed: %v", err)
|
||||
}
|
||||
defer resolver.Release(occupied)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
122
internal/adapter/openai/responses_stream_test.go
Normal file
122
internal/adapter/openai/responses_stream_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`
|
||||
streamBody := sseLine(rawToolJSON) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
outputText, _ := responseObj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText)
|
||||
}
|
||||
output, _ := responseObj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "tool_calls" {
|
||||
t.Fatalf("expected first output type tool_calls, got %#v", first["type"])
|
||||
}
|
||||
toolCalls, _ := first["tool_calls"].([]any)
|
||||
if len(toolCalls) == 0 {
|
||||
t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"])
|
||||
}
|
||||
call0, _ := toolCalls[0].(map[string]any)
|
||||
if call0["name"] != "read_file" {
|
||||
t.Fatalf("unexpected tool call name: %#v", call0["name"])
|
||||
}
|
||||
if strings.Contains(outputText, `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamIncompleteTailNotDuplicatedInCompletedOutputText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
tail := `{"tool_calls":[{"name":"read_file","input":`
|
||||
streamBody := sseLine("Before ") + sseLine(tail) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"})
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
outputText, _ := responseObj["output_text"].(string)
|
||||
if strings.Count(outputText, tail) > 1 {
|
||||
t.Fatalf("expected incomplete tail not to be duplicated, got output_text=%q", outputText)
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||
matched = evt == targetEvent
|
||||
continue
|
||||
}
|
||||
if !matched || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if raw == "" || raw == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return payload, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
104
internal/adapter/openai/standard_request.go
Normal file
104
internal/adapter/openai/standard_request.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
|
||||
}
|
||||
resolvedModel, ok := config.ResolveModel(store, model)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
responseModel := strings.TrimSpace(model)
|
||||
if responseModel == "" {
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "openai_chat",
|
||||
RequestedModel: strings.TrimSpace(model),
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: responseModel,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.")
|
||||
}
|
||||
resolvedModel, ok := config.ResolveModel(store, model)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
|
||||
// Keep width-control as an explicit policy hook even if current default is true.
|
||||
allowWideInput := true
|
||||
if store != nil {
|
||||
allowWideInput = store.CompatWideInputStrictOutput()
|
||||
}
|
||||
var messagesRaw []any
|
||||
if allowWideInput {
|
||||
messagesRaw = responsesMessagesFromRequest(req)
|
||||
} else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
|
||||
messagesRaw = msgs
|
||||
}
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "openai_responses",
|
||||
RequestedModel: model,
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: model,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
|
||||
out := map[string]any{}
|
||||
for _, k := range []string{
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"stop",
|
||||
} {
|
||||
if v, ok := req[k]; ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
60
internal/adapter/openai/standard_request_test.go
Normal file
60
internal/adapter/openai/standard_request_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
return config.LoadStore()
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIChatRequest(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-5-codex",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"temperature": 0.3,
|
||||
"stream": true,
|
||||
}
|
||||
n, err := normalizeOpenAIChatRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ResolvedModel != "deepseek-reasoner" {
|
||||
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
|
||||
}
|
||||
if !n.Stream {
|
||||
t.Fatalf("expected stream=true")
|
||||
}
|
||||
if _, ok := n.PassThrough["temperature"]; !ok {
|
||||
t.Fatalf("expected temperature passthrough")
|
||||
}
|
||||
if n.FinalPrompt == "" {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"instructions": "system",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ResolvedModel != "deepseek-chat" {
|
||||
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
|
||||
}
|
||||
if len(n.Messages) != 2 {
|
||||
t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages))
|
||||
}
|
||||
}
|
||||
@@ -7,14 +7,40 @@ import (
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
Content string
|
||||
ToolCalls []util.ParsedToolCall
|
||||
Content string
|
||||
ToolCalls []util.ParsedToolCall
|
||||
ToolCallDeltas []toolCallDelta
|
||||
}
|
||||
|
||||
type toolCallDelta struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
const toolSieveCaptureLimit = 8 * 1024
|
||||
const toolSieveContextTailLimit = 256
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.toolNameSent = false
|
||||
s.toolName = ""
|
||||
s.toolArgsStart = -1
|
||||
s.toolArgsSent = -1
|
||||
s.toolArgsString = false
|
||||
s.toolArgsDone = false
|
||||
}
|
||||
|
||||
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
||||
@@ -32,13 +58,27 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
||||
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||
if !ready {
|
||||
if state.capture.Len() > toolSieveCaptureLimit {
|
||||
content := state.capture.String()
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
@@ -58,11 +98,13 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.capture.WriteString(pending[start:])
|
||||
state.capturing = true
|
||||
state.resetIncrementalToolState()
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -72,6 +114,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.pending.WriteString(hold)
|
||||
state.noteText(safe)
|
||||
events = append(events, toolStreamEvent{Content: safe})
|
||||
}
|
||||
|
||||
@@ -84,25 +127,34 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if state.capturing {
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||
if ready {
|
||||
if consumedPrefix != "" {
|
||||
state.noteText(consumedPrefix)
|
||||
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
||||
}
|
||||
if len(consumedCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
||||
}
|
||||
if consumedSuffix != "" {
|
||||
state.noteText(consumedSuffix)
|
||||
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
||||
}
|
||||
} else {
|
||||
// Incomplete captured tool JSON at stream end: suppress raw capture.
|
||||
content := state.capture.String()
|
||||
if content != "" {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
events = append(events, toolStreamEvent{Content: state.pending.String()})
|
||||
content := state.pending.String()
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
@@ -144,17 +196,26 @@ func findToolSegmentStart(s string) int {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return -1
|
||||
offset := 0
|
||||
for {
|
||||
keyRel := strings.Index(lower[offset:], "tool_calls")
|
||||
if keyRel < 0 {
|
||||
return -1
|
||||
}
|
||||
keyIdx := offset + keyRel
|
||||
start := strings.LastIndex(s[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
start = keyIdx
|
||||
}
|
||||
if !insideCodeFence(s[:start]) {
|
||||
return start
|
||||
}
|
||||
offset = keyIdx + len("tool_calls")
|
||||
}
|
||||
if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 {
|
||||
return start
|
||||
}
|
||||
return keyIdx
|
||||
}
|
||||
|
||||
func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return "", nil, "", false
|
||||
}
|
||||
@@ -171,13 +232,25 @@ func consumeToolCapture(captured string, toolNames []string) (prefix string, cal
|
||||
if !ok {
|
||||
return "", nil, "", false
|
||||
}
|
||||
parsed := util.ParseToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
// `tool_calls` key exists but strict JSON parse failed.
|
||||
// Drop the captured object body to avoid leaking raw tool JSON.
|
||||
return captured[:start], nil, captured[end:], true
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return captured[:start], parsed, captured[end:], true
|
||||
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
if state.toolNameSent {
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return captured, nil, "", true
|
||||
}
|
||||
if state.toolNameSent {
|
||||
if len(parsed) > 1 {
|
||||
return prefixPart, parsed[1:], suffixPart, true
|
||||
}
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
@@ -221,3 +294,352 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return nil
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return nil
|
||||
}
|
||||
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
||||
return nil
|
||||
}
|
||||
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
deltas := make([]toolCallDelta, 0, 2)
|
||||
if state.toolName == "" {
|
||||
name, ok := extractToolCallName(captured, callStart)
|
||||
if !ok || name == "" {
|
||||
return nil
|
||||
}
|
||||
state.toolName = name
|
||||
}
|
||||
if state.toolArgsStart < 0 {
|
||||
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
|
||||
if ok {
|
||||
state.toolArgsString = stringMode
|
||||
if stringMode {
|
||||
state.toolArgsStart = argsStart + 1
|
||||
} else {
|
||||
state.toolArgsStart = argsStart
|
||||
}
|
||||
state.toolArgsSent = state.toolArgsStart
|
||||
}
|
||||
}
|
||||
if !state.toolNameSent {
|
||||
if state.toolArgsStart < 0 {
|
||||
return nil
|
||||
}
|
||||
state.toolNameSent = true
|
||||
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
|
||||
}
|
||||
if state.toolArgsStart < 0 || state.toolArgsDone {
|
||||
return deltas
|
||||
}
|
||||
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
|
||||
if !ok {
|
||||
return deltas
|
||||
}
|
||||
if end > state.toolArgsSent {
|
||||
deltas = append(deltas, toolCallDelta{
|
||||
Index: 0,
|
||||
Arguments: captured[state.toolArgsSent:end],
|
||||
})
|
||||
state.toolArgsSent = end
|
||||
}
|
||||
if complete {
|
||||
state.toolArgsDone = true
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
|
||||
i := keyIdx + len("tool_calls")
|
||||
for i < len(text) && text[i] != ':' {
|
||||
i++
|
||||
}
|
||||
if i >= len(text) {
|
||||
return -1, false
|
||||
}
|
||||
i = skipSpaces(text, i+1)
|
||||
if i >= len(text) || text[i] != '[' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func extractToolCallName(text string, callStart int) (string, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return "", false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
name, _, ok := parseJSONStringLiteral(text, valueStart)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return name, true
|
||||
}
|
||||
|
||||
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
|
||||
keys := []string{"input", "arguments", "args", "parameters", "params"}
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
|
||||
if !ok {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return -1, false, false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
|
||||
if !ok {
|
||||
return -1, false, false
|
||||
}
|
||||
}
|
||||
if valueStart >= len(text) {
|
||||
return -1, false, false
|
||||
}
|
||||
ch := text[valueStart]
|
||||
if ch == '{' || ch == '[' {
|
||||
return valueStart, false, true
|
||||
}
|
||||
if ch == '"' {
|
||||
return valueStart, true, true
|
||||
}
|
||||
return -1, false, false
|
||||
}
|
||||
|
||||
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
|
||||
if start < 0 || start > len(text) {
|
||||
return 0, false, false
|
||||
}
|
||||
if stringMode {
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return i, true, true
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
if start >= len(text) {
|
||||
return start, false, false
|
||||
}
|
||||
if text[start] != '{' && text[start] != '[' {
|
||||
return 0, false, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' || ch == '[' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' || ch == ']' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i + 1, true, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
|
||||
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
|
||||
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
|
||||
return 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := objStart; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
if depth == 1 {
|
||||
key, end, ok := parseJSONStringLiteral(text, i)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
j := skipSpaces(text, end)
|
||||
if j >= len(text) || text[j] != ':' {
|
||||
i = end - 1
|
||||
continue
|
||||
}
|
||||
j = skipSpaces(text, j+1)
|
||||
if j >= len(text) {
|
||||
return 0, false
|
||||
}
|
||||
if containsKey(keys, key) {
|
||||
return j, true
|
||||
}
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func findFunctionObjectStart(text string, callStart int) (int, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return valueStart, true
|
||||
}
|
||||
|
||||
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '"' {
|
||||
return "", 0, false
|
||||
}
|
||||
var b strings.Builder
|
||||
escaped := false
|
||||
for i := start + 1; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
b.WriteByte(ch)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return b.String(), i + 1, true
|
||||
}
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func containsKey(keys []string, value string) bool {
|
||||
for _, k := range keys {
|
||||
if k == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func skipSpaces(text string, i int) int {
|
||||
for i < len(text) {
|
||||
switch text[i] {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
i++
|
||||
default:
|
||||
return i
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func (s *toolStreamSieveState) noteText(content string) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
|
||||
}
|
||||
|
||||
func appendTail(prev, next string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
combined := prev + next
|
||||
if len(combined) <= max {
|
||||
return combined
|
||||
}
|
||||
return combined[len(combined)-max:]
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
return insideCodeFence(text)
|
||||
}
|
||||
|
||||
func insideCodeFence(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Count(text, "```")%2 == 1
|
||||
}
|
||||
|
||||
@@ -56,24 +56,16 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
if !stdReq.Stream {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
}
|
||||
|
||||
messages := normalizeMessages(messagesRaw)
|
||||
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
|
||||
messages, _ = injectToolPrompt(messages, tools)
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(messages)
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
@@ -93,14 +85,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
leaseID := h.holdStreamLease(a)
|
||||
if leaseID == "" {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease")
|
||||
@@ -108,15 +93,18 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
leased = true
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"session_id": sessionID,
|
||||
"lease_id": leaseID,
|
||||
"model": model,
|
||||
"final_prompt": finalPrompt,
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
"deepseek_token": a.DeepSeekToken,
|
||||
"pow_header": powHeader,
|
||||
"payload": payload,
|
||||
"session_id": sessionID,
|
||||
"lease_id": leaseID,
|
||||
"model": stdReq.ResponseModel,
|
||||
"final_prompt": stdReq.FinalPrompt,
|
||||
"thinking_enabled": stdReq.Thinking,
|
||||
"search_enabled": stdReq.Search,
|
||||
"tool_names": stdReq.ToolNames,
|
||||
"toolcall_feature_match": h.toolcallFeatureMatchEnabled(),
|
||||
"toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(),
|
||||
"deepseek_token": a.DeepSeekToken,
|
||||
"pow_header": powHeader,
|
||||
"payload": payload,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +56,14 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
preview = token
|
||||
}
|
||||
}
|
||||
items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview})
|
||||
items = append(items, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"has_password": acc.Password != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||
}
|
||||
@@ -94,7 +101,7 @@ func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, a := range c.Accounts {
|
||||
if a.Email == identifier || a.Mobile == identifier {
|
||||
if accountMatchesIdentifier(a, identifier) {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
@@ -122,10 +129,10 @@ func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
identifier, _ := req["identifier"].(string)
|
||||
if strings.TrimSpace(identifier) == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(email 或 mobile)"})
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
||||
return
|
||||
}
|
||||
acc, ok := h.Store.FindAccount(identifier)
|
||||
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||
return
|
||||
|
||||
138
internal/admin/handler_accounts_identifier_test.go
Normal file
138
internal/admin/handler_accounts_identifier_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newAdminTestHandler(t *testing.T, raw string) *Handler {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", raw)
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
store := config.LoadStore()
|
||||
return &Handler{
|
||||
Store: store,
|
||||
Pool: account.NewPool(store),
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/accounts?page=1&page_size=10", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
h.listAccounts(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
items, _ := payload["items"].([]any)
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("expected 1 item, got %d", len(items))
|
||||
}
|
||||
first, _ := items[0].(map[string]any)
|
||||
identifier, _ := first["identifier"].(string)
|
||||
if identifier == "" {
|
||||
t.Fatalf("expected non-empty identifier: %#v", first)
|
||||
}
|
||||
if !strings.HasPrefix(identifier, "token:") {
|
||||
t.Fatalf("expected token synthetic identifier, got %q", identifier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAccountSupportsTokenOnlyIdentifier(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
}`)
|
||||
accounts := h.Store.Accounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||
}
|
||||
id := accounts[0].Identifier()
|
||||
if id == "" {
|
||||
t.Fatal("expected token-only synthetic identifier")
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape(id), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := len(h.Store.Accounts()); got != 0 {
|
||||
t.Fatalf("expected account removed, remaining=%d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAccountSupportsMobileAlias(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"email":"u@example.com","mobile":"13800138000","password":"pwd"}]
|
||||
}`)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/13800138000", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := len(h.Store.Accounts()); got != 0 {
|
||||
t.Fatalf("expected account removed, remaining=%d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[
|
||||
{"email":"u@example.com","mobile":"13800138000","password":"pwd"},
|
||||
{"token":"token-only-account"}
|
||||
]
|
||||
}`)
|
||||
|
||||
accByMobile, ok := findAccountByIdentifier(h.Store, "13800138000")
|
||||
if !ok {
|
||||
t.Fatal("expected find by mobile")
|
||||
}
|
||||
if accByMobile.Email != "u@example.com" {
|
||||
t.Fatalf("unexpected account by mobile: %#v", accByMobile)
|
||||
}
|
||||
|
||||
tokenOnlyID := ""
|
||||
for _, acc := range h.Store.Accounts() {
|
||||
if strings.TrimSpace(acc.Email) == "" && strings.TrimSpace(acc.Mobile) == "" {
|
||||
tokenOnlyID = acc.Identifier()
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenOnlyID == "" {
|
||||
t.Fatal("expected token-only account identifier")
|
||||
}
|
||||
accByTokenOnly, ok := findAccountByIdentifier(h.Store, tokenOnlyID)
|
||||
if !ok {
|
||||
t.Fatalf("expected find by token-only id=%q", tokenOnlyID)
|
||||
}
|
||||
if accByTokenOnly.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token-only account: %#v", accByTokenOnly)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,7 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
}
|
||||
accounts = append(accounts, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"has_password": strings.TrimSpace(acc.Password) != "",
|
||||
|
||||
@@ -81,3 +81,34 @@ func statusOr(v int, d int) int {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func accountMatchesIdentifier(acc config.Account, identifier string) bool {
|
||||
id := strings.TrimSpace(identifier)
|
||||
if id == "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(acc.Email) == id {
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(acc.Mobile) == id {
|
||||
return true
|
||||
}
|
||||
return acc.Identifier() == id
|
||||
}
|
||||
|
||||
func findAccountByIdentifier(store *config.Store, identifier string) (config.Account, bool) {
|
||||
id := strings.TrimSpace(identifier)
|
||||
if id == "" {
|
||||
return config.Account{}, false
|
||||
}
|
||||
if acc, ok := store.FindAccount(id); ok {
|
||||
return acc, true
|
||||
}
|
||||
accounts := store.Snapshot().Accounts
|
||||
for _, acc := range accounts {
|
||||
if accountMatchesIdentifier(acc, id) {
|
||||
return acc, true
|
||||
}
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
240
internal/admin/helpers_edge_test.go
Normal file
240
internal/admin/helpers_edge_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── reverseAccounts ─────────────────────────────────────────────────
|
||||
|
||||
func TestReverseAccountsEmpty(t *testing.T) {
|
||||
a := []config.Account{}
|
||||
reverseAccounts(a)
|
||||
if len(a) != 0 {
|
||||
t.Fatal("expected empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseAccountsTwoElements(t *testing.T) {
|
||||
a := []config.Account{
|
||||
{Email: "a@test.com"},
|
||||
{Email: "b@test.com"},
|
||||
}
|
||||
reverseAccounts(a)
|
||||
if a[0].Email != "b@test.com" || a[1].Email != "a@test.com" {
|
||||
t.Fatalf("unexpected order after reverse: %v", a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseAccountsThreeElements(t *testing.T) {
|
||||
a := []config.Account{
|
||||
{Email: "1@test.com"},
|
||||
{Email: "2@test.com"},
|
||||
{Email: "3@test.com"},
|
||||
}
|
||||
reverseAccounts(a)
|
||||
if a[0].Email != "3@test.com" || a[1].Email != "2@test.com" || a[2].Email != "1@test.com" {
|
||||
t.Fatalf("unexpected order: %v", a)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── intFromQuery edge cases ─────────────────────────────────────────
|
||||
|
||||
func TestIntFromQueryPresent(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?limit=5", nil)
|
||||
if got := intFromQuery(req, "limit", 10); got != 5 {
|
||||
t.Fatalf("expected 5, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFromQueryMissing(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if got := intFromQuery(req, "limit", 10); got != 10 {
|
||||
t.Fatalf("expected default 10, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFromQueryInvalid(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?limit=abc", nil)
|
||||
if got := intFromQuery(req, "limit", 10); got != 10 {
|
||||
t.Fatalf("expected default 10 for invalid, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFromQueryNegative(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?limit=-3", nil)
|
||||
if got := intFromQuery(req, "limit", 10); got != -3 {
|
||||
t.Fatalf("expected -3, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFromQueryZero(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?limit=0", nil)
|
||||
if got := intFromQuery(req, "limit", 10); got != 0 {
|
||||
t.Fatalf("expected 0, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── nilIfEmpty ──────────────────────────────────────────────────────
|
||||
|
||||
func TestNilIfEmptyEmpty(t *testing.T) {
|
||||
if nilIfEmpty("") != nil {
|
||||
t.Fatal("expected nil for empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfEmptyNonEmpty(t *testing.T) {
|
||||
if nilIfEmpty("hello") != "hello" {
|
||||
t.Fatal("expected 'hello'")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── nilIfZero ───────────────────────────────────────────────────────
|
||||
|
||||
func TestNilIfZeroZero(t *testing.T) {
|
||||
if nilIfZero(0) != nil {
|
||||
t.Fatal("expected nil for zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfZeroNonZero(t *testing.T) {
|
||||
if nilIfZero(42) != int64(42) {
|
||||
t.Fatal("expected 42")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfZeroNegative(t *testing.T) {
|
||||
if nilIfZero(-1) != int64(-1) {
|
||||
t.Fatal("expected -1")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── toStringSlice ───────────────────────────────────────────────────
|
||||
|
||||
func TestToStringSliceFromAnySlice(t *testing.T) {
|
||||
input := []any{"a", "b", "c"}
|
||||
got, ok := toStringSlice(input)
|
||||
if !ok || len(got) != 3 {
|
||||
t.Fatalf("expected 3 strings, got %#v ok=%v", got, ok)
|
||||
}
|
||||
if got[0] != "a" || got[1] != "b" || got[2] != "c" {
|
||||
t.Fatalf("unexpected values: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStringSliceFromMixed(t *testing.T) {
|
||||
input := []any{"hello", 42, true}
|
||||
got, ok := toStringSlice(input)
|
||||
if !ok {
|
||||
t.Fatal("expected ok for mixed types")
|
||||
}
|
||||
if got[0] != "hello" || got[1] != "42" || got[2] != "true" {
|
||||
t.Fatalf("unexpected values: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStringSliceFromNonSlice(t *testing.T) {
|
||||
_, ok := toStringSlice("not a slice")
|
||||
if ok {
|
||||
t.Fatal("expected not ok for string input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStringSliceFromNil(t *testing.T) {
|
||||
_, ok := toStringSlice(nil)
|
||||
if ok {
|
||||
t.Fatal("expected not ok for nil input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStringSliceEmpty(t *testing.T) {
|
||||
got, ok := toStringSlice([]any{})
|
||||
if !ok {
|
||||
t.Fatal("expected ok for empty slice")
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty result, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToStringSliceTrimsWhitespace(t *testing.T) {
|
||||
got, ok := toStringSlice([]any{" hello ", " world "})
|
||||
if !ok {
|
||||
t.Fatal("expected ok")
|
||||
}
|
||||
if got[0] != "hello" || got[1] != "world" {
|
||||
t.Fatalf("expected trimmed values, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── toAccount edge cases ────────────────────────────────────────────
|
||||
|
||||
func TestToAccountAllFields(t *testing.T) {
|
||||
acc := toAccount(map[string]any{
|
||||
"email": "user@test.com",
|
||||
"mobile": "13800138000",
|
||||
"password": "secret",
|
||||
"token": "tok123",
|
||||
})
|
||||
if acc.Email != "user@test.com" {
|
||||
t.Fatalf("unexpected email: %q", acc.Email)
|
||||
}
|
||||
if acc.Mobile != "13800138000" {
|
||||
t.Fatalf("unexpected mobile: %q", acc.Mobile)
|
||||
}
|
||||
if acc.Password != "secret" {
|
||||
t.Fatalf("unexpected password: %q", acc.Password)
|
||||
}
|
||||
if acc.Token != "tok123" {
|
||||
t.Fatalf("unexpected token: %q", acc.Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAccountNumericValues(t *testing.T) {
|
||||
acc := toAccount(map[string]any{
|
||||
"email": 12345,
|
||||
})
|
||||
if acc.Email != "12345" {
|
||||
t.Fatalf("expected numeric converted to string, got %q", acc.Email)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── fieldString edge cases ──────────────────────────────────────────
|
||||
|
||||
func TestFieldStringNonString(t *testing.T) {
|
||||
got := fieldString(map[string]any{"key": 42}, "key")
|
||||
if got != "42" {
|
||||
t.Fatalf("expected '42' for int, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldStringBool(t *testing.T) {
|
||||
got := fieldString(map[string]any{"key": true}, "key")
|
||||
if got != "true" {
|
||||
t.Fatalf("expected 'true', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldStringWhitespace(t *testing.T) {
|
||||
got := fieldString(map[string]any{"key": " hello "}, "key")
|
||||
if got != "hello" {
|
||||
t.Fatalf("expected trimmed 'hello', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── statusOr ────────────────────────────────────────────────────────
|
||||
|
||||
func TestStatusOrZeroReturnsDefault(t *testing.T) {
|
||||
if got := statusOr(0, http.StatusOK); got != http.StatusOK {
|
||||
t.Fatalf("expected %d, got %d", http.StatusOK, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusOrNonZeroReturnsValue(t *testing.T) {
|
||||
if got := statusOr(http.StatusBadRequest, http.StatusOK); got != http.StatusBadRequest {
|
||||
t.Fatalf("expected %d, got %d", http.StatusBadRequest, got)
|
||||
}
|
||||
}
|
||||
375
internal/auth/auth_edge_test.go
Normal file
375
internal/auth/auth_edge_test.go
Normal file
@@ -0,0 +1,375 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── extractCallerToken edge cases ───────────────────────────────────
|
||||
|
||||
func TestExtractCallerTokenBearerPrefix(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer my-token")
|
||||
if got := extractCallerToken(req); got != "my-token" {
|
||||
t.Fatalf("expected my-token, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCallerTokenBearerCaseInsensitive(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "BEARER My-Token")
|
||||
if got := extractCallerToken(req); got != "My-Token" {
|
||||
t.Fatalf("expected My-Token, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCallerTokenBearerEmpty(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer ")
|
||||
if got := extractCallerToken(req); got != "" {
|
||||
t.Fatalf("expected empty for 'Bearer ', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCallerTokenXAPIKey(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("x-api-key", "x-api-key-token")
|
||||
if got := extractCallerToken(req); got != "x-api-key-token" {
|
||||
t.Fatalf("expected x-api-key-token, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCallerTokenBearerPreferredOverXAPIKey(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer bearer-token")
|
||||
req.Header.Set("x-api-key", "x-api-key-token")
|
||||
if got := extractCallerToken(req); got != "bearer-token" {
|
||||
t.Fatalf("expected bearer-token, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCallerTokenMissingHeaders(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
if got := extractCallerToken(req); got != "" {
|
||||
t.Fatalf("expected empty for missing headers, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCallerTokenNonBearerAuth(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Basic abc123")
|
||||
if got := extractCallerToken(req); got != "" {
|
||||
t.Fatalf("expected empty for Basic auth, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Context helpers ─────────────────────────────────────────────────
|
||||
|
||||
func TestWithAuthAndFromContext(t *testing.T) {
|
||||
a := &RequestAuth{DeepSeekToken: "test-token"}
|
||||
ctx := WithAuth(context.Background(), a)
|
||||
got, ok := FromContext(ctx)
|
||||
if !ok || got.DeepSeekToken != "test-token" {
|
||||
t.Fatalf("expected token from context, got ok=%v token=%q", ok, got.DeepSeekToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromContextMissing(t *testing.T) {
|
||||
_, ok := FromContext(context.Background())
|
||||
if ok {
|
||||
t.Fatal("expected not ok from empty context")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── RefreshToken edge cases ─────────────────────────────────────────
|
||||
|
||||
func TestRefreshTokenNotConfigToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
a := &RequestAuth{UseConfigToken: false, resolver: r}
|
||||
if r.RefreshToken(context.Background(), a) {
|
||||
t.Fatal("expected false for non-config token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenEmptyAccountID(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
a := &RequestAuth{UseConfigToken: true, AccountID: "", resolver: r}
|
||||
if r.RefreshToken(context.Background(), a) {
|
||||
t.Fatal("expected false for empty account ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenSuccess(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
// First acquire an account
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
a, err := r.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
defer r.Release(a)
|
||||
|
||||
if !r.RefreshToken(context.Background(), a) {
|
||||
t.Fatal("expected refresh to succeed")
|
||||
}
|
||||
if a.DeepSeekToken != "fresh-token" {
|
||||
t.Fatalf("expected fresh-token after refresh, got %q", a.DeepSeekToken)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── MarkTokenInvalid edge cases ─────────────────────────────────────
|
||||
|
||||
func TestMarkTokenInvalidNotConfigToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
a := &RequestAuth{UseConfigToken: false, DeepSeekToken: "direct", resolver: r}
|
||||
r.MarkTokenInvalid(a)
|
||||
// Should not panic, token should be unchanged for non-config
|
||||
if a.DeepSeekToken != "" {
|
||||
// Actually it does clear it; that's fine - let's check behavior
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkTokenInvalidEmptyAccountID(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
a := &RequestAuth{UseConfigToken: true, AccountID: "", DeepSeekToken: "tok", resolver: r}
|
||||
r.MarkTokenInvalid(a)
|
||||
// Should not panic
|
||||
}
|
||||
|
||||
func TestMarkTokenInvalidClearsToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
a, err := r.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
defer r.Release(a)
|
||||
|
||||
r.MarkTokenInvalid(a)
|
||||
if a.DeepSeekToken != "" {
|
||||
t.Fatalf("expected empty token after invalidation, got %q", a.DeepSeekToken)
|
||||
}
|
||||
if a.Account.Token != "" {
|
||||
t.Fatalf("expected empty account token after invalidation, got %q", a.Account.Token)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── SwitchAccount edge cases ────────────────────────────────────────
|
||||
|
||||
func TestSwitchAccountNotConfigToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
a := &RequestAuth{UseConfigToken: false, resolver: r}
|
||||
if r.SwitchAccount(context.Background(), a) {
|
||||
t.Fatal("expected false for non-config token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchAccountNilTriedAccounts(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[
|
||||
{"email":"acc1@test.com","token":"t1"},
|
||||
{"email":"acc2@test.com","token":"t2"}
|
||||
]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "new-token", nil
|
||||
})
|
||||
|
||||
// First acquire
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
a, err := r.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
|
||||
oldID := a.AccountID
|
||||
a.TriedAccounts = nil // test nil initialization in SwitchAccount
|
||||
if !r.SwitchAccount(context.Background(), a) {
|
||||
t.Fatal("expected switch to succeed")
|
||||
}
|
||||
if a.AccountID == oldID {
|
||||
t.Fatalf("expected different account after switch")
|
||||
}
|
||||
r.Release(a)
|
||||
}
|
||||
|
||||
// ─── Release edge cases ─────────────────────────────────────────────
|
||||
|
||||
func TestReleaseNilAuth(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
r.Release(nil) // should not panic
|
||||
}
|
||||
|
||||
func TestReleaseNonConfigToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
a := &RequestAuth{UseConfigToken: false}
|
||||
r.Release(a) // should not panic
|
||||
}
|
||||
|
||||
func TestReleaseEmptyAccountID(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
a := &RequestAuth{UseConfigToken: true, AccountID: ""}
|
||||
r.Release(a) // should not panic
|
||||
}
|
||||
|
||||
// ─── JWT edge cases ──────────────────────────────────────────────────
|
||||
|
||||
func TestVerifyJWTInvalidFormat(t *testing.T) {
|
||||
_, err := VerifyJWT("not-a-jwt")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JWT format")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyJWTInvalidSignature(t *testing.T) {
|
||||
token, _ := CreateJWT(1)
|
||||
// Tamper with the signature
|
||||
parts := splitJWT(token)
|
||||
if len(parts) == 3 {
|
||||
tampered := parts[0] + "." + parts[1] + ".invalid_signature"
|
||||
_, err := VerifyJWT(tampered)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for tampered signature")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyJWTExpired(t *testing.T) {
|
||||
// Create a token with 0 hours expiry - will use default, so we can't easily test
|
||||
// Instead test with bad payload
|
||||
_, err := VerifyJWT("eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for expired/invalid JWT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateJWTDefaultExpiry(t *testing.T) {
|
||||
token, err := CreateJWT(0) // should use default
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt failed: %v", err)
|
||||
}
|
||||
_, err = VerifyJWT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("verify jwt failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── VerifyAdminRequest edge cases ───────────────────────────────────
|
||||
|
||||
func TestVerifyAdminRequestNoHeader(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
||||
if err := VerifyAdminRequest(req); err == nil {
|
||||
t.Fatal("expected error for missing auth")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAdminRequestEmptyBearer(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
||||
req.Header.Set("Authorization", "Bearer ")
|
||||
if err := VerifyAdminRequest(req); err == nil {
|
||||
t.Fatal("expected error for empty bearer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAdminRequestWithAdminKey(t *testing.T) {
|
||||
t.Setenv("DS2API_ADMIN_KEY", "test-admin-key")
|
||||
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
||||
req.Header.Set("Authorization", "Bearer test-admin-key")
|
||||
if err := VerifyAdminRequest(req); err != nil {
|
||||
t.Fatalf("expected admin key accepted: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAdminRequestInvalidCredentials(t *testing.T) {
|
||||
t.Setenv("DS2API_ADMIN_KEY", "correct-key")
|
||||
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
||||
req.Header.Set("Authorization", "Bearer wrong-key")
|
||||
if err := VerifyAdminRequest(req); err == nil {
|
||||
t.Fatal("expected error for wrong key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAdminRequestBasicAuth(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "/admin/config", nil)
|
||||
req.Header.Set("Authorization", "Basic abc123")
|
||||
if err := VerifyAdminRequest(req); err == nil {
|
||||
t.Fatal("expected error for Basic auth")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Determine with login failure ────────────────────────────────────
|
||||
|
||||
func TestDetermineWithLoginFailure(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[{"email":"acc@test.com","password":"pwd"}]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "", errors.New("login failed")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
_, err := r.Determine(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when login fails")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Determine with target account ───────────────────────────────────
|
||||
|
||||
func TestDetermineWithTargetAccount(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[
|
||||
{"email":"acc1@test.com","token":"t1"},
|
||||
{"email":"acc2@test.com","token":"t2"}
|
||||
]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "fresh-token", nil
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("POST", "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
req.Header.Set("X-Ds2-Target-Account", "acc2@test.com")
|
||||
a, err := r.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
defer r.Release(a)
|
||||
if a.AccountID != "acc2@test.com" {
|
||||
t.Fatalf("expected target account acc2, got %q", a.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
// helper
|
||||
func splitJWT(token string) []string {
|
||||
result := make([]string, 0, 3)
|
||||
start := 0
|
||||
count := 0
|
||||
for i := 0; i < len(token); i++ {
|
||||
if token[i] == '.' {
|
||||
result = append(result, token[start:i])
|
||||
start = i + 1
|
||||
count++
|
||||
}
|
||||
}
|
||||
result = append(result, token[start:])
|
||||
return result
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -22,6 +24,7 @@ var (
|
||||
type RequestAuth struct {
|
||||
UseConfigToken bool
|
||||
DeepSeekToken string
|
||||
CallerID string
|
||||
AccountID string
|
||||
Account config.Account
|
||||
TriedAccounts map[string]bool
|
||||
@@ -45,9 +48,16 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
if callerKey == "" {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
callerID := callerTokenID(callerKey)
|
||||
ctx := req.Context()
|
||||
if !r.Store.HasAPIKey(callerKey) {
|
||||
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
|
||||
return &RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: callerKey,
|
||||
CallerID: callerID,
|
||||
resolver: r,
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
|
||||
acc, ok := r.Pool.AcquireWait(ctx, target, nil)
|
||||
@@ -56,6 +66,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
}
|
||||
a := &RequestAuth{
|
||||
UseConfigToken: true,
|
||||
CallerID: callerID,
|
||||
AccountID: acc.Identifier(),
|
||||
Account: acc,
|
||||
TriedAccounts: map[string]bool{},
|
||||
@@ -72,6 +83,26 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// DetermineCaller resolves caller identity without acquiring any pooled account.
|
||||
// Use this for local-cache lookup routes that only need tenant isolation.
|
||||
func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) {
|
||||
callerKey := extractCallerToken(req)
|
||||
if callerKey == "" {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
callerID := callerTokenID(callerKey)
|
||||
a := &RequestAuth{
|
||||
UseConfigToken: false,
|
||||
CallerID: callerID,
|
||||
resolver: r,
|
||||
TriedAccounts: map[string]bool{},
|
||||
}
|
||||
if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) {
|
||||
a.DeepSeekToken = callerKey
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func WithAuth(ctx context.Context, a *RequestAuth) context.Context {
|
||||
return context.WithValue(ctx, authCtxKey, a)
|
||||
}
|
||||
@@ -158,3 +189,12 @@ func extractCallerToken(req *http.Request) string {
|
||||
}
|
||||
return strings.TrimSpace(req.Header.Get("x-api-key"))
|
||||
}
|
||||
|
||||
func callerTokenID(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return "caller:" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
@@ -37,6 +37,9 @@ func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) {
|
||||
if auth.DeepSeekToken != "direct-token" {
|
||||
t.Fatalf("unexpected token: %q", auth.DeepSeekToken)
|
||||
}
|
||||
if auth.CallerID == "" {
|
||||
t.Fatalf("expected caller id to be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
|
||||
@@ -58,6 +61,44 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
|
||||
if auth.DeepSeekToken != "account-token" {
|
||||
t.Fatalf("unexpected account token: %q", auth.DeepSeekToken)
|
||||
}
|
||||
if auth.CallerID == "" {
|
||||
t.Fatalf("expected caller id to be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineCallerWithManagedKeySkipsAccountAcquire(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
|
||||
a, err := r.DetermineCaller(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine caller failed: %v", err)
|
||||
}
|
||||
if a.CallerID == "" {
|
||||
t.Fatalf("expected caller id to be populated")
|
||||
}
|
||||
if a.UseConfigToken {
|
||||
t.Fatalf("expected no config-token lease for caller-only auth")
|
||||
}
|
||||
if a.AccountID != "" {
|
||||
t.Fatalf("expected empty account id, got %q", a.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallerTokenIDStable(t *testing.T) {
|
||||
a := callerTokenID("token-a")
|
||||
b := callerTokenID("token-a")
|
||||
c := callerTokenID("token-b")
|
||||
if a == "" || b == "" || c == "" {
|
||||
t.Fatalf("expected non-empty caller ids")
|
||||
}
|
||||
if a != b {
|
||||
t.Fatalf("expected stable caller id, got %q and %q", a, b)
|
||||
}
|
||||
if a == c {
|
||||
t.Fatalf("expected different caller id for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineMissingToken(t *testing.T) {
|
||||
@@ -72,3 +113,16 @@ func TestDetermineMissingToken(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineCallerMissingToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
|
||||
|
||||
_, err := r.DetermineCaller(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected unauthorized error")
|
||||
}
|
||||
if err != ErrUnauthorized {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -61,11 +62,33 @@ type Config struct {
|
||||
Accounts []Account `json:"accounts,omitempty"`
|
||||
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
|
||||
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
|
||||
ModelAliases map[string]string `json:"model_aliases,omitempty"`
|
||||
Compat CompatConfig `json:"compat,omitempty"`
|
||||
Toolcall ToolcallConfig `json:"toolcall,omitempty"`
|
||||
Responses ResponsesConfig `json:"responses,omitempty"`
|
||||
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
|
||||
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
||||
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
||||
AdditionalFields map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
type CompatConfig struct {
|
||||
WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"`
|
||||
}
|
||||
|
||||
type ToolcallConfig struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"`
|
||||
}
|
||||
|
||||
type ResponsesConfig struct {
|
||||
StoreTTLSeconds int `json:"store_ttl_seconds,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingsConfig struct {
|
||||
Provider string `json:"provider,omitempty"`
|
||||
}
|
||||
|
||||
func (c Config) MarshalJSON() ([]byte, error) {
|
||||
m := map[string]any{}
|
||||
for k, v := range c.AdditionalFields {
|
||||
@@ -83,6 +106,21 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
if len(c.ClaudeModelMap) > 0 {
|
||||
m["claude_model_mapping"] = c.ClaudeModelMap
|
||||
}
|
||||
if len(c.ModelAliases) > 0 {
|
||||
m["model_aliases"] = c.ModelAliases
|
||||
}
|
||||
if c.Compat.WideInputStrictOutput != nil {
|
||||
m["compat"] = c.Compat
|
||||
}
|
||||
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
|
||||
m["toolcall"] = c.Toolcall
|
||||
}
|
||||
if c.Responses.StoreTTLSeconds > 0 {
|
||||
m["responses"] = c.Responses
|
||||
}
|
||||
if strings.TrimSpace(c.Embeddings.Provider) != "" {
|
||||
m["embeddings"] = c.Embeddings
|
||||
}
|
||||
if c.VercelSyncHash != "" {
|
||||
m["_vercel_sync_hash"] = c.VercelSyncHash
|
||||
}
|
||||
@@ -101,17 +139,49 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
for k, v := range raw {
|
||||
switch k {
|
||||
case "keys":
|
||||
_ = json.Unmarshal(v, &c.Keys)
|
||||
if err := json.Unmarshal(v, &c.Keys); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "accounts":
|
||||
_ = json.Unmarshal(v, &c.Accounts)
|
||||
if err := json.Unmarshal(v, &c.Accounts); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "claude_mapping":
|
||||
_ = json.Unmarshal(v, &c.ClaudeMapping)
|
||||
if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "claude_model_mapping":
|
||||
_ = json.Unmarshal(v, &c.ClaudeModelMap)
|
||||
if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "model_aliases":
|
||||
if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "compat":
|
||||
if err := json.Unmarshal(v, &c.Compat); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "toolcall":
|
||||
if err := json.Unmarshal(v, &c.Toolcall); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "responses":
|
||||
if err := json.Unmarshal(v, &c.Responses); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "embeddings":
|
||||
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_hash":
|
||||
_ = json.Unmarshal(v, &c.VercelSyncHash)
|
||||
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_time":
|
||||
_ = json.Unmarshal(v, &c.VercelSyncTime)
|
||||
if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
default:
|
||||
var anyVal any
|
||||
if err := json.Unmarshal(v, &anyVal); err == nil {
|
||||
@@ -124,10 +194,17 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
|
||||
func (c Config) Clone() Config {
|
||||
clone := Config{
|
||||
Keys: slices.Clone(c.Keys),
|
||||
Accounts: slices.Clone(c.Accounts),
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
|
||||
Keys: slices.Clone(c.Keys),
|
||||
Accounts: slices.Clone(c.Accounts),
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
|
||||
ModelAliases: cloneStringMap(c.ModelAliases),
|
||||
Compat: CompatConfig{
|
||||
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
|
||||
},
|
||||
Toolcall: c.Toolcall,
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
VercelSyncHash: c.VercelSyncHash,
|
||||
VercelSyncTime: c.VercelSyncTime,
|
||||
AdditionalFields: map[string]any{},
|
||||
@@ -149,6 +226,14 @@ func cloneStringMap(in map[string]string) map[string]string {
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneBoolPtr(in *bool) *bool {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
v := *in
|
||||
return &v
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
cfg Config
|
||||
@@ -233,30 +318,94 @@ func loadConfig() (Config, bool, error) {
|
||||
|
||||
content, err := os.ReadFile(ConfigPath())
|
||||
if err != nil {
|
||||
if IsVercel() {
|
||||
// Vercel one-click deploy may start without a writable/present config file.
|
||||
// Keep an in-memory config so users can bootstrap via WebUI then sync env.
|
||||
return Config{}, true, nil
|
||||
}
|
||||
return Config{}, false, err
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, false, err
|
||||
}
|
||||
if IsVercel() {
|
||||
// Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
|
||||
return cfg, true, nil
|
||||
}
|
||||
return cfg, false, nil
|
||||
}
|
||||
|
||||
func parseConfigString(raw string) (Config, error) {
|
||||
var cfg Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err == nil {
|
||||
return cfg, nil
|
||||
candidates := []string{raw}
|
||||
if normalized := normalizeConfigInput(raw); normalized != raw {
|
||||
candidates = append(candidates, normalized)
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(raw)
|
||||
for _, candidate := range candidates {
|
||||
if err := json.Unmarshal([]byte(candidate), &cfg); err == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
}
|
||||
|
||||
base64Input := candidates[len(candidates)-1]
|
||||
decoded, err := decodeConfigBase64(base64Input)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(decoded, &cfg); err != nil {
|
||||
return Config{}, err
|
||||
return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func normalizeConfigInput(raw string) string {
|
||||
normalized := strings.TrimSpace(raw)
|
||||
if normalized == "" {
|
||||
return normalized
|
||||
}
|
||||
for {
|
||||
changed := false
|
||||
if len(normalized) >= 2 {
|
||||
first := normalized[0]
|
||||
last := normalized[len(normalized)-1]
|
||||
if (first == '"' && last == '"') || (first == '\'' && last == '\'') {
|
||||
normalized = strings.TrimSpace(normalized[1 : len(normalized)-1])
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(normalized), "base64:") {
|
||||
normalized = strings.TrimSpace(normalized[len("base64:"):])
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
break
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(normalized)
|
||||
}
|
||||
|
||||
func decodeConfigBase64(raw string) ([]byte, error) {
|
||||
encodings := []*base64.Encoding{
|
||||
base64.StdEncoding,
|
||||
base64.RawStdEncoding,
|
||||
base64.URLEncoding,
|
||||
base64.RawURLEncoding,
|
||||
}
|
||||
var lastErr error
|
||||
for _, enc := range encodings {
|
||||
decoded, err := enc.DecodeString(raw)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errors.New("base64 decode failed")
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() Config {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -413,3 +562,62 @@ func (s *Store) ClaudeMapping() map[string]string {
|
||||
}
|
||||
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
|
||||
}
|
||||
|
||||
func (s *Store) ModelAliases() map[string]string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
out := DefaultModelAliases()
|
||||
for k, v := range s.cfg.ModelAliases {
|
||||
key := strings.TrimSpace(lower(k))
|
||||
val := strings.TrimSpace(lower(v))
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = val
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Store) CompatWideInputStrictOutput() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Compat.WideInputStrictOutput == nil {
|
||||
return true
|
||||
}
|
||||
return *s.cfg.Compat.WideInputStrictOutput
|
||||
}
|
||||
|
||||
func (s *Store) ToolcallMode() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
|
||||
if mode == "" {
|
||||
return "feature_match"
|
||||
}
|
||||
return mode
|
||||
}
|
||||
|
||||
func (s *Store) ToolcallEarlyEmitConfidence() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
|
||||
if level == "" {
|
||||
return "high"
|
||||
}
|
||||
return level
|
||||
}
|
||||
|
||||
func (s *Store) ResponsesStoreTTLSeconds() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Responses.StoreTTLSeconds > 0 {
|
||||
return s.cfg.Responses.StoreTTLSeconds
|
||||
}
|
||||
return 900
|
||||
}
|
||||
|
||||
func (s *Store) EmbeddingsProvider() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return strings.TrimSpace(s.cfg.Embeddings.Provider)
|
||||
}
|
||||
|
||||
478
internal/config/config_edge_test.go
Normal file
478
internal/config/config_edge_test.go
Normal file
@@ -0,0 +1,478 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ─── GetModelConfig edge cases ───────────────────────────────────────
|
||||
|
||||
func TestGetModelConfigDeepSeekChat(t *testing.T) {
|
||||
thinking, search, ok := GetModelConfig("deepseek-chat")
|
||||
if !ok {
|
||||
t.Fatal("expected ok for deepseek-chat")
|
||||
}
|
||||
if thinking || search {
|
||||
t.Fatalf("expected no thinking/search for deepseek-chat, got thinking=%v search=%v", thinking, search)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfigDeepSeekReasoner(t *testing.T) {
|
||||
thinking, search, ok := GetModelConfig("deepseek-reasoner")
|
||||
if !ok {
|
||||
t.Fatal("expected ok for deepseek-reasoner")
|
||||
}
|
||||
if !thinking || search {
|
||||
t.Fatalf("expected thinking=true search=false, got thinking=%v search=%v", thinking, search)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfigDeepSeekChatSearch(t *testing.T) {
|
||||
thinking, search, ok := GetModelConfig("deepseek-chat-search")
|
||||
if !ok {
|
||||
t.Fatal("expected ok for deepseek-chat-search")
|
||||
}
|
||||
if thinking || !search {
|
||||
t.Fatalf("expected thinking=false search=true, got thinking=%v search=%v", thinking, search)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfigDeepSeekReasonerSearch(t *testing.T) {
|
||||
thinking, search, ok := GetModelConfig("deepseek-reasoner-search")
|
||||
if !ok {
|
||||
t.Fatal("expected ok for deepseek-reasoner-search")
|
||||
}
|
||||
if !thinking || !search {
|
||||
t.Fatalf("expected both true, got thinking=%v search=%v", thinking, search)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfigCaseInsensitive(t *testing.T) {
|
||||
thinking, search, ok := GetModelConfig("DeepSeek-Chat")
|
||||
if !ok {
|
||||
t.Fatal("expected ok for case-insensitive deepseek-chat")
|
||||
}
|
||||
if thinking || search {
|
||||
t.Fatalf("expected no thinking/search for case-insensitive deepseek-chat")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfigUnknownModel(t *testing.T) {
|
||||
_, _, ok := GetModelConfig("gpt-4")
|
||||
if ok {
|
||||
t.Fatal("expected not ok for unknown model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelConfigEmpty(t *testing.T) {
|
||||
_, _, ok := GetModelConfig("")
|
||||
if ok {
|
||||
t.Fatal("expected not ok for empty model")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── lower function ──────────────────────────────────────────────────
|
||||
|
||||
func TestLowerFunction(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Hello", "hello"},
|
||||
{"ALLCAPS", "allcaps"},
|
||||
{"already-lower", "already-lower"},
|
||||
{"Mixed-CASE-123", "mixed-case-123"},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := lower(tc.input)
|
||||
if got != tc.expected {
|
||||
t.Errorf("lower(%q) = %q, want %q", tc.input, got, tc.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Config.MarshalJSON / UnmarshalJSON roundtrip ────────────────────
|
||||
|
||||
func TestConfigJSONRoundtrip(t *testing.T) {
|
||||
cfg := Config{
|
||||
Keys: []string{"key1", "key2"},
|
||||
Accounts: []Account{{Email: "user@example.com", Password: "pass", Token: "tok"}},
|
||||
ClaudeMapping: map[string]string{
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner",
|
||||
},
|
||||
VercelSyncHash: "hash123",
|
||||
VercelSyncTime: 1234567890,
|
||||
AdditionalFields: map[string]any{
|
||||
"custom_field": "custom_value",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := cfg.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
|
||||
var decoded Config
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if len(decoded.Keys) != 2 || decoded.Keys[0] != "key1" {
|
||||
t.Fatalf("unexpected keys: %#v", decoded.Keys)
|
||||
}
|
||||
if len(decoded.Accounts) != 1 || decoded.Accounts[0].Email != "user@example.com" {
|
||||
t.Fatalf("unexpected accounts: %#v", decoded.Accounts)
|
||||
}
|
||||
if decoded.ClaudeMapping["fast"] != "deepseek-chat" {
|
||||
t.Fatalf("unexpected claude mapping: %#v", decoded.ClaudeMapping)
|
||||
}
|
||||
if decoded.VercelSyncHash != "hash123" {
|
||||
t.Fatalf("unexpected vercel sync hash: %q", decoded.VercelSyncHash)
|
||||
}
|
||||
if decoded.AdditionalFields["custom_field"] != "custom_value" {
|
||||
t.Fatalf("unexpected additional fields: %#v", decoded.AdditionalFields)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigUnmarshalJSONPreservesUnknownFields(t *testing.T) {
|
||||
raw := `{"keys":["k1"],"accounts":[],"my_custom_field":"hello","number_field":42}`
|
||||
var cfg Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if cfg.AdditionalFields["my_custom_field"] != "hello" {
|
||||
t.Fatalf("expected custom field preserved, got %#v", cfg.AdditionalFields)
|
||||
}
|
||||
// number_field should also be preserved
|
||||
if cfg.AdditionalFields["number_field"] != float64(42) {
|
||||
t.Fatalf("expected number field preserved, got %#v", cfg.AdditionalFields["number_field"])
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Config.Clone ────────────────────────────────────────────────────
|
||||
|
||||
func TestConfigCloneIsDeepCopy(t *testing.T) {
|
||||
cfg := Config{
|
||||
Keys: []string{"key1"},
|
||||
Accounts: []Account{{Email: "user@test.com", Token: "token"}},
|
||||
ClaudeMapping: map[string]string{
|
||||
"fast": "deepseek-chat",
|
||||
},
|
||||
AdditionalFields: map[string]any{"custom": "value"},
|
||||
}
|
||||
|
||||
cloned := cfg.Clone()
|
||||
|
||||
// Modify original
|
||||
cfg.Keys[0] = "modified"
|
||||
cfg.Accounts[0].Email = "modified@test.com"
|
||||
cfg.ClaudeMapping["fast"] = "modified-model"
|
||||
|
||||
// Cloned should not be affected
|
||||
if cloned.Keys[0] != "key1" {
|
||||
t.Fatalf("clone keys was affected by original change: %#v", cloned.Keys)
|
||||
}
|
||||
if cloned.Accounts[0].Email != "user@test.com" {
|
||||
t.Fatalf("clone accounts was affected: %#v", cloned.Accounts)
|
||||
}
|
||||
if cloned.ClaudeMapping["fast"] != "deepseek-chat" {
|
||||
t.Fatalf("clone claude mapping was affected: %#v", cloned.ClaudeMapping)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigCloneNilMaps(t *testing.T) {
|
||||
cfg := Config{
|
||||
Keys: []string{"k"},
|
||||
Accounts: nil,
|
||||
}
|
||||
cloned := cfg.Clone()
|
||||
if len(cloned.Keys) != 1 {
|
||||
t.Fatalf("unexpected keys length: %d", len(cloned.Keys))
|
||||
}
|
||||
if cloned.Accounts != nil {
|
||||
t.Fatalf("expected nil accounts in clone, got %#v", cloned.Accounts)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Account.Identifier edge cases ───────────────────────────────────
|
||||
|
||||
func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) {
|
||||
acc := Account{Mobile: "13800138000", Token: "tok"}
|
||||
if acc.Identifier() != "13800138000" {
|
||||
t.Fatalf("expected mobile identifier, got %q", acc.Identifier())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountIdentifierPreferenceEmailOverMobile(t *testing.T) {
|
||||
acc := Account{Email: "user@test.com", Mobile: "13800138000"}
|
||||
if acc.Identifier() != "user@test.com" {
|
||||
t.Fatalf("expected email identifier, got %q", acc.Identifier())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountIdentifierEmptyAccount(t *testing.T) {
|
||||
acc := Account{}
|
||||
if acc.Identifier() != "" {
|
||||
t.Fatalf("expected empty identifier for empty account, got %q", acc.Identifier())
|
||||
}
|
||||
}
|
||||
|
||||
// ─── normalizeConfigInput ────────────────────────────────────────────
|
||||
|
||||
func TestNormalizeConfigInputStripsQuotes(t *testing.T) {
|
||||
got := normalizeConfigInput(`"base64:abc"`)
|
||||
if strings.HasPrefix(got, `"`) || strings.HasSuffix(got, `"`) {
|
||||
t.Fatalf("expected quotes stripped, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeConfigInputStripsSingleQuotes(t *testing.T) {
|
||||
got := normalizeConfigInput("'some-value'")
|
||||
if strings.HasPrefix(got, "'") || strings.HasSuffix(got, "'") {
|
||||
t.Fatalf("expected single quotes stripped, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeConfigInputTrimsWhitespace(t *testing.T) {
|
||||
got := normalizeConfigInput(" hello ")
|
||||
if got != "hello" {
|
||||
t.Fatalf("expected trimmed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── parseConfigString edge cases ────────────────────────────────────
|
||||
|
||||
func TestParseConfigStringPlainJSON(t *testing.T) {
|
||||
cfg, err := parseConfigString(`{"keys":["k1"],"accounts":[]}`)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" {
|
||||
t.Fatalf("unexpected keys: %#v", cfg.Keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigStringBase64Prefix(t *testing.T) {
|
||||
rawJSON := `{"keys":["base64-key"],"accounts":[]}`
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON))
|
||||
cfg, err := parseConfigString("base64:" + b64)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(cfg.Keys) != 1 || cfg.Keys[0] != "base64-key" {
|
||||
t.Fatalf("unexpected keys: %#v", cfg.Keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigStringInvalidBase64(t *testing.T) {
|
||||
_, err := parseConfigString("base64:!!!invalid!!!")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid base64")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigStringEmptyString(t *testing.T) {
|
||||
_, err := parseConfigString("")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty string")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Store methods ───────────────────────────────────────────────────
|
||||
|
||||
func TestStoreSnapshotReturnsClone(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com","token":"t1"}]}`)
|
||||
store := LoadStore()
|
||||
snap := store.Snapshot()
|
||||
snap.Keys[0] = "modified"
|
||||
if store.Keys()[0] != "k1" {
|
||||
t.Fatal("snapshot modification should not affect store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreHasAPIKeyMultipleKeys(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["key1","key2","key3"],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
if !store.HasAPIKey("key1") {
|
||||
t.Fatal("expected key1 found")
|
||||
}
|
||||
if !store.HasAPIKey("key2") {
|
||||
t.Fatal("expected key2 found")
|
||||
}
|
||||
if !store.HasAPIKey("key3") {
|
||||
t.Fatal("expected key3 found")
|
||||
}
|
||||
if store.HasAPIKey("nonexistent") {
|
||||
t.Fatal("expected nonexistent key not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreFindAccountNotFound(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com"}]}`)
|
||||
store := LoadStore()
|
||||
_, ok := store.FindAccount("nonexistent@test.com")
|
||||
if ok {
|
||||
t.Fatal("expected account not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCompatWideInputStrictOutputDefaultTrue(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
if !store.CompatWideInputStrictOutput() {
|
||||
t.Fatal("expected default wide_input_strict_output=true when unset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCompatWideInputStrictOutputCanDisable(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[],"compat":{"wide_input_strict_output":false}}`)
|
||||
store := LoadStore()
|
||||
if store.CompatWideInputStrictOutput() {
|
||||
t.Fatal("expected wide_input_strict_output=false when explicitly configured")
|
||||
}
|
||||
|
||||
snap := store.Snapshot()
|
||||
data, err := snap.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
rawCompat, ok := out["compat"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected compat in marshaled output, got %#v", out)
|
||||
}
|
||||
if rawCompat["wide_input_strict_output"] != false {
|
||||
t.Fatalf("expected explicit false in compat, got %#v", rawCompat)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreIsEnvBacked(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
if !store.IsEnvBacked() {
|
||||
t.Fatal("expected env-backed store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreReplace(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
newCfg := Config{
|
||||
Keys: []string{"new-key"},
|
||||
Accounts: []Account{{Email: "new@test.com"}},
|
||||
}
|
||||
if err := store.Replace(newCfg); err != nil {
|
||||
t.Fatalf("replace error: %v", err)
|
||||
}
|
||||
if !store.HasAPIKey("new-key") {
|
||||
t.Fatal("expected new key after replace")
|
||||
}
|
||||
if store.HasAPIKey("k1") {
|
||||
t.Fatal("expected old key removed after replace")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdate(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
err := store.Update(func(cfg *Config) error {
|
||||
cfg.Keys = append(cfg.Keys, "k2")
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("update error: %v", err)
|
||||
}
|
||||
if !store.HasAPIKey("k2") {
|
||||
t.Fatal("expected k2 after update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreClaudeMapping(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`)
|
||||
store := LoadStore()
|
||||
mapping := store.ClaudeMapping()
|
||||
if mapping["fast"] != "deepseek-chat" {
|
||||
t.Fatalf("unexpected fast mapping: %q", mapping["fast"])
|
||||
}
|
||||
if mapping["slow"] != "deepseek-reasoner" {
|
||||
t.Fatalf("unexpected slow mapping: %q", mapping["slow"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreClaudeMappingEmpty(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
mapping := store.ClaudeMapping()
|
||||
// Even without config mapping, there are defaults
|
||||
if mapping == nil {
|
||||
t.Fatal("expected non-nil mapping (may contain defaults)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreSetVercelSync(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
if err := store.SetVercelSync("hash123", 1234567890); err != nil {
|
||||
t.Fatalf("setVercelSync error: %v", err)
|
||||
}
|
||||
snap := store.Snapshot()
|
||||
if snap.VercelSyncHash != "hash123" || snap.VercelSyncTime != 1234567890 {
|
||||
t.Fatalf("unexpected vercel sync: hash=%q time=%d", snap.VercelSyncHash, snap.VercelSyncTime)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreExportJSONAndBase64(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["export-key"],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
jsonStr, b64Str, err := store.ExportJSONAndBase64()
|
||||
if err != nil {
|
||||
t.Fatalf("export error: %v", err)
|
||||
}
|
||||
if !strings.Contains(jsonStr, "export-key") {
|
||||
t.Fatalf("expected JSON to contain key: %q", jsonStr)
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(b64Str)
|
||||
if err != nil {
|
||||
t.Fatalf("base64 decode error: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(decoded), "export-key") {
|
||||
t.Fatalf("expected base64-decoded to contain key: %q", string(decoded))
|
||||
}
|
||||
}
|
||||
|
||||
// ─── OpenAIModelsResponse / ClaudeModelsResponse ─────────────────────
|
||||
|
||||
func TestOpenAIModelsResponse(t *testing.T) {
|
||||
resp := OpenAIModelsResponse()
|
||||
if resp["object"] != "list" {
|
||||
t.Fatalf("unexpected object: %v", resp["object"])
|
||||
}
|
||||
data, ok := resp["data"].([]ModelInfo)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected data type: %T", resp["data"])
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal("expected non-empty models list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeModelsResponse(t *testing.T) {
|
||||
resp := ClaudeModelsResponse()
|
||||
if resp["object"] != "list" {
|
||||
t.Fatalf("unexpected object: %v", resp["object"])
|
||||
}
|
||||
data, ok := resp["data"].([]ModelInfo)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected data type: %T", resp["data"])
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal("expected non-empty models list")
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@@ -70,3 +71,53 @@ func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T)
|
||||
t.Fatalf("expected find by old identifier alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadStoreRejectsInvalidFieldType(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":"not-array","accounts":[]}`)
|
||||
store := LoadStore()
|
||||
if len(store.Keys()) != 0 || len(store.Accounts()) != 0 {
|
||||
t.Fatalf("expected empty store when config type is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigStringSupportsQuotedBase64Prefix(t *testing.T) {
|
||||
rawJSON := `{"keys":["k1"],"accounts":[{"email":"u@example.com","password":"p"}]}`
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON))
|
||||
cfg, err := parseConfigString(`"base64:` + b64 + `"`)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected parse error: %v", err)
|
||||
}
|
||||
if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" {
|
||||
t.Fatalf("unexpected keys: %#v", cfg.Keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigStringSupportsRawURLBase64(t *testing.T) {
|
||||
rawJSON := `{"keys":["k-url"],"accounts":[]}`
|
||||
b64 := base64.RawURLEncoding.EncodeToString([]byte(rawJSON))
|
||||
cfg, err := parseConfigString(b64)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected parse error: %v", err)
|
||||
}
|
||||
if len(cfg.Keys) != 1 || cfg.Keys[0] != "k-url" {
|
||||
t.Fatalf("unexpected keys: %#v", cfg.Keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigOnVercelWithoutConfigFileFallsBackToMemory(t *testing.T) {
|
||||
t.Setenv("VERCEL", "1")
|
||||
t.Setenv("DS2API_CONFIG_JSON", "")
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
t.Setenv("DS2API_CONFIG_PATH", "testdata/does-not-exist.json")
|
||||
|
||||
cfg, fromEnv, err := loadConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !fromEnv {
|
||||
t.Fatalf("expected fromEnv=true for vercel fallback")
|
||||
}
|
||||
if len(cfg.Keys) != 0 || len(cfg.Accounts) != 0 {
|
||||
t.Fatalf("expected empty bootstrap config, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts))
|
||||
}
|
||||
}
|
||||
|
||||
44
internal/config/model_alias_test.go
Normal file
44
internal/config/model_alias_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResolveModelDirectDeepSeek(t *testing.T) {
|
||||
got, ok := ResolveModel(nil, "deepseek-chat")
|
||||
if !ok || got != "deepseek-chat" {
|
||||
t.Fatalf("expected deepseek-chat, got ok=%v model=%q", ok, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelAlias(t *testing.T) {
|
||||
got, ok := ResolveModel(nil, "gpt-4.1")
|
||||
if !ok || got != "deepseek-chat" {
|
||||
t.Fatalf("expected alias gpt-4.1 -> deepseek-chat, got ok=%v model=%q", ok, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelHeuristicReasoner(t *testing.T) {
|
||||
got, ok := ResolveModel(nil, "o3-super")
|
||||
if !ok || got != "deepseek-reasoner" {
|
||||
t.Fatalf("expected heuristic reasoner, got ok=%v model=%q", ok, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelUnknown(t *testing.T) {
|
||||
_, ok := ResolveModel(nil, "totally-custom-model")
|
||||
if ok {
|
||||
t.Fatal("expected unknown model to fail resolve")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeModelsResponsePaginationFields(t *testing.T) {
|
||||
resp := ClaudeModelsResponse()
|
||||
if _, ok := resp["first_id"]; !ok {
|
||||
t.Fatalf("expected first_id in response: %#v", resp)
|
||||
}
|
||||
if _, ok := resp["last_id"]; !ok {
|
||||
t.Fatalf("expected last_id in response: %#v", resp)
|
||||
}
|
||||
if _, ok := resp["has_more"]; !ok {
|
||||
t.Fatalf("expected has_more in response: %#v", resp)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
@@ -71,6 +73,91 @@ func GetModelConfig(model string) (thinking bool, search bool, ok bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func IsSupportedDeepSeekModel(model string) bool {
|
||||
_, _, ok := GetModelConfig(model)
|
||||
return ok
|
||||
}
|
||||
|
||||
func DefaultModelAliases() map[string]string {
|
||||
return map[string]string{
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-4.1": "deepseek-chat",
|
||||
"gpt-4.1-mini": "deepseek-chat",
|
||||
"gpt-4.1-nano": "deepseek-chat",
|
||||
"gpt-5": "deepseek-chat",
|
||||
"gpt-5-mini": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o1": "deepseek-reasoner",
|
||||
"o1-mini": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner",
|
||||
"o3-mini": "deepseek-reasoner",
|
||||
"claude-sonnet-4-5": "deepseek-chat",
|
||||
"claude-haiku-4-5": "deepseek-chat",
|
||||
"claude-opus-4-6": "deepseek-reasoner",
|
||||
"claude-3-5-sonnet": "deepseek-chat",
|
||||
"claude-3-5-haiku": "deepseek-chat",
|
||||
"claude-3-opus": "deepseek-reasoner",
|
||||
"gemini-2.5-pro": "deepseek-chat",
|
||||
"gemini-2.5-flash": "deepseek-chat",
|
||||
"llama-3.1-70b-instruct": "deepseek-chat",
|
||||
"qwen-max": "deepseek-chat",
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveModel(store *Store, requested string) (string, bool) {
|
||||
model := lower(strings.TrimSpace(requested))
|
||||
if model == "" {
|
||||
return "", false
|
||||
}
|
||||
if IsSupportedDeepSeekModel(model) {
|
||||
return model, true
|
||||
}
|
||||
aliases := DefaultModelAliases()
|
||||
if store != nil {
|
||||
for k, v := range store.ModelAliases() {
|
||||
aliases[lower(strings.TrimSpace(k))] = lower(strings.TrimSpace(v))
|
||||
}
|
||||
}
|
||||
if mapped, ok := aliases[model]; ok && IsSupportedDeepSeekModel(mapped) {
|
||||
return mapped, true
|
||||
}
|
||||
if strings.HasPrefix(model, "deepseek-") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
knownFamily := false
|
||||
for _, prefix := range []string{
|
||||
"gpt-", "o1", "o3", "claude-", "gemini-", "llama-", "qwen-", "mistral-", "command-",
|
||||
} {
|
||||
if strings.HasPrefix(model, prefix) {
|
||||
knownFamily = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !knownFamily {
|
||||
return "", false
|
||||
}
|
||||
|
||||
useReasoner := strings.Contains(model, "reason") ||
|
||||
strings.Contains(model, "reasoner") ||
|
||||
strings.HasPrefix(model, "o1") ||
|
||||
strings.HasPrefix(model, "o3") ||
|
||||
strings.Contains(model, "opus") ||
|
||||
strings.Contains(model, "r1")
|
||||
useSearch := strings.Contains(model, "search")
|
||||
|
||||
switch {
|
||||
case useReasoner && useSearch:
|
||||
return "deepseek-reasoner-search", true
|
||||
case useReasoner:
|
||||
return "deepseek-reasoner", true
|
||||
case useSearch:
|
||||
return "deepseek-chat-search", true
|
||||
default:
|
||||
return "deepseek-chat", true
|
||||
}
|
||||
}
|
||||
|
||||
func lower(s string) string {
|
||||
b := []byte(s)
|
||||
for i, c := range b {
|
||||
@@ -85,6 +172,28 @@ func OpenAIModelsResponse() map[string]any {
|
||||
return map[string]any{"object": "list", "data": DeepSeekModels}
|
||||
}
|
||||
|
||||
func ClaudeModelsResponse() map[string]any {
|
||||
return map[string]any{"object": "list", "data": ClaudeModels}
|
||||
func OpenAIModelByID(store *Store, id string) (ModelInfo, bool) {
|
||||
canonical, ok := ResolveModel(store, id)
|
||||
if !ok {
|
||||
return ModelInfo{}, false
|
||||
}
|
||||
for _, model := range DeepSeekModels {
|
||||
if model.ID == canonical {
|
||||
return model, true
|
||||
}
|
||||
}
|
||||
return ModelInfo{}, false
|
||||
}
|
||||
|
||||
func ClaudeModelsResponse() map[string]any {
|
||||
resp := map[string]any{"object": "list", "data": ClaudeModels}
|
||||
if len(ClaudeModels) > 0 {
|
||||
resp["first_id"] = ClaudeModels[0].ID
|
||||
resp["last_id"] = ClaudeModels[len(ClaudeModels)-1].ID
|
||||
} else {
|
||||
resp["first_id"] = nil
|
||||
resp["last_id"] = nil
|
||||
}
|
||||
resp["has_more"] = false
|
||||
return resp
|
||||
}
|
||||
|
||||
165
internal/deepseek/deepseek_edge_test.go
Normal file
165
internal/deepseek/deepseek_edge_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ─── toFloat64 edge cases ────────────────────────────────────────────
|
||||
|
||||
func TestToFloat64FromFloat64(t *testing.T) {
|
||||
if got := toFloat64(float64(3.14), 0); got != 3.14 {
|
||||
t.Fatalf("expected 3.14, got %f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToFloat64FromInt(t *testing.T) {
|
||||
if got := toFloat64(42, 0); got != 42.0 {
|
||||
t.Fatalf("expected 42.0, got %f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToFloat64FromInt64(t *testing.T) {
|
||||
if got := toFloat64(int64(100), 0); got != 100.0 {
|
||||
t.Fatalf("expected 100.0, got %f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToFloat64FromStringDefault(t *testing.T) {
|
||||
if got := toFloat64("42", 99.0); got != 99.0 {
|
||||
t.Fatalf("expected default 99.0, got %f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToFloat64FromNilDefault(t *testing.T) {
|
||||
if got := toFloat64(nil, 5.5); got != 5.5 {
|
||||
t.Fatalf("expected default 5.5, got %f", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToFloat64FromBoolDefault(t *testing.T) {
|
||||
if got := toFloat64(true, 1.0); got != 1.0 {
|
||||
t.Fatalf("expected default 1.0, got %f", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── toInt64 edge cases ──────────────────────────────────────────────
|
||||
|
||||
func TestToInt64FromFloat64(t *testing.T) {
|
||||
if got := toInt64(float64(42.9), 0); got != 42 {
|
||||
t.Fatalf("expected 42, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToInt64FromInt(t *testing.T) {
|
||||
if got := toInt64(42, 0); got != 42 {
|
||||
t.Fatalf("expected 42, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToInt64FromInt64(t *testing.T) {
|
||||
if got := toInt64(int64(100), 0); got != 100 {
|
||||
t.Fatalf("expected 100, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToInt64FromStringDefault(t *testing.T) {
|
||||
if got := toInt64("42", 99); got != 99 {
|
||||
t.Fatalf("expected default 99, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToInt64FromNilDefault(t *testing.T) {
|
||||
if got := toInt64(nil, 7); got != 7 {
|
||||
t.Fatalf("expected default 7, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── BuildPowHeader edge cases ───────────────────────────────────────
|
||||
|
||||
func TestBuildPowHeaderBasicChallenge(t *testing.T) {
|
||||
challenge := map[string]any{
|
||||
"algorithm": "DeepSeekHashV1",
|
||||
"challenge": "abc123",
|
||||
"salt": "salt456",
|
||||
"signature": "sig789",
|
||||
"target_path": "/path",
|
||||
}
|
||||
result, err := BuildPowHeader(challenge, 42)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result == "" {
|
||||
t.Fatal("expected non-empty result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPowHeaderEmptyChallenge(t *testing.T) {
|
||||
result, err := BuildPowHeader(map[string]any{}, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Should produce a base64 encoded JSON with nil values
|
||||
if result == "" {
|
||||
t.Fatal("expected non-empty result for empty challenge")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── PowSolver pool size ─────────────────────────────────────────────
|
||||
|
||||
func TestPowPoolSizeFromEnvDefault(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "")
|
||||
got := powPoolSizeFromEnv()
|
||||
if got < 1 {
|
||||
t.Fatalf("expected positive default pool size, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPowPoolSizeFromEnvInvalid(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "abc")
|
||||
got := powPoolSizeFromEnv()
|
||||
if got < 1 {
|
||||
t.Fatalf("expected positive default for invalid, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPowPoolSizeFromEnvSpecificValue(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "5")
|
||||
got := powPoolSizeFromEnv()
|
||||
if got != 5 {
|
||||
t.Fatalf("expected 5, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── NewClient ───────────────────────────────────────────────────────
|
||||
|
||||
func TestNewClientInitialState(t *testing.T) {
|
||||
client := NewClient(nil, nil)
|
||||
if client.powSolver == nil {
|
||||
t.Fatal("expected powSolver to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientPreloadPowIdempotent(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "1")
|
||||
client := NewClient(nil, nil)
|
||||
if err := client.PreloadPow(context.Background()); err != nil {
|
||||
t.Fatalf("first preload failed: %v", err)
|
||||
}
|
||||
if err := client.PreloadPow(context.Background()); err != nil {
|
||||
t.Fatalf("second preload failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── PowSolver init and module pool ──────────────────────────────────
|
||||
|
||||
func TestPowSolverPoolSizeMatchesEnv(t *testing.T) {
|
||||
t.Setenv("DS2API_POW_POOL_SIZE", "2")
|
||||
solver := NewPowSolver("test.wasm")
|
||||
if err := solver.init(context.Background()); err != nil {
|
||||
t.Fatalf("init failed: %v", err)
|
||||
}
|
||||
if cap(solver.pool) != 2 {
|
||||
t.Fatalf("expected pool capacity 2, got %d", cap(solver.pool))
|
||||
}
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func cors(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass")
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
|
||||
140
internal/sse/consumer_edge_test.go
Normal file
140
internal/sse/consumer_edge_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ─── CollectStream edge cases ────────────────────────────────────────
|
||||
|
||||
func makeHTTPResponse(body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamEmpty(t *testing.T) {
|
||||
resp := makeHTTPResponse("")
|
||||
result := CollectStream(resp, false, false)
|
||||
if result.Text != "" || result.Thinking != "" {
|
||||
t.Fatalf("expected empty result, got text=%q think=%q", result.Text, result.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamTextOnly(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\" World\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
if result.Text != "Hello World" {
|
||||
t.Fatalf("expected 'Hello World', got %q", result.Text)
|
||||
}
|
||||
if result.Thinking != "" {
|
||||
t.Fatalf("expected no thinking, got %q", result.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamThinkingAndText(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/thinking_content\",\"v\":\"Thinking...\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Answer\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, true, true)
|
||||
if result.Thinking != "Thinking..." {
|
||||
t.Fatalf("expected 'Thinking...', got %q", result.Thinking)
|
||||
}
|
||||
if result.Text != "Answer" {
|
||||
t.Fatalf("expected 'Answer', got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamOnlyThinking(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/thinking_content\",\"v\":\"Only thinking\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, true, true)
|
||||
if result.Thinking != "Only thinking" {
|
||||
t.Fatalf("expected 'Only thinking', got %q", result.Thinking)
|
||||
}
|
||||
if result.Text != "" {
|
||||
t.Fatalf("expected empty text, got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamSkipsInvalidLines(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"event: comment\n" +
|
||||
"data: invalid_json\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
if result.Text != "valid" {
|
||||
t.Fatalf("expected 'valid', got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamWithFragments(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"Think\"}]}\n" +
|
||||
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"Done\"}]}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, true, true)
|
||||
if result.Thinking != "Think" {
|
||||
t.Fatalf("expected 'Think' thinking, got %q", result.Thinking)
|
||||
}
|
||||
if result.Text != "Done" {
|
||||
t.Fatalf("expected 'Done' text, got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamWithCitation(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"[citation:1] cited text\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\" more\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
// CollectStream does NOT filter citations (that's done by the adapters)
|
||||
// So citations are passed through as-is
|
||||
if !strings.Contains(result.Text, "[citation:1]") {
|
||||
t.Fatalf("expected citations to be passed through, got %q", result.Text)
|
||||
}
|
||||
if result.Text != "Hello[citation:1] cited text more" {
|
||||
t.Fatalf("expected full text with citation, got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamMultipleThinkingChunks(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/thinking_content\",\"v\":\"part1\"}\n" +
|
||||
"data: {\"p\":\"response/thinking_content\",\"v\":\" part2\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"answer\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
result := CollectStream(resp, true, true)
|
||||
if result.Thinking != "part1 part2" {
|
||||
t.Fatalf("expected 'part1 part2', got %q", result.Thinking)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamStatusFinished(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" +
|
||||
"data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
if result.Text != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %q", result.Text)
|
||||
}
|
||||
}
|
||||
70
internal/sse/line_edge_test.go
Normal file
70
internal/sse/line_edge_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package sse
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseDeepSeekContentLineNotParsed(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte("not a data line"), false, "text")
|
||||
if res.Parsed {
|
||||
t.Fatal("expected not parsed")
|
||||
}
|
||||
if res.NextType != "text" {
|
||||
t.Fatalf("expected nextType preserved, got %q", res.NextType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLinePreservesNextType(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/thinking_content","v":"think"}`), true, "thinking")
|
||||
if !res.Parsed || res.Stop {
|
||||
t.Fatalf("expected parsed non-stop: %#v", res)
|
||||
}
|
||||
if len(res.Parts) != 1 || res.Parts[0].Type != "thinking" {
|
||||
t.Fatalf("unexpected parts: %#v", res.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineFragmentSwitchType(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine(
|
||||
[]byte(`data: {"p":"response/fragments","o":"APPEND","v":[{"type":"RESPONSE","content":"hi"}]}`),
|
||||
true, "thinking",
|
||||
)
|
||||
if !res.Parsed || res.Stop {
|
||||
t.Fatalf("expected parsed non-stop: %#v", res)
|
||||
}
|
||||
if res.NextType != "text" {
|
||||
t.Fatalf("expected nextType text after RESPONSE fragment, got %q", res.NextType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text")
|
||||
if !res.ContentFilter {
|
||||
t.Fatal("expected content filter flag")
|
||||
}
|
||||
if res.ErrorMessage == "" {
|
||||
t.Fatal("expected error message on content filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineErrorObjectFormat(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"error":{"message":"rate limit","code":429}}`), false, "text")
|
||||
if !res.Parsed || !res.Stop {
|
||||
t.Fatalf("expected parsed stop: %#v", res)
|
||||
}
|
||||
if res.ErrorMessage == "" {
|
||||
t.Fatal("expected non-empty error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineInvalidJSON(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte("data: {broken"), false, "text")
|
||||
if res.Parsed {
|
||||
t.Fatal("expected not parsed for broken JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineEmptyBytes(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte{}, false, "text")
|
||||
if res.Parsed {
|
||||
t.Fatal("expected not parsed for empty bytes")
|
||||
}
|
||||
}
|
||||
631
internal/sse/parser_edge_test.go
Normal file
631
internal/sse/parser_edge_test.go
Normal file
@@ -0,0 +1,631 @@
|
||||
package sse
|
||||
|
||||
import "testing"
|
||||
|
||||
// ─── ParseDeepSeekSSELine edge cases ─────────────────────────────────
|
||||
|
||||
func TestParseDeepSeekSSELineEmptyLine(t *testing.T) {
|
||||
_, _, ok := ParseDeepSeekSSELine([]byte(""))
|
||||
if ok {
|
||||
t.Fatal("expected not parsed for empty line")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekSSELineNoDataPrefix(t *testing.T) {
|
||||
_, _, ok := ParseDeepSeekSSELine([]byte("event: message"))
|
||||
if ok {
|
||||
t.Fatal("expected not parsed for non-data line")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekSSELineInvalidJSON(t *testing.T) {
|
||||
_, _, ok := ParseDeepSeekSSELine([]byte("data: {invalid json"))
|
||||
if ok {
|
||||
t.Fatal("expected not parsed for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekSSELineWhitespaceOnly(t *testing.T) {
|
||||
_, _, ok := ParseDeepSeekSSELine([]byte(" "))
|
||||
if ok {
|
||||
t.Fatal("expected not parsed for whitespace-only line")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekSSELineDataWithExtraSpaces(t *testing.T) {
|
||||
chunk, done, ok := ParseDeepSeekSSELine([]byte(`data: {"v":"hello"} `))
|
||||
if !ok || done {
|
||||
t.Fatalf("expected parsed chunk for spaced data line")
|
||||
}
|
||||
if chunk["v"] != "hello" {
|
||||
t.Fatalf("unexpected chunk: %#v", chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── shouldSkipPath edge cases ───────────────────────────────────────
|
||||
|
||||
func TestShouldSkipPathQuasiStatus(t *testing.T) {
|
||||
if !shouldSkipPath("response/quasi_status") {
|
||||
t.Fatal("expected skip for quasi_status path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathElapsedSecs(t *testing.T) {
|
||||
if !shouldSkipPath("response/elapsed_secs") {
|
||||
t.Fatal("expected skip for elapsed_secs path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathTokenUsage(t *testing.T) {
|
||||
if !shouldSkipPath("response/token_usage") {
|
||||
t.Fatal("expected skip for token_usage path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathPendingFragment(t *testing.T) {
|
||||
if !shouldSkipPath("response/pending_fragment") {
|
||||
t.Fatal("expected skip for pending_fragment path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathConversationMode(t *testing.T) {
|
||||
if !shouldSkipPath("response/conversation_mode") {
|
||||
t.Fatal("expected skip for conversation_mode path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathSearchStatus(t *testing.T) {
|
||||
if !shouldSkipPath("response/search_status") {
|
||||
t.Fatal("expected skip for search_status path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathFragmentStatus(t *testing.T) {
|
||||
if !shouldSkipPath("response/fragments/-1/status") {
|
||||
t.Fatal("expected skip for fragment -1 status")
|
||||
}
|
||||
if !shouldSkipPath("response/fragments/-2/status") {
|
||||
t.Fatal("expected skip for fragment -2 status")
|
||||
}
|
||||
if !shouldSkipPath("response/fragments/-3/status") {
|
||||
t.Fatal("expected skip for fragment -3 status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathRegularContent(t *testing.T) {
|
||||
if shouldSkipPath("response/content") {
|
||||
t.Fatal("expected not skip for content path")
|
||||
}
|
||||
if shouldSkipPath("response/thinking_content") {
|
||||
t.Fatal("expected not skip for thinking_content path")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── ParseSSEChunkForContent edge cases ──────────────────────────────
|
||||
|
||||
func TestParseSSEChunkForContentNoVField(t *testing.T) {
|
||||
parts, finished, nextType := ParseSSEChunkForContent(map[string]any{"p": "response/content"}, false, "text")
|
||||
if finished {
|
||||
t.Fatal("expected not finished")
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts when v is missing, got %#v", parts)
|
||||
}
|
||||
if nextType != "text" {
|
||||
t.Fatalf("expected type preserved, got %q", nextType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentSkippedPath(t *testing.T) {
|
||||
parts, finished, nextType := ParseSSEChunkForContent(map[string]any{
|
||||
"p": "response/token_usage",
|
||||
"v": "some data",
|
||||
}, false, "text")
|
||||
if finished || len(parts) > 0 {
|
||||
t.Fatalf("expected skipped path to produce no output")
|
||||
}
|
||||
if nextType != "text" {
|
||||
t.Fatalf("expected type preserved for skipped path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentFinishedStatus(t *testing.T) {
|
||||
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
|
||||
"p": "response/status",
|
||||
"v": "FINISHED",
|
||||
}, false, "text")
|
||||
if !finished {
|
||||
t.Fatal("expected finished on status FINISHED")
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts on finished, got %d", len(parts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentStatusNotFinished(t *testing.T) {
|
||||
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
|
||||
"p": "response/status",
|
||||
"v": "IN_PROGRESS",
|
||||
}, false, "text")
|
||||
if finished {
|
||||
t.Fatal("expected not finished for non-FINISHED status")
|
||||
}
|
||||
if len(parts) != 1 || parts[0].Text != "IN_PROGRESS" {
|
||||
t.Fatalf("expected content for non-FINISHED status, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentEmptyStringV(t *testing.T) {
|
||||
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": "",
|
||||
}, false, "text")
|
||||
if finished {
|
||||
t.Fatal("expected not finished")
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts for empty string v, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentFinishedOnEmptyPath(t *testing.T) {
|
||||
parts, finished, _ := ParseSSEChunkForContent(map[string]any{
|
||||
"p": "",
|
||||
"v": "FINISHED",
|
||||
}, false, "text")
|
||||
if !finished {
|
||||
t.Fatal("expected finished on empty path with FINISHED value")
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts on finished")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentFinishedOnStatusPath(t *testing.T) {
|
||||
_, finished, _ := ParseSSEChunkForContent(map[string]any{
|
||||
"p": "status",
|
||||
"v": "FINISHED",
|
||||
}, false, "text")
|
||||
if !finished {
|
||||
t.Fatal("expected finished on status path with FINISHED value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentThinkingPathEmptyPath(t *testing.T) {
|
||||
parts, _, nextType := ParseSSEChunkForContent(map[string]any{
|
||||
"v": "some thought",
|
||||
}, true, "thinking")
|
||||
if len(parts) != 1 || parts[0].Type != "thinking" {
|
||||
t.Fatalf("expected thinking part on empty path, got %#v", parts)
|
||||
}
|
||||
if nextType != "thinking" {
|
||||
t.Fatalf("expected nextType thinking, got %q", nextType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentThinkingEnabledTextType(t *testing.T) {
|
||||
parts, _, nextType := ParseSSEChunkForContent(map[string]any{
|
||||
"v": "text content",
|
||||
}, true, "text")
|
||||
if len(parts) != 1 || parts[0].Type != "text" {
|
||||
t.Fatalf("expected text part when currentType=text, got %#v", parts)
|
||||
}
|
||||
if nextType != "text" {
|
||||
t.Fatalf("expected nextType text, got %q", nextType)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── ParseSSEChunkForContent: fragments path with THINK type ─────────
|
||||
|
||||
func TestParseSSEChunkForContentFragmentsAppendThink(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/fragments",
|
||||
"o": "APPEND",
|
||||
"v": []any{
|
||||
map[string]any{
|
||||
"type": "THINK",
|
||||
"content": "深入思考...",
|
||||
},
|
||||
},
|
||||
}
|
||||
parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text")
|
||||
if finished {
|
||||
t.Fatal("expected not finished")
|
||||
}
|
||||
if nextType != "thinking" {
|
||||
t.Fatalf("expected nextType thinking, got %q", nextType)
|
||||
}
|
||||
if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "深入思考..." {
|
||||
t.Fatalf("unexpected parts: %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentFragmentsAppendEmptyContent(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/fragments",
|
||||
"o": "APPEND",
|
||||
"v": []any{
|
||||
map[string]any{
|
||||
"type": "RESPONSE",
|
||||
"content": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "thinking")
|
||||
if finished {
|
||||
t.Fatal("expected not finished")
|
||||
}
|
||||
if nextType != "text" {
|
||||
t.Fatalf("expected nextType text, got %q", nextType)
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts for empty content, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentFragmentsAppendDefaultType(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/fragments",
|
||||
"o": "APPEND",
|
||||
"v": []any{
|
||||
map[string]any{
|
||||
"type": "UNKNOWN",
|
||||
"content": "some text",
|
||||
},
|
||||
},
|
||||
}
|
||||
parts, _, _ := ParseSSEChunkForContent(chunk, true, "text")
|
||||
if len(parts) != 1 || parts[0].Type != "text" {
|
||||
t.Fatalf("expected text type for unknown fragment type, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentFragmentsAppendNonArray(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/fragments",
|
||||
"o": "APPEND",
|
||||
"v": "not an array",
|
||||
}
|
||||
parts, finished, _ := ParseSSEChunkForContent(chunk, true, "text")
|
||||
if finished {
|
||||
t.Fatal("expected not finished")
|
||||
}
|
||||
// "not an array" should be treated as string value at the end
|
||||
if len(parts) != 1 || parts[0].Text != "not an array" {
|
||||
t.Fatalf("unexpected parts: %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentFragmentsAppendNonMap(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/fragments",
|
||||
"o": "APPEND",
|
||||
"v": []any{"string item"},
|
||||
}
|
||||
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
|
||||
// Non-map items in fragment array are skipped; the []any itself is handled later
|
||||
_ = parts // just checking it doesn't panic
|
||||
}
|
||||
|
||||
// ─── ParseSSEChunkForContent: response path with nested fragment ─────
|
||||
|
||||
func TestParseSSEChunkForContentResponsePathFragmentsAppend(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response",
|
||||
"v": []any{
|
||||
map[string]any{
|
||||
"p": "fragments",
|
||||
"o": "APPEND",
|
||||
"v": []any{
|
||||
map[string]any{
|
||||
"type": "THINKING",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, _, nextType := ParseSSEChunkForContent(chunk, true, "text")
|
||||
if nextType != "thinking" {
|
||||
t.Fatalf("expected nextType thinking from response path fragments, got %q", nextType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentResponsePathResponseFragment(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response",
|
||||
"v": []any{
|
||||
map[string]any{
|
||||
"p": "fragments",
|
||||
"o": "APPEND",
|
||||
"v": []any{
|
||||
map[string]any{
|
||||
"type": "RESPONSE",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, _, nextType := ParseSSEChunkForContent(chunk, true, "thinking")
|
||||
if nextType != "text" {
|
||||
t.Fatalf("expected nextType text for RESPONSE fragment, got %q", nextType)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── ParseSSEChunkForContent: map value with wrapped response ────────
|
||||
|
||||
func TestParseSSEChunkForContentMapValueWithFragments(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"v": map[string]any{
|
||||
"response": map[string]any{
|
||||
"fragments": []any{
|
||||
map[string]any{
|
||||
"type": "THINK",
|
||||
"content": "思考...",
|
||||
},
|
||||
map[string]any{
|
||||
"type": "RESPONSE",
|
||||
"content": "回答...",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text")
|
||||
if finished {
|
||||
t.Fatal("expected not finished")
|
||||
}
|
||||
if nextType != "text" {
|
||||
t.Fatalf("expected nextType text after RESPONSE, got %q", nextType)
|
||||
}
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts, got %d: %#v", len(parts), parts)
|
||||
}
|
||||
if parts[0].Type != "thinking" || parts[0].Text != "思考..." {
|
||||
t.Fatalf("first part mismatch: %#v", parts[0])
|
||||
}
|
||||
if parts[1].Type != "text" || parts[1].Text != "回答..." {
|
||||
t.Fatalf("second part mismatch: %#v", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentMapValueDirectFragments(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"v": map[string]any{
|
||||
"fragments": []any{
|
||||
map[string]any{
|
||||
"type": "RESPONSE",
|
||||
"content": "直接回答",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
|
||||
if len(parts) != 1 || parts[0].Text != "直接回答" || parts[0].Type != "text" {
|
||||
t.Fatalf("unexpected parts for direct fragments: %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentMapValueUnknownType(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"v": map[string]any{
|
||||
"fragments": []any{
|
||||
map[string]any{
|
||||
"type": "CUSTOM",
|
||||
"content": "custom content",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
|
||||
if len(parts) != 1 || parts[0].Type != "text" {
|
||||
t.Fatalf("expected partType fallback for unknown type, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentMapValueEmptyFragmentContent(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"v": map[string]any{
|
||||
"fragments": []any{
|
||||
map[string]any{
|
||||
"type": "RESPONSE",
|
||||
"content": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
parts, _, _ := ParseSSEChunkForContent(chunk, false, "text")
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts for empty fragment content, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── ParseSSEChunkForContent: fragments/-1/content path ──────────────
|
||||
|
||||
func TestParseSSEChunkForContentFragmentContentPathInheritsType(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/fragments/-1/content",
|
||||
"v": "继续思考",
|
||||
}
|
||||
parts, _, _ := ParseSSEChunkForContent(chunk, true, "thinking")
|
||||
if len(parts) != 1 || parts[0].Type != "thinking" {
|
||||
t.Fatalf("expected inherited thinking type, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── IsCitation edge cases ───────────────────────────────────────────
|
||||
|
||||
func TestIsCitationWithLeadingWhitespace(t *testing.T) {
|
||||
if !IsCitation(" [citation:2] text") {
|
||||
t.Fatal("expected citation true with leading whitespace")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCitationEmpty(t *testing.T) {
|
||||
if IsCitation("") {
|
||||
t.Fatal("expected citation false for empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCitationSimilarPrefix(t *testing.T) {
|
||||
if IsCitation("[cite:1] text") {
|
||||
t.Fatal("expected citation false for [cite: prefix")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── extractContentRecursive edge cases ──────────────────────────────
|
||||
|
||||
func TestExtractContentRecursiveFinishedStatus(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "status", "v": "FINISHED"},
|
||||
}
|
||||
parts, finished := extractContentRecursive(items, "text")
|
||||
if !finished {
|
||||
t.Fatal("expected finished on status FINISHED")
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveSkipsPath(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "token_usage", "v": "data"},
|
||||
}
|
||||
parts, finished := extractContentRecursive(items, "text")
|
||||
if finished {
|
||||
t.Fatal("expected not finished")
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts for skipped path, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveContentField(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "x", "v": "val", "content": "actual content", "type": "RESPONSE"},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "text")
|
||||
if len(parts) != 1 || parts[0].Text != "actual content" || parts[0].Type != "text" {
|
||||
t.Fatalf("unexpected parts: %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveContentFieldThinkType(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "x", "v": "val", "content": "think text", "type": "THINK"},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "text")
|
||||
if len(parts) != 1 || parts[0].Type != "thinking" {
|
||||
t.Fatalf("expected thinking type for THINK content, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveThinkingPath(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "thinking_content", "v": "deep thought"},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "text")
|
||||
if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "deep thought" {
|
||||
t.Fatalf("unexpected parts for thinking path: %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveContentPath(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "content", "v": "text content"},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "thinking")
|
||||
if len(parts) != 1 || parts[0].Type != "text" {
|
||||
t.Fatalf("expected text type for content path, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveResponsePath(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "response", "v": "text content"},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "thinking")
|
||||
if len(parts) != 1 || parts[0].Type != "text" {
|
||||
t.Fatalf("expected text type for response path, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveFragmentsPath(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "fragments", "v": "fragment text"},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "thinking")
|
||||
if len(parts) != 1 || parts[0].Type != "text" {
|
||||
t.Fatalf("expected text type for fragments path, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveNestedArrayWithTypes(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{
|
||||
"p": "fragments",
|
||||
"v": []any{
|
||||
map[string]any{"content": "thought", "type": "THINKING"},
|
||||
map[string]any{"content": "answer", "type": "RESPONSE"},
|
||||
"raw string",
|
||||
},
|
||||
},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "text")
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected 3 parts, got %d: %#v", len(parts), parts)
|
||||
}
|
||||
if parts[0].Type != "thinking" || parts[0].Text != "thought" {
|
||||
t.Fatalf("first part mismatch: %#v", parts[0])
|
||||
}
|
||||
if parts[1].Type != "text" || parts[1].Text != "answer" {
|
||||
t.Fatalf("second part mismatch: %#v", parts[1])
|
||||
}
|
||||
if parts[2].Type != "text" || parts[2].Text != "raw string" {
|
||||
t.Fatalf("third part mismatch: %#v", parts[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveEmptyContentSkipped(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{
|
||||
"p": "fragments",
|
||||
"v": []any{
|
||||
map[string]any{"content": "", "type": "RESPONSE"},
|
||||
},
|
||||
},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "text")
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts for empty nested content, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveFinishedString(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "content", "v": "FINISHED"},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "text")
|
||||
// "FINISHED" string value on non-status path should be skipped
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected FINISHED string to be skipped, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveNoVField(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "content"},
|
||||
}
|
||||
parts, _ := extractContentRecursive(items, "text")
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts for missing v field, got %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractContentRecursiveNonMapItem(t *testing.T) {
|
||||
items := []any{"just a string", 42}
|
||||
parts, _ := extractContentRecursive(items, "text")
|
||||
if len(parts) != 0 {
|
||||
t.Fatalf("expected no parts for non-map items, got %#v", parts)
|
||||
}
|
||||
}
|
||||
177
internal/sse/stream_edge_test.go
Normal file
177
internal/sse/stream_edge_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStartParsedLinePumpEmptyBody(t *testing.T) {
|
||||
body := strings.NewReader("")
|
||||
results, done := StartParsedLinePump(context.Background(), body, false, "text")
|
||||
|
||||
collected := make([]LineResult, 0)
|
||||
for r := range results {
|
||||
collected = append(collected, r)
|
||||
}
|
||||
if err := <-done; err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(collected) != 0 {
|
||||
t.Fatalf("expected no results for empty body, got %d", len(collected))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartParsedLinePumpMultipleLines(t *testing.T) {
|
||||
body := strings.NewReader(
|
||||
"data: {\"p\":\"response/thinking_content\",\"v\":\"think\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"text\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
results, done := StartParsedLinePump(context.Background(), body, true, "thinking")
|
||||
|
||||
collected := make([]LineResult, 0)
|
||||
for r := range results {
|
||||
collected = append(collected, r)
|
||||
}
|
||||
if err := <-done; err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(collected) < 3 {
|
||||
t.Fatalf("expected at least 3 results, got %d", len(collected))
|
||||
}
|
||||
// First should be thinking
|
||||
if collected[0].Parts[0].Type != "thinking" {
|
||||
t.Fatalf("expected first part thinking, got %q", collected[0].Parts[0].Type)
|
||||
}
|
||||
// Last should be stop
|
||||
last := collected[len(collected)-1]
|
||||
if !last.Stop {
|
||||
t.Fatal("expected last result to be stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartParsedLinePumpTypeTracking(t *testing.T) {
|
||||
body := strings.NewReader(
|
||||
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"思\"}]}\n" +
|
||||
"data: {\"p\":\"response/fragments/-1/content\",\"v\":\"考\"}\n" +
|
||||
"data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"答\"}]}\n" +
|
||||
"data: {\"p\":\"response/fragments/-1/content\",\"v\":\"案\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
results, done := StartParsedLinePump(context.Background(), body, true, "text")
|
||||
|
||||
types := make([]string, 0)
|
||||
for r := range results {
|
||||
for _, p := range r.Parts {
|
||||
types = append(types, p.Type)
|
||||
}
|
||||
}
|
||||
<-done
|
||||
|
||||
// Should have: thinking, thinking, text, text
|
||||
expected := []string{"thinking", "thinking", "text", "text"}
|
||||
if len(types) != len(expected) {
|
||||
t.Fatalf("expected types %v, got %v", expected, types)
|
||||
}
|
||||
for i, want := range expected {
|
||||
if types[i] != want {
|
||||
t.Fatalf("type[%d] mismatch: want %q got %q (all=%v)", i, want, types[i], types)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartParsedLinePumpContextCancellation(t *testing.T) {
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
results, done := StartParsedLinePump(ctx, pr, false, "text")
|
||||
|
||||
// Write one line to allow it to start
|
||||
go func() {
|
||||
_, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n")
|
||||
// Don't close yet - wait for context cancel
|
||||
}()
|
||||
|
||||
// Read first result
|
||||
r := <-results
|
||||
if !r.Parsed || len(r.Parts) == 0 {
|
||||
t.Fatalf("expected first parsed result, got %#v", r)
|
||||
}
|
||||
|
||||
// Cancel context - this will cause the pump to exit on next send
|
||||
cancel()
|
||||
// Close the pipe to unblock scanner.Scan()
|
||||
pw.Close()
|
||||
|
||||
// Drain remaining results
|
||||
for range results {
|
||||
}
|
||||
|
||||
err := <-done
|
||||
// Error may be context.Canceled or nil (if pipe closed first)
|
||||
if err != nil && err != context.Canceled {
|
||||
t.Fatalf("expected context.Canceled or nil error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartParsedLinePumpOnlyDONE(t *testing.T) {
|
||||
body := strings.NewReader("data: [DONE]\n")
|
||||
results, done := StartParsedLinePump(context.Background(), body, false, "text")
|
||||
|
||||
collected := make([]LineResult, 0)
|
||||
for r := range results {
|
||||
collected = append(collected, r)
|
||||
}
|
||||
<-done
|
||||
|
||||
if len(collected) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(collected))
|
||||
}
|
||||
if !collected[0].Stop {
|
||||
t.Fatal("expected stop on [DONE]")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartParsedLinePumpNonSSELines(t *testing.T) {
|
||||
body := strings.NewReader(
|
||||
"event: update\n" +
|
||||
": comment line\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
results, done := StartParsedLinePump(context.Background(), body, false, "text")
|
||||
|
||||
var validCount int
|
||||
for r := range results {
|
||||
if r.Parsed && len(r.Parts) > 0 {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
<-done
|
||||
|
||||
if validCount != 1 {
|
||||
t.Fatalf("expected 1 valid result, got %d", validCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartParsedLinePumpThinkingDisabled(t *testing.T) {
|
||||
body := strings.NewReader(
|
||||
"data: {\"p\":\"response/thinking_content\",\"v\":\"thought\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"response\"}\n" +
|
||||
"data: [DONE]\n",
|
||||
)
|
||||
// With thinking disabled, thinking content should still be emitted but marked differently
|
||||
results, done := StartParsedLinePump(context.Background(), body, false, "text")
|
||||
|
||||
var parts []ContentPart
|
||||
for r := range results {
|
||||
parts = append(parts, r.Parts...)
|
||||
}
|
||||
<-done
|
||||
|
||||
if len(parts) < 1 {
|
||||
t.Fatalf("expected at least 1 part, got %d", len(parts))
|
||||
}
|
||||
}
|
||||
@@ -755,11 +755,15 @@ func (r *Runner) cases() []caseDef {
|
||||
{ID: "healthz_ok", Run: r.caseHealthz},
|
||||
{ID: "readyz_ok", Run: r.caseReadyz},
|
||||
{ID: "models_openai", Run: r.caseModelsOpenAI},
|
||||
{ID: "model_openai_by_id", Run: r.caseModelOpenAIByID},
|
||||
{ID: "models_claude", Run: r.caseModelsClaude},
|
||||
{ID: "admin_login_verify", Run: r.caseAdminLoginVerify},
|
||||
{ID: "admin_queue_status", Run: r.caseAdminQueueStatus},
|
||||
{ID: "chat_nonstream_basic", Run: r.caseChatNonstream},
|
||||
{ID: "chat_stream_basic", Run: r.caseChatStream},
|
||||
{ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream},
|
||||
{ID: "responses_stream_basic", Run: r.caseResponsesStream},
|
||||
{ID: "embeddings_contract", Run: r.caseEmbeddings},
|
||||
{ID: "reasoner_stream", Run: r.caseReasonerStream},
|
||||
{ID: "toolcall_nonstream", Run: r.caseToolcallNonstream},
|
||||
{ID: "toolcall_stream", Run: r.caseToolcallStream},
|
||||
@@ -817,6 +821,19 @@ func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true})
|
||||
if err != nil {
|
||||
@@ -942,6 +959,115 @@ func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/responses",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "请简要回答 hello",
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
responseID := asString(m["id"])
|
||||
cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
if responseID != "" {
|
||||
getResp, getErr := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/v1/responses/" + responseID,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if getErr != nil {
|
||||
return getErr
|
||||
}
|
||||
cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/responses",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "请流式回答 hello",
|
||||
"stream": true,
|
||||
},
|
||||
Stream: true,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
frames, done := parseSSEFrames(resp.Body)
|
||||
cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
|
||||
hasCreated := false
|
||||
hasCompleted := false
|
||||
for _, f := range frames {
|
||||
switch asString(f["type"]) {
|
||||
case "response.created":
|
||||
hasCreated = true
|
||||
case "response.completed":
|
||||
hasCompleted = true
|
||||
}
|
||||
}
|
||||
cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("done_terminated", done, "expected [DONE]")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/embeddings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": []string{"hello", "world"},
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
data, _ := m["data"].([]any)
|
||||
cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
errObj, _ := m["error"].(map[string]any)
|
||||
_, hasCode := errObj["code"]
|
||||
_, hasParam := errObj["param"]
|
||||
cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
@@ -68,15 +70,25 @@ func normalizeContent(v any) string {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if m["type"] == "text" {
|
||||
typeStr, _ := m["type"].(string)
|
||||
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
|
||||
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
|
||||
if txt, ok := m["text"].(string); ok {
|
||||
parts = append(parts, txt)
|
||||
continue
|
||||
}
|
||||
if txt, ok := m["content"].(string); ok {
|
||||
parts = append(parts, txt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return ""
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +33,33 @@ func TestMessagesPrepareRoles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareObjectContent(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": map[string]any{"temp": 18, "ok": true}},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if !contains(got, `"temp":18`) || !contains(got, `"ok":true`) {
|
||||
t.Fatalf("expected serialized object content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareArrayTextVariants(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "output_text", "text": "line1"},
|
||||
map[string]any{"type": "input_text", "text": "line2"},
|
||||
map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if got != "line1\nline2" {
|
||||
t.Fatalf("unexpected content from text variants: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeToDeepSeek(t *testing.T) {
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
|
||||
140
internal/util/render.go
Normal file
140
internal/util/render.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := ParseToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
messageObj["reasoning_content"] = finalThinking
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = FormatOpenAIToolCalls(detected)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
promptTokens := EstimateTokens(finalPrompt)
|
||||
reasoningTokens := EstimateTokens(finalThinking)
|
||||
completionTokens := EstimateTokens(finalText)
|
||||
|
||||
return map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := ParseToolCalls(finalText, toolNames)
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
// Keep structured tool output only; avoid leaking raw tool-call JSON
|
||||
// into response.output_text for clients reading completed responses.
|
||||
exposedOutputText = ""
|
||||
toolCalls := make([]any, 0, len(detected))
|
||||
for _, tc := range detected {
|
||||
toolCalls = append(toolCalls, map[string]any{
|
||||
"type": "tool_call",
|
||||
"name": tc.Name,
|
||||
"arguments": tc.Input,
|
||||
})
|
||||
}
|
||||
output = append(output, map[string]any{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": toolCalls,
|
||||
})
|
||||
} else {
|
||||
content := []any{
|
||||
map[string]any{
|
||||
"type": "output_text",
|
||||
"text": finalText,
|
||||
},
|
||||
}
|
||||
if finalThinking != "" {
|
||||
content = append([]any{map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
}}, content...)
|
||||
}
|
||||
output = append(output, map[string]any{
|
||||
"type": "message",
|
||||
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
promptTokens := EstimateTokens(finalPrompt)
|
||||
reasoningTokens := EstimateTokens(finalThinking)
|
||||
completionTokens := EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"id": responseID,
|
||||
"type": "response",
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"output_text": exposedOutputText,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": promptTokens,
|
||||
"output_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := ParseToolCalls(finalText, toolNames)
|
||||
content := make([]map[string]any, 0, 4)
|
||||
if finalThinking != "" {
|
||||
content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking})
|
||||
}
|
||||
stopReason := "end_turn"
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
content = append(content, map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if finalText == "" {
|
||||
finalText = "抱歉,没有生成有效的响应内容。"
|
||||
}
|
||||
content = append(content, map[string]any{"type": "text", "text": finalText})
|
||||
}
|
||||
return map[string]any{
|
||||
"id": messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content,
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": EstimateTokens(fmt.Sprintf("%v", normalizedMessages)),
|
||||
"output_tokens": EstimateTokens(finalThinking) + EstimateTokens(finalText),
|
||||
},
|
||||
}
|
||||
}
|
||||
93
internal/util/render_stream.go
Normal file
93
internal/util/render_stream.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package util
|
||||
|
||||
func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": delta,
|
||||
"index": index,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": map[string]any{},
|
||||
"index": index,
|
||||
"finish_reason": finishReason,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildOpenAIChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
|
||||
out := map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": choices,
|
||||
}
|
||||
if len(usage) > 0 {
|
||||
out["usage"] = usage
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := EstimateTokens(finalPrompt)
|
||||
reasoningTokens := EstimateTokens(finalThinking)
|
||||
completionTokens := EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.created",
|
||||
"id": responseID,
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"status": "in_progress",
|
||||
}
|
||||
}
|
||||
|
||||
func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
"id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning.delta",
|
||||
"id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
"id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.done",
|
||||
"id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildOpenAIResponsesCompletedPayload(response map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": response,
|
||||
}
|
||||
}
|
||||
48
internal/util/render_stream_test.go
Normal file
48
internal/util/render_stream_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildOpenAIChatStreamChunk(t *testing.T) {
|
||||
chunk := BuildOpenAIChatStreamChunk(
|
||||
"cid",
|
||||
123,
|
||||
"deepseek-chat",
|
||||
[]map[string]any{BuildOpenAIChatStreamDeltaChoice(0, map[string]any{"role": "assistant"})},
|
||||
nil,
|
||||
)
|
||||
if chunk["object"] != "chat.completion.chunk" {
|
||||
t.Fatalf("unexpected object: %#v", chunk["object"])
|
||||
}
|
||||
choices, _ := chunk["choices"].([]map[string]any)
|
||||
if len(choices) == 0 {
|
||||
rawChoices, _ := chunk["choices"].([]any)
|
||||
if len(rawChoices) == 0 {
|
||||
t.Fatalf("expected choices")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIChatUsage(t *testing.T) {
|
||||
usage := BuildOpenAIChatUsage("prompt", "think", "answer")
|
||||
if _, ok := usage["prompt_tokens"]; !ok {
|
||||
t.Fatalf("expected prompt_tokens")
|
||||
}
|
||||
if _, ok := usage["completion_tokens_details"]; !ok {
|
||||
t.Fatalf("expected completion_tokens_details")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIResponsesEventPayloads(t *testing.T) {
|
||||
created := BuildOpenAIResponsesCreatedPayload("resp_1", "gpt-4o")
|
||||
if created["type"] != "response.created" {
|
||||
t.Fatalf("unexpected type: %#v", created["type"])
|
||||
}
|
||||
done := BuildOpenAIResponsesToolCallDonePayload("resp_1", []map[string]any{{"index": 0}})
|
||||
if done["type"] != "response.output_tool_call.done" {
|
||||
t.Fatalf("unexpected type: %#v", done["type"])
|
||||
}
|
||||
completed := BuildOpenAIResponsesCompletedPayload(map[string]any{"id": "resp_1"})
|
||||
if completed["type"] != "response.completed" {
|
||||
t.Fatalf("unexpected type: %#v", completed["type"])
|
||||
}
|
||||
}
|
||||
94
internal/util/render_test.go
Normal file
94
internal/util/render_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildOpenAIChatCompletionWithToolCalls(t *testing.T) {
|
||||
out := BuildOpenAIChatCompletion(
|
||||
"cid1",
|
||||
"deepseek-chat",
|
||||
"prompt",
|
||||
"",
|
||||
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
|
||||
[]string{"search"},
|
||||
)
|
||||
if out["object"] != "chat.completion" {
|
||||
t.Fatalf("unexpected object: %#v", out["object"])
|
||||
}
|
||||
choices, _ := out["choices"].([]map[string]any)
|
||||
if len(choices) == 0 {
|
||||
// json-like map from generic marshalling may be []any in some paths
|
||||
rawChoices, _ := out["choices"].([]any)
|
||||
if len(rawChoices) == 0 {
|
||||
t.Fatalf("expected choices")
|
||||
}
|
||||
c0, _ := rawChoices[0].(map[string]any)
|
||||
if c0["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", c0["finish_reason"])
|
||||
}
|
||||
return
|
||||
}
|
||||
if choices[0]["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choices[0]["finish_reason"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIResponseObjectWithText(t *testing.T) {
|
||||
out := BuildOpenAIResponseObject(
|
||||
"resp_1",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"reasoning",
|
||||
"text",
|
||||
nil,
|
||||
)
|
||||
if out["object"] != "response" {
|
||||
t.Fatalf("unexpected object: %#v", out["object"])
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected output entries")
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected first output type message, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIResponseObjectToolCallsHidesRawOutputText(t *testing.T) {
|
||||
out := BuildOpenAIResponseObject(
|
||||
"resp_2",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
|
||||
[]string{"search"},
|
||||
)
|
||||
if out["output_text"] != "" {
|
||||
t.Fatalf("expected empty output_text for tool_calls, got %#v", out["output_text"])
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected output entries")
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "tool_calls" {
|
||||
t.Fatalf("expected first output type tool_calls, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeMessageResponseToolUse(t *testing.T) {
|
||||
out := BuildClaudeMessageResponse(
|
||||
"msg_1",
|
||||
"claude-sonnet-4-5",
|
||||
[]any{map[string]any{"role": "user", "content": "hi"}},
|
||||
"",
|
||||
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
|
||||
[]string{"search"},
|
||||
)
|
||||
if out["type"] != "message" {
|
||||
t.Fatalf("unexpected type: %#v", out["type"])
|
||||
}
|
||||
if out["stop_reason"] != "tool_use" {
|
||||
t.Fatalf("expected stop_reason=tool_use, got %#v", out["stop_reason"])
|
||||
}
|
||||
}
|
||||
30
internal/util/standard_request.go
Normal file
30
internal/util/standard_request.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package util
|
||||
|
||||
type StandardRequest struct {
|
||||
Surface string
|
||||
RequestedModel string
|
||||
ResolvedModel string
|
||||
ResponseModel string
|
||||
Messages []any
|
||||
FinalPrompt string
|
||||
ToolNames []string
|
||||
Stream bool
|
||||
Thinking bool
|
||||
Search bool
|
||||
PassThrough map[string]any
|
||||
}
|
||||
|
||||
func (r StandardRequest) CompletionPayload(sessionID string) map[string]any {
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": r.FinalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": r.Thinking,
|
||||
"search_enabled": r.Search,
|
||||
}
|
||||
for k, v := range r.PassThrough {
|
||||
payload[k] = v
|
||||
}
|
||||
return payload
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
|
||||
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
|
||||
var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
|
||||
|
||||
type ParsedToolCall struct {
|
||||
Name string `json:"name"`
|
||||
@@ -20,6 +21,10 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
}
|
||||
text = stripFencedCodeBlocks(text)
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
candidates := buildToolCallCandidates(text)
|
||||
var parsed []ParsedToolCall
|
||||
@@ -33,6 +38,34 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return nil
|
||||
}
|
||||
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
}
|
||||
|
||||
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
if looksLikeToolExampleContext(trimmed) {
|
||||
return nil
|
||||
}
|
||||
candidates := []string{trimmed}
|
||||
for _, candidate := range candidates {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") {
|
||||
continue
|
||||
}
|
||||
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall {
|
||||
allowed := map[string]struct{}{}
|
||||
for _, name := range availableToolNames {
|
||||
allowed[name] = struct{}{}
|
||||
@@ -283,6 +316,21 @@ func extractJSONObject(text string, start int) (string, int, bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
t := strings.ToLower(strings.TrimSpace(text))
|
||||
if t == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(t, "```")
|
||||
}
|
||||
|
||||
func stripFencedCodeBlocks(text string) string {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return ""
|
||||
}
|
||||
return fencedBlockPattern.ReplaceAllString(text, " ")
|
||||
}
|
||||
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for _, c := range calls {
|
||||
|
||||
@@ -19,11 +19,8 @@ func TestParseToolCalls(t *testing.T) {
|
||||
func TestParseToolCallsFromFencedJSON(t *testing.T) {
|
||||
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Input["q"] != "news" {
|
||||
t.Fatalf("unexpected args: %#v", calls[0].Input)
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,3 +59,23 @@ func TestFormatOpenAIToolCalls(t *testing.T) {
|
||||
t.Fatalf("unexpected function name: %#v", fn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) {
|
||||
mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 {
|
||||
t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls)
|
||||
}
|
||||
|
||||
standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
calls := ParseStandaloneToolCalls(standalone, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected standalone parser to match, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
|
||||
fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```"
|
||||
if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 {
|
||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
429
internal/util/util_edge_test.go
Normal file
429
internal/util/util_edge_test.go
Normal file
@@ -0,0 +1,429 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── EstimateTokens edge cases ───────────────────────────────────────
|
||||
|
||||
func TestEstimateTokensEmpty(t *testing.T) {
|
||||
if got := EstimateTokens(""); got != 0 {
|
||||
t.Fatalf("expected 0 for empty string, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokensShortASCII(t *testing.T) {
|
||||
got := EstimateTokens("ab")
|
||||
if got != 1 {
|
||||
t.Fatalf("expected 1 for 2 ascii chars, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokensLongASCII(t *testing.T) {
|
||||
got := EstimateTokens(strings.Repeat("x", 100))
|
||||
if got != 25 {
|
||||
t.Fatalf("expected 25 for 100 ascii chars, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokensChinese(t *testing.T) {
|
||||
got := EstimateTokens("你好世界")
|
||||
if got < 1 {
|
||||
t.Fatalf("expected at least 1 token for Chinese text, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokensMixed(t *testing.T) {
|
||||
got := EstimateTokens("Hello 你好世界")
|
||||
if got < 2 {
|
||||
t.Fatalf("expected at least 2 tokens for mixed text, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokensSingleByte(t *testing.T) {
|
||||
got := EstimateTokens("x")
|
||||
if got != 1 {
|
||||
t.Fatalf("expected 1 for single char (minimum), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokensSingleChinese(t *testing.T) {
|
||||
got := EstimateTokens("你")
|
||||
if got != 1 {
|
||||
t.Fatalf("expected 1 for single Chinese char, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── ToBool edge cases ───────────────────────────────────────────────
|
||||
|
||||
func TestToBoolTrue(t *testing.T) {
|
||||
if !ToBool(true) {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBoolFalse(t *testing.T) {
|
||||
if ToBool(false) {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBoolNonBool(t *testing.T) {
|
||||
if ToBool("true") {
|
||||
t.Fatal("expected false for string 'true'")
|
||||
}
|
||||
if ToBool(1) {
|
||||
t.Fatal("expected false for int 1")
|
||||
}
|
||||
if ToBool(nil) {
|
||||
t.Fatal("expected false for nil")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── IntFrom edge cases ─────────────────────────────────────────────
|
||||
|
||||
func TestIntFromFloat64(t *testing.T) {
|
||||
if got := IntFrom(float64(42.5)); got != 42 {
|
||||
t.Fatalf("expected 42 for float64(42.5), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFromInt(t *testing.T) {
|
||||
if got := IntFrom(int(42)); got != 42 {
|
||||
t.Fatalf("expected 42, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFromInt64(t *testing.T) {
|
||||
if got := IntFrom(int64(42)); got != 42 {
|
||||
t.Fatalf("expected 42, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFromString(t *testing.T) {
|
||||
if got := IntFrom("42"); got != 0 {
|
||||
t.Fatalf("expected 0 for string, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFromNil(t *testing.T) {
|
||||
if got := IntFrom(nil); got != 0 {
|
||||
t.Fatalf("expected 0 for nil, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── WriteJSON ───────────────────────────────────────────────────────
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
WriteJSON(rec, 200, map[string]any{"key": "value"})
|
||||
if rec.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d", rec.Code)
|
||||
}
|
||||
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
|
||||
t.Fatalf("expected application/json content type, got %q", ct)
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode error: %v", err)
|
||||
}
|
||||
if body["key"] != "value" {
|
||||
t.Fatalf("unexpected body: %#v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSONStatusCodes(t *testing.T) {
|
||||
for _, code := range []int{200, 201, 400, 404, 500} {
|
||||
rec := httptest.NewRecorder()
|
||||
WriteJSON(rec, code, map[string]any{"status": code})
|
||||
if rec.Code != code {
|
||||
t.Fatalf("expected %d, got %d", code, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── MessagesPrepare edge cases ──────────────────────────────────────
|
||||
|
||||
func TestMessagesPrepareEmpty(t *testing.T) {
|
||||
got := MessagesPrepare(nil)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty for nil messages, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareMergesConsecutiveSameRole(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "user", "content": "World"},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if !strings.Contains(got, "Hello") || !strings.Contains(got, "World") {
|
||||
t.Fatalf("expected both messages, got %q", got)
|
||||
}
|
||||
// Should be merged without <|User|> between them
|
||||
count := strings.Count(got, "<|User|>")
|
||||
if count != 0 {
|
||||
t.Fatalf("expected no User marker for first message pair, got %d occurrences", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareAssistantMarkers(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if !strings.Contains(got, "<|Assistant|>") {
|
||||
t.Fatalf("expected assistant marker, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|end▁of▁sentence|>") {
|
||||
t.Fatalf("expected end of sentence marker, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareUnknownRole(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "unknown_role", "content": "Unknown"},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if !strings.Contains(got, "Unknown") {
|
||||
t.Fatalf("expected unknown role content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareMarkdownImageReplaced(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "Look at this: "},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if strings.Contains(got, "![alt]") {
|
||||
t.Fatalf("expected markdown image to be replaced, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPrepareNilContent(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": nil},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if got != "null" {
|
||||
t.Logf("nil content handled as: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── normalizeContent edge cases ─────────────────────────────────────
|
||||
|
||||
func TestNormalizeContentString(t *testing.T) {
|
||||
got := normalizeContent("hello")
|
||||
if got != "hello" {
|
||||
t.Fatalf("expected 'hello', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentArray(t *testing.T) {
|
||||
got := normalizeContent([]any{
|
||||
map[string]any{"type": "text", "text": "line1"},
|
||||
map[string]any{"type": "text", "text": "line2"},
|
||||
})
|
||||
if got != "line1\nline2" {
|
||||
t.Fatalf("expected 'line1\\nline2', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentArrayWithContentField(t *testing.T) {
|
||||
got := normalizeContent([]any{
|
||||
map[string]any{"type": "text", "content": "from-content"},
|
||||
})
|
||||
if got != "from-content" {
|
||||
t.Fatalf("expected 'from-content', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentArraySkipsImage(t *testing.T) {
|
||||
got := normalizeContent([]any{
|
||||
map[string]any{"type": "image_url", "image_url": "https://example.com/img.png"},
|
||||
map[string]any{"type": "text", "text": "caption"},
|
||||
})
|
||||
if strings.Contains(got, "image") {
|
||||
t.Fatalf("expected image skipped, got %q", got)
|
||||
}
|
||||
if got != "caption" {
|
||||
t.Fatalf("expected 'caption', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentArrayNonMapItems(t *testing.T) {
|
||||
got := normalizeContent([]any{"string item", 42})
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty for non-map items, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentJSON(t *testing.T) {
|
||||
got := normalizeContent(map[string]any{"key": "value"})
|
||||
if !strings.Contains(got, `"key":"value"`) {
|
||||
t.Fatalf("expected JSON serialized, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── ConvertClaudeToDeepSeek edge cases ──────────────────────────────
|
||||
|
||||
func TestConvertClaudeToDeepSeekDefaultModel(t *testing.T) {
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
|
||||
}
|
||||
out := ConvertClaudeToDeepSeek(req, store)
|
||||
if out["model"] == "" {
|
||||
t.Fatal("expected default model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeToDeepSeekWithStopSequences(t *testing.T) {
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
|
||||
"stop_sequences": []any{"\n\n"},
|
||||
}
|
||||
out := ConvertClaudeToDeepSeek(req, store)
|
||||
if out["stop"] == nil {
|
||||
t.Fatal("expected stop field from stop_sequences")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeToDeepSeekWithTemperature(t *testing.T) {
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
out := ConvertClaudeToDeepSeek(req, store)
|
||||
if out["temperature"] != 0.7 {
|
||||
t.Fatalf("expected temperature 0.7, got %v", out["temperature"])
|
||||
}
|
||||
if out["top_p"] != 0.9 {
|
||||
t.Fatalf("expected top_p 0.9, got %v", out["top_p"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeToDeepSeekNoSystem(t *testing.T) {
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
|
||||
}
|
||||
out := ConvertClaudeToDeepSeek(req, store)
|
||||
msgs, _ := out["messages"].([]any)
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected 1 message without system, got %d", len(msgs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeToDeepSeekOpusUsesSlowMapping(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "Hi"}},
|
||||
}
|
||||
out := ConvertClaudeToDeepSeek(req, store)
|
||||
if out["model"] != "deepseek-reasoner" {
|
||||
t.Fatalf("expected opus to use slow mapping, got %q", out["model"])
|
||||
}
|
||||
}
|
||||
|
||||
// ─── FormatOpenAIStreamToolCalls ─────────────────────────────────────
|
||||
|
||||
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
|
||||
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
|
||||
{Name: "search", Input: map[string]any{"q": "test"}},
|
||||
})
|
||||
if len(formatted) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(formatted))
|
||||
}
|
||||
fn, _ := formatted[0]["function"].(map[string]any)
|
||||
if fn["name"] != "search" {
|
||||
t.Fatalf("unexpected function name: %#v", fn)
|
||||
}
|
||||
if formatted[0]["index"] != 0 {
|
||||
t.Fatalf("expected index 0, got %v", formatted[0]["index"])
|
||||
}
|
||||
}
|
||||
|
||||
// ─── ParseToolCalls more edge cases ──────────────────────────────────
|
||||
|
||||
func TestParseToolCallsNoToolNames(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
calls := ParseToolCalls(text, nil)
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call with nil tool names, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsEmptyText(t *testing.T) {
|
||||
calls := ParseToolCalls("", []string{"search"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected 0 calls for empty text, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsMultipleTools(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}},{"name":"get_weather","input":{"city":"beijing"}}]}`
|
||||
calls := ParseToolCalls(text, []string{"search", "get_weather"})
|
||||
if len(calls) != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsInputAsString(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"search","input":"{\"q\":\"golang\"}"}]}`
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Input["q"] != "golang" {
|
||||
t.Fatalf("expected parsed string input, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsWithFunctionWrapper(t *testing.T) {
|
||||
text := `{"tool_calls":[{"function":{"name":"calc","arguments":{"x":1,"y":2}}}]}`
|
||||
calls := ParseToolCalls(text, []string{"calc"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Name != "calc" {
|
||||
t.Fatalf("expected calc, got %q", calls[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) {
|
||||
fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this."
|
||||
calls := ParseStandaloneToolCalls(fenced, []string{"search"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected fenced code block ignored, got %d calls", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
// ─── looksLikeToolExampleContext ─────────────────────────────────────
|
||||
|
||||
func TestLooksLikeToolExampleContextNone(t *testing.T) {
|
||||
if looksLikeToolExampleContext("I will call the tool now") {
|
||||
t.Fatal("expected false for non-example context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeToolExampleContextFenced(t *testing.T) {
|
||||
if !looksLikeToolExampleContext("```json") {
|
||||
t.Fatal("expected true for fenced code block context")
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,12 @@
|
||||
"apiKey": "your-api-key"
|
||||
},
|
||||
"models": {
|
||||
"gpt-4o": {
|
||||
"name": "GPT-4o (aliased to deepseek-chat)"
|
||||
},
|
||||
"gpt-5-codex": {
|
||||
"name": "GPT-5 Codex (aliased to deepseek-reasoner)"
|
||||
},
|
||||
"deepseek-chat": {
|
||||
"name": "DeepSeek Chat (DS2API)"
|
||||
},
|
||||
@@ -18,5 +24,5 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"model": "ds2api/deepseek-chat"
|
||||
"model": "ds2api/gpt-5-codex"
|
||||
}
|
||||
|
||||
@@ -39,6 +39,10 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
|
||||
const [loadingAccounts, setLoadingAccounts] = useState(false)
|
||||
|
||||
const apiFetch = authFetch || fetch
|
||||
const resolveAccountIdentifier = (acc) => {
|
||||
if (!acc || typeof acc !== 'object') return ''
|
||||
return String(acc.identifier || acc.email || acc.mobile || '').trim()
|
||||
}
|
||||
|
||||
const fetchAccounts = async (targetPage = page) => {
|
||||
setLoadingAccounts(true)
|
||||
@@ -147,9 +151,14 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
|
||||
}
|
||||
|
||||
const deleteAccount = async (id) => {
|
||||
const identifier = String(id || '').trim()
|
||||
if (!identifier) {
|
||||
onMessage('error', t('accountManager.invalidIdentifier'))
|
||||
return
|
||||
}
|
||||
if (!confirm(t('accountManager.deleteAccountConfirm'))) return
|
||||
try {
|
||||
const res = await apiFetch(`/admin/accounts/${encodeURIComponent(id)}`, { method: 'DELETE' })
|
||||
const res = await apiFetch(`/admin/accounts/${encodeURIComponent(identifier)}`, { method: 'DELETE' })
|
||||
if (res.ok) {
|
||||
onMessage('success', t('messages.deleted'))
|
||||
fetchAccounts() // 刷新当前页
|
||||
@@ -163,24 +172,29 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
|
||||
}
|
||||
|
||||
const testAccount = async (identifier) => {
|
||||
setTesting(prev => ({ ...prev, [identifier]: true }))
|
||||
const accountID = String(identifier || '').trim()
|
||||
if (!accountID) {
|
||||
onMessage('error', t('accountManager.invalidIdentifier'))
|
||||
return
|
||||
}
|
||||
setTesting(prev => ({ ...prev, [accountID]: true }))
|
||||
try {
|
||||
const res = await apiFetch('/admin/accounts/test', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ identifier }),
|
||||
body: JSON.stringify({ identifier: accountID }),
|
||||
})
|
||||
const data = await res.json()
|
||||
const statusMessage = data.success
|
||||
? t('apiTester.testSuccess', { account: identifier, time: data.response_time })
|
||||
: `${identifier}: ${data.message}`
|
||||
? t('apiTester.testSuccess', { account: accountID, time: data.response_time })
|
||||
: `${accountID}: ${data.message}`
|
||||
onMessage(data.success ? 'success' : 'error', statusMessage)
|
||||
fetchAccounts() // 刷新当前页
|
||||
onRefresh()
|
||||
} catch (e) {
|
||||
onMessage('error', t('accountManager.testFailed', { error: e.message }))
|
||||
} finally {
|
||||
setTesting(prev => ({ ...prev, [identifier]: false }))
|
||||
setTesting(prev => ({ ...prev, [accountID]: false }))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,7 +211,12 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
|
||||
|
||||
for (let i = 0; i < allAccounts.length; i++) {
|
||||
const acc = allAccounts[i]
|
||||
const id = acc.email || acc.mobile
|
||||
const id = resolveAccountIdentifier(acc)
|
||||
if (!id) {
|
||||
results.push({ id: '-', success: false, message: t('accountManager.invalidIdentifier') })
|
||||
setBatchProgress({ current: i + 1, total: allAccounts.length, results: [...results] })
|
||||
continue
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await apiFetch('/admin/accounts/test', {
|
||||
@@ -387,7 +406,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
|
||||
<div className="p-8 text-center text-muted-foreground">{t('actions.loading')}</div>
|
||||
) : accounts.length > 0 ? (
|
||||
accounts.map((acc, i) => {
|
||||
const id = acc.email || acc.mobile
|
||||
const id = resolveAccountIdentifier(acc)
|
||||
return (
|
||||
<div key={i} className="p-4 flex flex-col md:flex-row md:items-center justify-between gap-4 hover:bg-muted/50 transition-colors">
|
||||
<div className="flex items-center gap-3 min-w-0">
|
||||
@@ -396,7 +415,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
|
||||
acc.has_token ? "bg-emerald-500 shadow-[0_0_8px_rgba(16,185,129,0.5)]" : "bg-amber-500"
|
||||
)} />
|
||||
<div className="min-w-0">
|
||||
<div className="font-medium truncate">{id}</div>
|
||||
<div className="font-medium truncate">{id || '-'}</div>
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground mt-0.5">
|
||||
<span>{acc.has_token ? t('accountManager.sessionActive') : t('accountManager.reauthRequired')}</span>
|
||||
{acc.token_preview && (
|
||||
@@ -419,7 +438,7 @@ export default function AccountManager({ config, onRefresh, onMessage, authFetch
|
||||
onClick={() => deleteAccount(id)}
|
||||
className="p-1 lg:p-1.5 text-muted-foreground hover:text-destructive hover:bg-destructive/10 rounded-md transition-colors"
|
||||
>
|
||||
<Trash2 className="w-3.5 h-3.5 lg:w-4 h-4" />
|
||||
<Trash2 className="w-3.5 h-3.5 lg:w-4 lg:h-4" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -42,6 +42,10 @@ export default function ApiTester({ config, onMessage, authFetch }) {
|
||||
|
||||
const apiFetch = authFetch || fetch
|
||||
const accounts = config.accounts || []
|
||||
const resolveAccountIdentifier = (acc) => {
|
||||
if (!acc || typeof acc !== 'object') return ''
|
||||
return String(acc.identifier || acc.email || acc.mobile || '').trim()
|
||||
}
|
||||
const configuredKeys = config.keys || []
|
||||
const trimmedApiKey = apiKey.trim()
|
||||
const defaultKey = configuredKeys[0] || ''
|
||||
@@ -297,11 +301,15 @@ return (
|
||||
onChange={e => setSelectedAccount(e.target.value)}
|
||||
>
|
||||
<option value="" className="bg-popover text-popover-foreground">{t('apiTester.autoRandom')}</option>
|
||||
{accounts.map((acc, i) => (
|
||||
<option key={i} value={acc.email || acc.mobile} className="bg-popover text-popover-foreground">
|
||||
👤 {acc.email || acc.mobile}
|
||||
</option>
|
||||
))}
|
||||
{accounts.map((acc, i) => {
|
||||
const id = resolveAccountIdentifier(acc)
|
||||
if (!id) return null
|
||||
return (
|
||||
<option key={i} value={id} className="bg-popover text-popover-foreground">
|
||||
👤 {id}
|
||||
</option>
|
||||
)
|
||||
})}
|
||||
</select>
|
||||
<ChevronDown className="absolute right-2.5 top-3 w-4 h-4 text-muted-foreground pointer-events-none" />
|
||||
</div>
|
||||
|
||||
@@ -86,6 +86,7 @@
|
||||
"requiredFields": "Password and email/mobile are required.",
|
||||
"deleteKeyConfirm": "Are you sure you want to delete this API key?",
|
||||
"deleteAccountConfirm": "Are you sure you want to delete this account?",
|
||||
"invalidIdentifier": "Invalid account identifier. Operation aborted.",
|
||||
"testAllConfirm": "Test API connectivity for all accounts?",
|
||||
"testAllCompleted": "Completed: {success}/{total} available",
|
||||
"testFailed": "Test failed: {error}",
|
||||
|
||||
@@ -86,6 +86,7 @@
|
||||
"requiredFields": "需要填写密码以及邮箱或手机号",
|
||||
"deleteKeyConfirm": "确定要删除此 API 密钥吗?",
|
||||
"deleteAccountConfirm": "确定要删除此账号吗?",
|
||||
"invalidIdentifier": "账号标识无效,无法执行操作",
|
||||
"testAllConfirm": "测试所有账号的 API 连通性?",
|
||||
"testAllCompleted": "完成:{success}/{total} 可用",
|
||||
"testFailed": "测试失败: {error}",
|
||||
|
||||
Reference in New Issue
Block a user