diff --git a/.dockerignore b/.dockerignore index a33226a..fecc5f7 100644 --- a/.dockerignore +++ b/.dockerignore @@ -10,7 +10,9 @@ __pycache__ .Python build/ develop-eggs/ -dist/ +dist/* +!dist/docker-input/ +!dist/docker-input/*.tar.gz downloads/ eggs/ .eggs/ diff --git a/.env.example b/.env.example index 21a4d2a..d63f133 100644 --- a/.env.example +++ b/.env.example @@ -52,6 +52,9 @@ DS2API_ADMIN_KEY=admin # Option C: Base64 encoded JSON (recommended for Vercel env var) # DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0= +# +# Generate from local config.json: +# DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" # --------------------------------------------------------------- # Paths (optional) diff --git a/.github/workflows/quality-gates.yml b/.github/workflows/quality-gates.yml new file mode 100644 index 0000000..3d7c9a1 --- /dev/null +++ b/.github/workflows/quality-gates.yml @@ -0,0 +1,40 @@ +name: Quality Gates + +on: + pull_request: + push: + branches: + - dev + +permissions: + contents: read + +jobs: + quality-gates: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.24.x" + + - name: Setup Node + uses: actions/setup-node@v4 + with: + node-version: "20" + cache: "npm" + cache-dependency-path: webui/package-lock.json + + - name: Refactor Line Gate + run: ./tests/scripts/check-refactor-line-gate.sh + + - name: Unit Gates (Go + Node) + run: ./tests/scripts/run-unit-all.sh + + - name: WebUI Build Gate + run: | + npm ci --prefix webui + npm run build --prefix webui diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 4ed0cd0..3293fb8 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -4,6 +4,12 @@ on: release: types: - published + workflow_dispatch: + inputs: + release_tag: + description: "Release tag to build/publish (e.g. v2.1.6)" + required: true + type: string permissions: contents: write @@ -12,6 +18,8 @@ permissions: jobs: build-and-upload: runs-on: ubuntu-latest + env: + RELEASE_TAG: ${{ github.event.release.tag_name || github.event.inputs.release_tag }} steps: - name: Checkout uses: actions/checkout@v4 @@ -28,6 +36,12 @@ jobs: cache: "npm" cache-dependency-path: webui/package-lock.json + - name: Release Blocking Gates + run: | + ./tests/scripts/check-stage6-manual-smoke.sh + ./tests/scripts/check-refactor-line-gate.sh + ./tests/scripts/run-unit-all.sh + - name: Build WebUI run: | npm ci --prefix webui @@ -36,7 +50,7 @@ jobs: - name: Build Multi-Platform Archives run: | set -euo pipefail - TAG="${{ github.event.release.tag_name }}" + TAG="${RELEASE_TAG}" mkdir -p dist targets=( @@ -73,15 +87,13 @@ jobs: rm -rf "${STAGE}" done - (cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt) - - - name: Upload Release Assets - uses: softprops/action-gh-release@v2 - with: - files: | - dist/*.tar.gz - dist/*.zip - dist/sha256sums.txt + - name: Prepare Docker release inputs + run: | + set -euo pipefail + TAG="${RELEASE_TAG}" + mkdir -p dist/docker-input + cp "dist/ds2api_${TAG}_linux_amd64.tar.gz" "dist/docker-input/linux_amd64.tar.gz" + cp "dist/ds2api_${TAG}_linux_arm64.tar.gz" "dist/docker-input/linux_arm64.tar.gz" - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -89,28 +101,103 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - - name: Log in to GHCR - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + - name: Wait for GHCR endpoint + run: | + set -euo pipefail + for i in {1..6}; do + code="$(curl -sS -o /dev/null -w '%{http_code}' --max-time 15 https://ghcr.io/v2/ || true)" + if [ "${code}" = "200" ] || [ "${code}" = "401" ] || [ "${code}" = "405" ]; then + exit 0 + fi + sleep "$((i * 10))" + done + echo "GHCR endpoint is unreachable after multiple retries (last status: ${code:-unknown})." >&2 + exit 1 + + - name: Log in to GHCR (with retry) + run: | + set -euo pipefail + for i in {1..6}; do + if echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin; then + exit 0 + fi + sleep "$((i * 10))" + done + echo "Failed to login to GHCR after multiple retries." >&2 + exit 1 - name: Extract Docker metadata - id: meta + id: meta_release uses: docker/metadata-action@v5 with: - images: ghcr.io/${{ github.repository }} + images: | + ghcr.io/${{ github.repository }} tags: | - type=raw,value=${{ github.event.release.tag_name }} + type=raw,value=${{ env.RELEASE_TAG }} type=raw,value=latest - name: Build and Push Docker Image uses: docker/build-push-action@v6 + env: + DOCKER_BUILD_RECORD_UPLOAD: "false" + DOCKER_BUILD_SUMMARY: "false" with: context: . file: ./Dockerfile + target: runtime-from-dist push: true platforms: linux/amd64,linux/arm64 - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} + tags: ${{ steps.meta_release.outputs.tags }} + labels: ${{ steps.meta_release.outputs.labels }} + + - name: Export Docker image archives for release assets + run: | + set -euo pipefail + TAG="${RELEASE_TAG}" + + docker buildx build \ + --platform linux/amd64 \ + --target runtime-from-dist \ + --output type=docker,dest="dist/ds2api_${TAG}_docker_linux_amd64.tar" \ + . + + docker buildx build \ + --platform linux/arm64 \ + --target runtime-from-dist \ + --output type=docker,dest="dist/ds2api_${TAG}_docker_linux_arm64.tar" \ + . + + gzip -f "dist/ds2api_${TAG}_docker_linux_amd64.tar" + gzip -f "dist/ds2api_${TAG}_docker_linux_arm64.tar" + + - name: Generate checksums + run: | + set -euo pipefail + (cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt) + + - name: Validate release tag + run: | + set -euo pipefail + TAG="${RELEASE_TAG}" + if [ -z "${TAG}" ]; then + echo "release tag is empty; set release_tag when using workflow_dispatch." >&2 + exit 1 + fi + + - name: Upload Release Assets + env: + GH_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + TAG="${RELEASE_TAG}" + FILES=( + dist/*.tar.gz + dist/*.zip + dist/sha256sums.txt + ) + + if gh release view "${TAG}" >/dev/null 2>&1; then + gh release upload "${TAG}" "${FILES[@]}" --clobber + else + gh release create "${TAG}" "${FILES[@]}" --title "${TAG}" --notes "" + fi diff --git a/.gitignore b/.gitignore index 2221ddd..d096b58 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,15 @@ ds2api-tests .env.local .env.*.local +# Testing +.coverage +htmlcov/ +.pytest_cache/ +.tox/ +*.coverprofile +coverage*.out +cover/ + # Misc .git/ Thumbs.db diff --git a/API.en.md b/API.en.md index e570dee..2b18d89 100644 --- a/API.en.md +++ b/API.en.md @@ -9,11 +9,13 @@ This document describes the actual behavior of the current Go codebase. ## Table of Contents - [Basics](#basics) +- [Configuration Best Practice](#configuration-best-practice) - [Authentication](#authentication) - [Route Index](#route-index) - [Health Endpoints](#health-endpoints) - [OpenAI-Compatible API](#openai-compatible-api) - [Claude-Compatible API](#claude-compatible-api) +- [Gemini-Compatible API](#gemini-compatible-api) - [Admin API](#admin-api) - [Error Payloads](#error-payloads) - [cURL Examples](#curl-examples) @@ -27,13 +29,35 @@ This document describes the actual behavior of the current Go codebase. | Base URL | `http://localhost:5001` or your deployment domain | | Default Content-Type | `application/json` | | Health probes | `GET /healthz`, `GET /readyz` | -| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`) | +| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) | + +--- + +## Configuration Best Practice + +Use `config.json` as the single source of truth: + +```bash +cp config.example.json config.json +# Edit config.json (keys/accounts) +``` + +Use it per deployment mode: + +- Local run: read `config.json` directly +- Docker / Vercel: generate Base64 from `config.json`, then set `DS2API_CONFIG_JSON` + +```bash +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +``` + +For Vercel one-click bootstrap, you can set only `DS2API_ADMIN_KEY` first, then import config at `/admin` and sync env vars from the "Vercel Sync" page. --- ## Authentication -### Business Endpoints (`/v1/*`, `/anthropic/*`) +### Business Endpoints (`/v1/*`, `/anthropic/*`, `/v1beta/models/*`) Two header formats accepted: @@ -66,15 +90,32 @@ Two header formats accepted: | GET | `/healthz` | None | Liveness probe | | GET | `/readyz` | None | Readiness probe | | GET | `/v1/models` | None | OpenAI model list | +| GET | `/v1/models/{id}` | None | OpenAI single-model query (alias accepted) | | POST | `/v1/chat/completions` | Business | OpenAI chat completions | +| POST | `/v1/responses` | Business | OpenAI Responses API (stream/non-stream) | +| GET | `/v1/responses/{response_id}` | Business | Query stored response (in-memory TTL) | +| POST | `/v1/embeddings` | Business | OpenAI Embeddings API | | GET | `/anthropic/v1/models` | None | Claude model list | | POST | `/anthropic/v1/messages` | Business | Claude messages | | POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting | +| POST | `/v1/messages` | Business | Claude shortcut path | +| POST | `/messages` | Business | Claude shortcut path | +| POST | `/v1/messages/count_tokens` | Business | Claude token counting shortcut | +| POST | `/messages/count_tokens` | Business | Claude token counting shortcut | +| POST | `/v1beta/models/{model}:generateContent` | Business | Gemini non-stream | +| POST | `/v1beta/models/{model}:streamGenerateContent` | Business | Gemini stream | +| POST | `/v1/models/{model}:generateContent` | Business | Gemini non-stream compat path | +| POST | `/v1/models/{model}:streamGenerateContent` | Business | Gemini stream compat path | | POST | `/admin/login` | None | Admin login | | GET | `/admin/verify` | JWT | Verify admin JWT | | GET | `/admin/vercel/config` | Admin | Read preconfigured Vercel creds | | GET | `/admin/config` | Admin | Read sanitized config | | POST | `/admin/config` | Admin | Update config | +| GET | `/admin/settings` | Admin | Read runtime settings | +| PUT | `/admin/settings` | Admin | Update runtime settings (hot reload) | +| POST | `/admin/settings/password` | Admin | Update admin password and invalidate old JWTs | +| POST | `/admin/config/import` | Admin | Import config (merge/replace) | +| GET | `/admin/config/export` | Admin | Export full config (`config`/`json`/`base64`) | | POST | `/admin/keys` | Admin | Add API key | | DELETE | `/admin/keys/{key}` | Admin | Delete API key | | GET | `/admin/accounts` | Admin | Paginated account list | @@ -88,6 +129,8 @@ Two header formats accepted: | POST | `/admin/vercel/sync` | Admin | Sync config to Vercel | | GET | `/admin/vercel/status` | Admin | Vercel sync status | | GET | `/admin/export` | Admin | Export config JSON/Base64 | +| GET | `/admin/dev/captures` | Admin | Read local packet-capture entries | +| DELETE | `/admin/dev/captures` | Admin | Clear local packet-capture entries | --- @@ -127,6 +170,15 @@ No auth required. Returns supported models. } ``` +### Model Alias Resolution + +For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-output policy: + +1. Match DeepSeek native model IDs first. +2. Then match exact keys in `model_aliases`. +3. If still unmatched, fall back by known family heuristics (`o*`, `gpt-*`, `claude-*`, etc.). +4. If still unmatched, return `invalid_request_error`. + ### `POST /v1/chat/completions` **Headers**: @@ -140,7 +192,7 @@ Content-Type: application/json | Field | Type | Required | Notes | | --- | --- | --- | --- | -| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` | +| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, etc.) | | `messages` | array | ✅ | OpenAI-style messages | | `stream` | boolean | ❌ | Default `false` | | `tools` | array | ❌ | Function calling schema | @@ -230,12 +282,90 @@ When `tools` is present, DS2API performs anti-leak handling: } ``` -**Stream**: DS2API buffers text first. If tool call detected → only structured `delta.tool_calls` (each with `index`); otherwise emits buffered text at once. +**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full JSON closure), then keeps sending argument deltas; confirmed raw tool JSON is never forwarded as `delta.content`. + +--- + +### `GET /v1/models/{id}` + +No auth required. Alias values are accepted as path params (for example `gpt-4o`), and the returned object is the mapped DeepSeek model. + +### `POST /v1/responses` + +OpenAI Responses-style endpoint, accepting either `input` or `messages`. + +| Field | Type | Required | Notes | +| --- | --- | --- | --- | +| `model` | string | ✅ | Supports native models + alias mapping | +| `input` | string/array/object | ❌ | One of `input` or `messages` is required | +| `messages` | array | ❌ | One of `input` or `messages` is required | +| `instructions` | string | ❌ | Prepended as a system message | +| `stream` | boolean | ❌ | Default `false` | +| `tools` | array | ❌ | Same tool detection/translation policy as chat | +| `tool_choice` | string/object | ❌ | Supports `auto`/`none`/`required` and forced function selection (`{"type":"function","name":"..."}`) | + +**Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache. +If `tool_choice=required` and no valid tool call is produced, DS2API returns HTTP `422` (`error.code=tool_choice_violation`). + +**Stream (SSE)**: minimal event sequence: + +```text +event: response.created +data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...} + +event: response.output_item.added +data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...} + +event: response.content_part.added +data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...} + +event: response.output_text.delta +data: {"type":"response.output_text.delta","response_id":"resp_xxx","item_id":"msg_xxx","output_index":0,"content_index":0,"delta":"..."} + +event: response.function_call_arguments.delta +data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."} + +event: response.function_call_arguments.done +data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"} + +event: response.content_part.done +data: {"type":"response.content_part.done","response_id":"resp_xxx",...} + +event: response.output_item.done +data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...} + +event: response.completed +data: {"type":"response.completed","response":{...}} + +data: [DONE] +``` + +If `tool_choice=required` is violated in stream mode, DS2API emits `response.failed` then `[DONE]` (no `response.completed`). +Unknown tool names (outside declared `tools`) are rejected and will not be emitted as valid tool calls. + +### `GET /v1/responses/{response_id}` + +Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read). + +> Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`). + +### `POST /v1/embeddings` + +Business auth required. Returns OpenAI-compatible embeddings shape. + +| Field | Type | Required | Notes | +| --- | --- | --- | --- | +| `model` | string | ✅ | Supports native models + alias mapping | +| `input` | string/array | ✅ | Supports string, string array, token array | + +> Requires `embeddings.provider`. Current supported values: `mock` / `deterministic` / `builtin`. If missing/unsupported, returns standard error shape with HTTP 501. --- ## Claude-Compatible API +Besides `/anthropic/v1/*`, DS2API also supports shortcut paths: `/v1/messages`, `/messages`, `/v1/messages/count_tokens`, `/messages/count_tokens`. + ### `GET /anthropic/v1/models` No auth required. @@ -249,7 +379,10 @@ No auth required. {"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"} - ] + ], + "first_id": "claude-opus-4-6", + "last_id": "claude-instant-1.0", + "has_more": false } ``` @@ -265,13 +398,15 @@ Content-Type: application/json anthropic-version: 2023-06-01 ``` +> `anthropic-version` is optional; DS2API auto-fills `2023-06-01` when absent. + **Request body**: | Field | Type | Required | Notes | | --- | --- | --- | --- | | `model` | string | ✅ | For example `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5` (compatible with `claude-3-5-haiku-latest`), plus historical Claude model IDs | | `messages` | array | ✅ | Claude-style messages | -| `max_tokens` | number | ❌ | Not strictly enforced by upstream bridge | +| `max_tokens` | number | ❌ | Auto-filled to `8192` when omitted; not strictly enforced by upstream bridge | | `stream` | boolean | ❌ | Default `false` | | `system` | string | ❌ | Optional system prompt | | `tools` | array | ❌ | Claude tool schema | @@ -354,6 +489,37 @@ data: {"type":"message_stop"} --- +## Gemini-Compatible API + +Supported paths: + +- `/v1beta/models/{model}:generateContent` +- `/v1beta/models/{model}:streamGenerateContent` +- `/v1/models/{model}:generateContent` (compat path) +- `/v1/models/{model}:streamGenerateContent` (compat path) + +Authentication is the same as other business routes (`Authorization: Bearer ` or `x-api-key`). + +### `POST /v1beta/models/{model}:generateContent` + +Request body accepts Gemini-style `contents` / `tools`. Model names can use aliases and are mapped to DeepSeek models. + +Response uses Gemini-compatible fields, including: + +- `candidates[].content.parts[].text` +- `candidates[].content.parts[].functionCall` (when tool call is produced) +- `usageMetadata` (`promptTokenCount` / `candidatesTokenCount` / `totalTokenCount`) + +### `POST /v1beta/models/{model}:streamGenerateContent` + +Returns SSE (`text/event-stream`), each chunk as `data: `: + +- regular text: incremental text chunks +- `tools` mode: buffered and emitted as `functionCall` at finalize phase +- final chunk: includes `finishReason: "STOP"` and `usageMetadata` + +--- + ## Admin API ### `POST /admin/login` @@ -416,6 +582,7 @@ Returns sanitized config. "keys": ["k1", "k2"], "accounts": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -449,6 +616,51 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. } ``` +### `GET /admin/settings` + +Reads runtime settings and status, including: + +- `admin` (JWT expiry, default-password warning, etc.) +- `runtime` (`account_max_inflight`, `account_max_queue`, `global_max_inflight`) +- `toolcall` / `responses` / `embeddings` +- `claude_mapping` / `model_aliases` +- `env_backed`, `needs_vercel_sync` + +### `PUT /admin/settings` + +Hot-updates runtime settings. Supported fields: + +- `admin.jwt_expire_hours` +- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight` +- `toolcall.mode` / `toolcall.early_emit_confidence` +- `responses.store_ttl_seconds` +- `embeddings.provider` +- `claude_mapping` +- `model_aliases` + +### `POST /admin/settings/password` + +Updates admin password and invalidates existing JWTs. + +Request example: + +```json +{"new_password":"your-new-password"} +``` + +### `POST /admin/config/import` + +Imports full config with: + +- `mode=merge` (default) +- `mode=replace` + +The request can send config directly, or wrapped as `{"config": {...}, "mode":"merge"}`. + +### `GET /admin/config/export` + +Exports full config in three forms: `config`, `json`, and `base64`. + ### `POST /admin/keys` ```json @@ -476,6 +688,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. { "items": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -500,7 +713,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. ### `DELETE /admin/accounts/{identifier}` -`identifier` is email or mobile. +`identifier` can be email, mobile, or the synthetic id for token-only accounts (`token:`). **Response**: `{"success": true, "total_accounts": 5}` @@ -530,7 +743,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`. | Field | Required | Notes | | --- | --- | --- | -| `identifier` | ✅ | email or mobile | +| `identifier` | ✅ | email / mobile / token-only synthetic id | | `model` | ❌ | default `deepseek-chat` | | `message` | ❌ | if empty, only session creation is tested | @@ -655,17 +868,53 @@ Or manual deploy required: } ``` +### `GET /admin/dev/captures` + +Reads local packet-capture status and recent entries (Admin auth required): + +- `enabled` +- `limit` +- `max_body_bytes` +- `items` + +### `DELETE /admin/dev/captures` + +Clears packet-capture entries: + +```json +{"success":true,"detail":"capture logs cleared"} +``` + --- ## Error Payloads -Error formats vary by module: +Compatible routes (`/v1/*`, `/anthropic/*`) use the same error envelope: -| Module | Format | -| --- | --- | -| OpenAI routes | `{"error": {"message": "...", "type": "..."}}` | -| Claude routes | `{"error": {"type": "...", "message": "..."}}` | -| Admin routes | `{"detail": "..."}` | +```json +{ + "error": { + "message": "...", + "type": "invalid_request_error", + "code": "invalid_request", + "param": null + } +} +``` + +Admin routes keep `{"detail":"..."}`. + +Gemini routes use Google-style errors: + +```json +{ + "error": { + "code": 400, + "message": "invalid json", + "status": "INVALID_ARGUMENT" + } +} +``` Clients should handle HTTP status code plus `error` / `detail` fields. @@ -707,6 +956,31 @@ curl http://localhost:5001/v1/chat/completions \ }' ``` +### OpenAI Responses (Stream) + +```bash +curl http://localhost:5001/v1/responses \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-5-codex", + "input": "Write a hello world in golang", + "stream": true + }' +``` + +### OpenAI Embeddings + +```bash +curl http://localhost:5001/v1/embeddings \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o", + "input": ["first text", "second text"] + }' +``` + ### OpenAI with Search ```bash @@ -748,6 +1022,38 @@ curl http://localhost:5001/v1/chat/completions \ }' ``` +### Gemini Non-Stream + +```bash +curl "http://localhost:5001/v1beta/models/gemini-2.5-pro:generateContent" \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + { + "role": "user", + "parts": [{"text": "Introduce Go in three sentences"}] + } + ] + }' +``` + +### Gemini Stream + +```bash +curl "http://localhost:5001/v1beta/models/gemini-2.5-flash:streamGenerateContent" \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + { + "role": "user", + "parts": [{"text": "Write a short summary"}] + } + ] + }' +``` + ### Claude Non-Stream ```bash diff --git a/API.md b/API.md index 6be7f65..0647c4e 100644 --- a/API.md +++ b/API.md @@ -9,11 +9,13 @@ ## 目录 - [基础信息](#基础信息) +- [配置最佳实践](#配置最佳实践) - [鉴权规则](#鉴权规则) - [路由总览](#路由总览) - [健康检查](#健康检查) - [OpenAI 兼容接口](#openai-兼容接口) - [Claude 兼容接口](#claude-兼容接口) +- [Gemini 兼容接口](#gemini-兼容接口) - [Admin 接口](#admin-接口) - [错误响应格式](#错误响应格式) - [cURL 示例](#curl-示例) @@ -27,13 +29,35 @@ | Base URL | `http://localhost:5001` 或你的部署域名 | | 默认 Content-Type | `application/json` | | 健康检查 | `GET /healthz`、`GET /readyz` | -| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`) | +| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) | + +--- + +## 配置最佳实践 + +推荐把 `config.json` 作为唯一配置源: + +```bash +cp config.example.json config.json +# 编辑 config.json(keys/accounts) +``` + +按部署方式使用: + +- 本地运行:直接读取 `config.json` +- Docker / Vercel:从 `config.json` 生成 Base64,填入 `DS2API_CONFIG_JSON` + +```bash +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +``` + +Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导入配置,再通过 “Vercel 同步” 写回环境变量。 --- ## 鉴权规则 -### 业务接口(`/v1/*`、`/anthropic/*`) +### 业务接口(`/v1/*`、`/anthropic/*`、`/v1beta/models/*`) 支持两种传参方式: @@ -66,15 +90,32 @@ | GET | `/healthz` | 无 | 存活探针 | | GET | `/readyz` | 无 | 就绪探针 | | GET | `/v1/models` | 无 | OpenAI 模型列表 | +| GET | `/v1/models/{id}` | 无 | OpenAI 单模型查询(支持 alias 入参) | | POST | `/v1/chat/completions` | 业务 | OpenAI 对话补全 | +| POST | `/v1/responses` | 业务 | OpenAI Responses 接口(流式/非流式) | +| GET | `/v1/responses/{response_id}` | 业务 | 查询已生成 response(内存 TTL) | +| POST | `/v1/embeddings` | 业务 | OpenAI Embeddings 接口 | | GET | `/anthropic/v1/models` | 无 | Claude 模型列表 | | POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 | | POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 | +| POST | `/v1/messages` | 业务 | Claude 消息快捷路径 | +| POST | `/messages` | 业务 | Claude 消息快捷路径 | +| POST | `/v1/messages/count_tokens` | 业务 | Claude token 计数快捷路径 | +| POST | `/messages/count_tokens` | 业务 | Claude token 计数快捷路径 | +| POST | `/v1beta/models/{model}:generateContent` | 业务 | Gemini 非流式 | +| POST | `/v1beta/models/{model}:streamGenerateContent` | 业务 | Gemini 流式 | +| POST | `/v1/models/{model}:generateContent` | 业务 | Gemini 非流式兼容路径 | +| POST | `/v1/models/{model}:streamGenerateContent` | 业务 | Gemini 流式兼容路径 | | POST | `/admin/login` | 无 | 管理登录 | | GET | `/admin/verify` | JWT | 校验管理 JWT | | GET | `/admin/vercel/config` | Admin | 读取 Vercel 预配置 | | GET | `/admin/config` | Admin | 读取配置(脱敏) | | POST | `/admin/config` | Admin | 更新配置 | +| GET | `/admin/settings` | Admin | 读取运行时设置 | +| PUT | `/admin/settings` | Admin | 更新运行时设置(热更新) | +| POST | `/admin/settings/password` | Admin | 更新 Admin 密码并使旧 JWT 失效 | +| POST | `/admin/config/import` | Admin | 导入配置(merge/replace) | +| GET | `/admin/config/export` | Admin | 导出完整配置(含 `config`/`json`/`base64`) | | POST | `/admin/keys` | Admin | 添加 API key | | DELETE | `/admin/keys/{key}` | Admin | 删除 API key | | GET | `/admin/accounts` | Admin | 分页账号列表 | @@ -88,6 +129,8 @@ | POST | `/admin/vercel/sync` | Admin | 同步配置到 Vercel | | GET | `/admin/vercel/status` | Admin | Vercel 同步状态 | | GET | `/admin/export` | Admin | 导出配置 JSON/Base64 | +| GET | `/admin/dev/captures` | Admin | 查看本地抓包记录 | +| DELETE | `/admin/dev/captures` | Admin | 清空本地抓包记录 | --- @@ -127,6 +170,15 @@ } ``` +### 模型 alias 解析策略 + +对 `chat` / `responses` / `embeddings` 的 `model` 字段采用“宽进严出”: + +1. 先匹配 DeepSeek 原生模型。 +2. 再匹配 `model_aliases` 精确映射。 +3. 未命中时按模型家族规则回退(如 `o*`、`gpt-*`、`claude-*`)。 +4. 仍未命中则返回 `invalid_request_error`。 + ### `POST /v1/chat/completions` **请求头**: @@ -140,7 +192,7 @@ Content-Type: application/json | 字段 | 类型 | 必填 | 说明 | | --- | --- | --- | --- | -| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` | +| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-4o`、`gpt-5-codex`、`o3`、`claude-sonnet-4-5`) | | `messages` | array | ✅ | OpenAI 风格消息数组 | | `stream` | boolean | ❌ | 默认 `false` | | `tools` | array | ❌ | Function Calling 定义 | @@ -230,12 +282,90 @@ data: [DONE] } ``` -**流式**:先缓冲正文片段。识别到工具调用 → 仅输出结构化 `delta.tool_calls`(每个 tool call 带 `index`);否则一次性输出普通文本。 +**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。 + +--- + +### `GET /v1/models/{id}` + +无需鉴权。入参支持 alias(例如 `gpt-4o`),返回的是映射后的 DeepSeek 模型对象。 + +### `POST /v1/responses` + +OpenAI Responses 风格接口,兼容 `input` 或 `messages`。 + +| 字段 | 类型 | 必填 | 说明 | +| --- | --- | --- | --- | +| `model` | string | ✅ | 支持原生模型 + alias 自动映射 | +| `input` | string/array/object | ❌ | 与 `messages` 二选一 | +| `messages` | array | ❌ | 与 `input` 二选一 | +| `instructions` | string | ❌ | 自动前置为 system 消息 | +| `stream` | boolean | ❌ | 默认 `false` | +| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 | +| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) | + +**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。 +当 `tool_choice=required` 且未产出有效工具调用时,返回 HTTP `422`(`error.code=tool_choice_violation`)。 + +**流式响应(SSE)**:最小事件序列如下。 + +```text +event: response.created +data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...} + +event: response.output_item.added +data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...} + +event: response.content_part.added +data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...} + +event: response.output_text.delta +data: {"type":"response.output_text.delta","response_id":"resp_xxx","item_id":"msg_xxx","output_index":0,"content_index":0,"delta":"..."} + +event: response.function_call_arguments.delta +data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."} + +event: response.function_call_arguments.done +data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"} + +event: response.content_part.done +data: {"type":"response.content_part.done","response_id":"resp_xxx",...} + +event: response.output_item.done +data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...} + +event: response.completed +data: {"type":"response.completed","response":{...}} + +data: [DONE] +``` + +流式场景下若 `tool_choice=required` 违规,会返回 `response.failed` 后结束(不再发送 `response.completed`)。 +未在 `tools` 声明中的工具名会被严格拒绝,不会作为有效 tool call 下发。 + +### `GET /v1/responses/{response_id}` + +需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。 + +> 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。 + +### `POST /v1/embeddings` + +需要业务鉴权。返回 OpenAI Embeddings 兼容结构。 + +| 字段 | 类型 | 必填 | 说明 | +| --- | --- | --- | --- | +| `model` | string | ✅ | 支持原生模型 + alias 自动映射 | +| `input` | string/array | ✅ | 支持字符串、字符串数组、token 数组 | + +> 需配置 `embeddings.provider`。当前支持:`mock` / `deterministic` / `builtin`。未配置或不支持时返回标准错误结构(HTTP 501)。 --- ## Claude 兼容接口 +除标准路径 `/anthropic/v1/*` 外,还支持快捷路径 `/v1/messages`、`/messages`、`/v1/messages/count_tokens`、`/messages/count_tokens`。 + ### `GET /anthropic/v1/models` 无需鉴权。 @@ -249,7 +379,10 @@ data: [DONE] {"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"}, {"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"} - ] + ], + "first_id": "claude-opus-4-6", + "last_id": "claude-instant-1.0", + "has_more": false } ``` @@ -265,13 +398,15 @@ Content-Type: application/json anthropic-version: 2023-06-01 ``` +> `anthropic-version` 可省略,服务端会自动补为 `2023-06-01`。 + **请求体**: | 字段 | 类型 | 必填 | 说明 | | --- | --- | --- | --- | | `model` | string | ✅ | 例如 `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5`(兼容 `claude-3-5-haiku-latest`),并支持历史 Claude 模型 ID | | `messages` | array | ✅ | Claude 风格消息数组 | -| `max_tokens` | number | ❌ | 当前实现不会硬性截断上游输出 | +| `max_tokens` | number | ❌ | 缺省自动补 `8192`;当前实现不会硬性截断上游输出 | | `stream` | boolean | ❌ | 默认 `false` | | `system` | string | ❌ | 可选系统提示 | | `tools` | array | ❌ | Claude tool 定义 | @@ -354,6 +489,37 @@ data: {"type":"message_stop"} --- +## Gemini 兼容接口 + +支持路径: + +- `/v1beta/models/{model}:generateContent` +- `/v1beta/models/{model}:streamGenerateContent` +- `/v1/models/{model}:generateContent`(兼容路径) +- `/v1/models/{model}:streamGenerateContent`(兼容路径) + +鉴权方式同业务接口(`Authorization: Bearer ` 或 `x-api-key`)。 + +### `POST /v1beta/models/{model}:generateContent` + +请求体兼容 Gemini `contents` / `tools` 字段,模型名可用 alias 自动映射到 DeepSeek 模型。 + +响应为 Gemini 兼容结构,核心字段包括: + +- `candidates[].content.parts[].text` +- `candidates[].content.parts[].functionCall`(工具调用时) +- `usageMetadata`(`promptTokenCount` / `candidatesTokenCount` / `totalTokenCount`) + +### `POST /v1beta/models/{model}:streamGenerateContent` + +返回 SSE(`text/event-stream`),每个 chunk 为一条 `data: `: + +- 常规文本:持续返回增量文本 chunk +- `tools` 场景:会缓冲并在结束时输出 `functionCall` 结构 +- 结束 chunk:包含 `finishReason: "STOP"` 与 `usageMetadata` + +--- + ## Admin 接口 ### `POST /admin/login` @@ -416,6 +582,7 @@ data: {"type":"message_stop"} "keys": ["k1", "k2"], "accounts": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -449,6 +616,51 @@ data: {"type":"message_stop"} } ``` +### `GET /admin/settings` + +读取运行时设置与状态,返回: + +- `admin`(JWT 过期、默认密码告警等) +- `runtime`(`account_max_inflight`、`account_max_queue`、`global_max_inflight`) +- `toolcall` / `responses` / `embeddings` +- `claude_mapping` / `model_aliases` +- `env_backed`、`needs_vercel_sync` + +### `PUT /admin/settings` + +热更新运行时设置。支持更新: + +- `admin.jwt_expire_hours` +- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight` +- `toolcall.mode` / `toolcall.early_emit_confidence` +- `responses.store_ttl_seconds` +- `embeddings.provider` +- `claude_mapping` +- `model_aliases` + +### `POST /admin/settings/password` + +更新管理密码并使旧 JWT 失效。 + +请求示例: + +```json +{"new_password":"your-new-password"} +``` + +### `POST /admin/config/import` + +导入完整配置,支持: + +- `mode=merge`(默认) +- `mode=replace` + +请求可直接传配置对象,或使用 `{"config": {...}, "mode":"merge"}` 包裹格式。 + +### `GET /admin/config/export` + +导出完整配置,返回 `config`、`json`、`base64` 三种格式。 + ### `POST /admin/keys` ```json @@ -476,6 +688,7 @@ data: {"type":"message_stop"} { "items": [ { + "identifier": "user@example.com", "email": "user@example.com", "mobile": "", "has_password": true, @@ -500,7 +713,7 @@ data: {"type":"message_stop"} ### `DELETE /admin/accounts/{identifier}` -`identifier` 为 email 或 mobile。 +`identifier` 可为 email、mobile,或 token-only 账号的合成标识(`token:`)。 **响应**:`{"success": true, "total_accounts": 5}` @@ -530,7 +743,7 @@ data: {"type":"message_stop"} | 字段 | 必填 | 说明 | | --- | --- | --- | -| `identifier` | ✅ | email 或 mobile | +| `identifier` | ✅ | email / mobile / token-only 合成标识 | | `model` | ❌ | 默认 `deepseek-chat` | | `message` | ❌ | 空字符串时仅测试会话创建 | @@ -655,17 +868,53 @@ data: {"type":"message_stop"} } ``` +### `GET /admin/dev/captures` + +查看本地抓包状态与最近记录(需 Admin 鉴权): + +- `enabled` +- `limit` +- `max_body_bytes` +- `items` + +### `DELETE /admin/dev/captures` + +清空抓包记录,返回: + +```json +{"success":true,"detail":"capture logs cleared"} +``` + --- ## 错误响应格式 -不同模块的错误格式略有差异: +兼容路由(`/v1/*`、`/anthropic/*`)统一使用以下结构: -| 模块 | 格式 | -| --- | --- | -| OpenAI 接口 | `{"error": {"message": "...", "type": "..."}}` | -| Claude 接口 | `{"error": {"type": "...", "message": "..."}}` | -| Admin 接口 | `{"detail": "..."}` | +```json +{ + "error": { + "message": "...", + "type": "invalid_request_error", + "code": "invalid_request", + "param": null + } +} +``` + +Admin 接口保持 `{"detail":"..."}`。 + +Gemini 路由使用 Google 风格错误结构: + +```json +{ + "error": { + "code": 400, + "message": "invalid json", + "status": "INVALID_ARGUMENT" + } +} +``` 建议客户端处理逻辑:检查 HTTP 状态码 + 解析 `error` 或 `detail` 字段。 @@ -707,6 +956,31 @@ curl http://localhost:5001/v1/chat/completions \ }' ``` +### OpenAI Responses(流式) + +```bash +curl http://localhost:5001/v1/responses \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-5-codex", + "input": "写一个 golang 的 hello world", + "stream": true + }' +``` + +### OpenAI Embeddings + +```bash +curl http://localhost:5001/v1/embeddings \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o", + "input": ["第一段文本", "第二段文本"] + }' +``` + ### OpenAI 带搜索 ```bash @@ -748,6 +1022,38 @@ curl http://localhost:5001/v1/chat/completions \ }' ``` +### Gemini 非流式 + +```bash +curl "http://localhost:5001/v1beta/models/gemini-2.5-pro:generateContent" \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + { + "role": "user", + "parts": [{"text": "用三句话介绍 Go 语言"}] + } + ] + }' +``` + +### Gemini 流式 + +```bash +curl "http://localhost:5001/v1beta/models/gemini-2.5-flash:streamGenerateContent" \ + -H "Authorization: Bearer your-api-key" \ + -H "Content-Type: application/json" \ + -d '{ + "contents": [ + { + "role": "user", + "parts": [{"text": "写一个简短摘要"}] + } + ] + }' +``` + ### Claude 非流式 ```bash diff --git a/CONTRIBUTING.en.md b/CONTRIBUTING.en.md index baf5eae..156b025 100644 --- a/CONTRIBUTING.en.md +++ b/CONTRIBUTING.en.md @@ -82,11 +82,11 @@ Manually build WebUI to `static/admin/`: ## Running Tests ```bash -# Go unit tests -go test ./... +# Go + Node unit tests (recommended) +./tests/scripts/run-unit-all.sh # End-to-end live tests (real accounts) -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` ## Project Structure @@ -104,13 +104,20 @@ ds2api/ │ ├── account/ # Account pool and concurrency queue │ ├── adapter/ │ │ ├── openai/ # OpenAI adapter -│ │ └── claude/ # Claude adapter +│ │ ├── claude/ # Claude adapter +│ │ └── gemini/ # Gemini adapter │ ├── admin/ # Admin API handlers │ ├── auth/ # Auth and JWT +│ ├── claudeconv/ # Claude message conversion +│ ├── compat/ # Compatibility helpers │ ├── config/ # Config loading and hot-reload │ ├── deepseek/ # DeepSeek client, PoW WASM +│ ├── devcapture/ # Dev packet capture +│ ├── format/ # Output formatting +│ ├── prompt/ # Prompt building │ ├── server/ # HTTP routing (chi router) │ ├── sse/ # SSE parsing utilities +│ ├── stream/ # Unified stream consumption engine │ ├── testsuite/ # Testsuite core logic │ ├── util/ # Common utilities │ └── webui/ # WebUI static hosting diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c75d450..fd44b9a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -82,11 +82,11 @@ docker-compose -f docker-compose.dev.yml up ## 运行测试 ```bash -# Go 单元测试 -go test ./... +# Go + Node 单元测试(推荐) +./tests/scripts/run-unit-all.sh # 端到端全链路测试(真实账号) -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` ## 项目结构 @@ -104,13 +104,20 @@ ds2api/ │ ├── account/ # 账号池与并发队列 │ ├── adapter/ │ │ ├── openai/ # OpenAI 兼容适配器 -│ │ └── claude/ # Claude 兼容适配器 +│ │ ├── claude/ # Claude 兼容适配器 +│ │ └── gemini/ # Gemini 兼容适配器 │ ├── admin/ # Admin API handlers │ ├── auth/ # 鉴权与 JWT +│ ├── claudeconv/ # Claude 消息格式转换 +│ ├── compat/ # 兼容性辅助 │ ├── config/ # 配置加载与热更新 │ ├── deepseek/ # DeepSeek 客户端、PoW WASM +│ ├── devcapture/ # 开发抓包 +│ ├── format/ # 输出格式化 +│ ├── prompt/ # Prompt 构建 │ ├── server/ # HTTP 路由(chi router) │ ├── sse/ # SSE 解析工具 +│ ├── stream/ # 统一流式消费引擎 │ ├── testsuite/ # 测试集核心逻辑 │ ├── util/ # 通用工具 │ └── webui/ # WebUI 静态托管 diff --git a/DEPLOY.en.md b/DEPLOY.en.md index b7caf8c..2ff3250 100644 --- a/DEPLOY.en.md +++ b/DEPLOY.en.md @@ -33,6 +33,17 @@ Config source (choose one): - **File**: `config.json` (recommended for local/Docker) - **Environment variable**: `DS2API_CONFIG_JSON` (recommended for Vercel; supports raw JSON or Base64) +Unified recommendation (best practice): + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Use `config.json` as the single source of truth: +- Local run: read `config.json` directly +- Docker / Vercel: generate `DS2API_CONFIG_JSON` (Base64) from `config.json` and inject it + --- ## 1. Local Run @@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api ### 2.1 Basic Steps ```bash -# Copy and edit environment +# Copy env template cp .env.example .env -# Edit .env, at minimum set: + +# Generate single-line Base64 from config.json +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# Edit .env and set: # DS2API_ADMIN_KEY=your-admin-key -# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]} +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} # Start docker-compose up -d @@ -120,11 +135,12 @@ docker-compose up -d --build ### 2.3 Docker Architecture -The `Dockerfile` uses a three-stage build: +The `Dockerfile` now provides two image paths: -1. **WebUI build stage**: `node:20` image, runs `npm ci && npm run build` -2. **Go build stage**: `golang:1.24` image, compiles the binary -3. **Runtime stage**: `debian:bookworm-slim` minimal image +1. **Default local/dev path (`runtime-from-source`)**: a three-stage build (WebUI build + Go build + runtime). +2. **Release path (`runtime-from-dist`)**: CI first creates `dist/ds2api__linux_.tar.gz`, then Docker directly reuses the binary and `static/admin` assets from those release archives, without running `npm build`/`go build` again. + +The release path keeps Docker images aligned with release archives and reduces duplicate build work. Container entry command: `/usr/local/bin/ds2api`, default exposed port: `5001`. @@ -145,7 +161,7 @@ Docker Compose includes a built-in health check: ```yaml healthcheck: - test: ["CMD", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"] + test: ["CMD", "/usr/local/bin/busybox", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"] interval: 30s timeout: 10s retries: 3 @@ -167,15 +183,49 @@ If container logs look normal but the admin panel is unreachable, check these fi 1. **Fork** the repo to your GitHub account 2. **Import** the project on Vercel -3. **Set environment variables** (at minimum): +3. **Set environment variables** (minimum required: one variable): | Variable | Description | | --- | --- | | `DS2API_ADMIN_KEY` | Admin key (required) | - | `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (required) | + | `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (optional, recommended) | 4. **Deploy** +### 3.1.1 Recommended Input (avoid `DS2API_CONFIG_JSON` mistakes) + +If you prefer faster one-click bootstrap, you can leave `DS2API_CONFIG_JSON` empty first, then open `/admin` after deployment, import config, and sync it back to Vercel env vars from the "Vercel Sync" page. + +Recommended: in repo root, copy the template first and fill your real accounts: + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Do not hand-edit large JSON directly in Vercel. Generate Base64 locally and paste it: + +```bash +# Run in repo root +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +echo "$DS2API_CONFIG_JSON" +``` + +If you choose to preconfigure before first deploy, set these vars in Vercel Project Settings -> Environment Variables: + +```text +DS2API_ADMIN_KEY=replace-with-a-strong-secret +DS2API_CONFIG_JSON= +``` + +Optional but recommended (for WebUI one-click Vercel sync): + +```text +VERCEL_TOKEN=your-vercel-token +VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx +VERCEL_TEAM_ID=team_xxxxxxxxxxxx # optional for personal accounts +``` + ### 3.2 Optional Environment Variables | Variable | Description | Default | @@ -184,6 +234,8 @@ If container logs look normal but the admin panel is unreachable, check these fi | `DS2API_ACCOUNT_CONCURRENCY` | Alias (legacy compat) | — | | `DS2API_ACCOUNT_MAX_QUEUE` | Waiting queue limit | `recommended_concurrency` | | `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — | +| `DS2API_GLOBAL_MAX_INFLIGHT` | Global inflight limit | `recommended_concurrency` | +| `DS2API_MAX_INFLIGHT` | Alias (legacy compat) | — | | `DS2API_VERCEL_INTERNAL_SECRET` | Hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` | | `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL | `900` | | `VERCEL_TOKEN` | Vercel sync token | — | @@ -290,6 +342,7 @@ Built-in GitHub Actions workflow: `.github/workflows/release-artifacts.yml` - **Trigger**: only on Release `published` (no build on normal push) - **Outputs**: multi-platform binary archives + `sha256sums.txt` +- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`) | Platform | Architecture | Format | | --- | --- | --- | @@ -310,8 +363,8 @@ Each archive includes: ```bash # 1. Download the archive for your platform # 2. Extract -tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz -cd ds2api_v1.7.0_linux_amd64 +tar -xzf ds2api__linux_amd64.tar.gz +cd ds2api__linux_amd64 # 3. Configure cp config.example.json config.json @@ -323,10 +376,20 @@ cp config.example.json config.json ### Maintainer Release Flow -1. Create and publish a GitHub Release (with tag, e.g. `v1.7.0`) +1. Create and publish a GitHub Release (with tag, for example `vX.Y.Z`) 2. Wait for the `Release Artifacts` workflow to complete 3. Download the matching archive from Release Assets +### Pull from GHCR (Optional) + +```bash +# latest +docker pull ghcr.io/cjackhwang/ds2api:latest + +# specific version (example) +docker pull ghcr.io/cjackhwang/ds2api:v2.1.2 +``` + --- ## 5. Reverse Proxy (Nginx) @@ -469,7 +532,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \ Run the full live testsuite before release (real account tests): ```bash -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` With custom flags: diff --git a/DEPLOY.md b/DEPLOY.md index b7fbf9a..d7d74d3 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -33,6 +33,17 @@ - **文件方式**:`config.json`(推荐本地/Docker 使用) - **环境变量方式**:`DS2API_CONFIG_JSON`(推荐 Vercel 使用,支持 JSON 字符串或 Base64 编码) +统一建议(最优实践): + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +建议把 `config.json` 作为唯一配置源: +- 本地运行:直接读 `config.json` +- Docker / Vercel:从 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量 + --- ## 一、本地运行 @@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api ### 2.1 基本步骤 ```bash -# 复制并编辑环境变量 +# 复制环境变量模板 cp .env.example .env -# 编辑 .env,至少设置: + +# 从 config.json 生成单行 Base64 +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# 编辑 .env(请改成你的强密码),设置: # DS2API_ADMIN_KEY=your-admin-key -# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]} +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} # 启动 docker-compose up -d @@ -120,11 +135,12 @@ docker-compose up -d --build ### 2.3 Docker 架构说明 -`Dockerfile` 使用三阶段构建: +`Dockerfile` 提供两条构建路径: -1. **WebUI 构建阶段**:`node:20` 镜像,执行 `npm ci && npm run build` -2. **Go 构建阶段**:`golang:1.24` 镜像,编译二进制文件 -3. **运行阶段**:`debian:bookworm-slim` 精简镜像 +1. **本地/开发默认路径(`runtime-from-source`)**:三阶段构建(WebUI 构建 + Go 构建 + 运行阶段)。 +2. **Release 路径(`runtime-from-dist`)**:CI 先生成 `dist/ds2api__linux_.tar.gz`,再由 Docker 直接复用该发布包内的二进制和 `static/admin` 产物组装运行镜像,不再重复执行 `npm build`/`go build`。 + +Release 路径可确保 Docker 镜像与 release 压缩包使用同一套产物,减少重复构建带来的差异。 容器内启动命令:`/usr/local/bin/ds2api`,默认暴露端口 `5001`。 @@ -145,7 +161,7 @@ Docker Compose 已配置内置健康检查: ```yaml healthcheck: - test: ["CMD", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"] + test: ["CMD", "/usr/local/bin/busybox", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"] interval: 30s timeout: 10s retries: 3 @@ -167,15 +183,49 @@ healthcheck: 1. **Fork 仓库**到你的 GitHub 账号 2. **在 Vercel 上导入项目** -3. **配置环境变量**(至少设置以下两项): +3. **配置环境变量**(最少只需设置以下一项): | 变量 | 说明 | | --- | --- | | `DS2API_ADMIN_KEY` | 管理密钥(必填) | - | `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(必填) | + | `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(可选,建议) | 4. **部署** +### 3.1.1 推荐填写方式(避免 `DS2API_CONFIG_JSON` 填错) + +如果你想先完成一键部署,也可以先不填 `DS2API_CONFIG_JSON`,部署后进入 `/admin` 导入配置,再在「Vercel 同步」里写回环境变量。 + +建议先在仓库目录复制示例配置,再按实际账号填写: + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +不要在 Vercel 面板里手写复杂 JSON,建议本地生成 Base64 后粘贴: + +```bash +# 在仓库根目录执行 +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" +echo "$DS2API_CONFIG_JSON" +``` + +如果你选择在部署前就预置配置,请在 Vercel Project Settings -> Environment Variables 配置: + +```text +DS2API_ADMIN_KEY=请替换为强密码 +DS2API_CONFIG_JSON=上一步生成的一整行 Base64 +``` + +可选但推荐(用于 WebUI 一键同步 Vercel 配置): + +```text +VERCEL_TOKEN=你的 Vercel Token +VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx +VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空 +``` + ### 3.2 可选环境变量 | 变量 | 说明 | 默认值 | @@ -184,6 +234,8 @@ healthcheck: | `DS2API_ACCOUNT_CONCURRENCY` | 同上(兼容别名) | — | | `DS2API_ACCOUNT_MAX_QUEUE` | 等待队列上限 | `recommended_concurrency` | | `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容别名) | — | +| `DS2API_GLOBAL_MAX_INFLIGHT` | 全局并发上限 | `recommended_concurrency` | +| `DS2API_MAX_INFLIGHT` | 同上(兼容别名) | — | | `DS2API_VERCEL_INTERNAL_SECRET` | 混合流式内部鉴权 | 回退用 `DS2API_ADMIN_KEY` | | `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease TTL | `900` | | `VERCEL_TOKEN` | Vercel 同步 token | — | @@ -290,6 +342,7 @@ No Output Directory named "public" found after the Build completed. - **触发条件**:仅在 Release `published` 时触发(普通 push 不会构建) - **构建产物**:多平台二进制压缩包 + `sha256sums.txt` +- **容器镜像发布**:仅发布到 GHCR(`ghcr.io/cjackhwang/ds2api`) | 平台 | 架构 | 文件格式 | | --- | --- | --- | @@ -310,8 +363,8 @@ No Output Directory named "public" found after the Build completed. ```bash # 1. 下载对应平台的压缩包 # 2. 解压 -tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz -cd ds2api_v1.7.0_linux_amd64 +tar -xzf ds2api__linux_amd64.tar.gz +cd ds2api__linux_amd64 # 3. 配置 cp config.example.json config.json @@ -323,10 +376,20 @@ cp config.example.json config.json ### 维护者发布步骤 -1. 在 GitHub 创建并发布 Release(带 tag,如 `v1.7.0`) +1. 在 GitHub 创建并发布 Release(带 tag,如 `vX.Y.Z`) 2. 等待 Actions 工作流 `Release Artifacts` 完成 3. 在 Release 的 Assets 下载对应平台压缩包 +### 拉取 GHCR 镜像(可选) + +```bash +# latest +docker pull ghcr.io/cjackhwang/ds2api:latest + +# 指定版本(示例) +docker pull ghcr.io/cjackhwang/ds2api:v2.1.2 +``` + --- ## 五、反向代理(Nginx) @@ -469,7 +532,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \ 建议在发布前执行完整的端到端测试集(使用真实账号): ```bash -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` 可自定义参数: diff --git a/Dockerfile b/Dockerfile index a67dfd1..c86f82d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,12 +15,44 @@ RUN go mod download COPY . . RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -o /out/ds2api ./cmd/ds2api -FROM debian:bookworm-slim +FROM busybox:1.36.1-musl AS busybox-tools + +FROM debian:bookworm-slim AS runtime-base WORKDIR /app -RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates wget && rm -rf /var/lib/apt/lists/* +COPY --from=go-builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt +COPY --from=busybox-tools /bin/busybox /usr/local/bin/busybox +EXPOSE 5001 +CMD ["/usr/local/bin/ds2api"] + +FROM runtime-base AS runtime-from-source COPY --from=go-builder /out/ds2api /usr/local/bin/ds2api COPY --from=go-builder /app/sha3_wasm_bg.7b9ca65ddd.wasm /app/sha3_wasm_bg.7b9ca65ddd.wasm COPY --from=go-builder /app/config.example.json /app/config.example.json COPY --from=webui-builder /app/static/admin /app/static/admin -EXPOSE 5001 -CMD ["/usr/local/bin/ds2api"] + +FROM busybox-tools AS dist-extract +ARG TARGETARCH +COPY dist/docker-input/linux_amd64.tar.gz /tmp/ds2api_linux_amd64.tar.gz +COPY dist/docker-input/linux_arm64.tar.gz /tmp/ds2api_linux_arm64.tar.gz +RUN set -eux; \ + case "${TARGETARCH}" in \ + amd64) ARCHIVE="/tmp/ds2api_linux_amd64.tar.gz" ;; \ + arm64) ARCHIVE="/tmp/ds2api_linux_arm64.tar.gz" ;; \ + *) echo "unsupported TARGETARCH: ${TARGETARCH}" >&2; exit 1 ;; \ + esac; \ + tar -xzf "${ARCHIVE}" -C /tmp; \ + PKG_DIR="$(find /tmp -maxdepth 1 -type d -name "ds2api_*_linux_${TARGETARCH}" | head -n1)"; \ + test -n "${PKG_DIR}"; \ + mkdir -p /out/static; \ + cp "${PKG_DIR}/ds2api" /out/ds2api; \ + cp "${PKG_DIR}/sha3_wasm_bg.7b9ca65ddd.wasm" /out/sha3_wasm_bg.7b9ca65ddd.wasm; \ + cp "${PKG_DIR}/config.example.json" /out/config.example.json; \ + cp -R "${PKG_DIR}/static/admin" /out/static/admin + +FROM runtime-base AS runtime-from-dist +COPY --from=dist-extract /out/ds2api /usr/local/bin/ds2api +COPY --from=dist-extract /out/sha3_wasm_bg.7b9ca65ddd.wasm /app/sha3_wasm_bg.7b9ca65ddd.wasm +COPY --from=dist-extract /out/config.example.json /app/config.example.json +COPY --from=dist-extract /out/static/admin /app/static/admin + +FROM runtime-from-source AS final diff --git a/README.MD b/README.MD index b438b75..b8c3be0 100644 --- a/README.MD +++ b/README.MD @@ -3,18 +3,18 @@ [![License](https://img.shields.io/github/license/CJackHwang/ds2api.svg)](LICENSE) ![Stars](https://img.shields.io/github/stars/CJackHwang/ds2api.svg) ![Forks](https://img.shields.io/github/forks/CJackHwang/ds2api.svg) -[![Version](https://img.shields.io/badge/version-1.6.11-blue.svg)](version.txt) +[![Release](https://img.shields.io/github/v/release/CJackHwang/ds2api?display_name=tag)](https://github.com/CJackHwang/ds2api/releases) [![Docker](https://img.shields.io/badge/docker-ready-blue.svg)](DEPLOY.md) 语言 / Language: [中文](README.MD) | [English](README.en.md) -将 DeepSeek Web 对话能力转换为 OpenAI 与 Claude 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。 +将 DeepSeek Web 对话能力转换为 OpenAI、Claude 与 Gemini 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。 ## 架构概览 ```mermaid flowchart LR - Client["🖥️ 客户端\n(OpenAI / Claude 兼容)"] + Client["🖥️ 客户端\n(OpenAI / Claude / Gemini 兼容)"] subgraph DS2API["DS2API 服务"] direction TB @@ -24,6 +24,7 @@ flowchart LR subgraph Adapters["适配器层"] OA["OpenAI 适配器\n/v1/*"] CA["Claude 适配器\n/anthropic/*"] + GA["Gemini 适配器\n/v1beta/models/*"] end subgraph Support["支撑模块"] @@ -38,11 +39,11 @@ flowchart LR DS["☁️ DeepSeek API"] Client -- "请求" --> CORS --> Auth - Auth --> OA & CA - OA & CA -- "调用" --> DS + Auth --> OA & CA & GA + OA & CA & GA -- "调用" --> DS Auth --> Admin - OA & CA -. "轮询选账号" .-> Pool - OA & CA -. "计算 PoW" .-> PoW + OA & CA & GA -. "轮询选账号" .-> Pool + OA & CA & GA -. "计算 PoW" .-> PoW DS -- "响应" --> Client ``` @@ -54,16 +55,29 @@ flowchart LR | 能力 | 说明 | | --- | --- | -| OpenAI 兼容 | `GET /v1/models`、`POST /v1/chat/completions`(流式/非流式) | -| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` | +| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` | +| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens`(及快捷路径 `/v1/messages`、`/messages`) | +| Gemini 兼容 | `POST /v1beta/models/{model}:generateContent`、`POST /v1beta/models/{model}:streamGenerateContent`(及 `/v1/models/{model}:*` 路径) | | 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 | | 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 | | DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 | -| Tool Calling | 防泄漏处理:自动缓冲、识别、结构化输出 | -| Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 | +| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 | +| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、导入导出、Vercel 同步 | | WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) | | 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) | +## 平台兼容矩阵 + +| 级别 | 平台 | 当前状态 | +| --- | --- | --- | +| P0 | Codex CLI/SDK(`wire_api=chat` / `wire_api=responses`) | ✅ | +| P0 | OpenAI SDK(JS/Python,chat + responses) | ✅ | +| P0 | Vercel AI SDK(openai-compatible) | ✅ | +| P0 | Anthropic SDK(messages) | ✅ | +| P0 | Google Gemini SDK(generateContent) | ✅ | +| P1 | LangChain / LlamaIndex / OpenWebUI(OpenAI 兼容接入) | ✅ | +| P2 | MCP 独立桥接层 | 规划中 | + ## 模型支持 ### OpenAI 接口 @@ -86,8 +100,25 @@ flowchart LR 可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。 另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。 +### Gemini 接口 + +Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。 + ## 快速开始 +### 通用第一步(所有部署方式) + +把 `config.json` 作为唯一配置源(推荐做法): + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +后续部署建议: +- 本地运行:直接读取 `config.json` +- Docker / Vercel:由 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量 + ### 方式一:本地运行 **前置要求**:Go 1.24+,Node.js 20+(仅在需要构建 WebUI 时) @@ -112,14 +143,20 @@ go run ./cmd/ds2api ### 方式二:Docker 运行 ```bash -# 1. 配置环境变量 +# 1. 准备环境变量文件 cp .env.example .env -# 编辑 .env -# 2. 启动 +# 2. 从 config.json 生成 DS2API_CONFIG_JSON(单行 Base64) +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# 3. 编辑 .env,设置: +# DS2API_ADMIN_KEY=请替换为强密码 +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} + +# 4. 启动 docker-compose up -d -# 3. 查看日志 +# 5. 查看日志 docker-compose logs -f ``` @@ -129,9 +166,22 @@ docker-compose logs -f 1. Fork 仓库到自己的 GitHub 2. 在 Vercel 上导入项目 -3. 配置环境变量(至少设置 `DS2API_ADMIN_KEY` 和 `DS2API_CONFIG_JSON`) +3. 配置环境变量(最少设置 `DS2API_ADMIN_KEY`;推荐同时设置 `DS2API_CONFIG_JSON`) 4. 部署 +建议先在仓库目录复制模板并填写: + +```bash +cp config.example.json config.json +# 编辑 config.json +``` + +推荐:先本地把 `config.json` 转成 Base64,再粘贴到 `DS2API_CONFIG_JSON`,避免 JSON 格式错误: + +```bash +base64 < config.json | tr -d '\n' +``` + > **流式说明**:`/v1/chat/completions` 在 Vercel 上默认走 `api/chat-stream.js`(Node Runtime)以保证实时 SSE。鉴权、账号选择、会话/PoW 准备仍由 Go 内部 prepare 接口完成;流式响应(含 `tools`)在 Node 侧执行与 Go 对齐的输出组装与防泄漏处理。 详细部署说明请参阅 [部署指南](DEPLOY.md)。 @@ -142,8 +192,8 @@ docker-compose logs -f ```bash # 下载对应平台的压缩包后 -tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz -cd ds2api_v1.7.0_linux_amd64 +tar -xzf ds2api__linux_amd64.tar.gz +cd ds2api__linux_amd64 cp config.example.json config.json # 编辑 config.json ./ds2api @@ -164,6 +214,7 @@ cp opencode.json.example opencode.json 3. 在项目目录启动 OpenCode CLI(按你的安装方式运行 `opencode`)。 > 建议优先使用 OpenAI 兼容路径(`/v1/*`),即示例里的 `@ai-sdk/openai-compatible` provider。 +> 若客户端支持 `wire_api`,可分别测试 `responses` 与 `chat`,DS2API 两条链路都兼容。 ## 配置说明 @@ -184,9 +235,35 @@ cp opencode.json.example opencode.json "token": "" } ], + "model_aliases": { + "gpt-4o": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o3": "deepseek-reasoner" + }, + "compat": { + "wide_input_strict_output": true + }, + "toolcall": { + "mode": "feature_match", + "early_emit_confidence": "high" + }, + "responses": { + "store_ttl_seconds": 900 + }, + "embeddings": { + "provider": "deterministic" + }, "claude_model_mapping": { "fast": "deepseek-chat", "slow": "deepseek-reasoner" + }, + "admin": { + "jwt_expire_hours": 24 + }, + "runtime": { + "account_max_inflight": 2, + "account_max_queue": 0, + "global_max_inflight": 0 } } ``` @@ -194,7 +271,14 @@ cp opencode.json.example opencode.json - `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer ` 鉴权 - `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录 - `token`:留空则首次请求时自动登录获取;也可预填已有 token +- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射 +- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出) +- `toolcall`:固定采用特征匹配 + 高置信早发策略 +- `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL +- `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`) - `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型 +- `admin`:管理后台设置(JWT 过期时间、密码哈希等),可通过 Admin Settings API 热更新 +- `runtime`:运行时参数(并发限制、队列大小),可通过 Admin Settings API 热更新 ### 环境变量 @@ -214,8 +298,13 @@ cp opencode.json.example opencode.json | `DS2API_ACCOUNT_CONCURRENCY` | 同上(兼容旧名) | — | | `DS2API_ACCOUNT_MAX_QUEUE` | 等待队列上限 | `recommended_concurrency` | | `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容旧名) | — | +| `DS2API_GLOBAL_MAX_INFLIGHT` | 全局最大 in-flight 请求数 | `recommended_concurrency` | +| `DS2API_MAX_INFLIGHT` | 同上(兼容旧名) | — | | `DS2API_VERCEL_INTERNAL_SECRET` | Vercel 混合流式内部鉴权密钥 | 回退用 `DS2API_ADMIN_KEY` | | `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease 过期秒数 | `900` | +| `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 | +| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | 本地抓包保留条数(超出自动淘汰) | `5` | +| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | 单条响应体最大记录字节数 | `2097152` | | `VERCEL_TOKEN` | Vercel 同步 token | — | | `VERCEL_PROJECT_ID` | Vercel 项目 ID | — | | `VERCEL_TEAM_ID` | Vercel 团队 ID | — | @@ -223,7 +312,7 @@ cp opencode.json.example opencode.json ## 鉴权模式 -调用业务接口(`/v1/*`、`/anthropic/*`)时支持两种模式: +调用业务接口(`/v1/*`、`/anthropic/*`、Gemini 路由)时支持两种模式: | 模式 | 说明 | | --- | --- | @@ -249,10 +338,34 @@ cp opencode.json.example opencode.json 当请求中带 `tools` 时,DS2API 会做防泄漏处理: -1. `stream=true` 时先**缓冲**正文片段 -2. 若识别到工具调用 → 仅输出结构化 `tool_calls`,不透传原始 JSON 文本 -3. 若最终不是工具调用 → 一次性输出普通文本 -4. 解析器支持混合文本、fenced JSON、`function.arguments` 字符串等格式 +1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发) +2. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`) +3. 未在 `tools` 声明中的工具名会被严格拒绝,不会下发为有效 tool call +4. `responses` 支持并执行 `tool_choice`(`auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed` +5. 仅在通过策略校验后才会发出有效工具调用事件,避免错误工具名进入客户端执行链 + +## 本地开发抓包工具 + +用于定位「responses 思考流/工具调用」等问题。开启后会自动记录最近 N 条 DeepSeek 对话上游请求体与响应体(默认 5 条,超出自动淘汰)。 + +启用示例: + +```bash +DS2API_DEV_PACKET_CAPTURE=true \ +DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \ +go run ./cmd/ds2api +``` + +查询/清空(需 Admin JWT): + +- `GET /admin/dev/captures`:查看抓包列表(最新在前) +- `DELETE /admin/dev/captures`:清空抓包 + +返回字段包含: + +- `request_body`:发送给 DeepSeek 的完整请求体 +- `response_body`:上游返回的原始流式内容拼接文本 +- `response_truncated`:是否触发单条大小截断 ## 项目结构 @@ -269,13 +382,20 @@ ds2api/ │ ├── account/ # 账号池与并发队列 │ ├── adapter/ │ │ ├── openai/ # OpenAI 兼容适配器(含 Tool Call 解析、Vercel 流式 prepare/release) -│ │ └── claude/ # Claude 兼容适配器 -│ ├── admin/ # Admin API handlers +│ │ ├── claude/ # Claude 兼容适配器 +│ │ └── gemini/ # Gemini 兼容适配器(generateContent / streamGenerateContent) +│ ├── admin/ # Admin API handlers(含 Settings 热更新) │ ├── auth/ # 鉴权与 JWT +│ ├── claudeconv/ # Claude 消息格式转换 +│ ├── compat/ # 兼容性辅助 │ ├── config/ # 配置加载与热更新 │ ├── deepseek/ # DeepSeek API 客户端、PoW WASM +│ ├── devcapture/ # 开发抓包模块 +│ ├── format/ # 输出格式化 +│ ├── prompt/ # Prompt 构建 │ ├── server/ # HTTP 路由与中间件(chi router) │ ├── sse/ # SSE 解析工具 +│ ├── stream/ # 统一流式消费引擎 │ ├── util/ # 通用工具函数 │ └── webui/ # WebUI 静态文件托管与自动构建 ├── webui/ # React WebUI 源码(Vite + Tailwind) @@ -283,11 +403,13 @@ ds2api/ │ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage │ └── locales/ # 中英文语言包(zh.json / en.json) ├── scripts/ -│ ├── build-webui.sh # WebUI 手动构建脚本 -│ └── testsuite/ # 测试集运行脚本 +│ └── build-webui.sh # WebUI 手动构建脚本 +├── tests/ +│ ├── compat/ # 兼容性测试夹具与期望输出 +│ └── scripts/ # 统一测试脚本入口(unit/e2e) ├── static/admin/ # WebUI 构建产物(不提交到 Git) ├── .github/ -│ ├── workflows/ # GitHub Actions(Release 自动构建) +│ ├── workflows/ # GitHub Actions(质量门禁 + Release 自动构建) │ ├── ISSUE_TEMPLATE/ # Issue 模板 │ └── PULL_REQUEST_TEMPLATE.md ├── config.example.json # 配置文件示例 @@ -296,8 +418,7 @@ ds2api/ ├── docker-compose.yml # 生产环境 Docker Compose ├── docker-compose.dev.yml # 开发环境 Docker Compose ├── vercel.json # Vercel 路由与构建配置 -├── go.mod / go.sum # Go 模块依赖 -└── version.txt # 版本号 +└── go.mod / go.sum # Go 模块依赖 ``` ## 文档索引 @@ -312,11 +433,11 @@ ds2api/ ## 测试 ```bash -# 单元测试 -go test ./... +# 单元测试(Go + Node) +./tests/scripts/run-unit-all.sh # 一键端到端全链路测试(真实账号,生成完整请求/响应日志) -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh # 或自定义参数 go run ./cmd/ds2api-tests \ @@ -327,12 +448,21 @@ go run ./cmd/ds2api-tests \ --retries 2 ``` +```bash +# 发布前阻断门禁 +./tests/scripts/check-stage6-manual-smoke.sh +./tests/scripts/check-refactor-line-gate.sh +./tests/scripts/run-unit-all.sh +npm ci --prefix webui && npm run build --prefix webui +``` + ## Release 自动构建(GitHub Actions) 工作流文件:`.github/workflows/release-artifacts.yml` - **触发条件**:仅在 GitHub Release `published` 时触发(普通 push 不会触发) - **构建产物**:多平台二进制包(`linux/amd64`、`linux/arm64`、`darwin/amd64`、`darwin/arm64`、`windows/amd64`)+ `sha256sums.txt` +- **容器镜像发布**:仅推送到 GHCR(`ghcr.io/cjackhwang/ds2api`) - **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件、配置示例、README、LICENSE ## 免责声明 diff --git a/README.en.md b/README.en.md index bbad73b..4e872ad 100644 --- a/README.en.md +++ b/README.en.md @@ -3,18 +3,18 @@ [![License](https://img.shields.io/github/license/CJackHwang/ds2api.svg)](LICENSE) ![Stars](https://img.shields.io/github/stars/CJackHwang/ds2api.svg) ![Forks](https://img.shields.io/github/forks/CJackHwang/ds2api.svg) -[![Version](https://img.shields.io/badge/version-1.6.11-blue.svg)](version.txt) +[![Release](https://img.shields.io/github/v/release/CJackHwang/ds2api?display_name=tag)](https://github.com/CJackHwang/ds2api/releases) [![Docker](https://img.shields.io/badge/docker-ready-blue.svg)](DEPLOY.en.md) Language: [中文](README.MD) | [English](README.en.md) -DS2API converts DeepSeek Web chat capability into OpenAI-compatible and Claude-compatible APIs. The backend is a **pure Go implementation**, with a React WebUI admin panel (source in `webui/`, build output auto-generated to `static/admin` during deployment). +DS2API converts DeepSeek Web chat capability into OpenAI-compatible, Claude-compatible, and Gemini-compatible APIs. The backend is a **pure Go implementation**, with a React WebUI admin panel (source in `webui/`, build output auto-generated to `static/admin` during deployment). ## Architecture Overview ```mermaid flowchart LR - Client["🖥️ Clients\n(OpenAI / Claude compat)"] + Client["🖥️ Clients\n(OpenAI / Claude / Gemini compat)"] subgraph DS2API["DS2API Service"] direction TB @@ -24,6 +24,7 @@ flowchart LR subgraph Adapters["Adapter Layer"] OA["OpenAI Adapter\n/v1/*"] CA["Claude Adapter\n/anthropic/*"] + GA["Gemini Adapter\n/v1beta/models/*"] end subgraph Support["Support Modules"] @@ -38,11 +39,11 @@ flowchart LR DS["☁️ DeepSeek API"] Client -- "Request" --> CORS --> Auth - Auth --> OA & CA - OA & CA -- "Call" --> DS + Auth --> OA & CA & GA + OA & CA & GA -- "Call" --> DS Auth --> Admin - OA & CA -. "Rotate accounts" .-> Pool - OA & CA -. "Compute PoW" .-> PoW + OA & CA & GA -. "Rotate accounts" .-> Pool + OA & CA & GA -. "Compute PoW" .-> PoW DS -- "Response" --> Client ``` @@ -54,16 +55,29 @@ flowchart LR | Capability | Details | | --- | --- | -| OpenAI compatible | `GET /v1/models`, `POST /v1/chat/completions` (stream/non-stream) | -| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` | +| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` | +| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` (plus shortcut paths `/v1/messages`, `/messages`) | +| Gemini compatible | `POST /v1beta/models/{model}:generateContent`, `POST /v1beta/models/{model}:streamGenerateContent` (plus `/v1/models/{model}:*` paths) | | Multi-account rotation | Auto token refresh, email/mobile dual login | | Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency | | DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency | -| Tool Calling | Anti-leak handling: auto buffer, detect, structured output | -| Admin API | Config management, account testing/batch test, import/export, Vercel sync | +| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output | +| Admin API | Config management, runtime settings hot-reload, account testing/batch test, import/export, Vercel sync | | WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) | | Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) | +## Platform Compatibility Matrix + +| Tier | Platform | Status | +| --- | --- | --- | +| P0 | Codex CLI/SDK (`wire_api=chat` / `wire_api=responses`) | ✅ | +| P0 | OpenAI SDK (JS/Python, chat + responses) | ✅ | +| P0 | Vercel AI SDK (openai-compatible) | ✅ | +| P0 | Anthropic SDK (messages) | ✅ | +| P0 | Google Gemini SDK (generateContent) | ✅ | +| P1 | LangChain / LlamaIndex / OpenWebUI (OpenAI-compatible integration) | ✅ | +| P2 | MCP standalone bridge | Planned | + ## Model Support ### OpenAI Endpoint @@ -86,8 +100,25 @@ flowchart LR Override mapping via `claude_mapping` or `claude_model_mapping` in config. In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility. +### Gemini Endpoint + +The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output). + ## Quick Start +### Universal First Step (all deployment modes) + +Use `config.json` as the single source of truth (recommended): + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Recommended per deployment mode: +- Local run: read `config.json` directly +- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON` + ### Option 1: Local Run **Prerequisites**: Go 1.24+, Node.js 20+ (only if building WebUI locally) @@ -112,14 +143,20 @@ Default URL: `http://localhost:5001` ### Option 2: Docker ```bash -# 1. Configure environment +# 1. Prepare env file cp .env.example .env -# Edit .env -# 2. Start +# 2. Generate DS2API_CONFIG_JSON from config.json (single-line Base64) +DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')" + +# 3. Edit .env and set: +# DS2API_ADMIN_KEY=replace-with-a-strong-secret +# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON} + +# 4. Start docker-compose up -d -# 3. View logs +# 5. View logs docker-compose logs -f ``` @@ -129,9 +166,22 @@ Rebuild after updates: `docker-compose up -d --build` 1. Fork this repo to your GitHub account 2. Import the project on Vercel -3. Set environment variables (minimum: `DS2API_ADMIN_KEY` and `DS2API_CONFIG_JSON`) +3. Set environment variables (minimum: `DS2API_ADMIN_KEY`; recommended to also set `DS2API_CONFIG_JSON`) 4. Deploy +Recommended first step in repo root: + +```bash +cp config.example.json config.json +# Edit config.json +``` + +Recommended: convert `config.json` to Base64 locally, then paste into `DS2API_CONFIG_JSON` to avoid JSON formatting mistakes: + +```bash +base64 < config.json | tr -d '\n' +``` + > **Streaming note**: `/v1/chat/completions` on Vercel is routed to `api/chat-stream.js` (Node Runtime) for real-time SSE. Auth, account selection, and session/PoW preparation are still handled by the Go internal prepare endpoint; streaming output (including `tools`) is assembled on Node with Go-aligned anti-leak handling. For detailed deployment instructions, see the [Deployment Guide](DEPLOY.en.md). @@ -142,8 +192,8 @@ GitHub Actions automatically builds multi-platform archives on each Release: ```bash # After downloading the archive for your platform -tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz -cd ds2api_v1.7.0_linux_amd64 +tar -xzf ds2api__linux_amd64.tar.gz +cd ds2api__linux_amd64 cp config.example.json config.json # Edit config.json ./ds2api @@ -164,6 +214,7 @@ cp opencode.json.example opencode.json 3. Start OpenCode CLI in the project directory (run `opencode` using your installed method). > Recommended: use the OpenAI-compatible path (`/v1/*`) via `@ai-sdk/openai-compatible` as shown in the example. +> If your client supports `wire_api`, test both `responses` and `chat`; DS2API supports both paths. ## Configuration @@ -184,9 +235,35 @@ cp opencode.json.example opencode.json "token": "" } ], + "model_aliases": { + "gpt-4o": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o3": "deepseek-reasoner" + }, + "compat": { + "wide_input_strict_output": true + }, + "toolcall": { + "mode": "feature_match", + "early_emit_confidence": "high" + }, + "responses": { + "store_ttl_seconds": 900 + }, + "embeddings": { + "provider": "deterministic" + }, "claude_model_mapping": { "fast": "deepseek-chat", "slow": "deepseek-reasoner" + }, + "admin": { + "jwt_expire_hours": 24 + }, + "runtime": { + "account_max_inflight": 2, + "account_max_queue": 0, + "global_max_inflight": 0 } } ``` @@ -194,7 +271,14 @@ cp opencode.json.example opencode.json - `keys`: API access keys; clients authenticate via `Authorization: Bearer ` - `accounts`: DeepSeek account list, supports `email` or `mobile` login - `token`: Leave empty for auto-login on first request; or pre-fill an existing token +- `model_aliases`: Map common model names (GPT/Codex/Claude) to DeepSeek models +- `compat.wide_input_strict_output`: Keep `true` (current default policy) +- `toolcall`: Fixed to feature matching + high-confidence early emit +- `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}` +- `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in) - `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models +- `admin`: Admin panel settings (JWT expiry, password hash, etc.), hot-reloadable via Admin Settings API +- `runtime`: Runtime parameters (concurrency limits, queue sizes), hot-reloadable via Admin Settings API ### Environment Variables @@ -214,8 +298,13 @@ cp opencode.json.example opencode.json | `DS2API_ACCOUNT_CONCURRENCY` | Alias (legacy compat) | — | | `DS2API_ACCOUNT_MAX_QUEUE` | Waiting queue limit | `recommended_concurrency` | | `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — | +| `DS2API_GLOBAL_MAX_INFLIGHT` | Global max in-flight requests | `recommended_concurrency` | +| `DS2API_MAX_INFLIGHT` | Alias (legacy compat) | — | | `DS2API_VERCEL_INTERNAL_SECRET` | Vercel hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` | | `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL seconds | `900` | +| `DS2API_DEV_PACKET_CAPTURE` | Local dev packet capture switch (record recent request/response bodies) | Enabled by default on non-Vercel local runtime | +| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | Number of captured sessions to retain (auto-evict overflow) | `5` | +| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | Max recorded bytes per captured response body | `2097152` | | `VERCEL_TOKEN` | Vercel sync token | — | | `VERCEL_PROJECT_ID` | Vercel project ID | — | | `VERCEL_TEAM_ID` | Vercel team ID | — | @@ -223,7 +312,7 @@ cp opencode.json.example opencode.json ## Authentication Modes -For business endpoints (`/v1/*`, `/anthropic/*`), DS2API supports two modes: +For business endpoints (`/v1/*`, `/anthropic/*`, Gemini routes), DS2API supports two modes: | Mode | Description | | --- | --- | @@ -249,10 +338,34 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency) When `tools` is present in the request, DS2API performs anti-leak handling: -1. With `stream=true`, DS2API **buffers** text deltas first -2. If a tool call is detected → only structured `tool_calls` are emitted, raw JSON is not leaked -3. If no tool call → buffered text is emitted at once -4. Parser supports mixed text, fenced JSON, and `function.arguments` payloads +1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored) +2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`) +3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls +4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream +5. Valid tool call events are only emitted after passing policy validation, preventing invalid tool names from entering the client execution chain + +## Local Dev Packet Capture + +This is for debugging issues such as Responses reasoning streaming and tool-call handoff. When enabled, DS2API stores the latest N DeepSeek conversation payload pairs (request body + upstream response body), defaulting to 5 entries with auto-eviction. + +Enable example: + +```bash +DS2API_DEV_PACKET_CAPTURE=true \ +DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \ +go run ./cmd/ds2api +``` + +Inspect/clear (Admin JWT required): + +- `GET /admin/dev/captures`: list captured items (newest first) +- `DELETE /admin/dev/captures`: clear captured items + +Response fields include: + +- `request_body`: full payload sent to DeepSeek +- `response_body`: concatenated raw upstream stream body text +- `response_truncated`: whether body-size truncation happened ## Project Structure @@ -269,13 +382,20 @@ ds2api/ │ ├── account/ # Account pool and concurrency queue │ ├── adapter/ │ │ ├── openai/ # OpenAI adapter (incl. tool call parsing, Vercel stream prepare/release) -│ │ └── claude/ # Claude adapter -│ ├── admin/ # Admin API handlers +│ │ ├── claude/ # Claude adapter +│ │ └── gemini/ # Gemini adapter (generateContent / streamGenerateContent) +│ ├── admin/ # Admin API handlers (incl. Settings hot-reload) │ ├── auth/ # Auth and JWT +│ ├── claudeconv/ # Claude message format conversion +│ ├── compat/ # Compatibility helpers │ ├── config/ # Config loading and hot-reload │ ├── deepseek/ # DeepSeek API client, PoW WASM +│ ├── devcapture/ # Dev packet capture module +│ ├── format/ # Output formatting +│ ├── prompt/ # Prompt construction │ ├── server/ # HTTP routing and middleware (chi router) │ ├── sse/ # SSE parsing utilities +│ ├── stream/ # Unified stream consumption engine │ ├── util/ # Common utilities │ └── webui/ # WebUI static file serving and auto-build ├── webui/ # React WebUI source (Vite + Tailwind) @@ -283,11 +403,13 @@ ds2api/ │ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage │ └── locales/ # Language packs (zh.json / en.json) ├── scripts/ -│ ├── build-webui.sh # Manual WebUI build script -│ └── testsuite/ # Testsuite runner scripts +│ └── build-webui.sh # Manual WebUI build script +├── tests/ +│ ├── compat/ # Compatibility fixtures and expected outputs +│ └── scripts/ # Unified test script entrypoints (unit/e2e) ├── static/admin/ # WebUI build output (not committed to Git) ├── .github/ -│ ├── workflows/ # GitHub Actions (Release artifact automation) +│ ├── workflows/ # GitHub Actions (quality gates + release automation) │ ├── ISSUE_TEMPLATE/ # Issue templates │ └── PULL_REQUEST_TEMPLATE.md ├── config.example.json # Config file template @@ -296,8 +418,7 @@ ds2api/ ├── docker-compose.yml # Production Docker Compose ├── docker-compose.dev.yml # Development Docker Compose ├── vercel.json # Vercel routing and build config -├── go.mod / go.sum # Go module dependencies -└── version.txt # Version number +└── go.mod / go.sum # Go module dependencies ``` ## Documentation Index @@ -312,11 +433,11 @@ ds2api/ ## Testing ```bash -# Unit tests -go test ./... +# Unit tests (Go + Node) +./tests/scripts/run-unit-all.sh # One-command live end-to-end tests (real accounts, full request/response logs) -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh # Or with custom flags go run ./cmd/ds2api-tests \ @@ -327,12 +448,21 @@ go run ./cmd/ds2api-tests \ --retries 2 ``` +```bash +# Release-blocking gates +./tests/scripts/check-stage6-manual-smoke.sh +./tests/scripts/check-refactor-line-gate.sh +./tests/scripts/run-unit-all.sh +npm ci --prefix webui && npm run build --prefix webui +``` + ## Release Artifact Automation (GitHub Actions) Workflow: `.github/workflows/release-artifacts.yml` - **Trigger**: only on GitHub Release `published` (normal pushes do not trigger builds) - **Outputs**: multi-platform archives (`linux/amd64`, `linux/arm64`, `darwin/amd64`, `darwin/arm64`, `windows/amd64`) + `sha256sums.txt` +- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`) - **Each archive includes**: `ds2api` executable, `static/admin`, WASM file, config template, README, LICENSE ## Disclaimer diff --git a/TESTING.md b/TESTING.md index 5540592..e617181 100644 --- a/TESTING.md +++ b/TESTING.md @@ -8,8 +8,10 @@ DS2API 提供两个层级的测试: | 层级 | 命令 | 说明 | | --- | --- | --- | -| 单元测试 | `go test ./...` | 不需要真实账号 | -| 端到端测试 | `./scripts/testsuite/run-live.sh` | 使用真实账号执行全链路测试 | +| 单元测试(Go) | `./tests/scripts/run-unit-go.sh` | 不需要真实账号 | +| 单元测试(Node) | `./tests/scripts/run-unit-node.sh` | 不需要真实账号 | +| 单元测试(全部) | `./tests/scripts/run-unit-all.sh` | 不需要真实账号 | +| 端到端测试 | `./tests/scripts/run-live.sh` | 使用真实账号执行全链路测试 | 端到端测试集会录制完整的请求/响应日志,用于故障排查。 @@ -20,26 +22,36 @@ DS2API 提供两个层级的测试: ### 单元测试 | Unit Tests ```bash -go test ./... +./tests/scripts/run-unit-all.sh ``` ```bash -node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js +# 或按语言拆分执行 +./tests/scripts/run-unit-go.sh +./tests/scripts/run-unit-node.sh +``` + +```bash +# 结构与流程门禁 +./tests/scripts/check-refactor-line-gate.sh +./tests/scripts/check-node-split-syntax.sh + +# 发布阻断:阶段 6 手工烟测签字检查(默认读取 plans/stage6-manual-smoke.md) +./tests/scripts/check-stage6-manual-smoke.sh ``` ### 端到端测试 | End-to-End Tests ```bash -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh ``` **默认行为**: 1. **Preflight 检查**: - `go test ./... -count=1`(单元测试) - - `node --check api/chat-stream.js`(语法检查) - - `node --check api/helpers/stream-tool-sieve.js`(语法检查) - - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js`(Node 流式拦截单测) + - `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁) + - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`(Node 流式拦截 + compat 单测) - `npm run build --prefix webui`(WebUI 构建检查) 2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程 @@ -179,7 +191,7 @@ go run ./cmd/ds2api-tests \ ```bash # 确保 config.json 存在且包含有效测试账号 -./scripts/testsuite/run-live.sh +./tests/scripts/run-live.sh exit_code=$? if [ $exit_code -ne 0 ]; then echo "Tests failed! Check artifacts for details." diff --git a/api/chat-stream.js b/api/chat-stream.js index aa92b17..9241b04 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -1,770 +1,3 @@ 'use strict'; -const { - extractToolNames, - createToolSieveState, - processToolSieveChunk, - flushToolSieve, - parseToolCalls, - formatOpenAIStreamToolCalls, -} = require('./helpers/stream-tool-sieve'); - -const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion'; - -const BASE_HEADERS = { - Host: 'chat.deepseek.com', - 'User-Agent': 'DeepSeek/1.6.11 Android/35', - Accept: 'application/json', - 'Content-Type': 'application/json', - 'x-client-platform': 'android', - 'x-client-version': '1.6.11', - 'x-client-locale': 'zh_CN', - 'accept-charset': 'UTF-8', -}; - -const SKIP_PATTERNS = [ - 'quasi_status', - 'elapsed_secs', - 'token_usage', - 'pending_fragment', - 'conversation_mode', - 'fragments/-1/status', - 'fragments/-2/status', - 'fragments/-3/status', -]; - -module.exports = async function handler(req, res) { - setCorsHeaders(res); - if (req.method === 'OPTIONS') { - res.statusCode = 204; - res.end(); - return; - } - if (req.method !== 'POST') { - writeOpenAIError(res, 405, 'method not allowed'); - return; - } - - const rawBody = await readRawBody(req); - - // Hard guard: only use Node data path for streaming on Vercel runtime. - // Any non-Vercel runtime always falls back to Go for full behavior parity. - if (!isVercelRuntime()) { - await proxyToGo(req, res, rawBody); - return; - } - - let payload; - try { - payload = JSON.parse(rawBody.toString('utf8') || '{}'); - } catch (_err) { - writeOpenAIError(res, 400, 'invalid json'); - return; - } - - // Keep all non-stream behavior on Go side to avoid compatibility regressions. - if (!toBool(payload.stream)) { - await proxyToGo(req, res, rawBody); - return; - } - - const prep = await fetchStreamPrepare(req, rawBody); - if (!prep.ok) { - relayPreparedFailure(res, prep); - return; - } - - const model = asString(prep.body.model) || asString(payload.model); - const sessionID = asString(prep.body.session_id) || `chatcmpl-${Date.now()}`; - const leaseID = asString(prep.body.lease_id); - const deepseekToken = asString(prep.body.deepseek_token); - const powHeader = asString(prep.body.pow_header); - const completionPayload = prep.body.payload && typeof prep.body.payload === 'object' ? prep.body.payload : null; - const finalPrompt = asString(prep.body.final_prompt); - const thinkingEnabled = toBool(prep.body.thinking_enabled); - const searchEnabled = toBool(prep.body.search_enabled); - const toolNames = extractToolNames(payload.tools); - - if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) { - writeOpenAIError(res, 500, 'invalid vercel prepare response'); - return; - } - const releaseLease = createLeaseReleaser(req, leaseID); - try { - const completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { - method: 'POST', - headers: { - ...BASE_HEADERS, - authorization: `Bearer ${deepseekToken}`, - 'x-ds-pow-response': powHeader, - }, - body: JSON.stringify(completionPayload), - }); - - if (!completionRes.ok || !completionRes.body) { - const detail = await safeReadText(completionRes); - writeOpenAIError(res, 500, detail ? `Failed to get completion: ${detail}` : 'Failed to get completion.'); - return; - } - - res.statusCode = 200; - res.setHeader('Content-Type', 'text/event-stream'); - res.setHeader('Cache-Control', 'no-cache, no-transform'); - res.setHeader('Connection', 'keep-alive'); - res.setHeader('X-Accel-Buffering', 'no'); - if (typeof res.flushHeaders === 'function') { - res.flushHeaders(); - } - - const created = Math.floor(Date.now() / 1000); - let firstChunkSent = false; - let currentType = thinkingEnabled ? 'thinking' : 'text'; - let thinkingText = ''; - let outputText = ''; - const toolSieveEnabled = toolNames.length > 0; - const toolSieveState = createToolSieveState(); - let toolCallsEmitted = false; - const decoder = new TextDecoder(); - const reader = completionRes.body.getReader(); - let buffered = ''; - let ended = false; - - const sendFrame = (obj) => { - res.write(`data: ${JSON.stringify(obj)}\n\n`); - if (typeof res.flush === 'function') { - res.flush(); - } - }; - - const sendDeltaFrame = (delta) => { - const payloadDelta = { ...delta }; - if (!firstChunkSent) { - payloadDelta.role = 'assistant'; - firstChunkSent = true; - } - sendFrame({ - id: sessionID, - object: 'chat.completion.chunk', - created, - model, - choices: [{ delta: payloadDelta, index: 0 }], - }); - }; - - const finish = async (reason) => { - if (ended) { - return; - } - ended = true; - const detected = parseToolCalls(outputText, toolNames); - if (detected.length > 0 && !toolCallsEmitted) { - toolCallsEmitted = true; - sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) }); - } else if (toolSieveEnabled) { - const tailEvents = flushToolSieve(toolSieveState, toolNames); - for (const evt of tailEvents) { - if (evt.text) { - sendDeltaFrame({ content: evt.text }); - } - } - } - if (detected.length > 0 || toolCallsEmitted) { - reason = 'tool_calls'; - } - sendFrame({ - id: sessionID, - object: 'chat.completion.chunk', - created, - model, - choices: [{ delta: {}, index: 0, finish_reason: reason }], - usage: buildUsage(finalPrompt, thinkingText, outputText), - }); - res.write('data: [DONE]\n\n'); - await releaseLease(); - res.end(); - }; - - try { - // eslint-disable-next-line no-constant-condition - while (true) { - const { value, done } = await reader.read(); - if (done) { - break; - } - buffered += decoder.decode(value, { stream: true }); - const lines = buffered.split('\n'); - buffered = lines.pop() || ''; - - for (const rawLine of lines) { - const line = rawLine.trim(); - if (!line.startsWith('data:')) { - continue; - } - const dataStr = line.slice(5).trim(); - if (!dataStr) { - continue; - } - if (dataStr === '[DONE]') { - await finish('stop'); - return; - } - let chunk; - try { - chunk = JSON.parse(dataStr); - } catch (_err) { - continue; - } - if (chunk.error || chunk.code === 'content_filter') { - await finish('content_filter'); - return; - } - const parsed = parseChunkForContent(chunk, thinkingEnabled, currentType); - currentType = parsed.newType; - if (parsed.finished) { - await finish('stop'); - return; - } - - for (const p of parsed.parts) { - if (!p.text) { - continue; - } - if (searchEnabled && isCitation(p.text)) { - continue; - } - if (p.type === 'thinking') { - if (thinkingEnabled) { - thinkingText += p.text; - sendDeltaFrame({ reasoning_content: p.text }); - } - } else { - outputText += p.text; - if (!toolSieveEnabled) { - sendDeltaFrame({ content: p.text }); - continue; - } - const events = processToolSieveChunk(toolSieveState, p.text, toolNames); - for (const evt of events) { - if (evt.type === 'tool_calls') { - toolCallsEmitted = true; - sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) }); - continue; - } - if (evt.text) { - sendDeltaFrame({ content: evt.text }); - } - } - } - } - } - } - await finish('stop'); - } catch (_err) { - await finish('stop'); - } - } finally { - await releaseLease(); - } -}; - -function setCorsHeaders(res) { - res.setHeader('Access-Control-Allow-Origin', '*'); - res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE'); - res.setHeader( - 'Access-Control-Allow-Headers', - 'Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass', - ); -} - -function header(req, key) { - if (!req || !req.headers) { - return ''; - } - return asString(req.headers[key.toLowerCase()]); -} - -async function readRawBody(req) { - if (Buffer.isBuffer(req.body)) { - return req.body; - } - if (typeof req.body === 'string') { - return Buffer.from(req.body); - } - if (req.body && typeof req.body === 'object') { - return Buffer.from(JSON.stringify(req.body)); - } - const chunks = []; - for await (const chunk of req) { - chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)); - } - return Buffer.concat(chunks); -} - -async function fetchStreamPrepare(req, rawBody) { - const url = buildInternalGoURL(req); - url.searchParams.set('__stream_prepare', '1'); - - const upstream = await fetch(url.toString(), { - method: 'POST', - headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }), - body: rawBody, - }); - - const text = await upstream.text(); - let body = {}; - try { - body = JSON.parse(text || '{}'); - } catch (_err) { - body = {}; - } - - return { - ok: upstream.ok, - status: upstream.status, - contentType: upstream.headers.get('content-type') || 'application/json', - text, - body, - }; -} - -function relayPreparedFailure(res, prep) { - if (prep.status === 401 && looksLikeVercelAuthPage(prep.text)) { - writeOpenAIError( - res, - 401, - 'Vercel Deployment Protection blocked internal prepare request. Disable protection for this deployment or set VERCEL_AUTOMATION_BYPASS_SECRET.', - ); - return; - } - res.statusCode = prep.status || 500; - res.setHeader('Content-Type', prep.contentType || 'application/json'); - if (prep.text) { - res.end(prep.text); - return; - } - writeOpenAIError(res, prep.status || 500, 'vercel prepare failed'); -} - -async function safeReadText(resp) { - if (!resp) { - return ''; - } - try { - const text = await resp.text(); - return text.trim(); - } catch (_err) { - return ''; - } -} - -function internalSecret() { - return asString(process.env.DS2API_VERCEL_INTERNAL_SECRET) || asString(process.env.DS2API_ADMIN_KEY) || 'admin'; -} - -function buildInternalGoURL(req) { - const proto = asString(header(req, 'x-forwarded-proto')) || 'https'; - const host = asString(header(req, 'host')); - const url = new URL(`${proto}://${host}${req.url || '/v1/chat/completions'}`); - url.searchParams.set('__go', '1'); - const protectionBypass = resolveProtectionBypass(req); - if (protectionBypass) { - url.searchParams.set('x-vercel-protection-bypass', protectionBypass); - } - return url; -} - -function buildInternalGoHeaders(req, opts = {}) { - const headers = { - authorization: asString(header(req, 'authorization')), - 'x-api-key': asString(header(req, 'x-api-key')), - 'x-ds2-target-account': asString(header(req, 'x-ds2-target-account')), - 'x-vercel-protection-bypass': resolveProtectionBypass(req), - }; - if (opts.withInternalToken) { - headers['x-ds2-internal-token'] = internalSecret(); - } - if (opts.withContentType) { - headers['content-type'] = asString(header(req, 'content-type')) || 'application/json'; - } - return headers; -} - -function createLeaseReleaser(req, leaseID) { - let released = false; - return async () => { - if (released || !leaseID) { - return; - } - released = true; - try { - await releaseStreamLease(req, leaseID); - } catch (_err) { - // Ignore release errors. Lease TTL cleanup on Go side still prevents permanent leaks. - } - }; -} - -async function releaseStreamLease(req, leaseID) { - const url = buildInternalGoURL(req); - url.searchParams.set('__stream_release', '1'); - const body = Buffer.from(JSON.stringify({ lease_id: leaseID })); - - const controller = new AbortController(); - const timeout = setTimeout(() => controller.abort(), 1500); - try { - await fetch(url.toString(), { - method: 'POST', - headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }), - body, - signal: controller.signal, - }); - } finally { - clearTimeout(timeout); - } -} - -function resolveProtectionBypass(req) { - const fromHeader = asString(header(req, 'x-vercel-protection-bypass')); - if (fromHeader) { - return fromHeader; - } - return asString(process.env.VERCEL_AUTOMATION_BYPASS_SECRET) || asString(process.env.DS2API_VERCEL_PROTECTION_BYPASS); -} - -function looksLikeVercelAuthPage(text) { - const body = asString(text).toLowerCase(); - if (!body) { - return false; - } - return body.includes('authentication required') && body.includes('vercel'); -} - -function parseChunkForContent(chunk, thinkingEnabled, currentType) { - if (!chunk || typeof chunk !== 'object' || !Object.prototype.hasOwnProperty.call(chunk, 'v')) { - return { parts: [], finished: false, newType: currentType }; - } - const pathValue = asString(chunk.p); - if (shouldSkipPath(pathValue)) { - return { parts: [], finished: false, newType: currentType }; - } - if (pathValue === 'response/status' && asString(chunk.v) === 'FINISHED') { - return { parts: [], finished: true, newType: currentType }; - } - - let newType = currentType; - const parts = []; - - if (pathValue === 'response/fragments' && asString(chunk.o).toUpperCase() === 'APPEND' && Array.isArray(chunk.v)) { - for (const frag of chunk.v) { - if (!frag || typeof frag !== 'object') { - continue; - } - const fragType = asString(frag.type).toUpperCase(); - const content = asString(frag.content); - if (!content) { - continue; - } - if (fragType === 'THINK' || fragType === 'THINKING') { - newType = 'thinking'; - parts.push({ text: content, type: 'thinking' }); - } else if (fragType === 'RESPONSE') { - newType = 'text'; - parts.push({ text: content, type: 'text' }); - } else { - parts.push({ text: content, type: 'text' }); - } - } - } - - if (pathValue === 'response' && Array.isArray(chunk.v)) { - for (const item of chunk.v) { - if (!item || typeof item !== 'object') { - continue; - } - if (item.p === 'fragments' && item.o === 'APPEND' && Array.isArray(item.v)) { - for (const frag of item.v) { - const fragType = asString(frag && frag.type).toUpperCase(); - if (fragType === 'THINK' || fragType === 'THINKING') { - newType = 'thinking'; - } else if (fragType === 'RESPONSE') { - newType = 'text'; - } - } - } - } - } - - let partType = 'text'; - if (pathValue === 'response/thinking_content') { - partType = 'thinking'; - } else if (pathValue === 'response/content') { - partType = 'text'; - } else if (pathValue.includes('response/fragments') && pathValue.includes('/content')) { - partType = newType; - } else if (!pathValue && thinkingEnabled) { - partType = newType; - } - - const val = chunk.v; - if (typeof val === 'string') { - if (val === 'FINISHED' && (!pathValue || pathValue === 'status')) { - return { parts: [], finished: true, newType }; - } - if (val) { - parts.push({ text: val, type: partType }); - } - return { parts, finished: false, newType }; - } - - if (Array.isArray(val)) { - const extracted = extractContentRecursive(val, partType); - if (extracted.finished) { - return { parts: [], finished: true, newType }; - } - parts.push(...extracted.parts); - return { parts, finished: false, newType }; - } - - if (val && typeof val === 'object') { - const resp = val.response && typeof val.response === 'object' ? val.response : val; - if (Array.isArray(resp.fragments)) { - for (const frag of resp.fragments) { - if (!frag || typeof frag !== 'object') { - continue; - } - const content = asString(frag.content); - if (!content) { - continue; - } - const t = asString(frag.type).toUpperCase(); - if (t === 'THINK' || t === 'THINKING') { - newType = 'thinking'; - parts.push({ text: content, type: 'thinking' }); - } else if (t === 'RESPONSE') { - newType = 'text'; - parts.push({ text: content, type: 'text' }); - } else { - parts.push({ text: content, type: partType }); - } - } - } - } - return { parts, finished: false, newType }; -} - -function extractContentRecursive(items, defaultType) { - const parts = []; - for (const it of items) { - if (!it || typeof it !== 'object') { - continue; - } - if (!Object.prototype.hasOwnProperty.call(it, 'v')) { - continue; - } - const itemPath = asString(it.p); - const itemV = it.v; - if (itemPath === 'status' && asString(itemV) === 'FINISHED') { - return { parts: [], finished: true }; - } - if (shouldSkipPath(itemPath)) { - continue; - } - const content = asString(it.content); - if (content) { - const typeName = asString(it.type).toUpperCase(); - if (typeName === 'THINK' || typeName === 'THINKING') { - parts.push({ text: content, type: 'thinking' }); - } else if (typeName === 'RESPONSE') { - parts.push({ text: content, type: 'text' }); - } else { - parts.push({ text: content, type: defaultType }); - } - continue; - } - - let partType = defaultType; - if (itemPath.includes('thinking')) { - partType = 'thinking'; - } else if (itemPath.includes('content') || itemPath === 'response' || itemPath === 'fragments') { - partType = 'text'; - } - - if (typeof itemV === 'string') { - if (itemV && itemV !== 'FINISHED') { - parts.push({ text: itemV, type: partType }); - } - continue; - } - - if (!Array.isArray(itemV)) { - continue; - } - for (const inner of itemV) { - if (typeof inner === 'string') { - if (inner) { - parts.push({ text: inner, type: partType }); - } - continue; - } - if (!inner || typeof inner !== 'object') { - continue; - } - const ct = asString(inner.content); - if (!ct) { - continue; - } - const typeName = asString(inner.type).toUpperCase(); - if (typeName === 'THINK' || typeName === 'THINKING') { - parts.push({ text: ct, type: 'thinking' }); - } else if (typeName === 'RESPONSE') { - parts.push({ text: ct, type: 'text' }); - } else { - parts.push({ text: ct, type: partType }); - } - } - } - return { parts, finished: false }; -} - -function shouldSkipPath(pathValue) { - if (pathValue === 'response/search_status') { - return true; - } - for (const p of SKIP_PATTERNS) { - if (pathValue.includes(p)) { - return true; - } - } - return false; -} - -function isCitation(text) { - return asString(text).trim().startsWith('[citation:'); -} - -function buildUsage(prompt, thinking, output) { - const promptTokens = estimateTokens(prompt); - const reasoningTokens = estimateTokens(thinking); - const completionTokens = estimateTokens(output); - return { - prompt_tokens: promptTokens, - completion_tokens: reasoningTokens + completionTokens, - total_tokens: promptTokens + reasoningTokens + completionTokens, - completion_tokens_details: { - reasoning_tokens: reasoningTokens, - }, - }; -} - -function estimateTokens(text) { - const t = asString(text); - if (!t) { - return 0; - } - const n = Math.floor(Array.from(t).length / 4); - return n < 1 ? 1 : n; -} - -async function proxyToGo(req, res, rawBody) { - const url = buildInternalGoURL(req); - - const upstream = await fetch(url.toString(), { - method: 'POST', - headers: buildInternalGoHeaders(req, { withContentType: true }), - body: rawBody, - }); - - res.statusCode = upstream.status; - upstream.headers.forEach((value, key) => { - if (key.toLowerCase() === 'content-length') { - return; - } - res.setHeader(key, value); - }); - - if (!upstream.body || typeof upstream.body.getReader !== 'function') { - const bytes = Buffer.from(await upstream.arrayBuffer()); - res.end(bytes); - return; - } - - const reader = upstream.body.getReader(); - try { - // eslint-disable-next-line no-constant-condition - while (true) { - const { value, done } = await reader.read(); - if (done) { - break; - } - if (value && value.length > 0) { - res.write(Buffer.from(value)); - if (typeof res.flush === 'function') { - res.flush(); - } - } - } - res.end(); - } catch (_err) { - if (!res.writableEnded) { - res.end(); - } - } -} - -function writeOpenAIError(res, status, message) { - res.statusCode = status; - res.setHeader('Content-Type', 'application/json'); - res.end( - JSON.stringify({ - error: { - message, - type: openAIErrorType(status), - }, - }), - ); -} - -function openAIErrorType(status) { - switch (status) { - case 400: - return 'invalid_request_error'; - case 401: - return 'authentication_error'; - case 403: - return 'permission_error'; - case 429: - return 'rate_limit_error'; - case 503: - return 'service_unavailable_error'; - default: - return status >= 500 ? 'api_error' : 'invalid_request_error'; - } -} - -function toBool(v) { - return v === true; -} - -function isVercelRuntime() { - return asString(process.env.VERCEL) !== '' || asString(process.env.NOW_REGION) !== ''; -} - -function asString(v) { - if (typeof v === 'string') { - return v.trim(); - } - if (Array.isArray(v)) { - return asString(v[0]); - } - if (v == null) { - return ''; - } - return String(v).trim(); -} - -module.exports.__test = { - parseChunkForContent, - extractContentRecursive, - shouldSkipPath, - asString, -}; +module.exports = require('../internal/js/chat-stream/index.js'); diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js deleted file mode 100644 index 3ced63d..0000000 --- a/api/helpers/stream-tool-sieve.js +++ /dev/null @@ -1,477 +0,0 @@ -'use strict'; - -const crypto = require('crypto'); -const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s; - -function extractToolNames(tools) { - if (!Array.isArray(tools) || tools.length === 0) { - return []; - } - const out = []; - for (const t of tools) { - if (!t || typeof t !== 'object') { - continue; - } - const fn = t.function && typeof t.function === 'object' ? t.function : t; - const name = toStringSafe(fn.name); - // Keep parity with Go injectToolPrompt: object tools without name still - // enter tool mode via fallback name "unknown". - out.push(name || 'unknown'); - } - return out; -} - -function createToolSieveState() { - return { - pending: '', - capture: '', - capturing: false, - }; -} - -function processToolSieveChunk(state, chunk, toolNames) { - if (!state) { - return []; - } - if (chunk) { - state.pending += chunk; - } - const events = []; - // eslint-disable-next-line no-constant-condition - while (true) { - if (state.capturing) { - if (state.pending) { - state.capture += state.pending; - state.pending = ''; - } - const consumed = consumeToolCapture(state.capture, toolNames); - if (!consumed.ready) { - break; - } - state.capture = ''; - state.capturing = false; - if (consumed.prefix) { - events.push({ type: 'text', text: consumed.prefix }); - } - if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { - events.push({ type: 'tool_calls', calls: consumed.calls }); - } - if (consumed.suffix) { - state.pending += consumed.suffix; - } - continue; - } - - if (!state.pending) { - break; - } - - const start = findToolSegmentStart(state.pending); - if (start >= 0) { - const prefix = state.pending.slice(0, start); - if (prefix) { - events.push({ type: 'text', text: prefix }); - } - state.capture = state.pending.slice(start); - state.pending = ''; - state.capturing = true; - continue; - } - - const [safe, hold] = splitSafeContentForToolDetection(state.pending); - if (!safe) { - break; - } - state.pending = hold; - events.push({ type: 'text', text: safe }); - } - return events; -} - -function flushToolSieve(state, toolNames) { - if (!state) { - return []; - } - const events = processToolSieveChunk(state, '', toolNames); - if (state.capturing) { - const consumed = consumeToolCapture(state.capture, toolNames); - if (consumed.ready) { - if (consumed.prefix) { - events.push({ type: 'text', text: consumed.prefix }); - } - if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { - events.push({ type: 'tool_calls', calls: consumed.calls }); - } - if (consumed.suffix) { - events.push({ type: 'text', text: consumed.suffix }); - } - } else if (state.capture) { - // Incomplete captured tool JSON at stream end: suppress raw capture. - } - state.capture = ''; - state.capturing = false; - } - if (state.pending) { - events.push({ type: 'text', text: state.pending }); - state.pending = ''; - } - return events; -} - -function splitSafeContentForToolDetection(s) { - const text = s || ''; - if (!text) { - return ['', '']; - } - const suspiciousStart = findSuspiciousPrefixStart(text); - if (suspiciousStart < 0) { - return [text, '']; - } - if (suspiciousStart > 0) { - return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)]; - } - // If suspicious content starts at the beginning, keep holding until we can - // either parse a full tool JSON block or reach stream flush. - return ['', text]; -} - -function findSuspiciousPrefixStart(s) { - let start = -1; - for (const needle of ['{', '[', '```']) { - const idx = s.lastIndexOf(needle); - if (idx > start) { - start = idx; - } - } - return start; -} - -function findToolSegmentStart(s) { - if (!s) { - return -1; - } - const lower = s.toLowerCase(); - const keyIdx = lower.indexOf('tool_calls'); - if (keyIdx < 0) { - return -1; - } - const start = s.slice(0, keyIdx).lastIndexOf('{'); - return start >= 0 ? start : keyIdx; -} - -function consumeToolCapture(captured, toolNames) { - if (!captured) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const lower = captured.toLowerCase(); - const keyIdx = lower.indexOf('tool_calls'); - if (keyIdx < 0) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const start = captured.slice(0, keyIdx).lastIndexOf('{'); - if (start < 0) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const obj = extractJSONObjectFrom(captured, start); - if (!obj.ok) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames); - if (parsed.length === 0) { - // `tool_calls` key exists but strict JSON parse failed. - // Drop the captured object body to avoid leaking raw tool JSON. - return { - ready: true, - prefix: captured.slice(0, start), - calls: [], - suffix: captured.slice(obj.end), - }; - } - return { - ready: true, - prefix: captured.slice(0, start), - calls: parsed, - suffix: captured.slice(obj.end), - }; -} - -function extractJSONObjectFrom(text, start) { - if (!text || start < 0 || start >= text.length || text[start] !== '{') { - return { ok: false, end: 0 }; - } - let depth = 0; - let quote = ''; - let escaped = false; - for (let i = start; i < text.length; i += 1) { - const ch = text[i]; - if (quote) { - if (escaped) { - escaped = false; - continue; - } - if (ch === '\\') { - escaped = true; - continue; - } - if (ch === quote) { - quote = ''; - } - continue; - } - if (ch === '"' || ch === "'") { - quote = ch; - continue; - } - if (ch === '{') { - depth += 1; - continue; - } - if (ch === '}') { - depth -= 1; - if (depth === 0) { - return { ok: true, end: i + 1 }; - } - } - } - return { ok: false, end: 0 }; -} - -function parseToolCalls(text, toolNames) { - if (!toStringSafe(text)) { - return []; - } - const candidates = buildToolCallCandidates(text); - let parsed = []; - for (const c of candidates) { - parsed = parseToolCallsPayload(c); - if (parsed.length > 0) { - break; - } - } - if (parsed.length === 0) { - return []; - } - const allowed = new Set((toolNames || []).filter(Boolean)); - const out = []; - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - if (allowed.size > 0 && !allowed.has(tc.name)) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); - } - if (out.length === 0 && parsed.length > 0) { - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); - } - } - return out; -} - -function buildToolCallCandidates(text) { - const trimmed = toStringSafe(text); - const candidates = [trimmed]; - const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || []; - for (const block of fenced) { - const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); - if (m && m[1]) { - candidates.push(toStringSafe(m[1])); - } - } - for (const candidate of extractToolCallObjects(trimmed)) { - candidates.push(toStringSafe(candidate)); - } - const first = trimmed.indexOf('{'); - const last = trimmed.lastIndexOf('}'); - if (first >= 0 && last > first) { - candidates.push(toStringSafe(trimmed.slice(first, last + 1))); - } - const m = trimmed.match(TOOL_CALL_PATTERN); - if (m && m[1]) { - candidates.push(`{"tool_calls":[${m[1]}]}`); - } - return [...new Set(candidates.filter(Boolean))]; -} - -function extractToolCallObjects(text) { - const raw = toStringSafe(text); - if (!raw) { - return []; - } - const lower = raw.toLowerCase(); - const out = []; - let offset = 0; - // eslint-disable-next-line no-constant-condition - while (true) { - let idx = lower.indexOf('tool_calls', offset); - if (idx < 0) { - break; - } - let start = raw.slice(0, idx).lastIndexOf('{'); - while (start >= 0) { - const obj = extractJSONObjectFrom(raw, start); - if (obj.ok) { - out.push(raw.slice(start, obj.end).trim()); - offset = obj.end; - idx = -1; - break; - } - start = raw.slice(0, start).lastIndexOf('{'); - } - if (idx >= 0) { - offset = idx + 'tool_calls'.length; - } - } - return out; -} - -function parseToolCallsPayload(payload) { - let decoded; - try { - decoded = JSON.parse(payload); - } catch (_err) { - return []; - } - if (Array.isArray(decoded)) { - return parseToolCallList(decoded); - } - if (!decoded || typeof decoded !== 'object') { - return []; - } - if (decoded.tool_calls) { - return parseToolCallList(decoded.tool_calls); - } - const one = parseToolCallItem(decoded); - return one ? [one] : []; -} - -function parseToolCallList(v) { - if (!Array.isArray(v)) { - return []; - } - const out = []; - for (const item of v) { - if (!item || typeof item !== 'object') { - continue; - } - const one = parseToolCallItem(item); - if (one) { - out.push(one); - } - } - return out; -} - -function parseToolCallItem(m) { - let name = toStringSafe(m.name); - let inputRaw = m.input; - let hasInput = Object.prototype.hasOwnProperty.call(m, 'input'); - const fn = m.function && typeof m.function === 'object' ? m.function : null; - if (fn) { - if (!name) { - name = toStringSafe(fn.name); - } - if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) { - inputRaw = fn.arguments; - hasInput = true; - } - } - if (!hasInput) { - for (const k of ['arguments', 'args', 'parameters', 'params']) { - if (Object.prototype.hasOwnProperty.call(m, k)) { - inputRaw = m[k]; - hasInput = true; - break; - } - } - } - if (!name) { - return null; - } - return { - name, - input: parseToolCallInput(inputRaw), - }; -} - -function parseToolCallInput(v) { - if (v == null) { - return {}; - } - if (typeof v === 'string') { - const raw = toStringSafe(v); - if (!raw) { - return {}; - } - try { - const parsed = JSON.parse(raw); - if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { - return parsed; - } - return { _raw: raw }; - } catch (_err) { - return { _raw: raw }; - } - } - if (typeof v === 'object' && !Array.isArray(v)) { - return v; - } - try { - const parsed = JSON.parse(JSON.stringify(v)); - if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { - return parsed; - } - } catch (_err) { - return {}; - } - return {}; -} - -function formatOpenAIStreamToolCalls(calls) { - if (!Array.isArray(calls) || calls.length === 0) { - return []; - } - return calls.map((c, idx) => ({ - index: idx, - id: `call_${newCallID()}`, - type: 'function', - function: { - name: c.name, - arguments: JSON.stringify(c.input || {}), - }, - })); -} - -function newCallID() { - if (typeof crypto.randomUUID === 'function') { - return crypto.randomUUID().replace(/-/g, ''); - } - return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; -} - -function toStringSafe(v) { - if (typeof v === 'string') { - return v.trim(); - } - if (Array.isArray(v)) { - return toStringSafe(v[0]); - } - if (v == null) { - return ''; - } - return String(v).trim(); -} - -module.exports = { - extractToolNames, - createToolSieveState, - processToolSieveChunk, - flushToolSieve, - parseToolCalls, - formatOpenAIStreamToolCalls, -}; diff --git a/api/helpers/stream-tool-sieve.test.js b/api/helpers/stream-tool-sieve.test.js deleted file mode 100644 index 47b3100..0000000 --- a/api/helpers/stream-tool-sieve.test.js +++ /dev/null @@ -1,130 +0,0 @@ -'use strict'; - -const test = require('node:test'); -const assert = require('node:assert/strict'); - -const { - extractToolNames, - createToolSieveState, - processToolSieveChunk, - flushToolSieve, - parseToolCalls, -} = require('./stream-tool-sieve'); - -function runSieve(chunks, toolNames) { - const state = createToolSieveState(); - const events = []; - for (const chunk of chunks) { - events.push(...processToolSieveChunk(state, chunk, toolNames)); - } - events.push(...flushToolSieve(state, toolNames)); - return events; -} - -function collectText(events) { - return events - .filter((evt) => evt.type === 'text' && evt.text) - .map((evt) => evt.text) - .join(''); -} - -test('extractToolNames keeps tool mode enabled with unknown fallback', () => { - const names = extractToolNames([ - { function: { description: 'no name tool' } }, - { function: { name: ' read_file ' } }, - {}, - ]); - assert.deepEqual(names, ['unknown', 'read_file', 'unknown']); -}); - -test('parseToolCalls keeps non-object argument strings as _raw (Go parity)', () => { - const payload = JSON.stringify({ - tool_calls: [ - { name: 'read_file', input: '123' }, - { name: 'list_dir', input: '[1,2,3]' }, - ], - }); - const calls = parseToolCalls(payload, ['read_file', 'list_dir']); - assert.deepEqual(calls, [ - { name: 'read_file', input: { _raw: '123' } }, - { name: 'list_dir', input: { _raw: '[1,2,3]' } }, - ]); -}); - -test('parseToolCalls still intercepts unknown schema names to avoid leaks', () => { - const payload = JSON.stringify({ - tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }], - }); - const calls = parseToolCalls(payload, ['search']); - assert.equal(calls.length, 1); - assert.equal(calls[0].name, 'not_in_schema'); -}); - -test('parseToolCalls supports fenced json and function.arguments string payload', () => { - const text = [ - 'I will call a tool now.', - '```json', - '{"tool_calls":[{"function":{"name":"read_file","arguments":"{\\"path\\":\\"README.md\\"}"}}]}', - '```', - ].join('\n'); - const calls = parseToolCalls(text, ['read_file']); - assert.equal(calls.length, 1); - assert.equal(calls[0].name, 'read_file'); - assert.deepEqual(calls[0].input, { path: 'README.md' }); -}); - -test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => { - const events = runSieve( - [ - '{"', - 'tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}', - '后置正文C。', - ], - ['read_file'], - ); - const leakedText = collectText(events); - const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0); - assert.equal(hasToolCall, true); - assert.equal(leakedText.includes('{'), false); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); - assert.equal(leakedText.includes('后置正文C。'), true); -}); - -test('sieve drops invalid tool json body while preserving surrounding text', () => { - const events = runSieve( - [ - '前置正文D。', - "{'tool_calls':[{'name':'read_file','input':{'path':'README.MD'}}]}", - '后置正文E。', - ], - ['read_file'], - ); - const leakedText = collectText(events); - const hasToolCall = events.some((evt) => evt.type === 'tool_calls'); - assert.equal(hasToolCall, false); - assert.equal(leakedText.includes('前置正文D。'), true); - assert.equal(leakedText.includes('后置正文E。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); -}); - -test('sieve suppresses incomplete captured tool json on stream finalize', () => { - const events = runSieve( - ['前置正文F。', '{"tool_calls":[{"name":"read_file"'], - ['read_file'], - ); - const leakedText = collectText(events); - assert.equal(leakedText.includes('前置正文F。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); - assert.equal(leakedText.includes('{'), false); -}); - -test('sieve keeps plain text intact in tool mode when no tool call appears', () => { - const events = runSieve( - ['你好,', '这是普通文本回复。', '请继续。'], - ['read_file'], - ); - const leakedText = collectText(events); - const hasToolCall = events.some((evt) => evt.type === 'tool_calls'); - assert.equal(hasToolCall, false); - assert.equal(leakedText, '你好,这是普通文本回复。请继续。'); -}); diff --git a/cmd/ds2api/main.go b/cmd/ds2api/main.go index 8e83008..1466d1c 100644 --- a/cmd/ds2api/main.go +++ b/cmd/ds2api/main.go @@ -2,6 +2,8 @@ package main import ( "context" + "fmt" + "net" "net/http" "os" "os/signal" @@ -28,10 +30,21 @@ func main() { Addr: "0.0.0.0:" + port, Handler: app.Router, } + localURL := fmt.Sprintf("http://127.0.0.1:%s", port) + lanIP := detectLANIPv4() + lanURL := "" + if lanIP != "" { + lanURL = fmt.Sprintf("http://%s:%s", lanIP, port) + } // Start server in a goroutine so we can listen for shutdown signals. go func() { - config.Logger.Info("starting ds2api", "port", port) + if lanURL != "" { + config.Logger.Info("starting ds2api", "bind", srv.Addr, "port", port, "local_url", localURL, "lan_url", lanURL, "lan_ip", lanIP) + } else { + config.Logger.Info("starting ds2api", "bind", srv.Addr, "port", port, "local_url", localURL) + config.Logger.Warn("lan ip not detected; check active network interfaces") + } if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { config.Logger.Error("server stopped unexpectedly", "error", err) os.Exit(1) @@ -54,3 +67,36 @@ func main() { } config.Logger.Info("server gracefully stopped") } + +func detectLANIPv4() string { + ifaces, err := net.Interfaces() + if err != nil { + return "" + } + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + ip = ip.To4() + if ip == nil || !ip.IsPrivate() { + continue + } + return ip.String() + } + } + return "" +} diff --git a/config.example.json b/config.example.json index 7614e77..97161f7 100644 --- a/config.example.json +++ b/config.example.json @@ -24,5 +24,27 @@ "password": "your-password-3", "token": "" } - ] -} \ No newline at end of file + ], + "model_aliases": { + "gpt-4o": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o3": "deepseek-reasoner" + }, + "compat": { + "wide_input_strict_output": true + }, + "toolcall": { + "mode": "feature_match", + "early_emit_confidence": "high" + }, + "responses": { + "store_ttl_seconds": 900 + }, + "embeddings": { + "provider": "deterministic" + }, + "claude_model_mapping": { + "fast": "deepseek-chat", + "slow": "deepseek-reasoner" + } +} diff --git a/internal/account/pool.go b/internal/account/pool.go deleted file mode 100644 index 665bcee..0000000 --- a/internal/account/pool.go +++ /dev/null @@ -1,302 +0,0 @@ -package account - -import ( - "context" - "os" - "sort" - "strconv" - "strings" - "sync" - - "ds2api/internal/config" -) - -type Pool struct { - store *config.Store - mu sync.Mutex - queue []string - inUse map[string]int - waiters []chan struct{} - maxInflightPerAccount int - recommendedConcurrency int - maxQueueSize int -} - -func NewPool(store *config.Store) *Pool { - p := &Pool{ - store: store, - inUse: map[string]int{}, - maxInflightPerAccount: maxInflightFromEnv(), - } - p.Reset() - return p -} - -func (p *Pool) Reset() { - accounts := p.store.Accounts() - sort.SliceStable(accounts, func(i, j int) bool { - iHas := accounts[i].Token != "" - jHas := accounts[j].Token != "" - if iHas == jHas { - return i < j - } - return iHas - }) - ids := make([]string, 0, len(accounts)) - for _, a := range accounts { - id := a.Identifier() - if id != "" { - ids = append(ids, id) - } - } - recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount) - queueLimit := maxQueueFromEnv(recommended) - p.mu.Lock() - defer p.mu.Unlock() - p.drainWaitersLocked() - p.queue = ids - p.inUse = map[string]int{} - p.recommendedConcurrency = recommended - p.maxQueueSize = queueLimit - config.Logger.Info( - "[init_account_queue] initialized", - "total", len(ids), - "max_inflight_per_account", p.maxInflightPerAccount, - "recommended_concurrency", p.recommendedConcurrency, - "max_queue_size", p.maxQueueSize, - ) -} - -func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) { - p.mu.Lock() - defer p.mu.Unlock() - return p.acquireLocked(target, normalizeExclude(exclude)) -} - -func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) { - if ctx == nil { - ctx = context.Background() - } - exclude = normalizeExclude(exclude) - for { - if ctx.Err() != nil { - return config.Account{}, false - } - - p.mu.Lock() - if acc, ok := p.acquireLocked(target, exclude); ok { - p.mu.Unlock() - return acc, true - } - if !p.canQueueLocked(target, exclude) { - p.mu.Unlock() - return config.Account{}, false - } - waiter := make(chan struct{}) - p.waiters = append(p.waiters, waiter) - p.mu.Unlock() - - select { - case <-ctx.Done(): - p.mu.Lock() - p.removeWaiterLocked(waiter) - p.mu.Unlock() - return config.Account{}, false - case <-waiter: - } - } -} - -func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) { - if target != "" { - if exclude[target] || p.inUse[target] >= p.maxInflightPerAccount { - return config.Account{}, false - } - acc, ok := p.store.FindAccount(target) - if !ok { - return config.Account{}, false - } - p.inUse[target]++ - p.bumpQueue(target) - return acc, true - } - - if acc, ok := p.tryAcquire(exclude, true); ok { - return acc, true - } - if acc, ok := p.tryAcquire(exclude, false); ok { - return acc, true - } - return config.Account{}, false -} - -func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) { - for i := 0; i < len(p.queue); i++ { - id := p.queue[i] - if exclude[id] || p.inUse[id] >= p.maxInflightPerAccount { - continue - } - acc, ok := p.store.FindAccount(id) - if !ok { - continue - } - if requireToken && acc.Token == "" { - continue - } - p.inUse[id]++ - p.bumpQueue(id) - return acc, true - } - return config.Account{}, false -} - -func (p *Pool) bumpQueue(accountID string) { - for i, id := range p.queue { - if id != accountID { - continue - } - p.queue = append(p.queue[:i], p.queue[i+1:]...) - p.queue = append(p.queue, accountID) - return - } -} - -func (p *Pool) Release(accountID string) { - if accountID == "" { - return - } - p.mu.Lock() - defer p.mu.Unlock() - count := p.inUse[accountID] - if count <= 0 { - return - } - if count == 1 { - delete(p.inUse, accountID) - p.notifyWaiterLocked() - return - } - p.inUse[accountID] = count - 1 - p.notifyWaiterLocked() -} - -func (p *Pool) Status() map[string]any { - p.mu.Lock() - defer p.mu.Unlock() - available := make([]string, 0, len(p.queue)) - inUseAccounts := make([]string, 0, len(p.inUse)) - inUseSlots := 0 - for _, id := range p.queue { - if p.inUse[id] < p.maxInflightPerAccount { - available = append(available, id) - } - } - for id, count := range p.inUse { - if count > 0 { - inUseAccounts = append(inUseAccounts, id) - inUseSlots += count - } - } - sort.Strings(inUseAccounts) - return map[string]any{ - "available": len(available), - "in_use": inUseSlots, - "total": len(p.store.Accounts()), - "available_accounts": available, - "in_use_accounts": inUseAccounts, - "max_inflight_per_account": p.maxInflightPerAccount, - "recommended_concurrency": p.recommendedConcurrency, - "waiting": len(p.waiters), - "max_queue_size": p.maxQueueSize, - } -} - -func maxInflightFromEnv() int { - for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { - raw := strings.TrimSpace(os.Getenv(key)) - if raw == "" { - continue - } - n, err := strconv.Atoi(raw) - if err == nil && n > 0 { - return n - } - } - return 2 -} - -func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int { - if accountCount <= 0 { - return 0 - } - if maxInflightPerAccount <= 0 { - maxInflightPerAccount = 2 - } - return accountCount * maxInflightPerAccount -} - -func normalizeExclude(exclude map[string]bool) map[string]bool { - if exclude == nil { - return map[string]bool{} - } - return exclude -} - -func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool { - if target != "" { - if exclude[target] { - return false - } - if _, ok := p.store.FindAccount(target); !ok { - return false - } - } - if p.maxQueueSize <= 0 { - return false - } - return len(p.waiters) < p.maxQueueSize -} - -func (p *Pool) notifyWaiterLocked() { - if len(p.waiters) == 0 { - return - } - waiter := p.waiters[0] - p.waiters = p.waiters[1:] - close(waiter) -} - -func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool { - for i, w := range p.waiters { - if w != waiter { - continue - } - p.waiters = append(p.waiters[:i], p.waiters[i+1:]...) - return true - } - return false -} - -func (p *Pool) drainWaitersLocked() { - for _, waiter := range p.waiters { - close(waiter) - } - p.waiters = nil -} - -func maxQueueFromEnv(defaultSize int) int { - for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} { - raw := strings.TrimSpace(os.Getenv(key)) - if raw == "" { - continue - } - n, err := strconv.Atoi(raw) - if err == nil && n >= 0 { - return n - } - } - if defaultSize < 0 { - return 0 - } - return defaultSize -} diff --git a/internal/account/pool_acquire.go b/internal/account/pool_acquire.go new file mode 100644 index 0000000..b0c548c --- /dev/null +++ b/internal/account/pool_acquire.go @@ -0,0 +1,108 @@ +package account + +import ( + "context" + + "ds2api/internal/config" +) + +func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) { + p.mu.Lock() + defer p.mu.Unlock() + return p.acquireLocked(target, normalizeExclude(exclude)) +} + +func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) { + if ctx == nil { + ctx = context.Background() + } + exclude = normalizeExclude(exclude) + for { + if ctx.Err() != nil { + return config.Account{}, false + } + + p.mu.Lock() + if acc, ok := p.acquireLocked(target, exclude); ok { + p.mu.Unlock() + return acc, true + } + if !p.canQueueLocked(target, exclude) { + p.mu.Unlock() + return config.Account{}, false + } + waiter := make(chan struct{}) + p.waiters = append(p.waiters, waiter) + p.mu.Unlock() + + select { + case <-ctx.Done(): + p.mu.Lock() + p.removeWaiterLocked(waiter) + p.mu.Unlock() + return config.Account{}, false + case <-waiter: + } + } +} + +func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) { + if target != "" { + if exclude[target] || !p.canAcquireIDLocked(target) { + return config.Account{}, false + } + acc, ok := p.store.FindAccount(target) + if !ok { + return config.Account{}, false + } + p.inUse[target]++ + p.bumpQueue(target) + return acc, true + } + + if acc, ok := p.tryAcquire(exclude, true); ok { + return acc, true + } + if acc, ok := p.tryAcquire(exclude, false); ok { + return acc, true + } + return config.Account{}, false +} + +func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) { + for i := 0; i < len(p.queue); i++ { + id := p.queue[i] + if exclude[id] || !p.canAcquireIDLocked(id) { + continue + } + acc, ok := p.store.FindAccount(id) + if !ok { + continue + } + if requireToken && acc.Token == "" { + continue + } + p.inUse[id]++ + p.bumpQueue(id) + return acc, true + } + return config.Account{}, false +} + +func (p *Pool) bumpQueue(accountID string) { + for i, id := range p.queue { + if id != accountID { + continue + } + p.queue = append(p.queue[:i], p.queue[i+1:]...) + p.queue = append(p.queue, accountID) + return + } +} + +func normalizeExclude(exclude map[string]bool) map[string]bool { + if exclude == nil { + return map[string]bool{} + } + return exclude +} diff --git a/internal/account/pool_core.go b/internal/account/pool_core.go new file mode 100644 index 0000000..90e2594 --- /dev/null +++ b/internal/account/pool_core.go @@ -0,0 +1,132 @@ +package account + +import ( + "sort" + "sync" + + "ds2api/internal/config" +) + +type Pool struct { + store *config.Store + mu sync.Mutex + queue []string + inUse map[string]int + waiters []chan struct{} + maxInflightPerAccount int + recommendedConcurrency int + maxQueueSize int + globalMaxInflight int +} + +func NewPool(store *config.Store) *Pool { + maxPer := 2 + if store != nil { + maxPer = store.RuntimeAccountMaxInflight() + } + p := &Pool{ + store: store, + inUse: map[string]int{}, + maxInflightPerAccount: maxPer, + } + p.Reset() + return p +} + +func (p *Pool) Reset() { + accounts := p.store.Accounts() + sort.SliceStable(accounts, func(i, j int) bool { + iHas := accounts[i].Token != "" + jHas := accounts[j].Token != "" + if iHas == jHas { + return i < j + } + return iHas + }) + ids := make([]string, 0, len(accounts)) + for _, a := range accounts { + id := a.Identifier() + if id != "" { + ids = append(ids, id) + } + } + if p.store != nil { + p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight() + } else { + p.maxInflightPerAccount = maxInflightFromEnv() + } + recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount) + queueLimit := maxQueueFromEnv(recommended) + globalLimit := recommended + if p.store != nil { + queueLimit = p.store.RuntimeAccountMaxQueue(recommended) + globalLimit = p.store.RuntimeGlobalMaxInflight(recommended) + } + p.mu.Lock() + defer p.mu.Unlock() + p.drainWaitersLocked() + p.queue = ids + p.inUse = map[string]int{} + p.recommendedConcurrency = recommended + p.maxQueueSize = queueLimit + p.globalMaxInflight = globalLimit + config.Logger.Info( + "[init_account_queue] initialized", + "total", len(ids), + "max_inflight_per_account", p.maxInflightPerAccount, + "global_max_inflight", p.globalMaxInflight, + "recommended_concurrency", p.recommendedConcurrency, + "max_queue_size", p.maxQueueSize, + ) +} + +func (p *Pool) Release(accountID string) { + if accountID == "" { + return + } + p.mu.Lock() + defer p.mu.Unlock() + count := p.inUse[accountID] + if count <= 0 { + return + } + if count == 1 { + delete(p.inUse, accountID) + p.notifyWaiterLocked() + return + } + p.inUse[accountID] = count - 1 + p.notifyWaiterLocked() +} + +func (p *Pool) Status() map[string]any { + p.mu.Lock() + defer p.mu.Unlock() + available := make([]string, 0, len(p.queue)) + inUseAccounts := make([]string, 0, len(p.inUse)) + inUseSlots := 0 + for _, id := range p.queue { + if p.inUse[id] < p.maxInflightPerAccount { + available = append(available, id) + } + } + for id, count := range p.inUse { + if count > 0 { + inUseAccounts = append(inUseAccounts, id) + inUseSlots += count + } + } + sort.Strings(inUseAccounts) + return map[string]any{ + "available": len(available), + "in_use": inUseSlots, + "total": len(p.store.Accounts()), + "available_accounts": available, + "in_use_accounts": inUseAccounts, + "max_inflight_per_account": p.maxInflightPerAccount, + "global_max_inflight": p.globalMaxInflight, + "recommended_concurrency": p.recommendedConcurrency, + "waiting": len(p.waiters), + "max_queue_size": p.maxQueueSize, + } +} diff --git a/internal/account/pool_edge_test.go b/internal/account/pool_edge_test.go new file mode 100644 index 0000000..6e90823 --- /dev/null +++ b/internal/account/pool_edge_test.go @@ -0,0 +1,249 @@ +package account + +import ( + "context" + "sync" + "testing" + "time" + + "ds2api/internal/config" +) + +// ─── Pool edge cases ───────────────────────────────────────────────── + +func TestPoolEmptyNoAccounts(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "2") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + pool := NewPool(config.LoadStore()) + if _, ok := pool.Acquire("", nil); ok { + t.Fatal("expected acquire to fail with no accounts") + } + status := pool.Status() + if total, ok := status["total"].(int); !ok || total != 0 { + t.Fatalf("unexpected total: %#v", status["total"]) + } +} + +func TestPoolReleaseNonExistentAccount(t *testing.T) { + pool := newPoolForTest(t, "2") + pool.Release("nonexistent@example.com") // should not panic +} + +func TestPoolReleaseAlreadyReleased(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected acquire success") + } + pool.Release(acc.Identifier()) + pool.Release(acc.Identifier()) // double release should not panic +} + +func TestPoolAcquireTargetNotFound(t *testing.T) { + pool := newPoolForTest(t, "2") + if _, ok := pool.Acquire("nonexistent@example.com", nil); ok { + t.Fatal("expected acquire to fail for non-existent target") + } +} + +func TestPoolAcquireWithExclusionList(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, ok := pool.Acquire("", map[string]bool{"acc1@example.com": true}) + if !ok { + t.Fatal("expected acquire success with exclusion") + } + if acc.Identifier() != "acc2@example.com" { + t.Fatalf("expected acc2 when acc1 excluded, got %q", acc.Identifier()) + } + pool.Release(acc.Identifier()) +} + +func TestPoolAcquireAllExcluded(t *testing.T) { + pool := newPoolForTest(t, "2") + if _, ok := pool.Acquire("", map[string]bool{ + "acc1@example.com": true, + "acc2@example.com": true, + }); ok { + t.Fatal("expected acquire to fail when all accounts excluded") + } +} + +func TestPoolStatusFields(t *testing.T) { + pool := newPoolForTest(t, "2") + status := pool.Status() + + // Check all expected fields are present + for _, key := range []string{"total", "available", "max_inflight_per_account", "recommended_concurrency", "available_accounts", "in_use_accounts", "waiting", "max_queue_size"} { + if _, ok := status[key]; !ok { + t.Fatalf("missing status field: %s", key) + } + } +} + +func TestPoolStatusAccountDetails(t *testing.T) { + pool := newPoolForTest(t, "2") + acc, _ := pool.Acquire("acc1@example.com", nil) + + status := pool.Status() + inUseAccounts, ok := status["in_use_accounts"].([]string) + if !ok { + t.Fatalf("unexpected in_use_accounts type: %T", status["in_use_accounts"]) + } + found := false + for _, id := range inUseAccounts { + if id == "acc1@example.com" { + found = true + break + } + } + if !found { + t.Fatalf("expected acc1 in in_use_accounts, got %v", inUseAccounts) + } + if status["in_use"] != 1 { + t.Fatalf("expected 1 in_use, got %v", status["in_use"]) + } + + pool.Release(acc.Identifier()) +} + +func TestPoolAcquireWaitContextCancelled(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + // Exhaust the pool + first, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected first acquire to succeed") + } + + ctx, cancel := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + wg.Add(1) + var waitOK bool + go func() { + defer wg.Done() + _, waitOK = pool.AcquireWait(ctx, "", nil) + }() + + // Wait until queued + waitForWaitingCount(t, pool, 1) + + // Cancel context + cancel() + + wg.Wait() + if waitOK { + t.Fatal("expected acquire to fail after context cancellation") + } + + pool.Release(first.Identifier()) +} + +func TestPoolAcquireWaitTargetAccount(t *testing.T) { + pool := newPoolForTest(t, "1") + // Exhaust acc1 + acc1, ok := pool.Acquire("acc1@example.com", nil) + if !ok { + t.Fatal("expected acquire acc1 success") + } + + // Acquire acc2 directly (should succeed since acc2 is free) + ctx := context.Background() + acc2, ok := pool.AcquireWait(ctx, "acc2@example.com", nil) + if !ok { + t.Fatal("expected acquire acc2 success via AcquireWait") + } + if acc2.Identifier() != "acc2@example.com" { + t.Fatalf("expected acc2, got %q", acc2.Identifier()) + } + + pool.Release(acc1.Identifier()) + pool.Release(acc2.Identifier()) +} + +func TestPoolMaxQueueSizeOverride(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "5") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`) + pool := NewPool(config.LoadStore()) + status := pool.Status() + if got, ok := status["max_queue_size"].(int); !ok || got != 5 { + t.Fatalf("expected max_queue_size=5, got %#v", status["max_queue_size"]) + } +} + +func TestPoolQueueSizeAliasEnv(t *testing.T) { + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "") + t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "7") + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`) + pool := NewPool(config.LoadStore()) + status := pool.Status() + if got, ok := status["max_queue_size"].(int); !ok || got != 7 { + t.Fatalf("expected max_queue_size=7, got %#v", status["max_queue_size"]) + } +} + +func TestPoolMultipleAcquireReleaseCycles(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + for i := 0; i < 10; i++ { + acc, ok := pool.Acquire("", nil) + if !ok { + t.Fatalf("acquire failed at cycle %d", i) + } + pool.Release(acc.Identifier()) + } +} + +func TestPoolConcurrentAcquireWait(t *testing.T) { + pool := newSingleAccountPoolForTest(t, "1") + first, ok := pool.Acquire("", nil) + if !ok { + t.Fatal("expected first acquire success") + } + + const waiters = 3 + results := make(chan bool, waiters) + + for i := 0; i < waiters; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, ok := pool.AcquireWait(ctx, "", nil) + results <- ok + }() + } + + // Wait for all to be queued (only 1 can queue) + time.Sleep(50 * time.Millisecond) + + // Release and allow queued requests to proceed + pool.Release(first.Identifier()) + + successCount := 0 + timeoutCount := 0 + for i := 0; i < waiters; i++ { + select { + case ok := <-results: + if ok { + successCount++ + // Release for next waiter + pool.Release("acc1@example.com") + } else { + timeoutCount++ + } + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for results") + } + } + + // At least 1 should succeed; 2 may fail due to queue limit + if successCount < 1 { + t.Fatalf("expected at least 1 success, got success=%d timeout=%d", successCount, timeoutCount) + } +} diff --git a/internal/account/pool_limits.go b/internal/account/pool_limits.go new file mode 100644 index 0000000..0f0854f --- /dev/null +++ b/internal/account/pool_limits.go @@ -0,0 +1,91 @@ +package account + +import ( + "os" + "strconv" + "strings" +) + +func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) { + if maxInflightPerAccount <= 0 { + maxInflightPerAccount = 1 + } + if maxQueueSize < 0 { + maxQueueSize = 0 + } + if globalMaxInflight <= 0 { + globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts()) + if globalMaxInflight <= 0 { + globalMaxInflight = maxInflightPerAccount + } + } + p.mu.Lock() + defer p.mu.Unlock() + p.maxInflightPerAccount = maxInflightPerAccount + p.maxQueueSize = maxQueueSize + p.globalMaxInflight = globalMaxInflight + p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount) + p.notifyWaiterLocked() +} + +func maxInflightFromEnv() int { + for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + return 2 +} + +func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int { + if accountCount <= 0 { + return 0 + } + if maxInflightPerAccount <= 0 { + maxInflightPerAccount = 2 + } + return accountCount * maxInflightPerAccount +} + +func maxQueueFromEnv(defaultSize int) int { + for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n >= 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} + +func (p *Pool) canAcquireIDLocked(accountID string) bool { + if accountID == "" { + return false + } + if p.inUse[accountID] >= p.maxInflightPerAccount { + return false + } + if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight { + return false + } + return true +} + +func (p *Pool) currentInUseLocked() int { + total := 0 + for _, n := range p.inUse { + total += n + } + return total +} diff --git a/internal/account/pool_waiters.go b/internal/account/pool_waiters.go new file mode 100644 index 0000000..40bd146 --- /dev/null +++ b/internal/account/pool_waiters.go @@ -0,0 +1,43 @@ +package account + +func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool { + if target != "" { + if exclude[target] { + return false + } + if _, ok := p.store.FindAccount(target); !ok { + return false + } + } + if p.maxQueueSize <= 0 { + return false + } + return len(p.waiters) < p.maxQueueSize +} + +func (p *Pool) notifyWaiterLocked() { + if len(p.waiters) == 0 { + return + } + waiter := p.waiters[0] + p.waiters = p.waiters[1:] + close(waiter) +} + +func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool { + for i, w := range p.waiters { + if w != waiter { + continue + } + p.waiters = append(p.waiters[:i], p.waiters[i+1:]...) + return true + } + return false +} + +func (p *Pool) drainWaitersLocked() { + for _, waiter := range p.waiters { + close(waiter) + } + p.waiters = nil +} diff --git a/internal/adapter/claude/convert.go b/internal/adapter/claude/convert.go new file mode 100644 index 0000000..dbb5e1a --- /dev/null +++ b/internal/adapter/claude/convert.go @@ -0,0 +1,11 @@ +package claude + +import ( + "ds2api/internal/claudeconv" +) + +const defaultClaudeModel = "claude-sonnet-4-5" + +func convertClaudeToDeepSeek(claudeReq map[string]any, store ConfigReader) map[string]any { + return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, defaultClaudeModel) +} diff --git a/internal/adapter/claude/deps.go b/internal/adapter/claude/deps.go new file mode 100644 index 0000000..73203b2 --- /dev/null +++ b/internal/adapter/claude/deps.go @@ -0,0 +1,29 @@ +package claude + +import ( + "context" + "net/http" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type AuthResolver interface { + Determine(req *http.Request) (*auth.RequestAuth, error) + Release(a *auth.RequestAuth) +} + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type ConfigReader interface { + ClaudeMapping() map[string]string +} + +var _ AuthResolver = (*auth.Resolver)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) +var _ ConfigReader = (*config.Store)(nil) diff --git a/internal/adapter/claude/deps_injection_test.go b/internal/adapter/claude/deps_injection_test.go new file mode 100644 index 0000000..39dfc2f --- /dev/null +++ b/internal/adapter/claude/deps_injection_test.go @@ -0,0 +1,33 @@ +package claude + +import "testing" + +type mockClaudeConfig struct { + m map[string]string +} + +func (m mockClaudeConfig) ClaudeMapping() map[string]string { return m.m } + +func TestNormalizeClaudeRequestUsesConfigInterfaceMapping(t *testing.T) { + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + out, err := normalizeClaudeRequest(mockClaudeConfig{ + m: map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner-search", + }, + }, req) + if err != nil { + t.Fatalf("normalizeClaudeRequest error: %v", err) + } + if out.Standard.ResolvedModel != "deepseek-reasoner-search" { + t.Fatalf("resolved model mismatch: got=%q", out.Standard.ResolvedModel) + } + if !out.Standard.Thinking || !out.Standard.Search { + t.Fatalf("unexpected flags: thinking=%v search=%v", out.Standard.Thinking, out.Standard.Search) + } +} diff --git a/internal/adapter/claude/error_shape_test.go b/internal/adapter/claude/error_shape_test.go new file mode 100644 index 0000000..b9dc469 --- /dev/null +++ b/internal/adapter/claude/error_shape_test.go @@ -0,0 +1,34 @@ +package claude + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) { + rec := httptest.NewRecorder() + writeClaudeError(rec, http.StatusUnauthorized, "bad token") + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", rec.Code) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body: %v", err) + } + errObj, _ := body["error"].(map[string]any) + if errObj["message"] != "bad token" { + t.Fatalf("unexpected message: %v", errObj["message"]) + } + if errObj["type"] != "invalid_request_error" { + t.Fatalf("unexpected type: %v", errObj["type"]) + } + if errObj["code"] != "authentication_failed" { + t.Fatalf("unexpected code: %v", errObj["code"]) + } + if _, ok := errObj["param"]; !ok { + t.Fatal("expected param field") + } +} diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go deleted file mode 100644 index b9ecd27..0000000 --- a/internal/adapter/claude/handler.go +++ /dev/null @@ -1,603 +0,0 @@ -package claude - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/go-chi/chi/v5" - - "ds2api/internal/auth" - "ds2api/internal/config" - "ds2api/internal/deepseek" - "ds2api/internal/sse" - "ds2api/internal/util" -) - -// writeJSON is a package-internal alias to avoid mass-renaming all call-sites. -var writeJSON = util.WriteJSON - -type Handler struct { - Store *config.Store - Auth *auth.Resolver - DS *deepseek.Client -} - -var ( - claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second - claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second - claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount -) - -func RegisterRoutes(r chi.Router, h *Handler) { - r.Get("/anthropic/v1/models", h.ListModels) - r.Post("/anthropic/v1/messages", h.Messages) - r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens) -} - -func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, config.ClaudeModelsResponse()) -} - -func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { - a, err := h.Auth.Determine(r) - if err != nil { - status := http.StatusUnauthorized - detail := err.Error() - if err == auth.ErrNoAccount { - status = http.StatusTooManyRequests - } - writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}}) - return - } - defer h.Auth.Release(a) - - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}}) - return - } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}}) - return - } - - normalized := normalizeClaudeMessages(messagesRaw) - payload := cloneMap(req) - payload["messages"] = normalized - toolsRequested, _ := req["tools"].([]any) - if len(toolsRequested) > 0 && !hasSystemMessage(normalized) { - payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...) - } - - dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store) - dsModel, _ := dsPayload["model"].(string) - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) - if !ok { - thinkingEnabled = false - searchEnabled = false - } - finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) - - sessionID, err := h.DS.CreateSession(r.Context(), a, 3) - if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}}) - return - } - pow, err := h.DS.GetPow(r.Context(), a, 3) - if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}}) - return - } - requestPayload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } - resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}}) - return - } - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}}) - return - } - - toolNames := extractClaudeToolNames(toolsRequested) - if util.ToBool(req["stream"]) { - h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames) - return - } - result := sse.CollectStream(resp, thinkingEnabled, true) - fullText := result.Text - fullThinking := result.Thinking - detected := util.ParseToolCalls(fullText, toolNames) - content := make([]map[string]any, 0, 4) - if fullThinking != "" { - content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking}) - } - stopReason := "end_turn" - if len(detected) > 0 { - stopReason = "tool_use" - for i, tc := range detected { - content = append(content, map[string]any{ - "type": "tool_use", - "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), - "name": tc.Name, - "input": tc.Input, - }) - } - } else { - if fullText == "" { - fullText = "抱歉,没有生成有效的响应内容。" - } - content = append(content, map[string]any{"type": "text", "text": fullText}) - } - writeJSON(w, http.StatusOK, map[string]any{ - "id": fmt.Sprintf("msg_%d", time.Now().UnixNano()), - "type": "message", - "role": "assistant", - "model": model, - "content": content, - "stop_reason": stopReason, - "stop_sequence": nil, - "usage": map[string]any{ - "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)), - "output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText), - }, - }) -} - -func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { - a, err := h.Auth.Determine(r) - if err != nil { - writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()}) - return - } - defer h.Auth.Release(a) - - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"}) - return - } - model, _ := req["model"].(string) - messages, _ := req["messages"].([]any) - if model == "" || len(messages) == 0 { - writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."}) - return - } - inputTokens := 0 - if sys, ok := req["system"].(string); ok { - inputTokens += util.EstimateTokens(sys) - } - for _, item := range messages { - msg, ok := item.(map[string]any) - if !ok { - continue - } - inputTokens += 2 - inputTokens += util.EstimateTokens(extractMessageContent(msg["content"])) - } - if tools, ok := req["tools"].([]any); ok { - for _, t := range tools { - b, _ := json.Marshal(t) - inputTokens += util.EstimateTokens(string(b)) - } - } - if inputTokens < 1 { - inputTokens = 1 - } - writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens}) -} - -func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) { - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}}) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - rc := http.NewResponseController(w) - canFlush := rc.Flush() == nil - if !canFlush { - config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") - } - send := func(event string, v any) { - b, _ := json.Marshal(v) - _, _ = w.Write([]byte("event: ")) - _, _ = w.Write([]byte(event)) - _, _ = w.Write([]byte("\n")) - _, _ = w.Write([]byte("data: ")) - _, _ = w.Write(b) - _, _ = w.Write([]byte("\n\n")) - if canFlush { - _ = rc.Flush() - } - } - sendError := func(message string) { - msg := strings.TrimSpace(message) - if msg == "" { - msg = "upstream stream error" - } - send("error", map[string]any{ - "type": "error", - "error": map[string]any{ - "type": "api_error", - "message": msg, - }, - }) - } - - messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano()) - inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages)) - send("message_start", map[string]any{ - "type": "message_start", - "message": map[string]any{ - "id": messageID, - "type": "message", - "role": "assistant", - "model": model, - "content": []any{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, - }, - }) - - initialType := "text" - if thinkingEnabled { - initialType = "thinking" - } - parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) - bufferToolContent := len(toolNames) > 0 - hasContent := false - lastContent := time.Now() - keepaliveCount := 0 - - thinking := strings.Builder{} - text := strings.Builder{} - - nextBlockIndex := 0 - thinkingBlockOpen := false - thinkingBlockIndex := -1 - textBlockOpen := false - textBlockIndex := -1 - ended := false - - closeThinkingBlock := func() { - if !thinkingBlockOpen { - return - } - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": thinkingBlockIndex, - }) - thinkingBlockOpen = false - thinkingBlockIndex = -1 - } - closeTextBlock := func() { - if !textBlockOpen { - return - } - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": textBlockIndex, - }) - textBlockOpen = false - textBlockIndex = -1 - } - - finalize := func(stopReason string) { - if ended { - return - } - ended = true - - closeThinkingBlock() - closeTextBlock() - - finalThinking := thinking.String() - finalText := text.String() - - if bufferToolContent { - detected := util.ParseToolCalls(finalText, toolNames) - if len(detected) > 0 { - stopReason = "tool_use" - for i, tc := range detected { - idx := nextBlockIndex + i - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": idx, - "content_block": map[string]any{ - "type": "tool_use", - "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), - "name": tc.Name, - "input": tc.Input, - }, - }) - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": idx, - }) - } - nextBlockIndex += len(detected) - } else if finalText != "" { - idx := nextBlockIndex - nextBlockIndex++ - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": idx, - "content_block": map[string]any{ - "type": "text", - "text": "", - }, - }) - send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": idx, - "delta": map[string]any{ - "type": "text_delta", - "text": finalText, - }, - }) - send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": idx, - }) - } - } - - outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) - send("message_delta", map[string]any{ - "type": "message_delta", - "delta": map[string]any{ - "stop_reason": stopReason, - "stop_sequence": nil, - }, - "usage": map[string]any{ - "output_tokens": outputTokens, - }, - }) - send("message_stop", map[string]any{"type": "message_stop"}) - } - - pingTicker := time.NewTicker(claudeStreamPingInterval) - defer pingTicker.Stop() - - for { - select { - case <-r.Context().Done(): - return - case <-pingTicker.C: - if !hasContent { - keepaliveCount++ - if keepaliveCount >= claudeStreamMaxKeepaliveCnt { - finalize("end_turn") - return - } - } - if hasContent && time.Since(lastContent) > claudeStreamIdleTimeout { - finalize("end_turn") - return - } - send("ping", map[string]any{"type": "ping"}) - case parsed, ok := <-parsedLines: - if !ok { - if err := <-done; err != nil { - sendError(err.Error()) - return - } - finalize("end_turn") - return - } - if !parsed.Parsed { - continue - } - if parsed.ErrorMessage != "" { - sendError(parsed.ErrorMessage) - return - } - if parsed.Stop { - finalize("end_turn") - return - } - - for _, p := range parsed.Parts { - if p.Text == "" { - continue - } - if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) { - continue - } - - hasContent = true - lastContent = time.Now() - keepaliveCount = 0 - - if p.Type == "thinking" { - if !thinkingEnabled { - continue - } - thinking.WriteString(p.Text) - closeTextBlock() - if !thinkingBlockOpen { - thinkingBlockIndex = nextBlockIndex - nextBlockIndex++ - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": thinkingBlockIndex, - "content_block": map[string]any{ - "type": "thinking", - "thinking": "", - }, - }) - thinkingBlockOpen = true - } - send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": thinkingBlockIndex, - "delta": map[string]any{ - "type": "thinking_delta", - "thinking": p.Text, - }, - }) - continue - } - - text.WriteString(p.Text) - if bufferToolContent { - continue - } - closeThinkingBlock() - if !textBlockOpen { - textBlockIndex = nextBlockIndex - nextBlockIndex++ - send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": textBlockIndex, - "content_block": map[string]any{ - "type": "text", - "text": "", - }, - }) - textBlockOpen = true - } - send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": textBlockIndex, - "delta": map[string]any{ - "type": "text_delta", - "text": p.Text, - }, - }) - } - } - } -} - -func normalizeClaudeMessages(messages []any) []any { - out := make([]any, 0, len(messages)) - for _, m := range messages { - msg, ok := m.(map[string]any) - if !ok { - continue - } - copied := cloneMap(msg) - switch content := msg["content"].(type) { - case []any: - parts := make([]string, 0, len(content)) - for _, block := range content { - b, ok := block.(map[string]any) - if !ok { - continue - } - typeStr, _ := b["type"].(string) - if typeStr == "text" { - if t, ok := b["text"].(string); ok { - parts = append(parts, t) - } - } - if typeStr == "tool_result" { - parts = append(parts, fmt.Sprintf("%v", b["content"])) - } - } - copied["content"] = strings.Join(parts, "\n") - } - out = append(out, copied) - } - return out -} - -func buildClaudeToolPrompt(tools []any) string { - parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"} - for _, t := range tools { - m, ok := t.(map[string]any) - if !ok { - continue - } - name, _ := m["name"].(string) - desc, _ := m["description"].(string) - schema, _ := json.Marshal(m["input_schema"]) - parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) - } - parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}") - return strings.Join(parts, "\n\n") -} - -func hasSystemMessage(messages []any) bool { - for _, m := range messages { - msg, ok := m.(map[string]any) - if ok && msg["role"] == "system" { - return true - } - } - return false -} - -func extractClaudeToolNames(tools []any) []string { - out := make([]string, 0, len(tools)) - for _, t := range tools { - m, ok := t.(map[string]any) - if !ok { - continue - } - if name, ok := m["name"].(string); ok && name != "" { - out = append(out, name) - } - } - return out -} - -func toMessageMaps(v any) []map[string]any { - arr, ok := v.([]any) - if !ok { - return nil - } - out := make([]map[string]any, 0, len(arr)) - for _, item := range arr { - if m, ok := item.(map[string]any); ok { - out = append(out, m) - } - } - return out -} - -func extractMessageContent(v any) string { - switch x := v.(type) { - case string: - return x - case []any: - parts := make([]string, 0, len(x)) - for _, it := range x { - parts = append(parts, fmt.Sprintf("%v", it)) - } - return strings.Join(parts, "\n") - default: - return fmt.Sprintf("%v", x) - } -} - -func cloneMap(in map[string]any) map[string]any { - out := make(map[string]any, len(in)) - for k, v := range in { - out[k] = v - } - return out -} diff --git a/internal/adapter/claude/handler_errors.go b/internal/adapter/claude/handler_errors.go new file mode 100644 index 0000000..f1188d6 --- /dev/null +++ b/internal/adapter/claude/handler_errors.go @@ -0,0 +1,25 @@ +package claude + +import "net/http" + +func writeClaudeError(w http.ResponseWriter, status int, message string) { + code := "invalid_request" + switch status { + case http.StatusUnauthorized: + code = "authentication_failed" + case http.StatusTooManyRequests: + code = "rate_limit_exceeded" + case http.StatusNotFound: + code = "not_found" + case http.StatusInternalServerError: + code = "internal_error" + } + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "type": "invalid_request_error", + "message": message, + "code": code, + "param": nil, + }, + }) +} diff --git a/internal/adapter/claude/handler_messages.go b/internal/adapter/claude/handler_messages.go new file mode 100644 index 0000000..1c4272b --- /dev/null +++ b/internal/adapter/claude/handler_messages.go @@ -0,0 +1,134 @@ +package claude + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" + claudefmt "ds2api/internal/format/claude" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" +) + +func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { + if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" { + r.Header.Set("anthropic-version", "2023-06-01") + } + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeClaudeError(w, status, detail) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeClaudeError(w, http.StatusBadRequest, "invalid json") + return + } + norm, err := normalizeClaudeRequest(h.Store, req) + if err != nil { + writeClaudeError(w, http.StatusBadRequest, err.Error()) + return + } + stdReq := norm.Standard + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + writeClaudeError(w, http.StatusUnauthorized, "invalid token.") + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW") + return + } + requestPayload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) + if err != nil { + writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.") + return + } + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + writeClaudeError(w, http.StatusInternalServerError, string(body)) + return + } + + if stdReq.Stream { + h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + return + } + result := sse.CollectStream(resp, stdReq.Thinking, true) + respBody := claudefmt.BuildMessageResponse( + fmt.Sprintf("msg_%d", time.Now().UnixNano()), + stdReq.ResponseModel, + norm.NormalizedMessages, + result.Thinking, + result.Text, + stdReq.ToolNames, + ) + writeJSON(w, http.StatusOK, respBody) +} + +func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeClaudeError(w, http.StatusInternalServerError, string(body)) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + rc := http.NewResponseController(w) + _, canFlush := w.(http.Flusher) + if !canFlush { + config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") + } + + streamRuntime := newClaudeStreamRuntime( + w, + rc, + canFlush, + model, + messages, + thinkingEnabled, + searchEnabled, + toolNames, + ) + streamRuntime.sendMessageStart() + + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: claudeStreamPingInterval, + IdleTimeout: claudeStreamIdleTimeout, + MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt, + }, streamengine.ConsumeHooks{ + OnKeepAlive: func() { + streamRuntime.sendPing() + }, + OnParsed: streamRuntime.onParsed, + OnFinalize: streamRuntime.onFinalize, + }) +} diff --git a/internal/adapter/claude/handler_routes.go b/internal/adapter/claude/handler_routes.go new file mode 100644 index 0000000..0376b2c --- /dev/null +++ b/internal/adapter/claude/handler_routes.go @@ -0,0 +1,41 @@ +package claude + +import ( + "net/http" + "time" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/config" + "ds2api/internal/deepseek" + "ds2api/internal/util" +) + +// writeJSON is a package-internal alias to avoid mass-renaming all call-sites. +var writeJSON = util.WriteJSON + +type Handler struct { + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller +} + +var ( + claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second + claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second + claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount +) + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Get("/anthropic/v1/models", h.ListModels) + r.Post("/anthropic/v1/messages", h.Messages) + r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens) + r.Post("/v1/messages", h.Messages) + r.Post("/messages", h.Messages) + r.Post("/v1/messages/count_tokens", h.CountTokens) + r.Post("/messages/count_tokens", h.CountTokens) +} + +func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, config.ClaudeModelsResponse()) +} diff --git a/internal/adapter/claude/handler_tokens.go b/internal/adapter/claude/handler_tokens.go new file mode 100644 index 0000000..a369345 --- /dev/null +++ b/internal/adapter/claude/handler_tokens.go @@ -0,0 +1,51 @@ +package claude + +import ( + "encoding/json" + "net/http" + + "ds2api/internal/util" +) + +func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + writeClaudeError(w, http.StatusUnauthorized, err.Error()) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeClaudeError(w, http.StatusBadRequest, "invalid json") + return + } + model, _ := req["model"].(string) + messages, _ := req["messages"].([]any) + if model == "" || len(messages) == 0 { + writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + return + } + inputTokens := 0 + if sys, ok := req["system"].(string); ok { + inputTokens += util.EstimateTokens(sys) + } + for _, item := range messages { + msg, ok := item.(map[string]any) + if !ok { + continue + } + inputTokens += 2 + inputTokens += util.EstimateTokens(extractMessageContent(msg["content"])) + } + if tools, ok := req["tools"].([]any); ok { + for _, t := range tools { + b, _ := json.Marshal(t) + inputTokens += util.EstimateTokens(string(b)) + } + } + if inputTokens < 1 { + inputTokens = 1 + } + writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens}) +} diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go new file mode 100644 index 0000000..ae75d8e --- /dev/null +++ b/internal/adapter/claude/handler_util_test.go @@ -0,0 +1,350 @@ +package claude + +import ( + "strings" + "testing" +) + +// ─── normalizeClaudeMessages ───────────────────────────────────────── + +func TestNormalizeClaudeMessagesSimpleString(t *testing.T) { + msgs := []any{ + map[string]any{"role": "user", "content": "Hello"}, + } + got := normalizeClaudeMessages(msgs) + if len(got) != 1 { + t.Fatalf("expected 1 message, got %d", len(got)) + } + m := got[0].(map[string]any) + if m["content"] != "Hello" { + t.Fatalf("expected 'Hello', got %v", m["content"]) + } +} + +func TestNormalizeClaudeMessagesArrayContent(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "line1"}, + map[string]any{"type": "text", "text": "line2"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "line1\nline2" { + t.Fatalf("expected joined text, got %q", m["content"]) + } +} + +func TestNormalizeClaudeMessagesToolResult(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "tool_result", "content": "tool output"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + content, _ := m["content"].(string) + if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") { + t.Fatalf("expected serialized tool result marker, got %q", content) + } +} + +func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) { + msgs := []any{"not a map", 42} + got := normalizeClaudeMessages(msgs) + if len(got) != 0 { + t.Fatalf("expected 0 messages for non-map items, got %d", len(got)) + } +} + +func TestNormalizeClaudeMessagesEmpty(t *testing.T) { + got := normalizeClaudeMessages(nil) + if len(got) != 0 { + t.Fatalf("expected 0, got %d", len(got)) + } +} + +func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) { + msgs := []any{ + map[string]any{"role": "assistant", "content": "response"}, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["role"] != "assistant" { + t.Fatalf("expected 'assistant', got %q", m["role"]) + } +} + +func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "Hello"}, + map[string]any{"type": "image", "source": "data:..."}, + map[string]any{"type": "text", "text": "World"}, + }, + }, + } + got := normalizeClaudeMessages(msgs) + m := got[0].(map[string]any) + if m["content"] != "Hello\nWorld" { + t.Fatalf("expected only text parts joined, got %q", m["content"]) + } +} + +// ─── buildClaudeToolPrompt ─────────────────────────────────────────── + +func TestBuildClaudeToolPromptSingleTool(t *testing.T) { + tools := []any{ + map[string]any{ + "name": "search", + "description": "Search the web", + "input_schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }, + } + prompt := buildClaudeToolPrompt(tools) + if prompt == "" { + t.Fatal("expected non-empty prompt") + } + // Should contain tool name and description + if !containsStr(prompt, "search") { + t.Fatalf("expected 'search' in prompt") + } + if !containsStr(prompt, "Search the web") { + t.Fatalf("expected description in prompt") + } + if !containsStr(prompt, "tool_calls") { + t.Fatalf("expected tool_calls instruction in prompt") + } +} + +func TestBuildClaudeToolPromptMultipleTools(t *testing.T) { + tools := []any{ + map[string]any{"name": "tool1", "description": "desc1"}, + map[string]any{"name": "tool2", "description": "desc2"}, + } + prompt := buildClaudeToolPrompt(tools) + if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") { + t.Fatalf("expected both tools in prompt") + } +} + +func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) { + tools := []any{"not a map"} + prompt := buildClaudeToolPrompt(tools) + if prompt == "" { + t.Fatal("expected non-empty prompt even with invalid tools") + } + // Should still contain the intro and instruction + if !containsStr(prompt, "You are Claude") { + t.Fatalf("expected intro in prompt") + } +} + +// ─── hasSystemMessage ──────────────────────────────────────────────── + +func TestHasSystemMessageTrue(t *testing.T) { + msgs := []any{ + map[string]any{"role": "system", "content": "You are a helper"}, + map[string]any{"role": "user", "content": "Hi"}, + } + if !hasSystemMessage(msgs) { + t.Fatal("expected true") + } +} + +func TestHasSystemMessageFalse(t *testing.T) { + msgs := []any{ + map[string]any{"role": "user", "content": "Hi"}, + map[string]any{"role": "assistant", "content": "Hello"}, + } + if hasSystemMessage(msgs) { + t.Fatal("expected false") + } +} + +func TestHasSystemMessageEmpty(t *testing.T) { + if hasSystemMessage(nil) { + t.Fatal("expected false for nil") + } +} + +func TestHasSystemMessageNonMap(t *testing.T) { + msgs := []any{"not a map"} + if hasSystemMessage(msgs) { + t.Fatal("expected false for non-map") + } +} + +// ─── extractClaudeToolNames ────────────────────────────────────────── + +func TestExtractClaudeToolNamesSingle(t *testing.T) { + tools := []any{ + map[string]any{"name": "search"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 1 || names[0] != "search" { + t.Fatalf("expected [search], got %v", names) + } +} + +func TestExtractClaudeToolNamesMultiple(t *testing.T) { + tools := []any{ + map[string]any{"name": "search"}, + map[string]any{"name": "calculate"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 2 { + t.Fatalf("expected 2 names, got %v", names) + } +} + +func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) { + tools := []any{ + map[string]any{"name": ""}, + map[string]any{"name": "valid"}, + } + names := extractClaudeToolNames(tools) + if len(names) != 1 || names[0] != "valid" { + t.Fatalf("expected [valid], got %v", names) + } +} + +func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) { + tools := []any{"not a map", 42} + names := extractClaudeToolNames(tools) + if len(names) != 0 { + t.Fatalf("expected 0, got %v", names) + } +} + +func TestExtractClaudeToolNamesNil(t *testing.T) { + names := extractClaudeToolNames(nil) + if len(names) != 0 { + t.Fatalf("expected 0, got %v", names) + } +} + +// ─── toMessageMaps ─────────────────────────────────────────────────── + +func TestToMessageMapsNormal(t *testing.T) { + input := []any{ + map[string]any{"role": "user", "content": "Hello"}, + } + got := toMessageMaps(input) + if len(got) != 1 { + t.Fatalf("expected 1, got %d", len(got)) + } +} + +func TestToMessageMapsNonSlice(t *testing.T) { + got := toMessageMaps("not a slice") + if got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +func TestToMessageMapsSkipsNonMap(t *testing.T) { + input := []any{"string", map[string]any{"role": "user"}, 42} + got := toMessageMaps(input) + if len(got) != 1 { + t.Fatalf("expected 1 map, got %d", len(got)) + } +} + +func TestToMessageMapsNil(t *testing.T) { + got := toMessageMaps(nil) + if got != nil { + t.Fatalf("expected nil, got %v", got) + } +} + +// ─── extractMessageContent ────────────────────────────────────────── + +func TestExtractMessageContentString(t *testing.T) { + if got := extractMessageContent("hello"); got != "hello" { + t.Fatalf("expected 'hello', got %q", got) + } +} + +func TestExtractMessageContentArray(t *testing.T) { + input := []any{"part1", "part2"} + got := extractMessageContent(input) + if got != "part1\npart2" { + t.Fatalf("expected joined, got %q", got) + } +} + +func TestExtractMessageContentOther(t *testing.T) { + got := extractMessageContent(42) + if got != "42" { + t.Fatalf("expected '42', got %q", got) + } +} + +func TestExtractMessageContentNil(t *testing.T) { + got := extractMessageContent(nil) + if got != "" { + t.Fatalf("expected '', got %q", got) + } +} + +// ─── cloneMap ──────────────────────────────────────────────────────── + +func TestCloneMapBasic(t *testing.T) { + original := map[string]any{"a": 1, "b": "hello"} + clone := cloneMap(original) + original["a"] = 999 + if clone["a"] != 1 { + t.Fatalf("expected 1, got %v", clone["a"]) + } + if clone["b"] != "hello" { + t.Fatalf("expected 'hello', got %v", clone["b"]) + } +} + +func TestCloneMapEmpty(t *testing.T) { + clone := cloneMap(map[string]any{}) + if len(clone) != 0 { + t.Fatalf("expected empty, got %v", clone) + } +} + +func TestCloneMapNested(t *testing.T) { + // cloneMap is shallow, so nested maps share references + inner := map[string]any{"key": "value"} + original := map[string]any{"nested": inner} + clone := cloneMap(original) + // Shallow clone means inner is shared + inner["key"] = "modified" + cloneNested := clone["nested"].(map[string]any) + if cloneNested["key"] != "modified" { + t.Fatal("expected shallow clone to share nested references") + } +} + +// helper +func containsStr(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub)) +} + +func findSubstring(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go new file mode 100644 index 0000000..df4c6b2 --- /dev/null +++ b/internal/adapter/claude/handler_utils.go @@ -0,0 +1,143 @@ +package claude + +import ( + "encoding/json" + "fmt" + "strings" +) + +func normalizeClaudeMessages(messages []any) []any { + out := make([]any, 0, len(messages)) + for _, m := range messages { + msg, ok := m.(map[string]any) + if !ok { + continue + } + copied := cloneMap(msg) + switch content := msg["content"].(type) { + case []any: + parts := make([]string, 0, len(content)) + for _, block := range content { + b, ok := block.(map[string]any) + if !ok { + continue + } + typeStr, _ := b["type"].(string) + if typeStr == "text" { + if t, ok := b["text"].(string); ok { + parts = append(parts, t) + } + } + if typeStr == "tool_result" { + parts = append(parts, formatClaudeToolResultForPrompt(b)) + } + } + copied["content"] = strings.Join(parts, "\n") + } + out = append(out, copied) + } + return out +} + +func buildClaudeToolPrompt(tools []any) string { + parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"} + for _, t := range tools { + m, ok := t.(map[string]any) + if !ok { + continue + } + name, _ := m["name"].(string) + desc, _ := m["description"].(string) + schema, _ := json.Marshal(m["input_schema"]) + parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) + } + parts = append(parts, + "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}", + "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.", + "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.", + ) + return strings.Join(parts, "\n\n") +} + +func formatClaudeToolResultForPrompt(block map[string]any) string { + if block == nil { + return "" + } + toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])) + if toolCallID == "" { + toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])) + } + if toolCallID == "" { + toolCallID = "unknown" + } + name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])) + if name == "" { + name = "unknown" + } + content := strings.TrimSpace(fmt.Sprintf("%v", block["content"])) + if content == "" { + content = "null" + } + return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content) +} + +func hasSystemMessage(messages []any) bool { + for _, m := range messages { + msg, ok := m.(map[string]any) + if ok && msg["role"] == "system" { + return true + } + } + return false +} + +func extractClaudeToolNames(tools []any) []string { + out := make([]string, 0, len(tools)) + for _, t := range tools { + m, ok := t.(map[string]any) + if !ok { + continue + } + if name, ok := m["name"].(string); ok && name != "" { + out = append(out, name) + } + } + return out +} + +func toMessageMaps(v any) []map[string]any { + arr, ok := v.([]any) + if !ok { + return nil + } + out := make([]map[string]any, 0, len(arr)) + for _, item := range arr { + if m, ok := item.(map[string]any); ok { + out = append(out, m) + } + } + return out +} + +func extractMessageContent(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, it := range x { + parts = append(parts, fmt.Sprintf("%v", it)) + } + return strings.Join(parts, "\n") + default: + return fmt.Sprintf("%v", x) + } +} + +func cloneMap(in map[string]any) map[string]any { + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/internal/adapter/claude/route_alias_test.go b/internal/adapter/claude/route_alias_test.go new file mode 100644 index 0000000..f01e5e3 --- /dev/null +++ b/internal/adapter/claude/route_alias_test.go @@ -0,0 +1,44 @@ +package claude + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" +) + +type routeAliasAuthStub struct{} + +func (routeAliasAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) { + return nil, auth.ErrUnauthorized +} + +func (routeAliasAuthStub) Release(_ *auth.RequestAuth) {} + +func TestClaudeRouteAliasesDoNot404(t *testing.T) { + h := &Handler{ + Auth: routeAliasAuthStub{}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + paths := []string{ + "/anthropic/v1/messages", + "/v1/messages", + "/messages", + "/anthropic/v1/messages/count_tokens", + "/v1/messages/count_tokens", + "/messages/count_tokens", + } + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code == http.StatusNotFound { + t.Fatalf("expected route %s to be registered, got 404", path) + } + } +} diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go new file mode 100644 index 0000000..23520c0 --- /dev/null +++ b/internal/adapter/claude/standard_request.go @@ -0,0 +1,113 @@ +package claude + +import ( + "fmt" + "strings" + + "ds2api/internal/config" + "ds2api/internal/deepseek" + "ds2api/internal/util" +) + +type claudeNormalizedRequest struct { + Standard util.StandardRequest + NormalizedMessages []any +} + +func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNormalizedRequest, error) { + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { + return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.") + } + if _, ok := req["max_tokens"]; !ok { + req["max_tokens"] = 8192 + } + normalizedMessages := normalizeClaudeMessages(messagesRaw) + payload := cloneMap(req) + payload["messages"] = normalizedMessages + toolsRequested, _ := req["tools"].([]any) + payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested) + + dsPayload := convertClaudeToDeepSeek(payload, store) + dsModel, _ := dsPayload["model"].(string) + thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) + if !ok { + thinkingEnabled = false + searchEnabled = false + } + finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + toolNames := extractClaudeToolNames(toolsRequested) + + return claudeNormalizedRequest{ + Standard: util.StandardRequest{ + Surface: "anthropic_messages", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: dsModel, + ResponseModel: strings.TrimSpace(model), + Messages: payload["messages"].([]any), + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + }, + NormalizedMessages: normalizedMessages, + }, nil +} + +func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any { + if len(tools) == 0 { + return normalizedMessages + } + toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools)) + if toolPrompt == "" { + return normalizedMessages + } + + // Prefer top-level Anthropic-style system prompt when available. + if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" { + payload["system"] = mergeSystemPrompt(systemText, toolPrompt) + return normalizedMessages + } + + messages := cloneAnySlice(normalizedMessages) + for i := range messages { + msg, ok := messages[i].(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if !strings.EqualFold(strings.TrimSpace(role), "system") { + continue + } + copied := cloneMap(msg) + copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt) + messages[i] = copied + return messages + } + + return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...) +} + +func mergeSystemPrompt(base, extra string) string { + base = strings.TrimSpace(base) + extra = strings.TrimSpace(extra) + switch { + case base == "": + return extra + case extra == "": + return base + default: + return base + "\n\n" + extra + } +} + +func cloneAnySlice(in []any) []any { + if len(in) == 0 { + return nil + } + out := make([]any, len(in)) + copy(out, in) + return out +} diff --git a/internal/adapter/claude/standard_request_test.go b/internal/adapter/claude/standard_request_test.go new file mode 100644 index 0000000..6110124 --- /dev/null +++ b/internal/adapter/claude/standard_request_test.go @@ -0,0 +1,92 @@ +package claude + +import ( + "testing" + + "ds2api/internal/config" +) + +func TestNormalizeClaudeRequest(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "stream": true, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if norm.Standard.ResolvedModel == "" { + t.Fatalf("expected resolved model") + } + if !norm.Standard.Stream { + t.Fatalf("expected stream=true") + } + if len(norm.Standard.ToolNames) == 0 { + t.Fatalf("expected tool names") + } + if norm.Standard.FinalPrompt == "" { + t.Fatalf("expected non-empty final prompt") + } +} + +func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{ + map[string]any{"role": "system", "content": "baseline rule"}, + map[string]any{"role": "user", "content": "hello"}, + }, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + + if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") { + t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt) + } + if !containsStr(norm.Standard.FinalPrompt, "baseline rule") { + t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt) + } +} + +func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "system": "top-level system", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + + if !containsStr(norm.Standard.FinalPrompt, "top-level system") { + t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt) + } + if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") { + t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt) + } +} diff --git a/internal/adapter/claude/stream_runtime_core.go b/internal/adapter/claude/stream_runtime_core.go new file mode 100644 index 0000000..cb24bdd --- /dev/null +++ b/internal/adapter/claude/stream_runtime_core.go @@ -0,0 +1,146 @@ +package claude + +import ( + "fmt" + "net/http" + "strings" + "time" + + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" +) + +type claudeStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + model string + toolNames []string + messages []any + + thinkingEnabled bool + searchEnabled bool + bufferToolContent bool + + messageID string + thinking strings.Builder + text strings.Builder + + nextBlockIndex int + thinkingBlockOpen bool + thinkingBlockIndex int + textBlockOpen bool + textBlockIndex int + ended bool + upstreamErr string +} + +func newClaudeStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + model string, + messages []any, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, +) *claudeStreamRuntime { + return &claudeStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + model: model, + messages: messages, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferToolContent: len(toolNames) > 0, + toolNames: toolNames, + messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), + thinkingBlockIndex: -1, + textBlockIndex: -1, + } +} + +func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ErrorMessage != "" { + s.upstreamErr = parsed.ErrorMessage + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")} + } + if parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + + if p.Type == "thinking" { + if !s.thinkingEnabled { + continue + } + s.thinking.WriteString(p.Text) + s.closeTextBlock() + if !s.thinkingBlockOpen { + s.thinkingBlockIndex = s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": s.thinkingBlockIndex, + "content_block": map[string]any{ + "type": "thinking", + "thinking": "", + }, + }) + s.thinkingBlockOpen = true + } + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": s.thinkingBlockIndex, + "delta": map[string]any{ + "type": "thinking_delta", + "thinking": p.Text, + }, + }) + continue + } + + s.text.WriteString(p.Text) + if s.bufferToolContent { + continue + } + s.closeThinkingBlock() + if !s.textBlockOpen { + s.textBlockIndex = s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": s.textBlockIndex, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + s.textBlockOpen = true + } + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": s.textBlockIndex, + "delta": map[string]any{ + "type": "text_delta", + "text": p.Text, + }, + }) + } + + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} diff --git a/internal/adapter/claude/stream_runtime_emit.go b/internal/adapter/claude/stream_runtime_emit.go new file mode 100644 index 0000000..c2fba19 --- /dev/null +++ b/internal/adapter/claude/stream_runtime_emit.go @@ -0,0 +1,59 @@ +package claude + +import ( + "encoding/json" + "fmt" + "strings" + + "ds2api/internal/util" +) + +func (s *claudeStreamRuntime) send(event string, v any) { + b, _ := json.Marshal(v) + _, _ = s.w.Write([]byte("event: ")) + _, _ = s.w.Write([]byte(event)) + _, _ = s.w.Write([]byte("\n")) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *claudeStreamRuntime) sendError(message string) { + msg := strings.TrimSpace(message) + if msg == "" { + msg = "upstream stream error" + } + s.send("error", map[string]any{ + "type": "error", + "error": map[string]any{ + "type": "api_error", + "message": msg, + "code": "internal_error", + "param": nil, + }, + }) +} + +func (s *claudeStreamRuntime) sendPing() { + s.send("ping", map[string]any{"type": "ping"}) +} + +func (s *claudeStreamRuntime) sendMessageStart() { + inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages)) + s.send("message_start", map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": s.messageID, + "type": "message", + "role": "assistant", + "model": s.model, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, + }, + }) +} diff --git a/internal/adapter/claude/stream_runtime_finalize.go b/internal/adapter/claude/stream_runtime_finalize.go new file mode 100644 index 0000000..f957ba1 --- /dev/null +++ b/internal/adapter/claude/stream_runtime_finalize.go @@ -0,0 +1,119 @@ +package claude + +import ( + "fmt" + "time" + + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +func (s *claudeStreamRuntime) closeThinkingBlock() { + if !s.thinkingBlockOpen { + return + } + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": s.thinkingBlockIndex, + }) + s.thinkingBlockOpen = false + s.thinkingBlockIndex = -1 +} + +func (s *claudeStreamRuntime) closeTextBlock() { + if !s.textBlockOpen { + return + } + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": s.textBlockIndex, + }) + s.textBlockOpen = false + s.textBlockIndex = -1 +} + +func (s *claudeStreamRuntime) finalize(stopReason string) { + if s.ended { + return + } + s.ended = true + + s.closeThinkingBlock() + s.closeTextBlock() + + finalThinking := s.thinking.String() + finalText := s.text.String() + + if s.bufferToolContent { + detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + idx := s.nextBlockIndex + i + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), + "name": tc.Name, + "input": tc.Input, + }, + }) + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + s.nextBlockIndex += len(detected) + } else if finalText != "" { + idx := s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": idx, + "delta": map[string]any{ + "type": "text_delta", + "text": finalText, + }, + }) + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + } + + outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) + s.send("message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]any{ + "output_tokens": outputTokens, + }, + }) + s.send("message_stop", map[string]any{"type": "message_stop"}) +} + +func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) { + if string(reason) == "upstream_error" { + s.sendError(s.upstreamErr) + return + } + if scannerErr != nil { + s.sendError(scannerErr.Error()) + return + } + s.finalize("end_turn") +} diff --git a/internal/adapter/claude/stream_status_test.go b/internal/adapter/claude/stream_status_test.go new file mode 100644 index 0000000..c3936de --- /dev/null +++ b/internal/adapter/claude/stream_status_test.go @@ -0,0 +1,100 @@ +package claude + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + chimw "github.com/go-chi/chi/v5/middleware" + + "ds2api/internal/auth" +) + +type streamStatusClaudeAuthStub struct{} + +func (streamStatusClaudeAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) { + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (streamStatusClaudeAuthStub) Release(_ *auth.RequestAuth) {} + +type streamStatusClaudeDSStub struct{} + +func (streamStatusClaudeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "session-id", nil +} + +func (streamStatusClaudeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "pow", nil +} + +func (streamStatusClaudeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) { + body := "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n" + "data: [DONE]\n" + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: ioNopCloser{strings.NewReader(body)}, + }, nil +} + +type ioNopCloser struct { + *strings.Reader +} + +func (ioNopCloser) Close() error { return nil } + +type streamStatusClaudeStoreStub struct{} + +func (streamStatusClaudeStoreStub) ClaudeMapping() map[string]string { + return map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner", + } +} + +func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor) + next.ServeHTTP(ww, r) + *statuses = append(*statuses, ww.Status()) + }) + } +} + +func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: streamStatusClaudeStoreStub{}, + Auth: streamStatusClaudeAuthStub{}, + DS: streamStatusClaudeDSStub{}, + } + r := chi.NewRouter() + r.Use(captureClaudeStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}` + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 { + t.Fatalf("expected one captured status, got %d", len(statuses)) + } + if statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0]) + } +} diff --git a/internal/adapter/gemini/convert_messages.go b/internal/adapter/gemini/convert_messages.go new file mode 100644 index 0000000..1148a7a --- /dev/null +++ b/internal/adapter/gemini/convert_messages.go @@ -0,0 +1,153 @@ +package gemini + +import "strings" + +func geminiMessagesFromRequest(req map[string]any) []any { + out := make([]any, 0, 8) + if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" { + out = append(out, map[string]any{ + "role": "system", + "content": sys, + }) + } + + contents, _ := req["contents"].([]any) + for _, item := range contents { + content, ok := item.(map[string]any) + if !ok { + continue + } + role := mapGeminiRole(content["role"]) + if role == "" { + role = "user" + } + parts, _ := content["parts"].([]any) + if len(parts) == 0 { + if text := strings.TrimSpace(asString(content["text"])); text != "" { + out = append(out, map[string]any{ + "role": role, + "content": text, + }) + } + continue + } + + textParts := make([]string, 0, len(parts)) + flushText := func() { + if len(textParts) == 0 { + return + } + out = append(out, map[string]any{ + "role": role, + "content": strings.Join(textParts, "\n"), + }) + textParts = textParts[:0] + } + + for _, rawPart := range parts { + part, ok := rawPart.(map[string]any) + if !ok { + continue + } + if text := strings.TrimSpace(asString(part["text"])); text != "" { + textParts = append(textParts, text) + continue + } + + if fnCall, ok := part["functionCall"].(map[string]any); ok { + flushText() + if name := strings.TrimSpace(asString(fnCall["name"])); name != "" { + callID := strings.TrimSpace(asString(fnCall["id"])) + if callID == "" { + callID = "call_gemini" + } + out = append(out, map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": callID, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": stringifyJSON(fnCall["args"]), + }, + }, + }, + }) + } + continue + } + + if fnResp, ok := part["functionResponse"].(map[string]any); ok { + flushText() + name := strings.TrimSpace(asString(fnResp["name"])) + callID := strings.TrimSpace(asString(fnResp["id"])) + if callID == "" { + callID = strings.TrimSpace(asString(fnResp["callId"])) + } + if callID == "" { + callID = strings.TrimSpace(asString(fnResp["tool_call_id"])) + } + if callID == "" { + callID = "call_gemini" + } + content := fnResp["response"] + if content == nil { + content = fnResp["output"] + } + if content == nil { + content = "" + } + msg := map[string]any{ + "role": "tool", + "tool_call_id": callID, + "content": content, + } + if name != "" { + msg["name"] = name + } + out = append(out, msg) + } + } + flushText() + } + return out +} + +func normalizeGeminiSystemInstruction(raw any) string { + switch v := raw.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if parts, ok := v["parts"].([]any); ok { + texts := make([]string, 0, len(parts)) + for _, item := range parts { + part, ok := item.(map[string]any) + if !ok { + continue + } + if text := strings.TrimSpace(asString(part["text"])); text != "" { + texts = append(texts, text) + } + } + return strings.Join(texts, "\n") + } + if text := strings.TrimSpace(asString(v["text"])); text != "" { + return text + } + } + return "" +} + +func mapGeminiRole(v any) string { + switch strings.ToLower(strings.TrimSpace(asString(v))) { + case "user": + return "user" + case "model", "assistant": + return "assistant" + case "system": + return "system" + default: + return "" + } +} diff --git a/internal/adapter/gemini/convert_passthrough.go b/internal/adapter/gemini/convert_passthrough.go new file mode 100644 index 0000000..05cd6cd --- /dev/null +++ b/internal/adapter/gemini/convert_passthrough.go @@ -0,0 +1,54 @@ +package gemini + +import ( + "encoding/json" + "strings" +) + +func collectGeminiPassThrough(req map[string]any) map[string]any { + cfg, _ := req["generationConfig"].(map[string]any) + if len(cfg) == 0 { + return nil + } + out := map[string]any{} + if v, ok := cfg["temperature"]; ok { + out["temperature"] = v + } + if v, ok := cfg["topP"]; ok { + out["top_p"] = v + } + if v, ok := cfg["maxOutputTokens"]; ok { + out["max_tokens"] = v + } + if v, ok := cfg["stopSequences"]; ok { + out["stop"] = v + } + if len(out) == 0 { + return nil + } + return out +} + +func asString(v any) string { + s, _ := v.(string) + return s +} + +func stringifyJSON(v any) string { + switch x := v.(type) { + case nil: + return "{}" + case string: + s := strings.TrimSpace(x) + if s == "" { + return "{}" + } + return s + default: + b, err := json.Marshal(x) + if err != nil || len(b) == 0 { + return "{}" + } + return string(b) + } +} diff --git a/internal/adapter/gemini/convert_request.go b/internal/adapter/gemini/convert_request.go new file mode 100644 index 0000000..2eca687 --- /dev/null +++ b/internal/adapter/gemini/convert_request.go @@ -0,0 +1,46 @@ +package gemini + +import ( + "fmt" + "strings" + + "ds2api/internal/adapter/openai" + "ds2api/internal/config" + "ds2api/internal/util" +) + +func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) { + requestedModel := strings.TrimSpace(routeModel) + if requestedModel == "" { + return util.StandardRequest{}, fmt.Errorf("model is required in request path") + } + + resolvedModel, ok := config.ResolveModel(store, requestedModel) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + + messagesRaw := geminiMessagesFromRequest(req) + if len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.") + } + + toolsRaw := convertGeminiTools(req["tools"]) + finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "") + passThrough := collectGeminiPassThrough(req) + + return util.StandardRequest{ + Surface: "google_gemini", + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + ResponseModel: requestedModel, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: stream, + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} diff --git a/internal/adapter/gemini/convert_tools.go b/internal/adapter/gemini/convert_tools.go new file mode 100644 index 0000000..4611f85 --- /dev/null +++ b/internal/adapter/gemini/convert_tools.go @@ -0,0 +1,71 @@ +package gemini + +import "strings" + +func convertGeminiTools(raw any) []any { + tools, _ := raw.([]any) + if len(tools) == 0 { + return nil + } + out := make([]any, 0, len(tools)) + for _, item := range tools { + tool, ok := item.(map[string]any) + if !ok { + continue + } + + if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 { + for _, declRaw := range fnDecls { + decl, ok := declRaw.(map[string]any) + if !ok { + continue + } + name := strings.TrimSpace(asString(decl["name"])) + if name == "" { + continue + } + function := map[string]any{ + "name": name, + } + if desc := strings.TrimSpace(asString(decl["description"])); desc != "" { + function["description"] = desc + } + if params, ok := decl["parameters"].(map[string]any); ok { + function["parameters"] = params + } + out = append(out, map[string]any{ + "type": "function", + "function": function, + }) + } + continue + } + + // OpenAI-style passthrough fallback. + if _, ok := tool["function"].(map[string]any); ok { + out = append(out, tool) + continue + } + + // Loose fallback for flattened function schema objects. + name := strings.TrimSpace(asString(tool["name"])) + if name == "" { + continue + } + fn := map[string]any{"name": name} + if desc := strings.TrimSpace(asString(tool["description"])); desc != "" { + fn["description"] = desc + } + if params, ok := tool["parameters"].(map[string]any); ok { + fn["parameters"] = params + } + out = append(out, map[string]any{ + "type": "function", + "function": fn, + }) + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/adapter/gemini/deps.go b/internal/adapter/gemini/deps.go new file mode 100644 index 0000000..312114a --- /dev/null +++ b/internal/adapter/gemini/deps.go @@ -0,0 +1,29 @@ +package gemini + +import ( + "context" + "net/http" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type AuthResolver interface { + Determine(req *http.Request) (*auth.RequestAuth, error) + Release(a *auth.RequestAuth) +} + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type ConfigReader interface { + ModelAliases() map[string]string +} + +var _ AuthResolver = (*auth.Resolver)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) +var _ ConfigReader = (*config.Store)(nil) diff --git a/internal/adapter/gemini/handler_errors.go b/internal/adapter/gemini/handler_errors.go new file mode 100644 index 0000000..09df09b --- /dev/null +++ b/internal/adapter/gemini/handler_errors.go @@ -0,0 +1,28 @@ +package gemini + +import "net/http" + +func writeGeminiError(w http.ResponseWriter, status int, message string) { + errorStatus := "INVALID_ARGUMENT" + switch status { + case http.StatusUnauthorized: + errorStatus = "UNAUTHENTICATED" + case http.StatusForbidden: + errorStatus = "PERMISSION_DENIED" + case http.StatusTooManyRequests: + errorStatus = "RESOURCE_EXHAUSTED" + case http.StatusNotFound: + errorStatus = "NOT_FOUND" + default: + if status >= 500 { + errorStatus = "INTERNAL" + } + } + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "code": status, + "message": message, + "status": errorStatus, + }, + }) +} diff --git a/internal/adapter/gemini/handler_generate.go b/internal/adapter/gemini/handler_generate.go new file mode 100644 index 0000000..9144a42 --- /dev/null +++ b/internal/adapter/gemini/handler_generate.go @@ -0,0 +1,135 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" + "ds2api/internal/sse" + "ds2api/internal/util" +) + +func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeGeminiError(w, status, detail) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeGeminiError(w, http.StatusBadRequest, "invalid json") + return + } + + routeModel := strings.TrimSpace(chi.URLParam(r, "model")) + stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream) + if err != nil { + writeGeminiError(w, http.StatusBadRequest, err.Error()) + return + } + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeGeminiError(w, http.StatusUnauthorized, "Invalid token.") + } + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + payload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.") + return + } + + if stream { + h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + return + } + h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) +} + +func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + + result := sse.CollectStream(resp, thinkingEnabled, true) + writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames)) +} + +func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) + usage := buildGeminiUsage(finalPrompt, finalThinking, finalText) + return map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + "finishReason": "STOP", + }, + }, + "modelVersion": model, + "usageMetadata": usage, + } +} + +func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "promptTokenCount": promptTokens, + "candidatesTokenCount": reasoningTokens + completionTokens, + "totalTokenCount": promptTokens + reasoningTokens + completionTokens, + } +} + +func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { + detected = util.ParseToolCalls(finalThinking, toolNames) + } + if len(detected) > 0 { + parts := make([]map[string]any, 0, len(detected)) + for _, tc := range detected { + parts = append(parts, map[string]any{ + "functionCall": map[string]any{ + "name": tc.Name, + "args": tc.Input, + }, + }) + } + return parts + } + + text := finalText + if strings.TrimSpace(text) == "" { + text = finalThinking + } + return []map[string]any{{"text": text}} +} diff --git a/internal/adapter/gemini/handler_routes.go b/internal/adapter/gemini/handler_routes.go new file mode 100644 index 0000000..6850b51 --- /dev/null +++ b/internal/adapter/gemini/handler_routes.go @@ -0,0 +1,32 @@ +package gemini + +import ( + "net/http" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/util" +) + +var writeJSON = util.WriteJSON + +type Handler struct { + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller +} + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent) + r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent) + r.Post("/v1/models/{model}:generateContent", h.GenerateContent) + r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent) +} + +func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) { + h.handleGenerateContent(w, r, false) +} + +func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) { + h.handleGenerateContent(w, r, true) +} diff --git a/internal/adapter/gemini/handler_stream_runtime.go b/internal/adapter/gemini/handler_stream_runtime.go new file mode 100644 index 0000000..c6a6bcd --- /dev/null +++ b/internal/adapter/gemini/handler_stream_runtime.go @@ -0,0 +1,181 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "ds2api/internal/deepseek" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" +) + +func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + rc := http.NewResponseController(w) + _, canFlush := w.(http.Flusher) + runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnParsed: runtime.onParsed, + OnFinalize: func(_ streamengine.StopReason, _ error) { + runtime.finalize() + }, + }) +} + +type geminiStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + model string + finalPrompt string + + thinkingEnabled bool + searchEnabled bool + bufferContent bool + toolNames []string + + thinking strings.Builder + text strings.Builder +} + +func newGeminiStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, +) *geminiStreamRuntime { + return &geminiStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + model: model, + finalPrompt: finalPrompt, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferContent: len(toolNames) > 0, + toolNames: toolNames, + } +} + +func (s *geminiStreamRuntime) sendChunk(payload map[string]any) { + b, _ := json.Marshal(payload) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + if p.Type == "thinking" { + if s.thinkingEnabled { + s.thinking.WriteString(p.Text) + } + continue + } + s.text.WriteString(p.Text) + if s.bufferContent { + continue + } + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": []map[string]any{{"text": p.Text}}, + }, + }, + }, + "modelVersion": s.model, + }) + } + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} + +func (s *geminiStreamRuntime) finalize() { + finalThinking := s.thinking.String() + finalText := s.text.String() + + if s.bufferContent { + parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames) + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + }, + }, + "modelVersion": s.model, + }) + } + + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": []map[string]any{ + {"text": ""}, + }, + }, + "finishReason": "STOP", + }, + }, + "modelVersion": s.model, + "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText), + }) +} diff --git a/internal/adapter/gemini/handler_test.go b/internal/adapter/gemini/handler_test.go new file mode 100644 index 0000000..8095417 --- /dev/null +++ b/internal/adapter/gemini/handler_test.go @@ -0,0 +1,216 @@ +package gemini + +import ( + "bufio" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" +) + +type testGeminiConfig struct{} + +func (testGeminiConfig) ModelAliases() map[string]string { return nil } + +type testGeminiAuth struct { + a *auth.RequestAuth + err error +} + +func (m testGeminiAuth) Determine(_ *http.Request) (*auth.RequestAuth, error) { + if m.err != nil { + return nil, m.err + } + if m.a != nil { + return m.a, nil + } + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (testGeminiAuth) Release(_ *auth.RequestAuth) {} + +type testGeminiDS struct { + resp *http.Response + err error +} + +func (m testGeminiDS) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "session-id", nil +} + +func (m testGeminiDS) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "pow", nil +} + +func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) { + if m.err != nil { + return nil, m.err + } + return m.resp, nil +} + +func makeGeminiUpstreamResponse(lines ...string) *http.Response { + body := strings.Join(lines, "\n") + if !strings.HasSuffix(body, "\n") { + body += "\n" + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestGeminiRoutesRegistered(t *testing.T) { + h := &Handler{ + Store: testGeminiConfig{}, + Auth: testGeminiAuth{err: auth.ErrUnauthorized}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + paths := []string{ + "/v1beta/models/gemini-2.5-pro:generateContent", + "/v1beta/models/gemini-2.5-pro:streamGenerateContent", + "/v1/models/gemini-2.5-pro:generateContent", + "/v1/models/gemini-2.5-pro:streamGenerateContent", + } + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code == http.StatusNotFound { + t.Fatalf("expected route %s to be registered, got 404", path) + } + } +} + +func TestGenerateContentReturnsFunctionCallParts(t *testing.T) { + upstream := makeGeminiUpstreamResponse( + `data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`, + `data: [DONE]`, + ) + h := &Handler{ + Store: testGeminiConfig{}, + Auth: testGeminiAuth{}, + DS: testGeminiDS{resp: upstream}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := `{ + "contents":[{"role":"user","parts":[{"text":"call tool"}]}], + "tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer direct-token") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + candidates, _ := out["candidates"].([]any) + if len(candidates) == 0 { + t.Fatalf("expected non-empty candidates: %#v", out) + } + c0, _ := candidates[0].(map[string]any) + content, _ := c0["content"].(map[string]any) + parts, _ := content["parts"].([]any) + if len(parts) == 0 { + t.Fatalf("expected non-empty parts: %#v", content) + } + part0, _ := parts[0].(map[string]any) + functionCall, _ := part0["functionCall"].(map[string]any) + if functionCall["name"] != "eval_javascript" { + t.Fatalf("expected functionCall name eval_javascript, got %#v", functionCall) + } +} + +func TestStreamGenerateContentEmitsSSE(t *testing.T) { + upstream := makeGeminiUpstreamResponse( + `data: {"p":"response/content","v":"hello "}`, + `data: {"p":"response/content","v":"world"}`, + `data: [DONE]`, + ) + h := &Handler{ + Store: testGeminiConfig{}, + Auth: testGeminiAuth{}, + DS: testGeminiDS{resp: upstream}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}` + req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:streamGenerateContent?alt=sse", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer direct-token") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "data: ") { + t.Fatalf("expected SSE data frames, got body=%s", rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"finishReason":"STOP"`) { + t.Fatalf("expected stream finish frame, got body=%s", rec.Body.String()) + } + + frames := extractGeminiSSEFrames(t, rec.Body.String()) + if len(frames) == 0 { + t.Fatalf("expected non-empty sse frames, body=%s", rec.Body.String()) + } + last := frames[len(frames)-1] + candidates, _ := last["candidates"].([]any) + if len(candidates) == 0 { + t.Fatalf("expected finish frame candidates, got %#v", last) + } + c0, _ := candidates[0].(map[string]any) + content, _ := c0["content"].(map[string]any) + if content == nil { + t.Fatalf("expected non-null content in finish frame, got %#v", c0) + } + parts, _ := content["parts"].([]any) + if len(parts) == 0 { + t.Fatalf("expected non-empty parts in finish frame content, got %#v", content) + } +} + +func extractGeminiSSEFrames(t *testing.T, body string) []map[string]any { + t.Helper() + scanner := bufio.NewScanner(strings.NewReader(body)) + out := make([]map[string]any, 0, 4) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if !strings.HasPrefix(line, "data: ") { + continue + } + raw := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + if raw == "" { + continue + } + var frame map[string]any + if err := json.Unmarshal([]byte(raw), &frame); err != nil { + continue + } + out = append(out, frame) + } + return out +} diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go new file mode 100644 index 0000000..a5ecbd6 --- /dev/null +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -0,0 +1,270 @@ +package openai + +import ( + "encoding/json" + "net/http" + "strings" + + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +type chatStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + completionID string + created int64 + model string + finalPrompt string + toolNames []string + + thinkingEnabled bool + searchEnabled bool + + firstChunkSent bool + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + toolCallsDoneEmitted bool + + toolSieve toolStreamSieveState + streamToolCallIDs map[int]string + streamToolNames map[int]string + thinking strings.Builder + text strings.Builder +} + +func newChatStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + completionID string, + created int64, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, + bufferToolContent bool, + emitEarlyToolDeltas bool, +) *chatStreamRuntime { + return &chatStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + completionID: completionID, + created: created, + model: model, + finalPrompt: finalPrompt, + toolNames: toolNames, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferToolContent: bufferToolContent, + emitEarlyToolDeltas: emitEarlyToolDeltas, + streamToolCallIDs: map[int]string{}, + streamToolNames: map[int]string{}, + } +} + +func (s *chatStreamRuntime) sendKeepAlive() { + if !s.canFlush { + return + } + _, _ = s.w.Write([]byte(": keep-alive\n\n")) + _ = s.rc.Flush() +} + +func (s *chatStreamRuntime) sendChunk(v any) { + b, _ := json.Marshal(v) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *chatStreamRuntime) sendDone() { + _, _ = s.w.Write([]byte("data: [DONE]\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *chatStreamRuntime) finalize(finishReason string) { + finalThinking := s.thinking.String() + finalText := s.text.String() + detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) > 0 && !s.toolCallsDoneEmitted { + finishReason = "tool_calls" + delta := map[string]any{ + "tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs), + } + if !s.firstChunkSent { + delta["role"] = "assistant" + s.firstChunkSent = true + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)}, + nil, + )) + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true + } else if s.bufferToolContent { + for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) { + if len(evt.ToolCalls) > 0 { + finishReason = "tool_calls" + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true + tcDelta := map[string]any{ + "tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs), + } + if !s.firstChunkSent { + tcDelta["role"] = "assistant" + s.firstChunkSent = true + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)}, + nil, + )) + } + if evt.Content == "" { + continue + } + delta := map[string]any{ + "content": evt.Content, + } + if !s.firstChunkSent { + delta["role"] = "assistant" + s.firstChunkSent = true + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)}, + nil, + )) + } + } + + if len(detected) > 0 || s.toolCallsEmitted { + finishReason = "tool_calls" + } + s.sendChunk(openaifmt.BuildChatStreamChunk( + s.completionID, + s.created, + s.model, + []map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)}, + openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText), + )) + s.sendDone() +} + +func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" { + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")} + } + if parsed.Stop { + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested} + } + + newChoices := make([]map[string]any, 0, len(parsed.Parts)) + contentSeen := false + for _, p := range parsed.Parts { + if s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + if p.Text == "" { + continue + } + contentSeen = true + delta := map[string]any{} + if !s.firstChunkSent { + delta["role"] = "assistant" + s.firstChunkSent = true + } + if p.Type == "thinking" { + if s.thinkingEnabled { + s.thinking.WriteString(p.Text) + delta["reasoning_content"] = p.Text + } + } else { + s.text.WriteString(p.Text) + if !s.bufferToolContent { + delta["content"] = p.Text + } else { + events := processToolSieveChunk(&s.toolSieve, p.Text, s.toolNames) + for _, evt := range events { + if len(evt.ToolCallDeltas) > 0 { + if !s.emitEarlyToolDeltas { + continue + } + filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames) + if len(filtered) == 0 { + continue + } + formatted := formatIncrementalStreamToolCallDeltas(filtered, s.streamToolCallIDs) + if len(formatted) == 0 { + continue + } + tcDelta := map[string]any{ + "tool_calls": formatted, + } + s.toolCallsEmitted = true + if !s.firstChunkSent { + tcDelta["role"] = "assistant" + s.firstChunkSent = true + } + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)) + continue + } + if len(evt.ToolCalls) > 0 { + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true + tcDelta := map[string]any{ + "tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs), + } + if !s.firstChunkSent { + tcDelta["role"] = "assistant" + s.firstChunkSent = true + } + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)) + continue + } + if evt.Content != "" { + contentDelta := map[string]any{ + "content": evt.Content, + } + if !s.firstChunkSent { + contentDelta["role"] = "assistant" + s.firstChunkSent = true + } + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, contentDelta)) + } + } + } + } + if len(delta) > 0 { + newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, delta)) + } + } + + if len(newChoices) > 0 { + s.sendChunk(openaifmt.BuildChatStreamChunk(s.completionID, s.created, s.model, newChoices, nil)) + } + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} diff --git a/internal/adapter/openai/deps.go b/internal/adapter/openai/deps.go new file mode 100644 index 0000000..6688756 --- /dev/null +++ b/internal/adapter/openai/deps.go @@ -0,0 +1,35 @@ +package openai + +import ( + "context" + "net/http" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type AuthResolver interface { + Determine(req *http.Request) (*auth.RequestAuth, error) + DetermineCaller(req *http.Request) (*auth.RequestAuth, error) + Release(a *auth.RequestAuth) +} + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type ConfigReader interface { + ModelAliases() map[string]string + CompatWideInputStrictOutput() bool + ToolcallMode() string + ToolcallEarlyEmitConfidence() string + ResponsesStoreTTLSeconds() int + EmbeddingsProvider() string +} + +var _ AuthResolver = (*auth.Resolver)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) +var _ ConfigReader = (*config.Store)(nil) diff --git a/internal/adapter/openai/deps_injection_test.go b/internal/adapter/openai/deps_injection_test.go new file mode 100644 index 0000000..6286c0c --- /dev/null +++ b/internal/adapter/openai/deps_injection_test.go @@ -0,0 +1,70 @@ +package openai + +import "testing" + +type mockOpenAIConfig struct { + aliases map[string]string + wideInput bool + toolMode string + earlyEmit string + responsesTTL int + embedProv string +} + +func (m mockOpenAIConfig) ModelAliases() map[string]string { return m.aliases } +func (m mockOpenAIConfig) CompatWideInputStrictOutput() bool { + return m.wideInput +} +func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMode } +func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit } +func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL } +func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv } + +func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) { + cfg := mockOpenAIConfig{ + aliases: map[string]string{ + "my-model": "deepseek-chat-search", + }, + wideInput: true, + } + req := map[string]any{ + "model": "my-model", + "messages": []any{map[string]any{"role": "user", "content": "hello"}}, + } + out, err := normalizeOpenAIChatRequest(cfg, req, "") + if err != nil { + t.Fatalf("normalizeOpenAIChatRequest error: %v", err) + } + if out.ResolvedModel != "deepseek-chat-search" { + t.Fatalf("resolved model mismatch: got=%q", out.ResolvedModel) + } + if !out.Search || out.Thinking { + t.Fatalf("unexpected model flags: thinking=%v search=%v", out.Thinking, out.Search) + } +} + +func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.T) { + req := map[string]any{ + "model": "deepseek-chat", + "input": "hi", + } + + _, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ + aliases: map[string]string{}, + wideInput: false, + }, req, "") + if err == nil { + t.Fatal("expected error when wide input is disabled and only input is provided") + } + + out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{ + aliases: map[string]string{}, + wideInput: true, + }, req, "") + if err != nil { + t.Fatalf("unexpected error when wide input is enabled: %v", err) + } + if out.Surface != "openai_responses" { + t.Fatalf("unexpected surface: %q", out.Surface) + } +} diff --git a/internal/adapter/openai/embeddings_handler.go b/internal/adapter/openai/embeddings_handler.go new file mode 100644 index 0000000..ff61be0 --- /dev/null +++ b/internal/adapter/openai/embeddings_handler.go @@ -0,0 +1,138 @@ +package openai + +import ( + "crypto/sha256" + "encoding/binary" + "encoding/json" + "fmt" + "net/http" + "strings" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/util" +) + +func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + model, _ := req["model"].(string) + model = strings.TrimSpace(model) + if model == "" { + writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.") + return + } + if _, ok := config.ResolveModel(h.Store, model); !ok { + writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) + return + } + + inputs := extractEmbeddingInputs(req["input"]) + if len(inputs) == 0 { + writeOpenAIError(w, http.StatusBadRequest, "Request must include non-empty 'input'.") + return + } + + provider := "" + if h.Store != nil { + provider = strings.ToLower(strings.TrimSpace(h.Store.EmbeddingsProvider())) + } + if provider == "" { + writeOpenAIError(w, http.StatusNotImplemented, "Embeddings provider is not configured. Set embeddings.provider in config.") + return + } + switch provider { + case "mock", "deterministic", "builtin": + // supported local deterministic provider + default: + writeOpenAIError(w, http.StatusNotImplemented, fmt.Sprintf("Embeddings provider '%s' is not supported.", provider)) + return + } + + data := make([]map[string]any, 0, len(inputs)) + totalTokens := 0 + for i, input := range inputs { + totalTokens += util.EstimateTokens(input) + data = append(data, map[string]any{ + "object": "embedding", + "index": i, + "embedding": deterministicEmbedding(input), + }) + } + writeJSON(w, http.StatusOK, map[string]any{ + "object": "list", + "data": data, + "model": model, + "usage": map[string]any{ + "prompt_tokens": totalTokens, + "total_tokens": totalTokens, + }, + }) +} + +func extractEmbeddingInputs(raw any) []string { + switch v := raw.(type) { + case string: + s := strings.TrimSpace(v) + if s == "" { + return nil + } + return []string{s} + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + switch iv := item.(type) { + case string: + s := strings.TrimSpace(iv) + if s != "" { + out = append(out, s) + } + case []any: + // Token array input support: convert to stable string form. + out = append(out, fmt.Sprintf("%v", iv)) + default: + s := strings.TrimSpace(fmt.Sprintf("%v", iv)) + if s != "" { + out = append(out, s) + } + } + } + return out + default: + return nil + } +} + +func deterministicEmbedding(input string) []float64 { + // Keep response shape stable without external dependencies. + const dims = 64 + out := make([]float64, dims) + seed := sha256.Sum256([]byte(input)) + buf := seed[:] + for i := 0; i < dims; i++ { + if len(buf) < 4 { + next := sha256.Sum256(buf) + buf = next[:] + } + v := binary.BigEndian.Uint32(buf[:4]) + buf = buf[4:] + // map [0, 2^32) -> [-1, 1] + out[i] = (float64(v)/2147483647.5 - 1.0) + } + return out +} diff --git a/internal/adapter/openai/embeddings_route_test.go b/internal/adapter/openai/embeddings_route_test.go new file mode 100644 index 0000000..4395d16 --- /dev/null +++ b/internal/adapter/openai/embeddings_route_test.go @@ -0,0 +1,96 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", cfgJSON) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func TestEmbeddingsRouteContract(t *testing.T) { + store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + t.Run("unauthorized", func(t *testing.T) { + body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("ok", func(t *testing.T) { + body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + if out["object"] != "list" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + data, _ := out["data"].([]any) + if len(data) != 2 { + t.Fatalf("expected 2 embeddings, got %d", len(data)) + } + }) +} + +func TestEmbeddingsRouteProviderMissing(t *testing.T) { + store, resolver := newResolverWithConfigJSON(t, `{}`) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotImplemented { + t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + errObj, _ := out["error"].(map[string]any) + if _, ok := errObj["code"]; !ok { + t.Fatalf("expected error.code in response: %#v", out) + } + if _, ok := errObj["param"]; !ok { + t.Fatalf("expected error.param in response: %#v", out) + } +} diff --git a/internal/adapter/openai/error_shape_test.go b/internal/adapter/openai/error_shape_test.go new file mode 100644 index 0000000..8c73e4b --- /dev/null +++ b/internal/adapter/openai/error_shape_test.go @@ -0,0 +1,34 @@ +package openai + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) { + rec := httptest.NewRecorder() + writeOpenAIError(rec, http.StatusBadRequest, "invalid input") + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", rec.Code) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body: %v", err) + } + errObj, _ := body["error"].(map[string]any) + if errObj["message"] != "invalid input" { + t.Fatalf("unexpected message: %v", errObj["message"]) + } + if errObj["type"] != "invalid_request_error" { + t.Fatalf("unexpected type: %v", errObj["type"]) + } + if errObj["code"] != "invalid_request" { + t.Fatalf("unexpected code: %v", errObj["code"]) + } + if _, ok := errObj["param"]; !ok { + t.Fatal("expected param field") + } +} diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go deleted file mode 100644 index d0a2f1d..0000000 --- a/internal/adapter/openai/handler.go +++ /dev/null @@ -1,487 +0,0 @@ -package openai - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "sync" - "time" - - "github.com/go-chi/chi/v5" - - "ds2api/internal/auth" - "ds2api/internal/config" - "ds2api/internal/deepseek" - "ds2api/internal/sse" - "ds2api/internal/util" -) - -// writeJSON is a package-internal alias kept to avoid mass-renaming across -// every call-site in this file. It delegates to the shared util version. -var writeJSON = util.WriteJSON - -type Handler struct { - Store *config.Store - Auth *auth.Resolver - DS *deepseek.Client - - leaseMu sync.Mutex - streamLeases map[string]streamLease -} - -type streamLease struct { - Auth *auth.RequestAuth - ExpiresAt time.Time -} - -func RegisterRoutes(r chi.Router, h *Handler) { - r.Get("/v1/models", h.ListModels) - r.Post("/v1/chat/completions", h.ChatCompletions) -} - -func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, config.OpenAIModelsResponse()) -} - -func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { - if isVercelStreamReleaseRequest(r) { - h.handleVercelStreamRelease(w, r) - return - } - if isVercelStreamPrepareRequest(r) { - h.handleVercelStreamPrepare(w, r) - return - } - - a, err := h.Auth.Determine(r) - if err != nil { - status := http.StatusUnauthorized - detail := err.Error() - if err == auth.ErrNoAccount { - status = http.StatusTooManyRequests - } - writeOpenAIError(w, status, detail) - return - } - defer h.Auth.Release(a) - r = r.WithContext(auth.WithAuth(r.Context(), a)) - - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeOpenAIError(w, http.StatusBadRequest, "invalid json") - return - } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") - return - } - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model) - if !ok { - writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model)) - return - } - - messages := normalizeMessages(messagesRaw) - toolNames := []string{} - if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { - messages, toolNames = injectToolPrompt(messages, tools) - } - finalPrompt := util.MessagesPrepare(messages) - - sessionID, err := h.DS.CreateSession(r.Context(), a, 3) - if err != nil { - if a.UseConfigToken { - writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") - } else { - writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") - } - return - } - pow, err := h.DS.GetPow(r.Context(), a, 3) - if err != nil { - writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") - return - } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } - resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) - if err != nil { - writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") - return - } - if util.ToBool(req["stream"]) { - h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) - return - } - h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) -} - -func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - writeOpenAIError(w, resp.StatusCode, string(body)) - return - } - _ = ctx - result := sse.CollectStream(resp, thinkingEnabled, true) - - finalThinking := result.Thinking - finalText := result.Text - detected := util.ParseToolCalls(finalText, toolNames) - finishReason := "stop" - messageObj := map[string]any{"role": "assistant", "content": finalText} - if thinkingEnabled && finalThinking != "" { - messageObj["reasoning_content"] = finalThinking - } - if len(detected) > 0 { - finishReason = "tool_calls" - messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) - messageObj["content"] = nil - } - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - - writeJSON(w, http.StatusOK, map[string]any{ - "id": completionID, - "object": "chat.completion", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, - "usage": map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - "completion_tokens_details": map[string]any{ - "reasoning_tokens": reasoningTokens, - }, - }, - }) -} - -func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - writeOpenAIError(w, resp.StatusCode, string(body)) - return - } - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - rc := http.NewResponseController(w) - canFlush := rc.Flush() == nil - if !canFlush { - config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered") - } - - created := time.Now().Unix() - firstChunkSent := false - bufferToolContent := len(toolNames) > 0 - var toolSieve toolStreamSieveState - toolCallsEmitted := false - initialType := "text" - if thinkingEnabled { - initialType = "thinking" - } - parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) - thinking := strings.Builder{} - text := strings.Builder{} - lastContent := time.Now() - hasContent := false - keepaliveTicker := time.NewTicker(time.Duration(deepseek.KeepAliveTimeout) * time.Second) - defer keepaliveTicker.Stop() - keepaliveCountWithoutContent := 0 - - sendChunk := func(v any) { - b, _ := json.Marshal(v) - _, _ = w.Write([]byte("data: ")) - _, _ = w.Write(b) - _, _ = w.Write([]byte("\n\n")) - if canFlush { - _ = rc.Flush() - } - } - sendDone := func() { - _, _ = w.Write([]byte("data: [DONE]\n\n")) - if canFlush { - _ = rc.Flush() - } - } - - finalize := func(finishReason string) { - finalThinking := thinking.String() - finalText := text.String() - detected := util.ParseToolCalls(finalText, toolNames) - if len(detected) > 0 && !toolCallsEmitted { - finishReason = "tool_calls" - delta := map[string]any{ - "tool_calls": util.FormatOpenAIStreamToolCalls(detected), - } - if !firstChunkSent { - delta["role"] = "assistant" - firstChunkSent = true - } - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]any{{"delta": delta, "index": 0}}, - }) - } else if bufferToolContent { - for _, evt := range flushToolSieve(&toolSieve, toolNames) { - if evt.Content == "" { - continue - } - delta := map[string]any{ - "content": evt.Content, - } - if !firstChunkSent { - delta["role"] = "assistant" - firstChunkSent = true - } - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]any{{"delta": delta, "index": 0}}, - }) - } - } - if len(detected) > 0 || toolCallsEmitted { - finishReason = "tool_calls" - } - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": []map[string]any{{"delta": map[string]any{}, "index": 0, "finish_reason": finishReason}}, - "usage": map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - "completion_tokens_details": map[string]any{ - "reasoning_tokens": reasoningTokens, - }, - }, - }) - sendDone() - } - - for { - select { - case <-r.Context().Done(): - return - case <-keepaliveTicker.C: - if !hasContent { - keepaliveCountWithoutContent++ - if keepaliveCountWithoutContent >= deepseek.MaxKeepaliveCount { - finalize("stop") - return - } - } - if hasContent && time.Since(lastContent) > time.Duration(deepseek.StreamIdleTimeout)*time.Second { - finalize("stop") - return - } - if canFlush { - _, _ = w.Write([]byte(": keep-alive\n\n")) - _ = rc.Flush() - } - case parsed, ok := <-parsedLines: - if !ok { - // Ensure scanner completion is observed only after all queued - // SSE lines are drained, avoiding early finalize races. - _ = <-done - finalize("stop") - return - } - if !parsed.Parsed { - continue - } - if parsed.ContentFilter || parsed.ErrorMessage != "" { - finalize("content_filter") - return - } - if parsed.Stop { - finalize("stop") - return - } - newChoices := make([]map[string]any, 0, len(parsed.Parts)) - for _, p := range parsed.Parts { - if searchEnabled && sse.IsCitation(p.Text) { - continue - } - if p.Text == "" { - continue - } - hasContent = true - lastContent = time.Now() - keepaliveCountWithoutContent = 0 - delta := map[string]any{} - if !firstChunkSent { - delta["role"] = "assistant" - firstChunkSent = true - } - if p.Type == "thinking" { - if thinkingEnabled { - thinking.WriteString(p.Text) - delta["reasoning_content"] = p.Text - } - } else { - text.WriteString(p.Text) - if !bufferToolContent { - delta["content"] = p.Text - } else { - events := processToolSieveChunk(&toolSieve, p.Text, toolNames) - if len(events) == 0 { - // Keep thinking delta only frame. - } - for _, evt := range events { - if len(evt.ToolCalls) > 0 { - toolCallsEmitted = true - tcDelta := map[string]any{ - "tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls), - } - if !firstChunkSent { - tcDelta["role"] = "assistant" - firstChunkSent = true - } - newChoices = append(newChoices, map[string]any{ - "delta": tcDelta, - "index": 0, - }) - continue - } - if evt.Content != "" { - contentDelta := map[string]any{ - "content": evt.Content, - } - if !firstChunkSent { - contentDelta["role"] = "assistant" - firstChunkSent = true - } - newChoices = append(newChoices, map[string]any{ - "delta": contentDelta, - "index": 0, - }) - } - } - } - } - if len(delta) > 0 { - newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0}) - } - } - if len(newChoices) > 0 { - sendChunk(map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": newChoices, - }) - } - } - } -} - -func normalizeMessages(raw []any) []map[string]any { - out := make([]map[string]any, 0, len(raw)) - for _, item := range raw { - m, ok := item.(map[string]any) - if ok { - out = append(out, m) - } - } - return out -} - -func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { - toolSchemas := make([]string, 0, len(tools)) - names := make([]string, 0, len(tools)) - for _, t := range tools { - tool, ok := t.(map[string]any) - if !ok { - continue - } - fn, _ := tool["function"].(map[string]any) - if len(fn) == 0 { - fn = tool - } - name, _ := fn["name"].(string) - desc, _ := fn["description"].(string) - schema, _ := fn["parameters"].(map[string]any) - if name == "" { - name = "unknown" - } - names = append(names, name) - if desc == "" { - desc = "No description available" - } - b, _ := json.Marshal(schema) - toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b))) - } - if len(toolSchemas) == 0 { - return messages, names - } - toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }" - - for i := range messages { - if messages[i]["role"] == "system" { - old, _ := messages[i]["content"].(string) - messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt) - return messages, names - } - } - messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...) - return messages, names -} - -func writeOpenAIError(w http.ResponseWriter, status int, message string) { - writeJSON(w, status, map[string]any{ - "error": map[string]any{ - "message": message, - "type": openAIErrorType(status), - }, - }) -} - -func openAIErrorType(status int) string { - switch status { - case http.StatusBadRequest: - return "invalid_request_error" - case http.StatusUnauthorized: - return "authentication_error" - case http.StatusForbidden: - return "permission_error" - case http.StatusTooManyRequests: - return "rate_limit_error" - case http.StatusServiceUnavailable: - return "service_unavailable_error" - default: - if status >= 500 { - return "api_error" - } - return "invalid_request_error" - } -} diff --git a/internal/adapter/openai/handler_chat.go b/internal/adapter/openai/handler_chat.go new file mode 100644 index 0000000..26a4bf2 --- /dev/null +++ b/internal/adapter/openai/handler_chat.go @@ -0,0 +1,156 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" +) + +func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { + if isVercelStreamReleaseRequest(r) { + h.handleVercelStreamRelease(w, r) + return + } + if isVercelStreamPrepareRequest(r) { + h.handleVercelStreamPrepare(w, r) + return + } + + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + r = r.WithContext(auth.WithAuth(r.Context(), a)) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r)) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") + } + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + payload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") + return + } + if stdReq.Stream { + h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + return + } + h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) +} + +func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, string(body)) + return + } + _ = ctx + result := sse.CollectStream(resp, thinkingEnabled, true) + + finalThinking := result.Thinking + finalText := result.Text + respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) + writeJSON(w, http.StatusOK, respBody) +} + +func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, string(body)) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + rc := http.NewResponseController(w) + _, canFlush := w.(http.Flusher) + if !canFlush { + config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered") + } + + created := time.Now().Unix() + bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() + emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + + streamRuntime := newChatStreamRuntime( + w, + rc, + canFlush, + completionID, + created, + model, + finalPrompt, + thinkingEnabled, + searchEnabled, + toolNames, + bufferToolContent, + emitEarlyToolDeltas, + ) + + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnKeepAlive: func() { + streamRuntime.sendKeepAlive() + }, + OnParsed: streamRuntime.onParsed, + OnFinalize: func(reason streamengine.StopReason, _ error) { + if string(reason) == "content_filter" { + streamRuntime.finalize("content_filter") + return + } + streamRuntime.finalize("stop") + }, + }) +} diff --git a/internal/adapter/openai/handler_errors.go b/internal/adapter/openai/handler_errors.go new file mode 100644 index 0000000..2e60d73 --- /dev/null +++ b/internal/adapter/openai/handler_errors.go @@ -0,0 +1,63 @@ +package openai + +import "net/http" + +func writeOpenAIError(w http.ResponseWriter, status int, message string) { + writeOpenAIErrorWithCode(w, status, message, "") +} + +func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) { + if code == "" { + code = openAIErrorCode(status) + } + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "message": message, + "type": openAIErrorType(status), + "code": code, + "param": nil, + }, + }) +} + +func openAIErrorType(status int) string { + switch status { + case http.StatusBadRequest: + return "invalid_request_error" + case http.StatusUnauthorized: + return "authentication_error" + case http.StatusForbidden: + return "permission_error" + case http.StatusTooManyRequests: + return "rate_limit_error" + case http.StatusServiceUnavailable: + return "service_unavailable_error" + default: + if status >= 500 { + return "api_error" + } + return "invalid_request_error" + } +} + +func openAIErrorCode(status int) string { + switch status { + case http.StatusBadRequest: + return "invalid_request" + case http.StatusUnauthorized: + return "authentication_failed" + case http.StatusForbidden: + return "forbidden" + case http.StatusTooManyRequests: + return "rate_limit_exceeded" + case http.StatusNotFound: + return "not_found" + case http.StatusServiceUnavailable: + return "service_unavailable" + default: + if status >= 500 { + return "internal_error" + } + return "invalid_request" + } +} diff --git a/internal/adapter/openai/handler_routes.go b/internal/adapter/openai/handler_routes.go new file mode 100644 index 0000000..a0cfcd6 --- /dev/null +++ b/internal/adapter/openai/handler_routes.go @@ -0,0 +1,57 @@ +package openai + +import ( + "net/http" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/util" +) + +// writeJSON is a package-internal alias kept to avoid mass-renaming across +// every call-site in this package. +var writeJSON = util.WriteJSON + +type Handler struct { + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller + + leaseMu sync.Mutex + streamLeases map[string]streamLease + responsesMu sync.Mutex + responses *responseStore +} + +type streamLease struct { + Auth *auth.RequestAuth + ExpiresAt time.Time +} + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Get("/v1/models", h.ListModels) + r.Get("/v1/models/{model_id}", h.GetModel) + r.Post("/v1/chat/completions", h.ChatCompletions) + r.Post("/v1/responses", h.Responses) + r.Get("/v1/responses/{response_id}", h.GetResponseByID) + r.Post("/v1/embeddings", h.Embeddings) +} + +func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, config.OpenAIModelsResponse()) +} + +func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) { + modelID := strings.TrimSpace(chi.URLParam(r, "model_id")) + model, ok := config.OpenAIModelByID(h.Store, modelID) + if !ok { + writeOpenAIError(w, http.StatusNotFound, "Model not found.") + return + } + writeJSON(w, http.StatusOK, model) +} diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go new file mode 100644 index 0000000..37ebaf9 --- /dev/null +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -0,0 +1,171 @@ +package openai + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/google/uuid" + + "ds2api/internal/util" +) + +func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolChoicePolicy) ([]map[string]any, []string) { + if policy.IsNone() { + return messages, nil + } + toolSchemas := make([]string, 0, len(tools)) + names := make([]string, 0, len(tools)) + isAllowed := func(name string) bool { + if strings.TrimSpace(name) == "" { + return false + } + if len(policy.Allowed) == 0 { + return true + } + _, ok := policy.Allowed[name] + return ok + } + + for _, t := range tools { + tool, ok := t.(map[string]any) + if !ok { + continue + } + fn, _ := tool["function"].(map[string]any) + if len(fn) == 0 { + fn = tool + } + name, _ := fn["name"].(string) + desc, _ := fn["description"].(string) + schema, _ := fn["parameters"].(map[string]any) + name = strings.TrimSpace(name) + if !isAllowed(name) { + continue + } + names = append(names, name) + if desc == "" { + desc = "No description available" + } + b, _ := json.Marshal(schema) + toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b))) + } + if len(toolSchemas) == 0 { + return messages, names + } + toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block." + if policy.Mode == util.ToolChoiceRequired { + toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list." + } + if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" { + toolPrompt += "\n5) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName) + toolPrompt += "\n6) Do not call any other tool." + } + + for i := range messages { + if messages[i]["role"] == "system" { + old, _ := messages[i]["content"].(string) + messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt) + return messages, names + } + } + messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...) + return messages, names +} + +func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any { + if len(deltas) == 0 { + return nil + } + out := make([]map[string]any, 0, len(deltas)) + for _, d := range deltas { + if d.Name == "" && d.Arguments == "" { + continue + } + callID, ok := ids[d.Index] + if !ok || callID == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + ids[d.Index] = callID + } + item := map[string]any{ + "index": d.Index, + "id": callID, + "type": "function", + } + fn := map[string]any{} + if d.Name != "" { + fn["name"] = d.Name + } + if d.Arguments != "" { + fn["arguments"] = d.Arguments + } + if len(fn) > 0 { + item["function"] = fn + } + out = append(out, item) + } + return out +} + +func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta { + if len(deltas) == 0 { + return nil + } + allowed := namesToSet(allowedNames) + if len(allowed) == 0 { + for _, d := range deltas { + if d.Name != "" { + seenNames[d.Index] = "__blocked__" + } + } + return nil + } + out := make([]toolCallDelta, 0, len(deltas)) + for _, d := range deltas { + if d.Name != "" { + if _, ok := allowed[d.Name]; !ok { + seenNames[d.Index] = "__blocked__" + continue + } + seenNames[d.Index] = d.Name + out = append(out, d) + continue + } + name := strings.TrimSpace(seenNames[d.Index]) + if name == "" || name == "__blocked__" { + continue + } + out = append(out, d) + } + return out +} + +func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any { + if len(calls) == 0 { + return nil + } + out := make([]map[string]any, 0, len(calls)) + for i, c := range calls { + callID := "" + if ids != nil { + callID = strings.TrimSpace(ids[i]) + } + if callID == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + if ids != nil { + ids[i] = callID + } + } + args, _ := json.Marshal(c.Input) + out = append(out, map[string]any{ + "index": i, + "id": callID, + "type": "function", + "function": map[string]any{ + "name": c.Name, + "arguments": string(args), + }, + }) + } + return out +} diff --git a/internal/adapter/openai/handler_toolcall_policy.go b/internal/adapter/openai/handler_toolcall_policy.go new file mode 100644 index 0000000..9f0e839 --- /dev/null +++ b/internal/adapter/openai/handler_toolcall_policy.go @@ -0,0 +1,25 @@ +package openai + +import "strings" + +func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) { + for k, v := range collectOpenAIChatPassThrough(req) { + payload[k] = v + } +} + +func (h *Handler) toolcallFeatureMatchEnabled() bool { + if h == nil || h.Store == nil { + return true + } + mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode())) + return mode == "" || mode == "feature_match" +} + +func (h *Handler) toolcallEarlyEmitHighConfidence() bool { + if h == nil || h.Store == nil { + return true + } + level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence())) + return level == "" || level == "high" +} diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index f9c44dd..895605f 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -100,6 +100,26 @@ func streamFinishReason(frames []map[string]any) string { return "" } +func streamToolCallArgumentChunks(frames []map[string]any) []string { + out := make([]string, 0, 4) + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + toolCalls, _ := delta["tool_calls"].([]any) + for _, tc := range toolCalls { + tcm, _ := tc.(map[string]any) + fn, _ := tcm["function"].(map[string]any) + if args, ok := fn["arguments"].(string); ok && args != "" { + out = append(out, args) + } + } + } + } + return out +} + func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( @@ -108,7 +128,7 @@ func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -141,7 +161,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -161,7 +181,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) { } } -func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { +func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`, @@ -169,7 +189,38 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"}) + h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + if _, ok := msg["tool_calls"]; ok { + t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"]) + } + content, _ := msg["content"].(string) + if !strings.Contains(content, `"tool_calls"`) { + t.Fatalf("expected unknown tool json to pass through as text, got %#v", content) + } +} + +func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"下面是示例:"}`, + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: {"p":"response/content","v":"请勿执行。"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"}) if rec.Code != http.StatusOK { t.Fatalf("unexpected status: %d", rec.Code) } @@ -181,12 +232,41 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"]) } msg, _ := choice["message"].(map[string]any) - if msg["content"] != nil { - t.Fatalf("expected content nil, got %#v", msg["content"]) - } toolCalls, _ := msg["tool_calls"].([]any) - if len(toolCalls) != 1 { - t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"]) + if len(toolCalls) == 0 { + t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"]) + } + if msg["content"] != nil { + t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"]) + } +} + +func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}", + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + if _, ok := msg["tool_calls"]; ok { + t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"]) + } + content, _ := msg["content"].(string) + if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) { + t.Fatalf("expected fenced tool example to pass through as text, got %q", content) } } @@ -295,7 +375,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing. } } -func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) { +func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`, @@ -310,29 +390,40 @@ func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) - } - foundToolIndex := false - for _, frame := range frames { - choices, _ := frame["choices"].([]any) - for _, item := range choices { - choice, _ := item.(map[string]any) - delta, _ := choice["delta"].(map[string]any) - toolCalls, _ := delta["tool_calls"].([]any) - for _, tc := range toolCalls { - tcm, _ := tc.(map[string]any) - if _, ok := tcm["index"].(float64); ok { - foundToolIndex = true - } - } - } - } - if !foundToolIndex { - t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String()) + if streamHasToolCallsDelta(frames) { + t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String()) } if streamHasRawToolJSONContent(frames) { - t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + t.Fatalf("did not expect raw tool_calls json leak for unknown schema name: %s", rec.Body.String()) + } + if streamFinishReason(frames) != "stop" { + t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String()) + } +} + +func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\"}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid5b", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if streamHasToolCallsDelta(frames) { + t.Fatalf("did not expect tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String()) + } + if streamHasRawToolJSONContent(frames) { + t.Fatalf("did not expect raw tool_calls json leak for unknown schema name (no args): %s", rec.Body.String()) + } + if streamFinishReason(frames) != "stop" { + t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String()) } } @@ -377,9 +468,9 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) { func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( - `data: {"p":"response/content","v":"前置正文A。"}`, + `data: {"p":"response/content","v":"下面是示例:"}`, `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, - `data: {"p":"response/content","v":"后置正文B。"}`, + `data: {"p":"response/content","v":"请勿执行。"}`, `data: [DONE]`, ) rec := httptest.NewRecorder() @@ -392,10 +483,7 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String()) - } - if streamHasRawToolJSONContent(frames) { - t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String()) + t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String()) } content := strings.Builder{} for _, frame := range frames { @@ -409,9 +497,95 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { } } got := content.String() - if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") { + if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") { t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got) } + if strings.Contains(strings.ToLower(got), `"tool_calls"`) { + t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String()) + } +} + +func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"我将调用工具。"}`, + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + content := strings.Builder{} + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + if c, ok := delta["content"].(string); ok { + content.WriteString(c) + } + } + } + got := content.String() + if !strings.Contains(got, "我将调用工具。") { + t.Fatalf("expected leading text to keep streaming, got=%q", got) + } + if strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("unexpected raw tool json leak, got=%q", got) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } +} + +func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + content := strings.Builder{} + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + if c, ok := delta["content"].(string); ok { + content.WriteString(c) + } + } + } + got := content.String() + if !strings.Contains(got, "接下来我会继续说明。") { + t.Fatalf("expected trailing plain text to be preserved, got=%q", got) + } + if strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("unexpected raw tool json leak, got=%q", got) + } if streamFinishReason(frames) != "tool_calls" { t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } @@ -495,16 +669,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) { } } } - got := strings.ToLower(content.String()) - if strings.Contains(got, "tool_calls") { - t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String()) - } - if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") { + got := content.String() + if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") { t.Fatalf("expected pre/post plain text to remain, got=%q", content.String()) } + if !strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got) + } } -func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) { +func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`, @@ -533,7 +707,112 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing. } } } - if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") { - t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String()) + if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") { + t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String()) + } +} + +func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`, + `data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + if streamHasRawToolJSONContent(frames) { + t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + } + argChunks := streamToolCallArgumentChunks(frames) + if len(argChunks) < 2 { + t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String()) + } + joined := strings.Join(argChunks, "") + if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) { + t.Fatalf("unexpected merged arguments stream: %q", joined) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } +} + +func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`, + `data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + + foundSearch := false + foundEval := false + foundIndex1 := false + toolCallsDeltaLens := make([]int, 0, 2) + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + rawToolCalls, hasToolCalls := delta["tool_calls"] + if !hasToolCalls { + continue + } + toolCalls, _ := rawToolCalls.([]any) + toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls)) + for _, tc := range toolCalls { + tcm, _ := tc.(map[string]any) + if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 { + foundIndex1 = true + } + fn, _ := tcm["function"].(map[string]any) + name, _ := fn["name"].(string) + switch name { + case "search_web": + foundSearch = true + case "eval_javascript": + foundEval = true + case "search_webeval_javascript": + t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String()) + } + if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) { + t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String()) + } + } + } + } + if !foundSearch || !foundEval { + t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String()) + } + if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 { + t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String()) + } + if !foundIndex1 { + t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String()) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } } diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go new file mode 100644 index 0000000..94b2339 --- /dev/null +++ b/internal/adapter/openai/message_normalize.go @@ -0,0 +1,270 @@ +package openai + +import ( + "encoding/json" + "fmt" + "io" + "strings" + + "ds2api/internal/config" +) + +func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any { + out := make([]map[string]any, 0, len(raw)) + for _, item := range raw { + msg, ok := item.(map[string]any) + if !ok { + continue + } + role := strings.ToLower(strings.TrimSpace(asString(msg["role"]))) + switch role { + case "assistant": + content := normalizeOpenAIContentForPrompt(msg["content"]) + toolCalls := formatAssistantToolCallsForPrompt(msg, traceID) + combined := joinNonEmpty(content, toolCalls) + if combined == "" { + continue + } + out = append(out, map[string]any{ + "role": "assistant", + "content": combined, + }) + case "tool", "function": + out = append(out, map[string]any{ + "role": "user", + "content": formatToolResultForPrompt(msg), + }) + case "user", "system": + out = append(out, map[string]any{ + "role": role, + "content": normalizeOpenAIContentForPrompt(msg["content"]), + }) + default: + content := normalizeOpenAIContentForPrompt(msg["content"]) + if content == "" { + continue + } + if role == "" { + role = "user" + } + out = append(out, map[string]any{ + "role": role, + "content": content, + }) + } + } + return out +} + +func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string { + entries := make([]string, 0) + if calls, ok := msg["tool_calls"].([]any); ok { + for i, item := range calls { + call, ok := item.(map[string]any) + if !ok { + continue + } + id := strings.TrimSpace(asString(call["id"])) + if id == "" { + id = fmt.Sprintf("call_%d", i+1) + } + name := strings.TrimSpace(asString(call["name"])) + args := "" + + if fn, ok := call["function"].(map[string]any); ok { + if name == "" { + name = strings.TrimSpace(asString(fn["name"])) + } + args = normalizeOpenAIArgumentsForPrompt(fn["arguments"]) + } + if name == "" { + name = "unknown" + } + if args == "" { + args = normalizeOpenAIArgumentsForPrompt(call["arguments"]) + } + if args == "" { + args = normalizeOpenAIArgumentsForPrompt(call["input"]) + } + if args == "" { + args = "{}" + } + maybeWarnSuspiciousToolHistory(traceID, id, name, args) + entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args)) + } + } + + if legacy, ok := msg["function_call"].(map[string]any); ok { + name := strings.TrimSpace(asString(legacy["name"])) + if name == "" { + name = "unknown" + } + args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"]) + if args == "" { + args = "{}" + } + maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args) + entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args)) + } + + return strings.Join(entries, "\n\n") +} + +func formatToolResultForPrompt(msg map[string]any) string { + toolCallID := strings.TrimSpace(asString(msg["tool_call_id"])) + if toolCallID == "" { + toolCallID = strings.TrimSpace(asString(msg["id"])) + } + if toolCallID == "" { + toolCallID = "unknown" + } + + name := strings.TrimSpace(asString(msg["name"])) + if name == "" { + name = "unknown" + } + + content := normalizeOpenAIContentForPrompt(msg["content"]) + if content == "" { + content = "null" + } + + return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content) +} + +func normalizeOpenAIContentForPrompt(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, item := range x { + m, ok := item.(map[string]any) + if !ok { + continue + } + t := strings.ToLower(strings.TrimSpace(asString(m["type"]))) + if t != "text" && t != "output_text" && t != "input_text" { + continue + } + if text := asString(m["text"]); text != "" { + parts = append(parts, text) + continue + } + if text := asString(m["content"]); text != "" { + parts = append(parts, text) + } + } + return strings.Join(parts, "\n") + default: + return marshalToPromptString(v) + } +} + +func normalizeOpenAIArgumentsForPrompt(v any) string { + switch x := v.(type) { + case string: + return normalizeToolArgumentString(x) + default: + return marshalToPromptString(v) + } +} + +func normalizeToolArgumentString(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + if !looksLikeConcatenatedJSON(trimmed) { + return trimmed + } + dec := json.NewDecoder(strings.NewReader(trimmed)) + values := make([]any, 0, 2) + for { + var v any + if err := dec.Decode(&v); err != nil { + if err == io.EOF { + break + } + return trimmed + } + values = append(values, v) + } + if len(values) < 2 { + return trimmed + } + last := values[len(values)-1] + b, err := json.Marshal(last) + if err != nil || len(b) == 0 { + return trimmed + } + return string(b) +} + +func marshalToPromptString(v any) string { + b, err := json.Marshal(v) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } + return string(b) +} + +func asString(v any) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +func joinNonEmpty(parts ...string) string { + nonEmpty := make([]string, 0, len(parts)) + for _, p := range parts { + if strings.TrimSpace(p) == "" { + continue + } + nonEmpty = append(nonEmpty, p) + } + return strings.Join(nonEmpty, "\n\n") +} + +func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) { + if !looksLikeConcatenatedJSON(args) { + return + } + traceID = strings.TrimSpace(traceID) + if traceID == "" { + traceID = "unknown" + } + config.Logger.Warn( + "[openai] suspicious tool call history payload detected", + "trace_id", traceID, + "tool_call_id", strings.TrimSpace(callID), + "name", strings.TrimSpace(name), + "arguments_preview", previewToolArgs(args, 160), + ) +} + +func looksLikeConcatenatedJSON(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return false + } + if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") { + return true + } + dec := json.NewDecoder(strings.NewReader(trimmed)) + var first any + if err := dec.Decode(&first); err != nil { + return false + } + var second any + return dec.Decode(&second) == nil +} + +func previewToolArgs(raw string, max int) string { + trimmed := strings.TrimSpace(raw) + if max <= 0 || len(trimmed) <= max { + return trimmed + } + return trimmed[:max] +} diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go new file mode 100644 index 0000000..ff36bd9 --- /dev/null +++ b/internal/adapter/openai/message_normalize_test.go @@ -0,0 +1,198 @@ +package openai + +import ( + "strings" + "testing" + + "ds2api/internal/util" +) + +func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *testing.T) { + raw := []any{ + map[string]any{"role": "system", "content": "You are helpful"}, + map[string]any{"role": "user", "content": "查北京天气"}, + map[string]any{ + "role": "assistant", + "content": nil, + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"beijing\"}", + }, + }, + }, + }, + map[string]any{ + "role": "tool", + "tool_call_id": "call_1", + "name": "get_weather", + "content": "{\"temp\":18}", + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 4 { + t.Fatalf("expected 4 normalized messages, got %d", len(normalized)) + } + assistantContent, _ := normalized[2]["content"].(string) + if !strings.Contains(assistantContent, "[TOOL_CALL_HISTORY]") || + !strings.Contains(assistantContent, "tool_call_id: call_1") || + !strings.Contains(assistantContent, "function.name: get_weather") || + !strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") { + t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent) + } + toolContent, _ := normalized[3]["content"].(string) + if !strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") || !strings.Contains(toolContent, "name: get_weather") { + t.Fatalf("tool result not serialized correctly: %q", toolContent) + } + + prompt := util.MessagesPrepare(normalized) + if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") { + t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "tool", + "tool_call_id": "call_2", + "name": "get_weather", + "content": map[string]any{ + "temp": 18, + "condition": "sunny", + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) { + t.Fatalf("expected serialized object in tool content, got %q", got) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "tool", + "tool_call_id": "call_3", + "name": "read_file", + "content": []any{ + map[string]any{"type": "input_text", "text": "line-1"}, + map[string]any{"type": "output_text", "text": "line-2"}, + map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"}, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "line-1\nline-2") { + t.Fatalf("expected joined text blocks, got %q", got) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "function", + "tool_call_id": "call_4", + "name": "legacy_tool", + "content": map[string]any{ + "ok": true, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 1 { + t.Fatalf("expected one normalized message, got %d", len(normalized)) + } + if normalized[0]["role"] != "user" { + t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"]) + } + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) { + t.Fatalf("unexpected normalized function-role content: %q", got) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_search", + "type": "function", + "function": map[string]any{ + "name": "search_web", + "arguments": `{"query":"latest ai news"}`, + }, + }, + map[string]any{ + "id": "call_eval", + "type": "function", + "function": map[string]any{ + "name": "eval_javascript", + "arguments": `{"code":"1+1"}`, + }, + }, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 1 { + t.Fatalf("expected one normalized assistant message, got %d", len(normalized)) + } + content, _ := normalized[0]["content"].(string) + if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 { + t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content) + } + if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") { + t.Fatalf("missing first tool call block, got %q", content) + } + if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") { + t.Fatalf("missing second tool call block, got %q", content) + } + if strings.Contains(content, "search_webeval_javascript") { + t.Fatalf("unexpected merged function name detected: %q", content) + } + if strings.Contains(content, `}{"`) { + t.Fatalf("unexpected concatenated function arguments detected: %q", content) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "function": map[string]any{ + "name": "search_web", + "arguments": `{}{"query":"测试工具调用"}`, + }, + }, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 1 { + t.Fatalf("expected one normalized message, got %d", len(normalized)) + } + content, _ := normalized[0]["content"].(string) + if !strings.Contains(content, `function.arguments: {"query":"测试工具调用"}`) { + t.Fatalf("expected repaired arguments in tool history, got %q", content) + } + if strings.Contains(content, `{}{"query":"测试工具调用"}`) { + t.Fatalf("expected concatenated JSON to be repaired, got %q", content) + } +} diff --git a/internal/adapter/openai/models_route_test.go b/internal/adapter/openai/models_route_test.go new file mode 100644 index 0000000..1ba3382 --- /dev/null +++ b/internal/adapter/openai/models_route_test.go @@ -0,0 +1,46 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestGetModelRouteDirectAndAlias(t *testing.T) { + h := &Handler{} + r := chi.NewRouter() + RegisterRoutes(r, h) + + t.Run("direct", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/models/deepseek-chat", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("alias", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/models/gpt-4.1", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 for alias, got %d body=%s", rec.Code, rec.Body.String()) + } + }) +} + +func TestGetModelRouteNotFound(t *testing.T) { + h := &Handler{} + r := chi.NewRouter() + RegisterRoutes(r, h) + + req := httptest.NewRequest(http.MethodGet, "/v1/models/not-exists", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String()) + } +} diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go new file mode 100644 index 0000000..d6823b2 --- /dev/null +++ b/internal/adapter/openai/prompt_build.go @@ -0,0 +1,26 @@ +package openai + +import ( + "ds2api/internal/deepseek" + "ds2api/internal/util" +) + +func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) { + return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy()) +} + +func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy) (string, []string) { + messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID) + toolNames := []string{} + if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { + messages, toolNames = injectToolPrompt(messages, tools, toolPolicy) + } + return deepseek.MessagesPrepare(messages), toolNames +} + +// BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so +// other protocol adapters (for example Gemini) can reuse the same tool/history +// normalization logic and remain behavior-compatible with chat/completions. +func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) { + return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID) +} diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go new file mode 100644 index 0000000..bd6223e --- /dev/null +++ b/internal/adapter/openai/prompt_build_test.go @@ -0,0 +1,83 @@ +package openai + +import ( + "strings" + "testing" +) + +func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *testing.T) { + messages := []any{ + map[string]any{"role": "user", "content": "查北京天气"}, + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"beijing\"}", + }, + }, + }, + }, + map[string]any{ + "role": "tool", + "tool_call_id": "call_1", + "name": "get_weather", + "content": map[string]any{"temp": 18, "condition": "sunny"}, + }, + } + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "description": "Get weather", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + } + + finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "") + if len(toolNames) != 1 || toolNames[0] != "get_weather" { + t.Fatalf("unexpected tool names: %#v", toolNames) + } + if !strings.Contains(finalPrompt, "tool_call_id: call_1") || + !strings.Contains(finalPrompt, "function.name: get_weather") || + !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") || + !strings.Contains(finalPrompt, `"condition":"sunny"`) { + t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt) + } +} + +func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) { + messages := []any{ + map[string]any{"role": "system", "content": "You are helpful"}, + map[string]any{"role": "user", "content": "请调用工具"}, + } + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + "description": "search docs", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + } + + finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "") + if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") { + t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt) + } + if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") { + t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt) + } + if !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") { + t.Fatalf("vercel prepare finalPrompt missing history marker instruction: %q", finalPrompt) + } +} diff --git a/internal/adapter/openai/response_store.go b/internal/adapter/openai/response_store.go new file mode 100644 index 0000000..63ebbaa --- /dev/null +++ b/internal/adapter/openai/response_store.go @@ -0,0 +1,109 @@ +package openai + +import ( + "sync" + "time" + + "ds2api/internal/auth" +) + +type storedResponse struct { + Owner string + Value map[string]any + ExpiresAt time.Time +} + +type responseStore struct { + mu sync.Mutex + ttl time.Duration + items map[string]storedResponse +} + +func newResponseStore(ttl time.Duration) *responseStore { + if ttl <= 0 { + ttl = 15 * time.Minute + } + return &responseStore{ + ttl: ttl, + items: make(map[string]storedResponse), + } +} + +func responseStoreKey(owner, id string) string { + return owner + "\x00" + id +} + +func responseStoreOwner(a *auth.RequestAuth) string { + if a == nil { + return "" + } + return a.CallerID +} + +func (s *responseStore) put(owner, id string, value map[string]any) { + if s == nil || owner == "" || id == "" || value == nil { + return + } + now := time.Now() + s.mu.Lock() + defer s.mu.Unlock() + s.sweepLocked(now) + s.items[responseStoreKey(owner, id)] = storedResponse{ + Owner: owner, + Value: cloneAnyMap(value), + ExpiresAt: now.Add(s.ttl), + } +} + +func (s *responseStore) get(owner, id string) (map[string]any, bool) { + if s == nil || owner == "" || id == "" { + return nil, false + } + now := time.Now() + s.mu.Lock() + defer s.mu.Unlock() + s.sweepLocked(now) + item, ok := s.items[responseStoreKey(owner, id)] + if !ok { + return nil, false + } + if item.Owner != owner { + return nil, false + } + return cloneAnyMap(item.Value), true +} + +func (s *responseStore) sweepLocked(now time.Time) { + for k, v := range s.items { + if now.After(v.ExpiresAt) { + delete(s.items, k) + } + } +} + +func cloneAnyMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func (h *Handler) getResponseStore() *responseStore { + if h == nil { + return nil + } + h.responsesMu.Lock() + defer h.responsesMu.Unlock() + if h.responses == nil { + ttl := 15 * time.Minute + if h.Store != nil { + ttl = time.Duration(h.Store.ResponsesStoreTTLSeconds()) * time.Second + } + h.responses = newResponseStore(ttl) + } + return h.responses +} diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go new file mode 100644 index 0000000..a586682 --- /dev/null +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -0,0 +1,197 @@ +package openai + +import ( + "strings" + "testing" + "time" +) + +func TestNormalizeResponsesInputAsMessagesString(t *testing.T) { + msgs := normalizeResponsesInputAsMessages("hello") + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "user" || m["content"] != "hello" { + t.Fatalf("unexpected message: %#v", m) + } +} + +func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) { + req := map[string]any{ + "model": "gpt-4.1", + "input": "ping", + "instructions": "system text", + } + msgs := responsesMessagesFromRequest(req) + if len(msgs) != 2 { + t.Fatalf("expected two messages, got %d", len(msgs)) + } + sys, _ := msgs[0].(map[string]any) + if sys["role"] != "system" { + t.Fatalf("unexpected first message: %#v", sys) + } +} + +func TestNormalizeResponsesInputAsMessagesObjectRoleContentBlocks(t *testing.T) { + msgs := normalizeResponsesInputAsMessages(map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": "line-1"}, + map[string]any{"type": "input_text", "text": "line-2"}, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "user" { + t.Fatalf("unexpected role: %#v", m) + } + if strings.TrimSpace(normalizeOpenAIContentForPrompt(m["content"])) != "line-1\nline-2" { + t.Fatalf("unexpected content: %#v", m["content"]) + } +} + +func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call_output", + "call_id": "call_123", + "output": map[string]any{"ok": true}, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "tool" { + t.Fatalf("expected tool role, got %#v", m) + } + if m["tool_call_id"] != "call_123" { + t.Fatalf("expected tool_call_id propagated, got %#v", m) + } +} + +func TestNormalizeResponsesInputAsMessagesBackfillsToolResultNameFromCallID(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call", + "call_id": "call_999", + "name": "search", + "arguments": `{"q":"golang"}`, + }, + map[string]any{ + "type": "function_call_output", + "call_id": "call_999", + "output": map[string]any{"ok": true}, + }, + }) + if len(msgs) != 2 { + t.Fatalf("expected two messages, got %d", len(msgs)) + } + toolMsg, _ := msgs[1].(map[string]any) + if toolMsg["role"] != "tool" { + t.Fatalf("expected tool role, got %#v", toolMsg) + } + if toolMsg["name"] != "search" { + t.Fatalf("expected tool name backfilled from call_id, got %#v", toolMsg["name"]) + } +} + +func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call", + "call_id": "call_456", + "name": "search", + "arguments": `{"q":"golang"}`, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "assistant" { + t.Fatalf("expected assistant role, got %#v", m["role"]) + } + toolCalls, _ := m["tool_calls"].([]any) + if len(toolCalls) != 1 { + t.Fatalf("expected one tool_call, got %#v", m["tool_calls"]) + } + call, _ := toolCalls[0].(map[string]any) + if call["id"] != "call_456" { + t.Fatalf("expected call id preserved, got %#v", call) + } + if call["type"] != "function" { + t.Fatalf("expected function type, got %#v", call) + } + fn, _ := call["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("expected call name preserved, got %#v", call) + } + if fn["arguments"] != `{"q":"golang"}` { + t.Fatalf("expected call arguments preserved, got %#v", call) + } +} + +func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call", + "call_id": "call_456", + "name": "search", + "arguments": `{}{"q":"golang"}`, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + toolCalls, _ := m["tool_calls"].([]any) + call, _ := toolCalls[0].(map[string]any) + fn, _ := call["function"].(map[string]any) + if fn["arguments"] != `{"q":"golang"}` { + t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"]) + } +} + +func TestExtractEmbeddingInputs(t *testing.T) { + got := extractEmbeddingInputs([]any{"a", "b"}) + if len(got) != 2 || got[0] != "a" || got[1] != "b" { + t.Fatalf("unexpected inputs: %#v", got) + } +} + +func TestDeterministicEmbeddingStable(t *testing.T) { + a := deterministicEmbedding("hello") + b := deterministicEmbedding("hello") + if len(a) != 64 || len(b) != 64 { + t.Fatalf("expected 64 dims, got %d and %d", len(a), len(b)) + } + for i := range a { + if a[i] != b[i] { + t.Fatalf("expected stable embedding at %d: %v != %v", i, a[i], b[i]) + } + } +} + +func TestResponseStorePutGet(t *testing.T) { + st := newResponseStore(100 * time.Millisecond) + st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"}) + got, ok := st.get("owner_1", "resp_1") + if !ok { + t.Fatal("expected stored response") + } + if got["id"] != "resp_1" { + t.Fatalf("unexpected response payload: %#v", got) + } +} + +func TestResponseStoreTenantIsolation(t *testing.T) { + st := newResponseStore(100 * time.Millisecond) + st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"}) + if _, ok := st.get("owner_b", "resp_1"); ok { + t.Fatal("expected owner_b to be isolated from owner_a response") + } +} diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go new file mode 100644 index 0000000..81da92d --- /dev/null +++ b/internal/adapter/openai/responses_handler.go @@ -0,0 +1,221 @@ +package openai + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.DetermineCaller(r) + if err != nil { + writeOpenAIError(w, http.StatusUnauthorized, err.Error()) + return + } + + id := strings.TrimSpace(chi.URLParam(r, "response_id")) + if id == "" { + writeOpenAIError(w, http.StatusBadRequest, "response_id is required.") + return + } + owner := responseStoreOwner(a) + if owner == "" { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized") + return + } + st := h.getResponseStore() + item, ok := st.get(owner, id) + if !ok { + writeOpenAIError(w, http.StatusNotFound, "Response not found.") + return + } + writeJSON(w, http.StatusOK, item) +} + +func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + r = r.WithContext(auth.WithAuth(r.Context(), a)) + owner := responseStoreOwner(a) + if owner == "" { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized") + return + } + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + traceID := requestTraceID(r) + stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, traceID) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") + } + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + payload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") + return + } + + responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") + if stdReq.Stream { + h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID) + return + } + h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID) +} + +func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + result := sse.CollectStream(resp, thinkingEnabled, true) + textParsed := util.ParseToolCallsDetailed(result.Text, toolNames) + thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames) + logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text") + logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking") + + callCount := len(textParsed.Calls) + if callCount == 0 { + callCount = len(thinkingParsed.Calls) + } + if toolChoice.IsRequired() && callCount == 0 { + writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation") + return + } + + responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) + h.getResponseStore().put(owner, responseID, responseObj) + writeJSON(w, http.StatusOK, responseObj) +} + +func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + rc := http.NewResponseController(w) + _, canFlush := w.(http.Flusher) + + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() + emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() + + streamRuntime := newResponsesStreamRuntime( + w, + rc, + canFlush, + responseID, + model, + finalPrompt, + thinkingEnabled, + searchEnabled, + toolNames, + bufferToolContent, + emitEarlyToolDeltas, + toolChoice, + traceID, + func(obj map[string]any) { + h.getResponseStore().put(owner, responseID, obj) + }, + ) + streamRuntime.sendCreated() + + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnParsed: streamRuntime.onParsed, + OnFinalize: func(_ streamengine.StopReason, _ error) { + streamRuntime.finalize() + }, + }) +} + +func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolicy, parsed util.ToolCallParseResult, channel string) { + rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames) + if !parsed.RejectedByPolicy || len(rejected) == 0 { + return + } + config.Logger.Warn( + "[responses] rejected tool calls by policy", + "trace_id", strings.TrimSpace(traceID), + "channel", channel, + "tool_choice_mode", policy.Mode, + "rejected_tool_names", strings.Join(rejected, ","), + ) +} + +func filteredRejectedToolNamesForLog(names []string) []string { + if len(names) == 0 { + return nil + } + out := make([]string, 0, len(names)) + for _, name := range names { + trimmed := strings.TrimSpace(name) + switch strings.ToLower(trimmed) { + case "", "tool_name": + continue + default: + out = append(out, trimmed) + } + } + return out +} diff --git a/internal/adapter/openai/responses_input_items.go b/internal/adapter/openai/responses_input_items.go new file mode 100644 index 0000000..e0eea09 --- /dev/null +++ b/internal/adapter/openai/responses_input_items.go @@ -0,0 +1,203 @@ +package openai + +import ( + "encoding/json" + "fmt" + "strings" + + "ds2api/internal/config" +) + +func normalizeResponsesInputItem(m map[string]any) map[string]any { + return normalizeResponsesInputItemWithState(m, nil) +} + +func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[string]string) map[string]any { + if m == nil { + return nil + } + + role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) + if role != "" { + content := m["content"] + if content == nil { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + content = txt + } + } + if content == nil { + return nil + } + return map[string]any{ + "role": role, + "content": content, + } + } + + itemType := strings.ToLower(strings.TrimSpace(asString(m["type"]))) + switch itemType { + case "message", "input_message": + content := m["content"] + if content == nil { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + content = txt + } + } + if content == nil { + return nil + } + role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) + if role == "" { + role = "user" + } + return map[string]any{ + "role": role, + "content": content, + } + case "function_call_output", "tool_result": + content := m["output"] + if content == nil { + content = m["content"] + } + if content == nil { + content = "" + } + out := map[string]any{ + "role": "tool", + "content": content, + } + if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { + out["tool_call_id"] = callID + } else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" { + out["tool_call_id"] = callID + } + if name := strings.TrimSpace(asString(m["name"])); name != "" { + out["name"] = name + } else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" { + out["name"] = name + } else if callID := strings.TrimSpace(asString(out["tool_call_id"])); callID != "" { + if inferred := strings.TrimSpace(callNameByID[callID]); inferred != "" { + out["name"] = inferred + } else { + config.Logger.Warn( + "[responses] unable to backfill tool result name from call_id", + "call_id", callID, + ) + } + } + return out + case "function_call", "tool_call": + name := strings.TrimSpace(asString(m["name"])) + var fn map[string]any + if rawFn, ok := m["function"].(map[string]any); ok { + fn = rawFn + if name == "" { + name = strings.TrimSpace(asString(fn["name"])) + } + } + if name == "" { + return nil + } + + var argsRaw any + if v, ok := m["arguments"]; ok { + argsRaw = v + } else if v, ok := m["input"]; ok { + argsRaw = v + } + if argsRaw == nil && fn != nil { + if v, ok := fn["arguments"]; ok { + argsRaw = v + } else if v, ok := fn["input"]; ok { + argsRaw = v + } + } + + functionPayload := map[string]any{ + "name": name, + "arguments": stringifyToolCallArguments(argsRaw), + } + call := map[string]any{ + "type": "function", + "function": functionPayload, + } + if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { + call["id"] = callID + } else if callID = strings.TrimSpace(asString(m["id"])); callID != "" { + call["id"] = callID + } + if callID := strings.TrimSpace(asString(call["id"])); callID != "" && callNameByID != nil { + callNameByID[callID] = name + } + return map[string]any{ + "role": "assistant", + "tool_calls": []any{call}, + } + case "input_text": + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return map[string]any{ + "role": "user", + "content": txt, + } + } + } + + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return map[string]any{ + "role": "user", + "content": txt, + } + } + if content, ok := m["content"]; ok { + if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { + return map[string]any{ + "role": "user", + "content": content, + } + } + } + return nil +} + +func normalizeResponsesFallbackPart(m map[string]any) string { + if m == nil { + return "" + } + if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return txt + } + } + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return txt + } + if content, ok := m["content"]; ok { + if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" { + return normalized + } + } + return strings.TrimSpace(fmt.Sprintf("%v", m)) +} + +func stringifyToolCallArguments(v any) string { + switch x := v.(type) { + case nil: + return "{}" + case string: + s := strings.TrimSpace(x) + if s == "" { + return "{}" + } + s = normalizeToolArgumentString(s) + if s == "" { + return "{}" + } + return s + default: + b, err := json.Marshal(x) + if err != nil || len(b) == 0 { + return "{}" + } + return string(b) + } +} diff --git a/internal/adapter/openai/responses_input_normalize.go b/internal/adapter/openai/responses_input_normalize.go new file mode 100644 index 0000000..6514669 --- /dev/null +++ b/internal/adapter/openai/responses_input_normalize.go @@ -0,0 +1,94 @@ +package openai + +import ( + "fmt" + "strings" +) + +func responsesMessagesFromRequest(req map[string]any) []any { + if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { + return prependInstructionMessage(msgs, req["instructions"]) + } + if rawInput, ok := req["input"]; ok { + if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 { + return prependInstructionMessage(msgs, req["instructions"]) + } + } + return nil +} + +func prependInstructionMessage(messages []any, instructions any) []any { + sys, _ := instructions.(string) + sys = strings.TrimSpace(sys) + if sys == "" { + return messages + } + out := make([]any, 0, len(messages)+1) + out = append(out, map[string]any{"role": "system", "content": sys}) + out = append(out, messages...) + return out +} + +func normalizeResponsesInputAsMessages(input any) []any { + switch v := input.(type) { + case string: + if strings.TrimSpace(v) == "" { + return nil + } + return []any{map[string]any{"role": "user", "content": v}} + case []any: + return normalizeResponsesInputArray(v) + case map[string]any: + if msg := normalizeResponsesInputItem(v); msg != nil { + return []any{msg} + } + if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" { + return []any{map[string]any{"role": "user", "content": txt}} + } + if content, ok := v["content"]; ok { + if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { + return []any{map[string]any{"role": "user", "content": content}} + } + } + } + return nil +} + +func normalizeResponsesInputArray(items []any) []any { + if len(items) == 0 { + return nil + } + out := make([]any, 0, len(items)) + callNameByID := map[string]string{} + fallbackParts := make([]string, 0, len(items)) + flushFallback := func() { + if len(fallbackParts) == 0 { + return + } + out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")}) + fallbackParts = fallbackParts[:0] + } + + for _, item := range items { + switch x := item.(type) { + case map[string]any: + if msg := normalizeResponsesInputItemWithState(x, callNameByID); msg != nil { + flushFallback() + out = append(out, msg) + continue + } + if s := normalizeResponsesFallbackPart(x); s != "" { + fallbackParts = append(fallbackParts, s) + } + default: + if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" { + fallbackParts = append(fallbackParts, s) + } + } + } + flushFallback() + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/adapter/openai/responses_route_test.go b/internal/adapter/openai/responses_route_test.go new file mode 100644 index 0000000..574c6fa --- /dev/null +++ b/internal/adapter/openai/responses_route_test.go @@ -0,0 +1,176 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}] + }`) + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0") + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth { + t.Helper() + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer "+token) + a, err := resolver.Determine(req) + if err != nil { + t.Fatalf("determine auth failed: %v", err) + } + return a +} + +func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) { + store, resolver := newDirectTokenResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + ownerA := responseStoreOwner(authForToken(t, resolver, "token-a")) + h.getResponseStore().put(ownerA, "resp_test", map[string]any{ + "id": "resp_test", + "object": "response", + }) + + t.Run("unauthorized", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("cross-tenant-not-found", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer token-b") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("same-tenant-ok", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer token-a") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body failed: %v", err) + } + if body["id"] != "resp_test" { + t.Fatalf("unexpected body: %#v", body) + } + }) +} + +func TestResponsesRouteValidationContract(t *testing.T) { + store, resolver := newDirectTokenResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + tests := []struct { + name string + body string + }{ + {name: "missing_model", body: `{"input":"hello"}`}, + {name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body)) + req.Header.Set("Authorization", "Bearer token-a") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + errObj, _ := out["error"].(map[string]any) + if _, ok := errObj["code"]; !ok { + t.Fatalf("expected error.code: %#v", out) + } + if _, ok := errObj["param"]; !ok { + t.Fatalf("expected error.param: %#v", out) + } + }) + } +} + +func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) { + store, resolver := newManagedKeyResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + ownerReq.Header.Set("Authorization", "Bearer managed-key") + ownerAuth, err := resolver.DetermineCaller(ownerReq) + if err != nil { + t.Fatalf("determine caller failed: %v", err) + } + owner := responseStoreOwner(ownerAuth) + h.getResponseStore().put(owner, "resp_test", map[string]any{ + "id": "resp_test", + "object": "response", + }) + + occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + occupyReq.Header.Set("Authorization", "Bearer managed-key") + occupied, err := resolver.Determine(occupyReq) + if err != nil { + t.Fatalf("expected first acquire to succeed: %v", err) + } + defer resolver.Release(occupied) + + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer managed-key") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String()) + } +} diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go new file mode 100644 index 0000000..02303d0 --- /dev/null +++ b/internal/adapter/openai/responses_stream_runtime_core.go @@ -0,0 +1,225 @@ +package openai + +import ( + "net/http" + "strings" + + "ds2api/internal/config" + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +type responsesStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + responseID string + model string + finalPrompt string + toolNames []string + traceID string + toolChoice util.ToolChoicePolicy + + thinkingEnabled bool + searchEnabled bool + + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + toolCallsDoneEmitted bool + + sieve toolStreamSieveState + thinkingSieve toolStreamSieveState + thinking strings.Builder + text strings.Builder + visibleText strings.Builder + streamToolCallIDs map[int]string + functionItemIDs map[int]string + functionOutputIDs map[int]int + functionArgs map[int]string + functionDone map[int]bool + functionAdded map[int]bool + functionNames map[int]string + messageItemID string + messageOutputID int + nextOutputID int + messageAdded bool + messagePartAdded bool + sequence int + failed bool + + persistResponse func(obj map[string]any) +} + +func newResponsesStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + responseID string, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, + bufferToolContent bool, + emitEarlyToolDeltas bool, + toolChoice util.ToolChoicePolicy, + traceID string, + persistResponse func(obj map[string]any), +) *responsesStreamRuntime { + return &responsesStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + responseID: responseID, + model: model, + finalPrompt: finalPrompt, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + toolNames: toolNames, + bufferToolContent: bufferToolContent, + emitEarlyToolDeltas: emitEarlyToolDeltas, + streamToolCallIDs: map[int]string{}, + functionItemIDs: map[int]string{}, + functionOutputIDs: map[int]int{}, + functionArgs: map[int]string{}, + functionDone: map[int]bool{}, + functionAdded: map[int]bool{}, + functionNames: map[int]string{}, + messageOutputID: -1, + toolChoice: toolChoice, + traceID: traceID, + persistResponse: persistResponse, + } +} + +func (s *responsesStreamRuntime) finalize() { + finalThinking := s.thinking.String() + finalText := s.text.String() + + if s.bufferToolContent { + s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) + s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false) + } + + textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames) + thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames) + detected := textParsed.Calls + if len(detected) == 0 { + detected = thinkingParsed.Calls + } + s.logToolPolicyRejections(textParsed, thinkingParsed) + + if len(detected) > 0 { + s.toolCallsEmitted = true + if !s.toolCallsDoneEmitted { + s.emitFunctionCallDoneEvents(detected) + } + } + + s.closeMessageItem() + + if s.toolChoice.IsRequired() && len(detected) == 0 { + s.failed = true + message := "tool_choice requires at least one valid tool call." + failedResp := map[string]any{ + "id": s.responseID, + "type": "response", + "object": "response", + "model": s.model, + "status": "failed", + "output": []any{}, + "output_text": "", + "error": map[string]any{ + "message": message, + "type": "invalid_request_error", + "code": "tool_choice_violation", + "param": nil, + }, + } + if s.persistResponse != nil { + s.persistResponse(failedResp) + } + s.sendEvent("response.failed", openaifmt.BuildResponsesFailedPayload(s.responseID, s.model, message, "tool_choice_violation")) + s.sendDone() + return + } + s.closeIncompleteFunctionItems() + + obj := s.buildCompletedResponseObject(finalThinking, finalText, detected) + if s.persistResponse != nil { + s.persistResponse(obj) + } + s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj)) + s.sendDone() +} + +func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) { + logRejected := func(parsed util.ToolCallParseResult, channel string) { + rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames) + if !parsed.RejectedByPolicy || len(rejected) == 0 { + return + } + config.Logger.Warn( + "[responses] rejected tool calls by policy", + "trace_id", strings.TrimSpace(s.traceID), + "channel", channel, + "tool_choice_mode", s.toolChoice.Mode, + "rejected_tool_names", strings.Join(rejected, ","), + ) + } + logRejected(textParsed, "text") + logRejected(thinkingParsed, "thinking") +} + +func (s *responsesStreamRuntime) hasFunctionCallDone() bool { + for _, done := range s.functionDone { + if done { + return true + } + } + return false +} + +func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + if p.Type == "thinking" { + if !s.thinkingEnabled { + continue + } + s.thinking.WriteString(p.Text) + s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) + if s.bufferToolContent { + s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false) + } + continue + } + + s.text.WriteString(p.Text) + if !s.bufferToolContent { + s.emitTextDelta(p.Text) + continue + } + s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true) + } + + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} diff --git a/internal/adapter/openai/responses_stream_runtime_events.go b/internal/adapter/openai/responses_stream_runtime_events.go new file mode 100644 index 0000000..792d0ce --- /dev/null +++ b/internal/adapter/openai/responses_stream_runtime_events.go @@ -0,0 +1,61 @@ +package openai + +import ( + "encoding/json" + + openaifmt "ds2api/internal/format/openai" +) + +func (s *responsesStreamRuntime) nextSequence() int { + s.sequence++ + return s.sequence +} + +func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) { + if payload == nil { + payload = map[string]any{} + } + if _, ok := payload["sequence_number"]; !ok { + payload["sequence_number"] = s.nextSequence() + } + b, _ := json.Marshal(payload) + _, _ = s.w.Write([]byte("event: " + event + "\n")) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *responsesStreamRuntime) sendCreated() { + s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model)) +} + +func (s *responsesStreamRuntime) sendDone() { + _, _ = s.w.Write([]byte("data: [DONE]\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) { + for _, evt := range events { + if emitContent && evt.Content != "" { + s.emitTextDelta(evt.Content) + } + if len(evt.ToolCallDeltas) > 0 { + if !s.emitEarlyToolDeltas { + continue + } + filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames) + if len(filtered) == 0 { + continue + } + s.emitFunctionCallDeltaEvents(filtered) + } + if len(evt.ToolCalls) > 0 { + s.emitFunctionCallDoneEvents(evt.ToolCalls) + } + } +} diff --git a/internal/adapter/openai/responses_stream_runtime_toolcalls.go b/internal/adapter/openai/responses_stream_runtime_toolcalls.go new file mode 100644 index 0000000..9947cbd --- /dev/null +++ b/internal/adapter/openai/responses_stream_runtime_toolcalls.go @@ -0,0 +1,235 @@ +package openai + +import ( + "encoding/json" + "strings" + + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/util" + + "github.com/google/uuid" +) + +func (s *responsesStreamRuntime) allocateOutputIndex() int { + idx := s.nextOutputID + s.nextOutputID++ + return idx +} + +func (s *responsesStreamRuntime) ensureMessageItemID() string { + if strings.TrimSpace(s.messageItemID) != "" { + return s.messageItemID + } + s.messageItemID = "msg_" + strings.ReplaceAll(uuid.NewString(), "-", "") + return s.messageItemID +} + +func (s *responsesStreamRuntime) ensureMessageOutputIndex() int { + if s.messageOutputID >= 0 { + return s.messageOutputID + } + s.messageOutputID = s.allocateOutputIndex() + return s.messageOutputID +} + +func (s *responsesStreamRuntime) ensureMessageItemAdded() { + if s.messageAdded { + return + } + itemID := s.ensureMessageItemID() + item := map[string]any{ + "id": itemID, + "type": "message", + "role": "assistant", + "status": "in_progress", + } + s.sendEvent( + "response.output_item.added", + openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.ensureMessageOutputIndex(), item), + ) + s.messageAdded = true +} + +func (s *responsesStreamRuntime) ensureMessageContentPartAdded() { + if s.messagePartAdded { + return + } + s.ensureMessageItemAdded() + s.sendEvent( + "response.content_part.added", + openaifmt.BuildResponsesContentPartAddedPayload( + s.responseID, + s.ensureMessageItemID(), + s.ensureMessageOutputIndex(), + 0, + map[string]any{"type": "output_text", "text": ""}, + ), + ) + s.messagePartAdded = true +} + +func (s *responsesStreamRuntime) emitTextDelta(content string) { + if strings.TrimSpace(content) == "" { + return + } + s.ensureMessageContentPartAdded() + s.visibleText.WriteString(content) + s.sendEvent( + "response.output_text.delta", + openaifmt.BuildResponsesTextDeltaPayload( + s.responseID, + s.ensureMessageItemID(), + s.ensureMessageOutputIndex(), + 0, + content, + ), + ) +} + +func (s *responsesStreamRuntime) closeMessageItem() { + if !s.messageAdded { + return + } + itemID := s.ensureMessageItemID() + outputIndex := s.ensureMessageOutputIndex() + text := s.visibleText.String() + if s.messagePartAdded { + s.sendEvent( + "response.content_part.done", + openaifmt.BuildResponsesContentPartDonePayload( + s.responseID, + itemID, + outputIndex, + 0, + map[string]any{"type": "output_text", "text": text}, + ), + ) + s.messagePartAdded = false + } + item := map[string]any{ + "id": itemID, + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]any{ + { + "type": "output_text", + "text": text, + }, + }, + } + s.sendEvent( + "response.output_item.done", + openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item), + ) +} + +func (s *responsesStreamRuntime) ensureFunctionItemID(callIndex int) string { + if id, ok := s.functionItemIDs[callIndex]; ok && strings.TrimSpace(id) != "" { + return id + } + id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "") + s.functionItemIDs[callIndex] = id + return id +} + +func (s *responsesStreamRuntime) ensureToolCallID(callIndex int) string { + if id, ok := s.streamToolCallIDs[callIndex]; ok && strings.TrimSpace(id) != "" { + return id + } + id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + s.streamToolCallIDs[callIndex] = id + return id +} + +func (s *responsesStreamRuntime) ensureFunctionOutputIndex(callIndex int) int { + if idx, ok := s.functionOutputIDs[callIndex]; ok { + return idx + } + idx := s.allocateOutputIndex() + s.functionOutputIDs[callIndex] = idx + return idx +} + +func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name string) { + if strings.TrimSpace(name) != "" { + s.functionNames[callIndex] = strings.TrimSpace(name) + } + if s.functionAdded[callIndex] { + return + } + fnName := strings.TrimSpace(s.functionNames[callIndex]) + if fnName == "" { + return + } + outputIndex := s.ensureFunctionOutputIndex(callIndex) + itemID := s.ensureFunctionItemID(callIndex) + callID := s.ensureToolCallID(callIndex) + item := map[string]any{ + "id": itemID, + "type": "function_call", + "call_id": callID, + "name": fnName, + "arguments": "", + "status": "in_progress", + } + s.sendEvent( + "response.output_item.added", + openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, outputIndex, item), + ) + s.functionAdded[callIndex] = true + s.toolCallsEmitted = true +} + +func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) { + for _, d := range deltas { + s.ensureFunctionItemAdded(d.Index, d.Name) + if strings.TrimSpace(d.Arguments) == "" { + continue + } + s.functionArgs[d.Index] += d.Arguments + outputIndex := s.ensureFunctionOutputIndex(d.Index) + itemID := s.ensureFunctionItemID(d.Index) + callID := s.ensureToolCallID(d.Index) + s.sendEvent( + "response.function_call_arguments.delta", + openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments), + ) + } +} + +func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) { + for idx, tc := range calls { + if strings.TrimSpace(tc.Name) == "" { + continue + } + s.ensureFunctionItemAdded(idx, tc.Name) + if s.functionDone[idx] { + continue + } + outputIndex := s.ensureFunctionOutputIndex(idx) + itemID := s.ensureFunctionItemID(idx) + callID := s.ensureToolCallID(idx) + argsBytes, _ := json.Marshal(tc.Input) + args := string(argsBytes) + s.functionArgs[idx] = args + s.sendEvent( + "response.function_call_arguments.done", + openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, args), + ) + item := map[string]any{ + "id": itemID, + "type": "function_call", + "call_id": callID, + "name": tc.Name, + "arguments": args, + "status": "completed", + } + s.sendEvent( + "response.output_item.done", + openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item), + ) + s.functionDone[idx] = true + s.toolCallsDoneEmitted = true + } +} diff --git a/internal/adapter/openai/responses_stream_runtime_toolcalls_finalize.go b/internal/adapter/openai/responses_stream_runtime_toolcalls_finalize.go new file mode 100644 index 0000000..46104a1 --- /dev/null +++ b/internal/adapter/openai/responses_stream_runtime_toolcalls_finalize.go @@ -0,0 +1,156 @@ +package openai + +import ( + "encoding/json" + "sort" + "strings" + + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/util" +) + +func (s *responsesStreamRuntime) closeIncompleteFunctionItems() { + if len(s.functionAdded) == 0 { + return + } + indices := make([]int, 0, len(s.functionAdded)) + for idx, added := range s.functionAdded { + if !added || s.functionDone[idx] { + continue + } + indices = append(indices, idx) + } + if len(indices) == 0 { + return + } + sort.Ints(indices) + for _, idx := range indices { + name := strings.TrimSpace(s.functionNames[idx]) + if name == "" { + continue + } + args := strings.TrimSpace(s.functionArgs[idx]) + if args == "" { + args = "{}" + } + outputIndex := s.ensureFunctionOutputIndex(idx) + itemID := s.ensureFunctionItemID(idx) + callID := s.ensureToolCallID(idx) + s.sendEvent( + "response.function_call_arguments.done", + openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, name, args), + ) + item := map[string]any{ + "id": itemID, + "type": "function_call", + "call_id": callID, + "name": name, + "arguments": args, + "status": "completed", + } + s.sendEvent( + "response.output_item.done", + openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item), + ) + s.functionDone[idx] = true + s.toolCallsDoneEmitted = true + } +} + +func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, finalText string, calls []util.ParsedToolCall) map[string]any { + type indexedItem struct { + index int + item map[string]any + } + indexed := make([]indexedItem, 0, len(calls)+1) + + if s.messageAdded { + text := s.visibleText.String() + indexed = append(indexed, indexedItem{ + index: s.ensureMessageOutputIndex(), + item: map[string]any{ + "id": s.ensureMessageItemID(), + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]any{ + { + "type": "output_text", + "text": text, + }, + }, + }, + }) + } else if len(calls) == 0 { + content := make([]map[string]any, 0, 2) + if strings.TrimSpace(finalThinking) != "" { + content = append(content, map[string]any{ + "type": "reasoning", + "text": finalThinking, + }) + } + if strings.TrimSpace(finalText) != "" { + content = append(content, map[string]any{ + "type": "output_text", + "text": finalText, + }) + } + if len(content) > 0 { + indexed = append(indexed, indexedItem{ + index: s.ensureMessageOutputIndex(), + item: map[string]any{ + "id": s.ensureMessageItemID(), + "type": "message", + "role": "assistant", + "status": "completed", + "content": content, + }, + }) + } + } + + for idx, tc := range calls { + if strings.TrimSpace(tc.Name) == "" { + continue + } + argsBytes, _ := json.Marshal(tc.Input) + indexed = append(indexed, indexedItem{ + index: s.ensureFunctionOutputIndex(idx), + item: map[string]any{ + "id": s.ensureFunctionItemID(idx), + "type": "function_call", + "call_id": s.ensureToolCallID(idx), + "name": tc.Name, + "arguments": string(argsBytes), + "status": "completed", + }, + }) + } + + sort.SliceStable(indexed, func(i, j int) bool { + return indexed[i].index < indexed[j].index + }) + output := make([]any, 0, len(indexed)) + for _, it := range indexed { + output = append(output, it.item) + } + + outputText := s.visibleText.String() + if strings.TrimSpace(outputText) == "" && len(calls) == 0 { + if strings.TrimSpace(finalText) != "" { + outputText = finalText + } else if strings.TrimSpace(finalThinking) != "" { + outputText = finalThinking + } + } + + return openaifmt.BuildResponseObjectFromItems( + s.responseID, + s.model, + s.finalPrompt, + finalThinking, + finalText, + output, + outputText, + ) +} diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go new file mode 100644 index 0000000..ca3c4a3 --- /dev/null +++ b/internal/adapter/openai/responses_stream_test.go @@ -0,0 +1,611 @@ +package openai + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "ds2api/internal/util" +) + +func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}` + streamBody := sseLine(rawToolJSON) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") + + completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") + if !ok { + t.Fatalf("expected response.completed event, body=%s", rec.Body.String()) + } + responseObj, _ := completed["response"].(map[string]any) + outputText, _ := responseObj["output_text"].(string) + if outputText != "" { + t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText) + } + output, _ := responseObj["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected structured output entries, got %#v", responseObj["output"]) + } + hasFunctionCall := false + hasLegacyWrapper := false + for _, item := range output { + m, _ := item.(map[string]any) + if m == nil { + continue + } + if m["type"] == "function_call" { + hasFunctionCall = true + } + if m["type"] == "tool_calls" { + hasLegacyWrapper = true + } + } + if !hasFunctionCall { + t.Fatalf("expected function_call item, got %#v", responseObj["output"]) + } + if hasLegacyWrapper { + t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"]) + } + if strings.Contains(outputText, `"tool_calls"`) { + t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText) + } +} + +func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") + body := rec.Body.String() + if !strings.Contains(body, "event: response.output_item.added") { + t.Fatalf("expected response.output_item.added event, body=%s", body) + } + if !strings.Contains(body, "event: response.output_item.done") { + t.Fatalf("expected response.output_item.done event, body=%s", body) + } + if !strings.Contains(body, "event: response.function_call_arguments.delta") { + t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body) + } + if !strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("expected response.function_call_arguments.done event, body=%s", body) + } + if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") { + t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body) + } + + addedPayloads := extractAllSSEEventPayloads(body, "response.output_item.added") + hasFunctionCallAdded := false + for _, payload := range addedPayloads { + item, _ := payload["item"].(map[string]any) + if item == nil || asString(item["type"]) != "function_call" { + continue + } + hasFunctionCallAdded = true + if asString(item["arguments"]) != "" { + t.Fatalf("expected in-progress function_call.arguments to start empty string, got %#v", item["arguments"]) + } + } + if !hasFunctionCallAdded { + t.Fatalf("expected function_call output_item.added payload, body=%s", body) + } + + donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done") + if !ok { + t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body) + } + doneCallID := strings.TrimSpace(asString(donePayload["call_id"])) + if doneCallID == "" { + t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload) + } + completed, ok := extractSSEEventPayload(body, "response.completed") + if !ok { + t.Fatalf("expected response.completed payload, body=%s", body) + } + responseObj, _ := completed["response"].(map[string]any) + output, _ := responseObj["output"].([]any) + var completedCallID string + for _, item := range output { + m, _ := item.(map[string]any) + if m == nil || m["type"] != "function_call" { + continue + } + completedCallID = strings.TrimSpace(asString(m["call_id"])) + if completedCallID != "" { + break + } + } + if completedCallID == "" { + t.Fatalf("expected function_call.call_id in completed output, output=%#v", output) + } + if completedCallID != doneCallID { + t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID) + } +} + +func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + b, _ := json.Marshal(map[string]any{ + "p": "response/thinking_content", + "v": "thought", + }) + streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "") + + body := rec.Body.String() + if !strings.Contains(body, "event: response.reasoning.delta") { + t.Fatalf("expected response.reasoning.delta event, body=%s", body) + } + if strings.Contains(body, "event: response.reasoning_text.delta") || strings.Contains(body, "event: response.reasoning_text.done") { + t.Fatalf("did not expect response.reasoning_text.* compatibility events, body=%s", body) + } +} + +func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) + + sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "") + + body := rec.Body.String() + donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done") + if len(donePayloads) != 2 { + t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body) + } + seenNames := map[string]string{} + for _, payload := range donePayloads { + name := strings.TrimSpace(asString(payload["name"])) + callID := strings.TrimSpace(asString(payload["call_id"])) + if name != "search_web" && name != "eval_javascript" { + t.Fatalf("unexpected tool name in done payload: %#v", payload) + } + if callID == "" { + t.Fatalf("expected non-empty call_id in done payload: %#v", payload) + } + seenNames[name] = callID + } + if seenNames["search_web"] == seenNames["eval_javascript"] { + t.Fatalf("expected distinct call_id per tool, got %#v", seenNames) + } +} + +func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("hello") + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "") + body := rec.Body.String() + + deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta") + if !ok { + t.Fatalf("expected response.output_text.delta payload, body=%s", body) + } + if strings.TrimSpace(asString(deltaPayload["item_id"])) == "" { + t.Fatalf("expected non-empty item_id in output_text.delta, payload=%#v", deltaPayload) + } + if _, ok := deltaPayload["output_index"]; !ok { + t.Fatalf("expected output_index in output_text.delta, payload=%#v", deltaPayload) + } + if _, ok := deltaPayload["content_index"]; !ok { + t.Fatalf("expected content_index in output_text.delta, payload=%#v", deltaPayload) + } +} + +func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(path, value string) string { + b, _ := json.Marshal(map[string]any{ + "p": path, + "v": value, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("response/thinking_content", "thinking...") + + sseLine("response/content", "先读取文件。") + + sseLine("response/content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") + + addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added") + if len(addedPayloads) < 2 { + t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String()) + } + + indexes := map[int]struct{}{} + typeByIndex := map[int]string{} + addedIDs := map[string]string{} + for _, payload := range addedPayloads { + item, _ := payload["item"].(map[string]any) + itemType := strings.TrimSpace(asString(item["type"])) + outputIndex := int(asFloat(payload["output_index"])) + if _, exists := indexes[outputIndex]; exists { + t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload) + } + indexes[outputIndex] = struct{}{} + typeByIndex[outputIndex] = itemType + addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"])) + } + + completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") + if !ok { + t.Fatalf("expected response.completed payload, body=%s", rec.Body.String()) + } + responseObj, _ := completedPayload["response"].(map[string]any) + output, _ := responseObj["output"].([]any) + found := map[string]bool{} + for _, item := range output { + m, _ := item.(map[string]any) + itemType := strings.TrimSpace(asString(m["type"])) + itemID := strings.TrimSpace(asString(m["id"])) + if itemType == "" || itemID == "" { + continue + } + if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID { + found[itemType] = true + } + } + if !found["message"] || !found["function_call"] { + t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output) + } +} + +func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone} + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "") + body := rec.Body.String() + if strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("did not expect function_call events for tool_choice=none, body=%s", body) + } +} + +func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + // invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it + streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") + body := rec.Body.String() + if !strings.Contains(body, "event: response.function_call_arguments.delta") { + t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body) + } + if !strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body) + } + if !strings.Contains(body, "event: response.output_item.done") { + t.Fatalf("expected runtime to close function output item, body=%s", body) + } + if !strings.Contains(body, "event: response.completed") { + t.Fatalf("expected response.completed event, body=%s", body) + } +} + +func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("plain text only") + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + policy := util.ToolChoicePolicy{ + Mode: util.ToolChoiceRequired, + Allowed: map[string]struct{}{"read_file": {}}, + } + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "") + + body := rec.Body.String() + if !strings.Contains(body, "event: response.failed") { + t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body) + } + if strings.Contains(body, "event: response.completed") { + t.Fatalf("did not expect response.completed after failure, body=%s", body) + } +} + +func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + policy := util.ToolChoicePolicy{ + Mode: util.ToolChoiceRequired, + Allowed: map[string]struct{}{"read_file": {}}, + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "") + + body := rec.Body.String() + if !strings.Contains(body, "event: response.failed") { + t.Fatalf("expected response.failed event, body=%s", body) + } + if strings.Contains(body, "event: response.completed") { + t.Fatalf("did not expect response.completed, body=%s", body) + } +} + +func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") + body := rec.Body.String() + if strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("did not expect function_call events for unknown tool, body=%s", body) + } +} + +func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `data: {"p":"response/content","v":"plain text only"}` + "\n" + + `data: [DONE]` + "\n", + )), + } + policy := util.ToolChoicePolicy{ + Mode: util.ToolChoiceRequired, + Allowed: map[string]struct{}{"read_file": {}}, + } + + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "") + if rec.Code != http.StatusUnprocessableEntity { + t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String()) + } + out := decodeJSONBody(t, rec.Body.String()) + errObj, _ := out["error"].(map[string]any) + if asString(errObj["code"]) != "tool_choice_violation" { + t.Fatalf("expected code=tool_choice_violation, got %#v", out) + } +} + +func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" + + `data: [DONE]` + "\n", + )), + } + policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone} + + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "") + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 for tool_choice=none passthrough text, got %d body=%s", rec.Code, rec.Body.String()) + } + out := decodeJSONBody(t, rec.Body.String()) + output, _ := out["output"].([]any) + for _, item := range output { + m, _ := item.(map[string]any) + if m != nil && m["type"] == "function_call" { + t.Fatalf("did not expect function_call output item for tool_choice=none, got %#v", output) + } + } +} + +func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { + scanner := bufio.NewScanner(strings.NewReader(body)) + matched := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "event: ") { + evt := strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + matched = evt == targetEvent + continue + } + if !matched || !strings.HasPrefix(line, "data: ") { + continue + } + raw := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + if raw == "" || raw == "[DONE]" { + continue + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return nil, false + } + return payload, true + } + return nil, false +} + +func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any { + scanner := bufio.NewScanner(strings.NewReader(body)) + matched := false + out := make([]map[string]any, 0, 2) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "event: ") { + evt := strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + matched = evt == targetEvent + continue + } + if !matched || !strings.HasPrefix(line, "data: ") { + continue + } + raw := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + if raw == "" || raw == "[DONE]" { + continue + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + continue + } + out = append(out, payload) + } + return out +} + +func asFloat(v any) float64 { + switch x := v.(type) { + case float64: + return x + case float32: + return float64(x) + case int: + return float64(x) + case int64: + return float64(x) + default: + return 0 + } +} diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go new file mode 100644 index 0000000..1ba957c --- /dev/null +++ b/internal/adapter/openai/standard_request.go @@ -0,0 +1,326 @@ +package openai + +import ( + "fmt" + "strings" + + "ds2api/internal/config" + "ds2api/internal/util" +) + +func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) { + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.") + } + resolvedModel, ok := config.ResolveModel(store, model) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + responseModel := strings.TrimSpace(model) + if responseModel == "" { + responseModel = resolvedModel + } + toolPolicy := util.DefaultToolChoicePolicy() + finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy) + passThrough := collectOpenAIChatPassThrough(req) + + return util.StandardRequest{ + Surface: "openai_chat", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: resolvedModel, + ResponseModel: responseModel, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + ToolChoice: toolPolicy, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) { + model, _ := req["model"].(string) + model = strings.TrimSpace(model) + if model == "" { + return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.") + } + resolvedModel, ok := config.ResolveModel(store, model) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + + // Keep width-control as an explicit policy hook even if current default is true. + allowWideInput := true + if store != nil { + allowWideInput = store.CompatWideInputStrictOutput() + } + var messagesRaw []any + if allowWideInput { + messagesRaw = responsesMessagesFromRequest(req) + } else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { + messagesRaw = msgs + } + if len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.") + } + toolPolicy, err := parseToolChoicePolicy(req["tool_choice"], req["tools"]) + if err != nil { + return util.StandardRequest{}, err + } + finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy) + if toolPolicy.IsNone() { + toolNames = nil + toolPolicy.Allowed = nil + } else { + toolPolicy.Allowed = namesToSet(toolNames) + } + passThrough := collectOpenAIChatPassThrough(req) + + return util.StandardRequest{ + Surface: "openai_responses", + RequestedModel: model, + ResolvedModel: resolvedModel, + ResponseModel: model, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + ToolChoice: toolPolicy, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func collectOpenAIChatPassThrough(req map[string]any) map[string]any { + out := map[string]any{} + for _, k := range []string{ + "temperature", + "top_p", + "max_tokens", + "max_completion_tokens", + "presence_penalty", + "frequency_penalty", + "stop", + } { + if v, ok := req[k]; ok { + out[k] = v + } + } + return out +} + +func parseToolChoicePolicy(toolChoiceRaw any, toolsRaw any) (util.ToolChoicePolicy, error) { + policy := util.DefaultToolChoicePolicy() + declaredNames := extractDeclaredToolNames(toolsRaw) + declaredSet := namesToSet(declaredNames) + if len(declaredNames) > 0 { + policy.Allowed = declaredSet + } + + if toolChoiceRaw == nil { + return policy, nil + } + + switch v := toolChoiceRaw.(type) { + case string: + switch strings.ToLower(strings.TrimSpace(v)) { + case "", "auto": + policy.Mode = util.ToolChoiceAuto + case "none": + policy.Mode = util.ToolChoiceNone + policy.Allowed = nil + case "required": + policy.Mode = util.ToolChoiceRequired + default: + return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice: %q", v) + } + case map[string]any: + allowedOverride, hasAllowedOverride, err := parseAllowedToolNames(v["allowed_tools"]) + if err != nil { + return util.ToolChoicePolicy{}, err + } + if hasAllowedOverride { + filtered := make([]string, 0, len(allowedOverride)) + for _, name := range allowedOverride { + if _, ok := declaredSet[name]; !ok { + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice.allowed_tools contains undeclared tool %q", name) + } + filtered = append(filtered, name) + } + policy.Allowed = namesToSet(filtered) + } + + typ := strings.ToLower(strings.TrimSpace(asString(v["type"]))) + switch typ { + case "", "auto": + if hasFunctionSelector(v) { + name, err := parseForcedToolName(v) + if err != nil { + return util.ToolChoicePolicy{}, err + } + policy.Mode = util.ToolChoiceForced + policy.ForcedName = name + policy.Allowed = namesToSet([]string{name}) + } else { + policy.Mode = util.ToolChoiceAuto + } + case "none": + policy.Mode = util.ToolChoiceNone + policy.Allowed = nil + case "required": + policy.Mode = util.ToolChoiceRequired + case "function": + name, err := parseForcedToolName(v) + if err != nil { + return util.ToolChoicePolicy{}, err + } + policy.Mode = util.ToolChoiceForced + policy.ForcedName = name + policy.Allowed = namesToSet([]string{name}) + default: + return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice.type: %q", typ) + } + default: + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice must be a string or object") + } + + if policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced { + if len(declaredNames) == 0 { + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice=%s requires non-empty tools.", policy.Mode) + } + } + if policy.Mode == util.ToolChoiceForced { + if _, ok := declaredSet[policy.ForcedName]; !ok { + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice forced function %q is not declared in tools", policy.ForcedName) + } + } + if len(policy.Allowed) == 0 && (policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced) { + return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice policy resolved to empty allowed tool set") + } + return policy, nil +} + +func parseForcedToolName(v map[string]any) (string, error) { + if name := strings.TrimSpace(asString(v["name"])); name != "" { + return name, nil + } + if fn, ok := v["function"].(map[string]any); ok { + if name := strings.TrimSpace(asString(fn["name"])); name != "" { + return name, nil + } + } + return "", fmt.Errorf("tool_choice function requires name") +} + +func parseAllowedToolNames(raw any) ([]string, bool, error) { + if raw == nil { + return nil, false, nil + } + collectName := func(v any) string { + if name := strings.TrimSpace(asString(v)); name != "" { + return name + } + if m, ok := v.(map[string]any); ok { + if name := strings.TrimSpace(asString(m["name"])); name != "" { + return name + } + if fn, ok := m["function"].(map[string]any); ok { + if name := strings.TrimSpace(asString(fn["name"])); name != "" { + return name + } + } + } + return "" + } + + names := []string{} + switch x := raw.(type) { + case []any: + for _, item := range x { + name := collectName(item) + if name == "" { + return nil, true, fmt.Errorf("tool_choice.allowed_tools contains invalid item") + } + names = append(names, name) + } + case []string: + for _, item := range x { + name := strings.TrimSpace(item) + if name == "" { + return nil, true, fmt.Errorf("tool_choice.allowed_tools contains empty name") + } + names = append(names, name) + } + default: + return nil, true, fmt.Errorf("tool_choice.allowed_tools must be an array") + } + + if len(names) == 0 { + return nil, true, fmt.Errorf("tool_choice.allowed_tools must not be empty") + } + return names, true, nil +} + +func hasFunctionSelector(v map[string]any) bool { + if strings.TrimSpace(asString(v["name"])) != "" { + return true + } + if fn, ok := v["function"].(map[string]any); ok { + return strings.TrimSpace(asString(fn["name"])) != "" + } + return false +} + +func extractDeclaredToolNames(toolsRaw any) []string { + tools, ok := toolsRaw.([]any) + if !ok || len(tools) == 0 { + return nil + } + out := make([]string, 0, len(tools)) + seen := map[string]struct{}{} + for _, t := range tools { + tool, ok := t.(map[string]any) + if !ok { + continue + } + fn, _ := tool["function"].(map[string]any) + if len(fn) == 0 { + fn = tool + } + name := strings.TrimSpace(asString(fn["name"])) + if name == "" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + out = append(out, name) + } + return out +} + +func namesToSet(names []string) map[string]struct{} { + if len(names) == 0 { + return nil + } + out := make(map[string]struct{}, len(names)) + for _, name := range names { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + continue + } + out[trimmed] = struct{}{} + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go new file mode 100644 index 0000000..e8d1225 --- /dev/null +++ b/internal/adapter/openai/standard_request_test.go @@ -0,0 +1,180 @@ +package openai + +import ( + "testing" + + "ds2api/internal/config" + "ds2api/internal/util" +) + +func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{}`) + return config.LoadStore() +} + +func TestNormalizeOpenAIChatRequest(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-5-codex", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "temperature": 0.3, + "stream": true, + } + n, err := normalizeOpenAIChatRequest(store, req, "") + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ResolvedModel != "deepseek-reasoner" { + t.Fatalf("unexpected resolved model: %s", n.ResolvedModel) + } + if !n.Stream { + t.Fatalf("expected stream=true") + } + if _, ok := n.PassThrough["temperature"]; !ok { + t.Fatalf("expected temperature passthrough") + } + if n.FinalPrompt == "" { + t.Fatalf("expected non-empty final prompt") + } +} + +func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "instructions": "system", + } + n, err := normalizeOpenAIResponsesRequest(store, req, "") + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ResolvedModel != "deepseek-chat" { + t.Fatalf("unexpected resolved model: %s", n.ResolvedModel) + } + if len(n.Messages) != 2 { + t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages)) + } +} + +func TestNormalizeOpenAIResponsesRequestToolChoiceRequired(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + }, + "tool_choice": "required", + } + n, err := normalizeOpenAIResponsesRequest(store, req, "") + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ToolChoice.Mode != util.ToolChoiceRequired { + t.Fatalf("expected tool choice mode required, got %q", n.ToolChoice.Mode) + } + if len(n.ToolNames) != 1 || n.ToolNames[0] != "search" { + t.Fatalf("unexpected tool names: %#v", n.ToolNames) + } +} + +func TestNormalizeOpenAIResponsesRequestToolChoiceForcedFunction(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + }, + }, + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "read_file", + }, + }, + }, + "tool_choice": map[string]any{ + "type": "function", + "name": "read_file", + }, + } + n, err := normalizeOpenAIResponsesRequest(store, req, "") + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ToolChoice.Mode != util.ToolChoiceForced { + t.Fatalf("expected tool choice mode forced, got %q", n.ToolChoice.Mode) + } + if n.ToolChoice.ForcedName != "read_file" { + t.Fatalf("expected forced tool name read_file, got %q", n.ToolChoice.ForcedName) + } + if len(n.ToolNames) != 1 || n.ToolNames[0] != "read_file" { + t.Fatalf("expected filtered tool names [read_file], got %#v", n.ToolNames) + } +} + +func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + }, + }, + }, + "tool_choice": map[string]any{ + "type": "function", + "name": "read_file", + }, + } + if _, err := normalizeOpenAIResponsesRequest(store, req, ""); err == nil { + t.Fatalf("expected forced undeclared tool to fail") + } +} + +func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "tools": []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + }, + }, + }, + "tool_choice": "none", + } + n, err := normalizeOpenAIResponsesRequest(store, req, "") + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ToolChoice.Mode != util.ToolChoiceNone { + t.Fatalf("expected tool choice mode none, got %q", n.ToolChoice.Mode) + } + if len(n.ToolNames) != 0 { + t.Fatalf("expected no tool names when tool_choice=none, got %#v", n.ToolNames) + } +} diff --git a/internal/adapter/openai/stream_status_test.go b/internal/adapter/openai/stream_status_test.go new file mode 100644 index 0000000..4f8305a --- /dev/null +++ b/internal/adapter/openai/stream_status_test.go @@ -0,0 +1,185 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + chimw "github.com/go-chi/chi/v5/middleware" + + "ds2api/internal/auth" +) + +type streamStatusAuthStub struct{} + +func (streamStatusAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) { + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (streamStatusAuthStub) DetermineCaller(_ *http.Request) (*auth.RequestAuth, error) { + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (streamStatusAuthStub) Release(_ *auth.RequestAuth) {} + +type streamStatusDSStub struct { + resp *http.Response +} + +func (m streamStatusDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "session-id", nil +} + +func (m streamStatusDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "pow", nil +} + +func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) { + return m.resp, nil +} + +func makeOpenAISSEHTTPResponse(lines ...string) *http.Response { + body := strings.Join(lines, "\n") + if !strings.HasSuffix(body, "\n") { + body += "\n" + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func captureStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor) + next.ServeHTTP(ww, r) + *statuses = append(*statuses, ww.Status()) + }) + } +} + +func TestChatCompletionsStreamStatusCapturedAs200(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 { + t.Fatalf("expected one captured status, got %d", len(statuses)) + } + if statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0]) + } +} + +func TestResponsesStreamStatusCapturedAs200(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","input":"hi","stream":true}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 { + t.Fatalf("expected one captured status, got %d", len(statuses)) + } + if statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0]) + } +} + +func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) { + statuses := make([]int, 0, 1) + content, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}", + }) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 || statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200, got %#v", statuses) + } + + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String()) + } + outputText, _ := out["output_text"].(string) + if outputText != "" { + t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText) + } + output, _ := out["output"].([]any) + hasFunctionCall := false + for _, item := range output { + m, _ := item.(map[string]any) + if m != nil && m["type"] == "function_call" { + hasFunctionCall = true + break + } + } + if !hasFunctionCall { + t.Fatalf("expected function_call output item, got %#v", output) + } +} diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve_core.go similarity index 60% rename from internal/adapter/openai/tool_sieve.go rename to internal/adapter/openai/tool_sieve_core.go index d1a9014..5ed9b90 100644 --- a/internal/adapter/openai/tool_sieve.go +++ b/internal/adapter/openai/tool_sieve_core.go @@ -6,17 +6,6 @@ import ( "ds2api/internal/util" ) -type toolStreamSieveState struct { - pending strings.Builder - capture strings.Builder - capturing bool -} - -type toolStreamEvent struct { - Content string - ToolCalls []util.ParsedToolCall -} - func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent { if state == nil { return nil @@ -32,13 +21,27 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames state.capture.WriteString(state.pending.String()) state.pending.Reset() } - prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames) + if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 { + events = append(events, toolStreamEvent{ToolCallDeltas: deltas}) + } + prefix, calls, suffix, ready := consumeToolCapture(state, toolNames) if !ready { + if state.capture.Len() > toolSieveCaptureLimit { + content := state.capture.String() + state.capture.Reset() + state.capturing = false + state.resetIncrementalToolState() + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) + continue + } break } state.capture.Reset() state.capturing = false + state.resetIncrementalToolState() if prefix != "" { + state.noteText(prefix) events = append(events, toolStreamEvent{Content: prefix}) } if len(calls) > 0 { @@ -58,11 +61,13 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames if start >= 0 { prefix := pending[:start] if prefix != "" { + state.noteText(prefix) events = append(events, toolStreamEvent{Content: prefix}) } state.pending.Reset() state.capture.WriteString(pending[start:]) state.capturing = true + state.resetIncrementalToolState() continue } @@ -72,6 +77,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames } state.pending.Reset() state.pending.WriteString(hold) + state.noteText(safe) events = append(events, toolStreamEvent{Content: safe}) } @@ -84,25 +90,34 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea } events := processToolSieveChunk(state, "", toolNames) if state.capturing { - consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames) + consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames) if ready { if consumedPrefix != "" { + state.noteText(consumedPrefix) events = append(events, toolStreamEvent{Content: consumedPrefix}) } if len(consumedCalls) > 0 { events = append(events, toolStreamEvent{ToolCalls: consumedCalls}) } if consumedSuffix != "" { + state.noteText(consumedSuffix) events = append(events, toolStreamEvent{Content: consumedSuffix}) } } else { - // Incomplete captured tool JSON at stream end: suppress raw capture. + content := state.capture.String() + if content != "" { + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) + } } state.capture.Reset() state.capturing = false + state.resetIncrementalToolState() } if state.pending.Len() > 0 { - events = append(events, toolStreamEvent{Content: state.pending.String()}) + content := state.pending.String() + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) state.pending.Reset() } return events @@ -144,17 +159,26 @@ func findToolSegmentStart(s string) int { return -1 } lower := strings.ToLower(s) - keyIdx := strings.Index(lower, "tool_calls") - if keyIdx < 0 { - return -1 + offset := 0 + for { + keyRel := strings.Index(lower[offset:], "tool_calls") + if keyRel < 0 { + return -1 + } + keyIdx := offset + keyRel + start := strings.LastIndex(s[:keyIdx], "{") + if start < 0 { + start = keyIdx + } + if !insideCodeFence(s[:start]) { + return start + } + offset = keyIdx + len("tool_calls") } - if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 { - return start - } - return keyIdx } -func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { +func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { + captured := state.capture.String() if captured == "" { return "", nil, "", false } @@ -171,53 +195,19 @@ func consumeToolCapture(captured string, toolNames []string) (prefix string, cal if !ok { return "", nil, "", false } - parsed := util.ParseToolCalls(obj, toolNames) - if len(parsed) == 0 { - // `tool_calls` key exists but strict JSON parse failed. - // Drop the captured object body to avoid leaking raw tool JSON. - return captured[:start], nil, captured[end:], true + prefixPart := captured[:start] + suffixPart := captured[end:] + if insideCodeFence(state.recentTextTail + prefixPart) { + return captured, nil, "", true } - return captured[:start], parsed, captured[end:], true -} - -func extractJSONObjectFrom(text string, start int) (string, int, bool) { - if start < 0 || start >= len(text) || text[start] != '{' { - return "", 0, false - } - depth := 0 - quote := byte(0) - escaped := false - for i := start; i < len(text); i++ { - ch := text[i] - if quote != 0 { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '"' || ch == '\'' { - quote = ch - continue - } - if ch == '{' { - depth++ - continue - } - if ch == '}' { - depth-- - if depth == 0 { - end := i + 1 - return text[start:end], end, true - } - } - } - return "", 0, false + parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames) + if len(parsed.Calls) == 0 { + if parsed.SawToolCallSyntax && parsed.RejectedByPolicy { + // Parsed as tool-call payload but rejected by schema/policy: + // consume it to avoid leaking raw tool_calls JSON to user content. + return prefixPart, nil, suffixPart, true + } + return captured, nil, "", true + } + return prefixPart, parsed.Calls, suffixPart, true } diff --git a/internal/adapter/openai/tool_sieve_incremental.go b/internal/adapter/openai/tool_sieve_incremental.go new file mode 100644 index 0000000..ad0f901 --- /dev/null +++ b/internal/adapter/openai/tool_sieve_incremental.go @@ -0,0 +1,291 @@ +package openai + +import "strings" + +func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { + if state.disableDeltas { + return nil + } + captured := state.capture.String() + if captured == "" { + return nil + } + lower := strings.ToLower(captured) + keyIdx := strings.Index(lower, "tool_calls") + if keyIdx < 0 { + return nil + } + start := strings.LastIndex(captured[:keyIdx], "{") + if start < 0 { + return nil + } + if insideCodeFence(state.recentTextTail + captured[:start]) { + return nil + } + certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx) + if hasMultiple { + state.disableDeltas = true + return nil + } + if !certainSingle { + // In uncertain phases (e.g. first call arrived but array not closed yet), + // avoid speculative deltas and wait for final parsed tool_calls payload. + return nil + } + callStart, ok := findFirstToolCallObjectStart(captured, keyIdx) + if !ok { + return nil + } + deltas := make([]toolCallDelta, 0, 2) + if state.toolName == "" { + name, ok := extractToolCallName(captured, callStart) + if !ok || name == "" { + return nil + } + state.toolName = name + } + if state.toolArgsStart < 0 { + argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart) + if ok { + state.toolArgsString = stringMode + if stringMode { + state.toolArgsStart = argsStart + 1 + } else { + state.toolArgsStart = argsStart + } + state.toolArgsSent = state.toolArgsStart + } + } + if !state.toolNameSent { + if state.toolArgsStart < 0 { + return nil + } + state.toolNameSent = true + deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName}) + } + if state.toolArgsStart < 0 || state.toolArgsDone { + return deltas + } + end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString) + if !ok { + return deltas + } + if end > state.toolArgsSent { + deltas = append(deltas, toolCallDelta{ + Index: 0, + Arguments: captured[state.toolArgsSent:end], + }) + state.toolArgsSent = end + } + if complete { + state.toolArgsDone = true + } + return deltas +} + +func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) { + arrStart, ok := findToolCallsArrayStart(text, keyIdx) + if !ok { + return false, false + } + i := skipSpaces(text, arrStart+1) + if i >= len(text) || text[i] != '{' { + return false, false + } + count := 0 + depth := 0 + quote := byte(0) + escaped := false + for ; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' { + if depth == 0 { + count++ + if count > 1 { + return false, true + } + } + depth++ + continue + } + if ch == '}' { + if depth > 0 { + depth-- + } + continue + } + if ch == ',' && depth == 0 { + // top-level separator means at least one more tool call exists + // (or is expected). Treat as multi-call and stop incremental deltas. + return false, true + } + if ch == ']' && depth == 0 { + return count == 1, false + } + } + // array not closed yet: still uncertain whether more calls will appear + return false, false +} + +func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) { + arrStart, ok := findToolCallsArrayStart(text, keyIdx) + if !ok { + return -1, false + } + i := skipSpaces(text, arrStart+1) + if i >= len(text) || text[i] != '{' { + return -1, false + } + return i, true +} + +func findToolCallsArrayStart(text string, keyIdx int) (int, bool) { + i := keyIdx + len("tool_calls") + for i < len(text) && text[i] != ':' { + i++ + } + if i >= len(text) { + return -1, false + } + i = skipSpaces(text, i+1) + if i >= len(text) || text[i] != '[' { + return -1, false + } + return i, true +} + +func extractToolCallName(text string, callStart int) (string, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return "", false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + return "", false + } + } + name, _, ok := parseJSONStringLiteral(text, valueStart) + if !ok { + return "", false + } + return name, true +} + +func findToolCallArgsStart(text string, callStart int) (int, bool, bool) { + keys := []string{"input", "arguments", "args", "parameters", "params"} + valueStart, ok := findObjectFieldValueStart(text, callStart, keys) + if !ok { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return -1, false, false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, keys) + if !ok { + return -1, false, false + } + } + if valueStart >= len(text) { + return -1, false, false + } + ch := text[valueStart] + if ch == '{' || ch == '[' { + return valueStart, false, true + } + if ch == '"' { + return valueStart, true, true + } + return -1, false, false +} + +func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) { + if start < 0 || start > len(text) { + return 0, false, false + } + if stringMode { + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return i, true, true + } + } + return len(text), false, true + } + if start >= len(text) { + return start, false, false + } + if text[start] != '{' && text[start] != '[' { + return 0, false, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' || ch == '[' { + depth++ + continue + } + if ch == '}' || ch == ']' { + depth-- + if depth == 0 { + return i + 1, true, true + } + } + } + return len(text), false, true +} + +func findFunctionObjectStart(text string, callStart int) (int, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"}) + if !ok || valueStart >= len(text) || text[valueStart] != '{' { + return -1, false + } + return valueStart, true +} diff --git a/internal/adapter/openai/tool_sieve_jsonscan.go b/internal/adapter/openai/tool_sieve_jsonscan.go new file mode 100644 index 0000000..d3abcc5 --- /dev/null +++ b/internal/adapter/openai/tool_sieve_jsonscan.go @@ -0,0 +1,152 @@ +package openai + +import "strings" + +func extractJSONObjectFrom(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '{' { + return "", 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + end := i + 1 + return text[start:end], end, true + } + } + } + return "", 0, false +} + +func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) { + if objStart < 0 || objStart >= len(text) || text[objStart] != '{' { + return 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := objStart; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + if depth == 1 { + key, end, ok := parseJSONStringLiteral(text, i) + if !ok { + return 0, false + } + j := skipSpaces(text, end) + if j >= len(text) || text[j] != ':' { + i = end - 1 + continue + } + j = skipSpaces(text, j+1) + if j >= len(text) { + return 0, false + } + if containsKey(keys, key) { + return j, true + } + i = j - 1 + continue + } + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + break + } + } + } + return 0, false +} + +func parseJSONStringLiteral(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '"' { + return "", 0, false + } + var b strings.Builder + escaped := false + for i := start + 1; i < len(text); i++ { + ch := text[i] + if escaped { + b.WriteByte(ch) + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return b.String(), i + 1, true + } + b.WriteByte(ch) + } + return "", 0, false +} + +func containsKey(keys []string, value string) bool { + for _, k := range keys { + if k == value { + return true + } + } + return false +} + +func skipSpaces(text string, i int) int { + for i < len(text) { + switch text[i] { + case ' ', '\t', '\n', '\r': + i++ + default: + return i + } + } + return i +} diff --git a/internal/adapter/openai/tool_sieve_state.go b/internal/adapter/openai/tool_sieve_state.go new file mode 100644 index 0000000..04699e6 --- /dev/null +++ b/internal/adapter/openai/tool_sieve_state.go @@ -0,0 +1,75 @@ +package openai + +import ( + "strings" + + "ds2api/internal/util" +) + +type toolStreamSieveState struct { + pending strings.Builder + capture strings.Builder + capturing bool + recentTextTail string + disableDeltas bool + toolNameSent bool + toolName string + toolArgsStart int + toolArgsSent int + toolArgsString bool + toolArgsDone bool +} + +type toolStreamEvent struct { + Content string + ToolCalls []util.ParsedToolCall + ToolCallDeltas []toolCallDelta +} + +type toolCallDelta struct { + Index int + Name string + Arguments string +} + +const toolSieveCaptureLimit = 8 * 1024 +const toolSieveContextTailLimit = 256 + +func (s *toolStreamSieveState) resetIncrementalToolState() { + s.disableDeltas = false + s.toolNameSent = false + s.toolName = "" + s.toolArgsStart = -1 + s.toolArgsSent = -1 + s.toolArgsString = false + s.toolArgsDone = false +} + +func (s *toolStreamSieveState) noteText(content string) { + if strings.TrimSpace(content) == "" { + return + } + s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit) +} + +func appendTail(prev, next string, max int) string { + if max <= 0 { + return "" + } + combined := prev + next + if len(combined) <= max { + return combined + } + return combined[len(combined)-max:] +} + +func looksLikeToolExampleContext(text string) bool { + return insideCodeFence(text) +} + +func insideCodeFence(text string) bool { + if text == "" { + return false + } + return strings.Count(text, "```")%2 == 1 +} diff --git a/internal/adapter/openai/trace.go b/internal/adapter/openai/trace.go new file mode 100644 index 0000000..8ea58f0 --- /dev/null +++ b/internal/adapter/openai/trace.go @@ -0,0 +1,21 @@ +package openai + +import ( + "net/http" + "strings" + + "github.com/go-chi/chi/v5/middleware" +) + +func requestTraceID(r *http.Request) string { + if r == nil { + return "" + } + if q := strings.TrimSpace(r.URL.Query().Get("__trace_id")); q != "" { + return q + } + if h := strings.TrimSpace(r.Header.Get("X-Ds2-Test-Trace")); h != "" { + return h + } + return strings.TrimSpace(middleware.GetReqID(r.Context())) +} diff --git a/internal/adapter/openai/trace_test.go b/internal/adapter/openai/trace_test.go new file mode 100644 index 0000000..cbacbf3 --- /dev/null +++ b/internal/adapter/openai/trace_test.go @@ -0,0 +1,47 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5/middleware" +) + +func traceIDViaMiddleware(req *http.Request) string { + if req == nil { + return requestTraceID(nil) + } + var got string + h := middleware.RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + got = requestTraceID(r) + })) + h.ServeHTTP(httptest.NewRecorder(), req) + return got +} + +func TestRequestTraceIDPriority(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions?__trace_id=query-trace", nil) + req.Header.Set("X-Ds2-Test-Trace", "header-trace") + got := traceIDViaMiddleware(req) + if got != "query-trace" { + t.Fatalf("expected query trace id to win, got %q", got) + } +} + +func TestRequestTraceIDHeaderFallback(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + req.Header.Set("X-Ds2-Test-Trace", "header-trace") + got := traceIDViaMiddleware(req) + if got != "header-trace" { + t.Fatalf("expected header trace id to win when query missing, got %q", got) + } +} + +func TestRequestTraceIDReqIDFallback(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + got := traceIDViaMiddleware(req) + if got == "" { + t.Fatal("expected middleware request id fallback to be non-empty") + } +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index 653f3cf..f34ea8b 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -56,24 +56,16 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r)) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) return } - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model) - if !ok { - writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model)) + if !stdReq.Stream { + writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - messages := normalizeMessages(messagesRaw) - if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { - messages, _ = injectToolPrompt(messages, tools) - } - finalPrompt := util.MessagesPrepare(messages) - sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { if a.UseConfigToken { @@ -93,14 +85,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque return } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } + payload := stdReq.CompletionPayload(sessionID) leaseID := h.holdStreamLease(a) if leaseID == "" { writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease") @@ -108,15 +93,18 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque } leased = true writeJSON(w, http.StatusOK, map[string]any{ - "session_id": sessionID, - "lease_id": leaseID, - "model": model, - "final_prompt": finalPrompt, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - "deepseek_token": a.DeepSeekToken, - "pow_header": powHeader, - "payload": payload, + "session_id": sessionID, + "lease_id": leaseID, + "model": stdReq.ResponseModel, + "final_prompt": stdReq.FinalPrompt, + "thinking_enabled": stdReq.Thinking, + "search_enabled": stdReq.Search, + "tool_names": stdReq.ToolNames, + "toolcall_feature_match": h.toolcallFeatureMatchEnabled(), + "toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(), + "deepseek_token": a.DeepSeekToken, + "pow_header": powHeader, + "payload": payload, }) } diff --git a/internal/admin/deps.go b/internal/admin/deps.go new file mode 100644 index 0000000..e92c37b --- /dev/null +++ b/internal/admin/deps.go @@ -0,0 +1,46 @@ +package admin + +import ( + "context" + "net/http" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type ConfigStore interface { + Snapshot() config.Config + Keys() []string + Accounts() []config.Account + FindAccount(identifier string) (config.Account, bool) + UpdateAccountToken(identifier, token string) error + Update(mutator func(*config.Config) error) error + ExportJSONAndBase64() (string, string, error) + IsEnvBacked() bool + SetVercelSync(hash string, ts int64) error + AdminPasswordHash() string + AdminJWTExpireHours() int + AdminJWTValidAfterUnix() int64 + RuntimeAccountMaxInflight() int + RuntimeAccountMaxQueue(defaultSize int) int + RuntimeGlobalMaxInflight(defaultSize int) int +} + +type PoolController interface { + Reset() + Status() map[string]any + ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) +} + +type DeepSeekCaller interface { + Login(ctx context.Context, acc config.Account) (string, error) + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +var _ ConfigStore = (*config.Store)(nil) +var _ PoolController = (*account.Pool)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) diff --git a/internal/admin/handler.go b/internal/admin/handler.go index 9d6151e..c8f7702 100644 --- a/internal/admin/handler.go +++ b/internal/admin/handler.go @@ -2,16 +2,12 @@ package admin import ( "github.com/go-chi/chi/v5" - - "ds2api/internal/account" - "ds2api/internal/config" - "ds2api/internal/deepseek" ) type Handler struct { - Store *config.Store - Pool *account.Pool - DS *deepseek.Client + Store ConfigStore + Pool PoolController + DS DeepSeekCaller } func RegisterRoutes(r chi.Router, h *Handler) { @@ -22,6 +18,11 @@ func RegisterRoutes(r chi.Router, h *Handler) { pr.Get("/vercel/config", h.getVercelConfig) pr.Get("/config", h.getConfig) pr.Post("/config", h.updateConfig) + pr.Get("/settings", h.getSettings) + pr.Put("/settings", h.updateSettings) + pr.Post("/settings/password", h.updateSettingsPassword) + pr.Post("/config/import", h.configImport) + pr.Get("/config/export", h.configExport) pr.Post("/keys", h.addKey) pr.Delete("/keys/{key}", h.deleteKey) pr.Get("/accounts", h.listAccounts) @@ -35,5 +36,7 @@ func RegisterRoutes(r chi.Router, h *Handler) { pr.Post("/vercel/sync", h.syncVercel) pr.Get("/vercel/status", h.vercelStatus) pr.Get("/export", h.exportConfig) + pr.Get("/dev/captures", h.getDevCaptures) + pr.Delete("/dev/captures", h.clearDevCaptures) }) } diff --git a/internal/admin/handler_accounts_crud.go b/internal/admin/handler_accounts_crud.go new file mode 100644 index 0000000..daaa434 --- /dev/null +++ b/internal/admin/handler_accounts_crud.go @@ -0,0 +1,114 @@ +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/config" +) + +func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { + page := intFromQuery(r, "page", 1) + pageSize := intFromQuery(r, "page_size", 10) + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 1 + } + if pageSize > 100 { + pageSize = 100 + } + accounts := h.Store.Snapshot().Accounts + total := len(accounts) + reverseAccounts(accounts) + totalPages := 1 + if total > 0 { + totalPages = (total + pageSize - 1) / pageSize + } + start := (page - 1) * pageSize + if start > total { + start = total + } + end := start + pageSize + if end > total { + end = total + } + items := make([]map[string]any, 0, end-start) + for _, acc := range accounts[start:end] { + token := strings.TrimSpace(acc.Token) + preview := "" + if token != "" { + if len(token) > 20 { + preview = token[:20] + "..." + } else { + preview = token + } + } + items = append(items, map[string]any{ + "identifier": acc.Identifier(), + "email": acc.Email, + "mobile": acc.Mobile, + "has_password": acc.Password != "", + "has_token": token != "", + "token_preview": preview, + }) + } + writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) +} + +func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + acc := toAccount(req) + if acc.Identifier() == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"}) + return + } + err := h.Store.Update(func(c *config.Config) error { + for _, a := range c.Accounts { + if acc.Email != "" && a.Email == acc.Email { + return fmt.Errorf("邮箱已存在") + } + if acc.Mobile != "" && a.Mobile == acc.Mobile { + return fmt.Errorf("手机号已存在") + } + } + c.Accounts = append(c.Accounts, acc) + return nil + }) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} + +func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { + identifier := chi.URLParam(r, "identifier") + err := h.Store.Update(func(c *config.Config) error { + idx := -1 + for i, a := range c.Accounts { + if accountMatchesIdentifier(a, identifier) { + idx = i + break + } + } + if idx < 0 { + return fmt.Errorf("账号不存在") + } + c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...) + return nil + }) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} diff --git a/internal/admin/handler_accounts_identifier_test.go b/internal/admin/handler_accounts_identifier_test.go new file mode 100644 index 0000000..591d43a --- /dev/null +++ b/internal/admin/handler_accounts_identifier_test.go @@ -0,0 +1,138 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/config" +) + +func newAdminTestHandler(t *testing.T, raw string) *Handler { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", raw) + t.Setenv("CONFIG_JSON", "") + store := config.LoadStore() + return &Handler{ + Store: store, + Pool: account.NewPool(store), + } +} + +func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"token":"token-only-account"}] + }`) + + req := httptest.NewRequest(http.MethodGet, "/admin/accounts?page=1&page_size=10", nil) + rec := httptest.NewRecorder() + h.listAccounts(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + + var payload map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response failed: %v", err) + } + items, _ := payload["items"].([]any) + if len(items) != 1 { + t.Fatalf("expected 1 item, got %d", len(items)) + } + first, _ := items[0].(map[string]any) + identifier, _ := first["identifier"].(string) + if identifier == "" { + t.Fatalf("expected non-empty identifier: %#v", first) + } + if !strings.HasPrefix(identifier, "token:") { + t.Fatalf("expected token synthetic identifier, got %q", identifier) + } +} + +func TestDeleteAccountSupportsTokenOnlyIdentifier(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"token":"token-only-account"}] + }`) + accounts := h.Store.Accounts() + if len(accounts) != 1 { + t.Fatalf("expected 1 account, got %d", len(accounts)) + } + id := accounts[0].Identifier() + if id == "" { + t.Fatal("expected token-only synthetic identifier") + } + + r := chi.NewRouter() + r.Delete("/admin/accounts/{identifier}", h.deleteAccount) + req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape(id), nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("expected account removed, remaining=%d", got) + } +} + +func TestDeleteAccountSupportsMobileAlias(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"email":"u@example.com","mobile":"13800138000","password":"pwd"}] + }`) + + r := chi.NewRouter() + r.Delete("/admin/accounts/{identifier}", h.deleteAccount) + req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/13800138000", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("expected account removed, remaining=%d", got) + } +} + +func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[ + {"email":"u@example.com","mobile":"13800138000","password":"pwd"}, + {"token":"token-only-account"} + ] + }`) + + accByMobile, ok := findAccountByIdentifier(h.Store, "13800138000") + if !ok { + t.Fatal("expected find by mobile") + } + if accByMobile.Email != "u@example.com" { + t.Fatalf("unexpected account by mobile: %#v", accByMobile) + } + + tokenOnlyID := "" + for _, acc := range h.Store.Accounts() { + if strings.TrimSpace(acc.Email) == "" && strings.TrimSpace(acc.Mobile) == "" { + tokenOnlyID = acc.Identifier() + break + } + } + if tokenOnlyID == "" { + t.Fatal("expected token-only account identifier") + } + accByTokenOnly, ok := findAccountByIdentifier(h.Store, tokenOnlyID) + if !ok { + t.Fatalf("expected find by token-only id=%q", tokenOnlyID) + } + if accByTokenOnly.Token != "token-only-account" { + t.Fatalf("unexpected token-only account: %#v", accByTokenOnly) + } +} diff --git a/internal/admin/handler_accounts_queue.go b/internal/admin/handler_accounts_queue.go new file mode 100644 index 0000000..108f802 --- /dev/null +++ b/internal/admin/handler_accounts_queue.go @@ -0,0 +1,7 @@ +package admin + +import "net/http" + +func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, h.Pool.Status()) +} diff --git a/internal/admin/handler_accounts.go b/internal/admin/handler_accounts_testing.go similarity index 68% rename from internal/admin/handler_accounts.go rename to internal/admin/handler_accounts_testing.go index b95077d..2bd7706 100644 --- a/internal/admin/handler_accounts.go +++ b/internal/admin/handler_accounts_testing.go @@ -11,121 +11,20 @@ import ( "sync" "time" - "github.com/go-chi/chi/v5" - authn "ds2api/internal/auth" "ds2api/internal/config" "ds2api/internal/sse" ) -func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { - page := intFromQuery(r, "page", 1) - pageSize := intFromQuery(r, "page_size", 10) - if page < 1 { - page = 1 - } - if pageSize < 1 { - pageSize = 1 - } - if pageSize > 100 { - pageSize = 100 - } - accounts := h.Store.Snapshot().Accounts - total := len(accounts) - reverseAccounts(accounts) - totalPages := 1 - if total > 0 { - totalPages = (total + pageSize - 1) / pageSize - } - start := (page - 1) * pageSize - if start > total { - start = total - } - end := start + pageSize - if end > total { - end = total - } - items := make([]map[string]any, 0, end-start) - for _, acc := range accounts[start:end] { - token := strings.TrimSpace(acc.Token) - preview := "" - if token != "" { - if len(token) > 20 { - preview = token[:20] + "..." - } else { - preview = token - } - } - items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview}) - } - writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) -} - -func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - acc := toAccount(req) - if acc.Identifier() == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"}) - return - } - err := h.Store.Update(func(c *config.Config) error { - for _, a := range c.Accounts { - if acc.Email != "" && a.Email == acc.Email { - return fmt.Errorf("邮箱已存在") - } - if acc.Mobile != "" && a.Mobile == acc.Mobile { - return fmt.Errorf("手机号已存在") - } - } - c.Accounts = append(c.Accounts, acc) - return nil - }) - if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) -} - -func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { - identifier := chi.URLParam(r, "identifier") - err := h.Store.Update(func(c *config.Config) error { - idx := -1 - for i, a := range c.Accounts { - if a.Email == identifier || a.Mobile == identifier { - idx = i - break - } - } - if idx < 0 { - return fmt.Errorf("账号不存在") - } - c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...) - return nil - }) - if err != nil { - writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) -} - -func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, h.Pool.Status()) -} - func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) { var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) identifier, _ := req["identifier"].(string) if strings.TrimSpace(identifier) == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(email 或 mobile)"}) + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"}) return } - acc, ok := h.Store.FindAccount(identifier) + acc, ok := findAccountByIdentifier(h.Store, identifier) if !ok { writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"}) return diff --git a/internal/admin/handler_auth.go b/internal/admin/handler_auth.go index 0d3ec1f..9b96b2f 100644 --- a/internal/admin/handler_auth.go +++ b/internal/admin/handler_auth.go @@ -12,7 +12,7 @@ import ( func (h *Handler) requireAdmin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := authn.VerifyAdminRequest(r); err != nil { + if err := authn.VerifyAdminRequestWithStore(r, h.Store); err != nil { writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) return } @@ -25,18 +25,18 @@ func (h *Handler) login(w http.ResponseWriter, r *http.Request) { _ = json.NewDecoder(r.Body).Decode(&req) adminKey, _ := req["admin_key"].(string) expireHours := intFrom(req["expire_hours"]) - if expireHours <= 0 { - expireHours = 24 - } - if adminKey != authn.AdminKey() { + if !authn.VerifyAdminCredential(adminKey, h.Store) { writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"}) return } - token, err := authn.CreateJWT(expireHours) + token, err := authn.CreateJWTWithStore(expireHours, h.Store) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) return } + if expireHours <= 0 { + expireHours = h.Store.AdminJWTExpireHours() + } writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600}) } @@ -47,7 +47,7 @@ func (h *Handler) verify(w http.ResponseWriter, r *http.Request) { return } token := strings.TrimSpace(header[7:]) - payload, err := authn.VerifyJWT(token) + payload, err := authn.VerifyJWTWithStore(token, h.Store) if err != nil { writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()}) return diff --git a/internal/admin/handler_config_import.go b/internal/admin/handler_config_import.go new file mode 100644 index 0000000..674d8b2 --- /dev/null +++ b/internal/admin/handler_config_import.go @@ -0,0 +1,182 @@ +package admin + +import ( + "crypto/md5" + "encoding/json" + "fmt" + "net/http" + "strings" + + "ds2api/internal/config" +) + +func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + + mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode"))) + if mode == "" { + mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode"))) + } + if mode == "" { + mode = "merge" + } + if mode != "merge" && mode != "replace" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"}) + return + } + + payload := req + if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 { + payload = raw + } + rawJSON, err := json.Marshal(payload) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"}) + return + } + var incoming config.Config + if err := json.Unmarshal(rawJSON, &incoming); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + + importedKeys, importedAccounts := 0, 0 + err = h.Store.Update(func(c *config.Config) error { + next := c.Clone() + if mode == "replace" { + next = incoming.Clone() + next.VercelSyncHash = c.VercelSyncHash + next.VercelSyncTime = c.VercelSyncTime + importedKeys = len(next.Keys) + importedAccounts = len(next.Accounts) + } else { + existingKeys := map[string]struct{}{} + for _, k := range next.Keys { + existingKeys[k] = struct{}{} + } + for _, k := range incoming.Keys { + key := strings.TrimSpace(k) + if key == "" { + continue + } + if _, ok := existingKeys[key]; ok { + continue + } + existingKeys[key] = struct{}{} + next.Keys = append(next.Keys, key) + importedKeys++ + } + + existingAccounts := map[string]struct{}{} + for _, acc := range next.Accounts { + existingAccounts[acc.Identifier()] = struct{}{} + } + for _, acc := range incoming.Accounts { + id := acc.Identifier() + if id == "" { + continue + } + if _, ok := existingAccounts[id]; ok { + continue + } + existingAccounts[id] = struct{}{} + next.Accounts = append(next.Accounts, acc) + importedAccounts++ + } + + if len(incoming.ClaudeMapping) > 0 { + if next.ClaudeMapping == nil { + next.ClaudeMapping = map[string]string{} + } + for k, v := range incoming.ClaudeMapping { + next.ClaudeMapping[k] = v + } + } + if len(incoming.ClaudeModelMap) > 0 { + if next.ClaudeModelMap == nil { + next.ClaudeModelMap = map[string]string{} + } + for k, v := range incoming.ClaudeModelMap { + next.ClaudeModelMap[k] = v + } + } + + if len(incoming.ModelAliases) > 0 { + if next.ModelAliases == nil { + next.ModelAliases = map[string]string{} + } + for k, v := range incoming.ModelAliases { + next.ModelAliases[k] = v + } + } + if strings.TrimSpace(incoming.Toolcall.Mode) != "" { + next.Toolcall.Mode = incoming.Toolcall.Mode + } + if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" { + next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence + } + if incoming.Responses.StoreTTLSeconds > 0 { + next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds + } + if strings.TrimSpace(incoming.Embeddings.Provider) != "" { + next.Embeddings.Provider = incoming.Embeddings.Provider + } + if strings.TrimSpace(incoming.Admin.PasswordHash) != "" { + next.Admin.PasswordHash = incoming.Admin.PasswordHash + } + if incoming.Admin.JWTExpireHours > 0 { + next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours + } + if incoming.Admin.JWTValidAfterUnix > 0 { + next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix + } + if incoming.Runtime.AccountMaxInflight > 0 { + next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight + } + if incoming.Runtime.AccountMaxQueue > 0 { + next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue + } + if incoming.Runtime.GlobalMaxInflight > 0 { + next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight + } + } + + normalizeSettingsConfig(&next) + if err := validateSettingsConfig(next); err != nil { + return newRequestError(err.Error()) + } + + *c = next + return nil + }) + if err != nil { + if detail, ok := requestErrorDetail(err); ok { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail}) + return + } + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "mode": mode, + "imported_keys": importedKeys, + "imported_accounts": importedAccounts, + "message": "config imported", + }) +} + +func (h *Handler) computeSyncHash() string { + snap := h.Store.Snapshot().Clone() + snap.VercelSyncHash = "" + snap.VercelSyncTime = 0 + b, _ := json.Marshal(snap) + sum := md5.Sum(b) + return fmt.Sprintf("%x", sum) +} diff --git a/internal/admin/handler_config_read.go b/internal/admin/handler_config_read.go new file mode 100644 index 0000000..e32aabd --- /dev/null +++ b/internal/admin/handler_config_read.go @@ -0,0 +1,61 @@ +package admin + +import ( + "net/http" + "strings" +) + +func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + safe := map[string]any{ + "keys": snap.Keys, + "accounts": []map[string]any{}, + "claude_mapping": func() map[string]string { + if len(snap.ClaudeMapping) > 0 { + return snap.ClaudeMapping + } + return snap.ClaudeModelMap + }(), + } + accounts := make([]map[string]any, 0, len(snap.Accounts)) + for _, acc := range snap.Accounts { + token := strings.TrimSpace(acc.Token) + preview := "" + if token != "" { + if len(token) > 20 { + preview = token[:20] + "..." + } else { + preview = token + } + } + accounts = append(accounts, map[string]any{ + "identifier": acc.Identifier(), + "email": acc.Email, + "mobile": acc.Mobile, + "has_password": strings.TrimSpace(acc.Password) != "", + "has_token": token != "", + "token_preview": preview, + }) + } + safe["accounts"] = accounts + writeJSON(w, http.StatusOK, safe) +} + +func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) { + h.configExport(w, nil) +} + +func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + jsonStr, b64, err := h.Store.ExportJSONAndBase64() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "config": snap, + "json": jsonStr, + "base64": b64, + }) +} diff --git a/internal/admin/handler_config.go b/internal/admin/handler_config_write.go similarity index 68% rename from internal/admin/handler_config.go rename to internal/admin/handler_config_write.go index 7627602..792e696 100644 --- a/internal/admin/handler_config.go +++ b/internal/admin/handler_config_write.go @@ -1,11 +1,9 @@ package admin import ( - "crypto/md5" "encoding/json" "fmt" "net/http" - "sort" "strings" "github.com/go-chi/chi/v5" @@ -13,41 +11,6 @@ import ( "ds2api/internal/config" ) -func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) { - snap := h.Store.Snapshot() - safe := map[string]any{ - "keys": snap.Keys, - "accounts": []map[string]any{}, - "claude_mapping": func() map[string]string { - if len(snap.ClaudeMapping) > 0 { - return snap.ClaudeMapping - } - return snap.ClaudeModelMap - }(), - } - accounts := make([]map[string]any, 0, len(snap.Accounts)) - for _, acc := range snap.Accounts { - token := strings.TrimSpace(acc.Token) - preview := "" - if token != "" { - if len(token) > 20 { - preview = token[:20] + "..." - } else { - preview = token - } - } - accounts = append(accounts, map[string]any{ - "email": acc.Email, - "mobile": acc.Mobile, - "has_password": strings.TrimSpace(acc.Password) != "", - "has_token": token != "", - "token_preview": preview, - }) - } - safe["accounts"] = accounts - writeJSON(w, http.StatusOK, safe) -} - func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) { var req map[string]any if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -201,40 +164,3 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) { h.Pool.Reset() writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts}) } - -func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) { - jsonStr, b64, err := h.Store.ExportJSONAndBase64() - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - writeJSON(w, http.StatusOK, map[string]any{"json": jsonStr, "base64": b64}) -} - -func (h *Handler) computeSyncHash() string { - snap := h.Store.Snapshot() - syncable := map[string]any{"keys": snap.Keys, "accounts": []map[string]any{}} - accounts := make([]map[string]any, 0, len(snap.Accounts)) - for _, a := range snap.Accounts { - m := map[string]any{} - if a.Email != "" { - m["email"] = a.Email - } - if a.Mobile != "" { - m["mobile"] = a.Mobile - } - if a.Password != "" { - m["password"] = a.Password - } - accounts = append(accounts, m) - } - sort.Slice(accounts, func(i, j int) bool { - ai := fmt.Sprintf("%v%v", accounts[i]["email"], accounts[i]["mobile"]) - aj := fmt.Sprintf("%v%v", accounts[j]["email"], accounts[j]["mobile"]) - return ai < aj - }) - syncable["accounts"] = accounts - b, _ := json.Marshal(syncable) - sum := md5.Sum(b) - return fmt.Sprintf("%x", sum) -} diff --git a/internal/admin/handler_dev_capture.go b/internal/admin/handler_dev_capture.go new file mode 100644 index 0000000..9b3615c --- /dev/null +++ b/internal/admin/handler_dev_capture.go @@ -0,0 +1,26 @@ +package admin + +import ( + "net/http" + + "ds2api/internal/devcapture" +) + +func (h *Handler) getDevCaptures(w http.ResponseWriter, _ *http.Request) { + store := devcapture.Global() + writeJSON(w, http.StatusOK, map[string]any{ + "enabled": store.Enabled(), + "limit": store.Limit(), + "max_body_bytes": store.MaxBodyBytes(), + "items": store.Snapshot(), + }) +} + +func (h *Handler) clearDevCaptures(w http.ResponseWriter, _ *http.Request) { + store := devcapture.Global() + store.Clear() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "detail": "capture logs cleared", + }) +} diff --git a/internal/admin/handler_dev_capture_test.go b/internal/admin/handler_dev_capture_test.go new file mode 100644 index 0000000..90ced8b --- /dev/null +++ b/internal/admin/handler_dev_capture_test.go @@ -0,0 +1,45 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetDevCapturesShape(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/dev/captures", nil) + h.getDevCaptures(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + if _, ok := out["enabled"]; !ok { + t.Fatalf("expected enabled field, got %#v", out) + } + if _, ok := out["items"]; !ok { + t.Fatalf("expected items field, got %#v", out) + } +} + +func TestClearDevCapturesShape(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/admin/dev/captures", nil) + h.clearDevCaptures(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + if out["success"] != true { + t.Fatalf("expected success=true, got %#v", out) + } +} diff --git a/internal/admin/handler_settings_parse.go b/internal/admin/handler_settings_parse.go new file mode 100644 index 0000000..6c5b7ee --- /dev/null +++ b/internal/admin/handler_settings_parse.go @@ -0,0 +1,134 @@ +package admin + +import ( + "fmt" + "strings" + + "ds2api/internal/config" +) + +func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) { + var ( + adminCfg *config.AdminConfig + runtimeCfg *config.RuntimeConfig + toolcallCfg *config.ToolcallConfig + respCfg *config.ResponsesConfig + embCfg *config.EmbeddingsConfig + claudeMap map[string]string + aliasMap map[string]string + ) + + if raw, ok := req["admin"].(map[string]any); ok { + cfg := &config.AdminConfig{} + if v, exists := raw["jwt_expire_hours"]; exists { + n := intFrom(v) + if n < 1 || n > 720 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") + } + cfg.JWTExpireHours = n + } + adminCfg = cfg + } + + if raw, ok := req["runtime"].(map[string]any); ok { + cfg := &config.RuntimeConfig{} + if v, exists := raw["account_max_inflight"]; exists { + n := intFrom(v) + if n < 1 || n > 256 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") + } + cfg.AccountMaxInflight = n + } + if v, exists := raw["account_max_queue"]; exists { + n := intFrom(v) + if n < 1 || n > 200000 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") + } + cfg.AccountMaxQueue = n + } + if v, exists := raw["global_max_inflight"]; exists { + n := intFrom(v) + if n < 1 || n > 200000 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") + } + cfg.GlobalMaxInflight = n + } + if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") + } + runtimeCfg = cfg + } + + if raw, ok := req["toolcall"].(map[string]any); ok { + cfg := &config.ToolcallConfig{} + if v, exists := raw["mode"]; exists { + mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) + switch mode { + case "feature_match", "off": + cfg.Mode = mode + default: + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off") + } + } + if v, exists := raw["early_emit_confidence"]; exists { + level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) + switch level { + case "high", "low", "off": + cfg.EarlyEmitConfidence = level + default: + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off") + } + } + toolcallCfg = cfg + } + + if raw, ok := req["responses"].(map[string]any); ok { + cfg := &config.ResponsesConfig{} + if v, exists := raw["store_ttl_seconds"]; exists { + n := intFrom(v) + if n < 30 || n > 86400 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") + } + cfg.StoreTTLSeconds = n + } + respCfg = cfg + } + + if raw, ok := req["embeddings"].(map[string]any); ok { + cfg := &config.EmbeddingsConfig{} + if v, exists := raw["provider"]; exists { + p := strings.TrimSpace(fmt.Sprintf("%v", v)) + if p == "" { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty") + } + cfg.Provider = p + } + embCfg = cfg + } + + if raw, ok := req["claude_mapping"].(map[string]any); ok { + claudeMap = map[string]string{} + for k, v := range raw { + key := strings.TrimSpace(k) + val := strings.TrimSpace(fmt.Sprintf("%v", v)) + if key == "" || val == "" { + continue + } + claudeMap[key] = val + } + } + + if raw, ok := req["model_aliases"].(map[string]any); ok { + aliasMap = map[string]string{} + for k, v := range raw { + key := strings.TrimSpace(k) + val := strings.TrimSpace(fmt.Sprintf("%v", v)) + if key == "" || val == "" { + continue + } + aliasMap[key] = val + } + } + + return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil +} diff --git a/internal/admin/handler_settings_read.go b/internal/admin/handler_settings_read.go new file mode 100644 index 0000000..565519f --- /dev/null +++ b/internal/admin/handler_settings_read.go @@ -0,0 +1,36 @@ +package admin + +import ( + "net/http" + "strings" + + authn "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight()) + needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "admin": map[string]any{ + "has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "", + "jwt_expire_hours": h.Store.AdminJWTExpireHours(), + "jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix, + "default_password_warning": authn.UsingDefaultAdminKey(h.Store), + }, + "runtime": map[string]any{ + "account_max_inflight": h.Store.RuntimeAccountMaxInflight(), + "account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended), + "global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended), + }, + "toolcall": snap.Toolcall, + "responses": snap.Responses, + "embeddings": snap.Embeddings, + "claude_mapping": settingsClaudeMapping(snap), + "model_aliases": snap.ModelAliases, + "env_backed": h.Store.IsEnvBacked(), + "needs_vercel_sync": needsSync, + }) +} diff --git a/internal/admin/handler_settings_runtime.go b/internal/admin/handler_settings_runtime.go new file mode 100644 index 0000000..6ff6902 --- /dev/null +++ b/internal/admin/handler_settings_runtime.go @@ -0,0 +1,51 @@ +package admin + +import "ds2api/internal/config" + +func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error { + merged := current + if incoming != nil { + if incoming.AccountMaxInflight > 0 { + merged.AccountMaxInflight = incoming.AccountMaxInflight + } + if incoming.AccountMaxQueue > 0 { + merged.AccountMaxQueue = incoming.AccountMaxQueue + } + if incoming.GlobalMaxInflight > 0 { + merged.GlobalMaxInflight = incoming.GlobalMaxInflight + } + } + return validateRuntimeSettings(merged) +} + +func (h *Handler) applyRuntimeSettings() { + if h == nil || h.Store == nil || h.Pool == nil { + return + } + accountCount := len(h.Store.Accounts()) + maxPer := h.Store.RuntimeAccountMaxInflight() + recommended := defaultRuntimeRecommended(accountCount, maxPer) + maxQueue := h.Store.RuntimeAccountMaxQueue(recommended) + global := h.Store.RuntimeGlobalMaxInflight(recommended) + h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global) +} + +func defaultRuntimeRecommended(accountCount, maxPer int) int { + if maxPer <= 0 { + maxPer = 1 + } + if accountCount <= 0 { + return maxPer + } + return accountCount * maxPer +} + +func settingsClaudeMapping(c config.Config) map[string]string { + if len(c.ClaudeMapping) > 0 { + return c.ClaudeMapping + } + if len(c.ClaudeModelMap) > 0 { + return c.ClaudeModelMap + } + return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} +} diff --git a/internal/admin/handler_settings_test.go b/internal/admin/handler_settings_test.go new file mode 100644 index 0000000..3eb5114 --- /dev/null +++ b/internal/admin/handler_settings_test.go @@ -0,0 +1,267 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + authn "ds2api/internal/auth" +) + +func TestGetSettingsDefaultPasswordWarning(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "") + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + req := httptest.NewRequest(http.MethodGet, "/admin/settings", nil) + rec := httptest.NewRecorder() + h.getSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + var body map[string]any + _ = json.Unmarshal(rec.Body.Bytes(), &body) + admin, _ := body["admin"].(map[string]any) + warn, _ := admin["default_password_warning"].(bool) + if !warn { + t.Fatalf("expected default password warning true, body=%v", body) + } +} + +func TestUpdateSettingsValidation(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + payload := map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 0, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } +} + +func TestUpdateSettingsValidationWithMergedRuntimeSnapshot(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "runtime":{ + "account_max_inflight":8, + "global_max_inflight":8 + } + }`) + payload := map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 16, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.global_max_inflight")) { + t.Fatalf("expected merged runtime validation detail, got %s", rec.Body.String()) + } +} + +func TestUpdateSettingsWithoutRuntimeSkipsMergedRuntimeValidation(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "runtime":{ + "account_max_inflight":8, + "global_max_inflight":4 + } + }`) + payload := map[string]any{ + "responses": map[string]any{ + "store_ttl_seconds": 600, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if got := h.Store.Snapshot().Responses.StoreTTLSeconds; got != 600 { + t.Fatalf("store_ttl_seconds=%d want=600", got) + } +} + +func TestUpdateSettingsHotReloadRuntime(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "accounts":[{"email":"a@test.com","token":"t1"},{"email":"b@test.com","token":"t2"}] + }`) + + payload := map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 3, + "account_max_queue": 20, + "global_max_inflight": 5, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + status := h.Pool.Status() + if got := intFrom(status["max_inflight_per_account"]); got != 3 { + t.Fatalf("max_inflight_per_account=%d want=3", got) + } + if got := intFrom(status["max_queue_size"]); got != 20 { + t.Fatalf("max_queue_size=%d want=20", got) + } + if got := intFrom(status["global_max_inflight"]); got != 5 { + t.Fatalf("global_max_inflight=%d want=5", got) + } +} + +func TestUpdateSettingsPasswordInvalidatesOldJWT(t *testing.T) { + hash := authn.HashAdminPassword("old-password") + h := newAdminTestHandler(t, `{"admin":{"password_hash":"`+hash+`"}}`) + + token, err := authn.CreateJWTWithStore(1, h.Store) + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + if _, err := authn.VerifyJWTWithStore(token, h.Store); err != nil { + t.Fatalf("verify before update failed: %v", err) + } + + body := map[string]any{"new_password": "new-password"} + b, _ := json.Marshal(body) + req := httptest.NewRequest(http.MethodPost, "/admin/settings/password", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettingsPassword(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + + if _, err := authn.VerifyJWTWithStore(token, h.Store); err == nil { + t.Fatal("expected old token to be invalid after password update") + } + if !authn.VerifyAdminCredential("new-password", h.Store) { + t.Fatal("expected new password credential to be accepted") + } +} + +func TestConfigImportMergeAndReplace(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "accounts":[{"email":"a@test.com","password":"p1"}] + }`) + + merge := map[string]any{ + "mode": "merge", + "config": map[string]any{ + "keys": []any{"k1", "k2"}, + "accounts": []any{ + map[string]any{"email": "a@test.com", "password": "p1"}, + map[string]any{"email": "b@test.com", "password": "p2"}, + }, + }, + } + mergeBytes, _ := json.Marshal(merge) + mergeReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(mergeBytes)) + mergeRec := httptest.NewRecorder() + h.configImport(mergeRec, mergeReq) + if mergeRec.Code != http.StatusOK { + t.Fatalf("merge status=%d body=%s", mergeRec.Code, mergeRec.Body.String()) + } + if got := len(h.Store.Keys()); got != 2 { + t.Fatalf("keys after merge=%d want=2", got) + } + if got := len(h.Store.Accounts()); got != 2 { + t.Fatalf("accounts after merge=%d want=2", got) + } + + replace := map[string]any{ + "mode": "replace", + "config": map[string]any{ + "keys": []any{"k9"}, + }, + } + replaceBytes, _ := json.Marshal(replace) + replaceReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=replace", bytes.NewReader(replaceBytes)) + replaceRec := httptest.NewRecorder() + h.configImport(replaceRec, replaceReq) + if replaceRec.Code != http.StatusOK { + t.Fatalf("replace status=%d body=%s", replaceRec.Code, replaceRec.Body.String()) + } + keys := h.Store.Keys() + if len(keys) != 1 || keys[0] != "k9" { + t.Fatalf("unexpected keys after replace: %#v", keys) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("accounts after replace=%d want=0", got) + } +} + +func TestConfigImportRejectsInvalidRuntimeBounds(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + payload := map[string]any{ + "mode": "replace", + "config": map[string]any{ + "keys": []any{"k2"}, + "runtime": map[string]any{ + "account_max_inflight": 300, + }, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=replace", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.configImport(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.account_max_inflight")) { + t.Fatalf("expected runtime bound detail, got %s", rec.Body.String()) + } + keys := h.Store.Keys() + if len(keys) != 1 || keys[0] != "k1" { + t.Fatalf("store should remain unchanged, keys=%v", keys) + } +} + +func TestConfigImportRejectsMergedRuntimeConflict(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "runtime":{ + "account_max_inflight":8, + "global_max_inflight":8 + } + }`) + payload := map[string]any{ + "mode": "merge", + "config": map[string]any{ + "runtime": map[string]any{ + "account_max_inflight": 16, + }, + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.configImport(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.global_max_inflight")) { + t.Fatalf("expected merged runtime validation detail, got %s", rec.Body.String()) + } + snap := h.Store.Snapshot() + if snap.Runtime.AccountMaxInflight != 8 || snap.Runtime.GlobalMaxInflight != 8 { + t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime) + } +} diff --git a/internal/admin/handler_settings_write.go b/internal/admin/handler_settings_write.go new file mode 100644 index 0000000..c0076ea --- /dev/null +++ b/internal/admin/handler_settings_write.go @@ -0,0 +1,119 @@ +package admin + +import ( + "encoding/json" + "net/http" + "strings" + "time" + + authn "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + + adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + if runtimeCfg != nil { + if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + } + + if err := h.Store.Update(func(c *config.Config) error { + if adminCfg != nil { + if adminCfg.JWTExpireHours > 0 { + c.Admin.JWTExpireHours = adminCfg.JWTExpireHours + } + } + if runtimeCfg != nil { + if runtimeCfg.AccountMaxInflight > 0 { + c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight + } + if runtimeCfg.AccountMaxQueue > 0 { + c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue + } + if runtimeCfg.GlobalMaxInflight > 0 { + c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight + } + } + if toolcallCfg != nil { + if strings.TrimSpace(toolcallCfg.Mode) != "" { + c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode) + } + if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" { + c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) + } + } + if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 { + c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds + } + if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" { + c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider) + } + if claudeMap != nil { + c.ClaudeMapping = claudeMap + c.ClaudeModelMap = nil + } + if aliasMap != nil { + c.ModelAliases = aliasMap + } + return nil + }); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + h.applyRuntimeSettings() + needsSync := config.IsVercel() || h.Store.IsEnvBacked() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "settings updated and hot reloaded", + "env_backed": h.Store.IsEnvBacked(), + "needs_vercel_sync": needsSync, + "manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。", + }) +} + +func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + newPassword := strings.TrimSpace(fieldString(req, "new_password")) + if newPassword == "" { + newPassword = strings.TrimSpace(fieldString(req, "password")) + } + if len(newPassword) < 4 { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"}) + return + } + + now := time.Now().Unix() + hash := authn.HashAdminPassword(newPassword) + if err := h.Store.Update(func(c *config.Config) error { + c.Admin.PasswordHash = hash + c.Admin.JWTValidAfterUnix = now + return nil + }); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "password updated", + "force_relogin": true, + "jwt_valid_after_unix": now, + }) +} diff --git a/internal/admin/handler_vercel.go b/internal/admin/handler_vercel.go index 189d8cc..2c6356c 100644 --- a/internal/admin/handler_vercel.go +++ b/internal/admin/handler_vercel.go @@ -3,8 +3,8 @@ package admin import ( "bytes" "context" - "encoding/base64" "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -19,6 +19,62 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) return } + opts, err := parseVercelSyncOptions(req) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + validated, failed := h.validateAccountsForVercelSync(r.Context(), opts.AutoValidate) + _, cfgB64, err := h.Store.ExportJSONAndBase64() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + client := &http.Client{Timeout: 30 * time.Second} + params := buildVercelParams(opts.TeamID) + headers := map[string]string{"Authorization": "Bearer " + opts.VercelToken} + + envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+opts.ProjectID+"/env", params, headers, nil) + if err != nil || status != http.StatusOK { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"}) + return + } + envs, _ := envResp["envs"].([]any) + status, err = upsertVercelEnv(r.Context(), client, opts.ProjectID, params, headers, envs, "DS2API_CONFIG_JSON", cfgB64) + if err != nil || (status != http.StatusOK && status != http.StatusCreated) { + writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"}) + return + } + savedCreds := h.saveVercelProjectCredentials(r.Context(), client, opts, params, headers, envs) + manual, deployURL := triggerVercelDeployment(r.Context(), client, opts.ProjectID, params, headers) + _ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix()) + result := map[string]any{"success": true, "validated_accounts": validated} + if manual { + result["message"] = "配置已同步到 Vercel,请手动触发重新部署" + result["manual_deploy_required"] = true + } else { + result["message"] = "配置已同步,正在重新部署..." + result["deployment_url"] = deployURL + } + if len(failed) > 0 { + result["failed_accounts"] = failed + } + if len(savedCreds) > 0 { + result["saved_credentials"] = savedCreds + } + writeJSON(w, http.StatusOK, result) +} + +type vercelSyncOptions struct { + VercelToken string + ProjectID string + TeamID string + AutoValidate bool + SaveCreds bool + UsePreconfig bool +} + +func parseVercelSyncOptions(req map[string]any) (vercelSyncOptions, error) { vercelToken, _ := req["vercel_token"].(string) projectID, _ := req["project_id"].(string) teamID, _ := req["team_id"].(string) @@ -40,108 +96,117 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(teamID) == "" { teamID = strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID")) } + vercelToken = strings.TrimSpace(vercelToken) + projectID = strings.TrimSpace(projectID) + teamID = strings.TrimSpace(teamID) if vercelToken == "" || projectID == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 Vercel Token 和 Project ID"}) - return + return vercelSyncOptions{}, fmt.Errorf("需要 Vercel Token 和 Project ID") + } + return vercelSyncOptions{ + VercelToken: vercelToken, + ProjectID: projectID, + TeamID: teamID, + AutoValidate: autoValidate, + SaveCreds: saveCreds, + UsePreconfig: usePreconfig, + }, nil +} + +func buildVercelParams(teamID string) url.Values { + params := url.Values{} + if strings.TrimSpace(teamID) != "" { + params.Set("teamId", strings.TrimSpace(teamID)) + } + return params +} + +func (h *Handler) validateAccountsForVercelSync(ctx context.Context, enabled bool) (int, []string) { + if !enabled { + return 0, nil } validated, failed := 0, []string{} - if autoValidate { - for _, acc := range h.Store.Snapshot().Accounts { - if strings.TrimSpace(acc.Token) != "" { - continue - } - token, err := h.DS.Login(r.Context(), acc) - if err != nil { - failed = append(failed, acc.Identifier()) - } else { - validated++ - _ = h.Store.UpdateAccountToken(acc.Identifier(), token) - } - time.Sleep(500 * time.Millisecond) + for _, acc := range h.Store.Snapshot().Accounts { + if strings.TrimSpace(acc.Token) != "" { + continue } + token, err := h.DS.Login(ctx, acc) + if err != nil { + failed = append(failed, acc.Identifier()) + } else { + validated++ + _ = h.Store.UpdateAccountToken(acc.Identifier(), token) + } + time.Sleep(500 * time.Millisecond) } + return validated, failed +} - cfgJSON, _, err := h.Store.ExportJSONAndBase64() - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return +func upsertVercelEnv(ctx context.Context, client *http.Client, projectID string, params url.Values, headers map[string]string, envs []any, key, value string) (int, error) { + existingID := findEnvID(envs, key) + if existingID != "" { + _, status, err := vercelRequest(ctx, client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingID, params, headers, map[string]any{"value": value}) + return status, err } - cfgB64 := base64.StdEncoding.EncodeToString([]byte(cfgJSON)) - client := &http.Client{Timeout: 30 * time.Second} - params := url.Values{} - if teamID != "" { - params.Set("teamId", teamID) + _, status, err := vercelRequest(ctx, client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{ + "key": key, + "value": value, + "type": "encrypted", + "target": []string{"production", "preview"}, + }) + return status, err +} + +func (h *Handler) saveVercelProjectCredentials(ctx context.Context, client *http.Client, opts vercelSyncOptions, params url.Values, headers map[string]string, envs []any) []string { + if !opts.SaveCreds || opts.UsePreconfig { + return nil } - headers := map[string]string{"Authorization": "Bearer " + vercelToken} - envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID+"/env", params, headers, nil) - if err != nil || status != http.StatusOK { - writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"}) - return + saved := []string{} + creds := [][2]string{{"VERCEL_TOKEN", opts.VercelToken}, {"VERCEL_PROJECT_ID", opts.ProjectID}} + if opts.TeamID != "" { + creds = append(creds, [2]string{"VERCEL_TEAM_ID", opts.TeamID}) } - envs, _ := envResp["envs"].([]any) - existingEnvID := findEnvID(envs, "DS2API_CONFIG_JSON") - if existingEnvID != "" { - _, status, err = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingEnvID, params, headers, map[string]any{"value": cfgB64}) - } else { - _, status, err = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": "DS2API_CONFIG_JSON", "value": cfgB64, "type": "encrypted", "target": []string{"production", "preview"}}) - } - if err != nil || (status != http.StatusOK && status != http.StatusCreated) { - writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"}) - return - } - savedCreds := []string{} - if saveCreds && !usePreconfig { - creds := [][2]string{{"VERCEL_TOKEN", vercelToken}, {"VERCEL_PROJECT_ID", projectID}} - if teamID != "" { - creds = append(creds, [2]string{"VERCEL_TEAM_ID", teamID}) - } - for _, kv := range creds { - id := findEnvID(envs, kv[0]) - if id != "" { - _, status, _ = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+id, params, headers, map[string]any{"value": kv[1]}) - } else { - _, status, _ = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": kv[0], "value": kv[1], "type": "encrypted", "target": []string{"production", "preview"}}) - } - if status == http.StatusOK || status == http.StatusCreated { - savedCreds = append(savedCreds, kv[0]) - } + for _, kv := range creds { + status, _ := upsertVercelEnv(ctx, client, opts.ProjectID, params, headers, envs, kv[0], kv[1]) + if status == http.StatusOK || status == http.StatusCreated { + saved = append(saved, kv[0]) } } - projectResp, status, _ := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil) - manual := true - deployURL := "" - if status == http.StatusOK { - if link, ok := projectResp["link"].(map[string]any); ok { - if linkType, _ := link["type"].(string); linkType == "github" { - repoID := intFrom(link["repoId"]) - ref, _ := link["productionBranch"].(string) - if ref == "" { - ref = "main" - } - depResp, depStatus, _ := vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{"name": projectID, "project": projectID, "target": "production", "gitSource": map[string]any{"type": "github", "repoId": repoID, "ref": ref}}) - if depStatus == http.StatusOK || depStatus == http.StatusCreated { - deployURL, _ = depResp["url"].(string) - manual = false - } - } - } + return saved +} + +func triggerVercelDeployment(ctx context.Context, client *http.Client, projectID string, params url.Values, headers map[string]string) (bool, string) { + projectResp, status, _ := vercelRequest(ctx, client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil) + if status != http.StatusOK { + return true, "" } - _ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix()) - result := map[string]any{"success": true, "validated_accounts": validated} - if manual { - result["message"] = "配置已同步到 Vercel,请手动触发重新部署" - result["manual_deploy_required"] = true - } else { - result["message"] = "配置已同步,正在重新部署..." - result["deployment_url"] = deployURL + link, ok := projectResp["link"].(map[string]any) + if !ok { + return true, "" } - if len(failed) > 0 { - result["failed_accounts"] = failed + linkType, _ := link["type"].(string) + if linkType != "github" { + return true, "" } - if len(savedCreds) > 0 { - result["saved_credentials"] = savedCreds + repoID := intFrom(link["repoId"]) + ref, _ := link["productionBranch"].(string) + if ref == "" { + ref = "main" } - writeJSON(w, http.StatusOK, result) + depResp, depStatus, _ := vercelRequest(ctx, client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{ + "name": projectID, + "project": projectID, + "target": "production", + "gitSource": map[string]any{ + "type": "github", + "repoId": repoID, + "ref": ref, + }, + }) + if depStatus != http.StatusOK && depStatus != http.StatusCreated { + return true, "" + } + deployURL, _ := depResp["url"].(string) + return false, deployURL } func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) { diff --git a/internal/admin/helpers.go b/internal/admin/helpers.go index fa75b59..2e00323 100644 --- a/internal/admin/helpers.go +++ b/internal/admin/helpers.go @@ -81,3 +81,34 @@ func statusOr(v int, d int) int { } return v } + +func accountMatchesIdentifier(acc config.Account, identifier string) bool { + id := strings.TrimSpace(identifier) + if id == "" { + return false + } + if strings.TrimSpace(acc.Email) == id { + return true + } + if strings.TrimSpace(acc.Mobile) == id { + return true + } + return acc.Identifier() == id +} + +func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) { + id := strings.TrimSpace(identifier) + if id == "" { + return config.Account{}, false + } + if acc, ok := store.FindAccount(id); ok { + return acc, true + } + accounts := store.Snapshot().Accounts + for _, acc := range accounts { + if accountMatchesIdentifier(acc, id) { + return acc, true + } + } + return config.Account{}, false +} diff --git a/internal/admin/helpers_edge_test.go b/internal/admin/helpers_edge_test.go new file mode 100644 index 0000000..2a0bf20 --- /dev/null +++ b/internal/admin/helpers_edge_test.go @@ -0,0 +1,240 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "testing" + + "ds2api/internal/config" +) + +// ─── reverseAccounts ───────────────────────────────────────────────── + +func TestReverseAccountsEmpty(t *testing.T) { + a := []config.Account{} + reverseAccounts(a) + if len(a) != 0 { + t.Fatal("expected empty") + } +} + +func TestReverseAccountsTwoElements(t *testing.T) { + a := []config.Account{ + {Email: "a@test.com"}, + {Email: "b@test.com"}, + } + reverseAccounts(a) + if a[0].Email != "b@test.com" || a[1].Email != "a@test.com" { + t.Fatalf("unexpected order after reverse: %v", a) + } +} + +func TestReverseAccountsThreeElements(t *testing.T) { + a := []config.Account{ + {Email: "1@test.com"}, + {Email: "2@test.com"}, + {Email: "3@test.com"}, + } + reverseAccounts(a) + if a[0].Email != "3@test.com" || a[1].Email != "2@test.com" || a[2].Email != "1@test.com" { + t.Fatalf("unexpected order: %v", a) + } +} + +// ─── intFromQuery edge cases ───────────────────────────────────────── + +func TestIntFromQueryPresent(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=5", nil) + if got := intFromQuery(req, "limit", 10); got != 5 { + t.Fatalf("expected 5, got %d", got) + } +} + +func TestIntFromQueryMissing(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if got := intFromQuery(req, "limit", 10); got != 10 { + t.Fatalf("expected default 10, got %d", got) + } +} + +func TestIntFromQueryInvalid(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=abc", nil) + if got := intFromQuery(req, "limit", 10); got != 10 { + t.Fatalf("expected default 10 for invalid, got %d", got) + } +} + +func TestIntFromQueryNegative(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=-3", nil) + if got := intFromQuery(req, "limit", 10); got != -3 { + t.Fatalf("expected -3, got %d", got) + } +} + +func TestIntFromQueryZero(t *testing.T) { + req := httptest.NewRequest("GET", "/?limit=0", nil) + if got := intFromQuery(req, "limit", 10); got != 0 { + t.Fatalf("expected 0, got %d", got) + } +} + +// ─── nilIfEmpty ────────────────────────────────────────────────────── + +func TestNilIfEmptyEmpty(t *testing.T) { + if nilIfEmpty("") != nil { + t.Fatal("expected nil for empty string") + } +} + +func TestNilIfEmptyNonEmpty(t *testing.T) { + if nilIfEmpty("hello") != "hello" { + t.Fatal("expected 'hello'") + } +} + +// ─── nilIfZero ─────────────────────────────────────────────────────── + +func TestNilIfZeroZero(t *testing.T) { + if nilIfZero(0) != nil { + t.Fatal("expected nil for zero") + } +} + +func TestNilIfZeroNonZero(t *testing.T) { + if nilIfZero(42) != int64(42) { + t.Fatal("expected 42") + } +} + +func TestNilIfZeroNegative(t *testing.T) { + if nilIfZero(-1) != int64(-1) { + t.Fatal("expected -1") + } +} + +// ─── toStringSlice ─────────────────────────────────────────────────── + +func TestToStringSliceFromAnySlice(t *testing.T) { + input := []any{"a", "b", "c"} + got, ok := toStringSlice(input) + if !ok || len(got) != 3 { + t.Fatalf("expected 3 strings, got %#v ok=%v", got, ok) + } + if got[0] != "a" || got[1] != "b" || got[2] != "c" { + t.Fatalf("unexpected values: %#v", got) + } +} + +func TestToStringSliceFromMixed(t *testing.T) { + input := []any{"hello", 42, true} + got, ok := toStringSlice(input) + if !ok { + t.Fatal("expected ok for mixed types") + } + if got[0] != "hello" || got[1] != "42" || got[2] != "true" { + t.Fatalf("unexpected values: %#v", got) + } +} + +func TestToStringSliceFromNonSlice(t *testing.T) { + _, ok := toStringSlice("not a slice") + if ok { + t.Fatal("expected not ok for string input") + } +} + +func TestToStringSliceFromNil(t *testing.T) { + _, ok := toStringSlice(nil) + if ok { + t.Fatal("expected not ok for nil input") + } +} + +func TestToStringSliceEmpty(t *testing.T) { + got, ok := toStringSlice([]any{}) + if !ok { + t.Fatal("expected ok for empty slice") + } + if len(got) != 0 { + t.Fatalf("expected empty result, got %#v", got) + } +} + +func TestToStringSliceTrimsWhitespace(t *testing.T) { + got, ok := toStringSlice([]any{" hello ", " world "}) + if !ok { + t.Fatal("expected ok") + } + if got[0] != "hello" || got[1] != "world" { + t.Fatalf("expected trimmed values, got %#v", got) + } +} + +// ─── toAccount edge cases ──────────────────────────────────────────── + +func TestToAccountAllFields(t *testing.T) { + acc := toAccount(map[string]any{ + "email": "user@test.com", + "mobile": "13800138000", + "password": "secret", + "token": "tok123", + }) + if acc.Email != "user@test.com" { + t.Fatalf("unexpected email: %q", acc.Email) + } + if acc.Mobile != "13800138000" { + t.Fatalf("unexpected mobile: %q", acc.Mobile) + } + if acc.Password != "secret" { + t.Fatalf("unexpected password: %q", acc.Password) + } + if acc.Token != "tok123" { + t.Fatalf("unexpected token: %q", acc.Token) + } +} + +func TestToAccountNumericValues(t *testing.T) { + acc := toAccount(map[string]any{ + "email": 12345, + }) + if acc.Email != "12345" { + t.Fatalf("expected numeric converted to string, got %q", acc.Email) + } +} + +// ─── fieldString edge cases ────────────────────────────────────────── + +func TestFieldStringNonString(t *testing.T) { + got := fieldString(map[string]any{"key": 42}, "key") + if got != "42" { + t.Fatalf("expected '42' for int, got %q", got) + } +} + +func TestFieldStringBool(t *testing.T) { + got := fieldString(map[string]any{"key": true}, "key") + if got != "true" { + t.Fatalf("expected 'true', got %q", got) + } +} + +func TestFieldStringWhitespace(t *testing.T) { + got := fieldString(map[string]any{"key": " hello "}, "key") + if got != "hello" { + t.Fatalf("expected trimmed 'hello', got %q", got) + } +} + +// ─── statusOr ──────────────────────────────────────────────────────── + +func TestStatusOrZeroReturnsDefault(t *testing.T) { + if got := statusOr(0, http.StatusOK); got != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, got) + } +} + +func TestStatusOrNonZeroReturnsValue(t *testing.T) { + if got := statusOr(http.StatusBadRequest, http.StatusOK); got != http.StatusBadRequest { + t.Fatalf("expected %d, got %d", http.StatusBadRequest, got) + } +} diff --git a/internal/admin/request_error.go b/internal/admin/request_error.go new file mode 100644 index 0000000..5431a3d --- /dev/null +++ b/internal/admin/request_error.go @@ -0,0 +1,23 @@ +package admin + +import "errors" + +type requestError struct { + detail string +} + +func (e *requestError) Error() string { + return e.detail +} + +func newRequestError(detail string) error { + return &requestError{detail: detail} +} + +func requestErrorDetail(err error) (string, bool) { + var reqErr *requestError + if errors.As(err, &reqErr) { + return reqErr.detail, true + } + return "", false +} diff --git a/internal/admin/settings_validation.go b/internal/admin/settings_validation.go new file mode 100644 index 0000000..f9d4c2f --- /dev/null +++ b/internal/admin/settings_validation.go @@ -0,0 +1,64 @@ +package admin + +import ( + "fmt" + "strings" + + "ds2api/internal/config" +) + +func normalizeSettingsConfig(c *config.Config) { + if c == nil { + return + } + c.Admin.PasswordHash = strings.TrimSpace(c.Admin.PasswordHash) + c.Toolcall.Mode = strings.ToLower(strings.TrimSpace(c.Toolcall.Mode)) + c.Toolcall.EarlyEmitConfidence = strings.ToLower(strings.TrimSpace(c.Toolcall.EarlyEmitConfidence)) + c.Embeddings.Provider = strings.TrimSpace(c.Embeddings.Provider) +} + +func validateSettingsConfig(c config.Config) error { + if c.Admin.JWTExpireHours != 0 && (c.Admin.JWTExpireHours < 1 || c.Admin.JWTExpireHours > 720) { + return fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") + } + if err := validateRuntimeSettings(c.Runtime); err != nil { + return err + } + if c.Responses.StoreTTLSeconds != 0 && (c.Responses.StoreTTLSeconds < 30 || c.Responses.StoreTTLSeconds > 86400) { + return fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") + } + if mode := strings.TrimSpace(c.Toolcall.Mode); mode != "" { + switch mode { + case "feature_match", "off": + default: + return fmt.Errorf("toolcall.mode must be feature_match or off") + } + } + if level := strings.TrimSpace(c.Toolcall.EarlyEmitConfidence); level != "" { + switch level { + case "high", "low", "off": + default: + return fmt.Errorf("toolcall.early_emit_confidence must be high, low or off") + } + } + if c.Embeddings.Provider != "" && strings.TrimSpace(c.Embeddings.Provider) == "" { + return fmt.Errorf("embeddings.provider cannot be empty") + } + return nil +} + +func validateRuntimeSettings(runtime config.RuntimeConfig) error { + if runtime.AccountMaxInflight != 0 && (runtime.AccountMaxInflight < 1 || runtime.AccountMaxInflight > 256) { + return fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") + } + if runtime.AccountMaxQueue != 0 && (runtime.AccountMaxQueue < 1 || runtime.AccountMaxQueue > 200000) { + return fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") + } + if runtime.GlobalMaxInflight != 0 && (runtime.GlobalMaxInflight < 1 || runtime.GlobalMaxInflight > 200000) { + return fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") + } + if runtime.AccountMaxInflight > 0 && runtime.GlobalMaxInflight > 0 && runtime.GlobalMaxInflight < runtime.AccountMaxInflight { + return fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") + } + return nil +} diff --git a/internal/auth/admin.go b/internal/auth/admin.go index 3a52f6c..8f1d276 100644 --- a/internal/auth/admin.go +++ b/internal/auth/admin.go @@ -3,7 +3,9 @@ package auth import ( "crypto/hmac" "crypto/sha256" + "crypto/subtle" "encoding/base64" + "encoding/hex" "encoding/json" "errors" "log/slog" @@ -17,7 +19,22 @@ import ( var warnOnce sync.Once +type AdminConfigReader interface { + AdminPasswordHash() string + AdminJWTExpireHours() int + AdminJWTValidAfterUnix() int64 +} + func AdminKey() string { + return effectiveAdminKey(nil) +} + +func effectiveAdminKey(store AdminConfigReader) string { + if store != nil { + if hash := strings.TrimSpace(store.AdminPasswordHash()); hash != "" { + return "" + } + } if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" { return v } @@ -27,14 +44,24 @@ func AdminKey() string { return "admin" } -func jwtSecret() string { +func jwtSecret(store AdminConfigReader) string { if v := strings.TrimSpace(os.Getenv("DS2API_JWT_SECRET")); v != "" { return v } - return AdminKey() + if store != nil { + if hash := strings.TrimSpace(store.AdminPasswordHash()); hash != "" { + return hash + } + } + return effectiveAdminKey(store) } -func jwtExpireHours() int { +func jwtExpireHours(store AdminConfigReader) int { + if store != nil { + if n := store.AdminJWTExpireHours(); n > 0 { + return n + } + } if v := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 { return n @@ -44,27 +71,44 @@ func jwtExpireHours() int { } func CreateJWT(expireHours int) (string, error) { + return CreateJWTWithStore(expireHours, nil) +} + +func CreateJWTWithStore(expireHours int, store AdminConfigReader) (string, error) { if expireHours <= 0 { - expireHours = jwtExpireHours() + expireHours = jwtExpireHours(store) } + issuedAt := time.Now().Unix() + // If sessions were invalidated in this same second, move iat forward by + // one second so newly minted tokens remain valid with strict cutoff checks. + if store != nil { + if validAfter := store.AdminJWTValidAfterUnix(); validAfter >= issuedAt { + issuedAt = validAfter + 1 + } + } + expireAt := time.Unix(issuedAt, 0).Add(time.Duration(expireHours) * time.Hour).Unix() header := map[string]any{"alg": "HS256", "typ": "JWT"} - payload := map[string]any{"iat": time.Now().Unix(), "exp": time.Now().Add(time.Duration(expireHours) * time.Hour).Unix(), "role": "admin"} + payload := map[string]any{"iat": issuedAt, "exp": expireAt, "role": "admin"} h, _ := json.Marshal(header) p, _ := json.Marshal(payload) headerB64 := rawB64Encode(h) payloadB64 := rawB64Encode(p) msg := headerB64 + "." + payloadB64 - sig := signHS256(msg) + sig := signHS256(msg, store) return msg + "." + rawB64Encode(sig), nil } func VerifyJWT(token string) (map[string]any, error) { + return VerifyJWTWithStore(token, nil) +} + +func VerifyJWTWithStore(token string, store AdminConfigReader) (map[string]any, error) { parts := strings.Split(token, ".") if len(parts) != 3 { return nil, errors.New("invalid token format") } msg := parts[0] + "." + parts[1] - expected := signHS256(msg) + expected := signHS256(msg, store) actual, err := rawB64Decode(parts[2]) if err != nil { return nil, errors.New("invalid signature") @@ -84,10 +128,23 @@ func VerifyJWT(token string) (map[string]any, error) { if int64(exp) < time.Now().Unix() { return nil, errors.New("token expired") } + if store != nil { + validAfter := store.AdminJWTValidAfterUnix() + if validAfter > 0 { + iat, _ := payload["iat"].(float64) + if int64(iat) <= validAfter { + return nil, errors.New("token expired") + } + } + } return payload, nil } func VerifyAdminRequest(r *http.Request) error { + return VerifyAdminRequestWithStore(r, nil) +} + +func VerifyAdminRequestWithStore(r *http.Request, store AdminConfigReader) error { authHeader := strings.TrimSpace(r.Header.Get("Authorization")) if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { return errors.New("authentication required") @@ -96,17 +153,65 @@ func VerifyAdminRequest(r *http.Request) error { if token == "" { return errors.New("authentication required") } - if token == AdminKey() { + if VerifyAdminCredential(token, store) { return nil } - if _, err := VerifyJWT(token); err == nil { + if _, err := VerifyJWTWithStore(token, store); err == nil { return nil } return errors.New("invalid credentials") } -func signHS256(msg string) []byte { - h := hmac.New(sha256.New, []byte(jwtSecret())) +func VerifyAdminCredential(candidate string, store AdminConfigReader) bool { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + return false + } + if store != nil { + hash := strings.TrimSpace(store.AdminPasswordHash()) + if hash != "" { + return verifyAdminPasswordHash(candidate, hash) + } + } + key := effectiveAdminKey(store) + if key == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(candidate), []byte(key)) == 1 +} + +func UsingDefaultAdminKey(store AdminConfigReader) bool { + if store != nil && strings.TrimSpace(store.AdminPasswordHash()) != "" { + return false + } + return strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")) == "" +} + +func HashAdminPassword(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + sum := sha256.Sum256([]byte(raw)) + return "sha256:" + hex.EncodeToString(sum[:]) +} + +func verifyAdminPasswordHash(candidate, encoded string) bool { + encoded = strings.TrimSpace(strings.ToLower(encoded)) + if encoded == "" { + return false + } + if strings.HasPrefix(encoded, "sha256:") { + want := strings.TrimPrefix(encoded, "sha256:") + sum := sha256.Sum256([]byte(candidate)) + got := hex.EncodeToString(sum[:]) + return subtle.ConstantTimeCompare([]byte(got), []byte(want)) == 1 + } + return subtle.ConstantTimeCompare([]byte(candidate), []byte(encoded)) == 1 +} + +func signHS256(msg string, store AdminConfigReader) []byte { + h := hmac.New(sha256.New, []byte(jwtSecret(store))) _, _ = h.Write([]byte(msg)) return h.Sum(nil) } diff --git a/internal/auth/admin_test.go b/internal/auth/admin_test.go index 7489074..bfbd4c3 100644 --- a/internal/auth/admin_test.go +++ b/internal/auth/admin_test.go @@ -3,6 +3,8 @@ package auth import ( "net/http" "testing" + + "ds2api/internal/config" ) func TestJWTCreateVerify(t *testing.T) { @@ -27,3 +29,58 @@ func TestVerifyAdminRequest(t *testing.T) { t.Fatalf("expected token accepted: %v", err) } } + +func TestVerifyJWTWithStoreValidAfter(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"admin":{"password_hash":"`+HashAdminPassword("oldpass")+`"}}`) + store := config.LoadStore() + token, err := CreateJWTWithStore(1, store) + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + if _, err := VerifyJWTWithStore(token, store); err != nil { + t.Fatalf("verify before invalidation failed: %v", err) + } + if err := store.Update(func(c *config.Config) error { + c.Admin.JWTValidAfterUnix = 1<<62 - 1 + return nil + }); err != nil { + t.Fatalf("set valid-after failed: %v", err) + } + if _, err := VerifyJWTWithStore(token, store); err == nil { + t.Fatal("expected token invalid after valid-after update") + } +} + +func TestVerifyJWTWithStoreSameSecondInvalidationAndRelogin(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"admin":{"password_hash":"`+HashAdminPassword("oldpass")+`"}}`) + store := config.LoadStore() + + oldToken, err := CreateJWTWithStore(1, store) + if err != nil { + t.Fatalf("create old jwt failed: %v", err) + } + oldPayload, err := VerifyJWTWithStore(oldToken, store) + if err != nil { + t.Fatalf("verify old jwt before invalidation failed: %v", err) + } + oldIAT, _ := oldPayload["iat"].(float64) + + if err := store.Update(func(c *config.Config) error { + c.Admin.JWTValidAfterUnix = int64(oldIAT) + return nil + }); err != nil { + t.Fatalf("set valid-after failed: %v", err) + } + + if _, err := VerifyJWTWithStore(oldToken, store); err == nil { + t.Fatal("expected old token invalid when iat == valid-after") + } + + newToken, err := CreateJWTWithStore(1, store) + if err != nil { + t.Fatalf("create new jwt failed: %v", err) + } + if _, err := VerifyJWTWithStore(newToken, store); err != nil { + t.Fatalf("expected new token valid after invalidation cutoff: %v", err) + } +} diff --git a/internal/auth/auth_edge_test.go b/internal/auth/auth_edge_test.go new file mode 100644 index 0000000..55c46ef --- /dev/null +++ b/internal/auth/auth_edge_test.go @@ -0,0 +1,375 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "testing" + + "ds2api/internal/account" + "ds2api/internal/config" +) + +// ─── extractCallerToken edge cases ─────────────────────────────────── + +func TestExtractCallerTokenBearerPrefix(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer my-token") + if got := extractCallerToken(req); got != "my-token" { + t.Fatalf("expected my-token, got %q", got) + } +} + +func TestExtractCallerTokenBearerCaseInsensitive(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "BEARER My-Token") + if got := extractCallerToken(req); got != "My-Token" { + t.Fatalf("expected My-Token, got %q", got) + } +} + +func TestExtractCallerTokenBearerEmpty(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer ") + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for 'Bearer ', got %q", got) + } +} + +func TestExtractCallerTokenXAPIKey(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("x-api-key", "x-api-key-token") + if got := extractCallerToken(req); got != "x-api-key-token" { + t.Fatalf("expected x-api-key-token, got %q", got) + } +} + +func TestExtractCallerTokenBearerPreferredOverXAPIKey(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer bearer-token") + req.Header.Set("x-api-key", "x-api-key-token") + if got := extractCallerToken(req); got != "bearer-token" { + t.Fatalf("expected bearer-token, got %q", got) + } +} + +func TestExtractCallerTokenMissingHeaders(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for missing headers, got %q", got) + } +} + +func TestExtractCallerTokenNonBearerAuth(t *testing.T) { + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Basic abc123") + if got := extractCallerToken(req); got != "" { + t.Fatalf("expected empty for Basic auth, got %q", got) + } +} + +// ─── Context helpers ───────────────────────────────────────────────── + +func TestWithAuthAndFromContext(t *testing.T) { + a := &RequestAuth{DeepSeekToken: "test-token"} + ctx := WithAuth(context.Background(), a) + got, ok := FromContext(ctx) + if !ok || got.DeepSeekToken != "test-token" { + t.Fatalf("expected token from context, got ok=%v token=%q", ok, got.DeepSeekToken) + } +} + +func TestFromContextMissing(t *testing.T) { + _, ok := FromContext(context.Background()) + if ok { + t.Fatal("expected not ok from empty context") + } +} + +// ─── RefreshToken edge cases ───────────────────────────────────────── + +func TestRefreshTokenNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, resolver: r} + if r.RefreshToken(context.Background(), a) { + t.Fatal("expected false for non-config token") + } +} + +func TestRefreshTokenEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: "", resolver: r} + if r.RefreshToken(context.Background(), a) { + t.Fatal("expected false for empty account ID") + } +} + +func TestRefreshTokenSuccess(t *testing.T) { + r := newTestResolver(t) + // First acquire an account + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + + if !r.RefreshToken(context.Background(), a) { + t.Fatal("expected refresh to succeed") + } + if a.DeepSeekToken != "fresh-token" { + t.Fatalf("expected fresh-token after refresh, got %q", a.DeepSeekToken) + } +} + +// ─── MarkTokenInvalid edge cases ───────────────────────────────────── + +func TestMarkTokenInvalidNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, DeepSeekToken: "direct", resolver: r} + r.MarkTokenInvalid(a) + // Should not panic, token should be unchanged for non-config + if a.DeepSeekToken != "" { + // Actually it does clear it; that's fine - let's check behavior + } +} + +func TestMarkTokenInvalidEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: "", DeepSeekToken: "tok", resolver: r} + r.MarkTokenInvalid(a) + // Should not panic +} + +func TestMarkTokenInvalidClearsToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + + r.MarkTokenInvalid(a) + if a.DeepSeekToken != "" { + t.Fatalf("expected empty token after invalidation, got %q", a.DeepSeekToken) + } + if a.Account.Token != "" { + t.Fatalf("expected empty account token after invalidation, got %q", a.Account.Token) + } +} + +// ─── SwitchAccount edge cases ──────────────────────────────────────── + +func TestSwitchAccountNotConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false, resolver: r} + if r.SwitchAccount(context.Background(), a) { + t.Fatal("expected false for non-config token") + } +} + +func TestSwitchAccountNilTriedAccounts(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"acc1@test.com","token":"t1"}, + {"email":"acc2@test.com","token":"t2"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "new-token", nil + }) + + // First acquire + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + + oldID := a.AccountID + a.TriedAccounts = nil // test nil initialization in SwitchAccount + if !r.SwitchAccount(context.Background(), a) { + t.Fatal("expected switch to succeed") + } + if a.AccountID == oldID { + t.Fatalf("expected different account after switch") + } + r.Release(a) +} + +// ─── Release edge cases ───────────────────────────────────────────── + +func TestReleaseNilAuth(t *testing.T) { + r := newTestResolver(t) + r.Release(nil) // should not panic +} + +func TestReleaseNonConfigToken(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: false} + r.Release(a) // should not panic +} + +func TestReleaseEmptyAccountID(t *testing.T) { + r := newTestResolver(t) + a := &RequestAuth{UseConfigToken: true, AccountID: ""} + r.Release(a) // should not panic +} + +// ─── JWT edge cases ────────────────────────────────────────────────── + +func TestVerifyJWTInvalidFormat(t *testing.T) { + _, err := VerifyJWT("not-a-jwt") + if err == nil { + t.Fatal("expected error for invalid JWT format") + } +} + +func TestVerifyJWTInvalidSignature(t *testing.T) { + token, _ := CreateJWT(1) + // Tamper with the signature + parts := splitJWT(token) + if len(parts) == 3 { + tampered := parts[0] + "." + parts[1] + ".invalid_signature" + _, err := VerifyJWT(tampered) + if err == nil { + t.Fatal("expected error for tampered signature") + } + } +} + +func TestVerifyJWTExpired(t *testing.T) { + // Create a token with 0 hours expiry - will use default, so we can't easily test + // Instead test with bad payload + _, err := VerifyJWT("eyJhbGciOiJIUzI1NiJ9.eyJleHAiOjF9.invalid") + if err == nil { + t.Fatal("expected error for expired/invalid JWT") + } +} + +func TestCreateJWTDefaultExpiry(t *testing.T) { + token, err := CreateJWT(0) // should use default + if err != nil { + t.Fatalf("create jwt failed: %v", err) + } + _, err = VerifyJWT(token) + if err != nil { + t.Fatalf("verify jwt failed: %v", err) + } +} + +// ─── VerifyAdminRequest edge cases ─────────────────────────────────── + +func TestVerifyAdminRequestNoHeader(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for missing auth") + } +} + +func TestVerifyAdminRequestEmptyBearer(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer ") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for empty bearer") + } +} + +func TestVerifyAdminRequestWithAdminKey(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "test-admin-key") + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer test-admin-key") + if err := VerifyAdminRequest(req); err != nil { + t.Fatalf("expected admin key accepted: %v", err) + } +} + +func TestVerifyAdminRequestInvalidCredentials(t *testing.T) { + t.Setenv("DS2API_ADMIN_KEY", "correct-key") + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Bearer wrong-key") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for wrong key") + } +} + +func TestVerifyAdminRequestBasicAuth(t *testing.T) { + req, _ := http.NewRequest("GET", "/admin/config", nil) + req.Header.Set("Authorization", "Basic abc123") + if err := VerifyAdminRequest(req); err == nil { + t.Fatal("expected error for Basic auth") + } +} + +// ─── Determine with login failure ──────────────────────────────────── + +func TestDetermineWithLoginFailure(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[{"email":"acc@test.com","password":"pwd"}] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "", errors.New("login failed") + }) + + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + _, err := r.Determine(req) + if err == nil { + t.Fatal("expected error when login fails") + } +} + +// ─── Determine with target account ─────────────────────────────────── + +func TestDetermineWithTargetAccount(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[ + {"email":"acc1@test.com","token":"t1"}, + {"email":"acc2@test.com","token":"t2"} + ] + }`) + store := config.LoadStore() + pool := account.NewPool(store) + r := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "fresh-token", nil + }) + + req, _ := http.NewRequest("POST", "/", nil) + req.Header.Set("Authorization", "Bearer managed-key") + req.Header.Set("X-Ds2-Target-Account", "acc2@test.com") + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + if a.AccountID != "acc2@test.com" { + t.Fatalf("expected target account acc2, got %q", a.AccountID) + } +} + +// helper +func splitJWT(token string) []string { + result := make([]string, 0, 3) + start := 0 + count := 0 + for i := 0; i < len(token); i++ { + if token[i] == '.' { + result = append(result, token[start:i]) + start = i + 1 + count++ + } + } + result = append(result, token[start:]) + return result +} diff --git a/internal/auth/request.go b/internal/auth/request.go index ea3d7f1..c0cdd52 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -2,6 +2,8 @@ package auth import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "net/http" "strings" @@ -22,6 +24,7 @@ var ( type RequestAuth struct { UseConfigToken bool DeepSeekToken string + CallerID string AccountID string Account config.Account TriedAccounts map[string]bool @@ -45,9 +48,16 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { if callerKey == "" { return nil, ErrUnauthorized } + callerID := callerTokenID(callerKey) ctx := req.Context() if !r.Store.HasAPIKey(callerKey) { - return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil + return &RequestAuth{ + UseConfigToken: false, + DeepSeekToken: callerKey, + CallerID: callerID, + resolver: r, + TriedAccounts: map[string]bool{}, + }, nil } target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account")) acc, ok := r.Pool.AcquireWait(ctx, target, nil) @@ -56,6 +66,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { } a := &RequestAuth{ UseConfigToken: true, + CallerID: callerID, AccountID: acc.Identifier(), Account: acc, TriedAccounts: map[string]bool{}, @@ -72,6 +83,26 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { return a, nil } +// DetermineCaller resolves caller identity without acquiring any pooled account. +// Use this for local-cache lookup routes that only need tenant isolation. +func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) { + callerKey := extractCallerToken(req) + if callerKey == "" { + return nil, ErrUnauthorized + } + callerID := callerTokenID(callerKey) + a := &RequestAuth{ + UseConfigToken: false, + CallerID: callerID, + resolver: r, + TriedAccounts: map[string]bool{}, + } + if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) { + a.DeepSeekToken = callerKey + } + return a, nil +} + func WithAuth(ctx context.Context, a *RequestAuth) context.Context { return context.WithValue(ctx, authCtxKey, a) } @@ -156,5 +187,26 @@ func extractCallerToken(req *http.Request) string { return token } } - return strings.TrimSpace(req.Header.Get("x-api-key")) + if key := strings.TrimSpace(req.Header.Get("x-api-key")); key != "" { + return key + } + // Gemini/Google clients commonly send API key via x-goog-api-key. + if key := strings.TrimSpace(req.Header.Get("x-goog-api-key")); key != "" { + return key + } + // Gemini AI Studio compatibility: allow query key fallback only when no + // header-based credential is present. + if key := strings.TrimSpace(req.URL.Query().Get("key")); key != "" { + return key + } + return strings.TrimSpace(req.URL.Query().Get("api_key")) +} + +func callerTokenID(token string) string { + token = strings.TrimSpace(token) + if token == "" { + return "" + } + sum := sha256.Sum256([]byte(token)) + return "caller:" + hex.EncodeToString(sum[:8]) } diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index 1d568f3..2f70e3f 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -37,6 +37,9 @@ func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) { if auth.DeepSeekToken != "direct-token" { t.Fatalf("unexpected token: %q", auth.DeepSeekToken) } + if auth.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } } func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { @@ -58,6 +61,44 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { if auth.DeepSeekToken != "account-token" { t.Fatalf("unexpected account token: %q", auth.DeepSeekToken) } + if auth.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } +} + +func TestDetermineCallerWithManagedKeySkipsAccountAcquire(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) + req.Header.Set("x-api-key", "managed-key") + + a, err := r.DetermineCaller(req) + if err != nil { + t.Fatalf("determine caller failed: %v", err) + } + if a.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } + if a.UseConfigToken { + t.Fatalf("expected no config-token lease for caller-only auth") + } + if a.AccountID != "" { + t.Fatalf("expected empty account id, got %q", a.AccountID) + } +} + +func TestCallerTokenIDStable(t *testing.T) { + a := callerTokenID("token-a") + b := callerTokenID("token-a") + c := callerTokenID("token-b") + if a == "" || b == "" || c == "" { + t.Fatalf("expected non-empty caller ids") + } + if a != b { + t.Fatalf("expected stable caller id, got %q and %q", a, b) + } + if a == c { + t.Fatalf("expected different caller id for different tokens") + } } func TestDetermineMissingToken(t *testing.T) { @@ -72,3 +113,83 @@ func TestDetermineMissingToken(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestDetermineWithQueryKeyUsesDirectToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=direct-query-key", nil) + + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + if a.UseConfigToken { + t.Fatalf("expected direct token mode") + } + if a.DeepSeekToken != "direct-query-key" { + t.Fatalf("unexpected token: %q", a.DeepSeekToken) + } +} + +func TestDetermineWithXGoogAPIKeyUsesDirectToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:streamGenerateContent?alt=sse", nil) + req.Header.Set("x-goog-api-key", "goog-header-key") + + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + if a.UseConfigToken { + t.Fatalf("expected direct token mode") + } + if a.DeepSeekToken != "goog-header-key" { + t.Fatalf("unexpected token: %q", a.DeepSeekToken) + } +} + +func TestDetermineWithAPIKeyQueryParamUsesDirectToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?api_key=direct-api-key", nil) + + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + if a.UseConfigToken { + t.Fatalf("expected direct token mode") + } + if a.DeepSeekToken != "direct-api-key" { + t.Fatalf("unexpected token: %q", a.DeepSeekToken) + } +} + +func TestDetermineHeaderTokenPrecedenceOverQueryKey(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=query-key", nil) + req.Header.Set("x-api-key", "managed-key") + + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + if !a.UseConfigToken { + t.Fatalf("expected managed key mode from header token") + } + if a.AccountID == "" { + t.Fatalf("expected managed account to be acquired") + } +} + +func TestDetermineCallerMissingToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) + + _, err := r.DetermineCaller(req) + if err == nil { + t.Fatal("expected unauthorized error") + } + if err != ErrUnauthorized { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/claudeconv/convert.go b/internal/claudeconv/convert.go new file mode 100644 index 0000000..1ce1f01 --- /dev/null +++ b/internal/claudeconv/convert.go @@ -0,0 +1,48 @@ +package claudeconv + +import "strings" + +type ClaudeMappingProvider interface { + ClaudeMapping() map[string]string +} + +func ConvertClaudeToDeepSeek(claudeReq map[string]any, mappingProvider ClaudeMappingProvider, defaultClaudeModel string) map[string]any { + messages, _ := claudeReq["messages"].([]any) + model, _ := claudeReq["model"].(string) + if model == "" { + model = defaultClaudeModel + } + + mapping := map[string]string{} + if mappingProvider != nil { + mapping = mappingProvider.ClaudeMapping() + } + dsModel := mapping["fast"] + if dsModel == "" { + dsModel = "deepseek-chat" + } + + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") { + if slow := mapping["slow"]; slow != "" { + dsModel = slow + } + } + + convertedMessages := make([]any, 0, len(messages)+1) + if system, ok := claudeReq["system"].(string); ok && system != "" { + convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system}) + } + convertedMessages = append(convertedMessages, messages...) + + out := map[string]any{"model": dsModel, "messages": convertedMessages} + for _, k := range []string{"temperature", "top_p", "stream"} { + if v, ok := claudeReq[k]; ok { + out[k] = v + } + } + if stopSeq, ok := claudeReq["stop_sequences"]; ok { + out["stop"] = stopSeq + } + return out +} diff --git a/internal/compat/go_compat_test.go b/internal/compat/go_compat_test.go new file mode 100644 index 0000000..024e7ba --- /dev/null +++ b/internal/compat/go_compat_test.go @@ -0,0 +1,142 @@ +package compat + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "testing" + + "ds2api/internal/sse" + "ds2api/internal/util" +) + +func TestGoCompatSSEFixtures(t *testing.T) { + files, err := filepath.Glob(compatPath("fixtures", "sse_chunks", "*.json")) + if err != nil { + t.Fatalf("glob fixtures failed: %v", err) + } + if len(files) == 0 { + t.Fatal("no sse fixtures found") + } + for _, fixturePath := range files { + name := trimExt(filepath.Base(fixturePath)) + expectedPath := compatPath("expected", "sse_"+name+".json") + + var fixture struct { + Chunk map[string]any `json:"chunk"` + ThinkingEnable bool `json:"thinking_enabled"` + CurrentType string `json:"current_type"` + } + mustLoadJSON(t, fixturePath, &fixture) + + var expected struct { + Parts []map[string]any `json:"parts"` + Finished bool `json:"finished"` + NewType string `json:"new_type"` + } + mustLoadJSON(t, expectedPath, &expected) + + parts, finished, newType := sse.ParseSSEChunkForContent(fixture.Chunk, fixture.ThinkingEnable, fixture.CurrentType) + gotParts := make([]map[string]any, 0, len(parts)) + for _, p := range parts { + gotParts = append(gotParts, map[string]any{ + "text": p.Text, + "type": p.Type, + }) + } + if !reflect.DeepEqual(gotParts, expected.Parts) || finished != expected.Finished || newType != expected.NewType { + t.Fatalf("fixture %s mismatch:\n got parts=%#v finished=%v newType=%q\nwant parts=%#v finished=%v newType=%q", + name, gotParts, finished, newType, expected.Parts, expected.Finished, expected.NewType) + } + } +} + +func TestGoCompatToolcallFixtures(t *testing.T) { + files, err := filepath.Glob(compatPath("fixtures", "toolcalls", "*.json")) + if err != nil { + t.Fatalf("glob toolcall fixtures failed: %v", err) + } + if len(files) == 0 { + t.Fatal("no toolcall fixtures found") + } + for _, fixturePath := range files { + name := trimExt(filepath.Base(fixturePath)) + expectedPath := compatPath("expected", "toolcalls_"+name+".json") + + var fixture struct { + Text string `json:"text"` + ToolNames []string `json:"tool_names"` + } + mustLoadJSON(t, fixturePath, &fixture) + + var expected struct { + Calls []util.ParsedToolCall `json:"calls"` + } + mustLoadJSON(t, expectedPath, &expected) + + got := util.ParseToolCalls(fixture.Text, fixture.ToolNames) + if len(got) == 0 && len(expected.Calls) == 0 { + continue + } + if !reflect.DeepEqual(got, expected.Calls) { + t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected.Calls) + } + } +} + +func TestGoCompatTokenFixtures(t *testing.T) { + var fixture struct { + Cases []struct { + Name string `json:"name"` + Text string `json:"text"` + } `json:"cases"` + } + mustLoadJSON(t, compatPath("fixtures", "token_cases.json"), &fixture) + + var expected struct { + Cases []struct { + Name string `json:"name"` + Tokens int `json:"tokens"` + } `json:"cases"` + } + mustLoadJSON(t, compatPath("expected", "token_cases.json"), &expected) + + expectByName := map[string]int{} + for _, c := range expected.Cases { + expectByName[c.Name] = c.Tokens + } + for _, c := range fixture.Cases { + want, ok := expectByName[c.Name] + if !ok { + t.Fatalf("missing expected token case: %s", c.Name) + } + got := util.EstimateTokens(c.Text) + if got != want { + t.Fatalf("token fixture %s mismatch: got=%d want=%d", c.Name, got, want) + } + } +} + +func mustLoadJSON(t *testing.T, path string, out any) { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read %s failed: %v", path, err) + } + if err := json.Unmarshal(b, out); err != nil { + t.Fatalf("decode %s failed: %v", path, err) + } +} + +func trimExt(name string) string { + if len(name) > 5 && name[len(name)-5:] == ".json" { + return name[:len(name)-5] + } + return name +} + +func compatPath(parts ...string) string { + prefix := []string{"..", "..", "tests", "compat"} + return filepath.Join(append(prefix, parts...)...) +} diff --git a/internal/config/account.go b/internal/config/account.go new file mode 100644 index 0000000..29a4947 --- /dev/null +++ b/internal/config/account.go @@ -0,0 +1,24 @@ +package config + +import ( + "crypto/sha256" + "encoding/hex" + "strings" +) + +func (a Account) Identifier() string { + if strings.TrimSpace(a.Email) != "" { + return strings.TrimSpace(a.Email) + } + if strings.TrimSpace(a.Mobile) != "" { + return strings.TrimSpace(a.Mobile) + } + // Backward compatibility: old configs may contain token-only accounts. + // Use a stable non-sensitive synthetic id so they can still join the pool. + token := strings.TrimSpace(a.Token) + if token == "" { + return "" + } + sum := sha256.Sum256([]byte(token)) + return "token:" + hex.EncodeToString(sum[:8]) +} diff --git a/internal/config/codec.go b/internal/config/codec.go new file mode 100644 index 0000000..2a23e20 --- /dev/null +++ b/internal/config/codec.go @@ -0,0 +1,241 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "slices" + "strings" +) + +func (c Config) MarshalJSON() ([]byte, error) { + m := map[string]any{} + for k, v := range c.AdditionalFields { + m[k] = v + } + if len(c.Keys) > 0 { + m["keys"] = c.Keys + } + if len(c.Accounts) > 0 { + m["accounts"] = c.Accounts + } + if len(c.ClaudeMapping) > 0 { + m["claude_mapping"] = c.ClaudeMapping + } + if len(c.ClaudeModelMap) > 0 { + m["claude_model_mapping"] = c.ClaudeModelMap + } + if len(c.ModelAliases) > 0 { + m["model_aliases"] = c.ModelAliases + } + if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 { + m["admin"] = c.Admin + } + if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 { + m["runtime"] = c.Runtime + } + if c.Compat.WideInputStrictOutput != nil { + m["compat"] = c.Compat + } + if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" { + m["toolcall"] = c.Toolcall + } + if c.Responses.StoreTTLSeconds > 0 { + m["responses"] = c.Responses + } + if strings.TrimSpace(c.Embeddings.Provider) != "" { + m["embeddings"] = c.Embeddings + } + if c.VercelSyncHash != "" { + m["_vercel_sync_hash"] = c.VercelSyncHash + } + if c.VercelSyncTime != 0 { + m["_vercel_sync_time"] = c.VercelSyncTime + } + return json.Marshal(m) +} + +func (c *Config) UnmarshalJSON(b []byte) error { + raw := map[string]json.RawMessage{} + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + c.AdditionalFields = map[string]any{} + for k, v := range raw { + switch k { + case "keys": + if err := json.Unmarshal(v, &c.Keys); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "accounts": + if err := json.Unmarshal(v, &c.Accounts); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "claude_mapping": + if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "claude_model_mapping": + if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "model_aliases": + if err := json.Unmarshal(v, &c.ModelAliases); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "admin": + if err := json.Unmarshal(v, &c.Admin); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "runtime": + if err := json.Unmarshal(v, &c.Runtime); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "compat": + if err := json.Unmarshal(v, &c.Compat); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "toolcall": + if err := json.Unmarshal(v, &c.Toolcall); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "responses": + if err := json.Unmarshal(v, &c.Responses); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "embeddings": + if err := json.Unmarshal(v, &c.Embeddings); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "_vercel_sync_hash": + if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "_vercel_sync_time": + if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + default: + var anyVal any + if err := json.Unmarshal(v, &anyVal); err == nil { + c.AdditionalFields[k] = anyVal + } + } + } + return nil +} + +func (c Config) Clone() Config { + clone := Config{ + Keys: slices.Clone(c.Keys), + Accounts: slices.Clone(c.Accounts), + ClaudeMapping: cloneStringMap(c.ClaudeMapping), + ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), + ModelAliases: cloneStringMap(c.ModelAliases), + Admin: c.Admin, + Runtime: c.Runtime, + Compat: CompatConfig{ + WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput), + }, + Toolcall: c.Toolcall, + Responses: c.Responses, + Embeddings: c.Embeddings, + VercelSyncHash: c.VercelSyncHash, + VercelSyncTime: c.VercelSyncTime, + AdditionalFields: map[string]any{}, + } + for k, v := range c.AdditionalFields { + clone.AdditionalFields[k] = v + } + return clone +} + +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneBoolPtr(in *bool) *bool { + if in == nil { + return nil + } + v := *in + return &v +} + +func parseConfigString(raw string) (Config, error) { + var cfg Config + candidates := []string{raw} + if normalized := normalizeConfigInput(raw); normalized != raw { + candidates = append(candidates, normalized) + } + for _, candidate := range candidates { + if err := json.Unmarshal([]byte(candidate), &cfg); err == nil { + return cfg, nil + } + } + + base64Input := candidates[len(candidates)-1] + decoded, err := decodeConfigBase64(base64Input) + if err != nil { + return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err) + } + if err := json.Unmarshal(decoded, &cfg); err != nil { + return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err) + } + return cfg, nil +} + +func normalizeConfigInput(raw string) string { + normalized := strings.TrimSpace(raw) + if normalized == "" { + return normalized + } + for { + changed := false + if len(normalized) >= 2 { + first := normalized[0] + last := normalized[len(normalized)-1] + if (first == '"' && last == '"') || (first == '\'' && last == '\'') { + normalized = strings.TrimSpace(normalized[1 : len(normalized)-1]) + changed = true + } + } + if strings.HasPrefix(strings.ToLower(normalized), "base64:") { + normalized = strings.TrimSpace(normalized[len("base64:"):]) + changed = true + } + if !changed { + break + } + } + return strings.TrimSpace(normalized) +} + +func decodeConfigBase64(raw string) ([]byte, error) { + encodings := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + var lastErr error + for _, enc := range encodings { + decoded, err := enc.DecodeString(raw) + if err == nil { + return decoded, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, errors.New("base64 decode failed") +} diff --git a/internal/config/config.go b/internal/config/config.go index 691df6d..4b281a2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,35 +1,20 @@ package config -import ( - "crypto/sha256" - "encoding/base64" - "encoding/hex" - "encoding/json" - "errors" - "log/slog" - "os" - "path/filepath" - "slices" - "strings" - "sync" -) - -var Logger = newLogger() - -func newLogger() *slog.Logger { - level := new(slog.LevelVar) - switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) { - case "DEBUG": - level.Set(slog.LevelDebug) - case "WARN": - level.Set(slog.LevelWarn) - case "ERROR": - level.Set(slog.LevelError) - default: - level.Set(slog.LevelInfo) - } - h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level}) - return slog.New(h) +type Config struct { + Keys []string `json:"keys,omitempty"` + Accounts []Account `json:"accounts,omitempty"` + ClaudeMapping map[string]string `json:"claude_mapping,omitempty"` + ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"` + ModelAliases map[string]string `json:"model_aliases,omitempty"` + Admin AdminConfig `json:"admin,omitempty"` + Runtime RuntimeConfig `json:"runtime,omitempty"` + Compat CompatConfig `json:"compat,omitempty"` + Toolcall ToolcallConfig `json:"toolcall,omitempty"` + Responses ResponsesConfig `json:"responses,omitempty"` + Embeddings EmbeddingsConfig `json:"embeddings,omitempty"` + VercelSyncHash string `json:"_vercel_sync_hash,omitempty"` + VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"` + AdditionalFields map[string]any `json:"-"` } type Account struct { @@ -39,377 +24,31 @@ type Account struct { Token string `json:"token,omitempty"` } -func (a Account) Identifier() string { - if strings.TrimSpace(a.Email) != "" { - return strings.TrimSpace(a.Email) - } - if strings.TrimSpace(a.Mobile) != "" { - return strings.TrimSpace(a.Mobile) - } - // Backward compatibility: old configs may contain token-only accounts. - // Use a stable non-sensitive synthetic id so they can still join the pool. - token := strings.TrimSpace(a.Token) - if token == "" { - return "" - } - sum := sha256.Sum256([]byte(token)) - return "token:" + hex.EncodeToString(sum[:8]) +type CompatConfig struct { + WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"` } -type Config struct { - Keys []string `json:"keys,omitempty"` - Accounts []Account `json:"accounts,omitempty"` - ClaudeMapping map[string]string `json:"claude_mapping,omitempty"` - ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"` - VercelSyncHash string `json:"_vercel_sync_hash,omitempty"` - VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"` - AdditionalFields map[string]any `json:"-"` +type AdminConfig struct { + PasswordHash string `json:"password_hash,omitempty"` + JWTExpireHours int `json:"jwt_expire_hours,omitempty"` + JWTValidAfterUnix int64 `json:"jwt_valid_after_unix,omitempty"` } -func (c Config) MarshalJSON() ([]byte, error) { - m := map[string]any{} - for k, v := range c.AdditionalFields { - m[k] = v - } - if len(c.Keys) > 0 { - m["keys"] = c.Keys - } - if len(c.Accounts) > 0 { - m["accounts"] = c.Accounts - } - if len(c.ClaudeMapping) > 0 { - m["claude_mapping"] = c.ClaudeMapping - } - if len(c.ClaudeModelMap) > 0 { - m["claude_model_mapping"] = c.ClaudeModelMap - } - if c.VercelSyncHash != "" { - m["_vercel_sync_hash"] = c.VercelSyncHash - } - if c.VercelSyncTime != 0 { - m["_vercel_sync_time"] = c.VercelSyncTime - } - return json.Marshal(m) +type RuntimeConfig struct { + AccountMaxInflight int `json:"account_max_inflight,omitempty"` + AccountMaxQueue int `json:"account_max_queue,omitempty"` + GlobalMaxInflight int `json:"global_max_inflight,omitempty"` } -func (c *Config) UnmarshalJSON(b []byte) error { - raw := map[string]json.RawMessage{} - if err := json.Unmarshal(b, &raw); err != nil { - return err - } - c.AdditionalFields = map[string]any{} - for k, v := range raw { - switch k { - case "keys": - _ = json.Unmarshal(v, &c.Keys) - case "accounts": - _ = json.Unmarshal(v, &c.Accounts) - case "claude_mapping": - _ = json.Unmarshal(v, &c.ClaudeMapping) - case "claude_model_mapping": - _ = json.Unmarshal(v, &c.ClaudeModelMap) - case "_vercel_sync_hash": - _ = json.Unmarshal(v, &c.VercelSyncHash) - case "_vercel_sync_time": - _ = json.Unmarshal(v, &c.VercelSyncTime) - default: - var anyVal any - if err := json.Unmarshal(v, &anyVal); err == nil { - c.AdditionalFields[k] = anyVal - } - } - } - return nil +type ToolcallConfig struct { + Mode string `json:"mode,omitempty"` + EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"` } -func (c Config) Clone() Config { - clone := Config{ - Keys: slices.Clone(c.Keys), - Accounts: slices.Clone(c.Accounts), - ClaudeMapping: cloneStringMap(c.ClaudeMapping), - ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), - VercelSyncHash: c.VercelSyncHash, - VercelSyncTime: c.VercelSyncTime, - AdditionalFields: map[string]any{}, - } - for k, v := range c.AdditionalFields { - clone.AdditionalFields[k] = v - } - return clone +type ResponsesConfig struct { + StoreTTLSeconds int `json:"store_ttl_seconds,omitempty"` } -func cloneStringMap(in map[string]string) map[string]string { - if len(in) == 0 { - return nil - } - out := make(map[string]string, len(in)) - for k, v := range in { - out[k] = v - } - return out -} - -type Store struct { - mu sync.RWMutex - cfg Config - path string - fromEnv bool - keyMap map[string]struct{} // O(1) API key lookup index - accMap map[string]int // O(1) account lookup: identifier -> slice index -} - -func BaseDir() string { - cwd, err := os.Getwd() - if err != nil { - return "." - } - return cwd -} - -func IsVercel() bool { - return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != "" -} - -func ResolvePath(envKey, defaultRel string) string { - raw := strings.TrimSpace(os.Getenv(envKey)) - if raw != "" { - if filepath.IsAbs(raw) { - return raw - } - return filepath.Join(BaseDir(), raw) - } - return filepath.Join(BaseDir(), defaultRel) -} - -func ConfigPath() string { - return ResolvePath("DS2API_CONFIG_PATH", "config.json") -} - -func WASMPath() string { - return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm") -} - -func StaticAdminDir() string { - return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin") -} - -func LoadStore() *Store { - cfg, fromEnv, err := loadConfig() - if err != nil { - Logger.Warn("[config] load failed", "error", err) - } - if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 { - Logger.Warn("[config] empty config loaded") - } - s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv} - s.rebuildIndexes() - return s -} - -// rebuildIndexes must be called with the lock already held (or during init). -func (s *Store) rebuildIndexes() { - s.keyMap = make(map[string]struct{}, len(s.cfg.Keys)) - for _, k := range s.cfg.Keys { - s.keyMap[k] = struct{}{} - } - s.accMap = make(map[string]int, len(s.cfg.Accounts)) - for i, acc := range s.cfg.Accounts { - id := acc.Identifier() - if id != "" { - s.accMap[id] = i - } - } -} - -func loadConfig() (Config, bool, error) { - rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON")) - if rawCfg == "" { - rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON")) - } - if rawCfg != "" { - cfg, err := parseConfigString(rawCfg) - return cfg, true, err - } - - content, err := os.ReadFile(ConfigPath()) - if err != nil { - return Config{}, false, err - } - var cfg Config - if err := json.Unmarshal(content, &cfg); err != nil { - return Config{}, false, err - } - return cfg, false, nil -} - -func parseConfigString(raw string) (Config, error) { - var cfg Config - if err := json.Unmarshal([]byte(raw), &cfg); err == nil { - return cfg, nil - } - decoded, err := base64.StdEncoding.DecodeString(raw) - if err != nil { - return Config{}, err - } - if err := json.Unmarshal(decoded, &cfg); err != nil { - return Config{}, err - } - return cfg, nil -} - -func (s *Store) Snapshot() Config { - s.mu.RLock() - defer s.mu.RUnlock() - return s.cfg.Clone() -} - -func (s *Store) HasAPIKey(k string) bool { - s.mu.RLock() - defer s.mu.RUnlock() - _, ok := s.keyMap[k] - return ok -} - -func (s *Store) Keys() []string { - s.mu.RLock() - defer s.mu.RUnlock() - return slices.Clone(s.cfg.Keys) -} - -func (s *Store) Accounts() []Account { - s.mu.RLock() - defer s.mu.RUnlock() - return slices.Clone(s.cfg.Accounts) -} - -func (s *Store) FindAccount(identifier string) (Account, bool) { - identifier = strings.TrimSpace(identifier) - s.mu.RLock() - defer s.mu.RUnlock() - if idx, ok := s.findAccountIndexLocked(identifier); ok { - return s.cfg.Accounts[idx], true - } - return Account{}, false -} - -func (s *Store) UpdateAccountToken(identifier, token string) error { - identifier = strings.TrimSpace(identifier) - s.mu.Lock() - defer s.mu.Unlock() - idx, ok := s.findAccountIndexLocked(identifier) - if !ok { - return errors.New("account not found") - } - oldID := s.cfg.Accounts[idx].Identifier() - s.cfg.Accounts[idx].Token = token - newID := s.cfg.Accounts[idx].Identifier() - // Keep historical aliases usable for long-lived queues while also adding - // the latest identifier after token refresh. - if identifier != "" { - s.accMap[identifier] = idx - } - if oldID != "" { - s.accMap[oldID] = idx - } - if newID != "" { - s.accMap[newID] = idx - } - return s.saveLocked() -} - -func (s *Store) Replace(cfg Config) error { - s.mu.Lock() - defer s.mu.Unlock() - s.cfg = cfg.Clone() - s.rebuildIndexes() - return s.saveLocked() -} - -func (s *Store) Update(mutator func(*Config) error) error { - s.mu.Lock() - defer s.mu.Unlock() - cfg := s.cfg.Clone() - if err := mutator(&cfg); err != nil { - return err - } - s.cfg = cfg - s.rebuildIndexes() - return s.saveLocked() -} - -func (s *Store) Save() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.fromEnv { - Logger.Info("[save_config] source from env, skip write") - return nil - } - b, err := json.MarshalIndent(s.cfg, "", " ") - if err != nil { - return err - } - return os.WriteFile(s.path, b, 0o644) -} - -func (s *Store) saveLocked() error { - if s.fromEnv { - Logger.Info("[save_config] source from env, skip write") - return nil - } - b, err := json.MarshalIndent(s.cfg, "", " ") - if err != nil { - return err - } - return os.WriteFile(s.path, b, 0o644) -} - -// findAccountIndexLocked expects the store lock to already be held. -func (s *Store) findAccountIndexLocked(identifier string) (int, bool) { - if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) { - return idx, true - } - // Fallback for token-only accounts whose derived identifier changed after - // a token refresh; this preserves correctness on map misses. - for i, acc := range s.cfg.Accounts { - if acc.Identifier() == identifier { - return i, true - } - } - return -1, false -} - -func (s *Store) IsEnvBacked() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.fromEnv -} - -func (s *Store) SetVercelSync(hash string, ts int64) error { - return s.Update(func(c *Config) error { - c.VercelSyncHash = hash - c.VercelSyncTime = ts - return nil - }) -} - -func (s *Store) ExportJSONAndBase64() (string, string, error) { - s.mu.RLock() - defer s.mu.RUnlock() - b, err := json.Marshal(s.cfg) - if err != nil { - return "", "", err - } - return string(b), base64.StdEncoding.EncodeToString(b), nil -} - -func (s *Store) ClaudeMapping() map[string]string { - s.mu.RLock() - defer s.mu.RUnlock() - if len(s.cfg.ClaudeModelMap) > 0 { - return cloneStringMap(s.cfg.ClaudeModelMap) - } - if len(s.cfg.ClaudeMapping) > 0 { - return cloneStringMap(s.cfg.ClaudeMapping) - } - return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} +type EmbeddingsConfig struct { + Provider string `json:"provider,omitempty"` } diff --git a/internal/config/config_edge_test.go b/internal/config/config_edge_test.go new file mode 100644 index 0000000..1138867 --- /dev/null +++ b/internal/config/config_edge_test.go @@ -0,0 +1,478 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "strings" + "testing" +) + +// ─── GetModelConfig edge cases ─────────────────────────────────────── + +func TestGetModelConfigDeepSeekChat(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-chat") + if !ok { + t.Fatal("expected ok for deepseek-chat") + } + if thinking || search { + t.Fatalf("expected no thinking/search for deepseek-chat, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekReasoner(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-reasoner") + if !ok { + t.Fatal("expected ok for deepseek-reasoner") + } + if !thinking || search { + t.Fatalf("expected thinking=true search=false, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekChatSearch(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-chat-search") + if !ok { + t.Fatal("expected ok for deepseek-chat-search") + } + if thinking || !search { + t.Fatalf("expected thinking=false search=true, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigDeepSeekReasonerSearch(t *testing.T) { + thinking, search, ok := GetModelConfig("deepseek-reasoner-search") + if !ok { + t.Fatal("expected ok for deepseek-reasoner-search") + } + if !thinking || !search { + t.Fatalf("expected both true, got thinking=%v search=%v", thinking, search) + } +} + +func TestGetModelConfigCaseInsensitive(t *testing.T) { + thinking, search, ok := GetModelConfig("DeepSeek-Chat") + if !ok { + t.Fatal("expected ok for case-insensitive deepseek-chat") + } + if thinking || search { + t.Fatalf("expected no thinking/search for case-insensitive deepseek-chat") + } +} + +func TestGetModelConfigUnknownModel(t *testing.T) { + _, _, ok := GetModelConfig("gpt-4") + if ok { + t.Fatal("expected not ok for unknown model") + } +} + +func TestGetModelConfigEmpty(t *testing.T) { + _, _, ok := GetModelConfig("") + if ok { + t.Fatal("expected not ok for empty model") + } +} + +// ─── lower function ────────────────────────────────────────────────── + +func TestLowerFunction(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"Hello", "hello"}, + {"ALLCAPS", "allcaps"}, + {"already-lower", "already-lower"}, + {"Mixed-CASE-123", "mixed-case-123"}, + {"", ""}, + } + for _, tc := range tests { + got := lower(tc.input) + if got != tc.expected { + t.Errorf("lower(%q) = %q, want %q", tc.input, got, tc.expected) + } + } +} + +// ─── Config.MarshalJSON / UnmarshalJSON roundtrip ──────────────────── + +func TestConfigJSONRoundtrip(t *testing.T) { + cfg := Config{ + Keys: []string{"key1", "key2"}, + Accounts: []Account{{Email: "user@example.com", Password: "pass", Token: "tok"}}, + ClaudeMapping: map[string]string{ + "fast": "deepseek-chat", + "slow": "deepseek-reasoner", + }, + VercelSyncHash: "hash123", + VercelSyncTime: 1234567890, + AdditionalFields: map[string]any{ + "custom_field": "custom_value", + }, + } + + data, err := cfg.MarshalJSON() + if err != nil { + t.Fatalf("marshal error: %v", err) + } + + var decoded Config + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + + if len(decoded.Keys) != 2 || decoded.Keys[0] != "key1" { + t.Fatalf("unexpected keys: %#v", decoded.Keys) + } + if len(decoded.Accounts) != 1 || decoded.Accounts[0].Email != "user@example.com" { + t.Fatalf("unexpected accounts: %#v", decoded.Accounts) + } + if decoded.ClaudeMapping["fast"] != "deepseek-chat" { + t.Fatalf("unexpected claude mapping: %#v", decoded.ClaudeMapping) + } + if decoded.VercelSyncHash != "hash123" { + t.Fatalf("unexpected vercel sync hash: %q", decoded.VercelSyncHash) + } + if decoded.AdditionalFields["custom_field"] != "custom_value" { + t.Fatalf("unexpected additional fields: %#v", decoded.AdditionalFields) + } +} + +func TestConfigUnmarshalJSONPreservesUnknownFields(t *testing.T) { + raw := `{"keys":["k1"],"accounts":[],"my_custom_field":"hello","number_field":42}` + var cfg Config + if err := json.Unmarshal([]byte(raw), &cfg); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if cfg.AdditionalFields["my_custom_field"] != "hello" { + t.Fatalf("expected custom field preserved, got %#v", cfg.AdditionalFields) + } + // number_field should also be preserved + if cfg.AdditionalFields["number_field"] != float64(42) { + t.Fatalf("expected number field preserved, got %#v", cfg.AdditionalFields["number_field"]) + } +} + +// ─── Config.Clone ──────────────────────────────────────────────────── + +func TestConfigCloneIsDeepCopy(t *testing.T) { + cfg := Config{ + Keys: []string{"key1"}, + Accounts: []Account{{Email: "user@test.com", Token: "token"}}, + ClaudeMapping: map[string]string{ + "fast": "deepseek-chat", + }, + AdditionalFields: map[string]any{"custom": "value"}, + } + + cloned := cfg.Clone() + + // Modify original + cfg.Keys[0] = "modified" + cfg.Accounts[0].Email = "modified@test.com" + cfg.ClaudeMapping["fast"] = "modified-model" + + // Cloned should not be affected + if cloned.Keys[0] != "key1" { + t.Fatalf("clone keys was affected by original change: %#v", cloned.Keys) + } + if cloned.Accounts[0].Email != "user@test.com" { + t.Fatalf("clone accounts was affected: %#v", cloned.Accounts) + } + if cloned.ClaudeMapping["fast"] != "deepseek-chat" { + t.Fatalf("clone claude mapping was affected: %#v", cloned.ClaudeMapping) + } +} + +func TestConfigCloneNilMaps(t *testing.T) { + cfg := Config{ + Keys: []string{"k"}, + Accounts: nil, + } + cloned := cfg.Clone() + if len(cloned.Keys) != 1 { + t.Fatalf("unexpected keys length: %d", len(cloned.Keys)) + } + if cloned.Accounts != nil { + t.Fatalf("expected nil accounts in clone, got %#v", cloned.Accounts) + } +} + +// ─── Account.Identifier edge cases ─────────────────────────────────── + +func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) { + acc := Account{Mobile: "13800138000", Token: "tok"} + if acc.Identifier() != "13800138000" { + t.Fatalf("expected mobile identifier, got %q", acc.Identifier()) + } +} + +func TestAccountIdentifierPreferenceEmailOverMobile(t *testing.T) { + acc := Account{Email: "user@test.com", Mobile: "13800138000"} + if acc.Identifier() != "user@test.com" { + t.Fatalf("expected email identifier, got %q", acc.Identifier()) + } +} + +func TestAccountIdentifierEmptyAccount(t *testing.T) { + acc := Account{} + if acc.Identifier() != "" { + t.Fatalf("expected empty identifier for empty account, got %q", acc.Identifier()) + } +} + +// ─── normalizeConfigInput ──────────────────────────────────────────── + +func TestNormalizeConfigInputStripsQuotes(t *testing.T) { + got := normalizeConfigInput(`"base64:abc"`) + if strings.HasPrefix(got, `"`) || strings.HasSuffix(got, `"`) { + t.Fatalf("expected quotes stripped, got %q", got) + } +} + +func TestNormalizeConfigInputStripsSingleQuotes(t *testing.T) { + got := normalizeConfigInput("'some-value'") + if strings.HasPrefix(got, "'") || strings.HasSuffix(got, "'") { + t.Fatalf("expected single quotes stripped, got %q", got) + } +} + +func TestNormalizeConfigInputTrimsWhitespace(t *testing.T) { + got := normalizeConfigInput(" hello ") + if got != "hello" { + t.Fatalf("expected trimmed, got %q", got) + } +} + +// ─── parseConfigString edge cases ──────────────────────────────────── + +func TestParseConfigStringPlainJSON(t *testing.T) { + cfg, err := parseConfigString(`{"keys":["k1"],"accounts":[]}`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringBase64Prefix(t *testing.T) { + rawJSON := `{"keys":["base64-key"],"accounts":[]}` + b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString("base64:" + b64) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "base64-key" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringInvalidBase64(t *testing.T) { + _, err := parseConfigString("base64:!!!invalid!!!") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} + +func TestParseConfigStringEmptyString(t *testing.T) { + _, err := parseConfigString("") + if err == nil { + t.Fatal("expected error for empty string") + } +} + +// ─── Store methods ─────────────────────────────────────────────────── + +func TestStoreSnapshotReturnsClone(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com","token":"t1"}]}`) + store := LoadStore() + snap := store.Snapshot() + snap.Keys[0] = "modified" + if store.Keys()[0] != "k1" { + t.Fatal("snapshot modification should not affect store") + } +} + +func TestStoreHasAPIKeyMultipleKeys(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["key1","key2","key3"],"accounts":[]}`) + store := LoadStore() + if !store.HasAPIKey("key1") { + t.Fatal("expected key1 found") + } + if !store.HasAPIKey("key2") { + t.Fatal("expected key2 found") + } + if !store.HasAPIKey("key3") { + t.Fatal("expected key3 found") + } + if store.HasAPIKey("nonexistent") { + t.Fatal("expected nonexistent key not found") + } +} + +func TestStoreFindAccountNotFound(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@test.com"}]}`) + store := LoadStore() + _, ok := store.FindAccount("nonexistent@test.com") + if ok { + t.Fatal("expected account not found") + } +} + +func TestStoreCompatWideInputStrictOutputDefaultTrue(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + if !store.CompatWideInputStrictOutput() { + t.Fatal("expected default wide_input_strict_output=true when unset") + } +} + +func TestStoreCompatWideInputStrictOutputCanDisable(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[],"compat":{"wide_input_strict_output":false}}`) + store := LoadStore() + if store.CompatWideInputStrictOutput() { + t.Fatal("expected wide_input_strict_output=false when explicitly configured") + } + + snap := store.Snapshot() + data, err := snap.MarshalJSON() + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + rawCompat, ok := out["compat"].(map[string]any) + if !ok { + t.Fatalf("expected compat in marshaled output, got %#v", out) + } + if rawCompat["wide_input_strict_output"] != false { + t.Fatalf("expected explicit false in compat, got %#v", rawCompat) + } +} + +func TestStoreIsEnvBacked(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + if !store.IsEnvBacked() { + t.Fatal("expected env-backed store") + } +} + +func TestStoreReplace(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + newCfg := Config{ + Keys: []string{"new-key"}, + Accounts: []Account{{Email: "new@test.com"}}, + } + if err := store.Replace(newCfg); err != nil { + t.Fatalf("replace error: %v", err) + } + if !store.HasAPIKey("new-key") { + t.Fatal("expected new key after replace") + } + if store.HasAPIKey("k1") { + t.Fatal("expected old key removed after replace") + } +} + +func TestStoreUpdate(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + err := store.Update(func(cfg *Config) error { + cfg.Keys = append(cfg.Keys, "k2") + return nil + }) + if err != nil { + t.Fatalf("update error: %v", err) + } + if !store.HasAPIKey("k2") { + t.Fatal("expected k2 after update") + } +} + +func TestStoreClaudeMapping(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`) + store := LoadStore() + mapping := store.ClaudeMapping() + if mapping["fast"] != "deepseek-chat" { + t.Fatalf("unexpected fast mapping: %q", mapping["fast"]) + } + if mapping["slow"] != "deepseek-reasoner" { + t.Fatalf("unexpected slow mapping: %q", mapping["slow"]) + } +} + +func TestStoreClaudeMappingEmpty(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := LoadStore() + mapping := store.ClaudeMapping() + // Even without config mapping, there are defaults + if mapping == nil { + t.Fatal("expected non-nil mapping (may contain defaults)") + } +} + +func TestStoreSetVercelSync(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := LoadStore() + if err := store.SetVercelSync("hash123", 1234567890); err != nil { + t.Fatalf("setVercelSync error: %v", err) + } + snap := store.Snapshot() + if snap.VercelSyncHash != "hash123" || snap.VercelSyncTime != 1234567890 { + t.Fatalf("unexpected vercel sync: hash=%q time=%d", snap.VercelSyncHash, snap.VercelSyncTime) + } +} + +func TestStoreExportJSONAndBase64(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["export-key"],"accounts":[]}`) + store := LoadStore() + jsonStr, b64Str, err := store.ExportJSONAndBase64() + if err != nil { + t.Fatalf("export error: %v", err) + } + if !strings.Contains(jsonStr, "export-key") { + t.Fatalf("expected JSON to contain key: %q", jsonStr) + } + decoded, err := base64.StdEncoding.DecodeString(b64Str) + if err != nil { + t.Fatalf("base64 decode error: %v", err) + } + if !strings.Contains(string(decoded), "export-key") { + t.Fatalf("expected base64-decoded to contain key: %q", string(decoded)) + } +} + +// ─── OpenAIModelsResponse / ClaudeModelsResponse ───────────────────── + +func TestOpenAIModelsResponse(t *testing.T) { + resp := OpenAIModelsResponse() + if resp["object"] != "list" { + t.Fatalf("unexpected object: %v", resp["object"]) + } + data, ok := resp["data"].([]ModelInfo) + if !ok { + t.Fatalf("unexpected data type: %T", resp["data"]) + } + if len(data) == 0 { + t.Fatal("expected non-empty models list") + } +} + +func TestClaudeModelsResponse(t *testing.T) { + resp := ClaudeModelsResponse() + if resp["object"] != "list" { + t.Fatalf("unexpected object: %v", resp["object"]) + } + data, ok := resp["data"].([]ModelInfo) + if !ok { + t.Fatalf("unexpected data type: %T", resp["data"]) + } + if len(data) == 0 { + t.Fatal("expected non-empty models list") + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 58a8a2a..a409fd7 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "encoding/base64" "strings" "testing" ) @@ -70,3 +71,53 @@ func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T) t.Fatalf("expected find by old identifier alias") } } + +func TestLoadStoreRejectsInvalidFieldType(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":"not-array","accounts":[]}`) + store := LoadStore() + if len(store.Keys()) != 0 || len(store.Accounts()) != 0 { + t.Fatalf("expected empty store when config type is invalid") + } +} + +func TestParseConfigStringSupportsQuotedBase64Prefix(t *testing.T) { + rawJSON := `{"keys":["k1"],"accounts":[{"email":"u@example.com","password":"p"}]}` + b64 := base64.StdEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString(`"base64:` + b64 + `"`) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k1" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestParseConfigStringSupportsRawURLBase64(t *testing.T) { + rawJSON := `{"keys":["k-url"],"accounts":[]}` + b64 := base64.RawURLEncoding.EncodeToString([]byte(rawJSON)) + cfg, err := parseConfigString(b64) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if len(cfg.Keys) != 1 || cfg.Keys[0] != "k-url" { + t.Fatalf("unexpected keys: %#v", cfg.Keys) + } +} + +func TestLoadConfigOnVercelWithoutConfigFileFallsBackToMemory(t *testing.T) { + t.Setenv("VERCEL", "1") + t.Setenv("DS2API_CONFIG_JSON", "") + t.Setenv("CONFIG_JSON", "") + t.Setenv("DS2API_CONFIG_PATH", "testdata/does-not-exist.json") + + cfg, fromEnv, err := loadConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !fromEnv { + t.Fatalf("expected fromEnv=true for vercel fallback") + } + if len(cfg.Keys) != 0 || len(cfg.Accounts) != 0 { + t.Fatalf("expected empty bootstrap config, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts)) + } +} diff --git a/internal/config/logger.go b/internal/config/logger.go new file mode 100644 index 0000000..8b2de91 --- /dev/null +++ b/internal/config/logger.go @@ -0,0 +1,25 @@ +package config + +import ( + "log/slog" + "os" + "strings" +) + +var Logger = newLogger() + +func newLogger() *slog.Logger { + level := new(slog.LevelVar) + switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) { + case "DEBUG": + level.Set(slog.LevelDebug) + case "WARN": + level.Set(slog.LevelWarn) + case "ERROR": + level.Set(slog.LevelError) + default: + level.Set(slog.LevelInfo) + } + h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level}) + return slog.New(h) +} diff --git a/internal/config/model_alias_test.go b/internal/config/model_alias_test.go new file mode 100644 index 0000000..89e74b0 --- /dev/null +++ b/internal/config/model_alias_test.go @@ -0,0 +1,44 @@ +package config + +import "testing" + +func TestResolveModelDirectDeepSeek(t *testing.T) { + got, ok := ResolveModel(nil, "deepseek-chat") + if !ok || got != "deepseek-chat" { + t.Fatalf("expected deepseek-chat, got ok=%v model=%q", ok, got) + } +} + +func TestResolveModelAlias(t *testing.T) { + got, ok := ResolveModel(nil, "gpt-4.1") + if !ok || got != "deepseek-chat" { + t.Fatalf("expected alias gpt-4.1 -> deepseek-chat, got ok=%v model=%q", ok, got) + } +} + +func TestResolveModelHeuristicReasoner(t *testing.T) { + got, ok := ResolveModel(nil, "o3-super") + if !ok || got != "deepseek-reasoner" { + t.Fatalf("expected heuristic reasoner, got ok=%v model=%q", ok, got) + } +} + +func TestResolveModelUnknown(t *testing.T) { + _, ok := ResolveModel(nil, "totally-custom-model") + if ok { + t.Fatal("expected unknown model to fail resolve") + } +} + +func TestClaudeModelsResponsePaginationFields(t *testing.T) { + resp := ClaudeModelsResponse() + if _, ok := resp["first_id"]; !ok { + t.Fatalf("expected first_id in response: %#v", resp) + } + if _, ok := resp["last_id"]; !ok { + t.Fatalf("expected last_id in response: %#v", resp) + } + if _, ok := resp["has_more"]; !ok { + t.Fatalf("expected has_more in response: %#v", resp) + } +} diff --git a/internal/config/models.go b/internal/config/models.go index 13fa63d..a2ec899 100644 --- a/internal/config/models.go +++ b/internal/config/models.go @@ -1,5 +1,7 @@ package config +import "strings" + type ModelInfo struct { ID string `json:"id"` Object string `json:"object"` @@ -8,6 +10,10 @@ type ModelInfo struct { Permission []any `json:"permission,omitempty"` } +type ModelAliasReader interface { + ModelAliases() map[string]string +} + var DeepSeekModels = []ModelInfo{ {ID: "deepseek-chat", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, {ID: "deepseek-reasoner", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}}, @@ -71,6 +77,91 @@ func GetModelConfig(model string) (thinking bool, search bool, ok bool) { } } +func IsSupportedDeepSeekModel(model string) bool { + _, _, ok := GetModelConfig(model) + return ok +} + +func DefaultModelAliases() map[string]string { + return map[string]string{ + "gpt-4o": "deepseek-chat", + "gpt-4.1": "deepseek-chat", + "gpt-4.1-mini": "deepseek-chat", + "gpt-4.1-nano": "deepseek-chat", + "gpt-5": "deepseek-chat", + "gpt-5-mini": "deepseek-chat", + "gpt-5-codex": "deepseek-reasoner", + "o1": "deepseek-reasoner", + "o1-mini": "deepseek-reasoner", + "o3": "deepseek-reasoner", + "o3-mini": "deepseek-reasoner", + "claude-sonnet-4-5": "deepseek-chat", + "claude-haiku-4-5": "deepseek-chat", + "claude-opus-4-6": "deepseek-reasoner", + "claude-3-5-sonnet": "deepseek-chat", + "claude-3-5-haiku": "deepseek-chat", + "claude-3-opus": "deepseek-reasoner", + "gemini-2.5-pro": "deepseek-chat", + "gemini-2.5-flash": "deepseek-chat", + "llama-3.1-70b-instruct": "deepseek-chat", + "qwen-max": "deepseek-chat", + } +} + +func ResolveModel(store ModelAliasReader, requested string) (string, bool) { + model := lower(strings.TrimSpace(requested)) + if model == "" { + return "", false + } + if IsSupportedDeepSeekModel(model) { + return model, true + } + aliases := DefaultModelAliases() + if store != nil { + for k, v := range store.ModelAliases() { + aliases[lower(strings.TrimSpace(k))] = lower(strings.TrimSpace(v)) + } + } + if mapped, ok := aliases[model]; ok && IsSupportedDeepSeekModel(mapped) { + return mapped, true + } + if strings.HasPrefix(model, "deepseek-") { + return "", false + } + + knownFamily := false + for _, prefix := range []string{ + "gpt-", "o1", "o3", "claude-", "gemini-", "llama-", "qwen-", "mistral-", "command-", + } { + if strings.HasPrefix(model, prefix) { + knownFamily = true + break + } + } + if !knownFamily { + return "", false + } + + useReasoner := strings.Contains(model, "reason") || + strings.Contains(model, "reasoner") || + strings.HasPrefix(model, "o1") || + strings.HasPrefix(model, "o3") || + strings.Contains(model, "opus") || + strings.Contains(model, "r1") + useSearch := strings.Contains(model, "search") + + switch { + case useReasoner && useSearch: + return "deepseek-reasoner-search", true + case useReasoner: + return "deepseek-reasoner", true + case useSearch: + return "deepseek-chat-search", true + default: + return "deepseek-chat", true + } +} + func lower(s string) string { b := []byte(s) for i, c := range b { @@ -85,6 +176,28 @@ func OpenAIModelsResponse() map[string]any { return map[string]any{"object": "list", "data": DeepSeekModels} } -func ClaudeModelsResponse() map[string]any { - return map[string]any{"object": "list", "data": ClaudeModels} +func OpenAIModelByID(store ModelAliasReader, id string) (ModelInfo, bool) { + canonical, ok := ResolveModel(store, id) + if !ok { + return ModelInfo{}, false + } + for _, model := range DeepSeekModels { + if model.ID == canonical { + return model, true + } + } + return ModelInfo{}, false +} + +func ClaudeModelsResponse() map[string]any { + resp := map[string]any{"object": "list", "data": ClaudeModels} + if len(ClaudeModels) > 0 { + resp["first_id"] = ClaudeModels[0].ID + resp["last_id"] = ClaudeModels[len(ClaudeModels)-1].ID + } else { + resp["first_id"] = nil + resp["last_id"] = nil + } + resp["has_more"] = false + return resp } diff --git a/internal/config/paths.go b/internal/config/paths.go new file mode 100644 index 0000000..23dfe54 --- /dev/null +++ b/internal/config/paths.go @@ -0,0 +1,42 @@ +package config + +import ( + "os" + "path/filepath" + "strings" +) + +func BaseDir() string { + cwd, err := os.Getwd() + if err != nil { + return "." + } + return cwd +} + +func IsVercel() bool { + return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != "" +} + +func ResolvePath(envKey, defaultRel string) string { + raw := strings.TrimSpace(os.Getenv(envKey)) + if raw != "" { + if filepath.IsAbs(raw) { + return raw + } + return filepath.Join(BaseDir(), raw) + } + return filepath.Join(BaseDir(), defaultRel) +} + +func ConfigPath() string { + return ResolvePath("DS2API_CONFIG_PATH", "config.json") +} + +func WASMPath() string { + return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm") +} + +func StaticAdminDir() string { + return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin") +} diff --git a/internal/config/store.go b/internal/config/store.go new file mode 100644 index 0000000..2e6fcaf --- /dev/null +++ b/internal/config/store.go @@ -0,0 +1,193 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "errors" + "os" + "slices" + "strings" + "sync" +) + +type Store struct { + mu sync.RWMutex + cfg Config + path string + fromEnv bool + keyMap map[string]struct{} // O(1) API key lookup index + accMap map[string]int // O(1) account lookup: identifier -> slice index +} + +func LoadStore() *Store { + cfg, fromEnv, err := loadConfig() + if err != nil { + Logger.Warn("[config] load failed", "error", err) + } + if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 { + Logger.Warn("[config] empty config loaded") + } + s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv} + s.rebuildIndexes() + return s +} + +func loadConfig() (Config, bool, error) { + rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON")) + if rawCfg == "" { + rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON")) + } + if rawCfg != "" { + cfg, err := parseConfigString(rawCfg) + return cfg, true, err + } + + content, err := os.ReadFile(ConfigPath()) + if err != nil { + if IsVercel() { + // Vercel one-click deploy may start without a writable/present config file. + // Keep an in-memory config so users can bootstrap via WebUI then sync env. + return Config{}, true, nil + } + return Config{}, false, err + } + var cfg Config + if err := json.Unmarshal(content, &cfg); err != nil { + return Config{}, false, err + } + if IsVercel() { + // Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors. + return cfg, true, nil + } + return cfg, false, nil +} + +func (s *Store) Snapshot() Config { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cfg.Clone() +} + +func (s *Store) HasAPIKey(k string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + _, ok := s.keyMap[k] + return ok +} + +func (s *Store) Keys() []string { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Clone(s.cfg.Keys) +} + +func (s *Store) Accounts() []Account { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Clone(s.cfg.Accounts) +} + +func (s *Store) FindAccount(identifier string) (Account, bool) { + identifier = strings.TrimSpace(identifier) + s.mu.RLock() + defer s.mu.RUnlock() + if idx, ok := s.findAccountIndexLocked(identifier); ok { + return s.cfg.Accounts[idx], true + } + return Account{}, false +} + +func (s *Store) UpdateAccountToken(identifier, token string) error { + identifier = strings.TrimSpace(identifier) + s.mu.Lock() + defer s.mu.Unlock() + idx, ok := s.findAccountIndexLocked(identifier) + if !ok { + return errors.New("account not found") + } + oldID := s.cfg.Accounts[idx].Identifier() + s.cfg.Accounts[idx].Token = token + newID := s.cfg.Accounts[idx].Identifier() + // Keep historical aliases usable for long-lived queues while also adding + // the latest identifier after token refresh. + if identifier != "" { + s.accMap[identifier] = idx + } + if oldID != "" { + s.accMap[oldID] = idx + } + if newID != "" { + s.accMap[newID] = idx + } + return s.saveLocked() +} + +func (s *Store) Replace(cfg Config) error { + s.mu.Lock() + defer s.mu.Unlock() + s.cfg = cfg.Clone() + s.rebuildIndexes() + return s.saveLocked() +} + +func (s *Store) Update(mutator func(*Config) error) error { + s.mu.Lock() + defer s.mu.Unlock() + cfg := s.cfg.Clone() + if err := mutator(&cfg); err != nil { + return err + } + s.cfg = cfg + s.rebuildIndexes() + return s.saveLocked() +} + +func (s *Store) Save() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.fromEnv { + Logger.Info("[save_config] source from env, skip write") + return nil + } + b, err := json.MarshalIndent(s.cfg, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, b, 0o644) +} + +func (s *Store) saveLocked() error { + if s.fromEnv { + Logger.Info("[save_config] source from env, skip write") + return nil + } + b, err := json.MarshalIndent(s.cfg, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, b, 0o644) +} + +func (s *Store) IsEnvBacked() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.fromEnv +} + +func (s *Store) SetVercelSync(hash string, ts int64) error { + return s.Update(func(c *Config) error { + c.VercelSyncHash = hash + c.VercelSyncTime = ts + return nil + }) +} + +func (s *Store) ExportJSONAndBase64() (string, string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + b, err := json.Marshal(s.cfg) + if err != nil { + return "", "", err + } + return string(b), base64.StdEncoding.EncodeToString(b), nil +} diff --git a/internal/config/store_accessors.go b/internal/config/store_accessors.go new file mode 100644 index 0000000..f0c5938 --- /dev/null +++ b/internal/config/store_accessors.go @@ -0,0 +1,167 @@ +package config + +import ( + "os" + "strconv" + "strings" +) + +func (s *Store) ClaudeMapping() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + if len(s.cfg.ClaudeModelMap) > 0 { + return cloneStringMap(s.cfg.ClaudeModelMap) + } + if len(s.cfg.ClaudeMapping) > 0 { + return cloneStringMap(s.cfg.ClaudeMapping) + } + return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} +} + +func (s *Store) ModelAliases() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + out := DefaultModelAliases() + for k, v := range s.cfg.ModelAliases { + key := strings.TrimSpace(lower(k)) + val := strings.TrimSpace(lower(v)) + if key == "" || val == "" { + continue + } + out[key] = val + } + return out +} + +func (s *Store) CompatWideInputStrictOutput() bool { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Compat.WideInputStrictOutput == nil { + return true + } + return *s.cfg.Compat.WideInputStrictOutput +} + +func (s *Store) ToolcallMode() string { + s.mu.RLock() + defer s.mu.RUnlock() + mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode)) + if mode == "" { + return "feature_match" + } + return mode +} + +func (s *Store) ToolcallEarlyEmitConfidence() string { + s.mu.RLock() + defer s.mu.RUnlock() + level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence)) + if level == "" { + return "high" + } + return level +} + +func (s *Store) ResponsesStoreTTLSeconds() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Responses.StoreTTLSeconds > 0 { + return s.cfg.Responses.StoreTTLSeconds + } + return 900 +} + +func (s *Store) EmbeddingsProvider() string { + s.mu.RLock() + defer s.mu.RUnlock() + return strings.TrimSpace(s.cfg.Embeddings.Provider) +} + +func (s *Store) AdminPasswordHash() string { + s.mu.RLock() + defer s.mu.RUnlock() + return strings.TrimSpace(s.cfg.Admin.PasswordHash) +} + +func (s *Store) AdminJWTExpireHours() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Admin.JWTExpireHours > 0 { + return s.cfg.Admin.JWTExpireHours + } + if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" { + if n, err := strconv.Atoi(raw); err == nil && n > 0 { + return n + } + } + return 24 +} + +func (s *Store) AdminJWTValidAfterUnix() int64 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cfg.Admin.JWTValidAfterUnix +} + +func (s *Store) RuntimeAccountMaxInflight() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.AccountMaxInflight > 0 { + return s.cfg.Runtime.AccountMaxInflight + } + for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + return 2 +} + +func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.AccountMaxQueue > 0 { + return s.cfg.Runtime.AccountMaxQueue + } + for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n >= 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} + +func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.GlobalMaxInflight > 0 { + return s.cfg.Runtime.GlobalMaxInflight + } + for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} diff --git a/internal/config/store_index.go b/internal/config/store_index.go new file mode 100644 index 0000000..7d0f62a --- /dev/null +++ b/internal/config/store_index.go @@ -0,0 +1,31 @@ +package config + +// rebuildIndexes must be called with the lock already held (or during init). +func (s *Store) rebuildIndexes() { + s.keyMap = make(map[string]struct{}, len(s.cfg.Keys)) + for _, k := range s.cfg.Keys { + s.keyMap[k] = struct{}{} + } + s.accMap = make(map[string]int, len(s.cfg.Accounts)) + for i, acc := range s.cfg.Accounts { + id := acc.Identifier() + if id != "" { + s.accMap[id] = i + } + } +} + +// findAccountIndexLocked expects the store lock to already be held. +func (s *Store) findAccountIndexLocked(identifier string) (int, bool) { + if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) { + return idx, true + } + // Fallback for token-only accounts whose derived identifier changed after + // a token refresh; this preserves correctness on map misses. + for i, acc := range s.cfg.Accounts { + if acc.Identifier() == identifier { + return i, true + } + } + return -1, false +} diff --git a/internal/deepseek/client.go b/internal/deepseek/client.go deleted file mode 100644 index 0523435..0000000 --- a/internal/deepseek/client.go +++ /dev/null @@ -1,337 +0,0 @@ -package deepseek - -import ( - "bufio" - "bytes" - "compress/gzip" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strings" - "time" - - "ds2api/internal/auth" - "ds2api/internal/config" - trans "ds2api/internal/deepseek/transport" - "ds2api/internal/util" - - "github.com/andybalholm/brotli" -) - -// intFrom is a package-internal alias for the shared util version. -var intFrom = util.IntFrom - -type Client struct { - Store *config.Store - Auth *auth.Resolver - regular trans.Doer - stream trans.Doer - fallback *http.Client - fallbackS *http.Client - powSolver *PowSolver - maxRetries int -} - -func NewClient(store *config.Store, resolver *auth.Resolver) *Client { - return &Client{ - Store: store, - Auth: resolver, - regular: trans.New(60 * time.Second), - stream: trans.New(0), - fallback: &http.Client{Timeout: 60 * time.Second}, - fallbackS: &http.Client{Timeout: 0}, - powSolver: NewPowSolver(config.WASMPath()), - maxRetries: 3, - } -} - -func (c *Client) PreloadPow(ctx context.Context) error { - return c.powSolver.init(ctx) -} - -func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) { - payload := map[string]any{ - "password": strings.TrimSpace(acc.Password), - "device_id": "deepseek_to_api", - "os": "android", - } - if email := strings.TrimSpace(acc.Email); email != "" { - payload["email"] = email - } else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" { - payload["mobile"] = mobile - payload["area_code"] = nil - } else { - return "", errors.New("missing email/mobile") - } - resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload) - if err != nil { - return "", err - } - code := intFrom(resp["code"]) - if code != 0 { - return "", fmt.Errorf("login failed: %v", resp["msg"]) - } - data, _ := resp["data"].(map[string]any) - if intFrom(data["biz_code"]) != 0 { - return "", fmt.Errorf("login failed: %v", data["biz_msg"]) - } - bizData, _ := data["biz_data"].(map[string]any) - user, _ := bizData["user"].(map[string]any) - token, _ := user["token"].(string) - if strings.TrimSpace(token) == "" { - return "", errors.New("missing login token") - } - return token, nil -} - -func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { - if maxAttempts <= 0 { - maxAttempts = c.maxRetries - } - attempts := 0 - refreshed := false - for attempts < maxAttempts { - headers := c.authHeaders(a.DeepSeekToken) - resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"}) - if err != nil { - config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID) - attempts++ - continue - } - code := intFrom(resp["code"]) - if status == http.StatusOK && code == 0 { - data, _ := resp["data"].(map[string]any) - bizData, _ := data["biz_data"].(map[string]any) - sessionID, _ := bizData["id"].(string) - if sessionID != "" { - return sessionID, nil - } - } - msg, _ := resp["msg"].(string) - config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) - if a.UseConfigToken { - if isTokenInvalid(status, code, msg) && !refreshed { - if c.Auth.RefreshToken(ctx, a) { - refreshed = true - continue - } - } - if c.Auth.SwitchAccount(ctx, a) { - refreshed = false - attempts++ - continue - } - } - attempts++ - } - return "", errors.New("create session failed") -} - -func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { - if maxAttempts <= 0 { - maxAttempts = c.maxRetries - } - attempts := 0 - for attempts < maxAttempts { - headers := c.authHeaders(a.DeepSeekToken) - resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"}) - if err != nil { - config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID) - attempts++ - continue - } - code := intFrom(resp["code"]) - if status == http.StatusOK && code == 0 { - data, _ := resp["data"].(map[string]any) - bizData, _ := data["biz_data"].(map[string]any) - challenge, _ := bizData["challenge"].(map[string]any) - answer, err := c.powSolver.Compute(ctx, challenge) - if err != nil { - attempts++ - continue - } - return BuildPowHeader(challenge, answer) - } - msg, _ := resp["msg"].(string) - config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) - if a.UseConfigToken { - if isTokenInvalid(status, code, msg) { - if c.Auth.RefreshToken(ctx, a) { - continue - } - } - if c.Auth.SwitchAccount(ctx, a) { - attempts++ - continue - } - } - attempts++ - } - return "", errors.New("get pow failed") -} - -func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) { - if maxAttempts <= 0 { - maxAttempts = c.maxRetries - } - headers := c.authHeaders(a.DeepSeekToken) - headers["x-ds-pow-response"] = powResp - attempts := 0 - for attempts < maxAttempts { - resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload) - if err != nil { - attempts++ - time.Sleep(time.Second) - continue - } - if resp.StatusCode == http.StatusOK { - return resp, nil - } - _ = resp.Body.Close() - attempts++ - time.Sleep(time.Second) - } - return nil, errors.New("completion failed") -} - -func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) { - body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload) - if err != nil { - return nil, err - } - if status == 0 { - return nil, errors.New("request failed") - } - return body, nil -} - -func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) { - b, err := json.Marshal(payload) - if err != nil { - return nil, 0, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) - if err != nil { - return nil, 0, err - } - for k, v := range headers { - req.Header.Set(k, v) - } - resp, err := doer.Do(req) - if err != nil { - config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err) - req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) - if reqErr != nil { - return nil, 0, err - } - for k, v := range headers { - req2.Header.Set(k, v) - } - resp, err = c.fallback.Do(req2) - if err != nil { - return nil, 0, err - } - } - defer resp.Body.Close() - payloadBytes, err := readResponseBody(resp) - if err != nil { - return nil, resp.StatusCode, err - } - out := map[string]any{} - if len(payloadBytes) > 0 { - if err := json.Unmarshal(payloadBytes, &out); err != nil { - config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes)) - } - } - return out, resp.StatusCode, nil -} - -func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) { - b, err := json.Marshal(payload) - if err != nil { - return nil, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) - if err != nil { - return nil, err - } - for k, v := range headers { - req.Header.Set(k, v) - } - resp, err := c.stream.Do(req) - if err != nil { - config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err) - req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) - if reqErr != nil { - return nil, err - } - for k, v := range headers { - req2.Header.Set(k, v) - } - return c.fallbackS.Do(req2) - } - return resp, nil -} - -func (c *Client) authHeaders(token string) map[string]string { - headers := make(map[string]string, len(BaseHeaders)+1) - for k, v := range BaseHeaders { - headers[k] = v - } - headers["authorization"] = "Bearer " + token - return headers -} - -func isTokenInvalid(status int, code int, msg string) bool { - msg = strings.ToLower(msg) - if status == http.StatusUnauthorized || status == http.StatusForbidden { - return true - } - if code == 40001 || code == 40002 || code == 40003 { - return true - } - return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized") -} - -func readResponseBody(resp *http.Response) ([]byte, error) { - encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding"))) - var reader io.Reader = resp.Body - switch encoding { - case "gzip": - gz, err := gzip.NewReader(resp.Body) - if err != nil { - return nil, err - } - defer gz.Close() - reader = gz - case "br": - reader = brotli.NewReader(resp.Body) - } - return io.ReadAll(reader) -} - -func preview(b []byte) string { - s := strings.TrimSpace(string(b)) - if len(s) > 160 { - return s[:160] - } - return s -} - -func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error { - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - if !onLine(scanner.Bytes()) { - break - } - } - if err := scanner.Err(); err != nil { - return err - } - return nil -} diff --git a/internal/deepseek/client_auth.go b/internal/deepseek/client_auth.go new file mode 100644 index 0000000..820acaf --- /dev/null +++ b/internal/deepseek/client_auth.go @@ -0,0 +1,153 @@ +package deepseek + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) { + payload := map[string]any{ + "password": strings.TrimSpace(acc.Password), + "device_id": "deepseek_to_api", + "os": "android", + } + if email := strings.TrimSpace(acc.Email); email != "" { + payload["email"] = email + } else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" { + payload["mobile"] = mobile + payload["area_code"] = nil + } else { + return "", errors.New("missing email/mobile") + } + resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload) + if err != nil { + return "", err + } + code := intFrom(resp["code"]) + if code != 0 { + return "", fmt.Errorf("login failed: %v", resp["msg"]) + } + data, _ := resp["data"].(map[string]any) + if intFrom(data["biz_code"]) != 0 { + return "", fmt.Errorf("login failed: %v", data["biz_msg"]) + } + bizData, _ := data["biz_data"].(map[string]any) + user, _ := bizData["user"].(map[string]any) + token, _ := user["token"].(string) + if strings.TrimSpace(token) == "" { + return "", errors.New("missing login token") + } + return token, nil +} + +func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { + if maxAttempts <= 0 { + maxAttempts = c.maxRetries + } + attempts := 0 + refreshed := false + for attempts < maxAttempts { + headers := c.authHeaders(a.DeepSeekToken) + resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"}) + if err != nil { + config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID) + attempts++ + continue + } + code := intFrom(resp["code"]) + if status == http.StatusOK && code == 0 { + data, _ := resp["data"].(map[string]any) + bizData, _ := data["biz_data"].(map[string]any) + sessionID, _ := bizData["id"].(string) + if sessionID != "" { + return sessionID, nil + } + } + msg, _ := resp["msg"].(string) + config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) + if a.UseConfigToken { + if isTokenInvalid(status, code, msg) && !refreshed { + if c.Auth.RefreshToken(ctx, a) { + refreshed = true + continue + } + } + if c.Auth.SwitchAccount(ctx, a) { + refreshed = false + attempts++ + continue + } + } + attempts++ + } + return "", errors.New("create session failed") +} + +func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { + if maxAttempts <= 0 { + maxAttempts = c.maxRetries + } + attempts := 0 + for attempts < maxAttempts { + headers := c.authHeaders(a.DeepSeekToken) + resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"}) + if err != nil { + config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID) + attempts++ + continue + } + code := intFrom(resp["code"]) + if status == http.StatusOK && code == 0 { + data, _ := resp["data"].(map[string]any) + bizData, _ := data["biz_data"].(map[string]any) + challenge, _ := bizData["challenge"].(map[string]any) + answer, err := c.powSolver.Compute(ctx, challenge) + if err != nil { + attempts++ + continue + } + return BuildPowHeader(challenge, answer) + } + msg, _ := resp["msg"].(string) + config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) + if a.UseConfigToken { + if isTokenInvalid(status, code, msg) { + if c.Auth.RefreshToken(ctx, a) { + continue + } + } + if c.Auth.SwitchAccount(ctx, a) { + attempts++ + continue + } + } + attempts++ + } + return "", errors.New("get pow failed") +} + +func (c *Client) authHeaders(token string) map[string]string { + headers := make(map[string]string, len(BaseHeaders)+1) + for k, v := range BaseHeaders { + headers[k] = v + } + headers["authorization"] = "Bearer " + token + return headers +} + +func isTokenInvalid(status int, code int, msg string) bool { + msg = strings.ToLower(msg) + if status == http.StatusUnauthorized || status == http.StatusForbidden { + return true + } + if code == 40001 || code == 40002 || code == 40003 { + return true + } + return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized") +} diff --git a/internal/deepseek/client_completion.go b/internal/deepseek/client_completion.go new file mode 100644 index 0000000..051bffe --- /dev/null +++ b/internal/deepseek/client_completion.go @@ -0,0 +1,71 @@ +package deepseek + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) { + if maxAttempts <= 0 { + maxAttempts = c.maxRetries + } + headers := c.authHeaders(a.DeepSeekToken) + headers["x-ds-pow-response"] = powResp + captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload) + attempts := 0 + for attempts < maxAttempts { + resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload) + if err != nil { + attempts++ + time.Sleep(time.Second) + continue + } + if resp.StatusCode == http.StatusOK { + if captureSession != nil { + resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) + } + return resp, nil + } + if captureSession != nil { + resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) + } + _ = resp.Body.Close() + attempts++ + time.Sleep(time.Second) + } + return nil, errors.New("completion failed") +} + +func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) { + b, err := json.Marshal(payload) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if err != nil { + return nil, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := c.stream.Do(req) + if err != nil { + config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err) + req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if reqErr != nil { + return nil, err + } + for k, v := range headers { + req2.Header.Set(k, v) + } + return c.fallbackS.Do(req2) + } + return resp, nil +} diff --git a/internal/deepseek/client_core.go b/internal/deepseek/client_core.go new file mode 100644 index 0000000..cda6edc --- /dev/null +++ b/internal/deepseek/client_core.go @@ -0,0 +1,46 @@ +package deepseek + +import ( + "context" + "net/http" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" + trans "ds2api/internal/deepseek/transport" + "ds2api/internal/devcapture" + "ds2api/internal/util" +) + +// intFrom is a package-internal alias for the shared util version. +var intFrom = util.IntFrom + +type Client struct { + Store *config.Store + Auth *auth.Resolver + capture *devcapture.Store + regular trans.Doer + stream trans.Doer + fallback *http.Client + fallbackS *http.Client + powSolver *PowSolver + maxRetries int +} + +func NewClient(store *config.Store, resolver *auth.Resolver) *Client { + return &Client{ + Store: store, + Auth: resolver, + capture: devcapture.Global(), + regular: trans.New(60 * time.Second), + stream: trans.New(0), + fallback: &http.Client{Timeout: 60 * time.Second}, + fallbackS: &http.Client{Timeout: 0}, + powSolver: NewPowSolver(config.WASMPath()), + maxRetries: 3, + } +} + +func (c *Client) PreloadPow(ctx context.Context) error { + return c.powSolver.init(ctx) +} diff --git a/internal/deepseek/client_http_helpers.go b/internal/deepseek/client_http_helpers.go new file mode 100644 index 0000000..05de224 --- /dev/null +++ b/internal/deepseek/client_http_helpers.go @@ -0,0 +1,51 @@ +package deepseek + +import ( + "bufio" + "compress/gzip" + "io" + "net/http" + "strings" + + "github.com/andybalholm/brotli" +) + +func readResponseBody(resp *http.Response) ([]byte, error) { + encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding"))) + var reader io.Reader = resp.Body + switch encoding { + case "gzip": + gz, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + defer gz.Close() + reader = gz + case "br": + reader = brotli.NewReader(resp.Body) + } + return io.ReadAll(reader) +} + +func preview(b []byte) string { + s := strings.TrimSpace(string(b)) + if len(s) > 160 { + return s[:160] + } + return s +} + +func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error { + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 2*1024*1024) + for scanner.Scan() { + if !onLine(scanner.Bytes()) { + break + } + } + if err := scanner.Err(); err != nil { + return err + } + return nil +} diff --git a/internal/deepseek/client_http_json.go b/internal/deepseek/client_http_json.go new file mode 100644 index 0000000..6d3599d --- /dev/null +++ b/internal/deepseek/client_http_json.go @@ -0,0 +1,64 @@ +package deepseek + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + + "ds2api/internal/config" + trans "ds2api/internal/deepseek/transport" +) + +func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) { + body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload) + if err != nil { + return nil, err + } + if status == 0 { + return nil, errors.New("request failed") + } + return body, nil +} + +func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) { + b, err := json.Marshal(payload) + if err != nil { + return nil, 0, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if err != nil { + return nil, 0, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := doer.Do(req) + if err != nil { + config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err) + req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if reqErr != nil { + return nil, 0, err + } + for k, v := range headers { + req2.Header.Set(k, v) + } + resp, err = c.fallback.Do(req2) + if err != nil { + return nil, 0, err + } + } + defer resp.Body.Close() + payloadBytes, err := readResponseBody(resp) + if err != nil { + return nil, resp.StatusCode, err + } + out := map[string]any{} + if len(payloadBytes) > 0 { + if err := json.Unmarshal(payloadBytes, &out); err != nil { + config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes)) + } + } + return out, resp.StatusCode, nil +} diff --git a/internal/deepseek/constants.go b/internal/deepseek/constants.go index 1e7d25f..042ec29 100644 --- a/internal/deepseek/constants.go +++ b/internal/deepseek/constants.go @@ -1,5 +1,10 @@ package deepseek +import ( + _ "embed" + "encoding/json" +) + const ( DeepSeekHost = "chat.deepseek.com" DeepSeekLoginURL = "https://chat.deepseek.com/api/v0/users/login" @@ -8,7 +13,7 @@ const ( DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion" ) -var BaseHeaders = map[string]string{ +var defaultBaseHeaders = map[string]string{ "Host": "chat.deepseek.com", "User-Agent": "DeepSeek/1.6.11 Android/35", "Accept": "application/json", @@ -19,6 +24,75 @@ var BaseHeaders = map[string]string{ "accept-charset": "UTF-8", } +var defaultSkipContainsPatterns = []string{ + "quasi_status", + "elapsed_secs", + "token_usage", + "pending_fragment", + "conversation_mode", + "fragments/-1/status", + "fragments/-2/status", + "fragments/-3/status", +} + +var defaultSkipExactPaths = []string{ + "response/search_status", +} + +var BaseHeaders = cloneStringMap(defaultBaseHeaders) +var SkipContainsPatterns = cloneStringSlice(defaultSkipContainsPatterns) +var SkipExactPathSet = toStringSet(defaultSkipExactPaths) + +type sharedConstants struct { + BaseHeaders map[string]string `json:"base_headers"` + SkipContainsPattern []string `json:"skip_contains_patterns"` + SkipExactPaths []string `json:"skip_exact_paths"` +} + +//go:embed constants_shared.json +var sharedConstantsJSON []byte + +func init() { + cfg := sharedConstants{} + if err := json.Unmarshal(sharedConstantsJSON, &cfg); err != nil { + return + } + if len(cfg.BaseHeaders) > 0 { + BaseHeaders = cloneStringMap(cfg.BaseHeaders) + } + if len(cfg.SkipContainsPattern) > 0 { + SkipContainsPatterns = cloneStringSlice(cfg.SkipContainsPattern) + } + if len(cfg.SkipExactPaths) > 0 { + SkipExactPathSet = toStringSet(cfg.SkipExactPaths) + } +} + +func cloneStringMap(in map[string]string) map[string]string { + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneStringSlice(in []string) []string { + out := make([]string, len(in)) + copy(out, in) + return out +} + +func toStringSet(in []string) map[string]struct{} { + out := make(map[string]struct{}, len(in)) + for _, v := range in { + if v == "" { + continue + } + out[v] = struct{}{} + } + return out +} + const ( KeepAliveTimeout = 5 StreamIdleTimeout = 30 diff --git a/internal/deepseek/constants_shared.json b/internal/deepseek/constants_shared.json new file mode 100644 index 0000000..a71ca02 --- /dev/null +++ b/internal/deepseek/constants_shared.json @@ -0,0 +1,25 @@ +{ + "base_headers": { + "Host": "chat.deepseek.com", + "User-Agent": "DeepSeek/1.6.11 Android/35", + "Accept": "application/json", + "Content-Type": "application/json", + "x-client-platform": "android", + "x-client-version": "1.6.11", + "x-client-locale": "zh_CN", + "accept-charset": "UTF-8" + }, + "skip_contains_patterns": [ + "quasi_status", + "elapsed_secs", + "token_usage", + "pending_fragment", + "conversation_mode", + "fragments/-1/status", + "fragments/-2/status", + "fragments/-3/status" + ], + "skip_exact_paths": [ + "response/search_status" + ] +} diff --git a/internal/deepseek/constants_test.go b/internal/deepseek/constants_test.go new file mode 100644 index 0000000..03c6788 --- /dev/null +++ b/internal/deepseek/constants_test.go @@ -0,0 +1,15 @@ +package deepseek + +import "testing" + +func TestSharedConstantsLoaded(t *testing.T) { + if BaseHeaders["x-client-platform"] != "android" { + t.Fatalf("unexpected base header x-client-platform=%q", BaseHeaders["x-client-platform"]) + } + if len(SkipContainsPatterns) == 0 { + t.Fatal("expected skip contains patterns to be loaded") + } + if _, ok := SkipExactPathSet["response/search_status"]; !ok { + t.Fatal("expected response/search_status in exact skip path set") + } +} diff --git a/internal/deepseek/deepseek_edge_test.go b/internal/deepseek/deepseek_edge_test.go new file mode 100644 index 0000000..92e6952 --- /dev/null +++ b/internal/deepseek/deepseek_edge_test.go @@ -0,0 +1,165 @@ +package deepseek + +import ( + "context" + "testing" +) + +// ─── toFloat64 edge cases ──────────────────────────────────────────── + +func TestToFloat64FromFloat64(t *testing.T) { + if got := toFloat64(float64(3.14), 0); got != 3.14 { + t.Fatalf("expected 3.14, got %f", got) + } +} + +func TestToFloat64FromInt(t *testing.T) { + if got := toFloat64(42, 0); got != 42.0 { + t.Fatalf("expected 42.0, got %f", got) + } +} + +func TestToFloat64FromInt64(t *testing.T) { + if got := toFloat64(int64(100), 0); got != 100.0 { + t.Fatalf("expected 100.0, got %f", got) + } +} + +func TestToFloat64FromStringDefault(t *testing.T) { + if got := toFloat64("42", 99.0); got != 99.0 { + t.Fatalf("expected default 99.0, got %f", got) + } +} + +func TestToFloat64FromNilDefault(t *testing.T) { + if got := toFloat64(nil, 5.5); got != 5.5 { + t.Fatalf("expected default 5.5, got %f", got) + } +} + +func TestToFloat64FromBoolDefault(t *testing.T) { + if got := toFloat64(true, 1.0); got != 1.0 { + t.Fatalf("expected default 1.0, got %f", got) + } +} + +// ─── toInt64 edge cases ────────────────────────────────────────────── + +func TestToInt64FromFloat64(t *testing.T) { + if got := toInt64(float64(42.9), 0); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestToInt64FromInt(t *testing.T) { + if got := toInt64(42, 0); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestToInt64FromInt64(t *testing.T) { + if got := toInt64(int64(100), 0); got != 100 { + t.Fatalf("expected 100, got %d", got) + } +} + +func TestToInt64FromStringDefault(t *testing.T) { + if got := toInt64("42", 99); got != 99 { + t.Fatalf("expected default 99, got %d", got) + } +} + +func TestToInt64FromNilDefault(t *testing.T) { + if got := toInt64(nil, 7); got != 7 { + t.Fatalf("expected default 7, got %d", got) + } +} + +// ─── BuildPowHeader edge cases ─────────────────────────────────────── + +func TestBuildPowHeaderBasicChallenge(t *testing.T) { + challenge := map[string]any{ + "algorithm": "DeepSeekHashV1", + "challenge": "abc123", + "salt": "salt456", + "signature": "sig789", + "target_path": "/path", + } + result, err := BuildPowHeader(challenge, 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == "" { + t.Fatal("expected non-empty result") + } +} + +func TestBuildPowHeaderEmptyChallenge(t *testing.T) { + result, err := BuildPowHeader(map[string]any{}, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should produce a base64 encoded JSON with nil values + if result == "" { + t.Fatal("expected non-empty result for empty challenge") + } +} + +// ─── PowSolver pool size ───────────────────────────────────────────── + +func TestPowPoolSizeFromEnvDefault(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "") + got := powPoolSizeFromEnv() + if got < 1 { + t.Fatalf("expected positive default pool size, got %d", got) + } +} + +func TestPowPoolSizeFromEnvInvalid(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "abc") + got := powPoolSizeFromEnv() + if got < 1 { + t.Fatalf("expected positive default for invalid, got %d", got) + } +} + +func TestPowPoolSizeFromEnvSpecificValue(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "5") + got := powPoolSizeFromEnv() + if got != 5 { + t.Fatalf("expected 5, got %d", got) + } +} + +// ─── NewClient ─────────────────────────────────────────────────────── + +func TestNewClientInitialState(t *testing.T) { + client := NewClient(nil, nil) + if client.powSolver == nil { + t.Fatal("expected powSolver to be initialized") + } +} + +func TestNewClientPreloadPowIdempotent(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "1") + client := NewClient(nil, nil) + if err := client.PreloadPow(context.Background()); err != nil { + t.Fatalf("first preload failed: %v", err) + } + if err := client.PreloadPow(context.Background()); err != nil { + t.Fatalf("second preload failed: %v", err) + } +} + +// ─── PowSolver init and module pool ────────────────────────────────── + +func TestPowSolverPoolSizeMatchesEnv(t *testing.T) { + t.Setenv("DS2API_POW_POOL_SIZE", "2") + solver := NewPowSolver("test.wasm") + if err := solver.init(context.Background()); err != nil { + t.Fatalf("init failed: %v", err) + } + if cap(solver.pool) != 2 { + t.Fatalf("expected pool capacity 2, got %d", cap(solver.pool)) + } +} diff --git a/internal/deepseek/prompt.go b/internal/deepseek/prompt.go new file mode 100644 index 0000000..2410390 --- /dev/null +++ b/internal/deepseek/prompt.go @@ -0,0 +1,7 @@ +package deepseek + +import "ds2api/internal/prompt" + +func MessagesPrepare(messages []map[string]any) string { + return prompt.MessagesPrepare(messages) +} diff --git a/internal/devcapture/store.go b/internal/devcapture/store.go new file mode 100644 index 0000000..6d0d8cd --- /dev/null +++ b/internal/devcapture/store.go @@ -0,0 +1,259 @@ +package devcapture + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +const ( + defaultLimit = 5 + defaultMaxBodyBytes = 2 * 1024 * 1024 + maxLimit = 50 +) + +type Entry struct { + ID string `json:"id"` + CreatedAt int64 `json:"created_at"` + Label string `json:"label"` + URL string `json:"url"` + AccountID string `json:"account_id,omitempty"` + StatusCode int `json:"status_code"` + RequestBody string `json:"request_body"` + ResponseBody string `json:"response_body"` + ResponseTruncated bool `json:"response_truncated"` +} + +type Store struct { + mu sync.Mutex + enabled bool + limit int + maxBodyBytes int + items []Entry +} + +type Session struct { + store *Store + id string + createdAt int64 + label string + url string + accountID string + requestRaw string +} + +type captureBody struct { + rc io.ReadCloser + s *Session + statusCode int + buf strings.Builder + truncated bool + finalized bool +} + +var ( + globalOnce sync.Once + globalInst *Store +) + +func Global() *Store { + globalOnce.Do(func() { + globalInst = NewFromEnv() + }) + return globalInst +} + +func NewFromEnv() *Store { + enabled := !isVercelRuntime() + if raw, ok := os.LookupEnv("DS2API_DEV_PACKET_CAPTURE"); ok { + enabled = parseBool(raw) + } + limit := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_LIMIT"), defaultLimit) + if limit < 1 { + limit = defaultLimit + } + if limit > maxLimit { + limit = maxLimit + } + maxBodyBytes := parseIntWithDefault(os.Getenv("DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES"), defaultMaxBodyBytes) + if maxBodyBytes < 1024 { + maxBodyBytes = defaultMaxBodyBytes + } + return &Store{ + enabled: enabled, + limit: limit, + maxBodyBytes: maxBodyBytes, + items: make([]Entry, 0, limit), + } +} + +func isVercelRuntime() bool { + return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != "" +} + +func (s *Store) Enabled() bool { + if s == nil { + return false + } + return s.enabled +} + +func (s *Store) Limit() int { + if s == nil { + return defaultLimit + } + return s.limit +} + +func (s *Store) MaxBodyBytes() int { + if s == nil { + return defaultMaxBodyBytes + } + return s.maxBodyBytes +} + +func (s *Store) Snapshot() []Entry { + if s == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + out := make([]Entry, len(s.items)) + copy(out, s.items) + return out +} + +func (s *Store) Clear() { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.items = s.items[:0] +} + +func (s *Store) Start(label, url, accountID string, requestPayload any) *Session { + if s == nil || !s.enabled { + return nil + } + return &Session{ + store: s, + id: "cap_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + createdAt: time.Now().Unix(), + label: strings.TrimSpace(label), + url: strings.TrimSpace(url), + accountID: strings.TrimSpace(accountID), + requestRaw: marshalPayload(requestPayload), + } +} + +func (s *Session) WrapBody(rc io.ReadCloser, statusCode int) io.ReadCloser { + if s == nil || rc == nil { + return rc + } + return &captureBody{ + rc: rc, + s: s, + statusCode: statusCode, + } +} + +func (c *captureBody) Read(p []byte) (int, error) { + n, err := c.rc.Read(p) + if n > 0 { + c.append(string(p[:n])) + } + if err == io.EOF { + c.finalize() + } + return n, err +} + +func (c *captureBody) Close() error { + err := c.rc.Close() + c.finalize() + return err +} + +func (c *captureBody) append(chunk string) { + if chunk == "" || c.s == nil || c.s.store == nil { + return + } + maxLen := c.s.store.maxBodyBytes + current := c.buf.Len() + if current >= maxLen { + c.truncated = true + return + } + remain := maxLen - current + if len(chunk) > remain { + c.buf.WriteString(chunk[:remain]) + c.truncated = true + return + } + c.buf.WriteString(chunk) +} + +func (c *captureBody) finalize() { + if c.finalized || c.s == nil || c.s.store == nil { + return + } + c.finalized = true + entry := Entry{ + ID: c.s.id, + CreatedAt: c.s.createdAt, + Label: c.s.label, + URL: c.s.url, + AccountID: c.s.accountID, + StatusCode: c.statusCode, + RequestBody: c.s.requestRaw, + ResponseBody: c.buf.String(), + ResponseTruncated: c.truncated, + } + c.s.store.push(entry) +} + +func (s *Store) push(entry Entry) { + s.mu.Lock() + defer s.mu.Unlock() + s.items = append([]Entry{entry}, s.items...) + if len(s.items) > s.limit { + s.items = s.items[:s.limit] + } +} + +func marshalPayload(v any) string { + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(b) +} + +func parseBool(v string) bool { + switch strings.ToLower(strings.TrimSpace(v)) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +func parseIntWithDefault(raw string, d int) int { + raw = strings.TrimSpace(raw) + if raw == "" { + return d + } + n, err := strconv.Atoi(raw) + if err != nil { + return d + } + return n +} diff --git a/internal/devcapture/store_test.go b/internal/devcapture/store_test.go new file mode 100644 index 0000000..1dd58b4 --- /dev/null +++ b/internal/devcapture/store_test.go @@ -0,0 +1,55 @@ +package devcapture + +import ( + "io" + "strings" + "testing" +) + +func TestStorePushKeepsNewestWithinLimit(t *testing.T) { + s := &Store{enabled: true, limit: 2, maxBodyBytes: 1024} + for i := 0; i < 3; i++ { + session := s.Start("test", "http://x", "", map[string]any{"seq": i}) + if session == nil { + t.Fatal("expected session") + } + rc := session.WrapBody(io.NopCloser(strings.NewReader("ok")), 200) + _, _ = io.ReadAll(rc) + _ = rc.Close() + } + items := s.Snapshot() + if len(items) != 2 { + t.Fatalf("expected 2 items, got %d", len(items)) + } + if !strings.Contains(items[0].RequestBody, `"seq":2`) { + t.Fatalf("expected newest first, got %#v", items[0].RequestBody) + } + if !strings.Contains(items[1].RequestBody, `"seq":1`) { + t.Fatalf("expected second newest, got %#v", items[1].RequestBody) + } +} + +func TestWrapBodyTruncatesByLimit(t *testing.T) { + s := &Store{enabled: true, limit: 5, maxBodyBytes: 4} + session := s.Start("test", "http://x", "acc1", map[string]any{"x": 1}) + if session == nil { + t.Fatal("expected session") + } + rc := session.WrapBody(io.NopCloser(strings.NewReader("abcdef")), 200) + _, _ = io.ReadAll(rc) + _ = rc.Close() + + items := s.Snapshot() + if len(items) != 1 { + t.Fatalf("expected 1 item, got %d", len(items)) + } + if items[0].ResponseBody != "abcd" { + t.Fatalf("expected truncated body, got %q", items[0].ResponseBody) + } + if !items[0].ResponseTruncated { + t.Fatal("expected truncated flag true") + } + if items[0].AccountID != "acc1" { + t.Fatalf("expected account id, got %q", items[0].AccountID) + } +} diff --git a/internal/format/claude/render.go b/internal/format/claude/render.go new file mode 100644 index 0000000..fdba055 --- /dev/null +++ b/internal/format/claude/render.go @@ -0,0 +1,46 @@ +package claude + +import ( + "fmt" + "time" + + "ds2api/internal/util" +) + +func BuildMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + content := make([]map[string]any, 0, 4) + if finalThinking != "" { + content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking}) + } + stopReason := "end_turn" + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + content = append(content, map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), + "name": tc.Name, + "input": tc.Input, + }) + } + } else { + if finalText == "" { + finalText = "抱歉,没有生成有效的响应内容。" + } + content = append(content, map[string]any{"type": "text", "text": finalText}) + } + return map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalizedMessages)), + "output_tokens": util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText), + }, + } +} diff --git a/internal/format/openai/render_chat.go b/internal/format/openai/render_chat.go new file mode 100644 index 0000000..1e58fbd --- /dev/null +++ b/internal/format/openai/render_chat.go @@ -0,0 +1,60 @@ +package openai + +import ( + "strings" + "time" + + "ds2api/internal/util" +) + +func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + finishReason := "stop" + messageObj := map[string]any{"role": "assistant", "content": finalText} + if strings.TrimSpace(finalThinking) != "" { + messageObj["reasoning_content"] = finalThinking + } + if len(detected) > 0 { + finishReason = "tool_calls" + messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) + messageObj["content"] = nil + } + + return map[string]any{ + "id": completionID, + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, + "usage": BuildChatUsage(finalPrompt, finalThinking, finalText), + } +} + +func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { + return map[string]any{ + "delta": delta, + "index": index, + } +} + +func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any { + return map[string]any{ + "delta": map[string]any{}, + "index": index, + "finish_reason": finishReason, + } +} + +func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { + out := map[string]any{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": choices, + } + if len(usage) > 0 { + out["usage"] = usage + } + return out +} diff --git a/internal/format/openai/render_responses.go b/internal/format/openai/render_responses.go new file mode 100644 index 0000000..f55ee9f --- /dev/null +++ b/internal/format/openai/render_responses.go @@ -0,0 +1,114 @@ +package openai + +import ( + "encoding/json" + "strings" + "time" + + "github.com/google/uuid" + + "ds2api/internal/util" +) + +func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + // Align responses tool-call semantics with chat/completions: + // mixed prose + tool_call payloads should still be interpreted as tool calls. + detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { + detected = util.ParseToolCalls(finalThinking, toolNames) + } + exposedOutputText := finalText + output := make([]any, 0, 2) + if len(detected) > 0 { + exposedOutputText = "" + output = append(output, toResponsesFunctionCallItems(detected)...) + } else { + content := make([]any, 0, 2) + if finalThinking != "" { + content = append([]any{map[string]any{ + "type": "reasoning", + "text": finalThinking, + }}, content...) + } + if strings.TrimSpace(finalText) != "" { + content = append(content, map[string]any{ + "type": "output_text", + "text": finalText, + }) + } + if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" { + exposedOutputText = finalThinking + } + output = append(output, map[string]any{ + "type": "message", + "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "role": "assistant", + "content": content, + }) + } + return BuildResponseObjectFromItems( + responseID, + model, + finalPrompt, + finalThinking, + finalText, + output, + exposedOutputText, + ) +} + +func BuildResponseObjectFromItems(responseID, model, finalPrompt, finalThinking, finalText string, output []any, outputText string) map[string]any { + if output == nil { + output = []any{} + } + return map[string]any{ + "id": responseID, + "type": "response", + "object": "response", + "created_at": time.Now().Unix(), + "status": "completed", + "model": model, + "output": output, + "output_text": outputText, + "usage": BuildResponsesUsage(finalPrompt, finalThinking, finalText), + } +} + +func toResponsesFunctionCallItems(toolCalls []util.ParsedToolCall) []any { + if len(toolCalls) == 0 { + return nil + } + out := make([]any, 0, len(toolCalls)) + for _, tc := range toolCalls { + if strings.TrimSpace(tc.Name) == "" { + continue + } + argsBytes, _ := json.Marshal(tc.Input) + args := normalizeJSONString(string(argsBytes)) + out = append(out, map[string]any{ + "id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "type": "function_call", + "call_id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "name": tc.Name, + "arguments": args, + "status": "completed", + }) + } + return out +} + +func normalizeJSONString(raw string) string { + s := strings.TrimSpace(raw) + if s == "" { + return "{}" + } + var v any + if err := json.Unmarshal([]byte(s), &v); err != nil { + return raw + } + b, err := json.Marshal(v) + if err != nil { + return raw + } + return string(b) +} diff --git a/internal/format/openai/render_stream_events.go b/internal/format/openai/render_stream_events.go new file mode 100644 index 0000000..dc13231 --- /dev/null +++ b/internal/format/openai/render_stream_events.go @@ -0,0 +1,136 @@ +package openai + +import "strings" + +func BuildResponsesCreatedPayload(responseID, model string) map[string]any { + return map[string]any{ + "type": "response.created", + "id": responseID, + "response_id": responseID, + "object": "response", + "model": model, + "status": "in_progress", + } +} + +func BuildResponsesOutputItemAddedPayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_item.added", + "id": responseID, + "response_id": responseID, + "output_index": outputIndex, + "item_id": itemID, + "item": item, + } +} + +func BuildResponsesOutputItemDonePayload(responseID, itemID string, outputIndex int, item map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_item.done", + "id": responseID, + "response_id": responseID, + "output_index": outputIndex, + "item_id": itemID, + "item": item, + } +} + +func BuildResponsesContentPartAddedPayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any { + return map[string]any{ + "type": "response.content_part.added", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "part": part, + } +} + +func BuildResponsesContentPartDonePayload(responseID, itemID string, outputIndex, contentIndex int, part map[string]any) map[string]any { + return map[string]any{ + "type": "response.content_part.done", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "part": part, + } +} + +func BuildResponsesTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any { + return map[string]any{ + "type": "response.output_text.delta", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "delta": delta, + } +} + +func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.reasoning.delta", + "id": responseID, + "response_id": responseID, + "delta": delta, + } +} + +func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any { + return map[string]any{ + "type": "response.function_call_arguments.delta", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "call_id": callID, + "delta": delta, + } +} + +func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any { + return map[string]any{ + "type": "response.function_call_arguments.done", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "call_id": callID, + "name": name, + "arguments": normalizeJSONString(arguments), + } +} + +func BuildResponsesFailedPayload(responseID, model, message, code string) map[string]any { + code = strings.TrimSpace(code) + if code == "" { + code = "api_error" + } + return map[string]any{ + "type": "response.failed", + "id": responseID, + "response_id": responseID, + "object": "response", + "model": model, + "status": "failed", + "error": map[string]any{ + "message": message, + "type": "invalid_request_error", + "code": code, + "param": nil, + }, + } +} + +func BuildResponsesCompletedPayload(response map[string]any) map[string]any { + responseID, _ := response["id"].(string) + return map[string]any{ + "type": "response.completed", + "response_id": responseID, + "response": response, + } +} diff --git a/internal/format/openai/render_test.go b/internal/format/openai/render_test.go new file mode 100644 index 0000000..df792ed --- /dev/null +++ b/internal/format/openai/render_test.go @@ -0,0 +1,148 @@ +package openai + +import ( + "encoding/json" + "testing" +) + +func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "", + `{"tool_calls":[{"name":"search","input":{"q":"golang"}}]}`, + []string{"search"}, + ) + + outputText, _ := obj["output_text"].(string) + if outputText != "" { + t.Fatalf("expected output_text to be hidden for tool calls, got %q", outputText) + } + + output, _ := obj["output"].([]any) + if len(output) != 1 { + t.Fatalf("expected function_call output only, got %#v", obj["output"]) + } + + first, _ := output[0].(map[string]any) + if first["type"] != "function_call" { + t.Fatalf("expected first output item type function_call, got %#v", first["type"]) + } + if first["call_id"] == "" { + t.Fatalf("expected function_call item to have call_id, got %#v", first) + } + if first["name"] != "search" { + t.Fatalf("unexpected function name: %#v", first["name"]) + } + argsRaw, _ := first["arguments"].(string) + var args map[string]any + if err := json.Unmarshal([]byte(argsRaw), &args); err != nil { + t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err) + } + if args["q"] != "golang" { + t.Fatalf("unexpected arguments: %#v", args) + } +} + +func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "", + `示例格式:{"tool_calls":[{"name":"search","input":{"q":"golang"}}]},但这条是普通回答。`, + []string{"search"}, + ) + + outputText, _ := obj["output_text"].(string) + if outputText != "" { + t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText) + } + + output, _ := obj["output"].([]any) + if len(output) != 1 { + t.Fatalf("expected function_call output only, got %#v", obj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "function_call" { + t.Fatalf("expected first output type function_call, got %#v", first["type"]) + } +} + +func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "", + "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```", + []string{"search"}, + ) + + outputText, _ := obj["output_text"].(string) + if outputText == "" { + t.Fatalf("expected output_text preserved for fenced example") + } + output, _ := obj["output"].([]any) + if len(output) != 1 { + t.Fatalf("expected one message output item, got %#v", obj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "message" { + t.Fatalf("expected message output type, got %#v", first["type"]) + } +} + +func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "internal thinking content", + "", + nil, + ) + + outputText, _ := obj["output_text"].(string) + if outputText == "" { + t.Fatalf("expected output_text fallback from reasoning when final text is empty") + } + + output, _ := obj["output"].([]any) + if len(output) != 1 { + t.Fatalf("expected one output item, got %#v", obj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "message" { + t.Fatalf("expected output type message, got %#v", first["type"]) + } + content, _ := first["content"].([]any) + if len(content) == 0 { + t.Fatalf("expected reasoning content, got %#v", first["content"]) + } + block0, _ := content[0].(map[string]any) + if block0["type"] != "reasoning" { + t.Fatalf("expected first content block reasoning, got %#v", block0["type"]) + } +} + +func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + `{"tool_calls":[{"name":"search","input":{"q":"from-thinking"}}]}`, + "", + []string{"search"}, + ) + + output, _ := obj["output"].([]any) + if len(output) != 1 { + t.Fatalf("expected function_call output only, got %#v", obj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "function_call" { + t.Fatalf("expected output function_call, got %#v", first["type"]) + } +} diff --git a/internal/format/openai/render_usage.go b/internal/format/openai/render_usage.go new file mode 100644 index 0000000..b328d20 --- /dev/null +++ b/internal/format/openai/render_usage.go @@ -0,0 +1,28 @@ +package openai + +import "ds2api/internal/util" + +func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + } +} + +func BuildResponsesUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "input_tokens": promptTokens, + "output_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + } +} diff --git a/internal/js/chat-stream/error_shape.js b/internal/js/chat-stream/error_shape.js new file mode 100644 index 0000000..18aeedb --- /dev/null +++ b/internal/js/chat-stream/error_shape.js @@ -0,0 +1,36 @@ +'use strict'; + +function writeOpenAIError(res, status, message) { + res.statusCode = status; + res.setHeader('Content-Type', 'application/json'); + res.end( + JSON.stringify({ + error: { + message, + type: openAIErrorType(status), + }, + }), + ); +} + +function openAIErrorType(status) { + switch (status) { + case 400: + return 'invalid_request_error'; + case 401: + return 'authentication_error'; + case 403: + return 'permission_error'; + case 429: + return 'rate_limit_error'; + case 503: + return 'service_unavailable_error'; + default: + return status >= 500 ? 'api_error' : 'invalid_request_error'; + } +} + +module.exports = { + writeOpenAIError, + openAIErrorType, +}; diff --git a/internal/js/chat-stream/http_internal.js b/internal/js/chat-stream/http_internal.js new file mode 100644 index 0000000..20f24c8 --- /dev/null +++ b/internal/js/chat-stream/http_internal.js @@ -0,0 +1,214 @@ +'use strict'; + +const { + writeOpenAIError, +} = require('./error_shape'); + +function setCorsHeaders(res) { + res.setHeader('Access-Control-Allow-Origin', '*'); + res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE'); + res.setHeader( + 'Access-Control-Allow-Headers', + 'Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass', + ); +} + +function header(req, key) { + if (!req || !req.headers) { + return ''; + } + return asString(req.headers[key.toLowerCase()]); +} + +async function readRawBody(req) { + if (Buffer.isBuffer(req.body)) { + return req.body; + } + if (typeof req.body === 'string') { + return Buffer.from(req.body); + } + if (req.body && typeof req.body === 'object') { + return Buffer.from(JSON.stringify(req.body)); + } + const chunks = []; + for await (const chunk of req) { + chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)); + } + return Buffer.concat(chunks); +} + +async function fetchStreamPrepare(req, rawBody) { + const url = buildInternalGoURL(req); + url.searchParams.set('__stream_prepare', '1'); + + const upstream = await fetch(url.toString(), { + method: 'POST', + headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }), + body: rawBody, + }); + + const text = await upstream.text(); + let body = {}; + try { + body = JSON.parse(text || '{}'); + } catch (_err) { + body = {}; + } + + return { + ok: upstream.ok, + status: upstream.status, + contentType: upstream.headers.get('content-type') || 'application/json', + text, + body, + }; +} + +function relayPreparedFailure(res, prep) { + if (prep.status === 401 && looksLikeVercelAuthPage(prep.text)) { + writeOpenAIError( + res, + 401, + 'Vercel Deployment Protection blocked internal prepare request. Disable protection for this deployment or set VERCEL_AUTOMATION_BYPASS_SECRET.', + ); + return; + } + res.statusCode = prep.status || 500; + res.setHeader('Content-Type', prep.contentType || 'application/json'); + if (prep.text) { + res.end(prep.text); + return; + } + writeOpenAIError(res, prep.status || 500, 'vercel prepare failed'); +} + +async function safeReadText(resp) { + if (!resp) { + return ''; + } + try { + const text = await resp.text(); + return text.trim(); + } catch (_err) { + return ''; + } +} + +function internalSecret() { + return asString(process.env.DS2API_VERCEL_INTERNAL_SECRET) || asString(process.env.DS2API_ADMIN_KEY) || 'admin'; +} + +function buildInternalGoURL(req) { + const proto = asString(header(req, 'x-forwarded-proto')) || 'https'; + const host = asString(header(req, 'host')); + const url = new URL(`${proto}://${host}${req.url || '/v1/chat/completions'}`); + url.searchParams.set('__go', '1'); + const protectionBypass = resolveProtectionBypass(req); + if (protectionBypass) { + url.searchParams.set('x-vercel-protection-bypass', protectionBypass); + } + return url; +} + +function buildInternalGoHeaders(req, opts = {}) { + const headers = { + authorization: asString(header(req, 'authorization')), + 'x-api-key': asString(header(req, 'x-api-key')), + 'x-ds2-target-account': asString(header(req, 'x-ds2-target-account')), + 'x-vercel-protection-bypass': resolveProtectionBypass(req), + }; + if (opts.withInternalToken) { + headers['x-ds2-internal-token'] = internalSecret(); + } + if (opts.withContentType) { + headers['content-type'] = asString(header(req, 'content-type')) || 'application/json'; + } + return headers; +} + +function createLeaseReleaser(req, leaseID) { + let released = false; + return async () => { + if (released || !leaseID) { + return; + } + released = true; + try { + await releaseStreamLease(req, leaseID); + } catch (_err) { + // Ignore release errors. Lease TTL cleanup on Go side still prevents permanent leaks. + } + }; +} + +async function releaseStreamLease(req, leaseID) { + const url = buildInternalGoURL(req); + url.searchParams.set('__stream_release', '1'); + const body = Buffer.from(JSON.stringify({ lease_id: leaseID })); + + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 1500); + try { + await fetch(url.toString(), { + method: 'POST', + headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }), + body, + signal: controller.signal, + }); + } finally { + clearTimeout(timeout); + } +} + +function resolveProtectionBypass(req) { + const fromHeader = asString(header(req, 'x-vercel-protection-bypass')); + if (fromHeader) { + return fromHeader; + } + return asString(process.env.VERCEL_AUTOMATION_BYPASS_SECRET) || asString(process.env.DS2API_VERCEL_PROTECTION_BYPASS); +} + +function looksLikeVercelAuthPage(text) { + const body = asString(text).toLowerCase(); + if (!body) { + return false; + } + return body.includes('authentication required') && body.includes('vercel'); +} + +function asString(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return asString(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +function isAbortError(err) { + if (!err || typeof err !== 'object') { + return false; + } + return err.name === 'AbortError' || err.code === 'ABORT_ERR'; +} + +module.exports = { + setCorsHeaders, + header, + readRawBody, + fetchStreamPrepare, + relayPreparedFailure, + safeReadText, + buildInternalGoURL, + buildInternalGoHeaders, + createLeaseReleaser, + releaseStreamLease, + resolveProtectionBypass, + looksLikeVercelAuthPage, + asString, + isAbortError, +}; diff --git a/internal/js/chat-stream/index.js b/internal/js/chat-stream/index.js new file mode 100644 index 0000000..4528924 --- /dev/null +++ b/internal/js/chat-stream/index.js @@ -0,0 +1,88 @@ +'use strict'; + +const { + writeOpenAIError, +} = require('./error_shape'); +const { + parseChunkForContent, + extractContentRecursive, + shouldSkipPath, +} = require('./sse_parse'); +const { + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, +} = require('./toolcall_policy'); +const { + estimateTokens, +} = require('./token_usage'); +const { + setCorsHeaders, + readRawBody, + asString, +} = require('./http_internal'); +const { + proxyToGo, +} = require('./proxy_go'); +const { + handleVercelStream, +} = require('./vercel_stream'); + +async function handler(req, res) { + setCorsHeaders(res); + if (req.method === 'OPTIONS') { + res.statusCode = 204; + res.end(); + return; + } + if (req.method !== 'POST') { + writeOpenAIError(res, 405, 'method not allowed'); + return; + } + + const rawBody = await readRawBody(req); + + // Hard guard: only use Node data path for streaming on Vercel runtime. + // Any non-Vercel runtime always falls back to Go for full behavior parity. + if (!isVercelRuntime()) { + await proxyToGo(req, res, rawBody); + return; + } + + let payload; + try { + payload = JSON.parse(rawBody.toString('utf8') || '{}'); + } catch (_err) { + writeOpenAIError(res, 400, 'invalid json'); + return; + } + + // Keep all non-stream behavior on Go side to avoid compatibility regressions. + if (!toBool(payload.stream)) { + await proxyToGo(req, res, rawBody); + return; + } + + await handleVercelStream(req, res, rawBody, payload); +} + +function toBool(v) { + return v === true; +} + +function isVercelRuntime() { + return asString(process.env.VERCEL) !== '' || asString(process.env.NOW_REGION) !== ''; +} + +module.exports = handler; + +module.exports.__test = { + parseChunkForContent, + extractContentRecursive, + shouldSkipPath, + asString, + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, + estimateTokens, +}; diff --git a/internal/js/chat-stream/proxy_go.js b/internal/js/chat-stream/proxy_go.js new file mode 100644 index 0000000..5218df0 --- /dev/null +++ b/internal/js/chat-stream/proxy_go.js @@ -0,0 +1,105 @@ +'use strict'; + +const { + buildInternalGoURL, + buildInternalGoHeaders, + isAbortError, +} = require('./http_internal'); + +async function proxyToGo(req, res, rawBody) { + const url = buildInternalGoURL(req); + const controller = new AbortController(); + let clientClosed = false; + const markClientClosed = () => { + if (clientClosed) { + return; + } + clientClosed = true; + controller.abort(); + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); + + try { + let upstream; + try { + upstream = await fetch(url.toString(), { + method: 'POST', + headers: buildInternalGoHeaders(req, { withContentType: true }), + body: rawBody, + signal: controller.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + if (!res.writableEnded) { + res.end(); + } + return; + } + throw err; + } + if (clientClosed) { + if (!res.writableEnded) { + res.end(); + } + return; + } + + res.statusCode = upstream.status; + upstream.headers.forEach((value, key) => { + if (key.toLowerCase() === 'content-length') { + return; + } + res.setHeader(key, value); + }); + + if (!upstream.body || typeof upstream.body.getReader !== 'function') { + const bytes = Buffer.from(await upstream.arrayBuffer()); + res.end(bytes); + return; + } + + const reader = upstream.body.getReader(); + try { + // eslint-disable-next-line no-constant-condition + while (true) { + if (clientClosed) { + break; + } + const { value, done } = await reader.read(); + if (done) { + break; + } + if (value && value.length > 0) { + res.write(Buffer.from(value)); + if (typeof res.flush === 'function') { + res.flush(); + } + } + } + if (!res.writableEnded) { + res.end(); + } + } catch (err) { + if (!isAbortError(err) && !res.writableEnded) { + res.end(); + } + } + } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); + if (!res.writableEnded) { + res.end(); + } + } +} + +module.exports = { + proxyToGo, +}; diff --git a/internal/js/chat-stream/sse_parse.js b/internal/js/chat-stream/sse_parse.js new file mode 100644 index 0000000..1774430 --- /dev/null +++ b/internal/js/chat-stream/sse_parse.js @@ -0,0 +1,229 @@ +'use strict'; + +const { + SKIP_PATTERNS, + SKIP_EXACT_PATHS, +} = require('../shared/deepseek-constants'); + +function parseChunkForContent(chunk, thinkingEnabled, currentType) { + if (!chunk || typeof chunk !== 'object' || !Object.prototype.hasOwnProperty.call(chunk, 'v')) { + return { parts: [], finished: false, newType: currentType }; + } + const pathValue = asString(chunk.p); + if (shouldSkipPath(pathValue)) { + return { parts: [], finished: false, newType: currentType }; + } + if (pathValue === 'response/status' && asString(chunk.v) === 'FINISHED') { + return { parts: [], finished: true, newType: currentType }; + } + + let newType = currentType; + const parts = []; + + if (pathValue === 'response/fragments' && asString(chunk.o).toUpperCase() === 'APPEND' && Array.isArray(chunk.v)) { + for (const frag of chunk.v) { + if (!frag || typeof frag !== 'object') { + continue; + } + const fragType = asString(frag.type).toUpperCase(); + const content = asString(frag.content); + if (!content) { + continue; + } + if (fragType === 'THINK' || fragType === 'THINKING') { + newType = 'thinking'; + parts.push({ text: content, type: 'thinking' }); + } else if (fragType === 'RESPONSE') { + newType = 'text'; + parts.push({ text: content, type: 'text' }); + } else { + parts.push({ text: content, type: 'text' }); + } + } + } + + if (pathValue === 'response' && Array.isArray(chunk.v)) { + for (const item of chunk.v) { + if (!item || typeof item !== 'object') { + continue; + } + if (item.p === 'fragments' && item.o === 'APPEND' && Array.isArray(item.v)) { + for (const frag of item.v) { + const fragType = asString(frag && frag.type).toUpperCase(); + if (fragType === 'THINK' || fragType === 'THINKING') { + newType = 'thinking'; + } else if (fragType === 'RESPONSE') { + newType = 'text'; + } + } + } + } + } + + let partType = 'text'; + if (pathValue === 'response/thinking_content') { + partType = 'thinking'; + } else if (pathValue === 'response/content') { + partType = 'text'; + } else if (pathValue.includes('response/fragments') && pathValue.includes('/content')) { + partType = newType; + } else if (!pathValue && thinkingEnabled) { + partType = newType; + } + + const val = chunk.v; + if (typeof val === 'string') { + if (val === 'FINISHED' && (!pathValue || pathValue === 'status')) { + return { parts: [], finished: true, newType }; + } + if (val) { + parts.push({ text: val, type: partType }); + } + return { parts, finished: false, newType }; + } + + if (Array.isArray(val)) { + const extracted = extractContentRecursive(val, partType); + if (extracted.finished) { + return { parts: [], finished: true, newType }; + } + parts.push(...extracted.parts); + return { parts, finished: false, newType }; + } + + if (val && typeof val === 'object') { + const resp = val.response && typeof val.response === 'object' ? val.response : val; + if (Array.isArray(resp.fragments)) { + for (const frag of resp.fragments) { + if (!frag || typeof frag !== 'object') { + continue; + } + const content = asString(frag.content); + if (!content) { + continue; + } + const t = asString(frag.type).toUpperCase(); + if (t === 'THINK' || t === 'THINKING') { + newType = 'thinking'; + parts.push({ text: content, type: 'thinking' }); + } else if (t === 'RESPONSE') { + newType = 'text'; + parts.push({ text: content, type: 'text' }); + } else { + parts.push({ text: content, type: partType }); + } + } + } + } + return { parts, finished: false, newType }; +} + +function extractContentRecursive(items, defaultType) { + const parts = []; + for (const it of items) { + if (!it || typeof it !== 'object') { + continue; + } + if (!Object.prototype.hasOwnProperty.call(it, 'v')) { + continue; + } + const itemPath = asString(it.p); + const itemV = it.v; + if (itemPath === 'status' && asString(itemV) === 'FINISHED') { + return { parts: [], finished: true }; + } + if (shouldSkipPath(itemPath)) { + continue; + } + const content = asString(it.content); + if (content) { + const typeName = asString(it.type).toUpperCase(); + if (typeName === 'THINK' || typeName === 'THINKING') { + parts.push({ text: content, type: 'thinking' }); + } else if (typeName === 'RESPONSE') { + parts.push({ text: content, type: 'text' }); + } else { + parts.push({ text: content, type: defaultType }); + } + continue; + } + + let partType = defaultType; + if (itemPath.includes('thinking')) { + partType = 'thinking'; + } else if (itemPath.includes('content') || itemPath === 'response' || itemPath === 'fragments') { + partType = 'text'; + } + + if (typeof itemV === 'string') { + if (itemV && itemV !== 'FINISHED') { + parts.push({ text: itemV, type: partType }); + } + continue; + } + + if (!Array.isArray(itemV)) { + continue; + } + for (const inner of itemV) { + if (typeof inner === 'string') { + if (inner) { + parts.push({ text: inner, type: partType }); + } + continue; + } + if (!inner || typeof inner !== 'object') { + continue; + } + const ct = asString(inner.content); + if (!ct) { + continue; + } + const typeName = asString(inner.type).toUpperCase(); + if (typeName === 'THINK' || typeName === 'THINKING') { + parts.push({ text: ct, type: 'thinking' }); + } else if (typeName === 'RESPONSE') { + parts.push({ text: ct, type: 'text' }); + } else { + parts.push({ text: ct, type: partType }); + } + } + } + return { parts, finished: false }; +} + +function shouldSkipPath(pathValue) { + if (SKIP_EXACT_PATHS.has(pathValue)) { + return true; + } + for (const p of SKIP_PATTERNS) { + if (pathValue.includes(p)) { + return true; + } + } + return false; +} + +function isCitation(text) { + return asString(text).trim().startsWith('[citation:'); +} + +function asString(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return asString(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +module.exports = { + parseChunkForContent, + extractContentRecursive, + shouldSkipPath, + isCitation, +}; diff --git a/internal/js/chat-stream/stream_emitter.js b/internal/js/chat-stream/stream_emitter.js new file mode 100644 index 0000000..442c24e --- /dev/null +++ b/internal/js/chat-stream/stream_emitter.js @@ -0,0 +1,39 @@ +'use strict'; + +function createChatCompletionEmitter({ res, sessionID, created, model, isClosed }) { + let firstChunkSent = false; + + const sendFrame = (obj) => { + if (isClosed() || res.writableEnded || res.destroyed) { + return; + } + res.write(`data: ${JSON.stringify(obj)}\n\n`); + if (typeof res.flush === 'function') { + res.flush(); + } + }; + + const sendDeltaFrame = (delta) => { + const payloadDelta = { ...delta }; + if (!firstChunkSent) { + payloadDelta.role = 'assistant'; + firstChunkSent = true; + } + sendFrame({ + id: sessionID, + object: 'chat.completion.chunk', + created, + model, + choices: [{ delta: payloadDelta, index: 0 }], + }); + }; + + return { + sendFrame, + sendDeltaFrame, + }; +} + +module.exports = { + createChatCompletionEmitter, +}; diff --git a/internal/js/chat-stream/token_usage.js b/internal/js/chat-stream/token_usage.js new file mode 100644 index 0000000..57a36fb --- /dev/null +++ b/internal/js/chat-stream/token_usage.js @@ -0,0 +1,51 @@ +'use strict'; + +function buildUsage(prompt, thinking, output) { + const promptTokens = estimateTokens(prompt); + const reasoningTokens = estimateTokens(thinking); + const completionTokens = estimateTokens(output); + return { + prompt_tokens: promptTokens, + completion_tokens: reasoningTokens + completionTokens, + total_tokens: promptTokens + reasoningTokens + completionTokens, + completion_tokens_details: { + reasoning_tokens: reasoningTokens, + }, + }; +} + +function estimateTokens(text) { + const t = asString(text); + if (!t) { + return 0; + } + let asciiChars = 0; + let nonASCIIChars = 0; + for (const ch of Array.from(t)) { + if (ch.charCodeAt(0) < 128) { + asciiChars += 1; + } else { + nonASCIIChars += 1; + } + } + const n = Math.floor(asciiChars / 4) + Math.floor((nonASCIIChars * 10 + 7) / 13); + return n < 1 ? 1 : n; +} + +function asString(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return asString(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +module.exports = { + buildUsage, + estimateTokens, +}; diff --git a/internal/js/chat-stream/toolcall_policy.js b/internal/js/chat-stream/toolcall_policy.js new file mode 100644 index 0000000..4f4b37c --- /dev/null +++ b/internal/js/chat-stream/toolcall_policy.js @@ -0,0 +1,107 @@ +'use strict'; + +const crypto = require('crypto'); + +const { + extractToolNames, +} = require('../helpers/stream-tool-sieve'); + +function resolveToolcallPolicy(prepBody, payloadTools) { + const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names); + const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools); + const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match); + const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high); + return { + toolNames, + toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled, + emitEarlyToolDeltas, + }; +} + +function normalizePreparedToolNames(v) { + if (!Array.isArray(v) || v.length === 0) { + return []; + } + const out = []; + for (const item of v) { + const name = asString(item); + if (!name) { + continue; + } + out.push(name); + } + return out; +} + +function boolDefaultTrue(v) { + return v !== false; +} + +function formatIncrementalToolCallDeltas(deltas, idStore) { + if (!Array.isArray(deltas) || deltas.length === 0) { + return []; + } + const out = []; + for (const d of deltas) { + if (!d || typeof d !== 'object') { + continue; + } + const index = Number.isInteger(d.index) ? d.index : 0; + const id = ensureStreamToolCallID(idStore, index); + const item = { + index, + id, + type: 'function', + }; + const fn = {}; + if (asString(d.name)) { + fn.name = asString(d.name); + } + if (typeof d.arguments === 'string' && d.arguments !== '') { + fn.arguments = d.arguments; + } + if (Object.keys(fn).length > 0) { + item.function = fn; + } + out.push(item); + } + return out; +} + +function ensureStreamToolCallID(idStore, index) { + const key = Number.isInteger(index) ? index : 0; + const existing = idStore.get(key); + if (existing) { + return existing; + } + const next = `call_${newCallID()}`; + idStore.set(key, next); + return next; +} + +function newCallID() { + if (typeof crypto.randomUUID === 'function') { + return crypto.randomUUID().replace(/-/g, ''); + } + return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; +} + +function asString(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return asString(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +module.exports = { + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, + formatIncrementalToolCallDeltas, +}; diff --git a/internal/js/chat-stream/vercel_stream.js b/internal/js/chat-stream/vercel_stream.js new file mode 100644 index 0000000..324a3d8 --- /dev/null +++ b/internal/js/chat-stream/vercel_stream.js @@ -0,0 +1,297 @@ +'use strict'; + +const { + extractToolNames, + createToolSieveState, + processToolSieveChunk, + flushToolSieve, + parseToolCalls, + formatOpenAIStreamToolCalls, +} = require('../helpers/stream-tool-sieve'); +const { + BASE_HEADERS, +} = require('../shared/deepseek-constants'); + +const { + writeOpenAIError, +} = require('./error_shape'); +const { + parseChunkForContent, + isCitation, +} = require('./sse_parse'); +const { + buildUsage, +} = require('./token_usage'); +const { + resolveToolcallPolicy, + formatIncrementalToolCallDeltas, +} = require('./toolcall_policy'); +const { + createChatCompletionEmitter, +} = require('./stream_emitter'); +const { + asString, + isAbortError, + fetchStreamPrepare, + relayPreparedFailure, + safeReadText, + createLeaseReleaser, +} = require('./http_internal'); + +const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion'; + +async function handleVercelStream(req, res, rawBody, payload) { + const prep = await fetchStreamPrepare(req, rawBody); + if (!prep.ok) { + relayPreparedFailure(res, prep); + return; + } + + const model = asString(prep.body.model) || asString(payload.model); + const sessionID = asString(prep.body.session_id) || `chatcmpl-${Date.now()}`; + const leaseID = asString(prep.body.lease_id); + const deepseekToken = asString(prep.body.deepseek_token); + const powHeader = asString(prep.body.pow_header); + const completionPayload = prep.body.payload && typeof prep.body.payload === 'object' ? prep.body.payload : null; + const finalPrompt = asString(prep.body.final_prompt); + const thinkingEnabled = toBool(prep.body.thinking_enabled); + const searchEnabled = toBool(prep.body.search_enabled); + const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools); + const toolNames = toolPolicy.toolNames; + + if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) { + writeOpenAIError(res, 500, 'invalid vercel prepare response'); + return; + } + + const releaseLease = createLeaseReleaser(req, leaseID); + const upstreamController = new AbortController(); + let clientClosed = false; + let reader = null; + const markClientClosed = () => { + if (clientClosed) { + return; + } + clientClosed = true; + upstreamController.abort(); + if (reader && typeof reader.cancel === 'function') { + Promise.resolve(reader.cancel()).catch(() => {}); + } + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); + + try { + let completionRes; + try { + completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { + method: 'POST', + headers: { + ...BASE_HEADERS, + authorization: `Bearer ${deepseekToken}`, + 'x-ds-pow-response': powHeader, + }, + body: JSON.stringify(completionPayload), + signal: upstreamController.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + return; + } + throw err; + } + if (clientClosed) { + return; + } + + if (!completionRes.ok || !completionRes.body) { + const detail = await safeReadText(completionRes); + writeOpenAIError(res, 500, detail ? `Failed to get completion: ${detail}` : 'Failed to get completion.'); + return; + } + + res.statusCode = 200; + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache, no-transform'); + res.setHeader('Connection', 'keep-alive'); + res.setHeader('X-Accel-Buffering', 'no'); + if (typeof res.flushHeaders === 'function') { + res.flushHeaders(); + } + + const created = Math.floor(Date.now() / 1000); + let currentType = thinkingEnabled ? 'thinking' : 'text'; + let thinkingText = ''; + let outputText = ''; + const toolSieveEnabled = toolPolicy.toolSieveEnabled; + const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas; + const toolSieveState = createToolSieveState(); + let toolCallsEmitted = false; + const streamToolCallIDs = new Map(); + const decoder = new TextDecoder(); + reader = completionRes.body.getReader(); + let buffered = ''; + let ended = false; + const { sendFrame, sendDeltaFrame } = createChatCompletionEmitter({ + res, + sessionID, + created, + model, + isClosed: () => clientClosed, + }); + + const finish = async (reason) => { + if (ended) { + return; + } + ended = true; + if (clientClosed || res.writableEnded || res.destroyed) { + await releaseLease(); + return; + } + const detected = parseToolCalls(outputText, toolNames); + if (detected.length > 0 && !toolCallsEmitted) { + toolCallsEmitted = true; + sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) }); + } else if (toolSieveEnabled) { + const tailEvents = flushToolSieve(toolSieveState, toolNames); + for (const evt of tailEvents) { + if (evt.text) { + sendDeltaFrame({ content: evt.text }); + } + } + } + if (detected.length > 0 || toolCallsEmitted) { + reason = 'tool_calls'; + } + sendFrame({ + id: sessionID, + object: 'chat.completion.chunk', + created, + model, + choices: [{ delta: {}, index: 0, finish_reason: reason }], + usage: buildUsage(finalPrompt, thinkingText, outputText), + }); + if (!res.writableEnded && !res.destroyed) { + res.write('data: [DONE]\n\n'); + } + await releaseLease(); + if (!res.writableEnded && !res.destroyed) { + res.end(); + } + }; + + try { + // eslint-disable-next-line no-constant-condition + while (true) { + if (clientClosed) { + await finish('stop'); + return; + } + const { value, done } = await reader.read(); + if (done) { + break; + } + buffered += decoder.decode(value, { stream: true }); + const lines = buffered.split('\n'); + buffered = lines.pop() || ''; + + for (const rawLine of lines) { + const line = rawLine.trim(); + if (!line.startsWith('data:')) { + continue; + } + const dataStr = line.slice(5).trim(); + if (!dataStr) { + continue; + } + if (dataStr === '[DONE]') { + await finish('stop'); + return; + } + let chunk; + try { + chunk = JSON.parse(dataStr); + } catch (_err) { + continue; + } + if (chunk.error || chunk.code === 'content_filter') { + await finish('content_filter'); + return; + } + const parsed = parseChunkForContent(chunk, thinkingEnabled, currentType); + currentType = parsed.newType; + if (parsed.finished) { + await finish('stop'); + return; + } + + for (const p of parsed.parts) { + if (!p.text) { + continue; + } + if (searchEnabled && isCitation(p.text)) { + continue; + } + if (p.type === 'thinking') { + if (thinkingEnabled) { + thinkingText += p.text; + sendDeltaFrame({ reasoning_content: p.text }); + } + } else { + outputText += p.text; + if (!toolSieveEnabled) { + sendDeltaFrame({ content: p.text }); + continue; + } + const events = processToolSieveChunk(toolSieveState, p.text, toolNames); + for (const evt of events) { + if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) { + if (!emitEarlyToolDeltas) { + continue; + } + toolCallsEmitted = true; + sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) }); + continue; + } + if (evt.type === 'tool_calls') { + toolCallsEmitted = true; + sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) }); + continue; + } + if (evt.text) { + sendDeltaFrame({ content: evt.text }); + } + } + } + } + } + } + await finish('stop'); + } catch (err) { + if (clientClosed || isAbortError(err)) { + await finish('stop'); + return; + } + await finish('stop'); + } + } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); + await releaseLease(); + } +} + +function toBool(v) { + return v === true; +} + +module.exports = { + handleVercelStream, +}; diff --git a/internal/js/helpers/stream-tool-sieve.js b/internal/js/helpers/stream-tool-sieve.js new file mode 100644 index 0000000..8985478 --- /dev/null +++ b/internal/js/helpers/stream-tool-sieve.js @@ -0,0 +1,3 @@ +'use strict'; + +module.exports = require('./stream-tool-sieve/index.js'); diff --git a/internal/js/helpers/stream-tool-sieve/format.js b/internal/js/helpers/stream-tool-sieve/format.js new file mode 100644 index 0000000..ff1dcef --- /dev/null +++ b/internal/js/helpers/stream-tool-sieve/format.js @@ -0,0 +1,29 @@ +'use strict'; + +const crypto = require('crypto'); + +function formatOpenAIStreamToolCalls(calls) { + if (!Array.isArray(calls) || calls.length === 0) { + return []; + } + return calls.map((c, idx) => ({ + index: idx, + id: `call_${newCallID()}`, + type: 'function', + function: { + name: c.name, + arguments: JSON.stringify(c.input || {}), + }, + })); +} + +function newCallID() { + if (typeof crypto.randomUUID === 'function') { + return crypto.randomUUID().replace(/-/g, ''); + } + return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; +} + +module.exports = { + formatOpenAIStreamToolCalls, +}; diff --git a/internal/js/helpers/stream-tool-sieve/incremental.js b/internal/js/helpers/stream-tool-sieve/incremental.js new file mode 100644 index 0000000..1895075 --- /dev/null +++ b/internal/js/helpers/stream-tool-sieve/incremental.js @@ -0,0 +1,226 @@ +'use strict'; + +const { + looksLikeToolExampleContext, + insideCodeFence, +} = require('./state'); +const { + findObjectFieldValueStart, + parseJSONStringLiteral, + skipSpaces, +} = require('./jsonscan'); + +function buildIncrementalToolDeltas(state) { + const captured = state.capture || ''; + if (!captured) { + return []; + } + if (looksLikeToolExampleContext(state.recentTextTail)) { + return []; + } + const lower = captured.toLowerCase(); + const keyIdx = lower.indexOf('tool_calls'); + if (keyIdx < 0) { + return []; + } + const start = captured.slice(0, keyIdx).lastIndexOf('{'); + if (start < 0) { + return []; + } + if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) { + return []; + } + const callStart = findFirstToolCallObjectStart(captured, keyIdx); + if (callStart < 0) { + return []; + } + + const deltas = []; + if (!state.toolName) { + const name = extractToolCallName(captured, callStart); + if (!name) { + return []; + } + state.toolName = name; + } + + if (state.toolArgsStart < 0) { + const args = findToolCallArgsStart(captured, callStart); + if (args) { + state.toolArgsString = Boolean(args.stringMode); + state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start; + state.toolArgsSent = state.toolArgsStart; + } + } + if (!state.toolNameSent) { + if (state.toolArgsStart < 0) { + return []; + } + state.toolNameSent = true; + deltas.push({ index: 0, name: state.toolName }); + } + if (state.toolArgsStart < 0 || state.toolArgsDone) { + return deltas; + } + const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString); + if (!progress) { + return deltas; + } + if (progress.end > state.toolArgsSent) { + deltas.push({ + index: 0, + arguments: captured.slice(state.toolArgsSent, progress.end), + }); + state.toolArgsSent = progress.end; + } + if (progress.complete) { + state.toolArgsDone = true; + } + return deltas; +} + +function findFirstToolCallObjectStart(text, keyIdx) { + const arrStart = findToolCallsArrayStart(text, keyIdx); + if (arrStart < 0) { + return -1; + } + const i = skipSpaces(text, arrStart + 1); + if (i >= text.length || text[i] !== '{') { + return -1; + } + return i; +} + +function findToolCallsArrayStart(text, keyIdx) { + let i = keyIdx + 'tool_calls'.length; + while (i < text.length && text[i] !== ':') { + i += 1; + } + if (i >= text.length) { + return -1; + } + i = skipSpaces(text, i + 1); + if (i >= text.length || text[i] !== '[') { + return -1; + } + return i; +} + +function extractToolCallName(text, callStart) { + let valueStart = findObjectFieldValueStart(text, callStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return ''; + } + valueStart = findObjectFieldValueStart(text, fnStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + return ''; + } + } + const parsed = parseJSONStringLiteral(text, valueStart); + if (!parsed) { + return ''; + } + return parsed.value; +} + +function findToolCallArgsStart(text, callStart) { + const keys = ['input', 'arguments', 'args', 'parameters', 'params']; + let valueStart = findObjectFieldValueStart(text, callStart, keys); + if (valueStart < 0) { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return null; + } + valueStart = findObjectFieldValueStart(text, fnStart, keys); + if (valueStart < 0) { + return null; + } + } + if (valueStart >= text.length) { + return null; + } + const ch = text[valueStart]; + if (ch === '{' || ch === '[') { + return { start: valueStart, stringMode: false }; + } + if (ch === '"') { + return { start: valueStart, stringMode: true }; + } + return null; +} + +function scanToolCallArgsProgress(text, start, stringMode) { + if (start < 0 || start > text.length) { + return null; + } + if (stringMode) { + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { end: i, complete: true }; + } + } + return { end: text.length, complete: false }; + } + if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) { + return null; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + quote = ch; + continue; + } + if (ch === '{' || ch === '[') { + depth += 1; + continue; + } + if (ch === '}' || ch === ']') { + depth -= 1; + if (depth === 0) { + return { end: i + 1, complete: true }; + } + } + } + return { end: text.length, complete: false }; +} + +function findFunctionObjectStart(text, callStart) { + const valueStart = findObjectFieldValueStart(text, callStart, ['function']); + if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') { + return -1; + } + return valueStart; +} + +module.exports = { + buildIncrementalToolDeltas, +}; diff --git a/internal/js/helpers/stream-tool-sieve/index.js b/internal/js/helpers/stream-tool-sieve/index.js new file mode 100644 index 0000000..f218b52 --- /dev/null +++ b/internal/js/helpers/stream-tool-sieve/index.js @@ -0,0 +1,27 @@ +'use strict'; + +const { + createToolSieveState, +} = require('./state'); +const { + processToolSieveChunk, + flushToolSieve, +} = require('./sieve'); +const { + extractToolNames, + parseToolCalls, + parseStandaloneToolCalls, +} = require('./parse'); +const { + formatOpenAIStreamToolCalls, +} = require('./format'); + +module.exports = { + extractToolNames, + createToolSieveState, + processToolSieveChunk, + flushToolSieve, + parseToolCalls, + parseStandaloneToolCalls, + formatOpenAIStreamToolCalls, +}; diff --git a/internal/js/helpers/stream-tool-sieve/jsonscan.js b/internal/js/helpers/stream-tool-sieve/jsonscan.js new file mode 100644 index 0000000..a86ed05 --- /dev/null +++ b/internal/js/helpers/stream-tool-sieve/jsonscan.js @@ -0,0 +1,148 @@ +'use strict'; + +function findObjectFieldValueStart(text, objStart, keys) { + if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') { + return -1; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = objStart; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + if (depth === 1) { + const parsed = parseJSONStringLiteral(text, i); + if (!parsed) { + return -1; + } + let j = skipSpaces(text, parsed.end); + if (j >= text.length || text[j] !== ':') { + i = parsed.end - 1; + continue; + } + j = skipSpaces(text, j + 1); + if (j >= text.length) { + return -1; + } + if (keys.includes(parsed.value)) { + return j; + } + i = j - 1; + continue; + } + quote = ch; + continue; + } + if (ch === '{') { + depth += 1; + continue; + } + if (ch === '}') { + depth -= 1; + if (depth === 0) { + break; + } + } + } + return -1; +} + +function parseJSONStringLiteral(text, start) { + if (!text || start < 0 || start >= text.length || text[start] !== '"') { + return null; + } + let out = ''; + let escaped = false; + for (let i = start + 1; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + out += ch; + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { value: out, end: i + 1 }; + } + out += ch; + } + return null; +} + +function skipSpaces(text, i) { + let idx = i; + while (idx < text.length) { + const ch = text[idx]; + if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') { + idx += 1; + continue; + } + break; + } + return idx; +} + +function extractJSONObjectFrom(text, start) { + if (!text || start < 0 || start >= text.length || text[start] !== '{') { + return { ok: false, end: 0 }; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + quote = ch; + continue; + } + if (ch === '{') { + depth += 1; + continue; + } + if (ch === '}') { + depth -= 1; + if (depth === 0) { + return { ok: true, end: i + 1 }; + } + } + } + return { ok: false, end: 0 }; +} + +module.exports = { + findObjectFieldValueStart, + parseJSONStringLiteral, + skipSpaces, + extractJSONObjectFrom, +}; diff --git a/internal/js/helpers/stream-tool-sieve/parse.js b/internal/js/helpers/stream-tool-sieve/parse.js new file mode 100644 index 0000000..f1efdda --- /dev/null +++ b/internal/js/helpers/stream-tool-sieve/parse.js @@ -0,0 +1,273 @@ +'use strict'; + +const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s; + +const { + toStringSafe, + looksLikeToolExampleContext, +} = require('./state'); +const { + extractJSONObjectFrom, +} = require('./jsonscan'); + +function extractToolNames(tools) { + if (!Array.isArray(tools) || tools.length === 0) { + return []; + } + const out = []; + for (const t of tools) { + if (!t || typeof t !== 'object') { + continue; + } + const fn = t.function && typeof t.function === 'object' ? t.function : t; + const name = toStringSafe(fn.name); + // Keep parity with Go injectToolPrompt: object tools without name still + // enter tool mode via fallback name "unknown". + out.push(name || 'unknown'); + } + return out; +} + +function parseToolCalls(text, toolNames) { + if (!toStringSafe(text)) { + return []; + } + const sanitized = stripFencedCodeBlocks(text); + if (!toStringSafe(sanitized)) { + return []; + } + const candidates = buildToolCallCandidates(sanitized); + let parsed = []; + for (const c of candidates) { + parsed = parseToolCallsPayload(c); + if (parsed.length > 0) { + break; + } + } + if (parsed.length === 0) { + return []; + } + return filterToolCalls(parsed, toolNames); +} + +function stripFencedCodeBlocks(text) { + const t = typeof text === 'string' ? text : ''; + if (!t) { + return ''; + } + return t.replace(/```[\s\S]*?```/g, ' '); +} + +function parseStandaloneToolCalls(text, toolNames) { + const trimmed = toStringSafe(text); + if (!trimmed) { + return []; + } + if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) { + return []; + } + if (looksLikeToolExampleContext(trimmed)) { + return []; + } + const candidates = [trimmed]; + if (trimmed.startsWith('```') && trimmed.endsWith('```')) { + const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); + if (m && m[1]) { + candidates.push(toStringSafe(m[1])); + } + } + for (const candidate of candidates) { + const c = toStringSafe(candidate); + if (!c) { + continue; + } + if (!c.startsWith('{') && !c.startsWith('[')) { + continue; + } + const parsed = parseToolCallsPayload(c); + if (parsed.length > 0) { + return filterToolCalls(parsed, toolNames); + } + } + return []; +} + +function buildToolCallCandidates(text) { + const trimmed = toStringSafe(text); + const candidates = [trimmed]; + const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || []; + for (const block of fenced) { + const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); + if (m && m[1]) { + candidates.push(toStringSafe(m[1])); + } + } + for (const candidate of extractToolCallObjects(trimmed)) { + candidates.push(toStringSafe(candidate)); + } + const first = trimmed.indexOf('{'); + const last = trimmed.lastIndexOf('}'); + if (first >= 0 && last > first) { + candidates.push(toStringSafe(trimmed.slice(first, last + 1))); + } + const m = trimmed.match(TOOL_CALL_PATTERN); + if (m && m[1]) { + candidates.push(`{"tool_calls":[${m[1]}]}`); + } + return [...new Set(candidates.filter(Boolean))]; +} + +function extractToolCallObjects(text) { + const raw = toStringSafe(text); + if (!raw) { + return []; + } + const lower = raw.toLowerCase(); + const out = []; + let offset = 0; + // eslint-disable-next-line no-constant-condition + while (true) { + let idx = lower.indexOf('tool_calls', offset); + if (idx < 0) { + break; + } + let start = raw.slice(0, idx).lastIndexOf('{'); + while (start >= 0) { + const obj = extractJSONObjectFrom(raw, start); + if (obj.ok) { + out.push(raw.slice(start, obj.end).trim()); + offset = obj.end; + idx = -1; + break; + } + start = raw.slice(0, start).lastIndexOf('{'); + } + if (idx >= 0) { + offset = idx + 'tool_calls'.length; + } + } + return out; +} + +function parseToolCallsPayload(payload) { + let decoded; + try { + decoded = JSON.parse(payload); + } catch (_err) { + return []; + } + if (Array.isArray(decoded)) { + return parseToolCallList(decoded); + } + if (!decoded || typeof decoded !== 'object') { + return []; + } + if (decoded.tool_calls) { + return parseToolCallList(decoded.tool_calls); + } + const one = parseToolCallItem(decoded); + return one ? [one] : []; +} + +function parseToolCallList(v) { + if (!Array.isArray(v)) { + return []; + } + const out = []; + for (const item of v) { + if (!item || typeof item !== 'object') { + continue; + } + const one = parseToolCallItem(item); + if (one) { + out.push(one); + } + } + return out; +} + +function parseToolCallItem(m) { + let name = toStringSafe(m.name); + let inputRaw = m.input; + let hasInput = Object.prototype.hasOwnProperty.call(m, 'input'); + const fn = m.function && typeof m.function === 'object' ? m.function : null; + if (fn) { + if (!name) { + name = toStringSafe(fn.name); + } + if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) { + inputRaw = fn.arguments; + hasInput = true; + } + } + if (!hasInput) { + for (const k of ['arguments', 'args', 'parameters', 'params']) { + if (Object.prototype.hasOwnProperty.call(m, k)) { + inputRaw = m[k]; + hasInput = true; + break; + } + } + } + if (!name) { + return null; + } + return { + name, + input: parseToolCallInput(inputRaw), + }; +} + +function parseToolCallInput(v) { + if (v == null) { + return {}; + } + if (typeof v === 'string') { + const raw = toStringSafe(v); + if (!raw) { + return {}; + } + try { + const parsed = JSON.parse(raw); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return parsed; + } + return { _raw: raw }; + } catch (_err) { + return { _raw: raw }; + } + } + if (typeof v === 'object' && !Array.isArray(v)) { + return v; + } + try { + const parsed = JSON.parse(JSON.stringify(v)); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return parsed; + } + } catch (_err) { + return {}; + } + return {}; +} + +function filterToolCalls(parsed, toolNames) { + const allowed = new Set((toolNames || []).filter(Boolean)); + const out = []; + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + if (allowed.size > 0 && !allowed.has(tc.name)) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + return out; +} + +module.exports = { + extractToolNames, + parseToolCalls, + parseStandaloneToolCalls, +}; diff --git a/internal/js/helpers/stream-tool-sieve/sieve.js b/internal/js/helpers/stream-tool-sieve/sieve.js new file mode 100644 index 0000000..699c3a8 --- /dev/null +++ b/internal/js/helpers/stream-tool-sieve/sieve.js @@ -0,0 +1,261 @@ +'use strict'; + +const { + TOOL_SIEVE_CAPTURE_LIMIT, + resetIncrementalToolState, + noteText, + insideCodeFence, +} = require('./state'); +const { + buildIncrementalToolDeltas, +} = require('./incremental'); +const { + parseStandaloneToolCalls, +} = require('./parse'); +const { + extractJSONObjectFrom, +} = require('./jsonscan'); + +function processToolSieveChunk(state, chunk, toolNames) { + if (!state) { + return []; + } + if (chunk) { + state.pending += chunk; + } + const events = []; + // eslint-disable-next-line no-constant-condition + while (true) { + if (state.capturing) { + if (state.pending) { + state.capture += state.pending; + state.pending = ''; + } + const deltas = buildIncrementalToolDeltas(state); + if (deltas.length > 0) { + events.push({ type: 'tool_call_deltas', deltas }); + } + const consumed = consumeToolCapture(state, toolNames); + if (!consumed.ready) { + if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) { + noteText(state, state.capture); + events.push({ type: 'text', text: state.capture }); + state.capture = ''; + state.capturing = false; + resetIncrementalToolState(state); + continue; + } + break; + } + state.capture = ''; + state.capturing = false; + resetIncrementalToolState(state); + if (consumed.prefix) { + noteText(state, consumed.prefix); + events.push({ type: 'text', text: consumed.prefix }); + } + if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { + events.push({ type: 'tool_calls', calls: consumed.calls }); + } + if (consumed.suffix) { + state.pending += consumed.suffix; + } + continue; + } + + if (!state.pending) { + break; + } + + const start = findToolSegmentStart(state.pending); + if (start >= 0) { + const prefix = state.pending.slice(0, start); + if (prefix) { + noteText(state, prefix); + events.push({ type: 'text', text: prefix }); + } + state.capture = state.pending.slice(start); + state.pending = ''; + state.capturing = true; + resetIncrementalToolState(state); + continue; + } + + const [safe, hold] = splitSafeContentForToolDetection(state.pending); + if (!safe) { + break; + } + state.pending = hold; + noteText(state, safe); + events.push({ type: 'text', text: safe }); + } + return events; +} + +function flushToolSieve(state, toolNames) { + if (!state) { + return []; + } + const events = processToolSieveChunk(state, '', toolNames); + if (state.capturing) { + const consumed = consumeToolCapture(state, toolNames); + if (consumed.ready) { + if (consumed.prefix) { + noteText(state, consumed.prefix); + events.push({ type: 'text', text: consumed.prefix }); + } + if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { + events.push({ type: 'tool_calls', calls: consumed.calls }); + } + if (consumed.suffix) { + noteText(state, consumed.suffix); + events.push({ type: 'text', text: consumed.suffix }); + } + } else if (state.capture) { + noteText(state, state.capture); + events.push({ type: 'text', text: state.capture }); + } + state.capture = ''; + state.capturing = false; + resetIncrementalToolState(state); + } + if (state.pending) { + noteText(state, state.pending); + events.push({ type: 'text', text: state.pending }); + state.pending = ''; + } + return events; +} + +function splitSafeContentForToolDetection(s) { + const text = s || ''; + if (!text) { + return ['', '']; + } + const suspiciousStart = findSuspiciousPrefixStart(text); + if (suspiciousStart < 0) { + return [text, '']; + } + if (suspiciousStart > 0) { + return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)]; + } + // If suspicious content starts at the beginning, keep holding until we can + // either parse a full tool JSON block or reach stream flush. + return ['', text]; +} + +function findSuspiciousPrefixStart(s) { + let start = -1; + for (const needle of ['{', '[', '```']) { + const idx = s.lastIndexOf(needle); + if (idx > start) { + start = idx; + } + } + return start; +} + +function findToolSegmentStart(s) { + if (!s) { + return -1; + } + const lower = s.toLowerCase(); + let offset = 0; + // eslint-disable-next-line no-constant-condition + while (true) { + const keyRel = lower.indexOf('tool_calls', offset); + if (keyRel < 0) { + return -1; + } + const keyIdx = keyRel; + const start = s.slice(0, keyIdx).lastIndexOf('{'); + const candidateStart = start >= 0 ? start : keyIdx; + if (!insideCodeFence(s.slice(0, candidateStart))) { + return candidateStart; + } + offset = keyIdx + 'tool_calls'.length; + } +} + +function consumeToolCapture(state, toolNames) { + const captured = state.capture; + if (!captured) { + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const lower = captured.toLowerCase(); + const keyIdx = lower.indexOf('tool_calls'); + if (keyIdx < 0) { + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const start = captured.slice(0, keyIdx).lastIndexOf('{'); + if (start < 0) { + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const obj = extractJSONObjectFrom(captured, start); + if (!obj.ok) { + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const prefixPart = captured.slice(0, start); + const suffixPart = captured.slice(obj.end); + if (insideCodeFence((state.recentTextTail || '') + prefixPart)) { + return { + ready: true, + prefix: captured, + calls: [], + suffix: '', + }; + } + const rawParsed = parseStandaloneToolCalls(captured.slice(start, obj.end), []); + const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames); + if (parsed.length === 0) { + if (rawParsed.length > 0 && Array.isArray(toolNames) && toolNames.length > 0) { + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, + }; + } + if (state.toolNameSent) { + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: captured, + calls: [], + suffix: '', + }; + } + if (state.toolNameSent) { + if (parsed.length > 1) { + return { + ready: true, + prefix: prefixPart, + calls: parsed.slice(1), + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: prefixPart, + calls: parsed, + suffix: suffixPart, + }; +} + +module.exports = { + processToolSieveChunk, + flushToolSieve, +}; diff --git a/internal/js/helpers/stream-tool-sieve/state.js b/internal/js/helpers/stream-tool-sieve/state.js new file mode 100644 index 0000000..a2d2b5c --- /dev/null +++ b/internal/js/helpers/stream-tool-sieve/state.js @@ -0,0 +1,91 @@ +'use strict'; + +const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024; +const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256; + +function createToolSieveState() { + return { + pending: '', + capture: '', + capturing: false, + recentTextTail: '', + toolNameSent: false, + toolName: '', + toolArgsStart: -1, + toolArgsSent: -1, + toolArgsString: false, + toolArgsDone: false, + }; +} + +function resetIncrementalToolState(state) { + state.toolNameSent = false; + state.toolName = ''; + state.toolArgsStart = -1; + state.toolArgsSent = -1; + state.toolArgsString = false; + state.toolArgsDone = false; +} + +function noteText(state, text) { + if (!state || !hasMeaningfulText(text)) { + return; + } + state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT); +} + +function appendTail(prev, next, max) { + const left = typeof prev === 'string' ? prev : ''; + const right = typeof next === 'string' ? next : ''; + if (!Number.isFinite(max) || max <= 0) { + return ''; + } + const combined = left + right; + if (combined.length <= max) { + return combined; + } + return combined.slice(combined.length - max); +} + +function looksLikeToolExampleContext(text) { + return insideCodeFence(text); +} + +function insideCodeFence(text) { + const t = typeof text === 'string' ? text : ''; + if (!t) { + return false; + } + const ticks = (t.match(/```/g) || []).length; + return ticks % 2 === 1; +} + +function hasMeaningfulText(text) { + return toStringSafe(text) !== ''; +} + +function toStringSafe(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return toStringSafe(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +module.exports = { + TOOL_SIEVE_CAPTURE_LIMIT, + TOOL_SIEVE_CONTEXT_TAIL_LIMIT, + createToolSieveState, + resetIncrementalToolState, + noteText, + appendTail, + looksLikeToolExampleContext, + insideCodeFence, + hasMeaningfulText, + toStringSafe, +}; diff --git a/internal/js/shared/deepseek-constants.js b/internal/js/shared/deepseek-constants.js new file mode 100644 index 0000000..1ec74f1 --- /dev/null +++ b/internal/js/shared/deepseek-constants.js @@ -0,0 +1,66 @@ +'use strict'; + +const fs = require('fs'); +const path = require('path'); + +const DEFAULT_BASE_HEADERS = Object.freeze({ + Host: 'chat.deepseek.com', + 'User-Agent': 'DeepSeek/1.6.11 Android/35', + Accept: 'application/json', + 'Content-Type': 'application/json', + 'x-client-platform': 'android', + 'x-client-version': '1.6.11', + 'x-client-locale': 'zh_CN', + 'accept-charset': 'UTF-8', +}); + +const DEFAULT_SKIP_PATTERNS = Object.freeze([ + 'quasi_status', + 'elapsed_secs', + 'token_usage', + 'pending_fragment', + 'conversation_mode', + 'fragments/-1/status', + 'fragments/-2/status', + 'fragments/-3/status', +]); + +const DEFAULT_SKIP_EXACT_PATHS = Object.freeze([ + 'response/search_status', +]); + +function loadSharedConstants() { + const sharedPath = path.resolve(__dirname, '../../internal/deepseek/constants_shared.json'); + try { + const raw = fs.readFileSync(sharedPath, 'utf8'); + const parsed = JSON.parse(raw); + const baseHeaders = parsed && typeof parsed.base_headers === 'object' && !Array.isArray(parsed.base_headers) + ? { ...DEFAULT_BASE_HEADERS, ...parsed.base_headers } + : { ...DEFAULT_BASE_HEADERS }; + const skipPatterns = Array.isArray(parsed && parsed.skip_contains_patterns) + ? parsed.skip_contains_patterns.filter((v) => typeof v === 'string' && v !== '') + : [...DEFAULT_SKIP_PATTERNS]; + const skipExactPaths = Array.isArray(parsed && parsed.skip_exact_paths) + ? parsed.skip_exact_paths.filter((v) => typeof v === 'string' && v !== '') + : [...DEFAULT_SKIP_EXACT_PATHS]; + return { + baseHeaders, + skipPatterns, + skipExactPaths, + }; + } catch (_err) { + return { + baseHeaders: { ...DEFAULT_BASE_HEADERS }, + skipPatterns: [...DEFAULT_SKIP_PATTERNS], + skipExactPaths: [...DEFAULT_SKIP_EXACT_PATHS], + }; + } +} + +const shared = loadSharedConstants(); + +module.exports = { + BASE_HEADERS: Object.freeze(shared.baseHeaders), + SKIP_PATTERNS: Object.freeze(shared.skipPatterns), + SKIP_EXACT_PATHS: new Set(shared.skipExactPaths), +}; diff --git a/internal/prompt/messages.go b/internal/prompt/messages.go new file mode 100644 index 0000000..69cfe5a --- /dev/null +++ b/internal/prompt/messages.go @@ -0,0 +1,84 @@ +package prompt + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) + +func MessagesPrepare(messages []map[string]any) string { + type block struct { + Role string + Text string + } + processed := make([]block, 0, len(messages)) + for _, m := range messages { + role, _ := m["role"].(string) + text := NormalizeContent(m["content"]) + processed = append(processed, block{Role: role, Text: text}) + } + if len(processed) == 0 { + return "" + } + merged := make([]block, 0, len(processed)) + for _, msg := range processed { + if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role { + merged[len(merged)-1].Text += "\n\n" + msg.Text + continue + } + merged = append(merged, msg) + } + parts := make([]string, 0, len(merged)) + for i, m := range merged { + switch m.Role { + case "assistant": + parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>") + case "user", "system": + if i > 0 { + parts = append(parts, "<|User|>"+m.Text) + } else { + parts = append(parts, m.Text) + } + default: + parts = append(parts, m.Text) + } + } + out := strings.Join(parts, "") + return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`) +} + +func NormalizeContent(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, item := range x { + m, ok := item.(map[string]any) + if !ok { + continue + } + typeStr, _ := m["type"].(string) + typeStr = strings.ToLower(strings.TrimSpace(typeStr)) + if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" { + if txt, ok := m["text"].(string); ok { + parts = append(parts, txt) + continue + } + if txt, ok := m["content"].(string); ok { + parts = append(parts, txt) + } + } + } + return strings.Join(parts, "\n") + default: + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(b) + } +} diff --git a/internal/server/router.go b/internal/server/router.go index c6339fb..ae3108e 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -12,6 +12,7 @@ import ( "ds2api/internal/account" "ds2api/internal/adapter/claude" + "ds2api/internal/adapter/gemini" "ds2api/internal/adapter/openai" "ds2api/internal/admin" "ds2api/internal/auth" @@ -44,6 +45,7 @@ func NewApp() *App { openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient} claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient} + geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient} adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient} webuiHandler := webui.NewHandler() @@ -67,6 +69,7 @@ func NewApp() *App { }) openai.RegisterRoutes(r, openaiHandler) claude.RegisterRoutes(r, claudeHandler) + gemini.RegisterRoutes(r, geminiHandler) r.Route("/admin", func(ar chi.Router) { admin.RegisterRoutes(ar, adminHandler) }) @@ -92,7 +95,7 @@ func cors(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return diff --git a/internal/sse/consumer_edge_test.go b/internal/sse/consumer_edge_test.go new file mode 100644 index 0000000..8f78f01 --- /dev/null +++ b/internal/sse/consumer_edge_test.go @@ -0,0 +1,140 @@ +package sse + +import ( + "io" + "net/http" + "strings" + "testing" +) + +// ─── CollectStream edge cases ──────────────────────────────────────── + +func makeHTTPResponse(body string) *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestCollectStreamEmpty(t *testing.T) { + resp := makeHTTPResponse("") + result := CollectStream(resp, false, false) + if result.Text != "" || result.Thinking != "" { + t.Fatalf("expected empty result, got text=%q think=%q", result.Text, result.Thinking) + } +} + +func TestCollectStreamTextOnly(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\" World\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "Hello World" { + t.Fatalf("expected 'Hello World', got %q", result.Text) + } + if result.Thinking != "" { + t.Fatalf("expected no thinking, got %q", result.Thinking) + } +} + +func TestCollectStreamThinkingAndText(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"Thinking...\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"Answer\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Thinking..." { + t.Fatalf("expected 'Thinking...', got %q", result.Thinking) + } + if result.Text != "Answer" { + t.Fatalf("expected 'Answer', got %q", result.Text) + } +} + +func TestCollectStreamOnlyThinking(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"Only thinking\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Only thinking" { + t.Fatalf("expected 'Only thinking', got %q", result.Thinking) + } + if result.Text != "" { + t.Fatalf("expected empty text, got %q", result.Text) + } +} + +func TestCollectStreamSkipsInvalidLines(t *testing.T) { + resp := makeHTTPResponse( + "event: comment\n" + + "data: invalid_json\n" + + "data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "valid" { + t.Fatalf("expected 'valid', got %q", result.Text) + } +} + +func TestCollectStreamWithFragments(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"Think\"}]}\n" + + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"Done\"}]}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "Think" { + t.Fatalf("expected 'Think' thinking, got %q", result.Thinking) + } + if result.Text != "Done" { + t.Fatalf("expected 'Done' text, got %q", result.Text) + } +} + +func TestCollectStreamWithCitation(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"[citation:1] cited text\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\" more\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + // CollectStream does NOT filter citations (that's done by the adapters) + // So citations are passed through as-is + if !strings.Contains(result.Text, "[citation:1]") { + t.Fatalf("expected citations to be passed through, got %q", result.Text) + } + if result.Text != "Hello[citation:1] cited text more" { + t.Fatalf("expected full text with citation, got %q", result.Text) + } +} + +func TestCollectStreamMultipleThinkingChunks(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/thinking_content\",\"v\":\"part1\"}\n" + + "data: {\"p\":\"response/thinking_content\",\"v\":\" part2\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"answer\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, true, true) + if result.Thinking != "part1 part2" { + t.Fatalf("expected 'part1 part2', got %q", result.Thinking) + } +} + +func TestCollectStreamStatusFinished(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Hello\"}\n" + + "data: {\"p\":\"response/status\",\"v\":\"FINISHED\"}\n", + ) + result := CollectStream(resp, false, false) + if result.Text != "Hello" { + t.Fatalf("expected 'Hello', got %q", result.Text) + } +} diff --git a/internal/sse/line_edge_test.go b/internal/sse/line_edge_test.go new file mode 100644 index 0000000..2ae53a6 --- /dev/null +++ b/internal/sse/line_edge_test.go @@ -0,0 +1,70 @@ +package sse + +import "testing" + +func TestParseDeepSeekContentLineNotParsed(t *testing.T) { + res := ParseDeepSeekContentLine([]byte("not a data line"), false, "text") + if res.Parsed { + t.Fatal("expected not parsed") + } + if res.NextType != "text" { + t.Fatalf("expected nextType preserved, got %q", res.NextType) + } +} + +func TestParseDeepSeekContentLinePreservesNextType(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/thinking_content","v":"think"}`), true, "thinking") + if !res.Parsed || res.Stop { + t.Fatalf("expected parsed non-stop: %#v", res) + } + if len(res.Parts) != 1 || res.Parts[0].Type != "thinking" { + t.Fatalf("unexpected parts: %#v", res.Parts) + } +} + +func TestParseDeepSeekContentLineFragmentSwitchType(t *testing.T) { + res := ParseDeepSeekContentLine( + []byte(`data: {"p":"response/fragments","o":"APPEND","v":[{"type":"RESPONSE","content":"hi"}]}`), + true, "thinking", + ) + if !res.Parsed || res.Stop { + t.Fatalf("expected parsed non-stop: %#v", res) + } + if res.NextType != "text" { + t.Fatalf("expected nextType text after RESPONSE fragment, got %q", res.NextType) + } +} + +func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"code":"content_filter"}`), false, "text") + if !res.ContentFilter { + t.Fatal("expected content filter flag") + } + if res.ErrorMessage == "" { + t.Fatal("expected error message on content filter") + } +} + +func TestParseDeepSeekContentLineErrorObjectFormat(t *testing.T) { + res := ParseDeepSeekContentLine([]byte(`data: {"error":{"message":"rate limit","code":429}}`), false, "text") + if !res.Parsed || !res.Stop { + t.Fatalf("expected parsed stop: %#v", res) + } + if res.ErrorMessage == "" { + t.Fatal("expected non-empty error message") + } +} + +func TestParseDeepSeekContentLineInvalidJSON(t *testing.T) { + res := ParseDeepSeekContentLine([]byte("data: {broken"), false, "text") + if res.Parsed { + t.Fatal("expected not parsed for broken JSON") + } +} + +func TestParseDeepSeekContentLineEmptyBytes(t *testing.T) { + res := ParseDeepSeekContentLine([]byte{}, false, "text") + if res.Parsed { + t.Fatal("expected not parsed for empty bytes") + } +} diff --git a/internal/sse/parser.go b/internal/sse/parser.go index 38429d9..c20bc79 100644 --- a/internal/sse/parser.go +++ b/internal/sse/parser.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/json" "strings" + + "ds2api/internal/deepseek" ) type ContentPart struct { @@ -11,11 +13,6 @@ type ContentPart struct { Type string } -var skipPatterns = []string{ - "quasi_status", "elapsed_secs", "token_usage", "pending_fragment", "conversation_mode", - "fragments/-1/status", "fragments/-2/status", "fragments/-3/status", -} - func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) { line := strings.TrimSpace(string(raw)) if line == "" || !strings.HasPrefix(line, "data:") { @@ -33,10 +30,10 @@ func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) { } func shouldSkipPath(path string) bool { - if path == "response/search_status" { + if _, ok := deepseek.SkipExactPathSet[path]; ok { return true } - for _, p := range skipPatterns { + for _, p := range deepseek.SkipContainsPatterns { if strings.Contains(path, p) { return true } @@ -60,126 +57,159 @@ func ParseSSEChunkForContent(chunk map[string]any, thinkingEnabled bool, current } newType := currentFragmentType parts := make([]ContentPart, 0, 8) + collectDirectFragments(path, chunk, v, &newType, &parts) + updateTypeFromNestedResponse(path, v, &newType) + partType := resolvePartType(path, thinkingEnabled, newType) + finished := appendChunkValueContent(v, partType, &newType, &parts, path) + if finished { + return nil, true, newType + } + return parts, false, newType +} - // Newer DeepSeek responses may emit fragment APPEND directly on - // path "response/fragments" instead of wrapping it in path "response". - if path == "response/fragments" { - if op, _ := chunk["o"].(string); strings.EqualFold(op, "APPEND") { - if frags, ok := v.([]any); ok { - for _, frag := range frags { - fm, ok := frag.(map[string]any) - if !ok { - continue - } - t, _ := fm["type"].(string) - content, _ := fm["content"].(string) - t = strings.ToUpper(t) - switch t { - case "THINK", "THINKING": - newType = "thinking" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "thinking"}) - } - case "RESPONSE": - newType = "text" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "text"}) - } - default: - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "text"}) - } - } - } +func collectDirectFragments(path string, chunk map[string]any, v any, newType *string, parts *[]ContentPart) { + if path != "response/fragments" { + return + } + op, _ := chunk["o"].(string) + if !strings.EqualFold(op, "APPEND") { + return + } + frags, ok := v.([]any) + if !ok { + return + } + for _, frag := range frags { + m, ok := frag.(map[string]any) + if !ok { + continue + } + typeName, content, fragType := parseFragmentTypeContent(m) + if typeName == "" { + typeName = fragType + } + switch typeName { + case "THINK", "THINKING": + *newType = "thinking" + appendContentPart(parts, content, "thinking") + case "RESPONSE": + *newType = "text" + appendContentPart(parts, content, "text") + default: + appendContentPart(parts, content, "text") + } + } +} + +func updateTypeFromNestedResponse(path string, v any, newType *string) { + if path != "response" { + return + } + arr, ok := v.([]any) + if !ok { + return + } + for _, it := range arr { + m, ok := it.(map[string]any) + if !ok || m["p"] != "fragments" || m["o"] != "APPEND" { + continue + } + frags, ok := m["v"].([]any) + if !ok { + continue + } + for _, frag := range frags { + fm, ok := frag.(map[string]any) + if !ok { + continue + } + typeName, _, _ := parseFragmentTypeContent(fm) + switch typeName { + case "THINK", "THINKING": + *newType = "thinking" + case "RESPONSE": + *newType = "text" } } } +} - if path == "response" { - if arr, ok := v.([]any); ok { - for _, it := range arr { - m, ok := it.(map[string]any) - if !ok { - continue - } - if m["p"] == "fragments" && m["o"] == "APPEND" { - if frags, ok := m["v"].([]any); ok { - for _, frag := range frags { - fm, ok := frag.(map[string]any) - if !ok { - continue - } - t, _ := fm["type"].(string) - t = strings.ToUpper(t) - if t == "THINK" || t == "THINKING" { - newType = "thinking" - } else if t == "RESPONSE" { - newType = "text" - } - } - } - } - } - } - } - partType := "text" +func resolvePartType(path string, thinkingEnabled bool, newType string) string { switch { case path == "response/thinking_content": - partType = "thinking" + return "thinking" case path == "response/content": - partType = "text" + return "text" case strings.Contains(path, "response/fragments") && strings.Contains(path, "/content"): - partType = newType - case path == "": - if thinkingEnabled { - partType = newType - } + return newType + case path == "" && thinkingEnabled: + return newType + default: + return "text" } +} + +func appendChunkValueContent(v any, partType string, newType *string, parts *[]ContentPart, path string) bool { switch val := v.(type) { case string: if val == "FINISHED" && (path == "" || path == "status") { - return nil, true, newType - } - if val != "" { - parts = append(parts, ContentPart{Text: val, Type: partType}) + return true } + appendContentPart(parts, val, partType) case []any: pp, finished := extractContentRecursive(val, partType) if finished { - return nil, true, newType + return true } - parts = append(parts, pp...) + *parts = append(*parts, pp...) case map[string]any: - resp := val - if wrapped, ok := val["response"].(map[string]any); ok { - resp = wrapped + appendWrappedFragments(val, partType, newType, parts) + } + return false +} + +func appendWrappedFragments(val map[string]any, partType string, newType *string, parts *[]ContentPart) { + resp := val + if wrapped, ok := val["response"].(map[string]any); ok { + resp = wrapped + } + frags, ok := resp["fragments"].([]any) + if !ok { + return + } + for _, item := range frags { + m, ok := item.(map[string]any) + if !ok { + continue } - if frags, ok := resp["fragments"].([]any); ok { - for _, item := range frags { - m, ok := item.(map[string]any) - if !ok { - continue - } - t, _ := m["type"].(string) - content, _ := m["content"].(string) - t = strings.ToUpper(t) - if t == "THINK" || t == "THINKING" { - newType = "thinking" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "thinking"}) - } - } else if t == "RESPONSE" { - newType = "text" - if content != "" { - parts = append(parts, ContentPart{Text: content, Type: "text"}) - } - } else if content != "" { - parts = append(parts, ContentPart{Text: content, Type: partType}) - } - } + typeName, content, fragType := parseFragmentTypeContent(m) + if typeName == "" { + typeName = fragType + } + switch typeName { + case "THINK", "THINKING": + *newType = "thinking" + appendContentPart(parts, content, "thinking") + case "RESPONSE": + *newType = "text" + appendContentPart(parts, content, "text") + default: + appendContentPart(parts, content, partType) } } - return parts, false, newType +} + +func parseFragmentTypeContent(m map[string]any) (string, string, string) { + typeName, _ := m["type"].(string) + content, _ := m["content"].(string) + return strings.ToUpper(typeName), content, strings.ToUpper(typeName) +} + +func appendContentPart(parts *[]ContentPart, content, kind string) { + if content == "" { + return + } + *parts = append(*parts, ContentPart{Text: content, Type: kind}) } func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bool) { diff --git a/internal/sse/parser_edge_test.go b/internal/sse/parser_edge_test.go new file mode 100644 index 0000000..c851c1f --- /dev/null +++ b/internal/sse/parser_edge_test.go @@ -0,0 +1,631 @@ +package sse + +import "testing" + +// ─── ParseDeepSeekSSELine edge cases ───────────────────────────────── + +func TestParseDeepSeekSSELineEmptyLine(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("")) + if ok { + t.Fatal("expected not parsed for empty line") + } +} + +func TestParseDeepSeekSSELineNoDataPrefix(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("event: message")) + if ok { + t.Fatal("expected not parsed for non-data line") + } +} + +func TestParseDeepSeekSSELineInvalidJSON(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte("data: {invalid json")) + if ok { + t.Fatal("expected not parsed for invalid JSON") + } +} + +func TestParseDeepSeekSSELineWhitespaceOnly(t *testing.T) { + _, _, ok := ParseDeepSeekSSELine([]byte(" ")) + if ok { + t.Fatal("expected not parsed for whitespace-only line") + } +} + +func TestParseDeepSeekSSELineDataWithExtraSpaces(t *testing.T) { + chunk, done, ok := ParseDeepSeekSSELine([]byte(`data: {"v":"hello"} `)) + if !ok || done { + t.Fatalf("expected parsed chunk for spaced data line") + } + if chunk["v"] != "hello" { + t.Fatalf("unexpected chunk: %#v", chunk) + } +} + +// ─── shouldSkipPath edge cases ─────────────────────────────────────── + +func TestShouldSkipPathQuasiStatus(t *testing.T) { + if !shouldSkipPath("response/quasi_status") { + t.Fatal("expected skip for quasi_status path") + } +} + +func TestShouldSkipPathElapsedSecs(t *testing.T) { + if !shouldSkipPath("response/elapsed_secs") { + t.Fatal("expected skip for elapsed_secs path") + } +} + +func TestShouldSkipPathTokenUsage(t *testing.T) { + if !shouldSkipPath("response/token_usage") { + t.Fatal("expected skip for token_usage path") + } +} + +func TestShouldSkipPathPendingFragment(t *testing.T) { + if !shouldSkipPath("response/pending_fragment") { + t.Fatal("expected skip for pending_fragment path") + } +} + +func TestShouldSkipPathConversationMode(t *testing.T) { + if !shouldSkipPath("response/conversation_mode") { + t.Fatal("expected skip for conversation_mode path") + } +} + +func TestShouldSkipPathSearchStatus(t *testing.T) { + if !shouldSkipPath("response/search_status") { + t.Fatal("expected skip for search_status path") + } +} + +func TestShouldSkipPathFragmentStatus(t *testing.T) { + if !shouldSkipPath("response/fragments/-1/status") { + t.Fatal("expected skip for fragment -1 status") + } + if !shouldSkipPath("response/fragments/-2/status") { + t.Fatal("expected skip for fragment -2 status") + } + if !shouldSkipPath("response/fragments/-3/status") { + t.Fatal("expected skip for fragment -3 status") + } +} + +func TestShouldSkipPathRegularContent(t *testing.T) { + if shouldSkipPath("response/content") { + t.Fatal("expected not skip for content path") + } + if shouldSkipPath("response/thinking_content") { + t.Fatal("expected not skip for thinking_content path") + } +} + +// ─── ParseSSEChunkForContent edge cases ────────────────────────────── + +func TestParseSSEChunkForContentNoVField(t *testing.T) { + parts, finished, nextType := ParseSSEChunkForContent(map[string]any{"p": "response/content"}, false, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts when v is missing, got %#v", parts) + } + if nextType != "text" { + t.Fatalf("expected type preserved, got %q", nextType) + } +} + +func TestParseSSEChunkForContentSkippedPath(t *testing.T) { + parts, finished, nextType := ParseSSEChunkForContent(map[string]any{ + "p": "response/token_usage", + "v": "some data", + }, false, "text") + if finished || len(parts) > 0 { + t.Fatalf("expected skipped path to produce no output") + } + if nextType != "text" { + t.Fatalf("expected type preserved for skipped path") + } +} + +func TestParseSSEChunkForContentFinishedStatus(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/status", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on status FINISHED") + } + if len(parts) != 0 { + t.Fatalf("expected no parts on finished, got %d", len(parts)) + } +} + +func TestParseSSEChunkForContentStatusNotFinished(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/status", + "v": "IN_PROGRESS", + }, false, "text") + if finished { + t.Fatal("expected not finished for non-FINISHED status") + } + if len(parts) != 1 || parts[0].Text != "IN_PROGRESS" { + t.Fatalf("expected content for non-FINISHED status, got %#v", parts) + } +} + +func TestParseSSEChunkForContentEmptyStringV(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "response/content", + "v": "", + }, false, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts for empty string v, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFinishedOnEmptyPath(t *testing.T) { + parts, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on empty path with FINISHED value") + } + if len(parts) != 0 { + t.Fatalf("expected no parts on finished") + } +} + +func TestParseSSEChunkForContentFinishedOnStatusPath(t *testing.T) { + _, finished, _ := ParseSSEChunkForContent(map[string]any{ + "p": "status", + "v": "FINISHED", + }, false, "text") + if !finished { + t.Fatal("expected finished on status path with FINISHED value") + } +} + +func TestParseSSEChunkForContentThinkingPathEmptyPath(t *testing.T) { + parts, _, nextType := ParseSSEChunkForContent(map[string]any{ + "v": "some thought", + }, true, "thinking") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected thinking part on empty path, got %#v", parts) + } + if nextType != "thinking" { + t.Fatalf("expected nextType thinking, got %q", nextType) + } +} + +func TestParseSSEChunkForContentThinkingEnabledTextType(t *testing.T) { + parts, _, nextType := ParseSSEChunkForContent(map[string]any{ + "v": "text content", + }, true, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text part when currentType=text, got %#v", parts) + } + if nextType != "text" { + t.Fatalf("expected nextType text, got %q", nextType) + } +} + +// ─── ParseSSEChunkForContent: fragments path with THINK type ───────── + +func TestParseSSEChunkForContentFragmentsAppendThink(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "THINK", + "content": "深入思考...", + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + if nextType != "thinking" { + t.Fatalf("expected nextType thinking, got %q", nextType) + } + if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "深入思考..." { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendEmptyContent(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "", + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "thinking") + if finished { + t.Fatal("expected not finished") + } + if nextType != "text" { + t.Fatalf("expected nextType text, got %q", nextType) + } + if len(parts) != 0 { + t.Fatalf("expected no parts for empty content, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendDefaultType(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "UNKNOWN", + "content": "some text", + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, true, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for unknown fragment type, got %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendNonArray(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": "not an array", + } + parts, finished, _ := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + // "not an array" should be treated as string value at the end + if len(parts) != 1 || parts[0].Text != "not an array" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestParseSSEChunkForContentFragmentsAppendNonMap(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments", + "o": "APPEND", + "v": []any{"string item"}, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + // Non-map items in fragment array are skipped; the []any itself is handled later + _ = parts // just checking it doesn't panic +} + +// ─── ParseSSEChunkForContent: response path with nested fragment ───── + +func TestParseSSEChunkForContentResponsePathFragmentsAppend(t *testing.T) { + chunk := map[string]any{ + "p": "response", + "v": []any{ + map[string]any{ + "p": "fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "THINKING", + }, + }, + }, + }, + } + _, _, nextType := ParseSSEChunkForContent(chunk, true, "text") + if nextType != "thinking" { + t.Fatalf("expected nextType thinking from response path fragments, got %q", nextType) + } +} + +func TestParseSSEChunkForContentResponsePathResponseFragment(t *testing.T) { + chunk := map[string]any{ + "p": "response", + "v": []any{ + map[string]any{ + "p": "fragments", + "o": "APPEND", + "v": []any{ + map[string]any{ + "type": "RESPONSE", + }, + }, + }, + }, + } + _, _, nextType := ParseSSEChunkForContent(chunk, true, "thinking") + if nextType != "text" { + t.Fatalf("expected nextType text for RESPONSE fragment, got %q", nextType) + } +} + +// ─── ParseSSEChunkForContent: map value with wrapped response ──────── + +func TestParseSSEChunkForContentMapValueWithFragments(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "response": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "THINK", + "content": "思考...", + }, + map[string]any{ + "type": "RESPONSE", + "content": "回答...", + }, + }, + }, + }, + } + parts, finished, nextType := ParseSSEChunkForContent(chunk, true, "text") + if finished { + t.Fatal("expected not finished") + } + if nextType != "text" { + t.Fatalf("expected nextType text after RESPONSE, got %q", nextType) + } + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d: %#v", len(parts), parts) + } + if parts[0].Type != "thinking" || parts[0].Text != "思考..." { + t.Fatalf("first part mismatch: %#v", parts[0]) + } + if parts[1].Type != "text" || parts[1].Text != "回答..." { + t.Fatalf("second part mismatch: %#v", parts[1]) + } +} + +func TestParseSSEChunkForContentMapValueDirectFragments(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "直接回答", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 1 || parts[0].Text != "直接回答" || parts[0].Type != "text" { + t.Fatalf("unexpected parts for direct fragments: %#v", parts) + } +} + +func TestParseSSEChunkForContentMapValueUnknownType(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "CUSTOM", + "content": "custom content", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected partType fallback for unknown type, got %#v", parts) + } +} + +func TestParseSSEChunkForContentMapValueEmptyFragmentContent(t *testing.T) { + chunk := map[string]any{ + "v": map[string]any{ + "fragments": []any{ + map[string]any{ + "type": "RESPONSE", + "content": "", + }, + }, + }, + } + parts, _, _ := ParseSSEChunkForContent(chunk, false, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for empty fragment content, got %#v", parts) + } +} + +// ─── ParseSSEChunkForContent: fragments/-1/content path ────────────── + +func TestParseSSEChunkForContentFragmentContentPathInheritsType(t *testing.T) { + chunk := map[string]any{ + "p": "response/fragments/-1/content", + "v": "继续思考", + } + parts, _, _ := ParseSSEChunkForContent(chunk, true, "thinking") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected inherited thinking type, got %#v", parts) + } +} + +// ─── IsCitation edge cases ─────────────────────────────────────────── + +func TestIsCitationWithLeadingWhitespace(t *testing.T) { + if !IsCitation(" [citation:2] text") { + t.Fatal("expected citation true with leading whitespace") + } +} + +func TestIsCitationEmpty(t *testing.T) { + if IsCitation("") { + t.Fatal("expected citation false for empty string") + } +} + +func TestIsCitationSimilarPrefix(t *testing.T) { + if IsCitation("[cite:1] text") { + t.Fatal("expected citation false for [cite: prefix") + } +} + +// ─── extractContentRecursive edge cases ────────────────────────────── + +func TestExtractContentRecursiveFinishedStatus(t *testing.T) { + items := []any{ + map[string]any{"p": "status", "v": "FINISHED"}, + } + parts, finished := extractContentRecursive(items, "text") + if !finished { + t.Fatal("expected finished on status FINISHED") + } + if len(parts) != 0 { + t.Fatalf("expected no parts, got %#v", parts) + } +} + +func TestExtractContentRecursiveSkipsPath(t *testing.T) { + items := []any{ + map[string]any{"p": "token_usage", "v": "data"}, + } + parts, finished := extractContentRecursive(items, "text") + if finished { + t.Fatal("expected not finished") + } + if len(parts) != 0 { + t.Fatalf("expected no parts for skipped path, got %#v", parts) + } +} + +func TestExtractContentRecursiveContentField(t *testing.T) { + items := []any{ + map[string]any{"p": "x", "v": "val", "content": "actual content", "type": "RESPONSE"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Text != "actual content" || parts[0].Type != "text" { + t.Fatalf("unexpected parts: %#v", parts) + } +} + +func TestExtractContentRecursiveContentFieldThinkType(t *testing.T) { + items := []any{ + map[string]any{"p": "x", "v": "val", "content": "think text", "type": "THINK"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Type != "thinking" { + t.Fatalf("expected thinking type for THINK content, got %#v", parts) + } +} + +func TestExtractContentRecursiveThinkingPath(t *testing.T) { + items := []any{ + map[string]any{"p": "thinking_content", "v": "deep thought"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 1 || parts[0].Type != "thinking" || parts[0].Text != "deep thought" { + t.Fatalf("unexpected parts for thinking path: %#v", parts) + } +} + +func TestExtractContentRecursiveContentPath(t *testing.T) { + items := []any{ + map[string]any{"p": "content", "v": "text content"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for content path, got %#v", parts) + } +} + +func TestExtractContentRecursiveResponsePath(t *testing.T) { + items := []any{ + map[string]any{"p": "response", "v": "text content"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for response path, got %#v", parts) + } +} + +func TestExtractContentRecursiveFragmentsPath(t *testing.T) { + items := []any{ + map[string]any{"p": "fragments", "v": "fragment text"}, + } + parts, _ := extractContentRecursive(items, "thinking") + if len(parts) != 1 || parts[0].Type != "text" { + t.Fatalf("expected text type for fragments path, got %#v", parts) + } +} + +func TestExtractContentRecursiveNestedArrayWithTypes(t *testing.T) { + items := []any{ + map[string]any{ + "p": "fragments", + "v": []any{ + map[string]any{"content": "thought", "type": "THINKING"}, + map[string]any{"content": "answer", "type": "RESPONSE"}, + "raw string", + }, + }, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 3 { + t.Fatalf("expected 3 parts, got %d: %#v", len(parts), parts) + } + if parts[0].Type != "thinking" || parts[0].Text != "thought" { + t.Fatalf("first part mismatch: %#v", parts[0]) + } + if parts[1].Type != "text" || parts[1].Text != "answer" { + t.Fatalf("second part mismatch: %#v", parts[1]) + } + if parts[2].Type != "text" || parts[2].Text != "raw string" { + t.Fatalf("third part mismatch: %#v", parts[2]) + } +} + +func TestExtractContentRecursiveEmptyContentSkipped(t *testing.T) { + items := []any{ + map[string]any{ + "p": "fragments", + "v": []any{ + map[string]any{"content": "", "type": "RESPONSE"}, + }, + }, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for empty nested content, got %#v", parts) + } +} + +func TestExtractContentRecursiveFinishedString(t *testing.T) { + items := []any{ + map[string]any{"p": "content", "v": "FINISHED"}, + } + parts, _ := extractContentRecursive(items, "text") + // "FINISHED" string value on non-status path should be skipped + if len(parts) != 0 { + t.Fatalf("expected FINISHED string to be skipped, got %#v", parts) + } +} + +func TestExtractContentRecursiveNoVField(t *testing.T) { + items := []any{ + map[string]any{"p": "content"}, + } + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for missing v field, got %#v", parts) + } +} + +func TestExtractContentRecursiveNonMapItem(t *testing.T) { + items := []any{"just a string", 42} + parts, _ := extractContentRecursive(items, "text") + if len(parts) != 0 { + t.Fatalf("expected no parts for non-map items, got %#v", parts) + } +} diff --git a/internal/sse/stream_edge_test.go b/internal/sse/stream_edge_test.go new file mode 100644 index 0000000..927b023 --- /dev/null +++ b/internal/sse/stream_edge_test.go @@ -0,0 +1,177 @@ +package sse + +import ( + "context" + "io" + "strings" + "testing" +) + +func TestStartParsedLinePumpEmptyBody(t *testing.T) { + body := strings.NewReader("") + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(collected) != 0 { + t.Fatalf("expected no results for empty body, got %d", len(collected)) + } +} + +func TestStartParsedLinePumpMultipleLines(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/thinking_content\",\"v\":\"think\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"text\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, true, "thinking") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + if err := <-done; err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(collected) < 3 { + t.Fatalf("expected at least 3 results, got %d", len(collected)) + } + // First should be thinking + if collected[0].Parts[0].Type != "thinking" { + t.Fatalf("expected first part thinking, got %q", collected[0].Parts[0].Type) + } + // Last should be stop + last := collected[len(collected)-1] + if !last.Stop { + t.Fatal("expected last result to be stop") + } +} + +func TestStartParsedLinePumpTypeTracking(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"THINK\",\"content\":\"思\"}]}\n" + + "data: {\"p\":\"response/fragments/-1/content\",\"v\":\"考\"}\n" + + "data: {\"p\":\"response/fragments\",\"o\":\"APPEND\",\"v\":[{\"type\":\"RESPONSE\",\"content\":\"答\"}]}\n" + + "data: {\"p\":\"response/fragments/-1/content\",\"v\":\"案\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, true, "text") + + types := make([]string, 0) + for r := range results { + for _, p := range r.Parts { + types = append(types, p.Type) + } + } + <-done + + // Should have: thinking, thinking, text, text + expected := []string{"thinking", "thinking", "text", "text"} + if len(types) != len(expected) { + t.Fatalf("expected types %v, got %v", expected, types) + } + for i, want := range expected { + if types[i] != want { + t.Fatalf("type[%d] mismatch: want %q got %q (all=%v)", i, want, types[i], types) + } + } +} + +func TestStartParsedLinePumpContextCancellation(t *testing.T) { + pr, pw := io.Pipe() + + ctx, cancel := context.WithCancel(context.Background()) + results, done := StartParsedLinePump(ctx, pr, false, "text") + + // Write one line to allow it to start + go func() { + _, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n") + // Don't close yet - wait for context cancel + }() + + // Read first result + r := <-results + if !r.Parsed || len(r.Parts) == 0 { + t.Fatalf("expected first parsed result, got %#v", r) + } + + // Cancel context - this will cause the pump to exit on next send + cancel() + // Close the pipe to unblock scanner.Scan() + pw.Close() + + // Drain remaining results + for range results { + } + + err := <-done + // Error may be context.Canceled or nil (if pipe closed first) + if err != nil && err != context.Canceled { + t.Fatalf("expected context.Canceled or nil error, got %v", err) + } +} + +func TestStartParsedLinePumpOnlyDONE(t *testing.T) { + body := strings.NewReader("data: [DONE]\n") + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + collected := make([]LineResult, 0) + for r := range results { + collected = append(collected, r) + } + <-done + + if len(collected) != 1 { + t.Fatalf("expected 1 result, got %d", len(collected)) + } + if !collected[0].Stop { + t.Fatal("expected stop on [DONE]") + } +} + +func TestStartParsedLinePumpNonSSELines(t *testing.T) { + body := strings.NewReader( + "event: update\n" + + ": comment line\n" + + "data: {\"p\":\"response/content\",\"v\":\"valid\"}\n" + + "data: [DONE]\n", + ) + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + var validCount int + for r := range results { + if r.Parsed && len(r.Parts) > 0 { + validCount++ + } + } + <-done + + if validCount != 1 { + t.Fatalf("expected 1 valid result, got %d", validCount) + } +} + +func TestStartParsedLinePumpThinkingDisabled(t *testing.T) { + body := strings.NewReader( + "data: {\"p\":\"response/thinking_content\",\"v\":\"thought\"}\n" + + "data: {\"p\":\"response/content\",\"v\":\"response\"}\n" + + "data: [DONE]\n", + ) + // With thinking disabled, thinking content should still be emitted but marked differently + results, done := StartParsedLinePump(context.Background(), body, false, "text") + + var parts []ContentPart + for r := range results { + parts = append(parts, r.Parts...) + } + <-done + + if len(parts) < 1 { + t.Fatalf("expected at least 1 part, got %d", len(parts)) + } +} diff --git a/internal/stream/engine.go b/internal/stream/engine.go new file mode 100644 index 0000000..c63cd7b --- /dev/null +++ b/internal/stream/engine.go @@ -0,0 +1,128 @@ +package stream + +import ( + "context" + "io" + "time" + + "ds2api/internal/sse" +) + +type StopReason string + +const ( + StopReasonNone StopReason = "" + StopReasonContextCancelled StopReason = "context_cancelled" + StopReasonNoContentTimeout StopReason = "no_content_timeout" + StopReasonIdleTimeout StopReason = "idle_timeout" + StopReasonUpstreamCompleted StopReason = "upstream_completed" + StopReasonHandlerRequested StopReason = "handler_requested" +) + +type ConsumeConfig struct { + Context context.Context + Body io.Reader + ThinkingEnabled bool + InitialType string + KeepAliveInterval time.Duration + IdleTimeout time.Duration + MaxKeepAliveNoInput int +} + +type ParsedDecision struct { + Stop bool + StopReason StopReason + ContentSeen bool +} + +type ConsumeHooks struct { + OnParsed func(parsed sse.LineResult) ParsedDecision + OnKeepAlive func() + OnFinalize func(reason StopReason, scannerErr error) + OnContextDone func() +} + +func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) { + if cfg.Context == nil { + cfg.Context = context.Background() + } + initialType := cfg.InitialType + if initialType == "" { + if cfg.ThinkingEnabled { + initialType = "thinking" + } else { + initialType = "text" + } + } + parsedLines, done := sse.StartParsedLinePump(cfg.Context, cfg.Body, cfg.ThinkingEnabled, initialType) + + var ticker *time.Ticker + if cfg.KeepAliveInterval > 0 { + ticker = time.NewTicker(cfg.KeepAliveInterval) + defer ticker.Stop() + } + + hasContent := false + lastContent := time.Now() + keepaliveCount := 0 + + finalize := func(reason StopReason, scannerErr error) { + if hooks.OnFinalize != nil { + hooks.OnFinalize(reason, scannerErr) + } + } + + for { + select { + case <-cfg.Context.Done(): + if hooks.OnContextDone != nil { + hooks.OnContextDone() + } + return + case <-tickCh(ticker): + if !hasContent { + keepaliveCount++ + if cfg.MaxKeepAliveNoInput > 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput { + finalize(StopReasonNoContentTimeout, nil) + return + } + } + if hasContent && cfg.IdleTimeout > 0 && time.Since(lastContent) > cfg.IdleTimeout { + finalize(StopReasonIdleTimeout, nil) + return + } + if hooks.OnKeepAlive != nil { + hooks.OnKeepAlive() + } + case parsed, ok := <-parsedLines: + if !ok { + finalize(StopReasonUpstreamCompleted, <-done) + return + } + if hooks.OnParsed == nil { + continue + } + decision := hooks.OnParsed(parsed) + if decision.ContentSeen { + hasContent = true + lastContent = time.Now() + keepaliveCount = 0 + } + if decision.Stop { + reason := decision.StopReason + if reason == StopReasonNone { + reason = StopReasonHandlerRequested + } + finalize(reason, nil) + return + } + } + } +} + +func tickCh(ticker *time.Ticker) <-chan time.Time { + if ticker == nil { + return nil + } + return ticker.C +} diff --git a/internal/testsuite/edge_cases.go b/internal/testsuite/edge_cases.go index cba0b5a..50bc8ac 100644 --- a/internal/testsuite/edge_cases.go +++ b/internal/testsuite/edge_cases.go @@ -1,7 +1,6 @@ package testsuite import ( - "bytes" "context" "encoding/json" "fmt" @@ -125,72 +124,6 @@ func (r *Runner) caseStreamAbortRelease(ctx context.Context, cc *caseContext) er return nil } -func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error { - cc.seq++ - traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq) - cc.traceIDsSet[traceID] = struct{}{} - fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) - if err != nil { - return err - } - headers := map[string]string{} - for k, v := range spec.Headers { - headers[k] = v - } - headers["X-Ds2-Test-Trace"] = traceID - bodyBytes, _ := json.Marshal(spec.Body) - headers["Content-Type"] = "application/json" - cc.requests = append(cc.requests, requestLog{ - Seq: cc.seq, - Attempt: 1, - TraceID: traceID, - Method: spec.Method, - URL: fullURL, - Headers: headers, - Body: spec.Body, - Timestamp: time.Now().Format(time.RFC3339Nano), - }) - - reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) - defer cancel() - req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes)) - if err != nil { - return err - } - for k, v := range headers { - req.Header.Set(k, v) - } - start := time.Now() - resp, err := cc.runner.httpClient.Do(req) - if err != nil { - cc.responses = append(cc.responses, responseLog{ - Seq: cc.seq, - Attempt: 1, - TraceID: traceID, - StatusCode: 0, - DurationMS: time.Since(start).Milliseconds(), - NetworkErr: err.Error(), - ReceivedAt: time.Now().Format(time.RFC3339Nano), - }) - return err - } - defer resp.Body.Close() - buf := make([]byte, 512) - _, _ = resp.Body.Read(buf) - _ = resp.Body.Close() - cc.responses = append(cc.responses, responseLog{ - Seq: cc.seq, - Attempt: 1, - TraceID: traceID, - StatusCode: resp.StatusCode, - Headers: resp.Header, - BodyText: "aborted_after_first_chunk", - DurationMS: time.Since(start).Milliseconds(), - ReceivedAt: time.Now().Format(time.RFC3339Nano), - }) - return nil -} - func (r *Runner) caseToolcallStreamMixed(ctx context.Context, cc *caseContext) error { payload := toolcallPayload(true) payload["messages"] = []map[string]any{ @@ -293,167 +226,6 @@ func (r *Runner) caseSSEJSONIntegrity(ctx context.Context, cc *caseContext) erro return nil } -func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error { - resp, err := cc.requestOnce(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-not-exists", - "messages": []map[string]any{ - {"role": "user", "content": "hi"}, - }, - "stream": false, - }, - Retryable: false, - }, 1) - if err != nil { - return err - } - cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - e, _ := m["error"].(map[string]any) - cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - e, _ := m["error"].(map[string]any) - cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/config", - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode)) - return nil -} - -func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error { - if len(r.configRaw.Accounts) == 0 { - cc.assert("account_present", false, "no account in config") - return nil - } - acc := r.configRaw.Accounts[0] - id := strings.TrimSpace(acc.Email) - if id == "" { - id = strings.TrimSpace(acc.Mobile) - } - if id == "" { - cc.assert("account_identifier", false, "first account has no identifier") - return nil - } - if strings.TrimSpace(acc.Password) == "" { - r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty") - cc.assert("account_password_present", true, "skipped strict refresh check due empty password") - return nil - } - invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID) - update := map[string]any{ - "keys": r.configRaw.Keys, - "accounts": []map[string]any{ - { - "email": acc.Email, - "mobile": acc.Mobile, - "password": acc.Password, - "token": invalidToken, - }, - }, - } - updResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/admin/config", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Body: update, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode)) - - chatResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - "X-Ds2-Target-Account": id, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": "token refresh test"}, - }, - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body))) - - cfgResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/config", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - var cfg map[string]any - _ = json.Unmarshal(cfgResp.Body, &cfg) - accounts, _ := cfg["accounts"].([]any) - preview := "" - hasToken := false - for _, item := range accounts { - m, _ := item.(map[string]any) - e := asString(m["email"]) - mo := asString(m["mobile"]) - if e == acc.Email && mo == acc.Mobile { - preview = asString(m["token_preview"]) - hasToken, _ = m["has_token"].(bool) - break - } - } - cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body))) - cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20])) - return nil -} - func (r *Runner) fetchQueueStatus(ctx context.Context, cc *caseContext) (map[string]any, error) { resp, err := cc.request(ctx, requestSpec{ Method: http.MethodGet, diff --git a/internal/testsuite/edge_cases_abort.go b/internal/testsuite/edge_cases_abort.go new file mode 100644 index 0000000..2cc1fc1 --- /dev/null +++ b/internal/testsuite/edge_cases_abort.go @@ -0,0 +1,76 @@ +package testsuite + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" +) + +func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error { + cc.seq++ + traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq) + cc.traceIDsSet[traceID] = struct{}{} + fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) + if err != nil { + return err + } + headers := map[string]string{} + for k, v := range spec.Headers { + headers[k] = v + } + headers["X-Ds2-Test-Trace"] = traceID + bodyBytes, _ := json.Marshal(spec.Body) + headers["Content-Type"] = "application/json" + cc.requests = append(cc.requests, requestLog{ + Seq: cc.seq, + Attempt: 1, + TraceID: traceID, + Method: spec.Method, + URL: fullURL, + Headers: headers, + Body: spec.Body, + Timestamp: time.Now().Format(time.RFC3339Nano), + }) + + reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes)) + if err != nil { + return err + } + for k, v := range headers { + req.Header.Set(k, v) + } + start := time.Now() + resp, err := cc.runner.httpClient.Do(req) + if err != nil { + cc.responses = append(cc.responses, responseLog{ + Seq: cc.seq, + Attempt: 1, + TraceID: traceID, + StatusCode: 0, + DurationMS: time.Since(start).Milliseconds(), + NetworkErr: err.Error(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + return err + } + defer resp.Body.Close() + buf := make([]byte, 512) + _, _ = resp.Body.Read(buf) + _ = resp.Body.Close() + cc.responses = append(cc.responses, responseLog{ + Seq: cc.seq, + Attempt: 1, + TraceID: traceID, + StatusCode: resp.StatusCode, + Headers: resp.Header, + BodyText: "aborted_after_first_chunk", + DurationMS: time.Since(start).Milliseconds(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + return nil +} diff --git a/internal/testsuite/edge_cases_error_contract.go b/internal/testsuite/edge_cases_error_contract.go new file mode 100644 index 0000000..d65ce6d --- /dev/null +++ b/internal/testsuite/edge_cases_error_contract.go @@ -0,0 +1,170 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" +) + +func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error { + resp, err := cc.requestOnce(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-not-exists", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "stream": false, + }, + Retryable: false, + }, 1) + if err != nil { + return err + } + cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + e, _ := m["error"].(map[string]any) + cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + e, _ := m["error"].(map[string]any) + cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode)) + return nil +} + +func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error { + if len(r.configRaw.Accounts) == 0 { + cc.assert("account_present", false, "no account in config") + return nil + } + acc := r.configRaw.Accounts[0] + id := strings.TrimSpace(acc.Email) + if id == "" { + id = strings.TrimSpace(acc.Mobile) + } + if id == "" { + cc.assert("account_identifier", false, "first account has no identifier") + return nil + } + if strings.TrimSpace(acc.Password) == "" { + r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty") + cc.assert("account_password_present", true, "skipped strict refresh check due empty password") + return nil + } + invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID) + update := map[string]any{ + "keys": r.configRaw.Keys, + "accounts": []map[string]any{ + { + "email": acc.Email, + "mobile": acc.Mobile, + "password": acc.Password, + "token": invalidToken, + }, + }, + } + updResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Body: update, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode)) + + chatResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "X-Ds2-Target-Account": id, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "token refresh test"}, + }, + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body))) + + cfgResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + var cfg map[string]any + _ = json.Unmarshal(cfgResp.Body, &cfg) + accounts, _ := cfg["accounts"].([]any) + preview := "" + hasToken := false + for _, item := range accounts { + m, _ := item.(map[string]any) + e := asString(m["email"]) + mo := asString(m["mobile"]) + if e == acc.Email && mo == acc.Mobile { + preview = asString(m["token_preview"]) + hasToken, _ = m["has_token"].(bool) + break + } + } + cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body))) + cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20])) + return nil +} diff --git a/internal/testsuite/runner.go b/internal/testsuite/runner.go deleted file mode 100644 index b48bce5..0000000 --- a/internal/testsuite/runner.go +++ /dev/null @@ -1,1640 +0,0 @@ -package testsuite - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/url" - "os" - "os/exec" - "path/filepath" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "time" -) - -type Options struct { - ConfigPath string - AdminKey string - OutputDir string - Port int - Timeout time.Duration - Retries int - NoPreflight bool - MaxKeepRuns int -} - -type runSummary struct { - RunID string `json:"run_id"` - StartedAt string `json:"started_at"` - EndedAt string `json:"ended_at"` - DurationMS int64 `json:"duration_ms"` - Stats map[string]any `json:"stats"` - Environment map[string]any `json:"environment"` - Cases []caseResult `json:"cases"` - Warnings []string `json:"warnings,omitempty"` -} - -type caseResult struct { - CaseID string `json:"case_id"` - Passed bool `json:"passed"` - DurationMS int64 `json:"duration_ms"` - TraceIDs []string `json:"trace_ids"` - StatusCodes []int `json:"status_codes"` - Error string `json:"error,omitempty"` - ArtifactPath string `json:"artifact_path"` - Assertions []assertionResult `json:"assertions"` -} - -type assertionResult struct { - Name string `json:"name"` - Passed bool `json:"passed"` - Detail string `json:"detail,omitempty"` -} - -type requestLog struct { - Seq int `json:"seq"` - Attempt int `json:"attempt"` - TraceID string `json:"trace_id"` - Method string `json:"method"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` - Body any `json:"body,omitempty"` - Timestamp string `json:"timestamp"` -} - -type responseLog struct { - Seq int `json:"seq"` - Attempt int `json:"attempt"` - TraceID string `json:"trace_id"` - StatusCode int `json:"status_code"` - Headers map[string][]string `json:"headers"` - BodyText string `json:"body_text"` - DurationMS int64 `json:"duration_ms"` - NetworkErr string `json:"network_error,omitempty"` - ReceivedAt string `json:"received_at"` -} - -type caseContext struct { - runner *Runner - id string - dir string - startedAt time.Time - mu sync.Mutex - seq int - assertions []assertionResult - requests []requestLog - responses []responseLog - streamRaw strings.Builder - traceIDsSet map[string]struct{} -} - -type requestSpec struct { - Method string - Path string - Headers map[string]string - Body any - Stream bool - Retryable bool -} - -type responseResult struct { - StatusCode int - Headers http.Header - Body []byte - TraceID string - URL string -} - -type Runner struct { - opts Options - - runID string - runDir string - serverLog string - preflightLog string - - baseURL string - httpClient *http.Client - serverCmd *exec.Cmd - serverLogFd *os.File - - configCopyPath string - originalConfigPath string - originalConfigHash string - - configRaw runConfig - apiKey string - adminKey string - adminJWT string - accountID string - - warnings []string - results []caseResult -} - -type runConfig struct { - Keys []string `json:"keys"` - Accounts []struct { - Email string `json:"email,omitempty"` - Mobile string `json:"mobile,omitempty"` - Password string `json:"password,omitempty"` - Token string `json:"token,omitempty"` - } `json:"accounts"` -} - -func DefaultOptions() Options { - return Options{ - ConfigPath: "config.json", - AdminKey: strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")), - OutputDir: "artifacts/testsuite", - Port: 0, - Timeout: 120 * time.Second, - Retries: 2, - NoPreflight: false, - MaxKeepRuns: 5, - } -} - -func Run(ctx context.Context, opts Options) error { - r, err := newRunner(opts) - if err != nil { - return err - } - start := time.Now() - defer func() { - _ = r.stopServer() - }() - - if err := r.prepareRunDir(); err != nil { - return err - } - - if !r.opts.NoPreflight { - if err := r.runPreflight(ctx); err != nil { - _ = r.writeSummary(start, time.Now()) - return err - } - } - - if err := r.prepareConfigIsolation(); err != nil { - _ = r.writeSummary(start, time.Now()) - return err - } - - if err := r.startServer(ctx); err != nil { - _ = r.writeSummary(start, time.Now()) - return err - } - - if err := r.prepareAuth(ctx); err != nil { - r.warnings = append(r.warnings, "auth prepare failed: "+err.Error()) - } - - for _, c := range r.cases() { - r.runCase(ctx, c) - } - - if err := r.ensureOriginalConfigUntouched(); err != nil { - r.warnings = append(r.warnings, err.Error()) - } - - end := time.Now() - if err := r.writeSummary(start, end); err != nil { - return err - } - - // Prune old test runs, keeping only the most recent N. - if err := r.pruneOldRuns(); err != nil { - r.warnings = append(r.warnings, "prune old runs: "+err.Error()) - } - - failed := 0 - for _, cs := range r.results { - if !cs.Passed { - failed++ - } - } - if failed > 0 { - return fmt.Errorf("testsuite failed: %d case(s) failed, see %s", failed, filepath.Join(r.runDir, "summary.md")) - } - return nil -} - -func newRunner(opts Options) (*Runner, error) { - if strings.TrimSpace(opts.ConfigPath) == "" { - opts.ConfigPath = "config.json" - } - if strings.TrimSpace(opts.OutputDir) == "" { - opts.OutputDir = "artifacts/testsuite" - } - if opts.Timeout <= 0 { - opts.Timeout = 120 * time.Second - } - if opts.Retries < 0 { - opts.Retries = 0 - } - adminKey := strings.TrimSpace(opts.AdminKey) - if adminKey == "" { - adminKey = strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")) - } - if adminKey == "" { - adminKey = "admin" - } - opts.AdminKey = adminKey - - return &Runner{ - opts: opts, - httpClient: &http.Client{ - Timeout: 0, - }, - runID: time.Now().UTC().Format("20060102T150405Z"), - adminKey: adminKey, - }, nil -} - -func (r *Runner) prepareRunDir() error { - r.runDir = filepath.Join(r.opts.OutputDir, r.runID) - if err := os.MkdirAll(r.runDir, 0o755); err != nil { - return err - } - if err := os.MkdirAll(filepath.Join(r.runDir, "cases"), 0o755); err != nil { - return err - } - r.serverLog = filepath.Join(r.runDir, "server.log") - r.preflightLog = filepath.Join(r.runDir, "preflight.log") - return nil -} - -// pruneOldRuns removes old test run directories, keeping the most recent MaxKeepRuns. -// Run IDs use the format "20060102T150405Z", so alphabetical order == chronological order. -func (r *Runner) pruneOldRuns() error { - keep := r.opts.MaxKeepRuns - if keep <= 0 { - return nil // 0 or negative means no pruning - } - - entries, err := os.ReadDir(r.opts.OutputDir) - if err != nil { - return err - } - - // Collect only directories (each run is a directory). - var runDirs []string - for _, e := range entries { - if !e.IsDir() { - continue - } - runDirs = append(runDirs, e.Name()) - } - - sort.Strings(runDirs) - - if len(runDirs) <= keep { - return nil - } - - // Remove oldest runs (those at the beginning of the sorted list). - toRemove := runDirs[:len(runDirs)-keep] - var errs []string - for _, name := range toRemove { - dirPath := filepath.Join(r.opts.OutputDir, name) - if err := os.RemoveAll(dirPath); err != nil { - errs = append(errs, fmt.Sprintf("remove %s: %v", name, err)) - } else { - fmt.Fprintf(os.Stdout, "pruned old test run: %s\n", name) - } - } - - if len(errs) > 0 { - return errors.New(strings.Join(errs, "; ")) - } - return nil -} - -func (r *Runner) runPreflight(ctx context.Context) error { - steps := [][]string{ - {"go", "test", "./...", "-count=1"}, - {"node", "--check", "api/chat-stream.js"}, - {"node", "--check", "api/helpers/stream-tool-sieve.js"}, - {"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js"}, - {"npm", "run", "build", "--prefix", "webui"}, - } - f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) - if err != nil { - return err - } - defer f.Close() - for _, step := range steps { - if _, err := fmt.Fprintf(f, "\n$ %s\n", strings.Join(step, " ")); err != nil { - return err - } - cmd := exec.CommandContext(ctx, step[0], step[1:]...) - cmd.Stdout = f - cmd.Stderr = f - if err := cmd.Run(); err != nil { - return fmt.Errorf("preflight failed at `%s`: %w", strings.Join(step, " "), err) - } - } - return nil -} - -func (r *Runner) prepareConfigIsolation() error { - abs, err := filepath.Abs(r.opts.ConfigPath) - if err != nil { - return err - } - r.originalConfigPath = abs - raw, err := os.ReadFile(abs) - if err != nil { - return err - } - sum := sha256.Sum256(raw) - r.originalConfigHash = hex.EncodeToString(sum[:]) - - tmpDir := filepath.Join(r.runDir, "tmp") - if err := os.MkdirAll(tmpDir, 0o755); err != nil { - return err - } - r.configCopyPath = filepath.Join(tmpDir, "config.json") - if err := os.WriteFile(r.configCopyPath, raw, 0o644); err != nil { - return err - } - var cfg runConfig - if err := json.Unmarshal(raw, &cfg); err != nil { - return fmt.Errorf("parse config failed: %w", err) - } - r.configRaw = cfg - if len(cfg.Keys) > 0 { - r.apiKey = strings.TrimSpace(cfg.Keys[0]) - } - for _, acc := range cfg.Accounts { - id := strings.TrimSpace(acc.Email) - if id == "" { - id = strings.TrimSpace(acc.Mobile) - } - if id != "" { - r.accountID = id - break - } - } - return nil -} - -func (r *Runner) startServer(ctx context.Context) error { - port := r.opts.Port - if port <= 0 { - p, err := findFreePort() - if err != nil { - return err - } - port = p - } - r.baseURL = "http://127.0.0.1:" + strconv.Itoa(port) - - logFd, err := os.OpenFile(r.serverLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) - if err != nil { - return err - } - r.serverLogFd = logFd - cmd := exec.CommandContext(ctx, "go", "run", "./cmd/ds2api") - cmd.Stdout = logFd - cmd.Stderr = logFd - cmd.Env = prepareServerEnv(os.Environ(), map[string]string{ - "PORT": strconv.Itoa(port), - "DS2API_CONFIG_PATH": r.configCopyPath, - "DS2API_AUTO_BUILD_WEBUI": "false", - "DS2API_CONFIG_JSON": "", - "CONFIG_JSON": "", - }) - if err := cmd.Start(); err != nil { - _ = logFd.Close() - return err - } - r.serverCmd = cmd - - deadline := time.Now().Add(90 * time.Second) - for time.Now().Before(deadline) { - if r.ping("/healthz") == nil && r.ping("/readyz") == nil { - return nil - } - time.Sleep(500 * time.Millisecond) - } - return errors.New("server readiness timeout") -} - -func (r *Runner) stopServer() error { - var errs []string - if r.serverCmd != nil && r.serverCmd.Process != nil { - _ = r.serverCmd.Process.Signal(os.Interrupt) - done := make(chan error, 1) - go func() { done <- r.serverCmd.Wait() }() - select { - case <-time.After(5 * time.Second): - _ = r.serverCmd.Process.Kill() - <-done - case <-done: - } - } - if r.serverLogFd != nil { - if err := r.serverLogFd.Close(); err != nil { - errs = append(errs, err.Error()) - } - } - if len(errs) > 0 { - return errors.New(strings.Join(errs, "; ")) - } - return nil -} - -func (r *Runner) ping(path string) error { - resp, err := r.httpClient.Get(r.baseURL + path) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("status=%d", resp.StatusCode) - } - return nil -} - -func (r *Runner) prepareAuth(ctx context.Context) error { - reqBody := map[string]any{ - "admin_key": r.adminKey, - "expire_hours": 24, - } - resp, err := r.doSimpleJSON(ctx, http.MethodPost, "/admin/login", nil, reqBody) - if err != nil { - return err - } - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("admin login status=%d body=%s", resp.StatusCode, string(resp.Body)) - } - var m map[string]any - if err := json.Unmarshal(resp.Body, &m); err != nil { - return err - } - token, _ := m["token"].(string) - if strings.TrimSpace(token) == "" { - return errors.New("empty admin jwt token") - } - r.adminJWT = token - return nil -} - -func (r *Runner) ensureOriginalConfigUntouched() error { - raw, err := os.ReadFile(r.originalConfigPath) - if err != nil { - return err - } - sum := sha256.Sum256(raw) - current := hex.EncodeToString(sum[:]) - if current != r.originalConfigHash { - return fmt.Errorf("original config changed unexpectedly: %s", r.originalConfigPath) - } - return nil -} - -func (r *Runner) runCase(ctx context.Context, c caseDef) { - caseDir := filepath.Join(r.runDir, "cases", c.ID) - _ = os.MkdirAll(caseDir, 0o755) - cc := &caseContext{ - runner: r, - id: c.ID, - dir: caseDir, - startedAt: time.Now(), - traceIDsSet: map[string]struct{}{}, - } - err := c.Run(ctx, cc) - duration := time.Since(cc.startedAt).Milliseconds() - - if err != nil { - cc.assertions = append(cc.assertions, assertionResult{ - Name: "case_error", - Passed: false, - Detail: err.Error(), - }) - } - passed := err == nil - for _, a := range cc.assertions { - if !a.Passed { - passed = false - break - } - } - - traceIDs := make([]string, 0, len(cc.traceIDsSet)) - for t := range cc.traceIDsSet { - traceIDs = append(traceIDs, t) - } - sort.Strings(traceIDs) - statuses := uniqueStatusCodes(cc.responses) - cs := caseResult{ - CaseID: c.ID, - Passed: passed, - DurationMS: duration, - TraceIDs: traceIDs, - StatusCodes: statuses, - ArtifactPath: caseDir, - Assertions: cc.assertions, - } - if err != nil { - cs.Error = err.Error() - } - _ = cc.flushArtifacts(cs) - r.results = append(r.results, cs) -} - -func (cc *caseContext) assert(name string, ok bool, detail string) { - cc.mu.Lock() - defer cc.mu.Unlock() - cc.assertions = append(cc.assertions, assertionResult{ - Name: name, - Passed: ok, - Detail: detail, - }) -} - -func (cc *caseContext) request(ctx context.Context, spec requestSpec) (*responseResult, error) { - retries := cc.runner.opts.Retries - if !spec.Retryable { - retries = 0 - } - var lastErr error - for attempt := 1; attempt <= retries+1; attempt++ { - resp, err := cc.requestOnce(ctx, spec, attempt) - if err == nil && resp.StatusCode < 500 { - return resp, nil - } - if err != nil { - lastErr = err - } else if resp.StatusCode >= 500 { - lastErr = fmt.Errorf("status=%d", resp.StatusCode) - } - if attempt <= retries { - sleep := time.Duration(300*(1<<(attempt-1))) * time.Millisecond - time.Sleep(sleep) - } - } - return nil, lastErr -} - -func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attempt int) (*responseResult, error) { - cc.mu.Lock() - cc.seq++ - seq := cc.seq - traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), seq) - cc.traceIDsSet[traceID] = struct{}{} - cc.mu.Unlock() - - fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) - if err != nil { - return nil, err - } - - headers := map[string]string{} - for k, v := range spec.Headers { - headers[k] = v - } - headers["X-Ds2-Test-Trace"] = traceID - - var bodyBytes []byte - var bodyAny any - if spec.Body != nil { - b, err := json.Marshal(spec.Body) - if err != nil { - return nil, err - } - bodyBytes = b - bodyAny = spec.Body - headers["Content-Type"] = "application/json" - } - cc.mu.Lock() - cc.requests = append(cc.requests, requestLog{ - Seq: seq, - Attempt: attempt, - TraceID: traceID, - Method: spec.Method, - URL: fullURL, - Headers: headers, - Body: bodyAny, - Timestamp: time.Now().Format(time.RFC3339Nano), - }) - cc.mu.Unlock() - - reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) - defer cancel() - req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes)) - if err != nil { - return nil, err - } - for k, v := range headers { - req.Header.Set(k, v) - } - start := time.Now() - resp, err := cc.runner.httpClient.Do(req) - if err != nil { - cc.mu.Lock() - cc.responses = append(cc.responses, responseLog{ - Seq: seq, - Attempt: attempt, - TraceID: traceID, - StatusCode: 0, - DurationMS: time.Since(start).Milliseconds(), - NetworkErr: err.Error(), - ReceivedAt: time.Now().Format(time.RFC3339Nano), - }) - cc.mu.Unlock() - return nil, err - } - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - - cc.mu.Lock() - cc.responses = append(cc.responses, responseLog{ - Seq: seq, - Attempt: attempt, - TraceID: traceID, - StatusCode: resp.StatusCode, - Headers: resp.Header, - BodyText: string(body), - DurationMS: time.Since(start).Milliseconds(), - ReceivedAt: time.Now().Format(time.RFC3339Nano), - }) - - if spec.Stream { - cc.streamRaw.WriteString(fmt.Sprintf("### trace=%s url=%s\n", traceID, fullURL)) - cc.streamRaw.Write(body) - cc.streamRaw.WriteString("\n\n") - } - cc.mu.Unlock() - - return &responseResult{ - StatusCode: resp.StatusCode, - Headers: resp.Header, - Body: body, - TraceID: traceID, - URL: fullURL, - }, nil -} - -func (cc *caseContext) flushArtifacts(cs caseResult) error { - requestPath := filepath.Join(cc.dir, "request.json") - headersPath := filepath.Join(cc.dir, "response.headers") - bodyPath := filepath.Join(cc.dir, "response.body") - streamPath := filepath.Join(cc.dir, "stream.raw") - assertPath := filepath.Join(cc.dir, "assertions.json") - metaPath := filepath.Join(cc.dir, "meta.json") - - if err := writeJSONFile(requestPath, cc.requests); err != nil { - return err - } - respHeaders := make([]map[string]any, 0, len(cc.responses)) - respBodies := make([]map[string]any, 0, len(cc.responses)) - for _, r := range cc.responses { - respHeaders = append(respHeaders, map[string]any{ - "seq": r.Seq, - "attempt": r.Attempt, - "trace_id": r.TraceID, - "status_code": r.StatusCode, - "headers": r.Headers, - }) - respBodies = append(respBodies, map[string]any{ - "seq": r.Seq, - "attempt": r.Attempt, - "trace_id": r.TraceID, - "status_code": r.StatusCode, - "body_text": r.BodyText, - "network_error": r.NetworkErr, - "duration_ms": r.DurationMS, - }) - } - if err := writeJSONFile(headersPath, respHeaders); err != nil { - return err - } - if err := writeJSONFile(bodyPath, respBodies); err != nil { - return err - } - if err := os.WriteFile(streamPath, []byte(cc.streamRaw.String()), 0o644); err != nil { - return err - } - if err := writeJSONFile(assertPath, cc.assertions); err != nil { - return err - } - meta := map[string]any{ - "case_id": cs.CaseID, - "trace_id": strings.Join(cs.TraceIDs, ","), - "attempt": len(cc.responses), - "duration_ms": cs.DurationMS, - "status": map[bool]string{true: "passed", false: "failed"}[cs.Passed], - "status_codes": cs.StatusCodes, - "assertions": cs.Assertions, - "artifact_path": cs.ArtifactPath, - } - return writeJSONFile(metaPath, meta) -} - -type caseDef struct { - ID string - Run func(context.Context, *caseContext) error -} - -func (r *Runner) cases() []caseDef { - return []caseDef{ - {ID: "healthz_ok", Run: r.caseHealthz}, - {ID: "readyz_ok", Run: r.caseReadyz}, - {ID: "models_openai", Run: r.caseModelsOpenAI}, - {ID: "models_claude", Run: r.caseModelsClaude}, - {ID: "admin_login_verify", Run: r.caseAdminLoginVerify}, - {ID: "admin_queue_status", Run: r.caseAdminQueueStatus}, - {ID: "chat_nonstream_basic", Run: r.caseChatNonstream}, - {ID: "chat_stream_basic", Run: r.caseChatStream}, - {ID: "reasoner_stream", Run: r.caseReasonerStream}, - {ID: "toolcall_nonstream", Run: r.caseToolcallNonstream}, - {ID: "toolcall_stream", Run: r.caseToolcallStream}, - {ID: "anthropic_messages_nonstream", Run: r.caseAnthropicNonstream}, - {ID: "anthropic_messages_stream", Run: r.caseAnthropicStream}, - {ID: "anthropic_count_tokens", Run: r.caseAnthropicCountTokens}, - {ID: "admin_account_test_single", Run: r.caseAdminAccountTest}, - {ID: "concurrency_burst", Run: r.caseConcurrencyBurst}, - {ID: "concurrency_threshold_limit", Run: r.caseConcurrencyThresholdLimit}, - {ID: "stream_abort_release", Run: r.caseStreamAbortRelease}, - {ID: "toolcall_stream_mixed", Run: r.caseToolcallStreamMixed}, - {ID: "sse_json_integrity", Run: r.caseSSEJSONIntegrity}, - {ID: "error_contract_invalid_model", Run: r.caseInvalidModel}, - {ID: "error_contract_missing_messages", Run: r.caseMissingMessages}, - {ID: "admin_unauthorized_contract", Run: r.caseAdminUnauthorized}, - {ID: "config_write_isolated", Run: r.caseConfigWriteIsolated}, - {ID: "token_refresh_managed_account", Run: r.caseTokenRefreshManagedAccount}, - {ID: "error_contract_invalid_key", Run: r.caseInvalidKey}, - } -} - -func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/healthz", Retryable: true}) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/readyz", Retryable: true}) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models", Retryable: true}) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - ids := extractModelIDs(resp.Body) - cc.assert("has_deepseek_chat", contains(ids, "deepseek-chat"), strings.Join(ids, ",")) - cc.assert("has_deepseek_reasoner", contains(ids, "deepseek-reasoner"), strings.Join(ids, ",")) - return nil -} - -func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true}) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - ids := extractModelIDs(resp.Body) - cc.assert("non_empty", len(ids) > 0, fmt.Sprintf("models=%v", ids)) - return nil -} - -func (r *Runner) caseAdminLoginVerify(ctx context.Context, cc *caseContext) error { - loginResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/admin/login", - Body: map[string]any{"admin_key": r.adminKey, "expire_hours": 24}, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("login_status_200", loginResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", loginResp.StatusCode)) - var payload map[string]any - _ = json.Unmarshal(loginResp.Body, &payload) - token := asString(payload["token"]) - cc.assert("token_exists", token != "", fmt.Sprintf("body=%s", string(loginResp.Body))) - if token == "" { - return nil - } - verifyResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/verify", - Headers: map[string]string{ - "Authorization": "Bearer " + token, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("verify_status_200", verifyResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", verifyResp.StatusCode)) - var v map[string]any - _ = json.Unmarshal(verifyResp.Body, &v) - valid, _ := v["valid"].(bool) - cc.assert("verify_valid_true", valid, fmt.Sprintf("body=%s", string(verifyResp.Body))) - return nil -} - -func (r *Runner) caseAdminQueueStatus(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/queue/status", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - _, hasRec := m["recommended_concurrency"] - _, hasQueue := m["max_queue_size"] - cc.assert("has_recommended_concurrency", hasRec, fmt.Sprintf("body=%s", string(resp.Body))) - cc.assert("has_max_queue_size", hasQueue, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseChatNonstream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": "请简单回复一句话"}, - }, - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("object_chat_completion", asString(m["object"]) == "chat.completion", fmt.Sprintf("body=%s", string(resp.Body))) - choices, _ := m["choices"].([]any) - cc.assert("choices_non_empty", len(choices) > 0, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": "请流式回复一句话"}, - }, - "stream": true, - }, - Stream: true, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - frames, done := parseSSEFrames(resp.Body) - cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) - cc.assert("done_terminated", done, "expected [DONE]") - return nil -} - -func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-reasoner", - "messages": []map[string]any{ - {"role": "user", "content": "先思考后回答:1+1"}, - }, - "stream": true, - }, - Stream: true, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - frames, done := parseSSEFrames(resp.Body) - hasReasoning := false - for _, f := range frames { - choices, _ := f["choices"].([]any) - for _, c := range choices { - ch, _ := c.(map[string]any) - delta, _ := ch["delta"].(map[string]any) - if asString(delta["reasoning_content"]) != "" { - hasReasoning = true - } - } - } - cc.assert("has_reasoning_content", hasReasoning, "reasoning_content not found") - cc.assert("done_terminated", done, "expected [DONE]") - return nil -} - -func (r *Runner) caseToolcallNonstream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: toolcallPayload(false), - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - choices, _ := m["choices"].([]any) - if len(choices) == 0 { - cc.assert("choices_non_empty", false, fmt.Sprintf("body=%s", string(resp.Body))) - return nil - } - c0, _ := choices[0].(map[string]any) - cc.assert("finish_reason_tool_calls", asString(c0["finish_reason"]) == "tool_calls", fmt.Sprintf("body=%s", string(resp.Body))) - msg, _ := c0["message"].(map[string]any) - tc, _ := msg["tool_calls"].([]any) - cc.assert("tool_calls_present", len(tc) > 0, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseToolcallStream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: toolcallPayload(true), - Stream: true, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - frames, done := parseSSEFrames(resp.Body) - hasTool := false - rawLeak := false - for _, f := range frames { - choices, _ := f["choices"].([]any) - for _, c := range choices { - ch, _ := c.(map[string]any) - delta, _ := ch["delta"].(map[string]any) - if _, ok := delta["tool_calls"]; ok { - hasTool = true - } - content := asString(delta["content"]) - if strings.Contains(strings.ToLower(content), `"tool_calls"`) { - rawLeak = true - } - } - } - cc.assert("tool_calls_delta_present", hasTool, "tool_calls delta missing") - cc.assert("no_raw_tool_json_leak", !rawLeak, "raw tool_calls JSON leaked in content") - cc.assert("done_terminated", done, "expected [DONE]") - return nil -} - -func (r *Runner) caseAnthropicNonstream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/anthropic/v1/messages", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - Body: map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{ - {"role": "user", "content": "hello"}, - }, - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("type_message", asString(m["type"]) == "message", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseAnthropicStream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/anthropic/v1/messages", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - Body: map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{ - {"role": "user", "content": "stream hello"}, - }, - "stream": true, - }, - Stream: true, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - events := parseClaudeStreamEvents(resp.Body) - cc.assert("has_message_start", contains(events, "message_start"), fmt.Sprintf("events=%v", events)) - cc.assert("has_message_stop", contains(events, "message_stop"), fmt.Sprintf("events=%v", events)) - return nil -} - -func (r *Runner) caseAnthropicCountTokens(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/anthropic/v1/messages/count_tokens", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - Body: map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{ - {"role": "user", "content": "count me"}, - }, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - v := toInt(m["input_tokens"]) - cc.assert("input_tokens_gt_zero", v > 0, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseAdminAccountTest(ctx context.Context, cc *caseContext) error { - if strings.TrimSpace(r.accountID) == "" { - cc.assert("account_present", false, "no account in config") - return nil - } - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/admin/accounts/test", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Body: map[string]any{ - "identifier": r.accountID, - "model": "deepseek-chat", - "message": "ping", - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - ok, _ := m["success"].(bool) - cc.assert("success_true", ok, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseConcurrencyBurst(ctx context.Context, cc *caseContext) error { - accountCount := len(r.configRaw.Accounts) - n := accountCount*2 + 2 - if n < 2 { - n = 2 - } - type one struct { - Status int - Err string - } - results := make([]one, n) - var wg sync.WaitGroup - for i := 0; i < n; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": fmt.Sprintf("并发请求 #%d,请回复ok", idx)}, - }, - "stream": true, - }, - Stream: true, - Retryable: true, - }) - if err != nil { - results[idx] = one{Err: err.Error()} - return - } - results[idx] = one{Status: resp.StatusCode} - }(i) - } - wg.Wait() - - dist := map[int]int{} - success := 0 - for _, it := range results { - if it.Status > 0 { - dist[it.Status]++ - if it.Status == http.StatusOK { - success++ - } - } - } - cc.assert("success_gt_zero", success > 0, fmt.Sprintf("distribution=%v", dist)) - _, has5xx := has5xx(dist) - cc.assert("no_5xx", !has5xx, fmt.Sprintf("distribution=%v", dist)) - if err := r.ping("/healthz"); err != nil { - cc.assert("server_alive", false, err.Error()) - } else { - cc.assert("server_alive", true, "") - } - return nil -} - -func (r *Runner) caseConfigWriteIsolated(ctx context.Context, cc *caseContext) error { - k := "testsuite-temp-" + sanitizeID(r.runID) - add, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/admin/keys", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Body: map[string]any{"key": k}, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("add_key_status_200", add.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", add.StatusCode)) - - cfg1, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/config", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - containsAdded := strings.Contains(string(cfg1.Body), k) - cc.assert("key_present_in_isolated_config", containsAdded, "added key not found in isolated config") - - delPath := "/admin/keys/" + url.PathEscape(k) - del, err := cc.request(ctx, requestSpec{ - Method: http.MethodDelete, - Path: delPath, - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("delete_key_status_200", del.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", del.StatusCode)) - - cfg2, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/config", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("key_removed_in_isolated_config", !strings.Contains(string(cfg2.Body), k), "temporary key still present") - - if err := r.ensureOriginalConfigUntouched(); err != nil { - cc.assert("original_config_unchanged", false, err.Error()) - } else { - cc.assert("original_config_unchanged", true, "") - } - return nil -} - -func (r *Runner) caseInvalidKey(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer invalid-testsuite-key-" + sanitizeID(r.runID), - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": "hi"}, - }, - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - e, _ := m["error"].(map[string]any) - cc.assert("error_object_present", len(e) > 0, fmt.Sprintf("body=%s", string(resp.Body))) - cc.assert("error_message_present", asString(e["message"]) != "", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) doSimpleJSON(ctx context.Context, method, path string, headers map[string]string, body any) (*responseResult, error) { - cc := &caseContext{ - runner: r, - id: "auth_prepare", - traceIDsSet: map[string]struct{}{}, - } - return cc.request(ctx, requestSpec{ - Method: method, - Path: path, - Headers: headers, - Body: body, - Retryable: true, - }) -} - -func (r *Runner) writeSummary(start, end time.Time) error { - passed := 0 - failed := 0 - for _, cs := range r.results { - if cs.Passed { - passed++ - } else { - failed++ - } - } - summary := runSummary{ - RunID: r.runID, - StartedAt: start.Format(time.RFC3339Nano), - EndedAt: end.Format(time.RFC3339Nano), - DurationMS: end.Sub(start).Milliseconds(), - Stats: map[string]any{ - "total": len(r.results), - "passed": passed, - "failed": failed, - }, - Environment: map[string]any{ - "go_version": runtime.Version(), - "os": runtime.GOOS, - "arch": runtime.GOARCH, - "base_url": r.baseURL, - "config_source": r.originalConfigPath, - "config_isolated": r.configCopyPath, - "server_log": r.serverLog, - "preflight_log": r.preflightLog, - "retries": r.opts.Retries, - "timeout_seconds": int(r.opts.Timeout.Seconds()), - }, - Cases: r.results, - Warnings: r.warnings, - } - if err := writeJSONFile(filepath.Join(r.runDir, "summary.json"), summary); err != nil { - return err - } - return os.WriteFile(filepath.Join(r.runDir, "summary.md"), []byte(r.summaryMarkdown(summary)), 0o644) -} - -func (r *Runner) summaryMarkdown(s runSummary) string { - var b strings.Builder - b.WriteString("# DS2API Live Testsuite Summary\n\n") - b.WriteString("**Sensitive Notice:** this run stores full raw request/response logs. Do not share artifacts publicly.\n\n") - fmt.Fprintf(&b, "- Run ID: `%s`\n", s.RunID) - fmt.Fprintf(&b, "- Started: `%s`\n", s.StartedAt) - fmt.Fprintf(&b, "- Ended: `%s`\n", s.EndedAt) - fmt.Fprintf(&b, "- Duration: `%d ms`\n", s.DurationMS) - fmt.Fprintf(&b, "- Passed/Failed: `%d/%d`\n\n", s.Stats["passed"], s.Stats["failed"]) - if len(s.Warnings) > 0 { - b.WriteString("## Warnings\n\n") - for _, w := range s.Warnings { - fmt.Fprintf(&b, "- %s\n", w) - } - b.WriteString("\n") - } - b.WriteString("## Failed Cases\n\n") - hasFailed := false - for _, c := range s.Cases { - if c.Passed { - continue - } - hasFailed = true - fmt.Fprintf(&b, "- `%s`: %s\n", c.CaseID, c.Error) - if len(c.TraceIDs) > 0 { - fmt.Fprintf(&b, " - trace_ids: `%s`\n", strings.Join(c.TraceIDs, ", ")) - fmt.Fprintf(&b, " - grep: `rg \"%s\" %s`\n", c.TraceIDs[0], filepath.Join(r.runDir, "server.log")) - } - fmt.Fprintf(&b, " - artifact: `%s`\n", c.ArtifactPath) - } - if !hasFailed { - b.WriteString("- none\n") - } - b.WriteString("\n## Case Table\n\n") - b.WriteString("| case_id | status | duration_ms | statuses | artifact |\n") - b.WriteString("|---|---:|---:|---|---|\n") - for _, c := range s.Cases { - status := "PASS" - if !c.Passed { - status = "FAIL" - } - fmt.Fprintf(&b, "| %s | %s | %d | %v | `%s` |\n", c.CaseID, status, c.DurationMS, c.StatusCodes, c.ArtifactPath) - } - return b.String() -} - -func toolcallPayload(stream bool) map[string]any { - return map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - { - "role": "user", - "content": "你必须调用工具 search 查询 golang,并仅返回工具调用。", - }, - }, - "tools": []map[string]any{ - { - "type": "function", - "function": map[string]any{ - "name": "search", - "description": "search documents", - "parameters": map[string]any{ - "type": "object", - "properties": map[string]any{ - "q": map[string]any{ - "type": "string", - }, - }, - "required": []string{"q"}, - }, - }, - }, - }, - "stream": stream, - } -} - -func parseSSEFrames(body []byte) ([]map[string]any, bool) { - lines := strings.Split(string(body), "\n") - frames := make([]map[string]any, 0, len(lines)) - done := false - for _, line := range lines { - line = strings.TrimSpace(line) - if !strings.HasPrefix(line, "data:") { - continue - } - payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if payload == "" { - continue - } - if payload == "[DONE]" { - done = true - continue - } - var m map[string]any - if err := json.Unmarshal([]byte(payload), &m); err == nil { - frames = append(frames, m) - } - } - return frames, done -} - -func parseClaudeStreamEvents(body []byte) []string { - events := []string{} - seen := map[string]bool{} - lines := strings.Split(string(body), "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - if !strings.HasPrefix(line, "data:") { - continue - } - payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if payload == "" { - continue - } - var m map[string]any - if err := json.Unmarshal([]byte(payload), &m); err != nil { - continue - } - t := asString(m["type"]) - if t == "" || seen[t] { - continue - } - seen[t] = true - events = append(events, t) - } - return events -} - -func extractModelIDs(body []byte) []string { - var m map[string]any - if err := json.Unmarshal(body, &m); err != nil { - return nil - } - out := []string{} - data, _ := m["data"].([]any) - for _, it := range data { - item, _ := it.(map[string]any) - id := asString(item["id"]) - if id != "" { - out = append(out, id) - } - } - return out -} - -func withTraceQuery(rawURL, traceID string) (string, error) { - u, err := url.Parse(rawURL) - if err != nil { - return "", err - } - q := u.Query() - q.Set("__trace_id", traceID) - u.RawQuery = q.Encode() - return u.String(), nil -} - -func writeJSONFile(path string, v any) error { - b, err := json.MarshalIndent(v, "", " ") - if err != nil { - return err - } - return os.WriteFile(path, b, 0o644) -} - -func prepareServerEnv(base []string, overrides map[string]string) []string { - out := make([]string, 0, len(base)+len(overrides)) - skip := map[string]struct{}{} - for k := range overrides { - skip[k] = struct{}{} - } - for _, e := range base { - parts := strings.SplitN(e, "=", 2) - if len(parts) != 2 { - continue - } - if _, ok := skip[parts[0]]; ok { - continue - } - out = append(out, e) - } - for k, v := range overrides { - out = append(out, k+"="+v) - } - return out -} - -func findFreePort() (int, error) { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return 0, err - } - defer ln.Close() - addr, ok := ln.Addr().(*net.TCPAddr) - if !ok { - return 0, errors.New("failed to detect tcp port") - } - return addr.Port, nil -} - -func uniqueStatusCodes(in []responseLog) []int { - set := map[int]struct{}{} - for _, it := range in { - if it.StatusCode > 0 { - set[it.StatusCode] = struct{}{} - } - } - out := make([]int, 0, len(set)) - for k := range set { - out = append(out, k) - } - sort.Ints(out) - return out -} - -func has5xx(dist map[int]int) (int, bool) { - for k := range dist { - if k >= 500 { - return k, true - } - } - return 0, false -} - -func sanitizeID(s string) string { - s = strings.ReplaceAll(s, ":", "_") - s = strings.ReplaceAll(s, "/", "_") - s = strings.ReplaceAll(s, " ", "_") - return s -} - -func asString(v any) string { - if v == nil { - return "" - } - switch x := v.(type) { - case string: - return strings.TrimSpace(x) - default: - return strings.TrimSpace(fmt.Sprintf("%v", v)) - } -} - -func toInt(v any) int { - switch x := v.(type) { - case float64: - return int(x) - case float32: - return int(x) - case int: - return x - case int64: - return int(x) - default: - return 0 - } -} - -func contains(xs []string, target string) bool { - for _, x := range xs { - if x == target { - return true - } - } - return false -} diff --git a/internal/testsuite/runner_cases_admin.go b/internal/testsuite/runner_cases_admin.go new file mode 100644 index 0000000..d66adea --- /dev/null +++ b/internal/testsuite/runner_cases_admin.go @@ -0,0 +1,161 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" +) + +func (r *Runner) caseAdminLoginVerify(ctx context.Context, cc *caseContext) error { + loginResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/admin/login", + Body: map[string]any{"admin_key": r.adminKey, "expire_hours": 24}, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("login_status_200", loginResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", loginResp.StatusCode)) + var payload map[string]any + _ = json.Unmarshal(loginResp.Body, &payload) + token := asString(payload["token"]) + cc.assert("token_exists", token != "", fmt.Sprintf("body=%s", string(loginResp.Body))) + if token == "" { + return nil + } + verifyResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/verify", + Headers: map[string]string{ + "Authorization": "Bearer " + token, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("verify_status_200", verifyResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", verifyResp.StatusCode)) + var v map[string]any + _ = json.Unmarshal(verifyResp.Body, &v) + valid, _ := v["valid"].(bool) + cc.assert("verify_valid_true", valid, fmt.Sprintf("body=%s", string(verifyResp.Body))) + return nil +} + +func (r *Runner) caseAdminQueueStatus(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/queue/status", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + _, hasRec := m["recommended_concurrency"] + _, hasQueue := m["max_queue_size"] + cc.assert("has_recommended_concurrency", hasRec, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("has_max_queue_size", hasQueue, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} +func (r *Runner) caseAdminAccountTest(ctx context.Context, cc *caseContext) error { + if strings.TrimSpace(r.accountID) == "" { + cc.assert("account_present", false, "no account in config") + return nil + } + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/admin/accounts/test", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Body: map[string]any{ + "identifier": r.accountID, + "model": "deepseek-chat", + "message": "ping", + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + ok, _ := m["success"].(bool) + cc.assert("success_true", ok, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} +func (r *Runner) caseConfigWriteIsolated(ctx context.Context, cc *caseContext) error { + k := "testsuite-temp-" + sanitizeID(r.runID) + add, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/admin/keys", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Body: map[string]any{"key": k}, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("add_key_status_200", add.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", add.StatusCode)) + + cfg1, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + containsAdded := strings.Contains(string(cfg1.Body), k) + cc.assert("key_present_in_isolated_config", containsAdded, "added key not found in isolated config") + + delPath := "/admin/keys/" + url.PathEscape(k) + del, err := cc.request(ctx, requestSpec{ + Method: http.MethodDelete, + Path: delPath, + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("delete_key_status_200", del.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", del.StatusCode)) + + cfg2, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("key_removed_in_isolated_config", !strings.Contains(string(cfg2.Body), k), "temporary key still present") + + if err := r.ensureOriginalConfigUntouched(); err != nil { + cc.assert("original_config_unchanged", false, err.Error()) + } else { + cc.assert("original_config_unchanged", true, "") + } + return nil +} diff --git a/internal/testsuite/runner_cases_claude.go b/internal/testsuite/runner_cases_claude.go new file mode 100644 index 0000000..590e524 --- /dev/null +++ b/internal/testsuite/runner_cases_claude.go @@ -0,0 +1,103 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" +) + +func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + ids := extractModelIDs(resp.Body) + cc.assert("non_empty", len(ids) > 0, fmt.Sprintf("models=%v", ids)) + return nil +} +func (r *Runner) caseAnthropicNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/anthropic/v1/messages", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + Body: map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("type_message", asString(m["type"]) == "message", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseAnthropicStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/anthropic/v1/messages", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + Body: map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "stream hello"}, + }, + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + events := parseClaudeStreamEvents(resp.Body) + cc.assert("has_message_start", contains(events, "message_start"), fmt.Sprintf("events=%v", events)) + cc.assert("has_message_stop", contains(events, "message_stop"), fmt.Sprintf("events=%v", events)) + return nil +} + +func (r *Runner) caseAnthropicCountTokens(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/anthropic/v1/messages/count_tokens", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + Body: map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "count me"}, + }, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + v := toInt(m["input_tokens"]) + cc.assert("input_tokens_gt_zero", v > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} diff --git a/internal/testsuite/runner_cases_openai.go b/internal/testsuite/runner_cases_openai.go new file mode 100644 index 0000000..4ca2e40 --- /dev/null +++ b/internal/testsuite/runner_cases_openai.go @@ -0,0 +1,221 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" +) + +func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/healthz", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/readyz", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + ids := extractModelIDs(resp.Body) + cc.assert("has_deepseek_chat", contains(ids, "deepseek-chat"), strings.Join(ids, ",")) + cc.assert("has_deepseek_reasoner", contains(ids, "deepseek-reasoner"), strings.Join(ids, ",")) + return nil +} + +func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} +func (r *Runner) caseChatNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "请简单回复一句话"}, + }, + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_chat_completion", asString(m["object"]) == "chat.completion", fmt.Sprintf("body=%s", string(resp.Body))) + choices, _ := m["choices"].([]any) + cc.assert("choices_non_empty", len(choices) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "请流式回复一句话"}, + }, + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请简要回答 hello", + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body))) + responseID := asString(m["id"]) + cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body))) + if responseID != "" { + getResp, getErr := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/v1/responses/" + responseID, + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Retryable: true, + }) + if getErr != nil { + return getErr + } + cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode)) + } + return nil +} + +func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请流式回答 hello", + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) + hasCreated := false + hasCompleted := false + for _, f := range frames { + switch asString(f["type"]) { + case "response.created": + hasCreated = true + case "response.completed": + hasCompleted = true + } + } + cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/embeddings", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": []string{"hello", "world"}, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + if resp.StatusCode == http.StatusOK { + cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body))) + data, _ := m["data"].([]any) + cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil + } + errObj, _ := m["error"].(map[string]any) + _, hasCode := errObj["code"] + _, hasParam := errObj["param"] + cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} diff --git a/internal/testsuite/runner_cases_openai_advanced.go b/internal/testsuite/runner_cases_openai_advanced.go new file mode 100644 index 0000000..34e9f01 --- /dev/null +++ b/internal/testsuite/runner_cases_openai_advanced.go @@ -0,0 +1,236 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" +) + +func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-reasoner", + "messages": []map[string]any{ + {"role": "user", "content": "先思考后回答:1+1"}, + }, + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + hasReasoning := false + for _, f := range frames { + choices, _ := f["choices"].([]any) + for _, c := range choices { + ch, _ := c.(map[string]any) + delta, _ := ch["delta"].(map[string]any) + if asString(delta["reasoning_content"]) != "" { + hasReasoning = true + } + } + } + cc.assert("has_reasoning_content", hasReasoning, "reasoning_content not found") + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseToolcallNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: toolcallPayload(false), + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + choices, _ := m["choices"].([]any) + if len(choices) == 0 { + cc.assert("choices_non_empty", false, fmt.Sprintf("body=%s", string(resp.Body))) + return nil + } + c0, _ := choices[0].(map[string]any) + cc.assert("finish_reason_tool_calls", asString(c0["finish_reason"]) == "tool_calls", fmt.Sprintf("body=%s", string(resp.Body))) + msg, _ := c0["message"].(map[string]any) + tc, _ := msg["tool_calls"].([]any) + cc.assert("tool_calls_present", len(tc) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseToolcallStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: toolcallPayload(true), + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + hasTool := false + rawLeak := false + for _, f := range frames { + choices, _ := f["choices"].([]any) + for _, c := range choices { + ch, _ := c.(map[string]any) + delta, _ := ch["delta"].(map[string]any) + if _, ok := delta["tool_calls"]; ok { + hasTool = true + } + content := asString(delta["content"]) + if strings.Contains(strings.ToLower(content), `"tool_calls"`) { + rawLeak = true + } + } + } + cc.assert("tool_calls_delta_present", hasTool, "tool_calls delta missing") + cc.assert("no_raw_tool_json_leak", !rawLeak, "raw tool_calls JSON leaked in content") + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseConcurrencyBurst(ctx context.Context, cc *caseContext) error { + accountCount := len(r.configRaw.Accounts) + n := accountCount*2 + 2 + if n < 2 { + n = 2 + } + type one struct { + Status int + Err string + } + results := make([]one, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": fmt.Sprintf("并发请求 #%d,请回复ok", idx)}, + }, + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + results[idx] = one{Err: err.Error()} + return + } + results[idx] = one{Status: resp.StatusCode} + }(i) + } + wg.Wait() + + dist := map[int]int{} + success := 0 + for _, it := range results { + if it.Status > 0 { + dist[it.Status]++ + if it.Status == http.StatusOK { + success++ + } + } + } + cc.assert("success_gt_zero", success > 0, fmt.Sprintf("distribution=%v", dist)) + _, has5xx := has5xx(dist) + cc.assert("no_5xx", !has5xx, fmt.Sprintf("distribution=%v", dist)) + if err := r.ping("/healthz"); err != nil { + cc.assert("server_alive", false, err.Error()) + } else { + cc.assert("server_alive", true, "") + } + return nil +} + +func (r *Runner) caseInvalidKey(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer invalid-testsuite-key-" + sanitizeID(r.runID), + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + e, _ := m["error"].(map[string]any) + cc.assert("error_object_present", len(e) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("error_message_present", asString(e["message"]) != "", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func toolcallPayload(stream bool) map[string]any { + return map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + { + "role": "user", + "content": "你必须调用工具 search 查询 golang,并仅返回工具调用。", + }, + }, + "tools": []map[string]any{ + { + "type": "function", + "function": map[string]any{ + "name": "search", + "description": "search documents", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "q": map[string]any{ + "type": "string", + }, + }, + "required": []string{"q"}, + }, + }, + }, + }, + "stream": stream, + } +} diff --git a/internal/testsuite/runner_core.go b/internal/testsuite/runner_core.go new file mode 100644 index 0000000..06eafa5 --- /dev/null +++ b/internal/testsuite/runner_core.go @@ -0,0 +1,290 @@ +package testsuite + +import ( + "context" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "sync" + "time" +) + +type Options struct { + ConfigPath string + AdminKey string + OutputDir string + Port int + Timeout time.Duration + Retries int + NoPreflight bool + MaxKeepRuns int +} + +type runSummary struct { + RunID string `json:"run_id"` + StartedAt string `json:"started_at"` + EndedAt string `json:"ended_at"` + DurationMS int64 `json:"duration_ms"` + Stats map[string]any `json:"stats"` + Environment map[string]any `json:"environment"` + Cases []caseResult `json:"cases"` + Warnings []string `json:"warnings,omitempty"` +} + +type caseResult struct { + CaseID string `json:"case_id"` + Passed bool `json:"passed"` + DurationMS int64 `json:"duration_ms"` + TraceIDs []string `json:"trace_ids"` + StatusCodes []int `json:"status_codes"` + Error string `json:"error,omitempty"` + ArtifactPath string `json:"artifact_path"` + Assertions []assertionResult `json:"assertions"` +} + +type assertionResult struct { + Name string `json:"name"` + Passed bool `json:"passed"` + Detail string `json:"detail,omitempty"` +} + +type requestLog struct { + Seq int `json:"seq"` + Attempt int `json:"attempt"` + TraceID string `json:"trace_id"` + Method string `json:"method"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Body any `json:"body,omitempty"` + Timestamp string `json:"timestamp"` +} + +type responseLog struct { + Seq int `json:"seq"` + Attempt int `json:"attempt"` + TraceID string `json:"trace_id"` + StatusCode int `json:"status_code"` + Headers map[string][]string `json:"headers"` + BodyText string `json:"body_text"` + DurationMS int64 `json:"duration_ms"` + NetworkErr string `json:"network_error,omitempty"` + ReceivedAt string `json:"received_at"` +} + +type caseContext struct { + runner *Runner + id string + dir string + startedAt time.Time + mu sync.Mutex + seq int + assertions []assertionResult + requests []requestLog + responses []responseLog + streamRaw strings.Builder + traceIDsSet map[string]struct{} +} + +type requestSpec struct { + Method string + Path string + Headers map[string]string + Body any + Stream bool + Retryable bool +} + +type responseResult struct { + StatusCode int + Headers http.Header + Body []byte + TraceID string + URL string +} + +type Runner struct { + opts Options + + runID string + runDir string + serverLog string + preflightLog string + + baseURL string + httpClient *http.Client + serverCmd *exec.Cmd + serverLogFd *os.File + + configCopyPath string + originalConfigPath string + originalConfigHash string + + configRaw runConfig + apiKey string + adminKey string + adminJWT string + accountID string + + warnings []string + results []caseResult +} + +type runConfig struct { + Keys []string `json:"keys"` + Accounts []struct { + Email string `json:"email,omitempty"` + Mobile string `json:"mobile,omitempty"` + Password string `json:"password,omitempty"` + Token string `json:"token,omitempty"` + } `json:"accounts"` +} + +func Run(ctx context.Context, opts Options) error { + r, err := newRunner(opts) + if err != nil { + return err + } + start := time.Now() + defer func() { + _ = r.stopServer() + }() + + if err := r.prepareRunDir(); err != nil { + return err + } + + if !r.opts.NoPreflight { + if err := r.runPreflight(ctx); err != nil { + _ = r.writeSummary(start, time.Now()) + return err + } + } + + if err := r.prepareConfigIsolation(); err != nil { + _ = r.writeSummary(start, time.Now()) + return err + } + + if err := r.startServer(ctx); err != nil { + _ = r.writeSummary(start, time.Now()) + return err + } + + if err := r.prepareAuth(ctx); err != nil { + r.warnings = append(r.warnings, "auth prepare failed: "+err.Error()) + } + + for _, c := range r.cases() { + r.runCase(ctx, c) + } + + if err := r.ensureOriginalConfigUntouched(); err != nil { + r.warnings = append(r.warnings, err.Error()) + } + + end := time.Now() + if err := r.writeSummary(start, end); err != nil { + return err + } + + // Prune old test runs, keeping only the most recent N. + if err := r.pruneOldRuns(); err != nil { + r.warnings = append(r.warnings, "prune old runs: "+err.Error()) + } + + failed := 0 + for _, cs := range r.results { + if !cs.Passed { + failed++ + } + } + if failed > 0 { + return fmt.Errorf("testsuite failed: %d case(s) failed, see %s", failed, filepath.Join(r.runDir, "summary.md")) + } + return nil +} + +func newRunner(opts Options) (*Runner, error) { + if strings.TrimSpace(opts.ConfigPath) == "" { + opts.ConfigPath = "config.json" + } + if strings.TrimSpace(opts.OutputDir) == "" { + opts.OutputDir = "artifacts/testsuite" + } + if opts.Timeout <= 0 { + opts.Timeout = 120 * time.Second + } + if opts.Retries < 0 { + opts.Retries = 0 + } + adminKey := strings.TrimSpace(opts.AdminKey) + if adminKey == "" { + adminKey = strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")) + } + if adminKey == "" { + adminKey = "admin" + } + opts.AdminKey = adminKey + + return &Runner{ + opts: opts, + httpClient: &http.Client{ + Timeout: 0, + }, + runID: time.Now().UTC().Format("20060102T150405Z"), + adminKey: adminKey, + }, nil +} +func (r *Runner) runCase(ctx context.Context, c caseDef) { + caseDir := filepath.Join(r.runDir, "cases", c.ID) + _ = os.MkdirAll(caseDir, 0o755) + cc := &caseContext{ + runner: r, + id: c.ID, + dir: caseDir, + startedAt: time.Now(), + traceIDsSet: map[string]struct{}{}, + } + err := c.Run(ctx, cc) + duration := time.Since(cc.startedAt).Milliseconds() + + if err != nil { + cc.assertions = append(cc.assertions, assertionResult{ + Name: "case_error", + Passed: false, + Detail: err.Error(), + }) + } + passed := err == nil + for _, a := range cc.assertions { + if !a.Passed { + passed = false + break + } + } + + traceIDs := make([]string, 0, len(cc.traceIDsSet)) + for t := range cc.traceIDsSet { + traceIDs = append(traceIDs, t) + } + sort.Strings(traceIDs) + statuses := uniqueStatusCodes(cc.responses) + cs := caseResult{ + CaseID: c.ID, + Passed: passed, + DurationMS: duration, + TraceIDs: traceIDs, + StatusCodes: statuses, + ArtifactPath: caseDir, + Assertions: cc.assertions, + } + if err != nil { + cs.Error = err.Error() + } + _ = cc.flushArtifacts(cs) + r.results = append(r.results, cs) +} diff --git a/internal/testsuite/runner_defaults.go b/internal/testsuite/runner_defaults.go new file mode 100644 index 0000000..ab30bf1 --- /dev/null +++ b/internal/testsuite/runner_defaults.go @@ -0,0 +1,20 @@ +package testsuite + +import ( + "os" + "strings" + "time" +) + +func DefaultOptions() Options { + return Options{ + ConfigPath: "config.json", + AdminKey: strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")), + OutputDir: "artifacts/testsuite", + Port: 0, + Timeout: 120 * time.Second, + Retries: 2, + NoPreflight: false, + MaxKeepRuns: 5, + } +} diff --git a/internal/testsuite/runner_env.go b/internal/testsuite/runner_env.go new file mode 100644 index 0000000..a953936 --- /dev/null +++ b/internal/testsuite/runner_env.go @@ -0,0 +1,264 @@ +package testsuite + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "sort" + "strconv" + "strings" + "time" +) + +func (r *Runner) prepareRunDir() error { + r.runDir = filepath.Join(r.opts.OutputDir, r.runID) + if err := os.MkdirAll(r.runDir, 0o755); err != nil { + return err + } + if err := os.MkdirAll(filepath.Join(r.runDir, "cases"), 0o755); err != nil { + return err + } + r.serverLog = filepath.Join(r.runDir, "server.log") + r.preflightLog = filepath.Join(r.runDir, "preflight.log") + return nil +} + +// pruneOldRuns removes old test run directories, keeping the most recent MaxKeepRuns. +// Run IDs use the format "20060102T150405Z", so alphabetical order == chronological order. +func (r *Runner) pruneOldRuns() error { + keep := r.opts.MaxKeepRuns + if keep <= 0 { + return nil // 0 or negative means no pruning + } + + entries, err := os.ReadDir(r.opts.OutputDir) + if err != nil { + return err + } + + // Collect only directories (each run is a directory). + var runDirs []string + for _, e := range entries { + if !e.IsDir() { + continue + } + runDirs = append(runDirs, e.Name()) + } + + sort.Strings(runDirs) + + if len(runDirs) <= keep { + return nil + } + + // Remove oldest runs (those at the beginning of the sorted list). + toRemove := runDirs[:len(runDirs)-keep] + var errs []string + for _, name := range toRemove { + dirPath := filepath.Join(r.opts.OutputDir, name) + if err := os.RemoveAll(dirPath); err != nil { + errs = append(errs, fmt.Sprintf("remove %s: %v", name, err)) + } else { + fmt.Fprintf(os.Stdout, "pruned old test run: %s\n", name) + } + } + + if len(errs) > 0 { + return errors.New(strings.Join(errs, "; ")) + } + return nil +} + +func (r *Runner) runPreflight(ctx context.Context) error { + steps := preflightSteps() + f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return err + } + defer f.Close() + for _, step := range steps { + if _, err := fmt.Fprintf(f, "\n$ %s\n", strings.Join(step, " ")); err != nil { + return err + } + cmd := exec.CommandContext(ctx, step[0], step[1:]...) + cmd.Stdout = f + cmd.Stderr = f + if err := cmd.Run(); err != nil { + return fmt.Errorf("preflight failed at `%s`: %w", strings.Join(step, " "), err) + } + } + return nil +} + +func preflightSteps() [][]string { + return [][]string{ + {"go", "test", "./...", "-count=1"}, + {"./tests/scripts/check-node-split-syntax.sh"}, + {"node", "--test", "tests/node/stream-tool-sieve.test.js", "tests/node/chat-stream.test.js", "tests/node/js_compat_test.js"}, + {"npm", "run", "build", "--prefix", "webui"}, + } +} + +func (r *Runner) prepareConfigIsolation() error { + abs, err := filepath.Abs(r.opts.ConfigPath) + if err != nil { + return err + } + r.originalConfigPath = abs + raw, err := os.ReadFile(abs) + if err != nil { + return err + } + sum := sha256.Sum256(raw) + r.originalConfigHash = hex.EncodeToString(sum[:]) + + tmpDir := filepath.Join(r.runDir, "tmp") + if err := os.MkdirAll(tmpDir, 0o755); err != nil { + return err + } + r.configCopyPath = filepath.Join(tmpDir, "config.json") + if err := os.WriteFile(r.configCopyPath, raw, 0o644); err != nil { + return err + } + var cfg runConfig + if err := json.Unmarshal(raw, &cfg); err != nil { + return fmt.Errorf("parse config failed: %w", err) + } + r.configRaw = cfg + if len(cfg.Keys) > 0 { + r.apiKey = strings.TrimSpace(cfg.Keys[0]) + } + for _, acc := range cfg.Accounts { + id := strings.TrimSpace(acc.Email) + if id == "" { + id = strings.TrimSpace(acc.Mobile) + } + if id != "" { + r.accountID = id + break + } + } + return nil +} + +func (r *Runner) startServer(ctx context.Context) error { + port := r.opts.Port + if port <= 0 { + p, err := findFreePort() + if err != nil { + return err + } + port = p + } + r.baseURL = "http://127.0.0.1:" + strconv.Itoa(port) + + logFd, err := os.OpenFile(r.serverLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return err + } + r.serverLogFd = logFd + cmd := exec.CommandContext(ctx, "go", "run", "./cmd/ds2api") + cmd.Stdout = logFd + cmd.Stderr = logFd + cmd.Env = prepareServerEnv(os.Environ(), map[string]string{ + "PORT": strconv.Itoa(port), + "DS2API_CONFIG_PATH": r.configCopyPath, + "DS2API_AUTO_BUILD_WEBUI": "false", + "DS2API_CONFIG_JSON": "", + "CONFIG_JSON": "", + }) + if err := cmd.Start(); err != nil { + _ = logFd.Close() + return err + } + r.serverCmd = cmd + + deadline := time.Now().Add(90 * time.Second) + for time.Now().Before(deadline) { + if r.ping("/healthz") == nil && r.ping("/readyz") == nil { + return nil + } + time.Sleep(500 * time.Millisecond) + } + return errors.New("server readiness timeout") +} + +func (r *Runner) stopServer() error { + var errs []string + if r.serverCmd != nil && r.serverCmd.Process != nil { + _ = r.serverCmd.Process.Signal(os.Interrupt) + done := make(chan error, 1) + go func() { done <- r.serverCmd.Wait() }() + select { + case <-time.After(5 * time.Second): + _ = r.serverCmd.Process.Kill() + <-done + case <-done: + } + } + if r.serverLogFd != nil { + if err := r.serverLogFd.Close(); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + return errors.New(strings.Join(errs, "; ")) + } + return nil +} + +func (r *Runner) ping(path string) error { + resp, err := r.httpClient.Get(r.baseURL + path) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status=%d", resp.StatusCode) + } + return nil +} + +func (r *Runner) prepareAuth(ctx context.Context) error { + reqBody := map[string]any{ + "admin_key": r.adminKey, + "expire_hours": 24, + } + resp, err := r.doSimpleJSON(ctx, http.MethodPost, "/admin/login", nil, reqBody) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("admin login status=%d body=%s", resp.StatusCode, string(resp.Body)) + } + var m map[string]any + if err := json.Unmarshal(resp.Body, &m); err != nil { + return err + } + token, _ := m["token"].(string) + if strings.TrimSpace(token) == "" { + return errors.New("empty admin jwt token") + } + r.adminJWT = token + return nil +} + +func (r *Runner) ensureOriginalConfigUntouched() error { + raw, err := os.ReadFile(r.originalConfigPath) + if err != nil { + return err + } + sum := sha256.Sum256(raw) + current := hex.EncodeToString(sum[:]) + if current != r.originalConfigHash { + return fmt.Errorf("original config changed unexpectedly: %s", r.originalConfigPath) + } + return nil +} diff --git a/internal/testsuite/runner_env_test.go b/internal/testsuite/runner_env_test.go new file mode 100644 index 0000000..98df72c --- /dev/null +++ b/internal/testsuite/runner_env_test.go @@ -0,0 +1,20 @@ +package testsuite + +import ( + "reflect" + "testing" +) + +func TestPreflightStepsExactSequence(t *testing.T) { + want := [][]string{ + {"go", "test", "./...", "-count=1"}, + {"./tests/scripts/check-node-split-syntax.sh"}, + {"node", "--test", "tests/node/stream-tool-sieve.test.js", "tests/node/chat-stream.test.js", "tests/node/js_compat_test.js"}, + {"npm", "run", "build", "--prefix", "webui"}, + } + + got := preflightSteps() + if !reflect.DeepEqual(got, want) { + t.Fatalf("preflight steps mismatch\nwant=%v\ngot=%v", want, got) + } +} diff --git a/internal/testsuite/runner_http.go b/internal/testsuite/runner_http.go new file mode 100644 index 0000000..d98c60a --- /dev/null +++ b/internal/testsuite/runner_http.go @@ -0,0 +1,217 @@ +package testsuite + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +func (cc *caseContext) assert(name string, ok bool, detail string) { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.assertions = append(cc.assertions, assertionResult{ + Name: name, + Passed: ok, + Detail: detail, + }) +} + +func (cc *caseContext) request(ctx context.Context, spec requestSpec) (*responseResult, error) { + retries := cc.runner.opts.Retries + if !spec.Retryable { + retries = 0 + } + var lastErr error + for attempt := 1; attempt <= retries+1; attempt++ { + resp, err := cc.requestOnce(ctx, spec, attempt) + if err == nil && resp.StatusCode < 500 { + return resp, nil + } + if err != nil { + lastErr = err + } else if resp.StatusCode >= 500 { + lastErr = fmt.Errorf("status=%d", resp.StatusCode) + } + if attempt <= retries { + sleep := time.Duration(300*(1<<(attempt-1))) * time.Millisecond + time.Sleep(sleep) + } + } + return nil, lastErr +} + +func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attempt int) (*responseResult, error) { + cc.mu.Lock() + cc.seq++ + seq := cc.seq + traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), seq) + cc.traceIDsSet[traceID] = struct{}{} + cc.mu.Unlock() + + fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) + if err != nil { + return nil, err + } + + headers := map[string]string{} + for k, v := range spec.Headers { + headers[k] = v + } + headers["X-Ds2-Test-Trace"] = traceID + + var bodyBytes []byte + var bodyAny any + if spec.Body != nil { + b, err := json.Marshal(spec.Body) + if err != nil { + return nil, err + } + bodyBytes = b + bodyAny = spec.Body + headers["Content-Type"] = "application/json" + } + cc.mu.Lock() + cc.requests = append(cc.requests, requestLog{ + Seq: seq, + Attempt: attempt, + TraceID: traceID, + Method: spec.Method, + URL: fullURL, + Headers: headers, + Body: bodyAny, + Timestamp: time.Now().Format(time.RFC3339Nano), + }) + cc.mu.Unlock() + + reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + start := time.Now() + resp, err := cc.runner.httpClient.Do(req) + if err != nil { + cc.mu.Lock() + cc.responses = append(cc.responses, responseLog{ + Seq: seq, + Attempt: attempt, + TraceID: traceID, + StatusCode: 0, + DurationMS: time.Since(start).Milliseconds(), + NetworkErr: err.Error(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + cc.mu.Unlock() + return nil, err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + + cc.mu.Lock() + cc.responses = append(cc.responses, responseLog{ + Seq: seq, + Attempt: attempt, + TraceID: traceID, + StatusCode: resp.StatusCode, + Headers: resp.Header, + BodyText: string(body), + DurationMS: time.Since(start).Milliseconds(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + + if spec.Stream { + cc.streamRaw.WriteString(fmt.Sprintf("### trace=%s url=%s\n", traceID, fullURL)) + cc.streamRaw.Write(body) + cc.streamRaw.WriteString("\n\n") + } + cc.mu.Unlock() + + return &responseResult{ + StatusCode: resp.StatusCode, + Headers: resp.Header, + Body: body, + TraceID: traceID, + URL: fullURL, + }, nil +} + +func (cc *caseContext) flushArtifacts(cs caseResult) error { + requestPath := filepath.Join(cc.dir, "request.json") + headersPath := filepath.Join(cc.dir, "response.headers") + bodyPath := filepath.Join(cc.dir, "response.body") + streamPath := filepath.Join(cc.dir, "stream.raw") + assertPath := filepath.Join(cc.dir, "assertions.json") + metaPath := filepath.Join(cc.dir, "meta.json") + + if err := writeJSONFile(requestPath, cc.requests); err != nil { + return err + } + respHeaders := make([]map[string]any, 0, len(cc.responses)) + respBodies := make([]map[string]any, 0, len(cc.responses)) + for _, r := range cc.responses { + respHeaders = append(respHeaders, map[string]any{ + "seq": r.Seq, + "attempt": r.Attempt, + "trace_id": r.TraceID, + "status_code": r.StatusCode, + "headers": r.Headers, + }) + respBodies = append(respBodies, map[string]any{ + "seq": r.Seq, + "attempt": r.Attempt, + "trace_id": r.TraceID, + "status_code": r.StatusCode, + "body_text": r.BodyText, + "network_error": r.NetworkErr, + "duration_ms": r.DurationMS, + }) + } + if err := writeJSONFile(headersPath, respHeaders); err != nil { + return err + } + if err := writeJSONFile(bodyPath, respBodies); err != nil { + return err + } + if err := os.WriteFile(streamPath, []byte(cc.streamRaw.String()), 0o644); err != nil { + return err + } + if err := writeJSONFile(assertPath, cc.assertions); err != nil { + return err + } + meta := map[string]any{ + "case_id": cs.CaseID, + "trace_id": strings.Join(cs.TraceIDs, ","), + "attempt": len(cc.responses), + "duration_ms": cs.DurationMS, + "status": map[bool]string{true: "passed", false: "failed"}[cs.Passed], + "status_codes": cs.StatusCodes, + "assertions": cs.Assertions, + "artifact_path": cs.ArtifactPath, + } + return writeJSONFile(metaPath, meta) +} +func (r *Runner) doSimpleJSON(ctx context.Context, method, path string, headers map[string]string, body any) (*responseResult, error) { + cc := &caseContext{ + runner: r, + id: "auth_prepare", + traceIDsSet: map[string]struct{}{}, + } + return cc.request(ctx, requestSpec{ + Method: method, + Path: path, + Headers: headers, + Body: body, + Retryable: true, + }) +} diff --git a/internal/testsuite/runner_registry.go b/internal/testsuite/runner_registry.go new file mode 100644 index 0000000..08b602a --- /dev/null +++ b/internal/testsuite/runner_registry.go @@ -0,0 +1,43 @@ +package testsuite + +import "context" + +type caseDef struct { + ID string + Run func(context.Context, *caseContext) error +} + +func (r *Runner) cases() []caseDef { + return []caseDef{ + {ID: "healthz_ok", Run: r.caseHealthz}, + {ID: "readyz_ok", Run: r.caseReadyz}, + {ID: "models_openai", Run: r.caseModelsOpenAI}, + {ID: "model_openai_by_id", Run: r.caseModelOpenAIByID}, + {ID: "models_claude", Run: r.caseModelsClaude}, + {ID: "admin_login_verify", Run: r.caseAdminLoginVerify}, + {ID: "admin_queue_status", Run: r.caseAdminQueueStatus}, + {ID: "chat_nonstream_basic", Run: r.caseChatNonstream}, + {ID: "chat_stream_basic", Run: r.caseChatStream}, + {ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream}, + {ID: "responses_stream_basic", Run: r.caseResponsesStream}, + {ID: "embeddings_contract", Run: r.caseEmbeddings}, + {ID: "reasoner_stream", Run: r.caseReasonerStream}, + {ID: "toolcall_nonstream", Run: r.caseToolcallNonstream}, + {ID: "toolcall_stream", Run: r.caseToolcallStream}, + {ID: "anthropic_messages_nonstream", Run: r.caseAnthropicNonstream}, + {ID: "anthropic_messages_stream", Run: r.caseAnthropicStream}, + {ID: "anthropic_count_tokens", Run: r.caseAnthropicCountTokens}, + {ID: "admin_account_test_single", Run: r.caseAdminAccountTest}, + {ID: "concurrency_burst", Run: r.caseConcurrencyBurst}, + {ID: "concurrency_threshold_limit", Run: r.caseConcurrencyThresholdLimit}, + {ID: "stream_abort_release", Run: r.caseStreamAbortRelease}, + {ID: "toolcall_stream_mixed", Run: r.caseToolcallStreamMixed}, + {ID: "sse_json_integrity", Run: r.caseSSEJSONIntegrity}, + {ID: "error_contract_invalid_model", Run: r.caseInvalidModel}, + {ID: "error_contract_missing_messages", Run: r.caseMissingMessages}, + {ID: "admin_unauthorized_contract", Run: r.caseAdminUnauthorized}, + {ID: "config_write_isolated", Run: r.caseConfigWriteIsolated}, + {ID: "token_refresh_managed_account", Run: r.caseTokenRefreshManagedAccount}, + {ID: "error_contract_invalid_key", Run: r.caseInvalidKey}, + } +} diff --git a/internal/testsuite/runner_registry_test.go b/internal/testsuite/runner_registry_test.go new file mode 100644 index 0000000..5e5cd7e --- /dev/null +++ b/internal/testsuite/runner_registry_test.go @@ -0,0 +1,85 @@ +package testsuite + +import ( + "sort" + "testing" +) + +func TestRunnerCasesRegistryExactSet(t *testing.T) { + r := &Runner{} + got := r.cases() + wantIDs := []string{ + "healthz_ok", + "readyz_ok", + "models_openai", + "model_openai_by_id", + "models_claude", + "admin_login_verify", + "admin_queue_status", + "chat_nonstream_basic", + "chat_stream_basic", + "responses_nonstream_basic", + "responses_stream_basic", + "embeddings_contract", + "reasoner_stream", + "toolcall_nonstream", + "toolcall_stream", + "anthropic_messages_nonstream", + "anthropic_messages_stream", + "anthropic_count_tokens", + "admin_account_test_single", + "concurrency_burst", + "concurrency_threshold_limit", + "stream_abort_release", + "toolcall_stream_mixed", + "sse_json_integrity", + "error_contract_invalid_model", + "error_contract_missing_messages", + "admin_unauthorized_contract", + "config_write_isolated", + "token_refresh_managed_account", + "error_contract_invalid_key", + } + + if len(got) != len(wantIDs) { + t.Fatalf("unexpected case count: got=%d want=%d", len(got), len(wantIDs)) + } + + wantSet := map[string]struct{}{} + for _, id := range wantIDs { + wantSet[id] = struct{}{} + } + + gotSet := map[string]struct{}{} + for i, cs := range got { + if cs.ID == "" { + t.Fatalf("case[%d] has empty ID", i) + } + if cs.Run == nil { + t.Fatalf("case[%d] (%s) has nil Run", i, cs.ID) + } + if _, exists := gotSet[cs.ID]; exists { + t.Fatalf("duplicate case ID: %s", cs.ID) + } + gotSet[cs.ID] = struct{}{} + } + + var missing []string + for id := range wantSet { + if _, ok := gotSet[id]; !ok { + missing = append(missing, id) + } + } + var extra []string + for id := range gotSet { + if _, ok := wantSet[id]; !ok { + extra = append(extra, id) + } + } + sort.Strings(missing) + sort.Strings(extra) + + if len(missing) > 0 || len(extra) > 0 { + t.Fatalf("registry mismatch: missing=%v extra=%v", missing, extra) + } +} diff --git a/internal/testsuite/runner_summary.go b/internal/testsuite/runner_summary.go new file mode 100644 index 0000000..25b44a4 --- /dev/null +++ b/internal/testsuite/runner_summary.go @@ -0,0 +1,97 @@ +package testsuite + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +func (r *Runner) writeSummary(start, end time.Time) error { + passed := 0 + failed := 0 + for _, cs := range r.results { + if cs.Passed { + passed++ + } else { + failed++ + } + } + summary := runSummary{ + RunID: r.runID, + StartedAt: start.Format(time.RFC3339Nano), + EndedAt: end.Format(time.RFC3339Nano), + DurationMS: end.Sub(start).Milliseconds(), + Stats: map[string]any{ + "total": len(r.results), + "passed": passed, + "failed": failed, + }, + Environment: map[string]any{ + "go_version": runtime.Version(), + "os": runtime.GOOS, + "arch": runtime.GOARCH, + "base_url": r.baseURL, + "config_source": r.originalConfigPath, + "config_isolated": r.configCopyPath, + "server_log": r.serverLog, + "preflight_log": r.preflightLog, + "retries": r.opts.Retries, + "timeout_seconds": int(r.opts.Timeout.Seconds()), + }, + Cases: r.results, + Warnings: r.warnings, + } + if err := writeJSONFile(filepath.Join(r.runDir, "summary.json"), summary); err != nil { + return err + } + return os.WriteFile(filepath.Join(r.runDir, "summary.md"), []byte(r.summaryMarkdown(summary)), 0o644) +} + +func (r *Runner) summaryMarkdown(s runSummary) string { + var b strings.Builder + b.WriteString("# DS2API Live Testsuite Summary\n\n") + b.WriteString("**Sensitive Notice:** this run stores full raw request/response logs. Do not share artifacts publicly.\n\n") + fmt.Fprintf(&b, "- Run ID: `%s`\n", s.RunID) + fmt.Fprintf(&b, "- Started: `%s`\n", s.StartedAt) + fmt.Fprintf(&b, "- Ended: `%s`\n", s.EndedAt) + fmt.Fprintf(&b, "- Duration: `%d ms`\n", s.DurationMS) + fmt.Fprintf(&b, "- Passed/Failed: `%d/%d`\n\n", s.Stats["passed"], s.Stats["failed"]) + if len(s.Warnings) > 0 { + b.WriteString("## Warnings\n\n") + for _, w := range s.Warnings { + fmt.Fprintf(&b, "- %s\n", w) + } + b.WriteString("\n") + } + b.WriteString("## Failed Cases\n\n") + hasFailed := false + for _, c := range s.Cases { + if c.Passed { + continue + } + hasFailed = true + fmt.Fprintf(&b, "- `%s`: %s\n", c.CaseID, c.Error) + if len(c.TraceIDs) > 0 { + fmt.Fprintf(&b, " - trace_ids: `%s`\n", strings.Join(c.TraceIDs, ", ")) + fmt.Fprintf(&b, " - grep: `rg \"%s\" %s`\n", c.TraceIDs[0], filepath.Join(r.runDir, "server.log")) + } + fmt.Fprintf(&b, " - artifact: `%s`\n", c.ArtifactPath) + } + if !hasFailed { + b.WriteString("- none\n") + } + b.WriteString("\n## Case Table\n\n") + b.WriteString("| case_id | status | duration_ms | statuses | artifact |\n") + b.WriteString("|---|---:|---:|---|---|\n") + for _, c := range s.Cases { + status := "PASS" + if !c.Passed { + status = "FAIL" + } + fmt.Fprintf(&b, "| %s | %s | %d | %v | `%s` |\n", c.CaseID, status, c.DurationMS, c.StatusCodes, c.ArtifactPath) + } + return b.String() +} diff --git a/internal/testsuite/runner_utils.go b/internal/testsuite/runner_utils.go new file mode 100644 index 0000000..c4879c6 --- /dev/null +++ b/internal/testsuite/runner_utils.go @@ -0,0 +1,202 @@ +package testsuite + +import ( + "encoding/json" + "errors" + "fmt" + "net" + "net/url" + "os" + "sort" + "strings" +) + +func parseSSEFrames(body []byte) ([]map[string]any, bool) { + lines := strings.Split(string(body), "\n") + frames := make([]map[string]any, 0, len(lines)) + done := false + for _, line := range lines { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" { + continue + } + if payload == "[DONE]" { + done = true + continue + } + var m map[string]any + if err := json.Unmarshal([]byte(payload), &m); err == nil { + frames = append(frames, m) + } + } + return frames, done +} + +func parseClaudeStreamEvents(body []byte) []string { + events := []string{} + seen := map[string]bool{} + lines := strings.Split(string(body), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" { + continue + } + var m map[string]any + if err := json.Unmarshal([]byte(payload), &m); err != nil { + continue + } + t := asString(m["type"]) + if t == "" || seen[t] { + continue + } + seen[t] = true + events = append(events, t) + } + return events +} + +func extractModelIDs(body []byte) []string { + var m map[string]any + if err := json.Unmarshal(body, &m); err != nil { + return nil + } + out := []string{} + data, _ := m["data"].([]any) + for _, it := range data { + item, _ := it.(map[string]any) + id := asString(item["id"]) + if id != "" { + out = append(out, id) + } + } + return out +} + +func withTraceQuery(rawURL, traceID string) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", err + } + q := u.Query() + q.Set("__trace_id", traceID) + u.RawQuery = q.Encode() + return u.String(), nil +} + +func writeJSONFile(path string, v any) error { + b, err := json.MarshalIndent(v, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, b, 0o644) +} + +func prepareServerEnv(base []string, overrides map[string]string) []string { + out := make([]string, 0, len(base)+len(overrides)) + skip := map[string]struct{}{} + for k := range overrides { + skip[k] = struct{}{} + } + for _, e := range base { + parts := strings.SplitN(e, "=", 2) + if len(parts) != 2 { + continue + } + if _, ok := skip[parts[0]]; ok { + continue + } + out = append(out, e) + } + for k, v := range overrides { + out = append(out, k+"="+v) + } + return out +} + +func findFreePort() (int, error) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return 0, err + } + defer ln.Close() + addr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + return 0, errors.New("failed to detect tcp port") + } + return addr.Port, nil +} + +func uniqueStatusCodes(in []responseLog) []int { + set := map[int]struct{}{} + for _, it := range in { + if it.StatusCode > 0 { + set[it.StatusCode] = struct{}{} + } + } + out := make([]int, 0, len(set)) + for k := range set { + out = append(out, k) + } + sort.Ints(out) + return out +} + +func has5xx(dist map[int]int) (int, bool) { + for k := range dist { + if k >= 500 { + return k, true + } + } + return 0, false +} + +func sanitizeID(s string) string { + s = strings.ReplaceAll(s, ":", "_") + s = strings.ReplaceAll(s, "/", "_") + s = strings.ReplaceAll(s, " ", "_") + return s +} + +func asString(v any) string { + if v == nil { + return "" + } + switch x := v.(type) { + case string: + return strings.TrimSpace(x) + default: + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } +} + +func toInt(v any) int { + switch x := v.(type) { + case float64: + return int(x) + case float32: + return int(x) + case int: + return x + case int64: + return int(x) + default: + return 0 + } +} + +func contains(xs []string, target string) bool { + for _, x := range xs { + if x == target { + return true + } + } + return false +} diff --git a/internal/util/messages.go b/internal/util/messages.go index 19f2948..b6920c0 100644 --- a/internal/util/messages.go +++ b/internal/util/messages.go @@ -1,14 +1,11 @@ package util import ( - "regexp" - "strings" - + "ds2api/internal/claudeconv" "ds2api/internal/config" + "ds2api/internal/prompt" ) -var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) - const ClaudeDefaultModel = "claude-sonnet-4-5" type Message struct { @@ -17,102 +14,15 @@ type Message struct { } func MessagesPrepare(messages []map[string]any) string { - type block struct { - Role string - Text string - } - processed := make([]block, 0, len(messages)) - for _, m := range messages { - role, _ := m["role"].(string) - text := normalizeContent(m["content"]) - processed = append(processed, block{Role: role, Text: text}) - } - if len(processed) == 0 { - return "" - } - merged := make([]block, 0, len(processed)) - for _, msg := range processed { - if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role { - merged[len(merged)-1].Text += "\n\n" + msg.Text - continue - } - merged = append(merged, msg) - } - parts := make([]string, 0, len(merged)) - for i, m := range merged { - switch m.Role { - case "assistant": - parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>") - case "user", "system": - if i > 0 { - parts = append(parts, "<|User|>"+m.Text) - } else { - parts = append(parts, m.Text) - } - default: - parts = append(parts, m.Text) - } - } - out := strings.Join(parts, "") - return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`) + return prompt.MessagesPrepare(messages) } func normalizeContent(v any) string { - switch x := v.(type) { - case string: - return x - case []any: - parts := make([]string, 0, len(x)) - for _, item := range x { - m, ok := item.(map[string]any) - if !ok { - continue - } - if m["type"] == "text" { - if txt, ok := m["text"].(string); ok { - parts = append(parts, txt) - } - } - } - return strings.Join(parts, "\n") - default: - return "" - } + return prompt.NormalizeContent(v) } func ConvertClaudeToDeepSeek(claudeReq map[string]any, store *config.Store) map[string]any { - messages, _ := claudeReq["messages"].([]any) - model, _ := claudeReq["model"].(string) - if model == "" { - model = ClaudeDefaultModel - } - mapping := store.ClaudeMapping() - dsModel := mapping["fast"] - if dsModel == "" { - dsModel = "deepseek-chat" - } - modelLower := strings.ToLower(model) - if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") { - if slow := mapping["slow"]; slow != "" { - dsModel = slow - } - } - convertedMessages := make([]any, 0, len(messages)+1) - if system, ok := claudeReq["system"].(string); ok && system != "" { - convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system}) - } - convertedMessages = append(convertedMessages, messages...) - - out := map[string]any{"model": dsModel, "messages": convertedMessages} - for _, k := range []string{"temperature", "top_p", "stream"} { - if v, ok := claudeReq[k]; ok { - out[k] = v - } - } - if stopSeq, ok := claudeReq["stop_sequences"]; ok { - out["stop"] = stopSeq - } - return out + return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, ClaudeDefaultModel) } // EstimateTokens provides a rough token count approximation. diff --git a/internal/util/messages_test.go b/internal/util/messages_test.go index 30b8cc0..776853b 100644 --- a/internal/util/messages_test.go +++ b/internal/util/messages_test.go @@ -33,6 +33,33 @@ func TestMessagesPrepareRoles(t *testing.T) { } } +func TestMessagesPrepareObjectContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": map[string]any{"temp": 18, "ok": true}}, + } + got := MessagesPrepare(messages) + if !contains(got, `"temp":18`) || !contains(got, `"ok":true`) { + t.Fatalf("expected serialized object content, got %q", got) + } +} + +func TestMessagesPrepareArrayTextVariants(t *testing.T) { + messages := []map[string]any{ + { + "role": "user", + "content": []any{ + map[string]any{"type": "output_text", "text": "line1"}, + map[string]any{"type": "input_text", "text": "line2"}, + map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"}, + }, + }, + } + got := MessagesPrepare(messages) + if got != "line1\nline2" { + t.Fatalf("unexpected content from text variants: %q", got) + } +} + func TestConvertClaudeToDeepSeek(t *testing.T) { store := config.LoadStore() req := map[string]any{ diff --git a/internal/util/render.go b/internal/util/render.go new file mode 100644 index 0000000..fff8501 --- /dev/null +++ b/internal/util/render.go @@ -0,0 +1,146 @@ +package util + +import ( + "fmt" + "strings" + "time" + + "github.com/google/uuid" +) + +// BuildOpenAIChatCompletion is kept for backward compatibility. +// Prefer internal/format/openai.BuildChatCompletion for new code. +func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := ParseToolCalls(finalText, toolNames) + finishReason := "stop" + messageObj := map[string]any{"role": "assistant", "content": finalText} + if strings.TrimSpace(finalThinking) != "" { + messageObj["reasoning_content"] = finalThinking + } + if len(detected) > 0 { + finishReason = "tool_calls" + messageObj["tool_calls"] = FormatOpenAIToolCalls(detected) + messageObj["content"] = nil + } + promptTokens := EstimateTokens(finalPrompt) + reasoningTokens := EstimateTokens(finalThinking) + completionTokens := EstimateTokens(finalText) + + return map[string]any{ + "id": completionID, + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, + "usage": map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + }, + } +} + +// BuildOpenAIResponseObject is kept for backward compatibility. +// Prefer internal/format/openai.BuildResponseObject for new code. +func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := ParseToolCalls(finalText, toolNames) + exposedOutputText := finalText + output := make([]any, 0, 2) + if len(detected) > 0 { + // Keep structured tool output only; avoid leaking raw tool-call JSON + // into response.output_text for clients reading completed responses. + exposedOutputText = "" + toolCalls := make([]any, 0, len(detected)) + for _, tc := range detected { + toolCalls = append(toolCalls, map[string]any{ + "type": "tool_call", + "name": tc.Name, + "arguments": tc.Input, + }) + } + output = append(output, map[string]any{ + "type": "tool_calls", + "tool_calls": toolCalls, + }) + } else { + content := []any{ + map[string]any{ + "type": "output_text", + "text": finalText, + }, + } + if finalThinking != "" { + content = append([]any{map[string]any{ + "type": "reasoning", + "text": finalThinking, + }}, content...) + } + output = append(output, map[string]any{ + "type": "message", + "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "role": "assistant", + "content": content, + }) + } + promptTokens := EstimateTokens(finalPrompt) + reasoningTokens := EstimateTokens(finalThinking) + completionTokens := EstimateTokens(finalText) + return map[string]any{ + "id": responseID, + "type": "response", + "object": "response", + "created_at": time.Now().Unix(), + "status": "completed", + "model": model, + "output": output, + "output_text": exposedOutputText, + "usage": map[string]any{ + "input_tokens": promptTokens, + "output_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + }, + } +} + +// BuildClaudeMessageResponse is kept for backward compatibility. +// Prefer internal/format/claude.BuildMessageResponse for new code. +func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any { + detected := ParseToolCalls(finalText, toolNames) + content := make([]map[string]any, 0, 4) + if finalThinking != "" { + content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking}) + } + stopReason := "end_turn" + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + content = append(content, map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i), + "name": tc.Name, + "input": tc.Input, + }) + } + } else { + if finalText == "" { + finalText = "抱歉,没有生成有效的响应内容。" + } + content = append(content, map[string]any{"type": "text", "text": finalText}) + } + return map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": EstimateTokens(fmt.Sprintf("%v", normalizedMessages)), + "output_tokens": EstimateTokens(finalThinking) + EstimateTokens(finalText), + }, + } +} diff --git a/internal/util/render_test.go b/internal/util/render_test.go new file mode 100644 index 0000000..9d4feec --- /dev/null +++ b/internal/util/render_test.go @@ -0,0 +1,94 @@ +package util + +import "testing" + +func TestBuildOpenAIChatCompletionWithToolCalls(t *testing.T) { + out := BuildOpenAIChatCompletion( + "cid1", + "deepseek-chat", + "prompt", + "", + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + []string{"search"}, + ) + if out["object"] != "chat.completion" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + choices, _ := out["choices"].([]map[string]any) + if len(choices) == 0 { + // json-like map from generic marshalling may be []any in some paths + rawChoices, _ := out["choices"].([]any) + if len(rawChoices) == 0 { + t.Fatalf("expected choices") + } + c0, _ := rawChoices[0].(map[string]any) + if c0["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", c0["finish_reason"]) + } + return + } + if choices[0]["finish_reason"] != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %#v", choices[0]["finish_reason"]) + } +} + +func TestBuildOpenAIResponseObjectWithText(t *testing.T) { + out := BuildOpenAIResponseObject( + "resp_1", + "gpt-4o", + "prompt", + "reasoning", + "text", + nil, + ) + if out["object"] != "response" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + output, _ := out["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected output entries") + } + first, _ := output[0].(map[string]any) + if first["type"] != "message" { + t.Fatalf("expected first output type message, got %#v", first["type"]) + } +} + +func TestBuildOpenAIResponseObjectToolCallsHidesRawOutputText(t *testing.T) { + out := BuildOpenAIResponseObject( + "resp_2", + "gpt-4o", + "prompt", + "", + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + []string{"search"}, + ) + if out["output_text"] != "" { + t.Fatalf("expected empty output_text for tool_calls, got %#v", out["output_text"]) + } + output, _ := out["output"].([]any) + if len(output) == 0 { + t.Fatalf("expected output entries") + } + first, _ := output[0].(map[string]any) + if first["type"] != "tool_calls" { + t.Fatalf("expected first output type tool_calls, got %#v", first["type"]) + } +} + +func TestBuildClaudeMessageResponseToolUse(t *testing.T) { + out := BuildClaudeMessageResponse( + "msg_1", + "claude-sonnet-4-5", + []any{map[string]any{"role": "user", "content": "hi"}}, + "", + `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`, + []string{"search"}, + ) + if out["type"] != "message" { + t.Fatalf("unexpected type: %#v", out["type"]) + } + if out["stop_reason"] != "tool_use" { + t.Fatalf("expected stop_reason=tool_use, got %#v", out["stop_reason"]) + } +} diff --git a/internal/util/standard_request.go b/internal/util/standard_request.go new file mode 100644 index 0000000..84a4c98 --- /dev/null +++ b/internal/util/standard_request.go @@ -0,0 +1,66 @@ +package util + +type StandardRequest struct { + Surface string + RequestedModel string + ResolvedModel string + ResponseModel string + Messages []any + FinalPrompt string + ToolNames []string + ToolChoice ToolChoicePolicy + Stream bool + Thinking bool + Search bool + PassThrough map[string]any +} + +type ToolChoiceMode string + +const ( + ToolChoiceAuto ToolChoiceMode = "auto" + ToolChoiceNone ToolChoiceMode = "none" + ToolChoiceRequired ToolChoiceMode = "required" + ToolChoiceForced ToolChoiceMode = "forced" +) + +type ToolChoicePolicy struct { + Mode ToolChoiceMode + ForcedName string + Allowed map[string]struct{} +} + +func DefaultToolChoicePolicy() ToolChoicePolicy { + return ToolChoicePolicy{Mode: ToolChoiceAuto} +} + +func (p ToolChoicePolicy) IsNone() bool { + return p.Mode == ToolChoiceNone +} + +func (p ToolChoicePolicy) IsRequired() bool { + return p.Mode == ToolChoiceRequired || p.Mode == ToolChoiceForced +} + +func (p ToolChoicePolicy) Allows(name string) bool { + if len(p.Allowed) == 0 { + return true + } + _, ok := p.Allowed[name] + return ok +} + +func (r StandardRequest) CompletionPayload(sessionID string) map[string]any { + payload := map[string]any{ + "chat_session_id": sessionID, + "parent_message_id": nil, + "prompt": r.FinalPrompt, + "ref_file_ids": []any{}, + "thinking_enabled": r.Thinking, + "search_enabled": r.Search, + } + for k, v := range r.PassThrough { + payload[k] = v + } + return payload +} diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls.go deleted file mode 100644 index 9b9d4e6..0000000 --- a/internal/util/toolcalls.go +++ /dev/null @@ -1,317 +0,0 @@ -package util - -import ( - "encoding/json" - "regexp" - "strings" - - "github.com/google/uuid" -) - -var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`) -var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```") - -type ParsedToolCall struct { - Name string `json:"name"` - Input map[string]any `json:"input"` -} - -func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { - if strings.TrimSpace(text) == "" { - return nil - } - - candidates := buildToolCallCandidates(text) - var parsed []ParsedToolCall - for _, candidate := range candidates { - if tc := parseToolCallsPayload(candidate); len(tc) > 0 { - parsed = tc - break - } - } - if len(parsed) == 0 { - return nil - } - - allowed := map[string]struct{}{} - for _, name := range availableToolNames { - allowed[name] = struct{}{} - } - out := make([]ParsedToolCall, 0, len(parsed)) - for _, tc := range parsed { - if tc.Name == "" { - continue - } - if len(allowed) > 0 { - if _, ok := allowed[tc.Name]; !ok { - continue - } - } - if tc.Input == nil { - tc.Input = map[string]any{} - } - out = append(out, tc) - } - // If the model clearly emitted tool_calls JSON but all names are outside the - // declared set, keep the parsed calls as a fallback so upper layers can still - // intercept structured tool output instead of leaking raw JSON to users. - if len(out) == 0 && len(parsed) > 0 { - for _, tc := range parsed { - if tc.Name == "" { - continue - } - if tc.Input == nil { - tc.Input = map[string]any{} - } - out = append(out, tc) - } - } - return out -} - -func buildToolCallCandidates(text string) []string { - trimmed := strings.TrimSpace(text) - candidates := []string{trimmed} - - // fenced code block candidates: ```json ... ``` - for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) { - if len(match) >= 2 { - candidates = append(candidates, strings.TrimSpace(match[1])) - } - } - - // best-effort extraction around "tool_calls" key in mixed text payloads. - candidates = append(candidates, extractToolCallObjects(trimmed)...) - - // best-effort object slice: from first '{' to last '}' - first := strings.Index(trimmed, "{") - last := strings.LastIndex(trimmed, "}") - if first >= 0 && last > first { - candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1])) - } - - // legacy regex extraction fallback - if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 { - candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}") - } - - uniq := make([]string, 0, len(candidates)) - seen := map[string]struct{}{} - for _, c := range candidates { - if c == "" { - continue - } - if _, ok := seen[c]; ok { - continue - } - seen[c] = struct{}{} - uniq = append(uniq, c) - } - return uniq -} - -func parseToolCallsPayload(payload string) []ParsedToolCall { - var decoded any - if err := json.Unmarshal([]byte(payload), &decoded); err != nil { - return nil - } - switch v := decoded.(type) { - case map[string]any: - if tc, ok := v["tool_calls"]; ok { - return parseToolCallList(tc) - } - if parsed, ok := parseToolCallItem(v); ok { - return []ParsedToolCall{parsed} - } - case []any: - return parseToolCallList(v) - } - return nil -} - -func parseToolCallList(v any) []ParsedToolCall { - items, ok := v.([]any) - if !ok { - return nil - } - out := make([]ParsedToolCall, 0, len(items)) - for _, item := range items { - m, ok := item.(map[string]any) - if !ok { - continue - } - if tc, ok := parseToolCallItem(m); ok { - out = append(out, tc) - } - } - if len(out) == 0 { - return nil - } - return out -} - -func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) { - name, _ := m["name"].(string) - inputRaw, hasInput := m["input"] - if fn, ok := m["function"].(map[string]any); ok { - if name == "" { - name, _ = fn["name"].(string) - } - if !hasInput { - if v, ok := fn["arguments"]; ok { - inputRaw = v - hasInput = true - } - } - } - if !hasInput { - for _, key := range []string{"arguments", "args", "parameters", "params"} { - if v, ok := m[key]; ok { - inputRaw = v - hasInput = true - break - } - } - } - if strings.TrimSpace(name) == "" { - return ParsedToolCall{}, false - } - return ParsedToolCall{ - Name: strings.TrimSpace(name), - Input: parseToolCallInput(inputRaw), - }, true -} - -func parseToolCallInput(v any) map[string]any { - switch x := v.(type) { - case nil: - return map[string]any{} - case map[string]any: - return x - case string: - raw := strings.TrimSpace(x) - if raw == "" { - return map[string]any{} - } - var parsed map[string]any - if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil { - return parsed - } - return map[string]any{"_raw": raw} - default: - b, err := json.Marshal(x) - if err != nil { - return map[string]any{} - } - var parsed map[string]any - if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil { - return parsed - } - return map[string]any{} - } -} - -func extractToolCallObjects(text string) []string { - if text == "" { - return nil - } - lower := strings.ToLower(text) - out := []string{} - offset := 0 - for { - idx := strings.Index(lower[offset:], "tool_calls") - if idx < 0 { - break - } - idx += offset - start := strings.LastIndex(text[:idx], "{") - for start >= 0 { - candidate, end, ok := extractJSONObject(text, start) - if ok { - // Move forward to avoid repeatedly matching the same object. - offset = end - out = append(out, strings.TrimSpace(candidate)) - break - } - start = strings.LastIndex(text[:start], "{") - } - if start < 0 { - offset = idx + len("tool_calls") - } - } - return out -} - -func extractJSONObject(text string, start int) (string, int, bool) { - if start < 0 || start >= len(text) || text[start] != '{' { - return "", 0, false - } - depth := 0 - quote := byte(0) - escaped := false - for i := start; i < len(text); i++ { - ch := text[i] - if quote != 0 { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '"' || ch == '\'' { - quote = ch - continue - } - if ch == '{' { - depth++ - continue - } - if ch == '}' { - depth-- - if depth == 0 { - return text[start : i+1], i + 1, true - } - } - } - return "", 0, false -} - -func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any { - out := make([]map[string]any, 0, len(calls)) - for _, c := range calls { - args, _ := json.Marshal(c.Input) - out = append(out, map[string]any{ - "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), - "type": "function", - "function": map[string]any{ - "name": c.Name, - "arguments": string(args), - }, - }) - } - return out -} - -func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any { - out := make([]map[string]any, 0, len(calls)) - for i, c := range calls { - args, _ := json.Marshal(c.Input) - out = append(out, map[string]any{ - "index": i, - "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), - "type": "function", - "function": map[string]any{ - "name": c.Name, - "arguments": string(args), - }, - }) - } - return out -} diff --git a/internal/util/toolcalls_candidates.go b/internal/util/toolcalls_candidates.go new file mode 100644 index 0000000..4e8afc4 --- /dev/null +++ b/internal/util/toolcalls_candidates.go @@ -0,0 +1,138 @@ +package util + +import ( + "regexp" + "strings" +) + +var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`) +var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```") +var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```") + +func buildToolCallCandidates(text string) []string { + trimmed := strings.TrimSpace(text) + candidates := []string{trimmed} + + // fenced code block candidates: ```json ... ``` + for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) { + if len(match) >= 2 { + candidates = append(candidates, strings.TrimSpace(match[1])) + } + } + + // best-effort extraction around "tool_calls" key in mixed text payloads. + candidates = append(candidates, extractToolCallObjects(trimmed)...) + + // best-effort object slice: from first '{' to last '}' + first := strings.Index(trimmed, "{") + last := strings.LastIndex(trimmed, "}") + if first >= 0 && last > first { + candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1])) + } + + // legacy regex extraction fallback + if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 { + candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}") + } + + uniq := make([]string, 0, len(candidates)) + seen := map[string]struct{}{} + for _, c := range candidates { + if c == "" { + continue + } + if _, ok := seen[c]; ok { + continue + } + seen[c] = struct{}{} + uniq = append(uniq, c) + } + return uniq +} + +func extractToolCallObjects(text string) []string { + if text == "" { + return nil + } + lower := strings.ToLower(text) + out := []string{} + offset := 0 + for { + idx := strings.Index(lower[offset:], "tool_calls") + if idx < 0 { + break + } + idx += offset + start := strings.LastIndex(text[:idx], "{") + for start >= 0 { + candidate, end, ok := extractJSONObject(text, start) + if ok { + // Move forward to avoid repeatedly matching the same object. + offset = end + out = append(out, strings.TrimSpace(candidate)) + break + } + start = strings.LastIndex(text[:start], "{") + } + if start < 0 { + offset = idx + len("tool_calls") + } + } + return out +} + +func extractJSONObject(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '{' { + return "", 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + return text[start : i+1], i + 1, true + } + } + } + return "", 0, false +} + +func looksLikeToolExampleContext(text string) bool { + t := strings.ToLower(strings.TrimSpace(text)) + if t == "" { + return false + } + return strings.Contains(t, "```") +} + +func stripFencedCodeBlocks(text string) string { + if strings.TrimSpace(text) == "" { + return "" + } + return fencedBlockPattern.ReplaceAllString(text, " ") +} diff --git a/internal/util/toolcalls_format.go b/internal/util/toolcalls_format.go new file mode 100644 index 0000000..8feb48f --- /dev/null +++ b/internal/util/toolcalls_format.go @@ -0,0 +1,41 @@ +package util + +import ( + "encoding/json" + "strings" + + "github.com/google/uuid" +) + +func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any { + out := make([]map[string]any, 0, len(calls)) + for _, c := range calls { + args, _ := json.Marshal(c.Input) + out = append(out, map[string]any{ + "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "type": "function", + "function": map[string]any{ + "name": c.Name, + "arguments": string(args), + }, + }) + } + return out +} + +func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any { + out := make([]map[string]any, 0, len(calls)) + for i, c := range calls { + args, _ := json.Marshal(c.Input) + out = append(out, map[string]any{ + "index": i, + "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "type": "function", + "function": map[string]any{ + "name": c.Name, + "arguments": string(args), + }, + }) + } + return out +} diff --git a/internal/util/toolcalls_parse.go b/internal/util/toolcalls_parse.go new file mode 100644 index 0000000..5b386c2 --- /dev/null +++ b/internal/util/toolcalls_parse.go @@ -0,0 +1,230 @@ +package util + +import ( + "encoding/json" + "strings" +) + +type ParsedToolCall struct { + Name string `json:"name"` + Input map[string]any `json:"input"` +} + +type ToolCallParseResult struct { + Calls []ParsedToolCall + SawToolCallSyntax bool + RejectedByPolicy bool + RejectedToolNames []string +} + +func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { + return ParseToolCallsDetailed(text, availableToolNames).Calls +} + +func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult { + result := ToolCallParseResult{} + if strings.TrimSpace(text) == "" { + return result + } + text = stripFencedCodeBlocks(text) + if strings.TrimSpace(text) == "" { + return result + } + result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls") + + candidates := buildToolCallCandidates(text) + var parsed []ParsedToolCall + for _, candidate := range candidates { + if tc := parseToolCallsPayload(candidate); len(tc) > 0 { + parsed = tc + result.SawToolCallSyntax = true + break + } + } + if len(parsed) == 0 { + return result + } + + calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) + result.Calls = calls + result.RejectedToolNames = rejectedNames + result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 + return result +} + +func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall { + return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls +} + +func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult { + result := ToolCallParseResult{} + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return result + } + if looksLikeToolExampleContext(trimmed) { + return result + } + result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls") + candidates := []string{trimmed} + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") { + continue + } + if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 { + result.SawToolCallSyntax = true + calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) + result.Calls = calls + result.RejectedToolNames = rejectedNames + result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 + return result + } + } + return result +} + +func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) { + allowed := map[string]struct{}{} + for _, name := range availableToolNames { + allowed[name] = struct{}{} + } + if len(allowed) == 0 { + rejectedSet := map[string]struct{}{} + for _, tc := range parsed { + if tc.Name == "" { + continue + } + rejectedSet[tc.Name] = struct{}{} + } + rejected := make([]string, 0, len(rejectedSet)) + for name := range rejectedSet { + rejected = append(rejected, name) + } + return nil, rejected + } + out := make([]ParsedToolCall, 0, len(parsed)) + rejectedSet := map[string]struct{}{} + for _, tc := range parsed { + if tc.Name == "" { + continue + } + if _, ok := allowed[tc.Name]; !ok { + rejectedSet[tc.Name] = struct{}{} + continue + } + if tc.Input == nil { + tc.Input = map[string]any{} + } + out = append(out, tc) + } + rejected := make([]string, 0, len(rejectedSet)) + for name := range rejectedSet { + rejected = append(rejected, name) + } + return out, rejected +} + +func parseToolCallsPayload(payload string) []ParsedToolCall { + var decoded any + if err := json.Unmarshal([]byte(payload), &decoded); err != nil { + return nil + } + switch v := decoded.(type) { + case map[string]any: + if tc, ok := v["tool_calls"]; ok { + return parseToolCallList(tc) + } + if parsed, ok := parseToolCallItem(v); ok { + return []ParsedToolCall{parsed} + } + case []any: + return parseToolCallList(v) + } + return nil +} + +func parseToolCallList(v any) []ParsedToolCall { + items, ok := v.([]any) + if !ok { + return nil + } + out := make([]ParsedToolCall, 0, len(items)) + for _, item := range items { + m, ok := item.(map[string]any) + if !ok { + continue + } + if tc, ok := parseToolCallItem(m); ok { + out = append(out, tc) + } + } + if len(out) == 0 { + return nil + } + return out +} + +func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) { + name, _ := m["name"].(string) + inputRaw, hasInput := m["input"] + if fn, ok := m["function"].(map[string]any); ok { + if name == "" { + name, _ = fn["name"].(string) + } + if !hasInput { + if v, ok := fn["arguments"]; ok { + inputRaw = v + hasInput = true + } + } + } + if !hasInput { + for _, key := range []string{"arguments", "args", "parameters", "params"} { + if v, ok := m[key]; ok { + inputRaw = v + hasInput = true + break + } + } + } + if strings.TrimSpace(name) == "" { + return ParsedToolCall{}, false + } + return ParsedToolCall{ + Name: strings.TrimSpace(name), + Input: parseToolCallInput(inputRaw), + }, true +} + +func parseToolCallInput(v any) map[string]any { + switch x := v.(type) { + case nil: + return map[string]any{} + case map[string]any: + return x + case string: + raw := strings.TrimSpace(x) + if raw == "" { + return map[string]any{} + } + var parsed map[string]any + if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil { + return parsed + } + return map[string]any{"_raw": raw} + default: + b, err := json.Marshal(x) + if err != nil { + return map[string]any{} + } + var parsed map[string]any + if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil { + return parsed + } + return map[string]any{} + } +} diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index 8c44320..0e823c0 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -19,11 +19,8 @@ func TestParseToolCalls(t *testing.T) { func TestParseToolCallsFromFencedJSON(t *testing.T) { text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```" calls := ParseToolCalls(text, []string{"search"}) - if len(calls) != 1 { - t.Fatalf("expected 1 call, got %d", len(calls)) - } - if calls[0].Input["q"] != "news" { - t.Fatalf("unexpected args: %#v", calls[0].Input) + if len(calls) != 0 { + t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls) } } @@ -41,14 +38,39 @@ func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) { } } -func TestParseToolCallsKeepsUnknownAsFallback(t *testing.T) { +func TestParseToolCallsRejectsUnknownToolName(t *testing.T) { text := `{"tool_calls":[{"name":"unknown","input":{}}]}` calls := ParseToolCalls(text, []string{"search"}) - if len(calls) != 1 { - t.Fatalf("expected fallback 1 call, got %d", len(calls)) + if len(calls) != 0 { + t.Fatalf("expected unknown tool to be rejected, got %#v", calls) } - if calls[0].Name != "unknown" { - t.Fatalf("unexpected name: %s", calls[0].Name) +} + +func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) { + text := `{"tool_calls":[{"name":"unknown","input":{}}]}` + res := ParseToolCallsDetailed(text, []string{"search"}) + if !res.SawToolCallSyntax { + t.Fatalf("expected SawToolCallSyntax=true, got %#v", res) + } + if !res.RejectedByPolicy { + t.Fatalf("expected RejectedByPolicy=true, got %#v", res) + } + if len(res.Calls) != 0 { + t.Fatalf("expected no calls after policy rejection, got %#v", res.Calls) + } +} + +func TestParseToolCallsDetailedRejectsWhenAllowListEmpty(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + res := ParseToolCallsDetailed(text, nil) + if !res.SawToolCallSyntax { + t.Fatalf("expected SawToolCallSyntax=true, got %#v", res) + } + if !res.RejectedByPolicy { + t.Fatalf("expected RejectedByPolicy=true, got %#v", res) + } + if len(res.Calls) != 0 { + t.Fatalf("expected no calls when allow-list is empty, got %#v", res.Calls) } } @@ -62,3 +84,23 @@ func TestFormatOpenAIToolCalls(t *testing.T) { t.Fatalf("unexpected function name: %#v", fn) } } + +func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) { + mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 { + t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls) + } + + standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + calls := ParseStandaloneToolCalls(standalone, []string{"search"}) + if len(calls) != 1 { + t.Fatalf("expected standalone parser to match, got %#v", calls) + } +} + +func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) { + fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```" + if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 { + t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls) + } +} diff --git a/internal/util/util_edge_test.go b/internal/util/util_edge_test.go new file mode 100644 index 0000000..876cd04 --- /dev/null +++ b/internal/util/util_edge_test.go @@ -0,0 +1,429 @@ +package util + +import ( + "encoding/json" + "net/http/httptest" + "strings" + "testing" + + "ds2api/internal/config" +) + +// ─── EstimateTokens edge cases ─────────────────────────────────────── + +func TestEstimateTokensEmpty(t *testing.T) { + if got := EstimateTokens(""); got != 0 { + t.Fatalf("expected 0 for empty string, got %d", got) + } +} + +func TestEstimateTokensShortASCII(t *testing.T) { + got := EstimateTokens("ab") + if got != 1 { + t.Fatalf("expected 1 for 2 ascii chars, got %d", got) + } +} + +func TestEstimateTokensLongASCII(t *testing.T) { + got := EstimateTokens(strings.Repeat("x", 100)) + if got != 25 { + t.Fatalf("expected 25 for 100 ascii chars, got %d", got) + } +} + +func TestEstimateTokensChinese(t *testing.T) { + got := EstimateTokens("你好世界") + if got < 1 { + t.Fatalf("expected at least 1 token for Chinese text, got %d", got) + } +} + +func TestEstimateTokensMixed(t *testing.T) { + got := EstimateTokens("Hello 你好世界") + if got < 2 { + t.Fatalf("expected at least 2 tokens for mixed text, got %d", got) + } +} + +func TestEstimateTokensSingleByte(t *testing.T) { + got := EstimateTokens("x") + if got != 1 { + t.Fatalf("expected 1 for single char (minimum), got %d", got) + } +} + +func TestEstimateTokensSingleChinese(t *testing.T) { + got := EstimateTokens("你") + if got != 1 { + t.Fatalf("expected 1 for single Chinese char, got %d", got) + } +} + +// ─── ToBool edge cases ─────────────────────────────────────────────── + +func TestToBoolTrue(t *testing.T) { + if !ToBool(true) { + t.Fatal("expected true") + } +} + +func TestToBoolFalse(t *testing.T) { + if ToBool(false) { + t.Fatal("expected false") + } +} + +func TestToBoolNonBool(t *testing.T) { + if ToBool("true") { + t.Fatal("expected false for string 'true'") + } + if ToBool(1) { + t.Fatal("expected false for int 1") + } + if ToBool(nil) { + t.Fatal("expected false for nil") + } +} + +// ─── IntFrom edge cases ───────────────────────────────────────────── + +func TestIntFromFloat64(t *testing.T) { + if got := IntFrom(float64(42.5)); got != 42 { + t.Fatalf("expected 42 for float64(42.5), got %d", got) + } +} + +func TestIntFromInt(t *testing.T) { + if got := IntFrom(int(42)); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestIntFromInt64(t *testing.T) { + if got := IntFrom(int64(42)); got != 42 { + t.Fatalf("expected 42, got %d", got) + } +} + +func TestIntFromString(t *testing.T) { + if got := IntFrom("42"); got != 0 { + t.Fatalf("expected 0 for string, got %d", got) + } +} + +func TestIntFromNil(t *testing.T) { + if got := IntFrom(nil); got != 0 { + t.Fatalf("expected 0 for nil, got %d", got) + } +} + +// ─── WriteJSON ─────────────────────────────────────────────────────── + +func TestWriteJSON(t *testing.T) { + rec := httptest.NewRecorder() + WriteJSON(rec, 200, map[string]any{"key": "value"}) + if rec.Code != 200 { + t.Fatalf("expected 200, got %d", rec.Code) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + t.Fatalf("expected application/json content type, got %q", ct) + } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode error: %v", err) + } + if body["key"] != "value" { + t.Fatalf("unexpected body: %#v", body) + } +} + +func TestWriteJSONStatusCodes(t *testing.T) { + for _, code := range []int{200, 201, 400, 404, 500} { + rec := httptest.NewRecorder() + WriteJSON(rec, code, map[string]any{"status": code}) + if rec.Code != code { + t.Fatalf("expected %d, got %d", code, rec.Code) + } + } +} + +// ─── MessagesPrepare edge cases ────────────────────────────────────── + +func TestMessagesPrepareEmpty(t *testing.T) { + got := MessagesPrepare(nil) + if got != "" { + t.Fatalf("expected empty for nil messages, got %q", got) + } +} + +func TestMessagesPrepareMergesConsecutiveSameRole(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "World"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "Hello") || !strings.Contains(got, "World") { + t.Fatalf("expected both messages, got %q", got) + } + // Should be merged without <|User|> between them + count := strings.Count(got, "<|User|>") + if count != 0 { + t.Fatalf("expected no User marker for first message pair, got %d occurrences", count) + } +} + +func TestMessagesPrepareAssistantMarkers(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "<|Assistant|>") { + t.Fatalf("expected assistant marker, got %q", got) + } + if !strings.Contains(got, "<|end▁of▁sentence|>") { + t.Fatalf("expected end of sentence marker, got %q", got) + } +} + +func TestMessagesPrepareUnknownRole(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Hello"}, + {"role": "unknown_role", "content": "Unknown"}, + } + got := MessagesPrepare(messages) + if !strings.Contains(got, "Unknown") { + t.Fatalf("expected unknown role content, got %q", got) + } +} + +func TestMessagesPrepareMarkdownImageReplaced(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": "Look at this: ![alt](https://example.com/img.png)"}, + } + got := MessagesPrepare(messages) + if strings.Contains(got, "![alt]") { + t.Fatalf("expected markdown image to be replaced, got %q", got) + } +} + +func TestMessagesPrepareNilContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": nil}, + } + got := MessagesPrepare(messages) + if got != "null" { + t.Logf("nil content handled as: %q", got) + } +} + +// ─── normalizeContent edge cases ───────────────────────────────────── + +func TestNormalizeContentString(t *testing.T) { + got := normalizeContent("hello") + if got != "hello" { + t.Fatalf("expected 'hello', got %q", got) + } +} + +func TestNormalizeContentArray(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "text", "text": "line1"}, + map[string]any{"type": "text", "text": "line2"}, + }) + if got != "line1\nline2" { + t.Fatalf("expected 'line1\\nline2', got %q", got) + } +} + +func TestNormalizeContentArrayWithContentField(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "text", "content": "from-content"}, + }) + if got != "from-content" { + t.Fatalf("expected 'from-content', got %q", got) + } +} + +func TestNormalizeContentArraySkipsImage(t *testing.T) { + got := normalizeContent([]any{ + map[string]any{"type": "image_url", "image_url": "https://example.com/img.png"}, + map[string]any{"type": "text", "text": "caption"}, + }) + if strings.Contains(got, "image") { + t.Fatalf("expected image skipped, got %q", got) + } + if got != "caption" { + t.Fatalf("expected 'caption', got %q", got) + } +} + +func TestNormalizeContentArrayNonMapItems(t *testing.T) { + got := normalizeContent([]any{"string item", 42}) + if got != "" { + t.Fatalf("expected empty for non-map items, got %q", got) + } +} + +func TestNormalizeContentJSON(t *testing.T) { + got := normalizeContent(map[string]any{"key": "value"}) + if !strings.Contains(got, `"key":"value"`) { + t.Fatalf("expected JSON serialized, got %q", got) + } +} + +// ─── ConvertClaudeToDeepSeek edge cases ────────────────────────────── + +func TestConvertClaudeToDeepSeekDefaultModel(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["model"] == "" { + t.Fatal("expected default model") + } +} + +func TestConvertClaudeToDeepSeekWithStopSequences(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + "stop_sequences": []any{"\n\n"}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["stop"] == nil { + t.Fatal("expected stop field from stop_sequences") + } +} + +func TestConvertClaudeToDeepSeekWithTemperature(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + "temperature": 0.7, + "top_p": 0.9, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["temperature"] != 0.7 { + t.Fatalf("expected temperature 0.7, got %v", out["temperature"]) + } + if out["top_p"] != 0.9 { + t.Fatalf("expected top_p 0.9, got %v", out["top_p"]) + } +} + +func TestConvertClaudeToDeepSeekNoSystem(t *testing.T) { + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + msgs, _ := out["messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("expected 1 message without system, got %d", len(msgs)) + } +} + +func TestConvertClaudeToDeepSeekOpusUsesSlowMapping(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[],"claude_mapping":{"fast":"deepseek-chat","slow":"deepseek-reasoner"}}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{map[string]any{"role": "user", "content": "Hi"}}, + } + out := ConvertClaudeToDeepSeek(req, store) + if out["model"] != "deepseek-reasoner" { + t.Fatalf("expected opus to use slow mapping, got %q", out["model"]) + } +} + +// ─── FormatOpenAIStreamToolCalls ───────────────────────────────────── + +func TestFormatOpenAIStreamToolCalls(t *testing.T) { + formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{ + {Name: "search", Input: map[string]any{"q": "test"}}, + }) + if len(formatted) != 1 { + t.Fatalf("expected 1, got %d", len(formatted)) + } + fn, _ := formatted[0]["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("unexpected function name: %#v", fn) + } + if formatted[0]["index"] != 0 { + t.Fatalf("expected index 0, got %v", formatted[0]["index"]) + } +} + +// ─── ParseToolCalls more edge cases ────────────────────────────────── + +func TestParseToolCallsNoToolNames(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + calls := ParseToolCalls(text, nil) + if len(calls) != 0 { + t.Fatalf("expected 0 call with nil tool names, got %d", len(calls)) + } +} + +func TestParseToolCallsEmptyText(t *testing.T) { + calls := ParseToolCalls("", []string{"search"}) + if len(calls) != 0 { + t.Fatalf("expected 0 calls for empty text, got %d", len(calls)) + } +} + +func TestParseToolCallsMultipleTools(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":{"q":"go"}},{"name":"get_weather","input":{"city":"beijing"}}]}` + calls := ParseToolCalls(text, []string{"search", "get_weather"}) + if len(calls) != 2 { + t.Fatalf("expected 2 calls, got %d", len(calls)) + } +} + +func TestParseToolCallsInputAsString(t *testing.T) { + text := `{"tool_calls":[{"name":"search","input":"{\"q\":\"golang\"}"}]}` + calls := ParseToolCalls(text, []string{"search"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Input["q"] != "golang" { + t.Fatalf("expected parsed string input, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsWithFunctionWrapper(t *testing.T) { + text := `{"tool_calls":[{"function":{"name":"calc","arguments":{"x":1,"y":2}}}]}` + calls := ParseToolCalls(text, []string{"calc"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].Name != "calc" { + t.Fatalf("expected calc, got %q", calls[0].Name) + } +} + +func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) { + fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this." + calls := ParseStandaloneToolCalls(fenced, []string{"search"}) + if len(calls) != 0 { + t.Fatalf("expected fenced code block ignored, got %d calls", len(calls)) + } +} + +// ─── looksLikeToolExampleContext ───────────────────────────────────── + +func TestLooksLikeToolExampleContextNone(t *testing.T) { + if looksLikeToolExampleContext("I will call the tool now") { + t.Fatal("expected false for non-example context") + } +} + +func TestLooksLikeToolExampleContextFenced(t *testing.T) { + if !looksLikeToolExampleContext("```json") { + t.Fatal("expected true for fenced code block context") + } +} diff --git a/opencode.json.example b/opencode.json.example index 2933e9f..ed18a63 100644 --- a/opencode.json.example +++ b/opencode.json.example @@ -9,6 +9,12 @@ "apiKey": "your-api-key" }, "models": { + "gpt-4o": { + "name": "GPT-4o (aliased to deepseek-chat)" + }, + "gpt-5-codex": { + "name": "GPT-5 Codex (aliased to deepseek-reasoner)" + }, "deepseek-chat": { "name": "DeepSeek Chat (DS2API)" }, @@ -18,5 +24,5 @@ } } }, - "model": "ds2api/deepseek-chat" + "model": "ds2api/gpt-5-codex" } diff --git a/plans/node-syntax-gate-targets.txt b/plans/node-syntax-gate-targets.txt new file mode 100644 index 0000000..7b268a8 --- /dev/null +++ b/plans/node-syntax-gate-targets.txt @@ -0,0 +1,22 @@ +# Node split syntax gate targets +# Keep this list in sync with api/chat-stream and internal/js/helpers/stream-tool-sieve split modules. + +api/chat-stream.js +internal/js/chat-stream/index.js +internal/js/chat-stream/error_shape.js +internal/js/chat-stream/http_internal.js +internal/js/chat-stream/proxy_go.js +internal/js/chat-stream/sse_parse.js +internal/js/chat-stream/stream_emitter.js +internal/js/chat-stream/token_usage.js +internal/js/chat-stream/toolcall_policy.js +internal/js/chat-stream/vercel_stream.js + +internal/js/helpers/stream-tool-sieve.js +internal/js/helpers/stream-tool-sieve/index.js +internal/js/helpers/stream-tool-sieve/state.js +internal/js/helpers/stream-tool-sieve/sieve.js +internal/js/helpers/stream-tool-sieve/incremental.js +internal/js/helpers/stream-tool-sieve/jsonscan.js +internal/js/helpers/stream-tool-sieve/parse.js +internal/js/helpers/stream-tool-sieve/format.js diff --git a/plans/refactor-baseline.md b/plans/refactor-baseline.md new file mode 100644 index 0000000..151683a --- /dev/null +++ b/plans/refactor-baseline.md @@ -0,0 +1,32 @@ +# DS2API Refactor Baseline (Historical Snapshot) + +- Snapshot time: `2026-02-22T08:53:54Z` +- Snapshot branch: `dev` +- Snapshot HEAD: `5d3989a` +- Scope: backend + node api + webui large-file decoupling (no behavior change) + +## Gate Commands + +1. `./tests/scripts/run-unit-all.sh` + - Result: PASS + - Includes: + - `go test ./...` + - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js` +2. `npm --prefix webui run build` + - Result: PASS +3. `./tests/scripts/check-refactor-line-gate.sh` + - Result: PASS (`checked=131 missing=0 over_limit=0`) +4. Stage gates (1-5) replay: + - `go test ./internal/config ./internal/admin ./internal/account ./internal/deepseek ./internal/format/openai` -> PASS + - `go test ./internal/adapter/openai ./internal/util ./internal/sse ./internal/compat` -> PASS + - `go test ./internal/adapter/claude ./internal/adapter/gemini ./internal/config` -> PASS + - `go test ./internal/testsuite ./cmd/ds2api-tests` -> PASS + - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js` -> PASS +5. Final full regression: + - `go test ./... -count=1` -> PASS + +## Notes + +- This file records a historical baseline for refactor process tracking. +- It is not intended to represent the current repository HEAD. +- Frontend manual smoke for phase 6 still requires human execution and sign-off. diff --git a/plans/refactor-line-gate-targets.txt b/plans/refactor-line-gate-targets.txt new file mode 100644 index 0000000..c9839b2 --- /dev/null +++ b/plans/refactor-line-gate-targets.txt @@ -0,0 +1,151 @@ +# Line gate targets for large-file decoupling refactor. +# Default limit: 300 lines +# Entry/facade limit: 120 lines (enforced in script) + +internal/config/config.go +internal/config/logger.go +internal/config/paths.go +internal/config/codec.go +internal/config/store.go +internal/config/store_index.go +internal/config/store_accessors.go +internal/config/account.go + +internal/admin/handler_config_read.go +internal/admin/handler_config_write.go +internal/admin/handler_config_import.go +internal/admin/handler_settings_read.go +internal/admin/handler_settings_write.go +internal/admin/handler_settings_parse.go +internal/admin/handler_settings_runtime.go +internal/admin/handler_accounts_crud.go +internal/admin/handler_accounts_testing.go +internal/admin/handler_accounts_queue.go + +internal/account/pool_core.go +internal/account/pool_acquire.go +internal/account/pool_waiters.go +internal/account/pool_limits.go + +internal/deepseek/client_core.go +internal/deepseek/client_auth.go +internal/deepseek/client_completion.go +internal/deepseek/client_http_json.go +internal/deepseek/client_http_helpers.go + +internal/format/openai/render_chat.go +internal/format/openai/render_responses.go +internal/format/openai/render_stream_events.go +internal/format/openai/render_usage.go + +internal/adapter/openai/handler_routes.go +internal/adapter/openai/handler_chat.go +internal/adapter/openai/handler_errors.go +internal/adapter/openai/handler_toolcall_policy.go +internal/adapter/openai/handler_toolcall_format.go +internal/adapter/openai/responses_handler.go +internal/adapter/openai/responses_input_normalize.go +internal/adapter/openai/responses_input_items.go +internal/adapter/openai/responses_stream_runtime_core.go +internal/adapter/openai/responses_stream_runtime_events.go +internal/adapter/openai/responses_stream_runtime_toolcalls.go +internal/adapter/openai/tool_sieve_state.go +internal/adapter/openai/tool_sieve_core.go +internal/adapter/openai/tool_sieve_incremental.go +internal/adapter/openai/tool_sieve_jsonscan.go + +internal/util/toolcalls_parse.go +internal/util/toolcalls_candidates.go +internal/util/toolcalls_format.go + +internal/adapter/claude/handler_routes.go +internal/adapter/claude/handler_messages.go +internal/adapter/claude/handler_tokens.go +internal/adapter/claude/handler_errors.go +internal/adapter/claude/handler_utils.go +internal/adapter/claude/stream_runtime_core.go +internal/adapter/claude/stream_runtime_emit.go +internal/adapter/claude/stream_runtime_finalize.go + +internal/adapter/gemini/handler_routes.go +internal/adapter/gemini/handler_generate.go +internal/adapter/gemini/handler_stream_runtime.go +internal/adapter/gemini/handler_errors.go +internal/adapter/gemini/convert_request.go +internal/adapter/gemini/convert_messages.go +internal/adapter/gemini/convert_tools.go +internal/adapter/gemini/convert_passthrough.go + +internal/testsuite/runner_core.go +internal/testsuite/runner_env.go +internal/testsuite/runner_http.go +internal/testsuite/runner_cases_openai.go +internal/testsuite/runner_cases_openai_advanced.go +internal/testsuite/runner_cases_admin.go +internal/testsuite/runner_cases_claude.go +internal/testsuite/runner_summary.go +internal/testsuite/runner_utils.go +internal/testsuite/runner_defaults.go +internal/testsuite/runner_registry.go +internal/testsuite/edge_cases_abort.go +internal/testsuite/edge_cases_error_contract.go + +api/chat-stream.js +internal/js/chat-stream/index.js +internal/js/chat-stream/vercel_stream.js +internal/js/chat-stream/proxy_go.js +internal/js/chat-stream/sse_parse.js +internal/js/chat-stream/http_internal.js +internal/js/chat-stream/toolcall_policy.js +internal/js/chat-stream/error_shape.js +internal/js/chat-stream/token_usage.js +internal/js/chat-stream/stream_emitter.js + +internal/js/helpers/stream-tool-sieve.js +internal/js/helpers/stream-tool-sieve/index.js +internal/js/helpers/stream-tool-sieve/state.js +internal/js/helpers/stream-tool-sieve/sieve.js +internal/js/helpers/stream-tool-sieve/incremental.js +internal/js/helpers/stream-tool-sieve/jsonscan.js +internal/js/helpers/stream-tool-sieve/parse.js +internal/js/helpers/stream-tool-sieve/format.js + +webui/src/App.jsx +webui/src/app/AppRoutes.jsx +webui/src/app/useAdminAuth.js +webui/src/app/useAdminConfig.js +webui/src/layout/DashboardShell.jsx + +webui/src/components/AccountManager.jsx +webui/src/features/account/AccountManagerContainer.jsx +webui/src/features/account/useAccountsData.js +webui/src/features/account/useAccountActions.js +webui/src/features/account/QueueCards.jsx +webui/src/features/account/ApiKeysPanel.jsx +webui/src/features/account/AccountsTable.jsx +webui/src/features/account/AddKeyModal.jsx +webui/src/features/account/AddAccountModal.jsx + +webui/src/components/ApiTester.jsx +webui/src/features/apiTester/ApiTesterContainer.jsx +webui/src/features/apiTester/useApiTesterState.js +webui/src/features/apiTester/useChatStreamClient.js +webui/src/features/apiTester/ConfigPanel.jsx +webui/src/features/apiTester/ChatPanel.jsx + +webui/src/components/Settings.jsx +webui/src/features/settings/SettingsContainer.jsx +webui/src/features/settings/useSettingsForm.js +webui/src/features/settings/settingsApi.js +webui/src/features/settings/SecuritySection.jsx +webui/src/features/settings/RuntimeSection.jsx +webui/src/features/settings/BehaviorSection.jsx +webui/src/features/settings/ModelSection.jsx +webui/src/features/settings/BackupSection.jsx + +webui/src/components/VercelSync.jsx +webui/src/features/vercel/VercelSyncContainer.jsx +webui/src/features/vercel/useVercelSyncState.js +webui/src/features/vercel/VercelSyncForm.jsx +webui/src/features/vercel/VercelSyncStatus.jsx +webui/src/features/vercel/VercelGuide.jsx diff --git a/plans/refactor-line-gate.md b/plans/refactor-line-gate.md new file mode 100644 index 0000000..86f0d82 --- /dev/null +++ b/plans/refactor-line-gate.md @@ -0,0 +1,21 @@ +# Refactor Line Gate + +## Rules + +1. Production file default upper bound: `<= 300` lines. +2. Entry/facade files upper bound: `<= 120` lines. +3. Scope is limited to target files in `plans/refactor-line-gate-targets.txt`. +4. Test files are out of scope for this gate. + +## Command + +```bash +./tests/scripts/check-refactor-line-gate.sh +``` + +## Naming Note + +- Original split plan used `internal/admin/handler_accounts_test.go` for account probing logic. +- In Go, `*_test.go` files are test-only compilation units and cannot host production handlers. +- The production file is implemented as `internal/admin/handler_accounts_testing.go`. + diff --git a/plans/stage6-manual-smoke.md b/plans/stage6-manual-smoke.md new file mode 100644 index 0000000..4c06d85 --- /dev/null +++ b/plans/stage6-manual-smoke.md @@ -0,0 +1,28 @@ +# Stage 6 Manual Smoke Checklist + +- Date: 2026-02-22 +- Tester: release-maintainer +- Environment: local macOS + latest Chrome + +## Items + +1. Login flow (`/admin/login`) succeeds and failure message shape unchanged. +2. Account manager: + - add/edit/delete account + - queue status cards render and refresh +3. API tester: + - non-stream request succeeds + - stream request receives incremental output and final state +4. Settings: + - read settings + - save settings + - backup/export path works +5. Vercel sync: + - status poll + - manual refresh + - sync action and status feedback text + +## Result + +- Status: `PASS` +- Notes: login/account/api-tester/settings/vercel-sync smoke passed with no behavior regressions. diff --git a/tests/compat/expected/sse_fragments_append.json b/tests/compat/expected/sse_fragments_append.json new file mode 100644 index 0000000..8647f3a --- /dev/null +++ b/tests/compat/expected/sse_fragments_append.json @@ -0,0 +1,8 @@ +{ + "parts": [ + {"text": "思考中", "type": "thinking"}, + {"text": "结论", "type": "text"} + ], + "finished": false, + "new_type": "text" +} diff --git a/tests/compat/expected/sse_nested_finished.json b/tests/compat/expected/sse_nested_finished.json new file mode 100644 index 0000000..7d588f7 --- /dev/null +++ b/tests/compat/expected/sse_nested_finished.json @@ -0,0 +1,5 @@ +{ + "parts": [], + "finished": true, + "new_type": "text" +} diff --git a/tests/compat/expected/sse_split_tool_json.json b/tests/compat/expected/sse_split_tool_json.json new file mode 100644 index 0000000..2afed2a --- /dev/null +++ b/tests/compat/expected/sse_split_tool_json.json @@ -0,0 +1,8 @@ +{ + "parts": [ + {"text": "{\"", "type": "text"}, + {"text": "tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}", "type": "text"} + ], + "finished": false, + "new_type": "text" +} diff --git a/tests/compat/expected/token_cases.json b/tests/compat/expected/token_cases.json new file mode 100644 index 0000000..69694eb --- /dev/null +++ b/tests/compat/expected/token_cases.json @@ -0,0 +1,7 @@ +{ + "cases": [ + {"name": "ascii_short", "tokens": 1}, + {"name": "cjk", "tokens": 3}, + {"name": "mixed", "tokens": 4} + ] +} diff --git a/tests/compat/expected/toolcalls_fenced_json.json b/tests/compat/expected/toolcalls_fenced_json.json new file mode 100644 index 0000000..97646bf --- /dev/null +++ b/tests/compat/expected/toolcalls_fenced_json.json @@ -0,0 +1,3 @@ +{ + "calls": [] +} diff --git a/tests/compat/expected/toolcalls_unknown_name.json b/tests/compat/expected/toolcalls_unknown_name.json new file mode 100644 index 0000000..97646bf --- /dev/null +++ b/tests/compat/expected/toolcalls_unknown_name.json @@ -0,0 +1,3 @@ +{ + "calls": [] +} diff --git a/tests/compat/fixtures/sse_chunks/fragments_append.json b/tests/compat/fixtures/sse_chunks/fragments_append.json new file mode 100644 index 0000000..c6f8ae6 --- /dev/null +++ b/tests/compat/fixtures/sse_chunks/fragments_append.json @@ -0,0 +1,12 @@ +{ + "chunk": { + "p": "response/fragments", + "o": "APPEND", + "v": [ + {"type": "THINK", "content": "思考中"}, + {"type": "RESPONSE", "content": "结论"} + ] + }, + "thinking_enabled": true, + "current_type": "thinking" +} diff --git a/tests/compat/fixtures/sse_chunks/nested_finished.json b/tests/compat/fixtures/sse_chunks/nested_finished.json new file mode 100644 index 0000000..da76280 --- /dev/null +++ b/tests/compat/fixtures/sse_chunks/nested_finished.json @@ -0,0 +1,10 @@ +{ + "chunk": { + "p": "response", + "v": [ + {"p": "status", "v": "FINISHED"} + ] + }, + "thinking_enabled": false, + "current_type": "text" +} diff --git a/tests/compat/fixtures/sse_chunks/split_tool_json.json b/tests/compat/fixtures/sse_chunks/split_tool_json.json new file mode 100644 index 0000000..e915fbb --- /dev/null +++ b/tests/compat/fixtures/sse_chunks/split_tool_json.json @@ -0,0 +1,11 @@ +{ + "chunk": { + "p": "response", + "v": [ + {"p": "response/content", "v": "{\""}, + {"p": "response/content", "v": "tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"} + ] + }, + "thinking_enabled": false, + "current_type": "text" +} diff --git a/tests/compat/fixtures/token_cases.json b/tests/compat/fixtures/token_cases.json new file mode 100644 index 0000000..3887356 --- /dev/null +++ b/tests/compat/fixtures/token_cases.json @@ -0,0 +1,7 @@ +{ + "cases": [ + {"name": "ascii_short", "text": "abcd"}, + {"name": "cjk", "text": "你好世界"}, + {"name": "mixed", "text": "Hello 你好世界"} + ] +} diff --git a/tests/compat/fixtures/toolcalls/fenced_json.json b/tests/compat/fixtures/toolcalls/fenced_json.json new file mode 100644 index 0000000..8d75cc1 --- /dev/null +++ b/tests/compat/fixtures/toolcalls/fenced_json.json @@ -0,0 +1,4 @@ +{ + "text": "```json\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}\n```", + "tool_names": ["read_file"] +} diff --git a/tests/compat/fixtures/toolcalls/unknown_name.json b/tests/compat/fixtures/toolcalls/unknown_name.json new file mode 100644 index 0000000..0ba9e76 --- /dev/null +++ b/tests/compat/fixtures/toolcalls/unknown_name.json @@ -0,0 +1,4 @@ +{ + "text": "{\"tool_calls\":[{\"name\":\"unknown_tool\",\"input\":{\"x\":1}}]}", + "tool_names": ["read_file"] +} diff --git a/api/chat-stream.test.js b/tests/node/chat-stream.test.js similarity index 69% rename from api/chat-stream.test.js rename to tests/node/chat-stream.test.js index b347342..e31afbe 100644 --- a/api/chat-stream.test.js +++ b/tests/node/chat-stream.test.js @@ -3,17 +3,57 @@ const test = require('node:test'); const assert = require('node:assert/strict'); -const handler = require('./chat-stream'); +const handler = require('../../api/chat-stream.js'); const { createToolSieveState, processToolSieveChunk, flushToolSieve, -} = require('./helpers/stream-tool-sieve'); +} = require('../../internal/js/helpers/stream-tool-sieve.js'); -const { parseChunkForContent } = handler.__test; +const { + parseChunkForContent, + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, +} = handler.__test; test('chat-stream exposes parser test hooks', () => { assert.equal(typeof parseChunkForContent, 'function'); + assert.equal(typeof resolveToolcallPolicy, 'function'); +}); + +test('resolveToolcallPolicy defaults to feature-match + early emit when prepare flags missing', () => { + const policy = resolveToolcallPolicy( + {}, + [{ type: 'function', function: { name: 'read_file', parameters: { type: 'object' } } }], + ); + assert.deepEqual(policy.toolNames, ['read_file']); + assert.equal(policy.toolSieveEnabled, true); + assert.equal(policy.emitEarlyToolDeltas, true); +}); + +test('resolveToolcallPolicy respects prepare flags and prepared tool names', () => { + const policy = resolveToolcallPolicy( + { + tool_names: [' prepped_tool ', '', null], + toolcall_feature_match: false, + toolcall_early_emit_high: false, + }, + [{ type: 'function', function: { name: 'fallback_tool', parameters: { type: 'object' } } }], + ); + assert.deepEqual(policy.toolNames, ['prepped_tool']); + assert.equal(policy.toolSieveEnabled, false); + assert.equal(policy.emitEarlyToolDeltas, false); +}); + +test('normalizePreparedToolNames filters empty values', () => { + assert.deepEqual(normalizePreparedToolNames([' a ', '', null, 'b']), ['a', 'b']); +}); + +test('boolDefaultTrue keeps false only when explicitly false', () => { + assert.equal(boolDefaultTrue(false), false); + assert.equal(boolDefaultTrue(true), true); + assert.equal(boolDefaultTrue(undefined), true); }); test('parseChunkForContent keeps split response/content fragments inside response array', () => { @@ -49,12 +89,13 @@ test('parseChunkForContent + sieve does not leak suspicious prefix in split tool events.push(...flushToolSieve(state, ['read_file'])); const hasToolCalls = events.some((evt) => evt.type === 'tool_calls' && evt.calls && evt.calls.length > 0); + const hasToolDeltas = events.some((evt) => evt.type === 'tool_call_deltas' && evt.deltas && evt.deltas.length > 0); const leakedText = events .filter((evt) => evt.type === 'text' && evt.text) .map((evt) => evt.text) .join(''); - assert.equal(hasToolCalls, true); + assert.equal(hasToolCalls || hasToolDeltas, true); assert.equal(leakedText.includes('{'), false); assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); }); diff --git a/tests/node/js_compat_test.js b/tests/node/js_compat_test.js new file mode 100644 index 0000000..0029abe --- /dev/null +++ b/tests/node/js_compat_test.js @@ -0,0 +1,60 @@ +'use strict'; + +const test = require('node:test'); +const assert = require('node:assert/strict'); +const fs = require('node:fs'); +const path = require('node:path'); + +const chatStream = require('../../api/chat-stream.js'); +const { parseToolCalls } = require('../../internal/js/helpers/stream-tool-sieve.js'); + +const { parseChunkForContent, estimateTokens } = chatStream.__test; + +const compatRoot = path.resolve(__dirname, '../../tests/compat'); + +function readJSON(filePath) { + return JSON.parse(fs.readFileSync(filePath, 'utf8')); +} + +test('js compat: sse fixtures', () => { + const fixtureDir = path.join(compatRoot, 'fixtures', 'sse_chunks'); + const expectedDir = path.join(compatRoot, 'expected'); + const files = fs.readdirSync(fixtureDir).filter((f) => f.endsWith('.json')).sort(); + assert.ok(files.length > 0); + + for (const file of files) { + const name = file.replace(/\.json$/i, ''); + const fixture = readJSON(path.join(fixtureDir, file)); + const expected = readJSON(path.join(expectedDir, `sse_${name}.json`)); + const got = parseChunkForContent(fixture.chunk, Boolean(fixture.thinking_enabled), fixture.current_type || 'text'); + assert.deepEqual(got.parts, expected.parts, `${name}: parts mismatch`); + assert.equal(got.finished, expected.finished, `${name}: finished mismatch`); + assert.equal(got.newType, expected.new_type, `${name}: newType mismatch`); + } +}); + +test('js compat: toolcall fixtures', () => { + const fixtureDir = path.join(compatRoot, 'fixtures', 'toolcalls'); + const expectedDir = path.join(compatRoot, 'expected'); + const files = fs.readdirSync(fixtureDir).filter((f) => f.endsWith('.json')).sort(); + assert.ok(files.length > 0); + + for (const file of files) { + const name = file.replace(/\.json$/i, ''); + const fixture = readJSON(path.join(fixtureDir, file)); + const expected = readJSON(path.join(expectedDir, `toolcalls_${name}.json`)); + const got = parseToolCalls(fixture.text, fixture.tool_names || []); + assert.deepEqual(got, expected.calls, `${name}: calls mismatch`); + } +}); + +test('js compat: token fixtures', () => { + const fixture = readJSON(path.join(compatRoot, 'fixtures', 'token_cases.json')); + const expected = readJSON(path.join(compatRoot, 'expected', 'token_cases.json')); + const expectedByName = new Map(expected.cases.map((c) => [c.name, c.tokens])); + for (const c of fixture.cases) { + assert.ok(expectedByName.has(c.name), `missing expected case: ${c.name}`); + const got = estimateTokens(c.text); + assert.equal(got, expectedByName.get(c.name), `${c.name}: tokens mismatch`); + } +}); diff --git a/tests/node/stream-tool-sieve.test.js b/tests/node/stream-tool-sieve.test.js new file mode 100644 index 0000000..f20cb11 --- /dev/null +++ b/tests/node/stream-tool-sieve.test.js @@ -0,0 +1,217 @@ +'use strict'; + +const test = require('node:test'); +const assert = require('node:assert/strict'); + +const { + extractToolNames, + createToolSieveState, + processToolSieveChunk, + flushToolSieve, + parseToolCalls, + parseStandaloneToolCalls, +} = require('../../internal/js/helpers/stream-tool-sieve.js'); + +function runSieve(chunks, toolNames) { + const state = createToolSieveState(); + const events = []; + for (const chunk of chunks) { + events.push(...processToolSieveChunk(state, chunk, toolNames)); + } + events.push(...flushToolSieve(state, toolNames)); + return events; +} + +function collectText(events) { + return events + .filter((evt) => evt.type === 'text' && evt.text) + .map((evt) => evt.text) + .join(''); +} + +test('extractToolNames keeps tool mode enabled with unknown fallback', () => { + const names = extractToolNames([ + { function: { description: 'no name tool' } }, + { function: { name: ' read_file ' } }, + {}, + ]); + assert.deepEqual(names, ['unknown', 'read_file', 'unknown']); +}); + +test('parseToolCalls keeps non-object argument strings as _raw (Go parity)', () => { + const payload = JSON.stringify({ + tool_calls: [ + { name: 'read_file', input: '123' }, + { name: 'list_dir', input: '[1,2,3]' }, + ], + }); + const calls = parseToolCalls(payload, ['read_file', 'list_dir']); + assert.deepEqual(calls, [ + { name: 'read_file', input: { _raw: '123' } }, + { name: 'list_dir', input: { _raw: '[1,2,3]' } }, + ]); +}); + +test('parseToolCalls drops unknown schema names when toolNames is provided', () => { + const payload = JSON.stringify({ + tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }], + }); + const calls = parseToolCalls(payload, ['search']); + assert.equal(calls.length, 0); +}); + +test('parseToolCalls keeps unknown names when toolNames is empty', () => { + const payload = JSON.stringify({ + tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }], + }); + const calls = parseToolCalls(payload, []); + assert.equal(calls.length, 1); + assert.equal(calls[0].name, 'not_in_schema'); +}); + +test('parseToolCalls supports fenced json and function.arguments string payload', () => { + const text = [ + 'I will call a tool now.', + '```json', + '{"tool_calls":[{"function":{"name":"read_file","arguments":"{\\"path\\":\\"README.md\\"}"}}]}', + '```', + ].join('\n'); + const calls = parseToolCalls(text, ['read_file']); + assert.equal(calls.length, 0); +}); + +test('parseStandaloneToolCalls only matches standalone payload and ignores mixed prose', () => { + const mixed = '这里是示例:{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]},请勿执行。'; + const standalone = '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'; + const mixedCalls = parseStandaloneToolCalls(mixed, ['read_file']); + const standaloneCalls = parseStandaloneToolCalls(standalone, ['read_file']); + assert.equal(mixedCalls.length, 0); + assert.equal(standaloneCalls.length, 1); +}); + +test('parseStandaloneToolCalls ignores fenced code block tool_call examples', () => { + const fenced = ['```json', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}', '```'].join('\n'); + const calls = parseStandaloneToolCalls(fenced, ['read_file']); + assert.equal(calls.length, 0); +}); + +test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => { + const events = runSieve( + [ + '{"', + 'tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}', + '后置正文C。', + ], + ['read_file'], + ); + const leakedText = collectText(events); + const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0); + const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0); + assert.equal(hasToolCall || hasToolDelta, true); + assert.equal(leakedText.includes('{'), false); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); + assert.equal(leakedText.includes('后置正文C。'), true); +}); + +test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => { + const events = runSieve( + [ + '前置正文D。', + "{'tool_calls':[{'name':'read_file','input':{'path':'README.MD'}}]}", + '后置正文E。', + ], + ['read_file'], + ); + const leakedText = collectText(events); + const hasToolCall = events.some((evt) => evt.type === 'tool_calls'); + assert.equal(hasToolCall, false); + assert.equal(leakedText.includes('前置正文D。'), true); + assert.equal(leakedText.includes('后置正文E。'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); +}); + +test('sieve flushes incomplete captured tool json as text on stream finalize', () => { + const events = runSieve( + ['前置正文F。', '{"tool_calls":[{"name":"read_file"'], + ['read_file'], + ); + const leakedText = collectText(events); + assert.equal(leakedText.includes('前置正文F。'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); + assert.equal(leakedText.includes('{'), true); +}); + +test('sieve keeps plain text intact in tool mode when no tool call appears', () => { + const events = runSieve( + ['你好,', '这是普通文本回复。', '请继续。'], + ['read_file'], + ); + const leakedText = collectText(events); + const hasToolCall = events.some((evt) => evt.type === 'tool_calls'); + assert.equal(hasToolCall, false); + assert.equal(leakedText, '你好,这是普通文本回复。请继续。'); +}); + +test('sieve intercepts rejected unknown tool payload (no args) without raw leak', () => { + const events = runSieve( + ['{"tool_calls":[{"name":"not_in_schema"}]}', '后置正文G。'], + ['read_file'], + ); + const leakedText = collectText(events); + const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0); + const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0); + assert.equal(hasToolCall, false); + assert.equal(hasToolDelta, false); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); + assert.equal(leakedText.includes('后置正文G。'), true); +}); + +test('sieve emits incremental tool_call_deltas for split arguments payload', () => { + const state = createToolSieveState(); + const first = processToolSieveChunk( + state, + '{"tool_calls":[{"name":"read_file","input":{"path":"READ', + ['read_file'], + ); + const second = processToolSieveChunk( + state, + 'ME.MD","mode":"head"}}]}', + ['read_file'], + ); + const tail = flushToolSieve(state, ['read_file']); + const events = [...first, ...second, ...tail]; + const deltaEvents = events.filter((evt) => evt.type === 'tool_call_deltas'); + assert.equal(deltaEvents.length > 0, true); + const merged = deltaEvents.flatMap((evt) => evt.deltas || []); + const hasName = merged.some((d) => d.name === 'read_file'); + const argsJoined = merged + .map((d) => d.arguments || '') + .join(''); + assert.equal(hasName, true); + assert.equal(argsJoined.includes('"path":"README.MD"'), true); + assert.equal(argsJoined.includes('"mode":"head"'), true); +}); + +test('sieve still intercepts tool call after leading plain text without suffix', () => { + const events = runSieve( + ['我将调用工具。', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'], + ['read_file'], + ); + const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0)); + const leakedText = collectText(events); + assert.equal(hasTool, true); + assert.equal(leakedText.includes('我将调用工具。'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); +}); + +test('sieve intercepts tool call and preserves trailing same-chunk text', () => { + const events = runSieve( + ['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}然后继续解释。'], + ['read_file'], + ); + const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0)); + const leakedText = collectText(events); + assert.equal(hasTool, true); + assert.equal(leakedText.includes('然后继续解释。'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); +}); diff --git a/tests/scripts/check-node-split-syntax.sh b/tests/scripts/check-node-split-syntax.sh new file mode 100755 index 0000000..e06cb47 --- /dev/null +++ b/tests/scripts/check-node-split-syntax.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +TARGETS_FILE="${1:-$ROOT_DIR/plans/node-syntax-gate-targets.txt}" + +if [[ ! -f "$TARGETS_FILE" ]]; then + echo "missing targets file: $TARGETS_FILE" >&2 + exit 1 +fi + +checked=0 +missing=0 +invalid=0 + +while IFS= read -r file; do + [[ -z "$file" ]] && continue + [[ "${file:0:1}" == "#" ]] && continue + + checked=$((checked + 1)) + abs="$ROOT_DIR/$file" + if [[ ! -f "$abs" ]]; then + echo "MISSING $file" + missing=$((missing + 1)) + continue + fi + + if ! node --check "$abs"; then + echo "INVALID $file" + invalid=$((invalid + 1)) + fi +done < "$TARGETS_FILE" + +echo "checked=$checked missing=$missing invalid=$invalid" + +if (( missing > 0 || invalid > 0 )); then + exit 1 +fi diff --git a/tests/scripts/check-refactor-line-gate.sh b/tests/scripts/check-refactor-line-gate.sh new file mode 100755 index 0000000..4118d15 --- /dev/null +++ b/tests/scripts/check-refactor-line-gate.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +TARGETS_FILE="$ROOT_DIR/plans/refactor-line-gate-targets.txt" + +DEFAULT_MAX=300 +ENTRY_MAX=120 + +is_entry_file() { + case "$1" in + api/chat-stream.js|\ + internal/js/helpers/stream-tool-sieve.js|\ + webui/src/App.jsx|\ + webui/src/components/AccountManager.jsx|\ + webui/src/components/ApiTester.jsx|\ + webui/src/components/Settings.jsx|\ + webui/src/components/VercelSync.jsx) + return 0 + ;; + esac + return 1 +} + +if [[ ! -f "$TARGETS_FILE" ]]; then + echo "missing targets file: $TARGETS_FILE" >&2 + exit 1 +fi + +missing=0 +over=0 +checked=0 + +while IFS= read -r file; do + [[ -z "$file" ]] && continue + [[ "${file:0:1}" == "#" ]] && continue + + checked=$((checked + 1)) + abs="$ROOT_DIR/$file" + if [[ ! -f "$abs" ]]; then + echo "MISSING $file" + missing=$((missing + 1)) + continue + fi + + lines="$(wc -l < "$abs" | tr -d ' ')" + limit="$DEFAULT_MAX" + if is_entry_file "$file"; then + limit="$ENTRY_MAX" + fi + + if (( lines > limit )); then + echo "OVER $file lines=$lines limit=$limit" + over=$((over + 1)) + fi +done < "$TARGETS_FILE" + +echo "checked=$checked missing=$missing over_limit=$over" + +if (( missing > 0 || over > 0 )); then + exit 1 +fi diff --git a/tests/scripts/check-stage6-manual-smoke.sh b/tests/scripts/check-stage6-manual-smoke.sh new file mode 100755 index 0000000..5e29aba --- /dev/null +++ b/tests/scripts/check-stage6-manual-smoke.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +SMOKE_FILE="${1:-$ROOT_DIR/plans/stage6-manual-smoke.md}" + +if [[ ! -f "$SMOKE_FILE" ]]; then + echo "missing smoke file: $SMOKE_FILE" >&2 + exit 1 +fi + +extract_field() { + local field="$1" + local line + line="$(grep -E "^[[:space:]]*-[[:space:]]*$field:" "$SMOKE_FILE" | head -n 1 || true)" + if [[ -z "$line" ]]; then + echo "" + return + fi + printf '%s' "$line" | sed -E "s/^[[:space:]]*-[[:space:]]*$field:[[:space:]]*//" | sed -E 's/`//g;s/^[[:space:]]+//;s/[[:space:]]+$//' +} + +date_value="$(extract_field "Date")" +tester_value="$(extract_field "Tester")" +env_value="$(extract_field "Environment")" +status_value="$(extract_field "Status")" +status_upper="$(printf '%s' "$status_value" | tr '[:lower:]' '[:upper:]')" + +failed=0 + +if [[ -z "$date_value" ]]; then + echo "invalid smoke file: Date is empty" + failed=1 +fi +if [[ -z "$tester_value" ]]; then + echo "invalid smoke file: Tester is empty" + failed=1 +fi +if [[ -z "$env_value" ]]; then + echo "invalid smoke file: Environment is empty" + failed=1 +fi +if [[ "$status_upper" != "PASS" ]]; then + echo "invalid smoke file: Status must be PASS (got: ${status_value:-})" + failed=1 +fi + +if (( failed != 0 )); then + exit 1 +fi + +echo "stage6_manual_smoke=PASS file=$SMOKE_FILE" diff --git a/scripts/testsuite/run-live.sh b/tests/scripts/run-live.sh similarity index 100% rename from scripts/testsuite/run-live.sh rename to tests/scripts/run-live.sh diff --git a/tests/scripts/run-unit-all.sh b/tests/scripts/run-unit-all.sh new file mode 100755 index 0000000..59b202c --- /dev/null +++ b/tests/scripts/run-unit-all.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$ROOT_DIR" + +./tests/scripts/run-unit-go.sh +./tests/scripts/run-unit-node.sh diff --git a/tests/scripts/run-unit-go.sh b/tests/scripts/run-unit-go.sh new file mode 100755 index 0000000..38a11b8 --- /dev/null +++ b/tests/scripts/run-unit-go.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$ROOT_DIR" + +go test ./... "$@" diff --git a/tests/scripts/run-unit-node.sh b/tests/scripts/run-unit-node.sh new file mode 100755 index 0000000..69ddf4e --- /dev/null +++ b/tests/scripts/run-unit-node.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +cd "$ROOT_DIR" + +./tests/scripts/check-node-split-syntax.sh + +# Keep Node's file-level test scheduling serial to avoid intermittent cross-file +# interference when multiple suites import mutable module singletons. +NODE_TEST_LOG="$(mktemp)" +cleanup() { + rm -f "$NODE_TEST_LOG" +} +trap cleanup EXIT + +if ! node --test --test-concurrency=1 tests/node/stream-tool-sieve.test.js tests/node/chat-stream.test.js tests/node/js_compat_test.js "$@" 2>&1 | tee "$NODE_TEST_LOG"; then + echo + echo "[run-unit-node] Node tests failed. 失败摘要如下:" + rg -n "^(not ok|# fail)|ERR_TEST_FAILURE" "$NODE_TEST_LOG" || true + exit 1 +fi diff --git a/vercel.json b/vercel.json index 2e68a94..bad49e0 100644 --- a/vercel.json +++ b/vercel.json @@ -38,6 +38,18 @@ "source": "/admin/config", "destination": "/api/index" }, + { + "source": "/admin/config/(.*)", + "destination": "/api/index" + }, + { + "source": "/admin/settings", + "destination": "/api/index" + }, + { + "source": "/admin/settings/(.*)", + "destination": "/api/index" + }, { "source": "/admin/keys(.*)", "destination": "/api/index" diff --git a/webui/src/App.jsx b/webui/src/App.jsx index 53d0b4a..8067b8b 100644 --- a/webui/src/App.jsx +++ b/webui/src/App.jsx @@ -1,341 +1,3 @@ -import { useState, useEffect } from 'react' -import { - Routes, - Route, - Navigate, - useNavigate, - useLocation -} from 'react-router-dom' -import { - LayoutDashboard, - Key, - Upload, - Cloud, - LogOut, - Menu, - X, - Server, - Users -} from 'lucide-react' -import clsx from 'clsx' +import AppRoutes from './app/AppRoutes' -import AccountManager from './components/AccountManager' -import ApiTester from './components/ApiTester' -import BatchImport from './components/BatchImport' -import VercelSync from './components/VercelSync' -import Login from './components/Login' -import LandingPage from './components/LandingPage' -import LanguageToggle from './components/LanguageToggle' -import { useI18n } from './i18n' - -function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message }) { - const { t } = useI18n() - const [activeTab, setActiveTab] = useState('accounts') - const [sidebarOpen, setSidebarOpen] = useState(false) - const [loading, setLoading] = useState(false) - - const navItems = [ - { id: 'accounts', label: t('nav.accounts.label'), icon: Users, description: t('nav.accounts.desc') }, - { id: 'test', label: t('nav.test.label'), icon: Server, description: t('nav.test.desc') }, - { id: 'import', label: t('nav.import.label'), icon: Upload, description: t('nav.import.desc') }, - { id: 'vercel', label: t('nav.vercel.label'), icon: Cloud, description: t('nav.vercel.desc') }, - ] - - const authFetch = async (url, options = {}) => { - const headers = { - ...options.headers, - 'Authorization': `Bearer ${token}` - } - const res = await fetch(url, { ...options, headers }) - - if (res.status === 401) { - onLogout() - throw new Error(t('auth.expired')) - } - return res - } - - const renderTab = () => { - switch (activeTab) { - case 'accounts': - return - case 'test': - return - case 'import': - return - case 'vercel': - return - default: - return null - } - } - - return ( -
- {sidebarOpen && ( -
setSidebarOpen(false)} - /> - )} - - - -
-
-
-
- -
- DS2API -
-
- - -
-
- -
-
-
-

- {navItems.find(n => n.id === activeTab)?.label} -

-

- {navItems.find(n => n.id === activeTab)?.description} -

-
- - {message && ( -
- {message.type === 'error' ? :
} - {message.text} -
- )} - -
- {renderTab()} -
-
-
-
-
- ) -} - -export default function App() { - const { t } = useI18n() - const navigate = useNavigate() - const location = useLocation() - const [config, setConfig] = useState({ keys: [], accounts: [] }) - const [loading, setLoading] = useState(true) - const [message, setMessage] = useState(null) - const [token, setToken] = useState(null) - const [authChecking, setAuthChecking] = useState(true) - - const isProduction = import.meta.env.MODE === 'production' - const isAdminRoute = location.pathname.startsWith('/admin') || isProduction - - useEffect(() => { - // Only check auth status on admin routes. - if (!isAdminRoute) { - setAuthChecking(false) - return - } - - const checkAuth = async () => { - const storedToken = localStorage.getItem('ds2api_token') || sessionStorage.getItem('ds2api_token') - const expiresAt = parseInt(localStorage.getItem('ds2api_token_expires') || sessionStorage.getItem('ds2api_token_expires') || '0') - - if (storedToken && expiresAt > Date.now()) { - try { - const res = await fetch('/admin/verify', { - headers: { 'Authorization': `Bearer ${storedToken}` } - }) - if (res.ok) { - setToken(storedToken) - } else { - handleLogout() - } - } catch { - setToken(storedToken) - } - } - setAuthChecking(false) - } - checkAuth() - }, [isAdminRoute]) - - const fetchConfig = async () => { - if (!token) return - try { - setLoading(true) - const res = await fetch('/admin/config', { - headers: { 'Authorization': `Bearer ${token}` } - }) - if (res.ok) { - const data = await res.json() - setConfig(data) - } - } catch (e) { - console.error('Failed to fetch config:', e) - showMessage('error', t('errors.fetchConfig', { error: e.message })) - } finally { - setLoading(false) - } - } - - useEffect(() => { - if (token) { - fetchConfig() - } - }, [token]) - - const showMessage = (type, text) => { - setMessage({ type, text }) - setTimeout(() => setMessage(null), 5000) - } - - const handleLogin = (newToken) => { - setToken(newToken) - } - - const handleLogout = () => { - setToken(null) - localStorage.removeItem('ds2api_token') - localStorage.removeItem('ds2api_token_expires') - sessionStorage.removeItem('ds2api_token') - sessionStorage.removeItem('ds2api_token_expires') - } - - // Wait for auth checks on admin routes. - if (isAdminRoute && authChecking) { - return ( -
-
-
-

{t('auth.checking')}

-
-
- ) - } - - return ( - - {!isProduction && ( - navigate('/admin')} />} /> - )} - - ) : ( -
-
-
-
-
- - {message && ( -
- {message.text} -
- )} - -
- ) - } /> - } /> -
- ) -} +export default AppRoutes diff --git a/webui/src/app/AppRoutes.jsx b/webui/src/app/AppRoutes.jsx new file mode 100644 index 0000000..9a75dba --- /dev/null +++ b/webui/src/app/AppRoutes.jsx @@ -0,0 +1,84 @@ +import { Navigate, Route, Routes, useLocation, useNavigate } from 'react-router-dom' +import clsx from 'clsx' + +import LandingPage from '../components/LandingPage' +import Login from '../components/Login' +import DashboardShell from '../layout/DashboardShell' +import { useI18n } from '../i18n' +import { useAdminAuth } from './useAdminAuth' +import { useAdminConfig } from './useAdminConfig' + +export default function AppRoutes() { + const { t } = useI18n() + const navigate = useNavigate() + const location = useLocation() + + const isProduction = import.meta.env.MODE === 'production' + const { + token, + authChecking, + message, + isAdminRoute, + isVercel, + showMessage, + handleLogin, + handleLogout, + } = useAdminAuth({ isProduction, location, t }) + + const { + config, + fetchConfig, + } = useAdminConfig({ token, showMessage, t }) + + if (isAdminRoute && authChecking) { + return ( +
+
+
+

{t('auth.checking')}

+
+
+ ) + } + + return ( + + {!isProduction && ( + navigate('/admin')} />} /> + )} + + ) : ( +
+
+
+
+
+ + {message && ( +
+ {message.text} +
+ )} + +
+ ) + } /> + } /> +
+ ) +} diff --git a/webui/src/app/useAdminAuth.js b/webui/src/app/useAdminAuth.js new file mode 100644 index 0000000..2da2391 --- /dev/null +++ b/webui/src/app/useAdminAuth.js @@ -0,0 +1,70 @@ +import { useCallback, useEffect, useMemo, useState } from 'react' +import { detectRuntimeEnv } from '../utils/runtimeEnv' + +export function useAdminAuth({ isProduction, location, t }) { + const [message, setMessage] = useState(null) + const [token, setToken] = useState(null) + const [authChecking, setAuthChecking] = useState(true) + + const isAdminRoute = location.pathname.startsWith('/admin') || isProduction + const runtimeEnv = useMemo(() => detectRuntimeEnv(), []) + const isVercel = runtimeEnv.isVercel + + const showMessage = useCallback((type, text) => { + setMessage({ type, text }) + setTimeout(() => setMessage(null), 5000) + }, []) + + const handleLogout = useCallback(() => { + setToken(null) + localStorage.removeItem('ds2api_token') + localStorage.removeItem('ds2api_token_expires') + sessionStorage.removeItem('ds2api_token') + sessionStorage.removeItem('ds2api_token_expires') + }, []) + + const handleLogin = useCallback((newToken) => { + setToken(newToken) + }, []) + + useEffect(() => { + if (!isAdminRoute) { + setAuthChecking(false) + return + } + + const checkAuth = async () => { + const storedToken = localStorage.getItem('ds2api_token') || sessionStorage.getItem('ds2api_token') + const expiresAt = parseInt(localStorage.getItem('ds2api_token_expires') || sessionStorage.getItem('ds2api_token_expires') || '0') + + if (storedToken && expiresAt > Date.now()) { + try { + const res = await fetch('/admin/verify', { + headers: { 'Authorization': `Bearer ${storedToken}` } + }) + if (res.ok) { + setToken(storedToken) + } else { + handleLogout() + } + } catch { + setToken(storedToken) + } + } + setAuthChecking(false) + } + + checkAuth() + }, [handleLogout, isAdminRoute, t]) + + return { + token, + authChecking, + message, + isAdminRoute, + isVercel, + showMessage, + handleLogin, + handleLogout, + } +} diff --git a/webui/src/app/useAdminConfig.js b/webui/src/app/useAdminConfig.js new file mode 100644 index 0000000..3fa410d --- /dev/null +++ b/webui/src/app/useAdminConfig.js @@ -0,0 +1,32 @@ +import { useCallback, useEffect, useState } from 'react' + +export function useAdminConfig({ token, showMessage, t }) { + const [config, setConfig] = useState({ keys: [], accounts: [] }) + + const fetchConfig = useCallback(async () => { + if (!token) return + try { + const res = await fetch('/admin/config', { + headers: { 'Authorization': `Bearer ${token}` } + }) + if (res.ok) { + const data = await res.json() + setConfig(data) + } + } catch (e) { + console.error('Failed to fetch config:', e) + showMessage('error', t('errors.fetchConfig', { error: e.message })) + } + }, [showMessage, t, token]) + + useEffect(() => { + if (token) { + fetchConfig() + } + }, [fetchConfig, token]) + + return { + config, + fetchConfig, + } +} diff --git a/webui/src/components/AccountManager.jsx b/webui/src/components/AccountManager.jsx index 773b84e..2a37010 100644 --- a/webui/src/components/AccountManager.jsx +++ b/webui/src/components/AccountManager.jsx @@ -1,559 +1,3 @@ -import { useState, useEffect } from 'react' -import { - Plus, - Trash2, - CheckCircle2, - Play, - X, - Server, - ShieldCheck, - Copy, - Check, - ChevronLeft, - ChevronRight, - ChevronDown -} from 'lucide-react' -import clsx from 'clsx' -import { useI18n } from '../i18n' +import AccountManagerContainer from '../features/account/AccountManagerContainer' -export default function AccountManager({ config, onRefresh, onMessage, authFetch }) { - const { t } = useI18n() - const [showAddKey, setShowAddKey] = useState(false) - const [showAddAccount, setShowAddAccount] = useState(false) - const [newKey, setNewKey] = useState('') - const [copiedKey, setCopiedKey] = useState(null) - const [newAccount, setNewAccount] = useState({ email: '', mobile: '', password: '' }) - const [loading, setLoading] = useState(false) - const [testing, setTesting] = useState({}) - const [testingAll, setTestingAll] = useState(false) - const [batchProgress, setBatchProgress] = useState({ current: 0, total: 0, results: [] }) - const [queueStatus, setQueueStatus] = useState(null) - const [keysExpanded, setKeysExpanded] = useState(false) - - // 分页状态 - const [accounts, setAccounts] = useState([]) - const [page, setPage] = useState(1) - const [pageSize] = useState(10) - const [totalPages, setTotalPages] = useState(1) - const [totalAccounts, setTotalAccounts] = useState(0) - const [loadingAccounts, setLoadingAccounts] = useState(false) - - const apiFetch = authFetch || fetch - - const fetchAccounts = async (targetPage = page) => { - setLoadingAccounts(true) - try { - const res = await apiFetch(`/admin/accounts?page=${targetPage}&page_size=${pageSize}`) - if (res.ok) { - const data = await res.json() - setAccounts(data.items || []) - setTotalPages(data.total_pages || 1) - setTotalAccounts(data.total || 0) - setPage(data.page || 1) - } - } catch (e) { - console.error('Failed to fetch accounts:', e) - } finally { - setLoadingAccounts(false) - } - } - - const fetchQueueStatus = async () => { - try { - const res = await apiFetch('/admin/queue/status') - if (res.ok) { - const data = await res.json() - setQueueStatus(data) - } - } catch (e) { - console.error('Failed to fetch queue status:', e) - } - } - - useEffect(() => { - fetchAccounts() - fetchQueueStatus() - const interval = setInterval(fetchQueueStatus, 5000) - return () => clearInterval(interval) - }, []) - - const addKey = async () => { - if (!newKey.trim()) return - setLoading(true) - try { - const res = await apiFetch('/admin/keys', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ key: newKey.trim() }), - }) - if (res.ok) { - onMessage('success', t('accountManager.addKeySuccess')) - setNewKey('') - setShowAddKey(false) - onRefresh() - } else { - const data = await res.json() - onMessage('error', data.detail || t('messages.failedToAdd')) - } - } catch (e) { - onMessage('error', t('messages.networkError')) - } finally { - setLoading(false) - } - } - - const deleteKey = async (key) => { - if (!confirm(t('accountManager.deleteKeyConfirm'))) return - try { - const res = await apiFetch(`/admin/keys/${encodeURIComponent(key)}`, { method: 'DELETE' }) - if (res.ok) { - onMessage('success', t('messages.deleted')) - onRefresh() - } else { - onMessage('error', t('messages.deleteFailed')) - } - } catch (e) { - onMessage('error', t('messages.networkError')) - } - } - - const addAccount = async () => { - if (!newAccount.password || (!newAccount.email && !newAccount.mobile)) { - onMessage('error', t('accountManager.requiredFields')) - return - } - setLoading(true) - try { - const res = await apiFetch('/admin/accounts', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(newAccount), - }) - if (res.ok) { - onMessage('success', t('accountManager.addAccountSuccess')) - setNewAccount({ email: '', mobile: '', password: '' }) - setShowAddAccount(false) - fetchAccounts(1) // 添加后回到第一页 - onRefresh() - } else { - const data = await res.json() - onMessage('error', data.detail || t('messages.failedToAdd')) - } - } catch (e) { - onMessage('error', t('messages.networkError')) - } finally { - setLoading(false) - } - } - - const deleteAccount = async (id) => { - if (!confirm(t('accountManager.deleteAccountConfirm'))) return - try { - const res = await apiFetch(`/admin/accounts/${encodeURIComponent(id)}`, { method: 'DELETE' }) - if (res.ok) { - onMessage('success', t('messages.deleted')) - fetchAccounts() // 刷新当前页 - onRefresh() - } else { - onMessage('error', t('messages.deleteFailed')) - } - } catch (e) { - onMessage('error', t('messages.networkError')) - } - } - - const testAccount = async (identifier) => { - setTesting(prev => ({ ...prev, [identifier]: true })) - try { - const res = await apiFetch('/admin/accounts/test', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ identifier }), - }) - const data = await res.json() - const statusMessage = data.success - ? t('apiTester.testSuccess', { account: identifier, time: data.response_time }) - : `${identifier}: ${data.message}` - onMessage(data.success ? 'success' : 'error', statusMessage) - fetchAccounts() // 刷新当前页 - onRefresh() - } catch (e) { - onMessage('error', t('accountManager.testFailed', { error: e.message })) - } finally { - setTesting(prev => ({ ...prev, [identifier]: false })) - } - } - - const testAllAccounts = async () => { - if (!confirm(t('accountManager.testAllConfirm'))) return - const allAccounts = config.accounts || [] - if (allAccounts.length === 0) return - - setTestingAll(true) - setBatchProgress({ current: 0, total: allAccounts.length, results: [] }) - - let successCount = 0 - const results = [] - - for (let i = 0; i < allAccounts.length; i++) { - const acc = allAccounts[i] - const id = acc.email || acc.mobile - - try { - const res = await apiFetch('/admin/accounts/test', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ identifier: id }), - }) - const data = await res.json() - results.push({ id, success: data.success, message: data.message, time: data.response_time }) - if (data.success) successCount++ - } catch (e) { - results.push({ id, success: false, message: e.message }) - } - - setBatchProgress({ current: i + 1, total: allAccounts.length, results: [...results] }) - } - - onMessage('success', t('accountManager.testAllCompleted', { success: successCount, total: allAccounts.length })) - fetchAccounts() // 刷新当前页 - onRefresh() - setTestingAll(false) - } - - return ( -
- {/* Queue Status - Flat & Clean */} - { - queueStatus && ( -
-
-
- -
-

{t('accountManager.available')}

-
- {queueStatus.available} - {t('accountManager.accountsUnit')} -
-
-
-
- -
-

{t('accountManager.inUse')}

-
- {queueStatus.in_use} - {t('accountManager.threadsUnit')} -
-
-
-
- -
-

{t('accountManager.totalPool')}

-
- {queueStatus.total} - {t('accountManager.accountsUnit')} -
-
-
- ) - } - - {/* API Keys Section */} -
-
setKeysExpanded(!keysExpanded)} - > -
- -
-

{t('accountManager.apiKeysTitle')}

-

{t('accountManager.apiKeysDesc')} ({config.keys?.length || 0})

-
-
- -
- - {keysExpanded && ( -
- {config.keys?.length > 0 ? ( - config.keys.map((key, i) => ( -
-
-
- {key.slice(0, 16)}**** -
- {copiedKey === key && ( - {t('accountManager.copied')} - )} -
-
- - -
-
- )) - ) : ( -
{t('accountManager.noApiKeys')}
- )} -
- )} -
- - {/* Accounts Section */} -
-
-
-

{t('accountManager.accountsTitle')}

-

{t('accountManager.accountsDesc')}

-
-
- - -
-
- - {/* Batch Progress */} - {testingAll && batchProgress.total > 0 && ( -
-
- {t('accountManager.testingAllAccounts')} - {batchProgress.current} / {batchProgress.total} -
-
-
-
- {batchProgress.results.length > 0 && ( -
- {batchProgress.results.map((r, i) => ( -
- {r.success ? '✓' : '✗'} {r.id} -
- ))} -
- )} -
- )} - -
- {loadingAccounts ? ( -
{t('actions.loading')}
- ) : accounts.length > 0 ? ( - accounts.map((acc, i) => { - const id = acc.email || acc.mobile - return ( -
-
-
-
-
{id}
-
- {acc.has_token ? t('accountManager.sessionActive') : t('accountManager.reauthRequired')} - {acc.token_preview && ( - - {acc.token_preview} - - )} -
-
-
-
- - -
-
- ) - }) - ) : ( -
{t('accountManager.noAccounts')}
- )} -
- - {/* 分页控件 */} - {totalPages > 1 && ( -
-
- {t('accountManager.pageInfo', { current: page, total: totalPages, count: totalAccounts })} -
-
- - {page} / {totalPages} - -
-
- )} -
- - {/* Modals */} - { - showAddKey && ( -
-
-
-

{t('accountManager.modalAddKeyTitle')}

- -
-
-
- -
- setNewKey(e.target.value)} - autoFocus - /> - -
-

{t('accountManager.generateHint')}

-
-
- - -
-
-
-
- ) - } - - { - showAddAccount && ( -
-
-
-

{t('accountManager.modalAddAccountTitle')}

- -
-
-
- - setNewAccount({ ...newAccount, email: e.target.value })} - /> -
-
- - setNewAccount({ ...newAccount, mobile: e.target.value })} - /> -
-
- - setNewAccount({ ...newAccount, password: e.target.value })} - /> -
-
- - -
-
-
-
- ) - } -
- ) -} +export default AccountManagerContainer diff --git a/webui/src/components/ApiTester.jsx b/webui/src/components/ApiTester.jsx index 7d49982..b688195 100644 --- a/webui/src/components/ApiTester.jsx +++ b/webui/src/components/ApiTester.jsx @@ -1,439 +1,3 @@ -import { useEffect, useRef, useState } from 'react' -import { - Send, - Square, - MessageSquare, - Cpu, - Search as SearchIcon, - Sparkles, - Bot, - User, - Loader2, - CheckCircle2, - AlertCircle, - ChevronDown, - ShieldCheck, - Terminal, - Zap, - ToggleLeft, - ToggleRight -} from 'lucide-react' -import clsx from 'clsx' -import { useI18n } from '../i18n' +import ApiTesterContainer from '../features/apiTester/ApiTesterContainer' -export default function ApiTester({ config, onMessage, authFetch }) { - const { t } = useI18n() - const [model, setModel] = useState('deepseek-chat') - const defaultMessage = t('apiTester.defaultMessage') - const [message, setMessage] = useState(defaultMessage) - const [apiKey, setApiKey] = useState('') - const [selectedAccount, setSelectedAccount] = useState('') - const [response, setResponse] = useState(null) - const [loading, setLoading] = useState(false) - const [streamingContent, setStreamingContent] = useState('') - const [streamingThinking, setStreamingThinking] = useState('') - const [isStreaming, setIsStreaming] = useState(false) - const [streamingMode, setStreamingMode] = useState(true) - const abortControllerRef = useRef(null) - const defaultMessageRef = useRef(defaultMessage) - - const [sidebarOpen, setSidebarOpen] = useState(false) - const [configExpanded, setConfigExpanded] = useState(false) - - const apiFetch = authFetch || fetch - const accounts = config.accounts || [] - const configuredKeys = config.keys || [] - const trimmedApiKey = apiKey.trim() - const defaultKey = configuredKeys[0] || '' - const effectiveKey = trimmedApiKey || defaultKey - const customKeyActive = trimmedApiKey !== '' - const customKeyManaged = customKeyActive && configuredKeys.includes(trimmedApiKey) - const models = [ - { id: "deepseek-chat", name: "deepseek-chat", icon: MessageSquare, desc: t('apiTester.models.chat'), color: "text-amber-500" }, - { id: "deepseek-reasoner", name: "deepseek-reasoner", icon: Cpu, desc: t('apiTester.models.reasoner'), color: "text-amber-600" }, - { id: "deepseek-chat-search", name: "deepseek-chat-search", icon: SearchIcon, desc: t('apiTester.models.chatSearch'), color: "text-cyan-500" }, - { id: "deepseek-reasoner-search", name: "deepseek-reasoner-search", icon: SearchIcon, desc: t('apiTester.models.reasonerSearch'), color: "text-cyan-600" }, - ] - - const stopGeneration = () => { - if (abortControllerRef.current) { - abortControllerRef.current.abort() - abortControllerRef.current = null - } - setLoading(false) - setIsStreaming(false) - } - - const extractErrorMessage = async (res) => { - let raw = '' - try { - raw = await res.text() - } catch { - return t('apiTester.requestFailed') - } - if (!raw) { - return t('apiTester.requestFailed') - } - try { - const data = JSON.parse(raw) - const fromErrorObject = data?.error?.message - const fromErrorString = typeof data?.error === 'string' ? data.error : '' - const detail = typeof data?.detail === 'string' ? data.detail : '' - const message = typeof data?.message === 'string' ? data.message : '' - return fromErrorObject || fromErrorString || detail || message || t('apiTester.requestFailed') - } catch { - return raw.length > 240 ? `${raw.slice(0, 240)}...` : raw - } - } - - const runTest = async () => { - if (loading) return - - const startedAt = Date.now() - - setLoading(true) - setIsStreaming(true) - setResponse(null) - setStreamingContent('') - setStreamingThinking('') - - abortControllerRef.current = new AbortController() - - try { - if (!effectiveKey) { - onMessage('error', t('apiTester.missingApiKey')) - setLoading(false) - setIsStreaming(false) - return - } - - const headers = { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${effectiveKey}`, - } - if (selectedAccount) { - headers['X-Ds2-Target-Account'] = selectedAccount - } - - const endpoint = streamingMode ? '/v1/chat/completions' : '/v1/chat/completions?__go=1' - const res = await fetch(endpoint, { - method: 'POST', - headers, - body: JSON.stringify({ - model, - messages: [{ role: 'user', content: message }], - stream: streamingMode, - }), - signal: abortControllerRef.current.signal, - }) - - if (!res.ok) { - const errorMsg = await extractErrorMessage(res) - setResponse({ success: false, error: errorMsg }) - onMessage('error', errorMsg) - setLoading(false) - setIsStreaming(false) - return - } - - if (streamingMode) { - setResponse({ success: true, status_code: res.status }) - - const reader = res.body.getReader() - const decoder = new TextDecoder() - let buffer = '' - - while (true) { - const { done, value } = await reader.read() - if (done) break - - buffer += decoder.decode(value, { stream: true }) - const lines = buffer.split('\n') - buffer = lines.pop() || '' - - for (const line of lines) { - const trimmed = line.trim() - if (!trimmed || !trimmed.startsWith('data: ')) continue - - const dataStr = trimmed.slice(6) - if (dataStr === '[DONE]') continue - - try { - const json = JSON.parse(dataStr) - const choice = json.choices?.[0] - if (choice?.delta) { - const delta = choice.delta - if (delta.reasoning_content) { - setStreamingThinking(prev => prev + delta.reasoning_content) - } - if (delta.content) { - setStreamingContent(prev => prev + delta.content) - } - } - } catch (e) { - console.error('Invalid JSON hunk:', dataStr, e) - } - } - } - } else { - const data = await res.json() - setResponse({ success: true, status_code: res.status, ...data }) - const elapsed = Math.max(0, Date.now() - startedAt) - onMessage('success', t('apiTester.testSuccess', { account: selectedAccount || 'Auto', time: elapsed })) - } - } catch (e) { - if (e.name === 'AbortError') { - onMessage('info', t('messages.generationStopped')) - } else { - onMessage('error', t('apiTester.networkError', { error: e.message })) - setResponse({ error: e.message, success: false }) - } - } finally { - setLoading(false) - setIsStreaming(false) - abortControllerRef.current = null - } - } - -useEffect(() => { - setMessage((prev) => (prev === defaultMessageRef.current ? defaultMessage : prev)) - defaultMessageRef.current = defaultMessage -}, [defaultMessage]) - -return ( -
- {/* Configuration Panel */} -
-
- {/* Mobile Toggle Header */} - - -
-
- -
- {models.map(m => { - const Icon = m.icon - return ( - - ) - })} -
-
- -
- - -
- -
- -
- - -
-
- -
- - setApiKey(e.target.value)} - /> - {customKeyActive && ( -

- {customKeyManaged ? t('apiTester.modeManaged') : t('apiTester.modeDirect')} -

- )} -
-
-
-
- - {/* Chat Interface */} -
- - {/* Messages Area */} -
- {/* User Message */} -
-
- -
-
-
- {message} -
-
-
- - {/* AI Response */} - {(response || isStreaming) && ( -
-
- -
-
-
- - DeepSeek - - {response && ( - - {response.status_code || t('apiTester.statusError')} - - )} -
- - {(streamingThinking || response?.choices?.[0]?.message?.reasoning_content) && ( -
-
- - {t('apiTester.reasoningTrace')} -
-
- {streamingThinking || response?.choices?.[0]?.message?.reasoning_content} -
-
- )} - -
- {streamingContent || response?.choices?.[0]?.message?.content || (response?.error && {response.error}) || (loading && {t('apiTester.generating')})} - {isStreaming && } -
-
-
- )} -
- - {/* Input Area */} -
-
-