mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 00:15:28 +08:00
Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
224462018a | ||
|
|
35e89230fd | ||
|
|
9a57af6092 | ||
|
|
2e1bd8a481 | ||
|
|
1e678ecc1a | ||
|
|
6b3523a66d | ||
|
|
d4017b87c1 | ||
|
|
d3b60edb6f | ||
|
|
6baf687ecf | ||
|
|
7da012a4d8 | ||
|
|
6c318f1910 | ||
|
|
a9403c5392 | ||
|
|
ae7dce0b32 | ||
|
|
312728c8b6 | ||
|
|
acf39f2823 | ||
|
|
8de87fb9e0 | ||
|
|
6c48429b90 | ||
|
|
cc6af8fd28 | ||
|
|
5d3989a9a7 | ||
|
|
920767f486 | ||
|
|
7a4e994f3a | ||
|
|
13b1ec46ee | ||
|
|
e2cb07f08c | ||
|
|
541816f2ab | ||
|
|
dec9d03fc5 | ||
|
|
2781951ce7 | ||
|
|
1d2a6bf281 | ||
|
|
db49a3ec02 | ||
|
|
c509066943 | ||
|
|
0283846543 | ||
|
|
210d9f5793 | ||
|
|
dd6af0788e | ||
|
|
7307a5cc9a | ||
|
|
3239ef3c3e | ||
|
|
d21aedac83 | ||
|
|
df9aea194c | ||
|
|
2dcc230852 | ||
|
|
51c543631b | ||
|
|
895423852f | ||
|
|
eb253a9d3a | ||
|
|
3a75b75ae0 | ||
|
|
27ecb4b69b | ||
|
|
0348fa8a22 | ||
|
|
7fc10573ab | ||
|
|
ce74b124d2 | ||
|
|
f2b10992cc | ||
|
|
deec72416e | ||
|
|
7beeea5779 | ||
|
|
19289c9008 | ||
|
|
89e93a1674 | ||
|
|
f62fa22338 | ||
|
|
2acf58590a |
@@ -52,6 +52,9 @@ DS2API_ADMIN_KEY=admin
|
|||||||
|
|
||||||
# Option C: Base64 encoded JSON (recommended for Vercel env var)
|
# Option C: Base64 encoded JSON (recommended for Vercel env var)
|
||||||
# DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0=
|
# DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0=
|
||||||
|
#
|
||||||
|
# Generate from local config.json:
|
||||||
|
# DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||||
|
|
||||||
# ---------------------------------------------------------------
|
# ---------------------------------------------------------------
|
||||||
# Paths (optional)
|
# Paths (optional)
|
||||||
|
|||||||
6
.github/PULL_REQUEST_TEMPLATE.md
vendored
6
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -13,12 +13,8 @@
|
|||||||
|
|
||||||
#### 🔀 变更说明 | Description of Change
|
#### 🔀 变更说明 | Description of Change
|
||||||
|
|
||||||
<!-- Thank you for your Pull Request. Please provide a description above. -->
|
|
||||||
|
|
||||||
#### 📝 补充信息 | Additional Information
|
#### 📝 补充信息 | Additional Information
|
||||||
|
|
||||||
<!-- Add any other context about the Pull Request here. -->
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
> 💡 **提示**:如果修改了 `webui/` 目录下的文件,PR 合并后 CI 会自动构建并提交产物,无需手动构建。
|
|
||||||
40
.github/workflows/quality-gates.yml
vendored
Normal file
40
.github/workflows/quality-gates.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
name: Quality Gates
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- dev
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
quality-gates:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.24.x"
|
||||||
|
|
||||||
|
- name: Setup Node
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "20"
|
||||||
|
cache: "npm"
|
||||||
|
cache-dependency-path: webui/package-lock.json
|
||||||
|
|
||||||
|
- name: Refactor Line Gate
|
||||||
|
run: ./tests/scripts/check-refactor-line-gate.sh
|
||||||
|
|
||||||
|
- name: Unit Gates (Go + Node)
|
||||||
|
run: ./tests/scripts/run-unit-all.sh
|
||||||
|
|
||||||
|
- name: WebUI Build Gate
|
||||||
|
run: |
|
||||||
|
npm ci --prefix webui
|
||||||
|
npm run build --prefix webui
|
||||||
67
.github/workflows/release-artifacts.yml
vendored
67
.github/workflows/release-artifacts.yml
vendored
@@ -12,6 +12,9 @@ permissions:
|
|||||||
jobs:
|
jobs:
|
||||||
build-and-upload:
|
build-and-upload:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -28,6 +31,12 @@ jobs:
|
|||||||
cache: "npm"
|
cache: "npm"
|
||||||
cache-dependency-path: webui/package-lock.json
|
cache-dependency-path: webui/package-lock.json
|
||||||
|
|
||||||
|
- name: Release Blocking Gates
|
||||||
|
run: |
|
||||||
|
./tests/scripts/check-stage6-manual-smoke.sh
|
||||||
|
./tests/scripts/check-refactor-line-gate.sh
|
||||||
|
./tests/scripts/run-unit-all.sh
|
||||||
|
|
||||||
- name: Build WebUI
|
- name: Build WebUI
|
||||||
run: |
|
run: |
|
||||||
npm ci --prefix webui
|
npm ci --prefix webui
|
||||||
@@ -73,16 +82,6 @@ jobs:
|
|||||||
rm -rf "${STAGE}"
|
rm -rf "${STAGE}"
|
||||||
done
|
done
|
||||||
|
|
||||||
(cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt)
|
|
||||||
|
|
||||||
- name: Upload Release Assets
|
|
||||||
uses: softprops/action-gh-release@v2
|
|
||||||
with:
|
|
||||||
files: |
|
|
||||||
dist/*.tar.gz
|
|
||||||
dist/*.zip
|
|
||||||
dist/sha256sums.txt
|
|
||||||
|
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v3
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
@@ -96,11 +95,20 @@ jobs:
|
|||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Log in to Docker Hub
|
||||||
|
if: "${{ env.DOCKERHUB_USERNAME != '' }}"
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ env.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Extract Docker metadata
|
- name: Extract Docker metadata
|
||||||
id: meta
|
id: meta_release
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: ghcr.io/${{ github.repository }}
|
images: |
|
||||||
|
ghcr.io/${{ github.repository }}
|
||||||
|
${{ env.DOCKERHUB_USERNAME || 'cjackhwang' }}/ds2api
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value=${{ github.event.release.tag_name }}
|
type=raw,value=${{ github.event.release.tag_name }}
|
||||||
type=raw,value=latest
|
type=raw,value=latest
|
||||||
@@ -112,5 +120,36 @@ jobs:
|
|||||||
file: ./Dockerfile
|
file: ./Dockerfile
|
||||||
push: true
|
push: true
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64,linux/arm64
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta_release.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta_release.outputs.labels }}
|
||||||
|
|
||||||
|
- name: Export Docker image archives for release assets
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
TAG="${{ github.event.release.tag_name }}"
|
||||||
|
|
||||||
|
docker buildx build \
|
||||||
|
--platform linux/amd64 \
|
||||||
|
--output type=docker,dest="dist/ds2api_${TAG}_docker_linux_amd64.tar" \
|
||||||
|
.
|
||||||
|
|
||||||
|
docker buildx build \
|
||||||
|
--platform linux/arm64 \
|
||||||
|
--output type=docker,dest="dist/ds2api_${TAG}_docker_linux_arm64.tar" \
|
||||||
|
.
|
||||||
|
|
||||||
|
gzip -f "dist/ds2api_${TAG}_docker_linux_amd64.tar"
|
||||||
|
gzip -f "dist/ds2api_${TAG}_docker_linux_arm64.tar"
|
||||||
|
|
||||||
|
- name: Generate checksums
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
(cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt)
|
||||||
|
|
||||||
|
- name: Upload Release Assets
|
||||||
|
uses: softprops/action-gh-release@v2
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
dist/*.tar.gz
|
||||||
|
dist/*.zip
|
||||||
|
dist/sha256sums.txt
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -81,6 +81,9 @@ ds2api-tests
|
|||||||
htmlcov/
|
htmlcov/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
.tox/
|
.tox/
|
||||||
|
*.coverprofile
|
||||||
|
coverage*.out
|
||||||
|
cover/
|
||||||
|
|
||||||
# Misc
|
# Misc
|
||||||
*.pyc
|
*.pyc
|
||||||
|
|||||||
334
API.en.md
334
API.en.md
@@ -9,11 +9,13 @@ This document describes the actual behavior of the current Go codebase.
|
|||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
- [Basics](#basics)
|
- [Basics](#basics)
|
||||||
|
- [Configuration Best Practice](#configuration-best-practice)
|
||||||
- [Authentication](#authentication)
|
- [Authentication](#authentication)
|
||||||
- [Route Index](#route-index)
|
- [Route Index](#route-index)
|
||||||
- [Health Endpoints](#health-endpoints)
|
- [Health Endpoints](#health-endpoints)
|
||||||
- [OpenAI-Compatible API](#openai-compatible-api)
|
- [OpenAI-Compatible API](#openai-compatible-api)
|
||||||
- [Claude-Compatible API](#claude-compatible-api)
|
- [Claude-Compatible API](#claude-compatible-api)
|
||||||
|
- [Gemini-Compatible API](#gemini-compatible-api)
|
||||||
- [Admin API](#admin-api)
|
- [Admin API](#admin-api)
|
||||||
- [Error Payloads](#error-payloads)
|
- [Error Payloads](#error-payloads)
|
||||||
- [cURL Examples](#curl-examples)
|
- [cURL Examples](#curl-examples)
|
||||||
@@ -27,13 +29,35 @@ This document describes the actual behavior of the current Go codebase.
|
|||||||
| Base URL | `http://localhost:5001` or your deployment domain |
|
| Base URL | `http://localhost:5001` or your deployment domain |
|
||||||
| Default Content-Type | `application/json` |
|
| Default Content-Type | `application/json` |
|
||||||
| Health probes | `GET /healthz`, `GET /readyz` |
|
| Health probes | `GET /healthz`, `GET /readyz` |
|
||||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`) |
|
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configuration Best Practice
|
||||||
|
|
||||||
|
Use `config.json` as the single source of truth:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# Edit config.json (keys/accounts)
|
||||||
|
```
|
||||||
|
|
||||||
|
Use it per deployment mode:
|
||||||
|
|
||||||
|
- Local run: read `config.json` directly
|
||||||
|
- Docker / Vercel: generate Base64 from `config.json`, then set `DS2API_CONFIG_JSON`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||||
|
```
|
||||||
|
|
||||||
|
For Vercel one-click bootstrap, you can set only `DS2API_ADMIN_KEY` first, then import config at `/admin` and sync env vars from the "Vercel Sync" page.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Authentication
|
## Authentication
|
||||||
|
|
||||||
### Business Endpoints (`/v1/*`, `/anthropic/*`)
|
### Business Endpoints (`/v1/*`, `/anthropic/*`, `/v1beta/models/*`)
|
||||||
|
|
||||||
Two header formats accepted:
|
Two header formats accepted:
|
||||||
|
|
||||||
@@ -66,15 +90,32 @@ Two header formats accepted:
|
|||||||
| GET | `/healthz` | None | Liveness probe |
|
| GET | `/healthz` | None | Liveness probe |
|
||||||
| GET | `/readyz` | None | Readiness probe |
|
| GET | `/readyz` | None | Readiness probe |
|
||||||
| GET | `/v1/models` | None | OpenAI model list |
|
| GET | `/v1/models` | None | OpenAI model list |
|
||||||
|
| GET | `/v1/models/{id}` | None | OpenAI single-model query (alias accepted) |
|
||||||
| POST | `/v1/chat/completions` | Business | OpenAI chat completions |
|
| POST | `/v1/chat/completions` | Business | OpenAI chat completions |
|
||||||
|
| POST | `/v1/responses` | Business | OpenAI Responses API (stream/non-stream) |
|
||||||
|
| GET | `/v1/responses/{response_id}` | Business | Query stored response (in-memory TTL) |
|
||||||
|
| POST | `/v1/embeddings` | Business | OpenAI Embeddings API |
|
||||||
| GET | `/anthropic/v1/models` | None | Claude model list |
|
| GET | `/anthropic/v1/models` | None | Claude model list |
|
||||||
| POST | `/anthropic/v1/messages` | Business | Claude messages |
|
| POST | `/anthropic/v1/messages` | Business | Claude messages |
|
||||||
| POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting |
|
| POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting |
|
||||||
|
| POST | `/v1/messages` | Business | Claude shortcut path |
|
||||||
|
| POST | `/messages` | Business | Claude shortcut path |
|
||||||
|
| POST | `/v1/messages/count_tokens` | Business | Claude token counting shortcut |
|
||||||
|
| POST | `/messages/count_tokens` | Business | Claude token counting shortcut |
|
||||||
|
| POST | `/v1beta/models/{model}:generateContent` | Business | Gemini non-stream |
|
||||||
|
| POST | `/v1beta/models/{model}:streamGenerateContent` | Business | Gemini stream |
|
||||||
|
| POST | `/v1/models/{model}:generateContent` | Business | Gemini non-stream compat path |
|
||||||
|
| POST | `/v1/models/{model}:streamGenerateContent` | Business | Gemini stream compat path |
|
||||||
| POST | `/admin/login` | None | Admin login |
|
| POST | `/admin/login` | None | Admin login |
|
||||||
| GET | `/admin/verify` | JWT | Verify admin JWT |
|
| GET | `/admin/verify` | JWT | Verify admin JWT |
|
||||||
| GET | `/admin/vercel/config` | Admin | Read preconfigured Vercel creds |
|
| GET | `/admin/vercel/config` | Admin | Read preconfigured Vercel creds |
|
||||||
| GET | `/admin/config` | Admin | Read sanitized config |
|
| GET | `/admin/config` | Admin | Read sanitized config |
|
||||||
| POST | `/admin/config` | Admin | Update config |
|
| POST | `/admin/config` | Admin | Update config |
|
||||||
|
| GET | `/admin/settings` | Admin | Read runtime settings |
|
||||||
|
| PUT | `/admin/settings` | Admin | Update runtime settings (hot reload) |
|
||||||
|
| POST | `/admin/settings/password` | Admin | Update admin password and invalidate old JWTs |
|
||||||
|
| POST | `/admin/config/import` | Admin | Import config (merge/replace) |
|
||||||
|
| GET | `/admin/config/export` | Admin | Export full config (`config`/`json`/`base64`) |
|
||||||
| POST | `/admin/keys` | Admin | Add API key |
|
| POST | `/admin/keys` | Admin | Add API key |
|
||||||
| DELETE | `/admin/keys/{key}` | Admin | Delete API key |
|
| DELETE | `/admin/keys/{key}` | Admin | Delete API key |
|
||||||
| GET | `/admin/accounts` | Admin | Paginated account list |
|
| GET | `/admin/accounts` | Admin | Paginated account list |
|
||||||
@@ -88,6 +129,8 @@ Two header formats accepted:
|
|||||||
| POST | `/admin/vercel/sync` | Admin | Sync config to Vercel |
|
| POST | `/admin/vercel/sync` | Admin | Sync config to Vercel |
|
||||||
| GET | `/admin/vercel/status` | Admin | Vercel sync status |
|
| GET | `/admin/vercel/status` | Admin | Vercel sync status |
|
||||||
| GET | `/admin/export` | Admin | Export config JSON/Base64 |
|
| GET | `/admin/export` | Admin | Export config JSON/Base64 |
|
||||||
|
| GET | `/admin/dev/captures` | Admin | Read local packet-capture entries |
|
||||||
|
| DELETE | `/admin/dev/captures` | Admin | Clear local packet-capture entries |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -127,6 +170,15 @@ No auth required. Returns supported models.
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Model Alias Resolution
|
||||||
|
|
||||||
|
For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-output policy:
|
||||||
|
|
||||||
|
1. Match DeepSeek native model IDs first.
|
||||||
|
2. Then match exact keys in `model_aliases`.
|
||||||
|
3. If still unmatched, fall back by known family heuristics (`o*`, `gpt-*`, `claude-*`, etc.).
|
||||||
|
4. If still unmatched, return `invalid_request_error`.
|
||||||
|
|
||||||
### `POST /v1/chat/completions`
|
### `POST /v1/chat/completions`
|
||||||
|
|
||||||
**Headers**:
|
**Headers**:
|
||||||
@@ -140,7 +192,7 @@ Content-Type: application/json
|
|||||||
|
|
||||||
| Field | Type | Required | Notes |
|
| Field | Type | Required | Notes |
|
||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` |
|
| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, etc.) |
|
||||||
| `messages` | array | ✅ | OpenAI-style messages |
|
| `messages` | array | ✅ | OpenAI-style messages |
|
||||||
| `stream` | boolean | ❌ | Default `false` |
|
| `stream` | boolean | ❌ | Default `false` |
|
||||||
| `tools` | array | ❌ | Function calling schema |
|
| `tools` | array | ❌ | Function calling schema |
|
||||||
@@ -230,12 +282,90 @@ When `tools` is present, DS2API performs anti-leak handling:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
**Stream**: DS2API buffers text first. If tool call detected → only structured `delta.tool_calls` (each with `index`); otherwise emits buffered text at once.
|
**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full JSON closure), then keeps sending argument deltas; confirmed raw tool JSON is never forwarded as `delta.content`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `GET /v1/models/{id}`
|
||||||
|
|
||||||
|
No auth required. Alias values are accepted as path params (for example `gpt-4o`), and the returned object is the mapped DeepSeek model.
|
||||||
|
|
||||||
|
### `POST /v1/responses`
|
||||||
|
|
||||||
|
OpenAI Responses-style endpoint, accepting either `input` or `messages`.
|
||||||
|
|
||||||
|
| Field | Type | Required | Notes |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `model` | string | ✅ | Supports native models + alias mapping |
|
||||||
|
| `input` | string/array/object | ❌ | One of `input` or `messages` is required |
|
||||||
|
| `messages` | array | ❌ | One of `input` or `messages` is required |
|
||||||
|
| `instructions` | string | ❌ | Prepended as a system message |
|
||||||
|
| `stream` | boolean | ❌ | Default `false` |
|
||||||
|
| `tools` | array | ❌ | Same tool detection/translation policy as chat |
|
||||||
|
| `tool_choice` | string/object | ❌ | Supports `auto`/`none`/`required` and forced function selection (`{"type":"function","name":"..."}`) |
|
||||||
|
|
||||||
|
**Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache.
|
||||||
|
If `tool_choice=required` and no valid tool call is produced, DS2API returns HTTP `422` (`error.code=tool_choice_violation`).
|
||||||
|
|
||||||
|
**Stream (SSE)**: minimal event sequence:
|
||||||
|
|
||||||
|
```text
|
||||||
|
event: response.created
|
||||||
|
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
|
||||||
|
|
||||||
|
event: response.output_item.added
|
||||||
|
data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||||
|
|
||||||
|
event: response.content_part.added
|
||||||
|
data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...}
|
||||||
|
|
||||||
|
event: response.output_text.delta
|
||||||
|
data: {"type":"response.output_text.delta","response_id":"resp_xxx","item_id":"msg_xxx","output_index":0,"content_index":0,"delta":"..."}
|
||||||
|
|
||||||
|
event: response.function_call_arguments.delta
|
||||||
|
data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."}
|
||||||
|
|
||||||
|
event: response.function_call_arguments.done
|
||||||
|
data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"}
|
||||||
|
|
||||||
|
event: response.content_part.done
|
||||||
|
data: {"type":"response.content_part.done","response_id":"resp_xxx",...}
|
||||||
|
|
||||||
|
event: response.output_item.done
|
||||||
|
data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||||
|
|
||||||
|
event: response.completed
|
||||||
|
data: {"type":"response.completed","response":{...}}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
```
|
||||||
|
|
||||||
|
If `tool_choice=required` is violated in stream mode, DS2API emits `response.failed` then `[DONE]` (no `response.completed`).
|
||||||
|
Unknown tool names (outside declared `tools`) are rejected and will not be emitted as valid tool calls.
|
||||||
|
|
||||||
|
### `GET /v1/responses/{response_id}`
|
||||||
|
|
||||||
|
Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read).
|
||||||
|
|
||||||
|
> Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`).
|
||||||
|
|
||||||
|
### `POST /v1/embeddings`
|
||||||
|
|
||||||
|
Business auth required. Returns OpenAI-compatible embeddings shape.
|
||||||
|
|
||||||
|
| Field | Type | Required | Notes |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `model` | string | ✅ | Supports native models + alias mapping |
|
||||||
|
| `input` | string/array | ✅ | Supports string, string array, token array |
|
||||||
|
|
||||||
|
> Requires `embeddings.provider`. Current supported values: `mock` / `deterministic` / `builtin`. If missing/unsupported, returns standard error shape with HTTP 501.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Claude-Compatible API
|
## Claude-Compatible API
|
||||||
|
|
||||||
|
Besides `/anthropic/v1/*`, DS2API also supports shortcut paths: `/v1/messages`, `/messages`, `/v1/messages/count_tokens`, `/messages/count_tokens`.
|
||||||
|
|
||||||
### `GET /anthropic/v1/models`
|
### `GET /anthropic/v1/models`
|
||||||
|
|
||||||
No auth required.
|
No auth required.
|
||||||
@@ -249,7 +379,10 @@ No auth required.
|
|||||||
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||||
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||||
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
|
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
|
||||||
]
|
],
|
||||||
|
"first_id": "claude-opus-4-6",
|
||||||
|
"last_id": "claude-instant-1.0",
|
||||||
|
"has_more": false
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -265,13 +398,15 @@ Content-Type: application/json
|
|||||||
anthropic-version: 2023-06-01
|
anthropic-version: 2023-06-01
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> `anthropic-version` is optional; DS2API auto-fills `2023-06-01` when absent.
|
||||||
|
|
||||||
**Request body**:
|
**Request body**:
|
||||||
|
|
||||||
| Field | Type | Required | Notes |
|
| Field | Type | Required | Notes |
|
||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
| `model` | string | ✅ | For example `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5` (compatible with `claude-3-5-haiku-latest`), plus historical Claude model IDs |
|
| `model` | string | ✅ | For example `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5` (compatible with `claude-3-5-haiku-latest`), plus historical Claude model IDs |
|
||||||
| `messages` | array | ✅ | Claude-style messages |
|
| `messages` | array | ✅ | Claude-style messages |
|
||||||
| `max_tokens` | number | ❌ | Not strictly enforced by upstream bridge |
|
| `max_tokens` | number | ❌ | Auto-filled to `8192` when omitted; not strictly enforced by upstream bridge |
|
||||||
| `stream` | boolean | ❌ | Default `false` |
|
| `stream` | boolean | ❌ | Default `false` |
|
||||||
| `system` | string | ❌ | Optional system prompt |
|
| `system` | string | ❌ | Optional system prompt |
|
||||||
| `tools` | array | ❌ | Claude tool schema |
|
| `tools` | array | ❌ | Claude tool schema |
|
||||||
@@ -354,6 +489,37 @@ data: {"type":"message_stop"}
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Gemini-Compatible API
|
||||||
|
|
||||||
|
Supported paths:
|
||||||
|
|
||||||
|
- `/v1beta/models/{model}:generateContent`
|
||||||
|
- `/v1beta/models/{model}:streamGenerateContent`
|
||||||
|
- `/v1/models/{model}:generateContent` (compat path)
|
||||||
|
- `/v1/models/{model}:streamGenerateContent` (compat path)
|
||||||
|
|
||||||
|
Authentication is the same as other business routes (`Authorization: Bearer <token>` or `x-api-key`).
|
||||||
|
|
||||||
|
### `POST /v1beta/models/{model}:generateContent`
|
||||||
|
|
||||||
|
Request body accepts Gemini-style `contents` / `tools`. Model names can use aliases and are mapped to DeepSeek models.
|
||||||
|
|
||||||
|
Response uses Gemini-compatible fields, including:
|
||||||
|
|
||||||
|
- `candidates[].content.parts[].text`
|
||||||
|
- `candidates[].content.parts[].functionCall` (when tool call is produced)
|
||||||
|
- `usageMetadata` (`promptTokenCount` / `candidatesTokenCount` / `totalTokenCount`)
|
||||||
|
|
||||||
|
### `POST /v1beta/models/{model}:streamGenerateContent`
|
||||||
|
|
||||||
|
Returns SSE (`text/event-stream`), each chunk as `data: <json>`:
|
||||||
|
|
||||||
|
- regular text: incremental text chunks
|
||||||
|
- `tools` mode: buffered and emitted as `functionCall` at finalize phase
|
||||||
|
- final chunk: includes `finishReason: "STOP"` and `usageMetadata`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Admin API
|
## Admin API
|
||||||
|
|
||||||
### `POST /admin/login`
|
### `POST /admin/login`
|
||||||
@@ -416,6 +582,7 @@ Returns sanitized config.
|
|||||||
"keys": ["k1", "k2"],
|
"keys": ["k1", "k2"],
|
||||||
"accounts": [
|
"accounts": [
|
||||||
{
|
{
|
||||||
|
"identifier": "user@example.com",
|
||||||
"email": "user@example.com",
|
"email": "user@example.com",
|
||||||
"mobile": "",
|
"mobile": "",
|
||||||
"has_password": true,
|
"has_password": true,
|
||||||
@@ -449,6 +616,51 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `GET /admin/settings`
|
||||||
|
|
||||||
|
Reads runtime settings and status, including:
|
||||||
|
|
||||||
|
- `admin` (JWT expiry, default-password warning, etc.)
|
||||||
|
- `runtime` (`account_max_inflight`, `account_max_queue`, `global_max_inflight`)
|
||||||
|
- `toolcall` / `responses` / `embeddings`
|
||||||
|
- `claude_mapping` / `model_aliases`
|
||||||
|
- `env_backed`, `needs_vercel_sync`
|
||||||
|
|
||||||
|
### `PUT /admin/settings`
|
||||||
|
|
||||||
|
Hot-updates runtime settings. Supported fields:
|
||||||
|
|
||||||
|
- `admin.jwt_expire_hours`
|
||||||
|
- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight`
|
||||||
|
- `toolcall.mode` / `toolcall.early_emit_confidence`
|
||||||
|
- `responses.store_ttl_seconds`
|
||||||
|
- `embeddings.provider`
|
||||||
|
- `claude_mapping`
|
||||||
|
- `model_aliases`
|
||||||
|
|
||||||
|
### `POST /admin/settings/password`
|
||||||
|
|
||||||
|
Updates admin password and invalidates existing JWTs.
|
||||||
|
|
||||||
|
Request example:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"new_password":"your-new-password"}
|
||||||
|
```
|
||||||
|
|
||||||
|
### `POST /admin/config/import`
|
||||||
|
|
||||||
|
Imports full config with:
|
||||||
|
|
||||||
|
- `mode=merge` (default)
|
||||||
|
- `mode=replace`
|
||||||
|
|
||||||
|
The request can send config directly, or wrapped as `{"config": {...}, "mode":"merge"}`.
|
||||||
|
|
||||||
|
### `GET /admin/config/export`
|
||||||
|
|
||||||
|
Exports full config in three forms: `config`, `json`, and `base64`.
|
||||||
|
|
||||||
### `POST /admin/keys`
|
### `POST /admin/keys`
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -476,6 +688,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
|||||||
{
|
{
|
||||||
"items": [
|
"items": [
|
||||||
{
|
{
|
||||||
|
"identifier": "user@example.com",
|
||||||
"email": "user@example.com",
|
"email": "user@example.com",
|
||||||
"mobile": "",
|
"mobile": "",
|
||||||
"has_password": true,
|
"has_password": true,
|
||||||
@@ -500,7 +713,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
|||||||
|
|
||||||
### `DELETE /admin/accounts/{identifier}`
|
### `DELETE /admin/accounts/{identifier}`
|
||||||
|
|
||||||
`identifier` is email or mobile.
|
`identifier` can be email, mobile, or the synthetic id for token-only accounts (`token:<hash>`).
|
||||||
|
|
||||||
**Response**: `{"success": true, "total_accounts": 5}`
|
**Response**: `{"success": true, "total_accounts": 5}`
|
||||||
|
|
||||||
@@ -530,7 +743,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
|||||||
|
|
||||||
| Field | Required | Notes |
|
| Field | Required | Notes |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `identifier` | ✅ | email or mobile |
|
| `identifier` | ✅ | email / mobile / token-only synthetic id |
|
||||||
| `model` | ❌ | default `deepseek-chat` |
|
| `model` | ❌ | default `deepseek-chat` |
|
||||||
| `message` | ❌ | if empty, only session creation is tested |
|
| `message` | ❌ | if empty, only session creation is tested |
|
||||||
|
|
||||||
@@ -655,17 +868,53 @@ Or manual deploy required:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `GET /admin/dev/captures`
|
||||||
|
|
||||||
|
Reads local packet-capture status and recent entries (Admin auth required):
|
||||||
|
|
||||||
|
- `enabled`
|
||||||
|
- `limit`
|
||||||
|
- `max_body_bytes`
|
||||||
|
- `items`
|
||||||
|
|
||||||
|
### `DELETE /admin/dev/captures`
|
||||||
|
|
||||||
|
Clears packet-capture entries:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"success":true,"detail":"capture logs cleared"}
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Error Payloads
|
## Error Payloads
|
||||||
|
|
||||||
Error formats vary by module:
|
Compatible routes (`/v1/*`, `/anthropic/*`) use the same error envelope:
|
||||||
|
|
||||||
| Module | Format |
|
```json
|
||||||
| --- | --- |
|
{
|
||||||
| OpenAI routes | `{"error": {"message": "...", "type": "..."}}` |
|
"error": {
|
||||||
| Claude routes | `{"error": {"type": "...", "message": "..."}}` |
|
"message": "...",
|
||||||
| Admin routes | `{"detail": "..."}` |
|
"type": "invalid_request_error",
|
||||||
|
"code": "invalid_request",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Admin routes keep `{"detail":"..."}`.
|
||||||
|
|
||||||
|
Gemini routes use Google-style errors:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"code": 400,
|
||||||
|
"message": "invalid json",
|
||||||
|
"status": "INVALID_ARGUMENT"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
Clients should handle HTTP status code plus `error` / `detail` fields.
|
Clients should handle HTTP status code plus `error` / `detail` fields.
|
||||||
|
|
||||||
@@ -707,6 +956,31 @@ curl http://localhost:5001/v1/chat/completions \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### OpenAI Responses (Stream)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:5001/v1/responses \
|
||||||
|
-H "Authorization: Bearer your-api-key" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-5-codex",
|
||||||
|
"input": "Write a hello world in golang",
|
||||||
|
"stream": true
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI Embeddings
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:5001/v1/embeddings \
|
||||||
|
-H "Authorization: Bearer your-api-key" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"input": ["first text", "second text"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
### OpenAI with Search
|
### OpenAI with Search
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -748,6 +1022,38 @@ curl http://localhost:5001/v1/chat/completions \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Gemini Non-Stream
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl "http://localhost:5001/v1beta/models/gemini-2.5-pro:generateContent" \
|
||||||
|
-H "Authorization: Bearer your-api-key" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "Introduce Go in three sentences"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Gemini Stream
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl "http://localhost:5001/v1beta/models/gemini-2.5-flash:streamGenerateContent" \
|
||||||
|
-H "Authorization: Bearer your-api-key" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "Write a short summary"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
### Claude Non-Stream
|
### Claude Non-Stream
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
334
API.md
334
API.md
@@ -9,11 +9,13 @@
|
|||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
- [基础信息](#基础信息)
|
- [基础信息](#基础信息)
|
||||||
|
- [配置最佳实践](#配置最佳实践)
|
||||||
- [鉴权规则](#鉴权规则)
|
- [鉴权规则](#鉴权规则)
|
||||||
- [路由总览](#路由总览)
|
- [路由总览](#路由总览)
|
||||||
- [健康检查](#健康检查)
|
- [健康检查](#健康检查)
|
||||||
- [OpenAI 兼容接口](#openai-兼容接口)
|
- [OpenAI 兼容接口](#openai-兼容接口)
|
||||||
- [Claude 兼容接口](#claude-兼容接口)
|
- [Claude 兼容接口](#claude-兼容接口)
|
||||||
|
- [Gemini 兼容接口](#gemini-兼容接口)
|
||||||
- [Admin 接口](#admin-接口)
|
- [Admin 接口](#admin-接口)
|
||||||
- [错误响应格式](#错误响应格式)
|
- [错误响应格式](#错误响应格式)
|
||||||
- [cURL 示例](#curl-示例)
|
- [cURL 示例](#curl-示例)
|
||||||
@@ -27,13 +29,35 @@
|
|||||||
| Base URL | `http://localhost:5001` 或你的部署域名 |
|
| Base URL | `http://localhost:5001` 或你的部署域名 |
|
||||||
| 默认 Content-Type | `application/json` |
|
| 默认 Content-Type | `application/json` |
|
||||||
| 健康检查 | `GET /healthz`、`GET /readyz` |
|
| 健康检查 | `GET /healthz`、`GET /readyz` |
|
||||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`) |
|
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 配置最佳实践
|
||||||
|
|
||||||
|
推荐把 `config.json` 作为唯一配置源:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# 编辑 config.json(keys/accounts)
|
||||||
|
```
|
||||||
|
|
||||||
|
按部署方式使用:
|
||||||
|
|
||||||
|
- 本地运行:直接读取 `config.json`
|
||||||
|
- Docker / Vercel:从 `config.json` 生成 Base64,填入 `DS2API_CONFIG_JSON`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||||
|
```
|
||||||
|
|
||||||
|
Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导入配置,再通过 “Vercel 同步” 写回环境变量。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 鉴权规则
|
## 鉴权规则
|
||||||
|
|
||||||
### 业务接口(`/v1/*`、`/anthropic/*`)
|
### 业务接口(`/v1/*`、`/anthropic/*`、`/v1beta/models/*`)
|
||||||
|
|
||||||
支持两种传参方式:
|
支持两种传参方式:
|
||||||
|
|
||||||
@@ -66,15 +90,32 @@
|
|||||||
| GET | `/healthz` | 无 | 存活探针 |
|
| GET | `/healthz` | 无 | 存活探针 |
|
||||||
| GET | `/readyz` | 无 | 就绪探针 |
|
| GET | `/readyz` | 无 | 就绪探针 |
|
||||||
| GET | `/v1/models` | 无 | OpenAI 模型列表 |
|
| GET | `/v1/models` | 无 | OpenAI 模型列表 |
|
||||||
|
| GET | `/v1/models/{id}` | 无 | OpenAI 单模型查询(支持 alias 入参) |
|
||||||
| POST | `/v1/chat/completions` | 业务 | OpenAI 对话补全 |
|
| POST | `/v1/chat/completions` | 业务 | OpenAI 对话补全 |
|
||||||
|
| POST | `/v1/responses` | 业务 | OpenAI Responses 接口(流式/非流式) |
|
||||||
|
| GET | `/v1/responses/{response_id}` | 业务 | 查询已生成 response(内存 TTL) |
|
||||||
|
| POST | `/v1/embeddings` | 业务 | OpenAI Embeddings 接口 |
|
||||||
| GET | `/anthropic/v1/models` | 无 | Claude 模型列表 |
|
| GET | `/anthropic/v1/models` | 无 | Claude 模型列表 |
|
||||||
| POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 |
|
| POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 |
|
||||||
| POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 |
|
| POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 |
|
||||||
|
| POST | `/v1/messages` | 业务 | Claude 消息快捷路径 |
|
||||||
|
| POST | `/messages` | 业务 | Claude 消息快捷路径 |
|
||||||
|
| POST | `/v1/messages/count_tokens` | 业务 | Claude token 计数快捷路径 |
|
||||||
|
| POST | `/messages/count_tokens` | 业务 | Claude token 计数快捷路径 |
|
||||||
|
| POST | `/v1beta/models/{model}:generateContent` | 业务 | Gemini 非流式 |
|
||||||
|
| POST | `/v1beta/models/{model}:streamGenerateContent` | 业务 | Gemini 流式 |
|
||||||
|
| POST | `/v1/models/{model}:generateContent` | 业务 | Gemini 非流式兼容路径 |
|
||||||
|
| POST | `/v1/models/{model}:streamGenerateContent` | 业务 | Gemini 流式兼容路径 |
|
||||||
| POST | `/admin/login` | 无 | 管理登录 |
|
| POST | `/admin/login` | 无 | 管理登录 |
|
||||||
| GET | `/admin/verify` | JWT | 校验管理 JWT |
|
| GET | `/admin/verify` | JWT | 校验管理 JWT |
|
||||||
| GET | `/admin/vercel/config` | Admin | 读取 Vercel 预配置 |
|
| GET | `/admin/vercel/config` | Admin | 读取 Vercel 预配置 |
|
||||||
| GET | `/admin/config` | Admin | 读取配置(脱敏) |
|
| GET | `/admin/config` | Admin | 读取配置(脱敏) |
|
||||||
| POST | `/admin/config` | Admin | 更新配置 |
|
| POST | `/admin/config` | Admin | 更新配置 |
|
||||||
|
| GET | `/admin/settings` | Admin | 读取运行时设置 |
|
||||||
|
| PUT | `/admin/settings` | Admin | 更新运行时设置(热更新) |
|
||||||
|
| POST | `/admin/settings/password` | Admin | 更新 Admin 密码并使旧 JWT 失效 |
|
||||||
|
| POST | `/admin/config/import` | Admin | 导入配置(merge/replace) |
|
||||||
|
| GET | `/admin/config/export` | Admin | 导出完整配置(含 `config`/`json`/`base64`) |
|
||||||
| POST | `/admin/keys` | Admin | 添加 API key |
|
| POST | `/admin/keys` | Admin | 添加 API key |
|
||||||
| DELETE | `/admin/keys/{key}` | Admin | 删除 API key |
|
| DELETE | `/admin/keys/{key}` | Admin | 删除 API key |
|
||||||
| GET | `/admin/accounts` | Admin | 分页账号列表 |
|
| GET | `/admin/accounts` | Admin | 分页账号列表 |
|
||||||
@@ -88,6 +129,8 @@
|
|||||||
| POST | `/admin/vercel/sync` | Admin | 同步配置到 Vercel |
|
| POST | `/admin/vercel/sync` | Admin | 同步配置到 Vercel |
|
||||||
| GET | `/admin/vercel/status` | Admin | Vercel 同步状态 |
|
| GET | `/admin/vercel/status` | Admin | Vercel 同步状态 |
|
||||||
| GET | `/admin/export` | Admin | 导出配置 JSON/Base64 |
|
| GET | `/admin/export` | Admin | 导出配置 JSON/Base64 |
|
||||||
|
| GET | `/admin/dev/captures` | Admin | 查看本地抓包记录 |
|
||||||
|
| DELETE | `/admin/dev/captures` | Admin | 清空本地抓包记录 |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -127,6 +170,15 @@
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 模型 alias 解析策略
|
||||||
|
|
||||||
|
对 `chat` / `responses` / `embeddings` 的 `model` 字段采用“宽进严出”:
|
||||||
|
|
||||||
|
1. 先匹配 DeepSeek 原生模型。
|
||||||
|
2. 再匹配 `model_aliases` 精确映射。
|
||||||
|
3. 未命中时按模型家族规则回退(如 `o*`、`gpt-*`、`claude-*`)。
|
||||||
|
4. 仍未命中则返回 `invalid_request_error`。
|
||||||
|
|
||||||
### `POST /v1/chat/completions`
|
### `POST /v1/chat/completions`
|
||||||
|
|
||||||
**请求头**:
|
**请求头**:
|
||||||
@@ -140,7 +192,7 @@ Content-Type: application/json
|
|||||||
|
|
||||||
| 字段 | 类型 | 必填 | 说明 |
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` |
|
| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-4o`、`gpt-5-codex`、`o3`、`claude-sonnet-4-5`) |
|
||||||
| `messages` | array | ✅ | OpenAI 风格消息数组 |
|
| `messages` | array | ✅ | OpenAI 风格消息数组 |
|
||||||
| `stream` | boolean | ❌ | 默认 `false` |
|
| `stream` | boolean | ❌ | 默认 `false` |
|
||||||
| `tools` | array | ❌ | Function Calling 定义 |
|
| `tools` | array | ❌ | Function Calling 定义 |
|
||||||
@@ -230,12 +282,90 @@ data: [DONE]
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
**流式**:先缓冲正文片段。识别到工具调用 → 仅输出结构化 `delta.tool_calls`(每个 tool call 带 `index`);否则一次性输出普通文本。
|
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `GET /v1/models/{id}`
|
||||||
|
|
||||||
|
无需鉴权。入参支持 alias(例如 `gpt-4o`),返回的是映射后的 DeepSeek 模型对象。
|
||||||
|
|
||||||
|
### `POST /v1/responses`
|
||||||
|
|
||||||
|
OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `model` | string | ✅ | 支持原生模型 + alias 自动映射 |
|
||||||
|
| `input` | string/array/object | ❌ | 与 `messages` 二选一 |
|
||||||
|
| `messages` | array | ❌ | 与 `input` 二选一 |
|
||||||
|
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
||||||
|
| `stream` | boolean | ❌ | 默认 `false` |
|
||||||
|
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 |
|
||||||
|
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) |
|
||||||
|
|
||||||
|
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
||||||
|
当 `tool_choice=required` 且未产出有效工具调用时,返回 HTTP `422`(`error.code=tool_choice_violation`)。
|
||||||
|
|
||||||
|
**流式响应(SSE)**:最小事件序列如下。
|
||||||
|
|
||||||
|
```text
|
||||||
|
event: response.created
|
||||||
|
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
|
||||||
|
|
||||||
|
event: response.output_item.added
|
||||||
|
data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||||
|
|
||||||
|
event: response.content_part.added
|
||||||
|
data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...}
|
||||||
|
|
||||||
|
event: response.output_text.delta
|
||||||
|
data: {"type":"response.output_text.delta","response_id":"resp_xxx","item_id":"msg_xxx","output_index":0,"content_index":0,"delta":"..."}
|
||||||
|
|
||||||
|
event: response.function_call_arguments.delta
|
||||||
|
data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."}
|
||||||
|
|
||||||
|
event: response.function_call_arguments.done
|
||||||
|
data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"}
|
||||||
|
|
||||||
|
event: response.content_part.done
|
||||||
|
data: {"type":"response.content_part.done","response_id":"resp_xxx",...}
|
||||||
|
|
||||||
|
event: response.output_item.done
|
||||||
|
data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||||
|
|
||||||
|
event: response.completed
|
||||||
|
data: {"type":"response.completed","response":{...}}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
```
|
||||||
|
|
||||||
|
流式场景下若 `tool_choice=required` 违规,会返回 `response.failed` 后结束(不再发送 `response.completed`)。
|
||||||
|
未在 `tools` 声明中的工具名会被严格拒绝,不会作为有效 tool call 下发。
|
||||||
|
|
||||||
|
### `GET /v1/responses/{response_id}`
|
||||||
|
|
||||||
|
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。
|
||||||
|
|
||||||
|
> 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。
|
||||||
|
|
||||||
|
### `POST /v1/embeddings`
|
||||||
|
|
||||||
|
需要业务鉴权。返回 OpenAI Embeddings 兼容结构。
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `model` | string | ✅ | 支持原生模型 + alias 自动映射 |
|
||||||
|
| `input` | string/array | ✅ | 支持字符串、字符串数组、token 数组 |
|
||||||
|
|
||||||
|
> 需配置 `embeddings.provider`。当前支持:`mock` / `deterministic` / `builtin`。未配置或不支持时返回标准错误结构(HTTP 501)。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Claude 兼容接口
|
## Claude 兼容接口
|
||||||
|
|
||||||
|
除标准路径 `/anthropic/v1/*` 外,还支持快捷路径 `/v1/messages`、`/messages`、`/v1/messages/count_tokens`、`/messages/count_tokens`。
|
||||||
|
|
||||||
### `GET /anthropic/v1/models`
|
### `GET /anthropic/v1/models`
|
||||||
|
|
||||||
无需鉴权。
|
无需鉴权。
|
||||||
@@ -249,7 +379,10 @@ data: [DONE]
|
|||||||
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||||
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||||
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
|
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
|
||||||
]
|
],
|
||||||
|
"first_id": "claude-opus-4-6",
|
||||||
|
"last_id": "claude-instant-1.0",
|
||||||
|
"has_more": false
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -265,13 +398,15 @@ Content-Type: application/json
|
|||||||
anthropic-version: 2023-06-01
|
anthropic-version: 2023-06-01
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> `anthropic-version` 可省略,服务端会自动补为 `2023-06-01`。
|
||||||
|
|
||||||
**请求体**:
|
**请求体**:
|
||||||
|
|
||||||
| 字段 | 类型 | 必填 | 说明 |
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
| `model` | string | ✅ | 例如 `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5`(兼容 `claude-3-5-haiku-latest`),并支持历史 Claude 模型 ID |
|
| `model` | string | ✅ | 例如 `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5`(兼容 `claude-3-5-haiku-latest`),并支持历史 Claude 模型 ID |
|
||||||
| `messages` | array | ✅ | Claude 风格消息数组 |
|
| `messages` | array | ✅ | Claude 风格消息数组 |
|
||||||
| `max_tokens` | number | ❌ | 当前实现不会硬性截断上游输出 |
|
| `max_tokens` | number | ❌ | 缺省自动补 `8192`;当前实现不会硬性截断上游输出 |
|
||||||
| `stream` | boolean | ❌ | 默认 `false` |
|
| `stream` | boolean | ❌ | 默认 `false` |
|
||||||
| `system` | string | ❌ | 可选系统提示 |
|
| `system` | string | ❌ | 可选系统提示 |
|
||||||
| `tools` | array | ❌ | Claude tool 定义 |
|
| `tools` | array | ❌ | Claude tool 定义 |
|
||||||
@@ -354,6 +489,37 @@ data: {"type":"message_stop"}
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Gemini 兼容接口
|
||||||
|
|
||||||
|
支持路径:
|
||||||
|
|
||||||
|
- `/v1beta/models/{model}:generateContent`
|
||||||
|
- `/v1beta/models/{model}:streamGenerateContent`
|
||||||
|
- `/v1/models/{model}:generateContent`(兼容路径)
|
||||||
|
- `/v1/models/{model}:streamGenerateContent`(兼容路径)
|
||||||
|
|
||||||
|
鉴权方式同业务接口(`Authorization: Bearer <token>` 或 `x-api-key`)。
|
||||||
|
|
||||||
|
### `POST /v1beta/models/{model}:generateContent`
|
||||||
|
|
||||||
|
请求体兼容 Gemini `contents` / `tools` 字段,模型名可用 alias 自动映射到 DeepSeek 模型。
|
||||||
|
|
||||||
|
响应为 Gemini 兼容结构,核心字段包括:
|
||||||
|
|
||||||
|
- `candidates[].content.parts[].text`
|
||||||
|
- `candidates[].content.parts[].functionCall`(工具调用时)
|
||||||
|
- `usageMetadata`(`promptTokenCount` / `candidatesTokenCount` / `totalTokenCount`)
|
||||||
|
|
||||||
|
### `POST /v1beta/models/{model}:streamGenerateContent`
|
||||||
|
|
||||||
|
返回 SSE(`text/event-stream`),每个 chunk 为一条 `data: <json>`:
|
||||||
|
|
||||||
|
- 常规文本:持续返回增量文本 chunk
|
||||||
|
- `tools` 场景:会缓冲并在结束时输出 `functionCall` 结构
|
||||||
|
- 结束 chunk:包含 `finishReason: "STOP"` 与 `usageMetadata`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Admin 接口
|
## Admin 接口
|
||||||
|
|
||||||
### `POST /admin/login`
|
### `POST /admin/login`
|
||||||
@@ -416,6 +582,7 @@ data: {"type":"message_stop"}
|
|||||||
"keys": ["k1", "k2"],
|
"keys": ["k1", "k2"],
|
||||||
"accounts": [
|
"accounts": [
|
||||||
{
|
{
|
||||||
|
"identifier": "user@example.com",
|
||||||
"email": "user@example.com",
|
"email": "user@example.com",
|
||||||
"mobile": "",
|
"mobile": "",
|
||||||
"has_password": true,
|
"has_password": true,
|
||||||
@@ -449,6 +616,51 @@ data: {"type":"message_stop"}
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `GET /admin/settings`
|
||||||
|
|
||||||
|
读取运行时设置与状态,返回:
|
||||||
|
|
||||||
|
- `admin`(JWT 过期、默认密码告警等)
|
||||||
|
- `runtime`(`account_max_inflight`、`account_max_queue`、`global_max_inflight`)
|
||||||
|
- `toolcall` / `responses` / `embeddings`
|
||||||
|
- `claude_mapping` / `model_aliases`
|
||||||
|
- `env_backed`、`needs_vercel_sync`
|
||||||
|
|
||||||
|
### `PUT /admin/settings`
|
||||||
|
|
||||||
|
热更新运行时设置。支持更新:
|
||||||
|
|
||||||
|
- `admin.jwt_expire_hours`
|
||||||
|
- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight`
|
||||||
|
- `toolcall.mode` / `toolcall.early_emit_confidence`
|
||||||
|
- `responses.store_ttl_seconds`
|
||||||
|
- `embeddings.provider`
|
||||||
|
- `claude_mapping`
|
||||||
|
- `model_aliases`
|
||||||
|
|
||||||
|
### `POST /admin/settings/password`
|
||||||
|
|
||||||
|
更新管理密码并使旧 JWT 失效。
|
||||||
|
|
||||||
|
请求示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"new_password":"your-new-password"}
|
||||||
|
```
|
||||||
|
|
||||||
|
### `POST /admin/config/import`
|
||||||
|
|
||||||
|
导入完整配置,支持:
|
||||||
|
|
||||||
|
- `mode=merge`(默认)
|
||||||
|
- `mode=replace`
|
||||||
|
|
||||||
|
请求可直接传配置对象,或使用 `{"config": {...}, "mode":"merge"}` 包裹格式。
|
||||||
|
|
||||||
|
### `GET /admin/config/export`
|
||||||
|
|
||||||
|
导出完整配置,返回 `config`、`json`、`base64` 三种格式。
|
||||||
|
|
||||||
### `POST /admin/keys`
|
### `POST /admin/keys`
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -476,6 +688,7 @@ data: {"type":"message_stop"}
|
|||||||
{
|
{
|
||||||
"items": [
|
"items": [
|
||||||
{
|
{
|
||||||
|
"identifier": "user@example.com",
|
||||||
"email": "user@example.com",
|
"email": "user@example.com",
|
||||||
"mobile": "",
|
"mobile": "",
|
||||||
"has_password": true,
|
"has_password": true,
|
||||||
@@ -500,7 +713,7 @@ data: {"type":"message_stop"}
|
|||||||
|
|
||||||
### `DELETE /admin/accounts/{identifier}`
|
### `DELETE /admin/accounts/{identifier}`
|
||||||
|
|
||||||
`identifier` 为 email 或 mobile。
|
`identifier` 可为 email、mobile,或 token-only 账号的合成标识(`token:<hash>`)。
|
||||||
|
|
||||||
**响应**:`{"success": true, "total_accounts": 5}`
|
**响应**:`{"success": true, "total_accounts": 5}`
|
||||||
|
|
||||||
@@ -530,7 +743,7 @@ data: {"type":"message_stop"}
|
|||||||
|
|
||||||
| 字段 | 必填 | 说明 |
|
| 字段 | 必填 | 说明 |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `identifier` | ✅ | email 或 mobile |
|
| `identifier` | ✅ | email / mobile / token-only 合成标识 |
|
||||||
| `model` | ❌ | 默认 `deepseek-chat` |
|
| `model` | ❌ | 默认 `deepseek-chat` |
|
||||||
| `message` | ❌ | 空字符串时仅测试会话创建 |
|
| `message` | ❌ | 空字符串时仅测试会话创建 |
|
||||||
|
|
||||||
@@ -655,17 +868,53 @@ data: {"type":"message_stop"}
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `GET /admin/dev/captures`
|
||||||
|
|
||||||
|
查看本地抓包状态与最近记录(需 Admin 鉴权):
|
||||||
|
|
||||||
|
- `enabled`
|
||||||
|
- `limit`
|
||||||
|
- `max_body_bytes`
|
||||||
|
- `items`
|
||||||
|
|
||||||
|
### `DELETE /admin/dev/captures`
|
||||||
|
|
||||||
|
清空抓包记录,返回:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"success":true,"detail":"capture logs cleared"}
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 错误响应格式
|
## 错误响应格式
|
||||||
|
|
||||||
不同模块的错误格式略有差异:
|
兼容路由(`/v1/*`、`/anthropic/*`)统一使用以下结构:
|
||||||
|
|
||||||
| 模块 | 格式 |
|
```json
|
||||||
| --- | --- |
|
{
|
||||||
| OpenAI 接口 | `{"error": {"message": "...", "type": "..."}}` |
|
"error": {
|
||||||
| Claude 接口 | `{"error": {"type": "...", "message": "..."}}` |
|
"message": "...",
|
||||||
| Admin 接口 | `{"detail": "..."}` |
|
"type": "invalid_request_error",
|
||||||
|
"code": "invalid_request",
|
||||||
|
"param": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Admin 接口保持 `{"detail":"..."}`。
|
||||||
|
|
||||||
|
Gemini 路由使用 Google 风格错误结构:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"code": 400,
|
||||||
|
"message": "invalid json",
|
||||||
|
"status": "INVALID_ARGUMENT"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
建议客户端处理逻辑:检查 HTTP 状态码 + 解析 `error` 或 `detail` 字段。
|
建议客户端处理逻辑:检查 HTTP 状态码 + 解析 `error` 或 `detail` 字段。
|
||||||
|
|
||||||
@@ -707,6 +956,31 @@ curl http://localhost:5001/v1/chat/completions \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### OpenAI Responses(流式)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:5001/v1/responses \
|
||||||
|
-H "Authorization: Bearer your-api-key" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-5-codex",
|
||||||
|
"input": "写一个 golang 的 hello world",
|
||||||
|
"stream": true
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI Embeddings
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:5001/v1/embeddings \
|
||||||
|
-H "Authorization: Bearer your-api-key" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"input": ["第一段文本", "第二段文本"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
### OpenAI 带搜索
|
### OpenAI 带搜索
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -748,6 +1022,38 @@ curl http://localhost:5001/v1/chat/completions \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Gemini 非流式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl "http://localhost:5001/v1beta/models/gemini-2.5-pro:generateContent" \
|
||||||
|
-H "Authorization: Bearer your-api-key" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "用三句话介绍 Go 语言"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Gemini 流式
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl "http://localhost:5001/v1beta/models/gemini-2.5-flash:streamGenerateContent" \
|
||||||
|
-H "Authorization: Bearer your-api-key" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"contents": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "写一个简短摘要"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
### Claude 非流式
|
### Claude 非流式
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -82,11 +82,11 @@ Manually build WebUI to `static/admin/`:
|
|||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Go unit tests
|
# Go + Node unit tests (recommended)
|
||||||
go test ./...
|
./tests/scripts/run-unit-all.sh
|
||||||
|
|
||||||
# End-to-end live tests (real accounts)
|
# End-to-end live tests (real accounts)
|
||||||
./scripts/testsuite/run-live.sh
|
./tests/scripts/run-live.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
@@ -104,13 +104,20 @@ ds2api/
|
|||||||
│ ├── account/ # Account pool and concurrency queue
|
│ ├── account/ # Account pool and concurrency queue
|
||||||
│ ├── adapter/
|
│ ├── adapter/
|
||||||
│ │ ├── openai/ # OpenAI adapter
|
│ │ ├── openai/ # OpenAI adapter
|
||||||
│ │ └── claude/ # Claude adapter
|
│ │ ├── claude/ # Claude adapter
|
||||||
|
│ │ └── gemini/ # Gemini adapter
|
||||||
│ ├── admin/ # Admin API handlers
|
│ ├── admin/ # Admin API handlers
|
||||||
│ ├── auth/ # Auth and JWT
|
│ ├── auth/ # Auth and JWT
|
||||||
|
│ ├── claudeconv/ # Claude message conversion
|
||||||
|
│ ├── compat/ # Compatibility helpers
|
||||||
│ ├── config/ # Config loading and hot-reload
|
│ ├── config/ # Config loading and hot-reload
|
||||||
│ ├── deepseek/ # DeepSeek client, PoW WASM
|
│ ├── deepseek/ # DeepSeek client, PoW WASM
|
||||||
|
│ ├── devcapture/ # Dev packet capture
|
||||||
|
│ ├── format/ # Output formatting
|
||||||
|
│ ├── prompt/ # Prompt building
|
||||||
│ ├── server/ # HTTP routing (chi router)
|
│ ├── server/ # HTTP routing (chi router)
|
||||||
│ ├── sse/ # SSE parsing utilities
|
│ ├── sse/ # SSE parsing utilities
|
||||||
|
│ ├── stream/ # Unified stream consumption engine
|
||||||
│ ├── testsuite/ # Testsuite core logic
|
│ ├── testsuite/ # Testsuite core logic
|
||||||
│ ├── util/ # Common utilities
|
│ ├── util/ # Common utilities
|
||||||
│ └── webui/ # WebUI static hosting
|
│ └── webui/ # WebUI static hosting
|
||||||
|
|||||||
@@ -82,11 +82,11 @@ docker-compose -f docker-compose.dev.yml up
|
|||||||
## 运行测试
|
## 运行测试
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Go 单元测试
|
# Go + Node 单元测试(推荐)
|
||||||
go test ./...
|
./tests/scripts/run-unit-all.sh
|
||||||
|
|
||||||
# 端到端全链路测试(真实账号)
|
# 端到端全链路测试(真实账号)
|
||||||
./scripts/testsuite/run-live.sh
|
./tests/scripts/run-live.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
## 项目结构
|
## 项目结构
|
||||||
@@ -104,13 +104,20 @@ ds2api/
|
|||||||
│ ├── account/ # 账号池与并发队列
|
│ ├── account/ # 账号池与并发队列
|
||||||
│ ├── adapter/
|
│ ├── adapter/
|
||||||
│ │ ├── openai/ # OpenAI 兼容适配器
|
│ │ ├── openai/ # OpenAI 兼容适配器
|
||||||
│ │ └── claude/ # Claude 兼容适配器
|
│ │ ├── claude/ # Claude 兼容适配器
|
||||||
|
│ │ └── gemini/ # Gemini 兼容适配器
|
||||||
│ ├── admin/ # Admin API handlers
|
│ ├── admin/ # Admin API handlers
|
||||||
│ ├── auth/ # 鉴权与 JWT
|
│ ├── auth/ # 鉴权与 JWT
|
||||||
|
│ ├── claudeconv/ # Claude 消息格式转换
|
||||||
|
│ ├── compat/ # 兼容性辅助
|
||||||
│ ├── config/ # 配置加载与热更新
|
│ ├── config/ # 配置加载与热更新
|
||||||
│ ├── deepseek/ # DeepSeek 客户端、PoW WASM
|
│ ├── deepseek/ # DeepSeek 客户端、PoW WASM
|
||||||
|
│ ├── devcapture/ # 开发抓包
|
||||||
|
│ ├── format/ # 输出格式化
|
||||||
|
│ ├── prompt/ # Prompt 构建
|
||||||
│ ├── server/ # HTTP 路由(chi router)
|
│ ├── server/ # HTTP 路由(chi router)
|
||||||
│ ├── sse/ # SSE 解析工具
|
│ ├── sse/ # SSE 解析工具
|
||||||
|
│ ├── stream/ # 统一流式消费引擎
|
||||||
│ ├── testsuite/ # 测试集核心逻辑
|
│ ├── testsuite/ # 测试集核心逻辑
|
||||||
│ ├── util/ # 通用工具
|
│ ├── util/ # 通用工具
|
||||||
│ └── webui/ # WebUI 静态托管
|
│ └── webui/ # WebUI 静态托管
|
||||||
|
|||||||
69
DEPLOY.en.md
69
DEPLOY.en.md
@@ -33,6 +33,17 @@ Config source (choose one):
|
|||||||
- **File**: `config.json` (recommended for local/Docker)
|
- **File**: `config.json` (recommended for local/Docker)
|
||||||
- **Environment variable**: `DS2API_CONFIG_JSON` (recommended for Vercel; supports raw JSON or Base64)
|
- **Environment variable**: `DS2API_CONFIG_JSON` (recommended for Vercel; supports raw JSON or Base64)
|
||||||
|
|
||||||
|
Unified recommendation (best practice):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# Edit config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `config.json` as the single source of truth:
|
||||||
|
- Local run: read `config.json` directly
|
||||||
|
- Docker / Vercel: generate `DS2API_CONFIG_JSON` (Base64) from `config.json` and inject it
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 1. Local Run
|
## 1. Local Run
|
||||||
@@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api
|
|||||||
### 2.1 Basic Steps
|
### 2.1 Basic Steps
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Copy and edit environment
|
# Copy env template
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
# Edit .env, at minimum set:
|
|
||||||
|
# Generate single-line Base64 from config.json
|
||||||
|
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||||
|
|
||||||
|
# Edit .env and set:
|
||||||
# DS2API_ADMIN_KEY=your-admin-key
|
# DS2API_ADMIN_KEY=your-admin-key
|
||||||
# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]}
|
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||||
|
|
||||||
# Start
|
# Start
|
||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
@@ -167,15 +182,49 @@ If container logs look normal but the admin panel is unreachable, check these fi
|
|||||||
|
|
||||||
1. **Fork** the repo to your GitHub account
|
1. **Fork** the repo to your GitHub account
|
||||||
2. **Import** the project on Vercel
|
2. **Import** the project on Vercel
|
||||||
3. **Set environment variables** (at minimum):
|
3. **Set environment variables** (minimum required: one variable):
|
||||||
|
|
||||||
| Variable | Description |
|
| Variable | Description |
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
| `DS2API_ADMIN_KEY` | Admin key (required) |
|
| `DS2API_ADMIN_KEY` | Admin key (required) |
|
||||||
| `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (required) |
|
| `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (optional, recommended) |
|
||||||
|
|
||||||
4. **Deploy**
|
4. **Deploy**
|
||||||
|
|
||||||
|
### 3.1.1 Recommended Input (avoid `DS2API_CONFIG_JSON` mistakes)
|
||||||
|
|
||||||
|
If you prefer faster one-click bootstrap, you can leave `DS2API_CONFIG_JSON` empty first, then open `/admin` after deployment, import config, and sync it back to Vercel env vars from the "Vercel Sync" page.
|
||||||
|
|
||||||
|
Recommended: in repo root, copy the template first and fill your real accounts:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# Edit config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
Do not hand-edit large JSON directly in Vercel. Generate Base64 locally and paste it:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run in repo root
|
||||||
|
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||||
|
echo "$DS2API_CONFIG_JSON"
|
||||||
|
```
|
||||||
|
|
||||||
|
If you choose to preconfigure before first deploy, set these vars in Vercel Project Settings -> Environment Variables:
|
||||||
|
|
||||||
|
```text
|
||||||
|
DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||||
|
DS2API_CONFIG_JSON=<the single-line Base64 output above>
|
||||||
|
```
|
||||||
|
|
||||||
|
Optional but recommended (for WebUI one-click Vercel sync):
|
||||||
|
|
||||||
|
```text
|
||||||
|
VERCEL_TOKEN=your-vercel-token
|
||||||
|
VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
|
||||||
|
VERCEL_TEAM_ID=team_xxxxxxxxxxxx # optional for personal accounts
|
||||||
|
```
|
||||||
|
|
||||||
### 3.2 Optional Environment Variables
|
### 3.2 Optional Environment Variables
|
||||||
|
|
||||||
| Variable | Description | Default |
|
| Variable | Description | Default |
|
||||||
@@ -184,6 +233,8 @@ If container logs look normal but the admin panel is unreachable, check these fi
|
|||||||
| `DS2API_ACCOUNT_CONCURRENCY` | Alias (legacy compat) | — |
|
| `DS2API_ACCOUNT_CONCURRENCY` | Alias (legacy compat) | — |
|
||||||
| `DS2API_ACCOUNT_MAX_QUEUE` | Waiting queue limit | `recommended_concurrency` |
|
| `DS2API_ACCOUNT_MAX_QUEUE` | Waiting queue limit | `recommended_concurrency` |
|
||||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
|
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
|
||||||
|
| `DS2API_GLOBAL_MAX_INFLIGHT` | Global inflight limit | `recommended_concurrency` |
|
||||||
|
| `DS2API_MAX_INFLIGHT` | Alias (legacy compat) | — |
|
||||||
| `DS2API_VERCEL_INTERNAL_SECRET` | Hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
|
| `DS2API_VERCEL_INTERNAL_SECRET` | Hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
|
||||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL | `900` |
|
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL | `900` |
|
||||||
| `VERCEL_TOKEN` | Vercel sync token | — |
|
| `VERCEL_TOKEN` | Vercel sync token | — |
|
||||||
@@ -310,8 +361,8 @@ Each archive includes:
|
|||||||
```bash
|
```bash
|
||||||
# 1. Download the archive for your platform
|
# 1. Download the archive for your platform
|
||||||
# 2. Extract
|
# 2. Extract
|
||||||
tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz
|
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||||
cd ds2api_v1.7.0_linux_amd64
|
cd ds2api_<tag>_linux_amd64
|
||||||
|
|
||||||
# 3. Configure
|
# 3. Configure
|
||||||
cp config.example.json config.json
|
cp config.example.json config.json
|
||||||
@@ -323,7 +374,7 @@ cp config.example.json config.json
|
|||||||
|
|
||||||
### Maintainer Release Flow
|
### Maintainer Release Flow
|
||||||
|
|
||||||
1. Create and publish a GitHub Release (with tag, e.g. `v1.7.0`)
|
1. Create and publish a GitHub Release (with tag, for example `vX.Y.Z`)
|
||||||
2. Wait for the `Release Artifacts` workflow to complete
|
2. Wait for the `Release Artifacts` workflow to complete
|
||||||
3. Download the matching archive from Release Assets
|
3. Download the matching archive from Release Assets
|
||||||
|
|
||||||
@@ -469,7 +520,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
|
|||||||
Run the full live testsuite before release (real account tests):
|
Run the full live testsuite before release (real account tests):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./scripts/testsuite/run-live.sh
|
./tests/scripts/run-live.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
With custom flags:
|
With custom flags:
|
||||||
|
|||||||
69
DEPLOY.md
69
DEPLOY.md
@@ -33,6 +33,17 @@
|
|||||||
- **文件方式**:`config.json`(推荐本地/Docker 使用)
|
- **文件方式**:`config.json`(推荐本地/Docker 使用)
|
||||||
- **环境变量方式**:`DS2API_CONFIG_JSON`(推荐 Vercel 使用,支持 JSON 字符串或 Base64 编码)
|
- **环境变量方式**:`DS2API_CONFIG_JSON`(推荐 Vercel 使用,支持 JSON 字符串或 Base64 编码)
|
||||||
|
|
||||||
|
统一建议(最优实践):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# 编辑 config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
建议把 `config.json` 作为唯一配置源:
|
||||||
|
- 本地运行:直接读 `config.json`
|
||||||
|
- Docker / Vercel:从 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 一、本地运行
|
## 一、本地运行
|
||||||
@@ -99,11 +110,15 @@ go build -o ds2api ./cmd/ds2api
|
|||||||
### 2.1 基本步骤
|
### 2.1 基本步骤
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 复制并编辑环境变量
|
# 复制环境变量模板
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
# 编辑 .env,至少设置:
|
|
||||||
|
# 从 config.json 生成单行 Base64
|
||||||
|
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||||
|
|
||||||
|
# 编辑 .env(请改成你的强密码),设置:
|
||||||
# DS2API_ADMIN_KEY=your-admin-key
|
# DS2API_ADMIN_KEY=your-admin-key
|
||||||
# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]}
|
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||||
|
|
||||||
# 启动
|
# 启动
|
||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
@@ -167,15 +182,49 @@ healthcheck:
|
|||||||
|
|
||||||
1. **Fork 仓库**到你的 GitHub 账号
|
1. **Fork 仓库**到你的 GitHub 账号
|
||||||
2. **在 Vercel 上导入项目**
|
2. **在 Vercel 上导入项目**
|
||||||
3. **配置环境变量**(至少设置以下两项):
|
3. **配置环境变量**(最少只需设置以下一项):
|
||||||
|
|
||||||
| 变量 | 说明 |
|
| 变量 | 说明 |
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
| `DS2API_ADMIN_KEY` | 管理密钥(必填) |
|
| `DS2API_ADMIN_KEY` | 管理密钥(必填) |
|
||||||
| `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(必填) |
|
| `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(可选,建议) |
|
||||||
|
|
||||||
4. **部署**
|
4. **部署**
|
||||||
|
|
||||||
|
### 3.1.1 推荐填写方式(避免 `DS2API_CONFIG_JSON` 填错)
|
||||||
|
|
||||||
|
如果你想先完成一键部署,也可以先不填 `DS2API_CONFIG_JSON`,部署后进入 `/admin` 导入配置,再在「Vercel 同步」里写回环境变量。
|
||||||
|
|
||||||
|
建议先在仓库目录复制示例配置,再按实际账号填写:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# 编辑 config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
不要在 Vercel 面板里手写复杂 JSON,建议本地生成 Base64 后粘贴:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 在仓库根目录执行
|
||||||
|
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||||
|
echo "$DS2API_CONFIG_JSON"
|
||||||
|
```
|
||||||
|
|
||||||
|
如果你选择在部署前就预置配置,请在 Vercel Project Settings -> Environment Variables 配置:
|
||||||
|
|
||||||
|
```text
|
||||||
|
DS2API_ADMIN_KEY=请替换为强密码
|
||||||
|
DS2API_CONFIG_JSON=上一步生成的一整行 Base64
|
||||||
|
```
|
||||||
|
|
||||||
|
可选但推荐(用于 WebUI 一键同步 Vercel 配置):
|
||||||
|
|
||||||
|
```text
|
||||||
|
VERCEL_TOKEN=你的 Vercel Token
|
||||||
|
VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
|
||||||
|
VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空
|
||||||
|
```
|
||||||
|
|
||||||
### 3.2 可选环境变量
|
### 3.2 可选环境变量
|
||||||
|
|
||||||
| 变量 | 说明 | 默认值 |
|
| 变量 | 说明 | 默认值 |
|
||||||
@@ -184,6 +233,8 @@ healthcheck:
|
|||||||
| `DS2API_ACCOUNT_CONCURRENCY` | 同上(兼容别名) | — |
|
| `DS2API_ACCOUNT_CONCURRENCY` | 同上(兼容别名) | — |
|
||||||
| `DS2API_ACCOUNT_MAX_QUEUE` | 等待队列上限 | `recommended_concurrency` |
|
| `DS2API_ACCOUNT_MAX_QUEUE` | 等待队列上限 | `recommended_concurrency` |
|
||||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容别名) | — |
|
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容别名) | — |
|
||||||
|
| `DS2API_GLOBAL_MAX_INFLIGHT` | 全局并发上限 | `recommended_concurrency` |
|
||||||
|
| `DS2API_MAX_INFLIGHT` | 同上(兼容别名) | — |
|
||||||
| `DS2API_VERCEL_INTERNAL_SECRET` | 混合流式内部鉴权 | 回退用 `DS2API_ADMIN_KEY` |
|
| `DS2API_VERCEL_INTERNAL_SECRET` | 混合流式内部鉴权 | 回退用 `DS2API_ADMIN_KEY` |
|
||||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease TTL | `900` |
|
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease TTL | `900` |
|
||||||
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
||||||
@@ -310,8 +361,8 @@ No Output Directory named "public" found after the Build completed.
|
|||||||
```bash
|
```bash
|
||||||
# 1. 下载对应平台的压缩包
|
# 1. 下载对应平台的压缩包
|
||||||
# 2. 解压
|
# 2. 解压
|
||||||
tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz
|
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||||
cd ds2api_v1.7.0_linux_amd64
|
cd ds2api_<tag>_linux_amd64
|
||||||
|
|
||||||
# 3. 配置
|
# 3. 配置
|
||||||
cp config.example.json config.json
|
cp config.example.json config.json
|
||||||
@@ -323,7 +374,7 @@ cp config.example.json config.json
|
|||||||
|
|
||||||
### 维护者发布步骤
|
### 维护者发布步骤
|
||||||
|
|
||||||
1. 在 GitHub 创建并发布 Release(带 tag,如 `v1.7.0`)
|
1. 在 GitHub 创建并发布 Release(带 tag,如 `vX.Y.Z`)
|
||||||
2. 等待 Actions 工作流 `Release Artifacts` 完成
|
2. 等待 Actions 工作流 `Release Artifacts` 完成
|
||||||
3. 在 Release 的 Assets 下载对应平台压缩包
|
3. 在 Release 的 Assets 下载对应平台压缩包
|
||||||
|
|
||||||
@@ -469,7 +520,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
|
|||||||
建议在发布前执行完整的端到端测试集(使用真实账号):
|
建议在发布前执行完整的端到端测试集(使用真实账号):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./scripts/testsuite/run-live.sh
|
./tests/scripts/run-live.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
可自定义参数:
|
可自定义参数:
|
||||||
|
|||||||
195
README.MD
195
README.MD
@@ -3,18 +3,18 @@
|
|||||||
[](LICENSE)
|
[](LICENSE)
|
||||||

|

|
||||||

|

|
||||||
[](version.txt)
|
[](https://github.com/CJackHwang/ds2api/releases)
|
||||||
[](DEPLOY.md)
|
[](DEPLOY.md)
|
||||||
|
|
||||||
语言 / Language: [中文](README.MD) | [English](README.en.md)
|
语言 / Language: [中文](README.MD) | [English](README.en.md)
|
||||||
|
|
||||||
将 DeepSeek Web 对话能力转换为 OpenAI 与 Claude 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。
|
将 DeepSeek Web 对话能力转换为 OpenAI、Claude 与 Gemini 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。
|
||||||
|
|
||||||
## 架构概览
|
## 架构概览
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
flowchart LR
|
flowchart LR
|
||||||
Client["🖥️ 客户端\n(OpenAI / Claude 兼容)"]
|
Client["🖥️ 客户端\n(OpenAI / Claude / Gemini 兼容)"]
|
||||||
|
|
||||||
subgraph DS2API["DS2API 服务"]
|
subgraph DS2API["DS2API 服务"]
|
||||||
direction TB
|
direction TB
|
||||||
@@ -24,6 +24,7 @@ flowchart LR
|
|||||||
subgraph Adapters["适配器层"]
|
subgraph Adapters["适配器层"]
|
||||||
OA["OpenAI 适配器\n/v1/*"]
|
OA["OpenAI 适配器\n/v1/*"]
|
||||||
CA["Claude 适配器\n/anthropic/*"]
|
CA["Claude 适配器\n/anthropic/*"]
|
||||||
|
GA["Gemini 适配器\n/v1beta/models/*"]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph Support["支撑模块"]
|
subgraph Support["支撑模块"]
|
||||||
@@ -38,11 +39,11 @@ flowchart LR
|
|||||||
DS["☁️ DeepSeek API"]
|
DS["☁️ DeepSeek API"]
|
||||||
|
|
||||||
Client -- "请求" --> CORS --> Auth
|
Client -- "请求" --> CORS --> Auth
|
||||||
Auth --> OA & CA
|
Auth --> OA & CA & GA
|
||||||
OA & CA -- "调用" --> DS
|
OA & CA & GA -- "调用" --> DS
|
||||||
Auth --> Admin
|
Auth --> Admin
|
||||||
OA & CA -. "轮询选账号" .-> Pool
|
OA & CA & GA -. "轮询选账号" .-> Pool
|
||||||
OA & CA -. "计算 PoW" .-> PoW
|
OA & CA & GA -. "计算 PoW" .-> PoW
|
||||||
DS -- "响应" --> Client
|
DS -- "响应" --> Client
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -54,16 +55,29 @@ flowchart LR
|
|||||||
|
|
||||||
| 能力 | 说明 |
|
| 能力 | 说明 |
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
| OpenAI 兼容 | `GET /v1/models`、`POST /v1/chat/completions`(流式/非流式) |
|
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` |
|
||||||
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` |
|
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens`(及快捷路径 `/v1/messages`、`/messages`) |
|
||||||
|
| Gemini 兼容 | `POST /v1beta/models/{model}:generateContent`、`POST /v1beta/models/{model}:streamGenerateContent`(及 `/v1/models/{model}:*` 路径) |
|
||||||
| 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 |
|
| 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 |
|
||||||
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
|
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
|
||||||
| DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 |
|
| DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 |
|
||||||
| Tool Calling | 防泄漏处理:自动缓冲、识别、结构化输出 |
|
| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 |
|
||||||
| Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 |
|
| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、导入导出、Vercel 同步 |
|
||||||
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
|
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
|
||||||
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
|
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
|
||||||
|
|
||||||
|
## 平台兼容矩阵
|
||||||
|
|
||||||
|
| 级别 | 平台 | 当前状态 |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| P0 | Codex CLI/SDK(`wire_api=chat` / `wire_api=responses`) | ✅ |
|
||||||
|
| P0 | OpenAI SDK(JS/Python,chat + responses) | ✅ |
|
||||||
|
| P0 | Vercel AI SDK(openai-compatible) | ✅ |
|
||||||
|
| P0 | Anthropic SDK(messages) | ✅ |
|
||||||
|
| P0 | Google Gemini SDK(generateContent) | ✅ |
|
||||||
|
| P1 | LangChain / LlamaIndex / OpenWebUI(OpenAI 兼容接入) | ✅ |
|
||||||
|
| P2 | MCP 独立桥接层 | 规划中 |
|
||||||
|
|
||||||
## 模型支持
|
## 模型支持
|
||||||
|
|
||||||
### OpenAI 接口
|
### OpenAI 接口
|
||||||
@@ -86,8 +100,25 @@ flowchart LR
|
|||||||
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
||||||
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
||||||
|
|
||||||
|
### Gemini 接口
|
||||||
|
|
||||||
|
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
|
### 通用第一步(所有部署方式)
|
||||||
|
|
||||||
|
把 `config.json` 作为唯一配置源(推荐做法):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# 编辑 config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
后续部署建议:
|
||||||
|
- 本地运行:直接读取 `config.json`
|
||||||
|
- Docker / Vercel:由 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量
|
||||||
|
|
||||||
### 方式一:本地运行
|
### 方式一:本地运行
|
||||||
|
|
||||||
**前置要求**:Go 1.24+,Node.js 20+(仅在需要构建 WebUI 时)
|
**前置要求**:Go 1.24+,Node.js 20+(仅在需要构建 WebUI 时)
|
||||||
@@ -112,14 +143,20 @@ go run ./cmd/ds2api
|
|||||||
### 方式二:Docker 运行
|
### 方式二:Docker 运行
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 1. 配置环境变量
|
# 1. 准备环境变量文件
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
# 编辑 .env
|
|
||||||
|
|
||||||
# 2. 启动
|
# 2. 从 config.json 生成 DS2API_CONFIG_JSON(单行 Base64)
|
||||||
|
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||||
|
|
||||||
|
# 3. 编辑 .env,设置:
|
||||||
|
# DS2API_ADMIN_KEY=请替换为强密码
|
||||||
|
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||||
|
|
||||||
|
# 4. 启动
|
||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
|
|
||||||
# 3. 查看日志
|
# 5. 查看日志
|
||||||
docker-compose logs -f
|
docker-compose logs -f
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -129,9 +166,22 @@ docker-compose logs -f
|
|||||||
|
|
||||||
1. Fork 仓库到自己的 GitHub
|
1. Fork 仓库到自己的 GitHub
|
||||||
2. 在 Vercel 上导入项目
|
2. 在 Vercel 上导入项目
|
||||||
3. 配置环境变量(至少设置 `DS2API_ADMIN_KEY` 和 `DS2API_CONFIG_JSON`)
|
3. 配置环境变量(最少设置 `DS2API_ADMIN_KEY`;推荐同时设置 `DS2API_CONFIG_JSON`)
|
||||||
4. 部署
|
4. 部署
|
||||||
|
|
||||||
|
建议先在仓库目录复制模板并填写:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# 编辑 config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
推荐:先本地把 `config.json` 转成 Base64,再粘贴到 `DS2API_CONFIG_JSON`,避免 JSON 格式错误:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
base64 < config.json | tr -d '\n'
|
||||||
|
```
|
||||||
|
|
||||||
> **流式说明**:`/v1/chat/completions` 在 Vercel 上默认走 `api/chat-stream.js`(Node Runtime)以保证实时 SSE。鉴权、账号选择、会话/PoW 准备仍由 Go 内部 prepare 接口完成;流式响应(含 `tools`)在 Node 侧执行与 Go 对齐的输出组装与防泄漏处理。
|
> **流式说明**:`/v1/chat/completions` 在 Vercel 上默认走 `api/chat-stream.js`(Node Runtime)以保证实时 SSE。鉴权、账号选择、会话/PoW 准备仍由 Go 内部 prepare 接口完成;流式响应(含 `tools`)在 Node 侧执行与 Go 对齐的输出组装与防泄漏处理。
|
||||||
|
|
||||||
详细部署说明请参阅 [部署指南](DEPLOY.md)。
|
详细部署说明请参阅 [部署指南](DEPLOY.md)。
|
||||||
@@ -142,8 +192,8 @@ docker-compose logs -f
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 下载对应平台的压缩包后
|
# 下载对应平台的压缩包后
|
||||||
tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz
|
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||||
cd ds2api_v1.7.0_linux_amd64
|
cd ds2api_<tag>_linux_amd64
|
||||||
cp config.example.json config.json
|
cp config.example.json config.json
|
||||||
# 编辑 config.json
|
# 编辑 config.json
|
||||||
./ds2api
|
./ds2api
|
||||||
@@ -164,6 +214,7 @@ cp opencode.json.example opencode.json
|
|||||||
3. 在项目目录启动 OpenCode CLI(按你的安装方式运行 `opencode`)。
|
3. 在项目目录启动 OpenCode CLI(按你的安装方式运行 `opencode`)。
|
||||||
|
|
||||||
> 建议优先使用 OpenAI 兼容路径(`/v1/*`),即示例里的 `@ai-sdk/openai-compatible` provider。
|
> 建议优先使用 OpenAI 兼容路径(`/v1/*`),即示例里的 `@ai-sdk/openai-compatible` provider。
|
||||||
|
> 若客户端支持 `wire_api`,可分别测试 `responses` 与 `chat`,DS2API 两条链路都兼容。
|
||||||
|
|
||||||
## 配置说明
|
## 配置说明
|
||||||
|
|
||||||
@@ -184,9 +235,35 @@ cp opencode.json.example opencode.json
|
|||||||
"token": ""
|
"token": ""
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"model_aliases": {
|
||||||
|
"gpt-4o": "deepseek-chat",
|
||||||
|
"gpt-5-codex": "deepseek-reasoner",
|
||||||
|
"o3": "deepseek-reasoner"
|
||||||
|
},
|
||||||
|
"compat": {
|
||||||
|
"wide_input_strict_output": true
|
||||||
|
},
|
||||||
|
"toolcall": {
|
||||||
|
"mode": "feature_match",
|
||||||
|
"early_emit_confidence": "high"
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"store_ttl_seconds": 900
|
||||||
|
},
|
||||||
|
"embeddings": {
|
||||||
|
"provider": "deterministic"
|
||||||
|
},
|
||||||
"claude_model_mapping": {
|
"claude_model_mapping": {
|
||||||
"fast": "deepseek-chat",
|
"fast": "deepseek-chat",
|
||||||
"slow": "deepseek-reasoner"
|
"slow": "deepseek-reasoner"
|
||||||
|
},
|
||||||
|
"admin": {
|
||||||
|
"jwt_expire_hours": 24
|
||||||
|
},
|
||||||
|
"runtime": {
|
||||||
|
"account_max_inflight": 2,
|
||||||
|
"account_max_queue": 0,
|
||||||
|
"global_max_inflight": 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -194,7 +271,14 @@ cp opencode.json.example opencode.json
|
|||||||
- `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer <key>` 鉴权
|
- `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer <key>` 鉴权
|
||||||
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录
|
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录
|
||||||
- `token`:留空则首次请求时自动登录获取;也可预填已有 token
|
- `token`:留空则首次请求时自动登录获取;也可预填已有 token
|
||||||
|
- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射
|
||||||
|
- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出)
|
||||||
|
- `toolcall`:固定采用特征匹配 + 高置信早发策略
|
||||||
|
- `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL
|
||||||
|
- `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`)
|
||||||
- `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型
|
- `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型
|
||||||
|
- `admin`:管理后台设置(JWT 过期时间、密码哈希等),可通过 Admin Settings API 热更新
|
||||||
|
- `runtime`:运行时参数(并发限制、队列大小),可通过 Admin Settings API 热更新
|
||||||
|
|
||||||
### 环境变量
|
### 环境变量
|
||||||
|
|
||||||
@@ -214,8 +298,13 @@ cp opencode.json.example opencode.json
|
|||||||
| `DS2API_ACCOUNT_CONCURRENCY` | 同上(兼容旧名) | — |
|
| `DS2API_ACCOUNT_CONCURRENCY` | 同上(兼容旧名) | — |
|
||||||
| `DS2API_ACCOUNT_MAX_QUEUE` | 等待队列上限 | `recommended_concurrency` |
|
| `DS2API_ACCOUNT_MAX_QUEUE` | 等待队列上限 | `recommended_concurrency` |
|
||||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容旧名) | — |
|
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容旧名) | — |
|
||||||
|
| `DS2API_GLOBAL_MAX_INFLIGHT` | 全局最大 in-flight 请求数 | `recommended_concurrency` |
|
||||||
|
| `DS2API_MAX_INFLIGHT` | 同上(兼容旧名) | — |
|
||||||
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel 混合流式内部鉴权密钥 | 回退用 `DS2API_ADMIN_KEY` |
|
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel 混合流式内部鉴权密钥 | 回退用 `DS2API_ADMIN_KEY` |
|
||||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease 过期秒数 | `900` |
|
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease 过期秒数 | `900` |
|
||||||
|
| `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 |
|
||||||
|
| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | 本地抓包保留条数(超出自动淘汰) | `5` |
|
||||||
|
| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | 单条响应体最大记录字节数 | `2097152` |
|
||||||
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
||||||
| `VERCEL_PROJECT_ID` | Vercel 项目 ID | — |
|
| `VERCEL_PROJECT_ID` | Vercel 项目 ID | — |
|
||||||
| `VERCEL_TEAM_ID` | Vercel 团队 ID | — |
|
| `VERCEL_TEAM_ID` | Vercel 团队 ID | — |
|
||||||
@@ -223,7 +312,7 @@ cp opencode.json.example opencode.json
|
|||||||
|
|
||||||
## 鉴权模式
|
## 鉴权模式
|
||||||
|
|
||||||
调用业务接口(`/v1/*`、`/anthropic/*`)时支持两种模式:
|
调用业务接口(`/v1/*`、`/anthropic/*`、Gemini 路由)时支持两种模式:
|
||||||
|
|
||||||
| 模式 | 说明 |
|
| 模式 | 说明 |
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
@@ -249,10 +338,34 @@ cp opencode.json.example opencode.json
|
|||||||
|
|
||||||
当请求中带 `tools` 时,DS2API 会做防泄漏处理:
|
当请求中带 `tools` 时,DS2API 会做防泄漏处理:
|
||||||
|
|
||||||
1. `stream=true` 时先**缓冲**正文片段
|
1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发)
|
||||||
2. 若识别到工具调用 → 仅输出结构化 `tool_calls`,不透传原始 JSON 文本
|
2. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`)
|
||||||
3. 若最终不是工具调用 → 一次性输出普通文本
|
3. 未在 `tools` 声明中的工具名会被严格拒绝,不会下发为有效 tool call
|
||||||
4. 解析器支持混合文本、fenced JSON、`function.arguments` 字符串等格式
|
4. `responses` 支持并执行 `tool_choice`(`auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed`
|
||||||
|
5. 仅在通过策略校验后才会发出有效工具调用事件,避免错误工具名进入客户端执行链
|
||||||
|
|
||||||
|
## 本地开发抓包工具
|
||||||
|
|
||||||
|
用于定位「responses 思考流/工具调用」等问题。开启后会自动记录最近 N 条 DeepSeek 对话上游请求体与响应体(默认 5 条,超出自动淘汰)。
|
||||||
|
|
||||||
|
启用示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
DS2API_DEV_PACKET_CAPTURE=true \
|
||||||
|
DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
|
||||||
|
go run ./cmd/ds2api
|
||||||
|
```
|
||||||
|
|
||||||
|
查询/清空(需 Admin JWT):
|
||||||
|
|
||||||
|
- `GET /admin/dev/captures`:查看抓包列表(最新在前)
|
||||||
|
- `DELETE /admin/dev/captures`:清空抓包
|
||||||
|
|
||||||
|
返回字段包含:
|
||||||
|
|
||||||
|
- `request_body`:发送给 DeepSeek 的完整请求体
|
||||||
|
- `response_body`:上游返回的原始流式内容拼接文本
|
||||||
|
- `response_truncated`:是否触发单条大小截断
|
||||||
|
|
||||||
## 项目结构
|
## 项目结构
|
||||||
|
|
||||||
@@ -269,13 +382,20 @@ ds2api/
|
|||||||
│ ├── account/ # 账号池与并发队列
|
│ ├── account/ # 账号池与并发队列
|
||||||
│ ├── adapter/
|
│ ├── adapter/
|
||||||
│ │ ├── openai/ # OpenAI 兼容适配器(含 Tool Call 解析、Vercel 流式 prepare/release)
|
│ │ ├── openai/ # OpenAI 兼容适配器(含 Tool Call 解析、Vercel 流式 prepare/release)
|
||||||
│ │ └── claude/ # Claude 兼容适配器
|
│ │ ├── claude/ # Claude 兼容适配器
|
||||||
│ ├── admin/ # Admin API handlers
|
│ │ └── gemini/ # Gemini 兼容适配器(generateContent / streamGenerateContent)
|
||||||
|
│ ├── admin/ # Admin API handlers(含 Settings 热更新)
|
||||||
│ ├── auth/ # 鉴权与 JWT
|
│ ├── auth/ # 鉴权与 JWT
|
||||||
|
│ ├── claudeconv/ # Claude 消息格式转换
|
||||||
|
│ ├── compat/ # 兼容性辅助
|
||||||
│ ├── config/ # 配置加载与热更新
|
│ ├── config/ # 配置加载与热更新
|
||||||
│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM
|
│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM
|
||||||
|
│ ├── devcapture/ # 开发抓包模块
|
||||||
|
│ ├── format/ # 输出格式化
|
||||||
|
│ ├── prompt/ # Prompt 构建
|
||||||
│ ├── server/ # HTTP 路由与中间件(chi router)
|
│ ├── server/ # HTTP 路由与中间件(chi router)
|
||||||
│ ├── sse/ # SSE 解析工具
|
│ ├── sse/ # SSE 解析工具
|
||||||
|
│ ├── stream/ # 统一流式消费引擎
|
||||||
│ ├── util/ # 通用工具函数
|
│ ├── util/ # 通用工具函数
|
||||||
│ └── webui/ # WebUI 静态文件托管与自动构建
|
│ └── webui/ # WebUI 静态文件托管与自动构建
|
||||||
├── webui/ # React WebUI 源码(Vite + Tailwind)
|
├── webui/ # React WebUI 源码(Vite + Tailwind)
|
||||||
@@ -283,11 +403,13 @@ ds2api/
|
|||||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||||
│ └── locales/ # 中英文语言包(zh.json / en.json)
|
│ └── locales/ # 中英文语言包(zh.json / en.json)
|
||||||
├── scripts/
|
├── scripts/
|
||||||
│ ├── build-webui.sh # WebUI 手动构建脚本
|
│ └── build-webui.sh # WebUI 手动构建脚本
|
||||||
│ └── testsuite/ # 测试集运行脚本
|
├── tests/
|
||||||
|
│ ├── compat/ # 兼容性测试夹具与期望输出
|
||||||
|
│ └── scripts/ # 统一测试脚本入口(unit/e2e)
|
||||||
├── static/admin/ # WebUI 构建产物(不提交到 Git)
|
├── static/admin/ # WebUI 构建产物(不提交到 Git)
|
||||||
├── .github/
|
├── .github/
|
||||||
│ ├── workflows/ # GitHub Actions(Release 自动构建)
|
│ ├── workflows/ # GitHub Actions(质量门禁 + Release 自动构建)
|
||||||
│ ├── ISSUE_TEMPLATE/ # Issue 模板
|
│ ├── ISSUE_TEMPLATE/ # Issue 模板
|
||||||
│ └── PULL_REQUEST_TEMPLATE.md
|
│ └── PULL_REQUEST_TEMPLATE.md
|
||||||
├── config.example.json # 配置文件示例
|
├── config.example.json # 配置文件示例
|
||||||
@@ -296,8 +418,7 @@ ds2api/
|
|||||||
├── docker-compose.yml # 生产环境 Docker Compose
|
├── docker-compose.yml # 生产环境 Docker Compose
|
||||||
├── docker-compose.dev.yml # 开发环境 Docker Compose
|
├── docker-compose.dev.yml # 开发环境 Docker Compose
|
||||||
├── vercel.json # Vercel 路由与构建配置
|
├── vercel.json # Vercel 路由与构建配置
|
||||||
├── go.mod / go.sum # Go 模块依赖
|
└── go.mod / go.sum # Go 模块依赖
|
||||||
└── version.txt # 版本号
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 文档索引
|
## 文档索引
|
||||||
@@ -312,11 +433,11 @@ ds2api/
|
|||||||
## 测试
|
## 测试
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 单元测试
|
# 单元测试(Go + Node)
|
||||||
go test ./...
|
./tests/scripts/run-unit-all.sh
|
||||||
|
|
||||||
# 一键端到端全链路测试(真实账号,生成完整请求/响应日志)
|
# 一键端到端全链路测试(真实账号,生成完整请求/响应日志)
|
||||||
./scripts/testsuite/run-live.sh
|
./tests/scripts/run-live.sh
|
||||||
|
|
||||||
# 或自定义参数
|
# 或自定义参数
|
||||||
go run ./cmd/ds2api-tests \
|
go run ./cmd/ds2api-tests \
|
||||||
@@ -327,6 +448,14 @@ go run ./cmd/ds2api-tests \
|
|||||||
--retries 2
|
--retries 2
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 发布前阻断门禁
|
||||||
|
./tests/scripts/check-stage6-manual-smoke.sh
|
||||||
|
./tests/scripts/check-refactor-line-gate.sh
|
||||||
|
./tests/scripts/run-unit-all.sh
|
||||||
|
npm ci --prefix webui && npm run build --prefix webui
|
||||||
|
```
|
||||||
|
|
||||||
## Release 自动构建(GitHub Actions)
|
## Release 自动构建(GitHub Actions)
|
||||||
|
|
||||||
工作流文件:`.github/workflows/release-artifacts.yml`
|
工作流文件:`.github/workflows/release-artifacts.yml`
|
||||||
|
|||||||
195
README.en.md
195
README.en.md
@@ -3,18 +3,18 @@
|
|||||||
[](LICENSE)
|
[](LICENSE)
|
||||||

|

|
||||||

|

|
||||||
[](version.txt)
|
[](https://github.com/CJackHwang/ds2api/releases)
|
||||||
[](DEPLOY.en.md)
|
[](DEPLOY.en.md)
|
||||||
|
|
||||||
Language: [中文](README.MD) | [English](README.en.md)
|
Language: [中文](README.MD) | [English](README.en.md)
|
||||||
|
|
||||||
DS2API converts DeepSeek Web chat capability into OpenAI-compatible and Claude-compatible APIs. The backend is a **pure Go implementation**, with a React WebUI admin panel (source in `webui/`, build output auto-generated to `static/admin` during deployment).
|
DS2API converts DeepSeek Web chat capability into OpenAI-compatible, Claude-compatible, and Gemini-compatible APIs. The backend is a **pure Go implementation**, with a React WebUI admin panel (source in `webui/`, build output auto-generated to `static/admin` during deployment).
|
||||||
|
|
||||||
## Architecture Overview
|
## Architecture Overview
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
flowchart LR
|
flowchart LR
|
||||||
Client["🖥️ Clients\n(OpenAI / Claude compat)"]
|
Client["🖥️ Clients\n(OpenAI / Claude / Gemini compat)"]
|
||||||
|
|
||||||
subgraph DS2API["DS2API Service"]
|
subgraph DS2API["DS2API Service"]
|
||||||
direction TB
|
direction TB
|
||||||
@@ -24,6 +24,7 @@ flowchart LR
|
|||||||
subgraph Adapters["Adapter Layer"]
|
subgraph Adapters["Adapter Layer"]
|
||||||
OA["OpenAI Adapter\n/v1/*"]
|
OA["OpenAI Adapter\n/v1/*"]
|
||||||
CA["Claude Adapter\n/anthropic/*"]
|
CA["Claude Adapter\n/anthropic/*"]
|
||||||
|
GA["Gemini Adapter\n/v1beta/models/*"]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph Support["Support Modules"]
|
subgraph Support["Support Modules"]
|
||||||
@@ -38,11 +39,11 @@ flowchart LR
|
|||||||
DS["☁️ DeepSeek API"]
|
DS["☁️ DeepSeek API"]
|
||||||
|
|
||||||
Client -- "Request" --> CORS --> Auth
|
Client -- "Request" --> CORS --> Auth
|
||||||
Auth --> OA & CA
|
Auth --> OA & CA & GA
|
||||||
OA & CA -- "Call" --> DS
|
OA & CA & GA -- "Call" --> DS
|
||||||
Auth --> Admin
|
Auth --> Admin
|
||||||
OA & CA -. "Rotate accounts" .-> Pool
|
OA & CA & GA -. "Rotate accounts" .-> Pool
|
||||||
OA & CA -. "Compute PoW" .-> PoW
|
OA & CA & GA -. "Compute PoW" .-> PoW
|
||||||
DS -- "Response" --> Client
|
DS -- "Response" --> Client
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -54,16 +55,29 @@ flowchart LR
|
|||||||
|
|
||||||
| Capability | Details |
|
| Capability | Details |
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
| OpenAI compatible | `GET /v1/models`, `POST /v1/chat/completions` (stream/non-stream) |
|
| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` |
|
||||||
| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` |
|
| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` (plus shortcut paths `/v1/messages`, `/messages`) |
|
||||||
|
| Gemini compatible | `POST /v1beta/models/{model}:generateContent`, `POST /v1beta/models/{model}:streamGenerateContent` (plus `/v1/models/{model}:*` paths) |
|
||||||
| Multi-account rotation | Auto token refresh, email/mobile dual login |
|
| Multi-account rotation | Auto token refresh, email/mobile dual login |
|
||||||
| Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency |
|
| Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency |
|
||||||
| DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency |
|
| DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency |
|
||||||
| Tool Calling | Anti-leak handling: auto buffer, detect, structured output |
|
| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output |
|
||||||
| Admin API | Config management, account testing/batch test, import/export, Vercel sync |
|
| Admin API | Config management, runtime settings hot-reload, account testing/batch test, import/export, Vercel sync |
|
||||||
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) |
|
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) |
|
||||||
| Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) |
|
| Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) |
|
||||||
|
|
||||||
|
## Platform Compatibility Matrix
|
||||||
|
|
||||||
|
| Tier | Platform | Status |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| P0 | Codex CLI/SDK (`wire_api=chat` / `wire_api=responses`) | ✅ |
|
||||||
|
| P0 | OpenAI SDK (JS/Python, chat + responses) | ✅ |
|
||||||
|
| P0 | Vercel AI SDK (openai-compatible) | ✅ |
|
||||||
|
| P0 | Anthropic SDK (messages) | ✅ |
|
||||||
|
| P0 | Google Gemini SDK (generateContent) | ✅ |
|
||||||
|
| P1 | LangChain / LlamaIndex / OpenWebUI (OpenAI-compatible integration) | ✅ |
|
||||||
|
| P2 | MCP standalone bridge | Planned |
|
||||||
|
|
||||||
## Model Support
|
## Model Support
|
||||||
|
|
||||||
### OpenAI Endpoint
|
### OpenAI Endpoint
|
||||||
@@ -86,8 +100,25 @@ flowchart LR
|
|||||||
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
||||||
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
||||||
|
|
||||||
|
### Gemini Endpoint
|
||||||
|
|
||||||
|
The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output).
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
### Universal First Step (all deployment modes)
|
||||||
|
|
||||||
|
Use `config.json` as the single source of truth (recommended):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# Edit config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
Recommended per deployment mode:
|
||||||
|
- Local run: read `config.json` directly
|
||||||
|
- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON`
|
||||||
|
|
||||||
### Option 1: Local Run
|
### Option 1: Local Run
|
||||||
|
|
||||||
**Prerequisites**: Go 1.24+, Node.js 20+ (only if building WebUI locally)
|
**Prerequisites**: Go 1.24+, Node.js 20+ (only if building WebUI locally)
|
||||||
@@ -112,14 +143,20 @@ Default URL: `http://localhost:5001`
|
|||||||
### Option 2: Docker
|
### Option 2: Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 1. Configure environment
|
# 1. Prepare env file
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
# Edit .env
|
|
||||||
|
|
||||||
# 2. Start
|
# 2. Generate DS2API_CONFIG_JSON from config.json (single-line Base64)
|
||||||
|
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||||
|
|
||||||
|
# 3. Edit .env and set:
|
||||||
|
# DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||||
|
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||||
|
|
||||||
|
# 4. Start
|
||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
|
|
||||||
# 3. View logs
|
# 5. View logs
|
||||||
docker-compose logs -f
|
docker-compose logs -f
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -129,9 +166,22 @@ Rebuild after updates: `docker-compose up -d --build`
|
|||||||
|
|
||||||
1. Fork this repo to your GitHub account
|
1. Fork this repo to your GitHub account
|
||||||
2. Import the project on Vercel
|
2. Import the project on Vercel
|
||||||
3. Set environment variables (minimum: `DS2API_ADMIN_KEY` and `DS2API_CONFIG_JSON`)
|
3. Set environment variables (minimum: `DS2API_ADMIN_KEY`; recommended to also set `DS2API_CONFIG_JSON`)
|
||||||
4. Deploy
|
4. Deploy
|
||||||
|
|
||||||
|
Recommended first step in repo root:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config.example.json config.json
|
||||||
|
# Edit config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
Recommended: convert `config.json` to Base64 locally, then paste into `DS2API_CONFIG_JSON` to avoid JSON formatting mistakes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
base64 < config.json | tr -d '\n'
|
||||||
|
```
|
||||||
|
|
||||||
> **Streaming note**: `/v1/chat/completions` on Vercel is routed to `api/chat-stream.js` (Node Runtime) for real-time SSE. Auth, account selection, and session/PoW preparation are still handled by the Go internal prepare endpoint; streaming output (including `tools`) is assembled on Node with Go-aligned anti-leak handling.
|
> **Streaming note**: `/v1/chat/completions` on Vercel is routed to `api/chat-stream.js` (Node Runtime) for real-time SSE. Auth, account selection, and session/PoW preparation are still handled by the Go internal prepare endpoint; streaming output (including `tools`) is assembled on Node with Go-aligned anti-leak handling.
|
||||||
|
|
||||||
For detailed deployment instructions, see the [Deployment Guide](DEPLOY.en.md).
|
For detailed deployment instructions, see the [Deployment Guide](DEPLOY.en.md).
|
||||||
@@ -142,8 +192,8 @@ GitHub Actions automatically builds multi-platform archives on each Release:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# After downloading the archive for your platform
|
# After downloading the archive for your platform
|
||||||
tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz
|
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||||
cd ds2api_v1.7.0_linux_amd64
|
cd ds2api_<tag>_linux_amd64
|
||||||
cp config.example.json config.json
|
cp config.example.json config.json
|
||||||
# Edit config.json
|
# Edit config.json
|
||||||
./ds2api
|
./ds2api
|
||||||
@@ -164,6 +214,7 @@ cp opencode.json.example opencode.json
|
|||||||
3. Start OpenCode CLI in the project directory (run `opencode` using your installed method).
|
3. Start OpenCode CLI in the project directory (run `opencode` using your installed method).
|
||||||
|
|
||||||
> Recommended: use the OpenAI-compatible path (`/v1/*`) via `@ai-sdk/openai-compatible` as shown in the example.
|
> Recommended: use the OpenAI-compatible path (`/v1/*`) via `@ai-sdk/openai-compatible` as shown in the example.
|
||||||
|
> If your client supports `wire_api`, test both `responses` and `chat`; DS2API supports both paths.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
@@ -184,9 +235,35 @@ cp opencode.json.example opencode.json
|
|||||||
"token": ""
|
"token": ""
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"model_aliases": {
|
||||||
|
"gpt-4o": "deepseek-chat",
|
||||||
|
"gpt-5-codex": "deepseek-reasoner",
|
||||||
|
"o3": "deepseek-reasoner"
|
||||||
|
},
|
||||||
|
"compat": {
|
||||||
|
"wide_input_strict_output": true
|
||||||
|
},
|
||||||
|
"toolcall": {
|
||||||
|
"mode": "feature_match",
|
||||||
|
"early_emit_confidence": "high"
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"store_ttl_seconds": 900
|
||||||
|
},
|
||||||
|
"embeddings": {
|
||||||
|
"provider": "deterministic"
|
||||||
|
},
|
||||||
"claude_model_mapping": {
|
"claude_model_mapping": {
|
||||||
"fast": "deepseek-chat",
|
"fast": "deepseek-chat",
|
||||||
"slow": "deepseek-reasoner"
|
"slow": "deepseek-reasoner"
|
||||||
|
},
|
||||||
|
"admin": {
|
||||||
|
"jwt_expire_hours": 24
|
||||||
|
},
|
||||||
|
"runtime": {
|
||||||
|
"account_max_inflight": 2,
|
||||||
|
"account_max_queue": 0,
|
||||||
|
"global_max_inflight": 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -194,7 +271,14 @@ cp opencode.json.example opencode.json
|
|||||||
- `keys`: API access keys; clients authenticate via `Authorization: Bearer <key>`
|
- `keys`: API access keys; clients authenticate via `Authorization: Bearer <key>`
|
||||||
- `accounts`: DeepSeek account list, supports `email` or `mobile` login
|
- `accounts`: DeepSeek account list, supports `email` or `mobile` login
|
||||||
- `token`: Leave empty for auto-login on first request; or pre-fill an existing token
|
- `token`: Leave empty for auto-login on first request; or pre-fill an existing token
|
||||||
|
- `model_aliases`: Map common model names (GPT/Codex/Claude) to DeepSeek models
|
||||||
|
- `compat.wide_input_strict_output`: Keep `true` (current default policy)
|
||||||
|
- `toolcall`: Fixed to feature matching + high-confidence early emit
|
||||||
|
- `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}`
|
||||||
|
- `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in)
|
||||||
- `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models
|
- `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models
|
||||||
|
- `admin`: Admin panel settings (JWT expiry, password hash, etc.), hot-reloadable via Admin Settings API
|
||||||
|
- `runtime`: Runtime parameters (concurrency limits, queue sizes), hot-reloadable via Admin Settings API
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
@@ -214,8 +298,13 @@ cp opencode.json.example opencode.json
|
|||||||
| `DS2API_ACCOUNT_CONCURRENCY` | Alias (legacy compat) | — |
|
| `DS2API_ACCOUNT_CONCURRENCY` | Alias (legacy compat) | — |
|
||||||
| `DS2API_ACCOUNT_MAX_QUEUE` | Waiting queue limit | `recommended_concurrency` |
|
| `DS2API_ACCOUNT_MAX_QUEUE` | Waiting queue limit | `recommended_concurrency` |
|
||||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
|
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
|
||||||
|
| `DS2API_GLOBAL_MAX_INFLIGHT` | Global max in-flight requests | `recommended_concurrency` |
|
||||||
|
| `DS2API_MAX_INFLIGHT` | Alias (legacy compat) | — |
|
||||||
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
|
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
|
||||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL seconds | `900` |
|
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL seconds | `900` |
|
||||||
|
| `DS2API_DEV_PACKET_CAPTURE` | Local dev packet capture switch (record recent request/response bodies) | Enabled by default on non-Vercel local runtime |
|
||||||
|
| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | Number of captured sessions to retain (auto-evict overflow) | `5` |
|
||||||
|
| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | Max recorded bytes per captured response body | `2097152` |
|
||||||
| `VERCEL_TOKEN` | Vercel sync token | — |
|
| `VERCEL_TOKEN` | Vercel sync token | — |
|
||||||
| `VERCEL_PROJECT_ID` | Vercel project ID | — |
|
| `VERCEL_PROJECT_ID` | Vercel project ID | — |
|
||||||
| `VERCEL_TEAM_ID` | Vercel team ID | — |
|
| `VERCEL_TEAM_ID` | Vercel team ID | — |
|
||||||
@@ -223,7 +312,7 @@ cp opencode.json.example opencode.json
|
|||||||
|
|
||||||
## Authentication Modes
|
## Authentication Modes
|
||||||
|
|
||||||
For business endpoints (`/v1/*`, `/anthropic/*`), DS2API supports two modes:
|
For business endpoints (`/v1/*`, `/anthropic/*`, Gemini routes), DS2API supports two modes:
|
||||||
|
|
||||||
| Mode | Description |
|
| Mode | Description |
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
@@ -249,10 +338,34 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
|||||||
|
|
||||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||||
|
|
||||||
1. With `stream=true`, DS2API **buffers** text deltas first
|
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||||
2. If a tool call is detected → only structured `tool_calls` are emitted, raw JSON is not leaked
|
2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
||||||
3. If no tool call → buffered text is emitted at once
|
3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls
|
||||||
4. Parser supports mixed text, fenced JSON, and `function.arguments` payloads
|
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
||||||
|
5. Valid tool call events are only emitted after passing policy validation, preventing invalid tool names from entering the client execution chain
|
||||||
|
|
||||||
|
## Local Dev Packet Capture
|
||||||
|
|
||||||
|
This is for debugging issues such as Responses reasoning streaming and tool-call handoff. When enabled, DS2API stores the latest N DeepSeek conversation payload pairs (request body + upstream response body), defaulting to 5 entries with auto-eviction.
|
||||||
|
|
||||||
|
Enable example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
DS2API_DEV_PACKET_CAPTURE=true \
|
||||||
|
DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
|
||||||
|
go run ./cmd/ds2api
|
||||||
|
```
|
||||||
|
|
||||||
|
Inspect/clear (Admin JWT required):
|
||||||
|
|
||||||
|
- `GET /admin/dev/captures`: list captured items (newest first)
|
||||||
|
- `DELETE /admin/dev/captures`: clear captured items
|
||||||
|
|
||||||
|
Response fields include:
|
||||||
|
|
||||||
|
- `request_body`: full payload sent to DeepSeek
|
||||||
|
- `response_body`: concatenated raw upstream stream body text
|
||||||
|
- `response_truncated`: whether body-size truncation happened
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|
||||||
@@ -269,13 +382,20 @@ ds2api/
|
|||||||
│ ├── account/ # Account pool and concurrency queue
|
│ ├── account/ # Account pool and concurrency queue
|
||||||
│ ├── adapter/
|
│ ├── adapter/
|
||||||
│ │ ├── openai/ # OpenAI adapter (incl. tool call parsing, Vercel stream prepare/release)
|
│ │ ├── openai/ # OpenAI adapter (incl. tool call parsing, Vercel stream prepare/release)
|
||||||
│ │ └── claude/ # Claude adapter
|
│ │ ├── claude/ # Claude adapter
|
||||||
│ ├── admin/ # Admin API handlers
|
│ │ └── gemini/ # Gemini adapter (generateContent / streamGenerateContent)
|
||||||
|
│ ├── admin/ # Admin API handlers (incl. Settings hot-reload)
|
||||||
│ ├── auth/ # Auth and JWT
|
│ ├── auth/ # Auth and JWT
|
||||||
|
│ ├── claudeconv/ # Claude message format conversion
|
||||||
|
│ ├── compat/ # Compatibility helpers
|
||||||
│ ├── config/ # Config loading and hot-reload
|
│ ├── config/ # Config loading and hot-reload
|
||||||
│ ├── deepseek/ # DeepSeek API client, PoW WASM
|
│ ├── deepseek/ # DeepSeek API client, PoW WASM
|
||||||
|
│ ├── devcapture/ # Dev packet capture module
|
||||||
|
│ ├── format/ # Output formatting
|
||||||
|
│ ├── prompt/ # Prompt construction
|
||||||
│ ├── server/ # HTTP routing and middleware (chi router)
|
│ ├── server/ # HTTP routing and middleware (chi router)
|
||||||
│ ├── sse/ # SSE parsing utilities
|
│ ├── sse/ # SSE parsing utilities
|
||||||
|
│ ├── stream/ # Unified stream consumption engine
|
||||||
│ ├── util/ # Common utilities
|
│ ├── util/ # Common utilities
|
||||||
│ └── webui/ # WebUI static file serving and auto-build
|
│ └── webui/ # WebUI static file serving and auto-build
|
||||||
├── webui/ # React WebUI source (Vite + Tailwind)
|
├── webui/ # React WebUI source (Vite + Tailwind)
|
||||||
@@ -283,11 +403,13 @@ ds2api/
|
|||||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||||
│ └── locales/ # Language packs (zh.json / en.json)
|
│ └── locales/ # Language packs (zh.json / en.json)
|
||||||
├── scripts/
|
├── scripts/
|
||||||
│ ├── build-webui.sh # Manual WebUI build script
|
│ └── build-webui.sh # Manual WebUI build script
|
||||||
│ └── testsuite/ # Testsuite runner scripts
|
├── tests/
|
||||||
|
│ ├── compat/ # Compatibility fixtures and expected outputs
|
||||||
|
│ └── scripts/ # Unified test script entrypoints (unit/e2e)
|
||||||
├── static/admin/ # WebUI build output (not committed to Git)
|
├── static/admin/ # WebUI build output (not committed to Git)
|
||||||
├── .github/
|
├── .github/
|
||||||
│ ├── workflows/ # GitHub Actions (Release artifact automation)
|
│ ├── workflows/ # GitHub Actions (quality gates + release automation)
|
||||||
│ ├── ISSUE_TEMPLATE/ # Issue templates
|
│ ├── ISSUE_TEMPLATE/ # Issue templates
|
||||||
│ └── PULL_REQUEST_TEMPLATE.md
|
│ └── PULL_REQUEST_TEMPLATE.md
|
||||||
├── config.example.json # Config file template
|
├── config.example.json # Config file template
|
||||||
@@ -296,8 +418,7 @@ ds2api/
|
|||||||
├── docker-compose.yml # Production Docker Compose
|
├── docker-compose.yml # Production Docker Compose
|
||||||
├── docker-compose.dev.yml # Development Docker Compose
|
├── docker-compose.dev.yml # Development Docker Compose
|
||||||
├── vercel.json # Vercel routing and build config
|
├── vercel.json # Vercel routing and build config
|
||||||
├── go.mod / go.sum # Go module dependencies
|
└── go.mod / go.sum # Go module dependencies
|
||||||
└── version.txt # Version number
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Documentation Index
|
## Documentation Index
|
||||||
@@ -312,11 +433,11 @@ ds2api/
|
|||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Unit tests
|
# Unit tests (Go + Node)
|
||||||
go test ./...
|
./tests/scripts/run-unit-all.sh
|
||||||
|
|
||||||
# One-command live end-to-end tests (real accounts, full request/response logs)
|
# One-command live end-to-end tests (real accounts, full request/response logs)
|
||||||
./scripts/testsuite/run-live.sh
|
./tests/scripts/run-live.sh
|
||||||
|
|
||||||
# Or with custom flags
|
# Or with custom flags
|
||||||
go run ./cmd/ds2api-tests \
|
go run ./cmd/ds2api-tests \
|
||||||
@@ -327,6 +448,14 @@ go run ./cmd/ds2api-tests \
|
|||||||
--retries 2
|
--retries 2
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Release-blocking gates
|
||||||
|
./tests/scripts/check-stage6-manual-smoke.sh
|
||||||
|
./tests/scripts/check-refactor-line-gate.sh
|
||||||
|
./tests/scripts/run-unit-all.sh
|
||||||
|
npm ci --prefix webui && npm run build --prefix webui
|
||||||
|
```
|
||||||
|
|
||||||
## Release Artifact Automation (GitHub Actions)
|
## Release Artifact Automation (GitHub Actions)
|
||||||
|
|
||||||
Workflow: `.github/workflows/release-artifacts.yml`
|
Workflow: `.github/workflows/release-artifacts.yml`
|
||||||
|
|||||||
30
TESTING.md
30
TESTING.md
@@ -8,8 +8,10 @@ DS2API 提供两个层级的测试:
|
|||||||
|
|
||||||
| 层级 | 命令 | 说明 |
|
| 层级 | 命令 | 说明 |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| 单元测试 | `go test ./...` | 不需要真实账号 |
|
| 单元测试(Go) | `./tests/scripts/run-unit-go.sh` | 不需要真实账号 |
|
||||||
| 端到端测试 | `./scripts/testsuite/run-live.sh` | 使用真实账号执行全链路测试 |
|
| 单元测试(Node) | `./tests/scripts/run-unit-node.sh` | 不需要真实账号 |
|
||||||
|
| 单元测试(全部) | `./tests/scripts/run-unit-all.sh` | 不需要真实账号 |
|
||||||
|
| 端到端测试 | `./tests/scripts/run-live.sh` | 使用真实账号执行全链路测试 |
|
||||||
|
|
||||||
端到端测试集会录制完整的请求/响应日志,用于故障排查。
|
端到端测试集会录制完整的请求/响应日志,用于故障排查。
|
||||||
|
|
||||||
@@ -20,26 +22,36 @@ DS2API 提供两个层级的测试:
|
|||||||
### 单元测试 | Unit Tests
|
### 单元测试 | Unit Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
go test ./...
|
./tests/scripts/run-unit-all.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js
|
# 或按语言拆分执行
|
||||||
|
./tests/scripts/run-unit-go.sh
|
||||||
|
./tests/scripts/run-unit-node.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 结构与流程门禁
|
||||||
|
./tests/scripts/check-refactor-line-gate.sh
|
||||||
|
./tests/scripts/check-node-split-syntax.sh
|
||||||
|
|
||||||
|
# 发布阻断:阶段 6 手工烟测签字检查(默认读取 plans/stage6-manual-smoke.md)
|
||||||
|
./tests/scripts/check-stage6-manual-smoke.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
### 端到端测试 | End-to-End Tests
|
### 端到端测试 | End-to-End Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./scripts/testsuite/run-live.sh
|
./tests/scripts/run-live.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
**默认行为**:
|
**默认行为**:
|
||||||
|
|
||||||
1. **Preflight 检查**:
|
1. **Preflight 检查**:
|
||||||
- `go test ./... -count=1`(单元测试)
|
- `go test ./... -count=1`(单元测试)
|
||||||
- `node --check api/chat-stream.js`(语法检查)
|
- `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁)
|
||||||
- `node --check api/helpers/stream-tool-sieve.js`(语法检查)
|
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`(Node 流式拦截 + compat 单测)
|
||||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js`(Node 流式拦截单测)
|
|
||||||
- `npm run build --prefix webui`(WebUI 构建检查)
|
- `npm run build --prefix webui`(WebUI 构建检查)
|
||||||
|
|
||||||
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
||||||
@@ -179,7 +191,7 @@ go run ./cmd/ds2api-tests \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 确保 config.json 存在且包含有效测试账号
|
# 确保 config.json 存在且包含有效测试账号
|
||||||
./scripts/testsuite/run-live.sh
|
./tests/scripts/run-live.sh
|
||||||
exit_code=$?
|
exit_code=$?
|
||||||
if [ $exit_code -ne 0 ]; then
|
if [ $exit_code -ne 0 ]; then
|
||||||
echo "Tests failed! Check artifacts for details."
|
echo "Tests failed! Check artifacts for details."
|
||||||
|
|||||||
@@ -1,770 +1,3 @@
|
|||||||
'use strict';
|
'use strict';
|
||||||
|
|
||||||
const {
|
module.exports = require('../internal/js/chat-stream/index.js');
|
||||||
extractToolNames,
|
|
||||||
createToolSieveState,
|
|
||||||
processToolSieveChunk,
|
|
||||||
flushToolSieve,
|
|
||||||
parseToolCalls,
|
|
||||||
formatOpenAIStreamToolCalls,
|
|
||||||
} = require('./helpers/stream-tool-sieve');
|
|
||||||
|
|
||||||
const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion';
|
|
||||||
|
|
||||||
const BASE_HEADERS = {
|
|
||||||
Host: 'chat.deepseek.com',
|
|
||||||
'User-Agent': 'DeepSeek/1.6.11 Android/35',
|
|
||||||
Accept: 'application/json',
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'x-client-platform': 'android',
|
|
||||||
'x-client-version': '1.6.11',
|
|
||||||
'x-client-locale': 'zh_CN',
|
|
||||||
'accept-charset': 'UTF-8',
|
|
||||||
};
|
|
||||||
|
|
||||||
const SKIP_PATTERNS = [
|
|
||||||
'quasi_status',
|
|
||||||
'elapsed_secs',
|
|
||||||
'token_usage',
|
|
||||||
'pending_fragment',
|
|
||||||
'conversation_mode',
|
|
||||||
'fragments/-1/status',
|
|
||||||
'fragments/-2/status',
|
|
||||||
'fragments/-3/status',
|
|
||||||
];
|
|
||||||
|
|
||||||
module.exports = async function handler(req, res) {
|
|
||||||
setCorsHeaders(res);
|
|
||||||
if (req.method === 'OPTIONS') {
|
|
||||||
res.statusCode = 204;
|
|
||||||
res.end();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (req.method !== 'POST') {
|
|
||||||
writeOpenAIError(res, 405, 'method not allowed');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const rawBody = await readRawBody(req);
|
|
||||||
|
|
||||||
// Hard guard: only use Node data path for streaming on Vercel runtime.
|
|
||||||
// Any non-Vercel runtime always falls back to Go for full behavior parity.
|
|
||||||
if (!isVercelRuntime()) {
|
|
||||||
await proxyToGo(req, res, rawBody);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let payload;
|
|
||||||
try {
|
|
||||||
payload = JSON.parse(rawBody.toString('utf8') || '{}');
|
|
||||||
} catch (_err) {
|
|
||||||
writeOpenAIError(res, 400, 'invalid json');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keep all non-stream behavior on Go side to avoid compatibility regressions.
|
|
||||||
if (!toBool(payload.stream)) {
|
|
||||||
await proxyToGo(req, res, rawBody);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const prep = await fetchStreamPrepare(req, rawBody);
|
|
||||||
if (!prep.ok) {
|
|
||||||
relayPreparedFailure(res, prep);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const model = asString(prep.body.model) || asString(payload.model);
|
|
||||||
const sessionID = asString(prep.body.session_id) || `chatcmpl-${Date.now()}`;
|
|
||||||
const leaseID = asString(prep.body.lease_id);
|
|
||||||
const deepseekToken = asString(prep.body.deepseek_token);
|
|
||||||
const powHeader = asString(prep.body.pow_header);
|
|
||||||
const completionPayload = prep.body.payload && typeof prep.body.payload === 'object' ? prep.body.payload : null;
|
|
||||||
const finalPrompt = asString(prep.body.final_prompt);
|
|
||||||
const thinkingEnabled = toBool(prep.body.thinking_enabled);
|
|
||||||
const searchEnabled = toBool(prep.body.search_enabled);
|
|
||||||
const toolNames = extractToolNames(payload.tools);
|
|
||||||
|
|
||||||
if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
|
|
||||||
writeOpenAIError(res, 500, 'invalid vercel prepare response');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const releaseLease = createLeaseReleaser(req, leaseID);
|
|
||||||
try {
|
|
||||||
const completionRes = await fetch(DEEPSEEK_COMPLETION_URL, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
...BASE_HEADERS,
|
|
||||||
authorization: `Bearer ${deepseekToken}`,
|
|
||||||
'x-ds-pow-response': powHeader,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(completionPayload),
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!completionRes.ok || !completionRes.body) {
|
|
||||||
const detail = await safeReadText(completionRes);
|
|
||||||
writeOpenAIError(res, 500, detail ? `Failed to get completion: ${detail}` : 'Failed to get completion.');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
res.statusCode = 200;
|
|
||||||
res.setHeader('Content-Type', 'text/event-stream');
|
|
||||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
|
||||||
res.setHeader('Connection', 'keep-alive');
|
|
||||||
res.setHeader('X-Accel-Buffering', 'no');
|
|
||||||
if (typeof res.flushHeaders === 'function') {
|
|
||||||
res.flushHeaders();
|
|
||||||
}
|
|
||||||
|
|
||||||
const created = Math.floor(Date.now() / 1000);
|
|
||||||
let firstChunkSent = false;
|
|
||||||
let currentType = thinkingEnabled ? 'thinking' : 'text';
|
|
||||||
let thinkingText = '';
|
|
||||||
let outputText = '';
|
|
||||||
const toolSieveEnabled = toolNames.length > 0;
|
|
||||||
const toolSieveState = createToolSieveState();
|
|
||||||
let toolCallsEmitted = false;
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
const reader = completionRes.body.getReader();
|
|
||||||
let buffered = '';
|
|
||||||
let ended = false;
|
|
||||||
|
|
||||||
const sendFrame = (obj) => {
|
|
||||||
res.write(`data: ${JSON.stringify(obj)}\n\n`);
|
|
||||||
if (typeof res.flush === 'function') {
|
|
||||||
res.flush();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const sendDeltaFrame = (delta) => {
|
|
||||||
const payloadDelta = { ...delta };
|
|
||||||
if (!firstChunkSent) {
|
|
||||||
payloadDelta.role = 'assistant';
|
|
||||||
firstChunkSent = true;
|
|
||||||
}
|
|
||||||
sendFrame({
|
|
||||||
id: sessionID,
|
|
||||||
object: 'chat.completion.chunk',
|
|
||||||
created,
|
|
||||||
model,
|
|
||||||
choices: [{ delta: payloadDelta, index: 0 }],
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
const finish = async (reason) => {
|
|
||||||
if (ended) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
ended = true;
|
|
||||||
const detected = parseToolCalls(outputText, toolNames);
|
|
||||||
if (detected.length > 0 && !toolCallsEmitted) {
|
|
||||||
toolCallsEmitted = true;
|
|
||||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) });
|
|
||||||
} else if (toolSieveEnabled) {
|
|
||||||
const tailEvents = flushToolSieve(toolSieveState, toolNames);
|
|
||||||
for (const evt of tailEvents) {
|
|
||||||
if (evt.text) {
|
|
||||||
sendDeltaFrame({ content: evt.text });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (detected.length > 0 || toolCallsEmitted) {
|
|
||||||
reason = 'tool_calls';
|
|
||||||
}
|
|
||||||
sendFrame({
|
|
||||||
id: sessionID,
|
|
||||||
object: 'chat.completion.chunk',
|
|
||||||
created,
|
|
||||||
model,
|
|
||||||
choices: [{ delta: {}, index: 0, finish_reason: reason }],
|
|
||||||
usage: buildUsage(finalPrompt, thinkingText, outputText),
|
|
||||||
});
|
|
||||||
res.write('data: [DONE]\n\n');
|
|
||||||
await releaseLease();
|
|
||||||
res.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
// eslint-disable-next-line no-constant-condition
|
|
||||||
while (true) {
|
|
||||||
const { value, done } = await reader.read();
|
|
||||||
if (done) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
buffered += decoder.decode(value, { stream: true });
|
|
||||||
const lines = buffered.split('\n');
|
|
||||||
buffered = lines.pop() || '';
|
|
||||||
|
|
||||||
for (const rawLine of lines) {
|
|
||||||
const line = rawLine.trim();
|
|
||||||
if (!line.startsWith('data:')) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const dataStr = line.slice(5).trim();
|
|
||||||
if (!dataStr) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (dataStr === '[DONE]') {
|
|
||||||
await finish('stop');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
let chunk;
|
|
||||||
try {
|
|
||||||
chunk = JSON.parse(dataStr);
|
|
||||||
} catch (_err) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (chunk.error || chunk.code === 'content_filter') {
|
|
||||||
await finish('content_filter');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const parsed = parseChunkForContent(chunk, thinkingEnabled, currentType);
|
|
||||||
currentType = parsed.newType;
|
|
||||||
if (parsed.finished) {
|
|
||||||
await finish('stop');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const p of parsed.parts) {
|
|
||||||
if (!p.text) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (searchEnabled && isCitation(p.text)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (p.type === 'thinking') {
|
|
||||||
if (thinkingEnabled) {
|
|
||||||
thinkingText += p.text;
|
|
||||||
sendDeltaFrame({ reasoning_content: p.text });
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
outputText += p.text;
|
|
||||||
if (!toolSieveEnabled) {
|
|
||||||
sendDeltaFrame({ content: p.text });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
|
||||||
for (const evt of events) {
|
|
||||||
if (evt.type === 'tool_calls') {
|
|
||||||
toolCallsEmitted = true;
|
|
||||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) });
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (evt.text) {
|
|
||||||
sendDeltaFrame({ content: evt.text });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
await finish('stop');
|
|
||||||
} catch (_err) {
|
|
||||||
await finish('stop');
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
await releaseLease();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
function setCorsHeaders(res) {
|
|
||||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
|
||||||
res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
|
|
||||||
res.setHeader(
|
|
||||||
'Access-Control-Allow-Headers',
|
|
||||||
'Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
function header(req, key) {
|
|
||||||
if (!req || !req.headers) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
return asString(req.headers[key.toLowerCase()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function readRawBody(req) {
|
|
||||||
if (Buffer.isBuffer(req.body)) {
|
|
||||||
return req.body;
|
|
||||||
}
|
|
||||||
if (typeof req.body === 'string') {
|
|
||||||
return Buffer.from(req.body);
|
|
||||||
}
|
|
||||||
if (req.body && typeof req.body === 'object') {
|
|
||||||
return Buffer.from(JSON.stringify(req.body));
|
|
||||||
}
|
|
||||||
const chunks = [];
|
|
||||||
for await (const chunk of req) {
|
|
||||||
chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk));
|
|
||||||
}
|
|
||||||
return Buffer.concat(chunks);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function fetchStreamPrepare(req, rawBody) {
|
|
||||||
const url = buildInternalGoURL(req);
|
|
||||||
url.searchParams.set('__stream_prepare', '1');
|
|
||||||
|
|
||||||
const upstream = await fetch(url.toString(), {
|
|
||||||
method: 'POST',
|
|
||||||
headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }),
|
|
||||||
body: rawBody,
|
|
||||||
});
|
|
||||||
|
|
||||||
const text = await upstream.text();
|
|
||||||
let body = {};
|
|
||||||
try {
|
|
||||||
body = JSON.parse(text || '{}');
|
|
||||||
} catch (_err) {
|
|
||||||
body = {};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
ok: upstream.ok,
|
|
||||||
status: upstream.status,
|
|
||||||
contentType: upstream.headers.get('content-type') || 'application/json',
|
|
||||||
text,
|
|
||||||
body,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
function relayPreparedFailure(res, prep) {
|
|
||||||
if (prep.status === 401 && looksLikeVercelAuthPage(prep.text)) {
|
|
||||||
writeOpenAIError(
|
|
||||||
res,
|
|
||||||
401,
|
|
||||||
'Vercel Deployment Protection blocked internal prepare request. Disable protection for this deployment or set VERCEL_AUTOMATION_BYPASS_SECRET.',
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
res.statusCode = prep.status || 500;
|
|
||||||
res.setHeader('Content-Type', prep.contentType || 'application/json');
|
|
||||||
if (prep.text) {
|
|
||||||
res.end(prep.text);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
writeOpenAIError(res, prep.status || 500, 'vercel prepare failed');
|
|
||||||
}
|
|
||||||
|
|
||||||
async function safeReadText(resp) {
|
|
||||||
if (!resp) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const text = await resp.text();
|
|
||||||
return text.trim();
|
|
||||||
} catch (_err) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function internalSecret() {
|
|
||||||
return asString(process.env.DS2API_VERCEL_INTERNAL_SECRET) || asString(process.env.DS2API_ADMIN_KEY) || 'admin';
|
|
||||||
}
|
|
||||||
|
|
||||||
function buildInternalGoURL(req) {
|
|
||||||
const proto = asString(header(req, 'x-forwarded-proto')) || 'https';
|
|
||||||
const host = asString(header(req, 'host'));
|
|
||||||
const url = new URL(`${proto}://${host}${req.url || '/v1/chat/completions'}`);
|
|
||||||
url.searchParams.set('__go', '1');
|
|
||||||
const protectionBypass = resolveProtectionBypass(req);
|
|
||||||
if (protectionBypass) {
|
|
||||||
url.searchParams.set('x-vercel-protection-bypass', protectionBypass);
|
|
||||||
}
|
|
||||||
return url;
|
|
||||||
}
|
|
||||||
|
|
||||||
function buildInternalGoHeaders(req, opts = {}) {
|
|
||||||
const headers = {
|
|
||||||
authorization: asString(header(req, 'authorization')),
|
|
||||||
'x-api-key': asString(header(req, 'x-api-key')),
|
|
||||||
'x-ds2-target-account': asString(header(req, 'x-ds2-target-account')),
|
|
||||||
'x-vercel-protection-bypass': resolveProtectionBypass(req),
|
|
||||||
};
|
|
||||||
if (opts.withInternalToken) {
|
|
||||||
headers['x-ds2-internal-token'] = internalSecret();
|
|
||||||
}
|
|
||||||
if (opts.withContentType) {
|
|
||||||
headers['content-type'] = asString(header(req, 'content-type')) || 'application/json';
|
|
||||||
}
|
|
||||||
return headers;
|
|
||||||
}
|
|
||||||
|
|
||||||
function createLeaseReleaser(req, leaseID) {
|
|
||||||
let released = false;
|
|
||||||
return async () => {
|
|
||||||
if (released || !leaseID) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
released = true;
|
|
||||||
try {
|
|
||||||
await releaseStreamLease(req, leaseID);
|
|
||||||
} catch (_err) {
|
|
||||||
// Ignore release errors. Lease TTL cleanup on Go side still prevents permanent leaks.
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
async function releaseStreamLease(req, leaseID) {
|
|
||||||
const url = buildInternalGoURL(req);
|
|
||||||
url.searchParams.set('__stream_release', '1');
|
|
||||||
const body = Buffer.from(JSON.stringify({ lease_id: leaseID }));
|
|
||||||
|
|
||||||
const controller = new AbortController();
|
|
||||||
const timeout = setTimeout(() => controller.abort(), 1500);
|
|
||||||
try {
|
|
||||||
await fetch(url.toString(), {
|
|
||||||
method: 'POST',
|
|
||||||
headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }),
|
|
||||||
body,
|
|
||||||
signal: controller.signal,
|
|
||||||
});
|
|
||||||
} finally {
|
|
||||||
clearTimeout(timeout);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function resolveProtectionBypass(req) {
|
|
||||||
const fromHeader = asString(header(req, 'x-vercel-protection-bypass'));
|
|
||||||
if (fromHeader) {
|
|
||||||
return fromHeader;
|
|
||||||
}
|
|
||||||
return asString(process.env.VERCEL_AUTOMATION_BYPASS_SECRET) || asString(process.env.DS2API_VERCEL_PROTECTION_BYPASS);
|
|
||||||
}
|
|
||||||
|
|
||||||
function looksLikeVercelAuthPage(text) {
|
|
||||||
const body = asString(text).toLowerCase();
|
|
||||||
if (!body) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return body.includes('authentication required') && body.includes('vercel');
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseChunkForContent(chunk, thinkingEnabled, currentType) {
|
|
||||||
if (!chunk || typeof chunk !== 'object' || !Object.prototype.hasOwnProperty.call(chunk, 'v')) {
|
|
||||||
return { parts: [], finished: false, newType: currentType };
|
|
||||||
}
|
|
||||||
const pathValue = asString(chunk.p);
|
|
||||||
if (shouldSkipPath(pathValue)) {
|
|
||||||
return { parts: [], finished: false, newType: currentType };
|
|
||||||
}
|
|
||||||
if (pathValue === 'response/status' && asString(chunk.v) === 'FINISHED') {
|
|
||||||
return { parts: [], finished: true, newType: currentType };
|
|
||||||
}
|
|
||||||
|
|
||||||
let newType = currentType;
|
|
||||||
const parts = [];
|
|
||||||
|
|
||||||
if (pathValue === 'response/fragments' && asString(chunk.o).toUpperCase() === 'APPEND' && Array.isArray(chunk.v)) {
|
|
||||||
for (const frag of chunk.v) {
|
|
||||||
if (!frag || typeof frag !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const fragType = asString(frag.type).toUpperCase();
|
|
||||||
const content = asString(frag.content);
|
|
||||||
if (!content) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (fragType === 'THINK' || fragType === 'THINKING') {
|
|
||||||
newType = 'thinking';
|
|
||||||
parts.push({ text: content, type: 'thinking' });
|
|
||||||
} else if (fragType === 'RESPONSE') {
|
|
||||||
newType = 'text';
|
|
||||||
parts.push({ text: content, type: 'text' });
|
|
||||||
} else {
|
|
||||||
parts.push({ text: content, type: 'text' });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pathValue === 'response' && Array.isArray(chunk.v)) {
|
|
||||||
for (const item of chunk.v) {
|
|
||||||
if (!item || typeof item !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (item.p === 'fragments' && item.o === 'APPEND' && Array.isArray(item.v)) {
|
|
||||||
for (const frag of item.v) {
|
|
||||||
const fragType = asString(frag && frag.type).toUpperCase();
|
|
||||||
if (fragType === 'THINK' || fragType === 'THINKING') {
|
|
||||||
newType = 'thinking';
|
|
||||||
} else if (fragType === 'RESPONSE') {
|
|
||||||
newType = 'text';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let partType = 'text';
|
|
||||||
if (pathValue === 'response/thinking_content') {
|
|
||||||
partType = 'thinking';
|
|
||||||
} else if (pathValue === 'response/content') {
|
|
||||||
partType = 'text';
|
|
||||||
} else if (pathValue.includes('response/fragments') && pathValue.includes('/content')) {
|
|
||||||
partType = newType;
|
|
||||||
} else if (!pathValue && thinkingEnabled) {
|
|
||||||
partType = newType;
|
|
||||||
}
|
|
||||||
|
|
||||||
const val = chunk.v;
|
|
||||||
if (typeof val === 'string') {
|
|
||||||
if (val === 'FINISHED' && (!pathValue || pathValue === 'status')) {
|
|
||||||
return { parts: [], finished: true, newType };
|
|
||||||
}
|
|
||||||
if (val) {
|
|
||||||
parts.push({ text: val, type: partType });
|
|
||||||
}
|
|
||||||
return { parts, finished: false, newType };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Array.isArray(val)) {
|
|
||||||
const extracted = extractContentRecursive(val, partType);
|
|
||||||
if (extracted.finished) {
|
|
||||||
return { parts: [], finished: true, newType };
|
|
||||||
}
|
|
||||||
parts.push(...extracted.parts);
|
|
||||||
return { parts, finished: false, newType };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (val && typeof val === 'object') {
|
|
||||||
const resp = val.response && typeof val.response === 'object' ? val.response : val;
|
|
||||||
if (Array.isArray(resp.fragments)) {
|
|
||||||
for (const frag of resp.fragments) {
|
|
||||||
if (!frag || typeof frag !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const content = asString(frag.content);
|
|
||||||
if (!content) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const t = asString(frag.type).toUpperCase();
|
|
||||||
if (t === 'THINK' || t === 'THINKING') {
|
|
||||||
newType = 'thinking';
|
|
||||||
parts.push({ text: content, type: 'thinking' });
|
|
||||||
} else if (t === 'RESPONSE') {
|
|
||||||
newType = 'text';
|
|
||||||
parts.push({ text: content, type: 'text' });
|
|
||||||
} else {
|
|
||||||
parts.push({ text: content, type: partType });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { parts, finished: false, newType };
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractContentRecursive(items, defaultType) {
|
|
||||||
const parts = [];
|
|
||||||
for (const it of items) {
|
|
||||||
if (!it || typeof it !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!Object.prototype.hasOwnProperty.call(it, 'v')) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const itemPath = asString(it.p);
|
|
||||||
const itemV = it.v;
|
|
||||||
if (itemPath === 'status' && asString(itemV) === 'FINISHED') {
|
|
||||||
return { parts: [], finished: true };
|
|
||||||
}
|
|
||||||
if (shouldSkipPath(itemPath)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const content = asString(it.content);
|
|
||||||
if (content) {
|
|
||||||
const typeName = asString(it.type).toUpperCase();
|
|
||||||
if (typeName === 'THINK' || typeName === 'THINKING') {
|
|
||||||
parts.push({ text: content, type: 'thinking' });
|
|
||||||
} else if (typeName === 'RESPONSE') {
|
|
||||||
parts.push({ text: content, type: 'text' });
|
|
||||||
} else {
|
|
||||||
parts.push({ text: content, type: defaultType });
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let partType = defaultType;
|
|
||||||
if (itemPath.includes('thinking')) {
|
|
||||||
partType = 'thinking';
|
|
||||||
} else if (itemPath.includes('content') || itemPath === 'response' || itemPath === 'fragments') {
|
|
||||||
partType = 'text';
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof itemV === 'string') {
|
|
||||||
if (itemV && itemV !== 'FINISHED') {
|
|
||||||
parts.push({ text: itemV, type: partType });
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!Array.isArray(itemV)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (const inner of itemV) {
|
|
||||||
if (typeof inner === 'string') {
|
|
||||||
if (inner) {
|
|
||||||
parts.push({ text: inner, type: partType });
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!inner || typeof inner !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const ct = asString(inner.content);
|
|
||||||
if (!ct) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const typeName = asString(inner.type).toUpperCase();
|
|
||||||
if (typeName === 'THINK' || typeName === 'THINKING') {
|
|
||||||
parts.push({ text: ct, type: 'thinking' });
|
|
||||||
} else if (typeName === 'RESPONSE') {
|
|
||||||
parts.push({ text: ct, type: 'text' });
|
|
||||||
} else {
|
|
||||||
parts.push({ text: ct, type: partType });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { parts, finished: false };
|
|
||||||
}
|
|
||||||
|
|
||||||
function shouldSkipPath(pathValue) {
|
|
||||||
if (pathValue === 'response/search_status') {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
for (const p of SKIP_PATTERNS) {
|
|
||||||
if (pathValue.includes(p)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isCitation(text) {
|
|
||||||
return asString(text).trim().startsWith('[citation:');
|
|
||||||
}
|
|
||||||
|
|
||||||
function buildUsage(prompt, thinking, output) {
|
|
||||||
const promptTokens = estimateTokens(prompt);
|
|
||||||
const reasoningTokens = estimateTokens(thinking);
|
|
||||||
const completionTokens = estimateTokens(output);
|
|
||||||
return {
|
|
||||||
prompt_tokens: promptTokens,
|
|
||||||
completion_tokens: reasoningTokens + completionTokens,
|
|
||||||
total_tokens: promptTokens + reasoningTokens + completionTokens,
|
|
||||||
completion_tokens_details: {
|
|
||||||
reasoning_tokens: reasoningTokens,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
function estimateTokens(text) {
|
|
||||||
const t = asString(text);
|
|
||||||
if (!t) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
const n = Math.floor(Array.from(t).length / 4);
|
|
||||||
return n < 1 ? 1 : n;
|
|
||||||
}
|
|
||||||
|
|
||||||
async function proxyToGo(req, res, rawBody) {
|
|
||||||
const url = buildInternalGoURL(req);
|
|
||||||
|
|
||||||
const upstream = await fetch(url.toString(), {
|
|
||||||
method: 'POST',
|
|
||||||
headers: buildInternalGoHeaders(req, { withContentType: true }),
|
|
||||||
body: rawBody,
|
|
||||||
});
|
|
||||||
|
|
||||||
res.statusCode = upstream.status;
|
|
||||||
upstream.headers.forEach((value, key) => {
|
|
||||||
if (key.toLowerCase() === 'content-length') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
res.setHeader(key, value);
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!upstream.body || typeof upstream.body.getReader !== 'function') {
|
|
||||||
const bytes = Buffer.from(await upstream.arrayBuffer());
|
|
||||||
res.end(bytes);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const reader = upstream.body.getReader();
|
|
||||||
try {
|
|
||||||
// eslint-disable-next-line no-constant-condition
|
|
||||||
while (true) {
|
|
||||||
const { value, done } = await reader.read();
|
|
||||||
if (done) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (value && value.length > 0) {
|
|
||||||
res.write(Buffer.from(value));
|
|
||||||
if (typeof res.flush === 'function') {
|
|
||||||
res.flush();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
res.end();
|
|
||||||
} catch (_err) {
|
|
||||||
if (!res.writableEnded) {
|
|
||||||
res.end();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function writeOpenAIError(res, status, message) {
|
|
||||||
res.statusCode = status;
|
|
||||||
res.setHeader('Content-Type', 'application/json');
|
|
||||||
res.end(
|
|
||||||
JSON.stringify({
|
|
||||||
error: {
|
|
||||||
message,
|
|
||||||
type: openAIErrorType(status),
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
function openAIErrorType(status) {
|
|
||||||
switch (status) {
|
|
||||||
case 400:
|
|
||||||
return 'invalid_request_error';
|
|
||||||
case 401:
|
|
||||||
return 'authentication_error';
|
|
||||||
case 403:
|
|
||||||
return 'permission_error';
|
|
||||||
case 429:
|
|
||||||
return 'rate_limit_error';
|
|
||||||
case 503:
|
|
||||||
return 'service_unavailable_error';
|
|
||||||
default:
|
|
||||||
return status >= 500 ? 'api_error' : 'invalid_request_error';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function toBool(v) {
|
|
||||||
return v === true;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isVercelRuntime() {
|
|
||||||
return asString(process.env.VERCEL) !== '' || asString(process.env.NOW_REGION) !== '';
|
|
||||||
}
|
|
||||||
|
|
||||||
function asString(v) {
|
|
||||||
if (typeof v === 'string') {
|
|
||||||
return v.trim();
|
|
||||||
}
|
|
||||||
if (Array.isArray(v)) {
|
|
||||||
return asString(v[0]);
|
|
||||||
}
|
|
||||||
if (v == null) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
return String(v).trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
module.exports.__test = {
|
|
||||||
parseChunkForContent,
|
|
||||||
extractContentRecursive,
|
|
||||||
shouldSkipPath,
|
|
||||||
asString,
|
|
||||||
};
|
|
||||||
|
|||||||
@@ -1,477 +0,0 @@
|
|||||||
'use strict';
|
|
||||||
|
|
||||||
const crypto = require('crypto');
|
|
||||||
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
|
||||||
|
|
||||||
function extractToolNames(tools) {
|
|
||||||
if (!Array.isArray(tools) || tools.length === 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const out = [];
|
|
||||||
for (const t of tools) {
|
|
||||||
if (!t || typeof t !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const fn = t.function && typeof t.function === 'object' ? t.function : t;
|
|
||||||
const name = toStringSafe(fn.name);
|
|
||||||
// Keep parity with Go injectToolPrompt: object tools without name still
|
|
||||||
// enter tool mode via fallback name "unknown".
|
|
||||||
out.push(name || 'unknown');
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
function createToolSieveState() {
|
|
||||||
return {
|
|
||||||
pending: '',
|
|
||||||
capture: '',
|
|
||||||
capturing: false,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
function processToolSieveChunk(state, chunk, toolNames) {
|
|
||||||
if (!state) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (chunk) {
|
|
||||||
state.pending += chunk;
|
|
||||||
}
|
|
||||||
const events = [];
|
|
||||||
// eslint-disable-next-line no-constant-condition
|
|
||||||
while (true) {
|
|
||||||
if (state.capturing) {
|
|
||||||
if (state.pending) {
|
|
||||||
state.capture += state.pending;
|
|
||||||
state.pending = '';
|
|
||||||
}
|
|
||||||
const consumed = consumeToolCapture(state.capture, toolNames);
|
|
||||||
if (!consumed.ready) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
state.capture = '';
|
|
||||||
state.capturing = false;
|
|
||||||
if (consumed.prefix) {
|
|
||||||
events.push({ type: 'text', text: consumed.prefix });
|
|
||||||
}
|
|
||||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
|
||||||
events.push({ type: 'tool_calls', calls: consumed.calls });
|
|
||||||
}
|
|
||||||
if (consumed.suffix) {
|
|
||||||
state.pending += consumed.suffix;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!state.pending) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const start = findToolSegmentStart(state.pending);
|
|
||||||
if (start >= 0) {
|
|
||||||
const prefix = state.pending.slice(0, start);
|
|
||||||
if (prefix) {
|
|
||||||
events.push({ type: 'text', text: prefix });
|
|
||||||
}
|
|
||||||
state.capture = state.pending.slice(start);
|
|
||||||
state.pending = '';
|
|
||||||
state.capturing = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
const [safe, hold] = splitSafeContentForToolDetection(state.pending);
|
|
||||||
if (!safe) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
state.pending = hold;
|
|
||||||
events.push({ type: 'text', text: safe });
|
|
||||||
}
|
|
||||||
return events;
|
|
||||||
}
|
|
||||||
|
|
||||||
function flushToolSieve(state, toolNames) {
|
|
||||||
if (!state) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const events = processToolSieveChunk(state, '', toolNames);
|
|
||||||
if (state.capturing) {
|
|
||||||
const consumed = consumeToolCapture(state.capture, toolNames);
|
|
||||||
if (consumed.ready) {
|
|
||||||
if (consumed.prefix) {
|
|
||||||
events.push({ type: 'text', text: consumed.prefix });
|
|
||||||
}
|
|
||||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
|
||||||
events.push({ type: 'tool_calls', calls: consumed.calls });
|
|
||||||
}
|
|
||||||
if (consumed.suffix) {
|
|
||||||
events.push({ type: 'text', text: consumed.suffix });
|
|
||||||
}
|
|
||||||
} else if (state.capture) {
|
|
||||||
// Incomplete captured tool JSON at stream end: suppress raw capture.
|
|
||||||
}
|
|
||||||
state.capture = '';
|
|
||||||
state.capturing = false;
|
|
||||||
}
|
|
||||||
if (state.pending) {
|
|
||||||
events.push({ type: 'text', text: state.pending });
|
|
||||||
state.pending = '';
|
|
||||||
}
|
|
||||||
return events;
|
|
||||||
}
|
|
||||||
|
|
||||||
function splitSafeContentForToolDetection(s) {
|
|
||||||
const text = s || '';
|
|
||||||
if (!text) {
|
|
||||||
return ['', ''];
|
|
||||||
}
|
|
||||||
const suspiciousStart = findSuspiciousPrefixStart(text);
|
|
||||||
if (suspiciousStart < 0) {
|
|
||||||
return [text, ''];
|
|
||||||
}
|
|
||||||
if (suspiciousStart > 0) {
|
|
||||||
return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)];
|
|
||||||
}
|
|
||||||
// If suspicious content starts at the beginning, keep holding until we can
|
|
||||||
// either parse a full tool JSON block or reach stream flush.
|
|
||||||
return ['', text];
|
|
||||||
}
|
|
||||||
|
|
||||||
function findSuspiciousPrefixStart(s) {
|
|
||||||
let start = -1;
|
|
||||||
for (const needle of ['{', '[', '```']) {
|
|
||||||
const idx = s.lastIndexOf(needle);
|
|
||||||
if (idx > start) {
|
|
||||||
start = idx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return start;
|
|
||||||
}
|
|
||||||
|
|
||||||
function findToolSegmentStart(s) {
|
|
||||||
if (!s) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
const lower = s.toLowerCase();
|
|
||||||
const keyIdx = lower.indexOf('tool_calls');
|
|
||||||
if (keyIdx < 0) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
|
||||||
return start >= 0 ? start : keyIdx;
|
|
||||||
}
|
|
||||||
|
|
||||||
function consumeToolCapture(captured, toolNames) {
|
|
||||||
if (!captured) {
|
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
|
||||||
}
|
|
||||||
const lower = captured.toLowerCase();
|
|
||||||
const keyIdx = lower.indexOf('tool_calls');
|
|
||||||
if (keyIdx < 0) {
|
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
|
||||||
}
|
|
||||||
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
|
||||||
if (start < 0) {
|
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
|
||||||
}
|
|
||||||
const obj = extractJSONObjectFrom(captured, start);
|
|
||||||
if (!obj.ok) {
|
|
||||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
|
||||||
}
|
|
||||||
const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames);
|
|
||||||
if (parsed.length === 0) {
|
|
||||||
// `tool_calls` key exists but strict JSON parse failed.
|
|
||||||
// Drop the captured object body to avoid leaking raw tool JSON.
|
|
||||||
return {
|
|
||||||
ready: true,
|
|
||||||
prefix: captured.slice(0, start),
|
|
||||||
calls: [],
|
|
||||||
suffix: captured.slice(obj.end),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
ready: true,
|
|
||||||
prefix: captured.slice(0, start),
|
|
||||||
calls: parsed,
|
|
||||||
suffix: captured.slice(obj.end),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractJSONObjectFrom(text, start) {
|
|
||||||
if (!text || start < 0 || start >= text.length || text[start] !== '{') {
|
|
||||||
return { ok: false, end: 0 };
|
|
||||||
}
|
|
||||||
let depth = 0;
|
|
||||||
let quote = '';
|
|
||||||
let escaped = false;
|
|
||||||
for (let i = start; i < text.length; i += 1) {
|
|
||||||
const ch = text[i];
|
|
||||||
if (quote) {
|
|
||||||
if (escaped) {
|
|
||||||
escaped = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '\\') {
|
|
||||||
escaped = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === quote) {
|
|
||||||
quote = '';
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '"' || ch === "'") {
|
|
||||||
quote = ch;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '{') {
|
|
||||||
depth += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ch === '}') {
|
|
||||||
depth -= 1;
|
|
||||||
if (depth === 0) {
|
|
||||||
return { ok: true, end: i + 1 };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { ok: false, end: 0 };
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCalls(text, toolNames) {
|
|
||||||
if (!toStringSafe(text)) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const candidates = buildToolCallCandidates(text);
|
|
||||||
let parsed = [];
|
|
||||||
for (const c of candidates) {
|
|
||||||
parsed = parseToolCallsPayload(c);
|
|
||||||
if (parsed.length > 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (parsed.length === 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const allowed = new Set((toolNames || []).filter(Boolean));
|
|
||||||
const out = [];
|
|
||||||
for (const tc of parsed) {
|
|
||||||
if (!tc || !tc.name) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (allowed.size > 0 && !allowed.has(tc.name)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
out.push({ name: tc.name, input: tc.input || {} });
|
|
||||||
}
|
|
||||||
if (out.length === 0 && parsed.length > 0) {
|
|
||||||
for (const tc of parsed) {
|
|
||||||
if (!tc || !tc.name) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
out.push({ name: tc.name, input: tc.input || {} });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
function buildToolCallCandidates(text) {
|
|
||||||
const trimmed = toStringSafe(text);
|
|
||||||
const candidates = [trimmed];
|
|
||||||
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
|
|
||||||
for (const block of fenced) {
|
|
||||||
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
|
||||||
if (m && m[1]) {
|
|
||||||
candidates.push(toStringSafe(m[1]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const candidate of extractToolCallObjects(trimmed)) {
|
|
||||||
candidates.push(toStringSafe(candidate));
|
|
||||||
}
|
|
||||||
const first = trimmed.indexOf('{');
|
|
||||||
const last = trimmed.lastIndexOf('}');
|
|
||||||
if (first >= 0 && last > first) {
|
|
||||||
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
|
|
||||||
}
|
|
||||||
const m = trimmed.match(TOOL_CALL_PATTERN);
|
|
||||||
if (m && m[1]) {
|
|
||||||
candidates.push(`{"tool_calls":[${m[1]}]}`);
|
|
||||||
}
|
|
||||||
return [...new Set(candidates.filter(Boolean))];
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractToolCallObjects(text) {
|
|
||||||
const raw = toStringSafe(text);
|
|
||||||
if (!raw) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const lower = raw.toLowerCase();
|
|
||||||
const out = [];
|
|
||||||
let offset = 0;
|
|
||||||
// eslint-disable-next-line no-constant-condition
|
|
||||||
while (true) {
|
|
||||||
let idx = lower.indexOf('tool_calls', offset);
|
|
||||||
if (idx < 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let start = raw.slice(0, idx).lastIndexOf('{');
|
|
||||||
while (start >= 0) {
|
|
||||||
const obj = extractJSONObjectFrom(raw, start);
|
|
||||||
if (obj.ok) {
|
|
||||||
out.push(raw.slice(start, obj.end).trim());
|
|
||||||
offset = obj.end;
|
|
||||||
idx = -1;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
start = raw.slice(0, start).lastIndexOf('{');
|
|
||||||
}
|
|
||||||
if (idx >= 0) {
|
|
||||||
offset = idx + 'tool_calls'.length;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallsPayload(payload) {
|
|
||||||
let decoded;
|
|
||||||
try {
|
|
||||||
decoded = JSON.parse(payload);
|
|
||||||
} catch (_err) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (Array.isArray(decoded)) {
|
|
||||||
return parseToolCallList(decoded);
|
|
||||||
}
|
|
||||||
if (!decoded || typeof decoded !== 'object') {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
if (decoded.tool_calls) {
|
|
||||||
return parseToolCallList(decoded.tool_calls);
|
|
||||||
}
|
|
||||||
const one = parseToolCallItem(decoded);
|
|
||||||
return one ? [one] : [];
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallList(v) {
|
|
||||||
if (!Array.isArray(v)) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const out = [];
|
|
||||||
for (const item of v) {
|
|
||||||
if (!item || typeof item !== 'object') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const one = parseToolCallItem(item);
|
|
||||||
if (one) {
|
|
||||||
out.push(one);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallItem(m) {
|
|
||||||
let name = toStringSafe(m.name);
|
|
||||||
let inputRaw = m.input;
|
|
||||||
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
|
||||||
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
|
||||||
if (fn) {
|
|
||||||
if (!name) {
|
|
||||||
name = toStringSafe(fn.name);
|
|
||||||
}
|
|
||||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
|
|
||||||
inputRaw = fn.arguments;
|
|
||||||
hasInput = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!hasInput) {
|
|
||||||
for (const k of ['arguments', 'args', 'parameters', 'params']) {
|
|
||||||
if (Object.prototype.hasOwnProperty.call(m, k)) {
|
|
||||||
inputRaw = m[k];
|
|
||||||
hasInput = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!name) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
name,
|
|
||||||
input: parseToolCallInput(inputRaw),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
function parseToolCallInput(v) {
|
|
||||||
if (v == null) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
if (typeof v === 'string') {
|
|
||||||
const raw = toStringSafe(v);
|
|
||||||
if (!raw) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(raw);
|
|
||||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
|
||||||
return parsed;
|
|
||||||
}
|
|
||||||
return { _raw: raw };
|
|
||||||
} catch (_err) {
|
|
||||||
return { _raw: raw };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (typeof v === 'object' && !Array.isArray(v)) {
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(JSON.stringify(v));
|
|
||||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
|
||||||
return parsed;
|
|
||||||
}
|
|
||||||
} catch (_err) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatOpenAIStreamToolCalls(calls) {
|
|
||||||
if (!Array.isArray(calls) || calls.length === 0) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
return calls.map((c, idx) => ({
|
|
||||||
index: idx,
|
|
||||||
id: `call_${newCallID()}`,
|
|
||||||
type: 'function',
|
|
||||||
function: {
|
|
||||||
name: c.name,
|
|
||||||
arguments: JSON.stringify(c.input || {}),
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
function newCallID() {
|
|
||||||
if (typeof crypto.randomUUID === 'function') {
|
|
||||||
return crypto.randomUUID().replace(/-/g, '');
|
|
||||||
}
|
|
||||||
return `${Date.now()}${Math.floor(Math.random() * 1e9)}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
function toStringSafe(v) {
|
|
||||||
if (typeof v === 'string') {
|
|
||||||
return v.trim();
|
|
||||||
}
|
|
||||||
if (Array.isArray(v)) {
|
|
||||||
return toStringSafe(v[0]);
|
|
||||||
}
|
|
||||||
if (v == null) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
return String(v).trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
module.exports = {
|
|
||||||
extractToolNames,
|
|
||||||
createToolSieveState,
|
|
||||||
processToolSieveChunk,
|
|
||||||
flushToolSieve,
|
|
||||||
parseToolCalls,
|
|
||||||
formatOpenAIStreamToolCalls,
|
|
||||||
};
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
'use strict';
|
|
||||||
|
|
||||||
const test = require('node:test');
|
|
||||||
const assert = require('node:assert/strict');
|
|
||||||
|
|
||||||
const {
|
|
||||||
extractToolNames,
|
|
||||||
createToolSieveState,
|
|
||||||
processToolSieveChunk,
|
|
||||||
flushToolSieve,
|
|
||||||
parseToolCalls,
|
|
||||||
} = require('./stream-tool-sieve');
|
|
||||||
|
|
||||||
function runSieve(chunks, toolNames) {
|
|
||||||
const state = createToolSieveState();
|
|
||||||
const events = [];
|
|
||||||
for (const chunk of chunks) {
|
|
||||||
events.push(...processToolSieveChunk(state, chunk, toolNames));
|
|
||||||
}
|
|
||||||
events.push(...flushToolSieve(state, toolNames));
|
|
||||||
return events;
|
|
||||||
}
|
|
||||||
|
|
||||||
function collectText(events) {
|
|
||||||
return events
|
|
||||||
.filter((evt) => evt.type === 'text' && evt.text)
|
|
||||||
.map((evt) => evt.text)
|
|
||||||
.join('');
|
|
||||||
}
|
|
||||||
|
|
||||||
test('extractToolNames keeps tool mode enabled with unknown fallback', () => {
|
|
||||||
const names = extractToolNames([
|
|
||||||
{ function: { description: 'no name tool' } },
|
|
||||||
{ function: { name: ' read_file ' } },
|
|
||||||
{},
|
|
||||||
]);
|
|
||||||
assert.deepEqual(names, ['unknown', 'read_file', 'unknown']);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('parseToolCalls keeps non-object argument strings as _raw (Go parity)', () => {
|
|
||||||
const payload = JSON.stringify({
|
|
||||||
tool_calls: [
|
|
||||||
{ name: 'read_file', input: '123' },
|
|
||||||
{ name: 'list_dir', input: '[1,2,3]' },
|
|
||||||
],
|
|
||||||
});
|
|
||||||
const calls = parseToolCalls(payload, ['read_file', 'list_dir']);
|
|
||||||
assert.deepEqual(calls, [
|
|
||||||
{ name: 'read_file', input: { _raw: '123' } },
|
|
||||||
{ name: 'list_dir', input: { _raw: '[1,2,3]' } },
|
|
||||||
]);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('parseToolCalls still intercepts unknown schema names to avoid leaks', () => {
|
|
||||||
const payload = JSON.stringify({
|
|
||||||
tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }],
|
|
||||||
});
|
|
||||||
const calls = parseToolCalls(payload, ['search']);
|
|
||||||
assert.equal(calls.length, 1);
|
|
||||||
assert.equal(calls[0].name, 'not_in_schema');
|
|
||||||
});
|
|
||||||
|
|
||||||
test('parseToolCalls supports fenced json and function.arguments string payload', () => {
|
|
||||||
const text = [
|
|
||||||
'I will call a tool now.',
|
|
||||||
'```json',
|
|
||||||
'{"tool_calls":[{"function":{"name":"read_file","arguments":"{\\"path\\":\\"README.md\\"}"}}]}',
|
|
||||||
'```',
|
|
||||||
].join('\n');
|
|
||||||
const calls = parseToolCalls(text, ['read_file']);
|
|
||||||
assert.equal(calls.length, 1);
|
|
||||||
assert.equal(calls[0].name, 'read_file');
|
|
||||||
assert.deepEqual(calls[0].input, { path: 'README.md' });
|
|
||||||
});
|
|
||||||
|
|
||||||
test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => {
|
|
||||||
const events = runSieve(
|
|
||||||
[
|
|
||||||
'{"',
|
|
||||||
'tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}',
|
|
||||||
'后置正文C。',
|
|
||||||
],
|
|
||||||
['read_file'],
|
|
||||||
);
|
|
||||||
const leakedText = collectText(events);
|
|
||||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0);
|
|
||||||
assert.equal(hasToolCall, true);
|
|
||||||
assert.equal(leakedText.includes('{'), false);
|
|
||||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
|
||||||
assert.equal(leakedText.includes('后置正文C。'), true);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('sieve drops invalid tool json body while preserving surrounding text', () => {
|
|
||||||
const events = runSieve(
|
|
||||||
[
|
|
||||||
'前置正文D。',
|
|
||||||
"{'tool_calls':[{'name':'read_file','input':{'path':'README.MD'}}]}",
|
|
||||||
'后置正文E。',
|
|
||||||
],
|
|
||||||
['read_file'],
|
|
||||||
);
|
|
||||||
const leakedText = collectText(events);
|
|
||||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls');
|
|
||||||
assert.equal(hasToolCall, false);
|
|
||||||
assert.equal(leakedText.includes('前置正文D。'), true);
|
|
||||||
assert.equal(leakedText.includes('后置正文E。'), true);
|
|
||||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('sieve suppresses incomplete captured tool json on stream finalize', () => {
|
|
||||||
const events = runSieve(
|
|
||||||
['前置正文F。', '{"tool_calls":[{"name":"read_file"'],
|
|
||||||
['read_file'],
|
|
||||||
);
|
|
||||||
const leakedText = collectText(events);
|
|
||||||
assert.equal(leakedText.includes('前置正文F。'), true);
|
|
||||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
|
||||||
assert.equal(leakedText.includes('{'), false);
|
|
||||||
});
|
|
||||||
|
|
||||||
test('sieve keeps plain text intact in tool mode when no tool call appears', () => {
|
|
||||||
const events = runSieve(
|
|
||||||
['你好,', '这是普通文本回复。', '请继续。'],
|
|
||||||
['read_file'],
|
|
||||||
);
|
|
||||||
const leakedText = collectText(events);
|
|
||||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls');
|
|
||||||
assert.equal(hasToolCall, false);
|
|
||||||
assert.equal(leakedText, '你好,这是普通文本回复。请继续。');
|
|
||||||
});
|
|
||||||
@@ -2,6 +2,8 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -28,10 +30,21 @@ func main() {
|
|||||||
Addr: "0.0.0.0:" + port,
|
Addr: "0.0.0.0:" + port,
|
||||||
Handler: app.Router,
|
Handler: app.Router,
|
||||||
}
|
}
|
||||||
|
localURL := fmt.Sprintf("http://127.0.0.1:%s", port)
|
||||||
|
lanIP := detectLANIPv4()
|
||||||
|
lanURL := ""
|
||||||
|
if lanIP != "" {
|
||||||
|
lanURL = fmt.Sprintf("http://%s:%s", lanIP, port)
|
||||||
|
}
|
||||||
|
|
||||||
// Start server in a goroutine so we can listen for shutdown signals.
|
// Start server in a goroutine so we can listen for shutdown signals.
|
||||||
go func() {
|
go func() {
|
||||||
config.Logger.Info("starting ds2api", "port", port)
|
if lanURL != "" {
|
||||||
|
config.Logger.Info("starting ds2api", "bind", srv.Addr, "port", port, "local_url", localURL, "lan_url", lanURL, "lan_ip", lanIP)
|
||||||
|
} else {
|
||||||
|
config.Logger.Info("starting ds2api", "bind", srv.Addr, "port", port, "local_url", localURL)
|
||||||
|
config.Logger.Warn("lan ip not detected; check active network interfaces")
|
||||||
|
}
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
config.Logger.Error("server stopped unexpectedly", "error", err)
|
config.Logger.Error("server stopped unexpectedly", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
@@ -54,3 +67,36 @@ func main() {
|
|||||||
}
|
}
|
||||||
config.Logger.Info("server gracefully stopped")
|
config.Logger.Info("server gracefully stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func detectLANIPv4() string {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = ip.To4()
|
||||||
|
if ip == nil || !ip.IsPrivate() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return ip.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,5 +24,27 @@
|
|||||||
"password": "your-password-3",
|
"password": "your-password-3",
|
||||||
"token": ""
|
"token": ""
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
"model_aliases": {
|
||||||
|
"gpt-4o": "deepseek-chat",
|
||||||
|
"gpt-5-codex": "deepseek-reasoner",
|
||||||
|
"o3": "deepseek-reasoner"
|
||||||
|
},
|
||||||
|
"compat": {
|
||||||
|
"wide_input_strict_output": true
|
||||||
|
},
|
||||||
|
"toolcall": {
|
||||||
|
"mode": "feature_match",
|
||||||
|
"early_emit_confidence": "high"
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"store_ttl_seconds": 900
|
||||||
|
},
|
||||||
|
"embeddings": {
|
||||||
|
"provider": "deterministic"
|
||||||
|
},
|
||||||
|
"claude_model_mapping": {
|
||||||
|
"fast": "deepseek-chat",
|
||||||
|
"slow": "deepseek-reasoner"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,302 +0,0 @@
|
|||||||
package account
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"ds2api/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Pool struct {
|
|
||||||
store *config.Store
|
|
||||||
mu sync.Mutex
|
|
||||||
queue []string
|
|
||||||
inUse map[string]int
|
|
||||||
waiters []chan struct{}
|
|
||||||
maxInflightPerAccount int
|
|
||||||
recommendedConcurrency int
|
|
||||||
maxQueueSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPool(store *config.Store) *Pool {
|
|
||||||
p := &Pool{
|
|
||||||
store: store,
|
|
||||||
inUse: map[string]int{},
|
|
||||||
maxInflightPerAccount: maxInflightFromEnv(),
|
|
||||||
}
|
|
||||||
p.Reset()
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) Reset() {
|
|
||||||
accounts := p.store.Accounts()
|
|
||||||
sort.SliceStable(accounts, func(i, j int) bool {
|
|
||||||
iHas := accounts[i].Token != ""
|
|
||||||
jHas := accounts[j].Token != ""
|
|
||||||
if iHas == jHas {
|
|
||||||
return i < j
|
|
||||||
}
|
|
||||||
return iHas
|
|
||||||
})
|
|
||||||
ids := make([]string, 0, len(accounts))
|
|
||||||
for _, a := range accounts {
|
|
||||||
id := a.Identifier()
|
|
||||||
if id != "" {
|
|
||||||
ids = append(ids, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
|
|
||||||
queueLimit := maxQueueFromEnv(recommended)
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
p.drainWaitersLocked()
|
|
||||||
p.queue = ids
|
|
||||||
p.inUse = map[string]int{}
|
|
||||||
p.recommendedConcurrency = recommended
|
|
||||||
p.maxQueueSize = queueLimit
|
|
||||||
config.Logger.Info(
|
|
||||||
"[init_account_queue] initialized",
|
|
||||||
"total", len(ids),
|
|
||||||
"max_inflight_per_account", p.maxInflightPerAccount,
|
|
||||||
"recommended_concurrency", p.recommendedConcurrency,
|
|
||||||
"max_queue_size", p.maxQueueSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
return p.acquireLocked(target, normalizeExclude(exclude))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
exclude = normalizeExclude(exclude)
|
|
||||||
for {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return config.Account{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
p.mu.Lock()
|
|
||||||
if acc, ok := p.acquireLocked(target, exclude); ok {
|
|
||||||
p.mu.Unlock()
|
|
||||||
return acc, true
|
|
||||||
}
|
|
||||||
if !p.canQueueLocked(target, exclude) {
|
|
||||||
p.mu.Unlock()
|
|
||||||
return config.Account{}, false
|
|
||||||
}
|
|
||||||
waiter := make(chan struct{})
|
|
||||||
p.waiters = append(p.waiters, waiter)
|
|
||||||
p.mu.Unlock()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
p.mu.Lock()
|
|
||||||
p.removeWaiterLocked(waiter)
|
|
||||||
p.mu.Unlock()
|
|
||||||
return config.Account{}, false
|
|
||||||
case <-waiter:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
|
|
||||||
if target != "" {
|
|
||||||
if exclude[target] || p.inUse[target] >= p.maxInflightPerAccount {
|
|
||||||
return config.Account{}, false
|
|
||||||
}
|
|
||||||
acc, ok := p.store.FindAccount(target)
|
|
||||||
if !ok {
|
|
||||||
return config.Account{}, false
|
|
||||||
}
|
|
||||||
p.inUse[target]++
|
|
||||||
p.bumpQueue(target)
|
|
||||||
return acc, true
|
|
||||||
}
|
|
||||||
|
|
||||||
if acc, ok := p.tryAcquire(exclude, true); ok {
|
|
||||||
return acc, true
|
|
||||||
}
|
|
||||||
if acc, ok := p.tryAcquire(exclude, false); ok {
|
|
||||||
return acc, true
|
|
||||||
}
|
|
||||||
return config.Account{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
|
|
||||||
for i := 0; i < len(p.queue); i++ {
|
|
||||||
id := p.queue[i]
|
|
||||||
if exclude[id] || p.inUse[id] >= p.maxInflightPerAccount {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
acc, ok := p.store.FindAccount(id)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if requireToken && acc.Token == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.inUse[id]++
|
|
||||||
p.bumpQueue(id)
|
|
||||||
return acc, true
|
|
||||||
}
|
|
||||||
return config.Account{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) bumpQueue(accountID string) {
|
|
||||||
for i, id := range p.queue {
|
|
||||||
if id != accountID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.queue = append(p.queue[:i], p.queue[i+1:]...)
|
|
||||||
p.queue = append(p.queue, accountID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) Release(accountID string) {
|
|
||||||
if accountID == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
count := p.inUse[accountID]
|
|
||||||
if count <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if count == 1 {
|
|
||||||
delete(p.inUse, accountID)
|
|
||||||
p.notifyWaiterLocked()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.inUse[accountID] = count - 1
|
|
||||||
p.notifyWaiterLocked()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) Status() map[string]any {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
available := make([]string, 0, len(p.queue))
|
|
||||||
inUseAccounts := make([]string, 0, len(p.inUse))
|
|
||||||
inUseSlots := 0
|
|
||||||
for _, id := range p.queue {
|
|
||||||
if p.inUse[id] < p.maxInflightPerAccount {
|
|
||||||
available = append(available, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for id, count := range p.inUse {
|
|
||||||
if count > 0 {
|
|
||||||
inUseAccounts = append(inUseAccounts, id)
|
|
||||||
inUseSlots += count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sort.Strings(inUseAccounts)
|
|
||||||
return map[string]any{
|
|
||||||
"available": len(available),
|
|
||||||
"in_use": inUseSlots,
|
|
||||||
"total": len(p.store.Accounts()),
|
|
||||||
"available_accounts": available,
|
|
||||||
"in_use_accounts": inUseAccounts,
|
|
||||||
"max_inflight_per_account": p.maxInflightPerAccount,
|
|
||||||
"recommended_concurrency": p.recommendedConcurrency,
|
|
||||||
"waiting": len(p.waiters),
|
|
||||||
"max_queue_size": p.maxQueueSize,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func maxInflightFromEnv() int {
|
|
||||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
|
||||||
raw := strings.TrimSpace(os.Getenv(key))
|
|
||||||
if raw == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
n, err := strconv.Atoi(raw)
|
|
||||||
if err == nil && n > 0 {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
|
|
||||||
if accountCount <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
if maxInflightPerAccount <= 0 {
|
|
||||||
maxInflightPerAccount = 2
|
|
||||||
}
|
|
||||||
return accountCount * maxInflightPerAccount
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeExclude(exclude map[string]bool) map[string]bool {
|
|
||||||
if exclude == nil {
|
|
||||||
return map[string]bool{}
|
|
||||||
}
|
|
||||||
return exclude
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
|
|
||||||
if target != "" {
|
|
||||||
if exclude[target] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if _, ok := p.store.FindAccount(target); !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if p.maxQueueSize <= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return len(p.waiters) < p.maxQueueSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) notifyWaiterLocked() {
|
|
||||||
if len(p.waiters) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
waiter := p.waiters[0]
|
|
||||||
p.waiters = p.waiters[1:]
|
|
||||||
close(waiter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
|
|
||||||
for i, w := range p.waiters {
|
|
||||||
if w != waiter {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pool) drainWaitersLocked() {
|
|
||||||
for _, waiter := range p.waiters {
|
|
||||||
close(waiter)
|
|
||||||
}
|
|
||||||
p.waiters = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func maxQueueFromEnv(defaultSize int) int {
|
|
||||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
|
||||||
raw := strings.TrimSpace(os.Getenv(key))
|
|
||||||
if raw == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
n, err := strconv.Atoi(raw)
|
|
||||||
if err == nil && n >= 0 {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if defaultSize < 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return defaultSize
|
|
||||||
}
|
|
||||||
108
internal/account/pool_acquire.go
Normal file
108
internal/account/pool_acquire.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.acquireLocked(target, normalizeExclude(exclude))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
exclude = normalizeExclude(exclude)
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return config.Account{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
if acc, ok := p.acquireLocked(target, exclude); ok {
|
||||||
|
p.mu.Unlock()
|
||||||
|
return acc, true
|
||||||
|
}
|
||||||
|
if !p.canQueueLocked(target, exclude) {
|
||||||
|
p.mu.Unlock()
|
||||||
|
return config.Account{}, false
|
||||||
|
}
|
||||||
|
waiter := make(chan struct{})
|
||||||
|
p.waiters = append(p.waiters, waiter)
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
p.mu.Lock()
|
||||||
|
p.removeWaiterLocked(waiter)
|
||||||
|
p.mu.Unlock()
|
||||||
|
return config.Account{}, false
|
||||||
|
case <-waiter:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
|
||||||
|
if target != "" {
|
||||||
|
if exclude[target] || !p.canAcquireIDLocked(target) {
|
||||||
|
return config.Account{}, false
|
||||||
|
}
|
||||||
|
acc, ok := p.store.FindAccount(target)
|
||||||
|
if !ok {
|
||||||
|
return config.Account{}, false
|
||||||
|
}
|
||||||
|
p.inUse[target]++
|
||||||
|
p.bumpQueue(target)
|
||||||
|
return acc, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if acc, ok := p.tryAcquire(exclude, true); ok {
|
||||||
|
return acc, true
|
||||||
|
}
|
||||||
|
if acc, ok := p.tryAcquire(exclude, false); ok {
|
||||||
|
return acc, true
|
||||||
|
}
|
||||||
|
return config.Account{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
|
||||||
|
for i := 0; i < len(p.queue); i++ {
|
||||||
|
id := p.queue[i]
|
||||||
|
if exclude[id] || !p.canAcquireIDLocked(id) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
acc, ok := p.store.FindAccount(id)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if requireToken && acc.Token == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p.inUse[id]++
|
||||||
|
p.bumpQueue(id)
|
||||||
|
return acc, true
|
||||||
|
}
|
||||||
|
return config.Account{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) bumpQueue(accountID string) {
|
||||||
|
for i, id := range p.queue {
|
||||||
|
if id != accountID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p.queue = append(p.queue[:i], p.queue[i+1:]...)
|
||||||
|
p.queue = append(p.queue, accountID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeExclude(exclude map[string]bool) map[string]bool {
|
||||||
|
if exclude == nil {
|
||||||
|
return map[string]bool{}
|
||||||
|
}
|
||||||
|
return exclude
|
||||||
|
}
|
||||||
132
internal/account/pool_core.go
Normal file
132
internal/account/pool_core.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Pool struct {
|
||||||
|
store *config.Store
|
||||||
|
mu sync.Mutex
|
||||||
|
queue []string
|
||||||
|
inUse map[string]int
|
||||||
|
waiters []chan struct{}
|
||||||
|
maxInflightPerAccount int
|
||||||
|
recommendedConcurrency int
|
||||||
|
maxQueueSize int
|
||||||
|
globalMaxInflight int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPool(store *config.Store) *Pool {
|
||||||
|
maxPer := 2
|
||||||
|
if store != nil {
|
||||||
|
maxPer = store.RuntimeAccountMaxInflight()
|
||||||
|
}
|
||||||
|
p := &Pool{
|
||||||
|
store: store,
|
||||||
|
inUse: map[string]int{},
|
||||||
|
maxInflightPerAccount: maxPer,
|
||||||
|
}
|
||||||
|
p.Reset()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) Reset() {
|
||||||
|
accounts := p.store.Accounts()
|
||||||
|
sort.SliceStable(accounts, func(i, j int) bool {
|
||||||
|
iHas := accounts[i].Token != ""
|
||||||
|
jHas := accounts[j].Token != ""
|
||||||
|
if iHas == jHas {
|
||||||
|
return i < j
|
||||||
|
}
|
||||||
|
return iHas
|
||||||
|
})
|
||||||
|
ids := make([]string, 0, len(accounts))
|
||||||
|
for _, a := range accounts {
|
||||||
|
id := a.Identifier()
|
||||||
|
if id != "" {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p.store != nil {
|
||||||
|
p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight()
|
||||||
|
} else {
|
||||||
|
p.maxInflightPerAccount = maxInflightFromEnv()
|
||||||
|
}
|
||||||
|
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
|
||||||
|
queueLimit := maxQueueFromEnv(recommended)
|
||||||
|
globalLimit := recommended
|
||||||
|
if p.store != nil {
|
||||||
|
queueLimit = p.store.RuntimeAccountMaxQueue(recommended)
|
||||||
|
globalLimit = p.store.RuntimeGlobalMaxInflight(recommended)
|
||||||
|
}
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
p.drainWaitersLocked()
|
||||||
|
p.queue = ids
|
||||||
|
p.inUse = map[string]int{}
|
||||||
|
p.recommendedConcurrency = recommended
|
||||||
|
p.maxQueueSize = queueLimit
|
||||||
|
p.globalMaxInflight = globalLimit
|
||||||
|
config.Logger.Info(
|
||||||
|
"[init_account_queue] initialized",
|
||||||
|
"total", len(ids),
|
||||||
|
"max_inflight_per_account", p.maxInflightPerAccount,
|
||||||
|
"global_max_inflight", p.globalMaxInflight,
|
||||||
|
"recommended_concurrency", p.recommendedConcurrency,
|
||||||
|
"max_queue_size", p.maxQueueSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) Release(accountID string) {
|
||||||
|
if accountID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
count := p.inUse[accountID]
|
||||||
|
if count <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if count == 1 {
|
||||||
|
delete(p.inUse, accountID)
|
||||||
|
p.notifyWaiterLocked()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.inUse[accountID] = count - 1
|
||||||
|
p.notifyWaiterLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) Status() map[string]any {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
available := make([]string, 0, len(p.queue))
|
||||||
|
inUseAccounts := make([]string, 0, len(p.inUse))
|
||||||
|
inUseSlots := 0
|
||||||
|
for _, id := range p.queue {
|
||||||
|
if p.inUse[id] < p.maxInflightPerAccount {
|
||||||
|
available = append(available, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id, count := range p.inUse {
|
||||||
|
if count > 0 {
|
||||||
|
inUseAccounts = append(inUseAccounts, id)
|
||||||
|
inUseSlots += count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(inUseAccounts)
|
||||||
|
return map[string]any{
|
||||||
|
"available": len(available),
|
||||||
|
"in_use": inUseSlots,
|
||||||
|
"total": len(p.store.Accounts()),
|
||||||
|
"available_accounts": available,
|
||||||
|
"in_use_accounts": inUseAccounts,
|
||||||
|
"max_inflight_per_account": p.maxInflightPerAccount,
|
||||||
|
"global_max_inflight": p.globalMaxInflight,
|
||||||
|
"recommended_concurrency": p.recommendedConcurrency,
|
||||||
|
"waiting": len(p.waiters),
|
||||||
|
"max_queue_size": p.maxQueueSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
249
internal/account/pool_edge_test.go
Normal file
249
internal/account/pool_edge_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ─── Pool edge cases ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestPoolEmptyNoAccounts(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "2")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
|
||||||
|
pool := NewPool(config.LoadStore())
|
||||||
|
if _, ok := pool.Acquire("", nil); ok {
|
||||||
|
t.Fatal("expected acquire to fail with no accounts")
|
||||||
|
}
|
||||||
|
status := pool.Status()
|
||||||
|
if total, ok := status["total"].(int); !ok || total != 0 {
|
||||||
|
t.Fatalf("unexpected total: %#v", status["total"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolReleaseNonExistentAccount(t *testing.T) {
|
||||||
|
pool := newPoolForTest(t, "2")
|
||||||
|
pool.Release("nonexistent@example.com") // should not panic
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolReleaseAlreadyReleased(t *testing.T) {
|
||||||
|
pool := newPoolForTest(t, "2")
|
||||||
|
acc, ok := pool.Acquire("", nil)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected acquire success")
|
||||||
|
}
|
||||||
|
pool.Release(acc.Identifier())
|
||||||
|
pool.Release(acc.Identifier()) // double release should not panic
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolAcquireTargetNotFound(t *testing.T) {
|
||||||
|
pool := newPoolForTest(t, "2")
|
||||||
|
if _, ok := pool.Acquire("nonexistent@example.com", nil); ok {
|
||||||
|
t.Fatal("expected acquire to fail for non-existent target")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolAcquireWithExclusionList(t *testing.T) {
|
||||||
|
pool := newPoolForTest(t, "2")
|
||||||
|
acc, ok := pool.Acquire("", map[string]bool{"acc1@example.com": true})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected acquire success with exclusion")
|
||||||
|
}
|
||||||
|
if acc.Identifier() != "acc2@example.com" {
|
||||||
|
t.Fatalf("expected acc2 when acc1 excluded, got %q", acc.Identifier())
|
||||||
|
}
|
||||||
|
pool.Release(acc.Identifier())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolAcquireAllExcluded(t *testing.T) {
|
||||||
|
pool := newPoolForTest(t, "2")
|
||||||
|
if _, ok := pool.Acquire("", map[string]bool{
|
||||||
|
"acc1@example.com": true,
|
||||||
|
"acc2@example.com": true,
|
||||||
|
}); ok {
|
||||||
|
t.Fatal("expected acquire to fail when all accounts excluded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolStatusFields(t *testing.T) {
|
||||||
|
pool := newPoolForTest(t, "2")
|
||||||
|
status := pool.Status()
|
||||||
|
|
||||||
|
// Check all expected fields are present
|
||||||
|
for _, key := range []string{"total", "available", "max_inflight_per_account", "recommended_concurrency", "available_accounts", "in_use_accounts", "waiting", "max_queue_size"} {
|
||||||
|
if _, ok := status[key]; !ok {
|
||||||
|
t.Fatalf("missing status field: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolStatusAccountDetails(t *testing.T) {
|
||||||
|
pool := newPoolForTest(t, "2")
|
||||||
|
acc, _ := pool.Acquire("acc1@example.com", nil)
|
||||||
|
|
||||||
|
status := pool.Status()
|
||||||
|
inUseAccounts, ok := status["in_use_accounts"].([]string)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("unexpected in_use_accounts type: %T", status["in_use_accounts"])
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, id := range inUseAccounts {
|
||||||
|
if id == "acc1@example.com" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected acc1 in in_use_accounts, got %v", inUseAccounts)
|
||||||
|
}
|
||||||
|
if status["in_use"] != 1 {
|
||||||
|
t.Fatalf("expected 1 in_use, got %v", status["in_use"])
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.Release(acc.Identifier())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolAcquireWaitContextCancelled(t *testing.T) {
|
||||||
|
pool := newSingleAccountPoolForTest(t, "1")
|
||||||
|
// Exhaust the pool
|
||||||
|
first, ok := pool.Acquire("", nil)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected first acquire to succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
var waitOK bool
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, waitOK = pool.AcquireWait(ctx, "", nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait until queued
|
||||||
|
waitForWaitingCount(t, pool, 1)
|
||||||
|
|
||||||
|
// Cancel context
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
if waitOK {
|
||||||
|
t.Fatal("expected acquire to fail after context cancellation")
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.Release(first.Identifier())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolAcquireWaitTargetAccount(t *testing.T) {
|
||||||
|
pool := newPoolForTest(t, "1")
|
||||||
|
// Exhaust acc1
|
||||||
|
acc1, ok := pool.Acquire("acc1@example.com", nil)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected acquire acc1 success")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire acc2 directly (should succeed since acc2 is free)
|
||||||
|
ctx := context.Background()
|
||||||
|
acc2, ok := pool.AcquireWait(ctx, "acc2@example.com", nil)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected acquire acc2 success via AcquireWait")
|
||||||
|
}
|
||||||
|
if acc2.Identifier() != "acc2@example.com" {
|
||||||
|
t.Fatalf("expected acc2, got %q", acc2.Identifier())
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.Release(acc1.Identifier())
|
||||||
|
pool.Release(acc2.Identifier())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolMaxQueueSizeOverride(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "5")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
|
||||||
|
pool := NewPool(config.LoadStore())
|
||||||
|
status := pool.Status()
|
||||||
|
if got, ok := status["max_queue_size"].(int); !ok || got != 5 {
|
||||||
|
t.Fatalf("expected max_queue_size=5, got %#v", status["max_queue_size"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolQueueSizeAliasEnv(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "7")
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
|
||||||
|
pool := NewPool(config.LoadStore())
|
||||||
|
status := pool.Status()
|
||||||
|
if got, ok := status["max_queue_size"].(int); !ok || got != 7 {
|
||||||
|
t.Fatalf("expected max_queue_size=7, got %#v", status["max_queue_size"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolMultipleAcquireReleaseCycles(t *testing.T) {
|
||||||
|
pool := newSingleAccountPoolForTest(t, "1")
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
acc, ok := pool.Acquire("", nil)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("acquire failed at cycle %d", i)
|
||||||
|
}
|
||||||
|
pool.Release(acc.Identifier())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPoolConcurrentAcquireWait(t *testing.T) {
|
||||||
|
pool := newSingleAccountPoolForTest(t, "1")
|
||||||
|
first, ok := pool.Acquire("", nil)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected first acquire success")
|
||||||
|
}
|
||||||
|
|
||||||
|
const waiters = 3
|
||||||
|
results := make(chan bool, waiters)
|
||||||
|
|
||||||
|
for i := 0; i < waiters; i++ {
|
||||||
|
go func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, ok := pool.AcquireWait(ctx, "", nil)
|
||||||
|
results <- ok
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all to be queued (only 1 can queue)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Release and allow queued requests to proceed
|
||||||
|
pool.Release(first.Identifier())
|
||||||
|
|
||||||
|
successCount := 0
|
||||||
|
timeoutCount := 0
|
||||||
|
for i := 0; i < waiters; i++ {
|
||||||
|
select {
|
||||||
|
case ok := <-results:
|
||||||
|
if ok {
|
||||||
|
successCount++
|
||||||
|
// Release for next waiter
|
||||||
|
pool.Release("acc1@example.com")
|
||||||
|
} else {
|
||||||
|
timeoutCount++
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// At least 1 should succeed; 2 may fail due to queue limit
|
||||||
|
if successCount < 1 {
|
||||||
|
t.Fatalf("expected at least 1 success, got success=%d timeout=%d", successCount, timeoutCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
91
internal/account/pool_limits.go
Normal file
91
internal/account/pool_limits.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) {
|
||||||
|
if maxInflightPerAccount <= 0 {
|
||||||
|
maxInflightPerAccount = 1
|
||||||
|
}
|
||||||
|
if maxQueueSize < 0 {
|
||||||
|
maxQueueSize = 0
|
||||||
|
}
|
||||||
|
if globalMaxInflight <= 0 {
|
||||||
|
globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts())
|
||||||
|
if globalMaxInflight <= 0 {
|
||||||
|
globalMaxInflight = maxInflightPerAccount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
p.maxInflightPerAccount = maxInflightPerAccount
|
||||||
|
p.maxQueueSize = maxQueueSize
|
||||||
|
p.globalMaxInflight = globalMaxInflight
|
||||||
|
p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount)
|
||||||
|
p.notifyWaiterLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func maxInflightFromEnv() int {
|
||||||
|
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
||||||
|
raw := strings.TrimSpace(os.Getenv(key))
|
||||||
|
if raw == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(raw)
|
||||||
|
if err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
|
||||||
|
if accountCount <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if maxInflightPerAccount <= 0 {
|
||||||
|
maxInflightPerAccount = 2
|
||||||
|
}
|
||||||
|
return accountCount * maxInflightPerAccount
|
||||||
|
}
|
||||||
|
|
||||||
|
func maxQueueFromEnv(defaultSize int) int {
|
||||||
|
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
||||||
|
raw := strings.TrimSpace(os.Getenv(key))
|
||||||
|
if raw == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(raw)
|
||||||
|
if err == nil && n >= 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if defaultSize < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return defaultSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) canAcquireIDLocked(accountID string) bool {
|
||||||
|
if accountID == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if p.inUse[accountID] >= p.maxInflightPerAccount {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) currentInUseLocked() int {
|
||||||
|
total := 0
|
||||||
|
for _, n := range p.inUse {
|
||||||
|
total += n
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
43
internal/account/pool_waiters.go
Normal file
43
internal/account/pool_waiters.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package account
|
||||||
|
|
||||||
|
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
|
||||||
|
if target != "" {
|
||||||
|
if exclude[target] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, ok := p.store.FindAccount(target); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p.maxQueueSize <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return len(p.waiters) < p.maxQueueSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) notifyWaiterLocked() {
|
||||||
|
if len(p.waiters) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
waiter := p.waiters[0]
|
||||||
|
p.waiters = p.waiters[1:]
|
||||||
|
close(waiter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
|
||||||
|
for i, w := range p.waiters {
|
||||||
|
if w != waiter {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) drainWaitersLocked() {
|
||||||
|
for _, waiter := range p.waiters {
|
||||||
|
close(waiter)
|
||||||
|
}
|
||||||
|
p.waiters = nil
|
||||||
|
}
|
||||||
11
internal/adapter/claude/convert.go
Normal file
11
internal/adapter/claude/convert.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ds2api/internal/claudeconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultClaudeModel = "claude-sonnet-4-5"
|
||||||
|
|
||||||
|
func convertClaudeToDeepSeek(claudeReq map[string]any, store ConfigReader) map[string]any {
|
||||||
|
return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, defaultClaudeModel)
|
||||||
|
}
|
||||||
29
internal/adapter/claude/deps.go
Normal file
29
internal/adapter/claude/deps.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthResolver interface {
|
||||||
|
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||||
|
Release(a *auth.RequestAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeepSeekCaller interface {
|
||||||
|
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
|
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
|
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigReader interface {
|
||||||
|
ClaudeMapping() map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||||
|
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||||
|
var _ ConfigReader = (*config.Store)(nil)
|
||||||
33
internal/adapter/claude/deps_injection_test.go
Normal file
33
internal/adapter/claude/deps_injection_test.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
type mockClaudeConfig struct {
|
||||||
|
m map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockClaudeConfig) ClaudeMapping() map[string]string { return m.m }
|
||||||
|
|
||||||
|
func TestNormalizeClaudeRequestUsesConfigInterfaceMapping(t *testing.T) {
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "claude-opus-4-6",
|
||||||
|
"messages": []any{
|
||||||
|
map[string]any{"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out, err := normalizeClaudeRequest(mockClaudeConfig{
|
||||||
|
m: map[string]string{
|
||||||
|
"fast": "deepseek-chat",
|
||||||
|
"slow": "deepseek-reasoner-search",
|
||||||
|
},
|
||||||
|
}, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeClaudeRequest error: %v", err)
|
||||||
|
}
|
||||||
|
if out.Standard.ResolvedModel != "deepseek-reasoner-search" {
|
||||||
|
t.Fatalf("resolved model mismatch: got=%q", out.Standard.ResolvedModel)
|
||||||
|
}
|
||||||
|
if !out.Standard.Thinking || !out.Standard.Search {
|
||||||
|
t.Fatalf("unexpected flags: thinking=%v search=%v", out.Standard.Thinking, out.Standard.Search)
|
||||||
|
}
|
||||||
|
}
|
||||||
34
internal/adapter/claude/error_shape_test.go
Normal file
34
internal/adapter/claude/error_shape_test.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
writeClaudeError(rec, http.StatusUnauthorized, "bad token")
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected 401, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
errObj, _ := body["error"].(map[string]any)
|
||||||
|
if errObj["message"] != "bad token" {
|
||||||
|
t.Fatalf("unexpected message: %v", errObj["message"])
|
||||||
|
}
|
||||||
|
if errObj["type"] != "invalid_request_error" {
|
||||||
|
t.Fatalf("unexpected type: %v", errObj["type"])
|
||||||
|
}
|
||||||
|
if errObj["code"] != "authentication_failed" {
|
||||||
|
t.Fatalf("unexpected code: %v", errObj["code"])
|
||||||
|
}
|
||||||
|
if _, ok := errObj["param"]; !ok {
|
||||||
|
t.Fatal("expected param field")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,603 +0,0 @@
|
|||||||
package claude
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
|
|
||||||
"ds2api/internal/auth"
|
|
||||||
"ds2api/internal/config"
|
|
||||||
"ds2api/internal/deepseek"
|
|
||||||
"ds2api/internal/sse"
|
|
||||||
"ds2api/internal/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
|
||||||
var writeJSON = util.WriteJSON
|
|
||||||
|
|
||||||
type Handler struct {
|
|
||||||
Store *config.Store
|
|
||||||
Auth *auth.Resolver
|
|
||||||
DS *deepseek.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
|
|
||||||
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
|
|
||||||
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
|
|
||||||
)
|
|
||||||
|
|
||||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
|
||||||
r.Get("/anthropic/v1/models", h.ListModels)
|
|
||||||
r.Post("/anthropic/v1/messages", h.Messages)
|
|
||||||
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
|
||||||
a, err := h.Auth.Determine(r)
|
|
||||||
if err != nil {
|
|
||||||
status := http.StatusUnauthorized
|
|
||||||
detail := err.Error()
|
|
||||||
if err == auth.ErrNoAccount {
|
|
||||||
status = http.StatusTooManyRequests
|
|
||||||
}
|
|
||||||
writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer h.Auth.Release(a)
|
|
||||||
|
|
||||||
var req map[string]any
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
model, _ := req["model"].(string)
|
|
||||||
messagesRaw, _ := req["messages"].([]any)
|
|
||||||
if model == "" || len(messagesRaw) == 0 {
|
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized := normalizeClaudeMessages(messagesRaw)
|
|
||||||
payload := cloneMap(req)
|
|
||||||
payload["messages"] = normalized
|
|
||||||
toolsRequested, _ := req["tools"].([]any)
|
|
||||||
if len(toolsRequested) > 0 && !hasSystemMessage(normalized) {
|
|
||||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...)
|
|
||||||
}
|
|
||||||
|
|
||||||
dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store)
|
|
||||||
dsModel, _ := dsPayload["model"].(string)
|
|
||||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
|
||||||
if !ok {
|
|
||||||
thinkingEnabled = false
|
|
||||||
searchEnabled = false
|
|
||||||
}
|
|
||||||
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
|
||||||
|
|
||||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
|
||||||
if err != nil {
|
|
||||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
|
||||||
if err != nil {
|
|
||||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
requestPayload := map[string]any{
|
|
||||||
"chat_session_id": sessionID,
|
|
||||||
"parent_message_id": nil,
|
|
||||||
"prompt": finalPrompt,
|
|
||||||
"ref_file_ids": []any{},
|
|
||||||
"thinking_enabled": thinkingEnabled,
|
|
||||||
"search_enabled": searchEnabled,
|
|
||||||
}
|
|
||||||
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
|
||||||
if err != nil {
|
|
||||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
toolNames := extractClaudeToolNames(toolsRequested)
|
|
||||||
if util.ToBool(req["stream"]) {
|
|
||||||
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
|
||||||
fullText := result.Text
|
|
||||||
fullThinking := result.Thinking
|
|
||||||
detected := util.ParseToolCalls(fullText, toolNames)
|
|
||||||
content := make([]map[string]any, 0, 4)
|
|
||||||
if fullThinking != "" {
|
|
||||||
content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking})
|
|
||||||
}
|
|
||||||
stopReason := "end_turn"
|
|
||||||
if len(detected) > 0 {
|
|
||||||
stopReason = "tool_use"
|
|
||||||
for i, tc := range detected {
|
|
||||||
content = append(content, map[string]any{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
|
|
||||||
"name": tc.Name,
|
|
||||||
"input": tc.Input,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if fullText == "" {
|
|
||||||
fullText = "抱歉,没有生成有效的响应内容。"
|
|
||||||
}
|
|
||||||
content = append(content, map[string]any{"type": "text", "text": fullText})
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
|
||||||
"id": fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": model,
|
|
||||||
"content": content,
|
|
||||||
"stop_reason": stopReason,
|
|
||||||
"stop_sequence": nil,
|
|
||||||
"usage": map[string]any{
|
|
||||||
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)),
|
|
||||||
"output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
|
||||||
a, err := h.Auth.Determine(r)
|
|
||||||
if err != nil {
|
|
||||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer h.Auth.Release(a)
|
|
||||||
|
|
||||||
var req map[string]any
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
model, _ := req["model"].(string)
|
|
||||||
messages, _ := req["messages"].([]any)
|
|
||||||
if model == "" || len(messages) == 0 {
|
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
inputTokens := 0
|
|
||||||
if sys, ok := req["system"].(string); ok {
|
|
||||||
inputTokens += util.EstimateTokens(sys)
|
|
||||||
}
|
|
||||||
for _, item := range messages {
|
|
||||||
msg, ok := item.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
inputTokens += 2
|
|
||||||
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
|
||||||
}
|
|
||||||
if tools, ok := req["tools"].([]any); ok {
|
|
||||||
for _, t := range tools {
|
|
||||||
b, _ := json.Marshal(t)
|
|
||||||
inputTokens += util.EstimateTokens(string(b))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if inputTokens < 1 {
|
|
||||||
inputTokens = 1
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
rc := http.NewResponseController(w)
|
|
||||||
canFlush := rc.Flush() == nil
|
|
||||||
if !canFlush {
|
|
||||||
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
|
||||||
}
|
|
||||||
send := func(event string, v any) {
|
|
||||||
b, _ := json.Marshal(v)
|
|
||||||
_, _ = w.Write([]byte("event: "))
|
|
||||||
_, _ = w.Write([]byte(event))
|
|
||||||
_, _ = w.Write([]byte("\n"))
|
|
||||||
_, _ = w.Write([]byte("data: "))
|
|
||||||
_, _ = w.Write(b)
|
|
||||||
_, _ = w.Write([]byte("\n\n"))
|
|
||||||
if canFlush {
|
|
||||||
_ = rc.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sendError := func(message string) {
|
|
||||||
msg := strings.TrimSpace(message)
|
|
||||||
if msg == "" {
|
|
||||||
msg = "upstream stream error"
|
|
||||||
}
|
|
||||||
send("error", map[string]any{
|
|
||||||
"type": "error",
|
|
||||||
"error": map[string]any{
|
|
||||||
"type": "api_error",
|
|
||||||
"message": msg,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
|
|
||||||
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages))
|
|
||||||
send("message_start", map[string]any{
|
|
||||||
"type": "message_start",
|
|
||||||
"message": map[string]any{
|
|
||||||
"id": messageID,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": model,
|
|
||||||
"content": []any{},
|
|
||||||
"stop_reason": nil,
|
|
||||||
"stop_sequence": nil,
|
|
||||||
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
initialType := "text"
|
|
||||||
if thinkingEnabled {
|
|
||||||
initialType = "thinking"
|
|
||||||
}
|
|
||||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
|
||||||
bufferToolContent := len(toolNames) > 0
|
|
||||||
hasContent := false
|
|
||||||
lastContent := time.Now()
|
|
||||||
keepaliveCount := 0
|
|
||||||
|
|
||||||
thinking := strings.Builder{}
|
|
||||||
text := strings.Builder{}
|
|
||||||
|
|
||||||
nextBlockIndex := 0
|
|
||||||
thinkingBlockOpen := false
|
|
||||||
thinkingBlockIndex := -1
|
|
||||||
textBlockOpen := false
|
|
||||||
textBlockIndex := -1
|
|
||||||
ended := false
|
|
||||||
|
|
||||||
closeThinkingBlock := func() {
|
|
||||||
if !thinkingBlockOpen {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
send("content_block_stop", map[string]any{
|
|
||||||
"type": "content_block_stop",
|
|
||||||
"index": thinkingBlockIndex,
|
|
||||||
})
|
|
||||||
thinkingBlockOpen = false
|
|
||||||
thinkingBlockIndex = -1
|
|
||||||
}
|
|
||||||
closeTextBlock := func() {
|
|
||||||
if !textBlockOpen {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
send("content_block_stop", map[string]any{
|
|
||||||
"type": "content_block_stop",
|
|
||||||
"index": textBlockIndex,
|
|
||||||
})
|
|
||||||
textBlockOpen = false
|
|
||||||
textBlockIndex = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
finalize := func(stopReason string) {
|
|
||||||
if ended {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ended = true
|
|
||||||
|
|
||||||
closeThinkingBlock()
|
|
||||||
closeTextBlock()
|
|
||||||
|
|
||||||
finalThinking := thinking.String()
|
|
||||||
finalText := text.String()
|
|
||||||
|
|
||||||
if bufferToolContent {
|
|
||||||
detected := util.ParseToolCalls(finalText, toolNames)
|
|
||||||
if len(detected) > 0 {
|
|
||||||
stopReason = "tool_use"
|
|
||||||
for i, tc := range detected {
|
|
||||||
idx := nextBlockIndex + i
|
|
||||||
send("content_block_start", map[string]any{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": idx,
|
|
||||||
"content_block": map[string]any{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
|
||||||
"name": tc.Name,
|
|
||||||
"input": tc.Input,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
send("content_block_stop", map[string]any{
|
|
||||||
"type": "content_block_stop",
|
|
||||||
"index": idx,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
nextBlockIndex += len(detected)
|
|
||||||
} else if finalText != "" {
|
|
||||||
idx := nextBlockIndex
|
|
||||||
nextBlockIndex++
|
|
||||||
send("content_block_start", map[string]any{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": idx,
|
|
||||||
"content_block": map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
send("content_block_delta", map[string]any{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": idx,
|
|
||||||
"delta": map[string]any{
|
|
||||||
"type": "text_delta",
|
|
||||||
"text": finalText,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
send("content_block_stop", map[string]any{
|
|
||||||
"type": "content_block_stop",
|
|
||||||
"index": idx,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
|
||||||
send("message_delta", map[string]any{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": map[string]any{
|
|
||||||
"stop_reason": stopReason,
|
|
||||||
"stop_sequence": nil,
|
|
||||||
},
|
|
||||||
"usage": map[string]any{
|
|
||||||
"output_tokens": outputTokens,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
send("message_stop", map[string]any{"type": "message_stop"})
|
|
||||||
}
|
|
||||||
|
|
||||||
pingTicker := time.NewTicker(claudeStreamPingInterval)
|
|
||||||
defer pingTicker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-r.Context().Done():
|
|
||||||
return
|
|
||||||
case <-pingTicker.C:
|
|
||||||
if !hasContent {
|
|
||||||
keepaliveCount++
|
|
||||||
if keepaliveCount >= claudeStreamMaxKeepaliveCnt {
|
|
||||||
finalize("end_turn")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hasContent && time.Since(lastContent) > claudeStreamIdleTimeout {
|
|
||||||
finalize("end_turn")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
send("ping", map[string]any{"type": "ping"})
|
|
||||||
case parsed, ok := <-parsedLines:
|
|
||||||
if !ok {
|
|
||||||
if err := <-done; err != nil {
|
|
||||||
sendError(err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
finalize("end_turn")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !parsed.Parsed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if parsed.ErrorMessage != "" {
|
|
||||||
sendError(parsed.ErrorMessage)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if parsed.Stop {
|
|
||||||
finalize("end_turn")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range parsed.Parts {
|
|
||||||
if p.Text == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
hasContent = true
|
|
||||||
lastContent = time.Now()
|
|
||||||
keepaliveCount = 0
|
|
||||||
|
|
||||||
if p.Type == "thinking" {
|
|
||||||
if !thinkingEnabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
thinking.WriteString(p.Text)
|
|
||||||
closeTextBlock()
|
|
||||||
if !thinkingBlockOpen {
|
|
||||||
thinkingBlockIndex = nextBlockIndex
|
|
||||||
nextBlockIndex++
|
|
||||||
send("content_block_start", map[string]any{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": thinkingBlockIndex,
|
|
||||||
"content_block": map[string]any{
|
|
||||||
"type": "thinking",
|
|
||||||
"thinking": "",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
thinkingBlockOpen = true
|
|
||||||
}
|
|
||||||
send("content_block_delta", map[string]any{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": thinkingBlockIndex,
|
|
||||||
"delta": map[string]any{
|
|
||||||
"type": "thinking_delta",
|
|
||||||
"thinking": p.Text,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
text.WriteString(p.Text)
|
|
||||||
if bufferToolContent {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
closeThinkingBlock()
|
|
||||||
if !textBlockOpen {
|
|
||||||
textBlockIndex = nextBlockIndex
|
|
||||||
nextBlockIndex++
|
|
||||||
send("content_block_start", map[string]any{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": textBlockIndex,
|
|
||||||
"content_block": map[string]any{
|
|
||||||
"type": "text",
|
|
||||||
"text": "",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
textBlockOpen = true
|
|
||||||
}
|
|
||||||
send("content_block_delta", map[string]any{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": textBlockIndex,
|
|
||||||
"delta": map[string]any{
|
|
||||||
"type": "text_delta",
|
|
||||||
"text": p.Text,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeClaudeMessages(messages []any) []any {
|
|
||||||
out := make([]any, 0, len(messages))
|
|
||||||
for _, m := range messages {
|
|
||||||
msg, ok := m.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
copied := cloneMap(msg)
|
|
||||||
switch content := msg["content"].(type) {
|
|
||||||
case []any:
|
|
||||||
parts := make([]string, 0, len(content))
|
|
||||||
for _, block := range content {
|
|
||||||
b, ok := block.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
typeStr, _ := b["type"].(string)
|
|
||||||
if typeStr == "text" {
|
|
||||||
if t, ok := b["text"].(string); ok {
|
|
||||||
parts = append(parts, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if typeStr == "tool_result" {
|
|
||||||
parts = append(parts, fmt.Sprintf("%v", b["content"]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
copied["content"] = strings.Join(parts, "\n")
|
|
||||||
}
|
|
||||||
out = append(out, copied)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildClaudeToolPrompt(tools []any) string {
|
|
||||||
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
|
||||||
for _, t := range tools {
|
|
||||||
m, ok := t.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name, _ := m["name"].(string)
|
|
||||||
desc, _ := m["description"].(string)
|
|
||||||
schema, _ := json.Marshal(m["input_schema"])
|
|
||||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
|
||||||
}
|
|
||||||
parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}")
|
|
||||||
return strings.Join(parts, "\n\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasSystemMessage(messages []any) bool {
|
|
||||||
for _, m := range messages {
|
|
||||||
msg, ok := m.(map[string]any)
|
|
||||||
if ok && msg["role"] == "system" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractClaudeToolNames(tools []any) []string {
|
|
||||||
out := make([]string, 0, len(tools))
|
|
||||||
for _, t := range tools {
|
|
||||||
m, ok := t.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if name, ok := m["name"].(string); ok && name != "" {
|
|
||||||
out = append(out, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func toMessageMaps(v any) []map[string]any {
|
|
||||||
arr, ok := v.([]any)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
out := make([]map[string]any, 0, len(arr))
|
|
||||||
for _, item := range arr {
|
|
||||||
if m, ok := item.(map[string]any); ok {
|
|
||||||
out = append(out, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractMessageContent(v any) string {
|
|
||||||
switch x := v.(type) {
|
|
||||||
case string:
|
|
||||||
return x
|
|
||||||
case []any:
|
|
||||||
parts := make([]string, 0, len(x))
|
|
||||||
for _, it := range x {
|
|
||||||
parts = append(parts, fmt.Sprintf("%v", it))
|
|
||||||
}
|
|
||||||
return strings.Join(parts, "\n")
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("%v", x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneMap(in map[string]any) map[string]any {
|
|
||||||
out := make(map[string]any, len(in))
|
|
||||||
for k, v := range in {
|
|
||||||
out[k] = v
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
25
internal/adapter/claude/handler_errors.go
Normal file
25
internal/adapter/claude/handler_errors.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
func writeClaudeError(w http.ResponseWriter, status int, message string) {
|
||||||
|
code := "invalid_request"
|
||||||
|
switch status {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
code = "authentication_failed"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
code = "rate_limit_exceeded"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
code = "not_found"
|
||||||
|
case http.StatusInternalServerError:
|
||||||
|
code = "internal_error"
|
||||||
|
}
|
||||||
|
writeJSON(w, status, map[string]any{
|
||||||
|
"error": map[string]any{
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"message": message,
|
||||||
|
"code": code,
|
||||||
|
"param": nil,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
134
internal/adapter/claude/handler_messages.go
Normal file
134
internal/adapter/claude/handler_messages.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
claudefmt "ds2api/internal/format/claude"
|
||||||
|
"ds2api/internal/sse"
|
||||||
|
streamengine "ds2api/internal/stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||||
|
r.Header.Set("anthropic-version", "2023-06-01")
|
||||||
|
}
|
||||||
|
a, err := h.Auth.Determine(r)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusUnauthorized
|
||||||
|
detail := err.Error()
|
||||||
|
if err == auth.ErrNoAccount {
|
||||||
|
status = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
writeClaudeError(w, status, detail)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer h.Auth.Release(a)
|
||||||
|
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
norm, err := normalizeClaudeRequest(h.Store, req)
|
||||||
|
if err != nil {
|
||||||
|
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stdReq := norm.Standard
|
||||||
|
|
||||||
|
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||||
|
if err != nil {
|
||||||
|
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||||
|
if err != nil {
|
||||||
|
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requestPayload := stdReq.CompletionPayload(sessionID)
|
||||||
|
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
||||||
|
if err != nil {
|
||||||
|
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stdReq.Stream {
|
||||||
|
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result := sse.CollectStream(resp, stdReq.Thinking, true)
|
||||||
|
respBody := claudefmt.BuildMessageResponse(
|
||||||
|
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||||
|
stdReq.ResponseModel,
|
||||||
|
norm.NormalizedMessages,
|
||||||
|
result.Thinking,
|
||||||
|
result.Text,
|
||||||
|
stdReq.ToolNames,
|
||||||
|
)
|
||||||
|
writeJSON(w, http.StatusOK, respBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
rc := http.NewResponseController(w)
|
||||||
|
_, canFlush := w.(http.Flusher)
|
||||||
|
if !canFlush {
|
||||||
|
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
||||||
|
}
|
||||||
|
|
||||||
|
streamRuntime := newClaudeStreamRuntime(
|
||||||
|
w,
|
||||||
|
rc,
|
||||||
|
canFlush,
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
thinkingEnabled,
|
||||||
|
searchEnabled,
|
||||||
|
toolNames,
|
||||||
|
)
|
||||||
|
streamRuntime.sendMessageStart()
|
||||||
|
|
||||||
|
initialType := "text"
|
||||||
|
if thinkingEnabled {
|
||||||
|
initialType = "thinking"
|
||||||
|
}
|
||||||
|
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||||
|
Context: r.Context(),
|
||||||
|
Body: resp.Body,
|
||||||
|
ThinkingEnabled: thinkingEnabled,
|
||||||
|
InitialType: initialType,
|
||||||
|
KeepAliveInterval: claudeStreamPingInterval,
|
||||||
|
IdleTimeout: claudeStreamIdleTimeout,
|
||||||
|
MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
|
||||||
|
}, streamengine.ConsumeHooks{
|
||||||
|
OnKeepAlive: func() {
|
||||||
|
streamRuntime.sendPing()
|
||||||
|
},
|
||||||
|
OnParsed: streamRuntime.onParsed,
|
||||||
|
OnFinalize: streamRuntime.onFinalize,
|
||||||
|
})
|
||||||
|
}
|
||||||
41
internal/adapter/claude/handler_routes.go
Normal file
41
internal/adapter/claude/handler_routes.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
||||||
|
var writeJSON = util.WriteJSON
|
||||||
|
|
||||||
|
type Handler struct {
|
||||||
|
Store ConfigReader
|
||||||
|
Auth AuthResolver
|
||||||
|
DS DeepSeekCaller
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
|
||||||
|
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
|
||||||
|
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
|
||||||
|
)
|
||||||
|
|
||||||
|
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||||
|
r.Get("/anthropic/v1/models", h.ListModels)
|
||||||
|
r.Post("/anthropic/v1/messages", h.Messages)
|
||||||
|
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
||||||
|
r.Post("/v1/messages", h.Messages)
|
||||||
|
r.Post("/messages", h.Messages)
|
||||||
|
r.Post("/v1/messages/count_tokens", h.CountTokens)
|
||||||
|
r.Post("/messages/count_tokens", h.CountTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
||||||
|
}
|
||||||
51
internal/adapter/claude/handler_tokens.go
Normal file
51
internal/adapter/claude/handler_tokens.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
||||||
|
a, err := h.Auth.Determine(r)
|
||||||
|
if err != nil {
|
||||||
|
writeClaudeError(w, http.StatusUnauthorized, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer h.Auth.Release(a)
|
||||||
|
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model, _ := req["model"].(string)
|
||||||
|
messages, _ := req["messages"].([]any)
|
||||||
|
if model == "" || len(messages) == 0 {
|
||||||
|
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
inputTokens := 0
|
||||||
|
if sys, ok := req["system"].(string); ok {
|
||||||
|
inputTokens += util.EstimateTokens(sys)
|
||||||
|
}
|
||||||
|
for _, item := range messages {
|
||||||
|
msg, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
inputTokens += 2
|
||||||
|
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
||||||
|
}
|
||||||
|
if tools, ok := req["tools"].([]any); ok {
|
||||||
|
for _, t := range tools {
|
||||||
|
b, _ := json.Marshal(t)
|
||||||
|
inputTokens += util.EstimateTokens(string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if inputTokens < 1 {
|
||||||
|
inputTokens = 1
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
||||||
|
}
|
||||||
350
internal/adapter/claude/handler_util_test.go
Normal file
350
internal/adapter/claude/handler_util_test.go
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ─── normalizeClaudeMessages ─────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestNormalizeClaudeMessagesSimpleString(t *testing.T) {
|
||||||
|
msgs := []any{
|
||||||
|
map[string]any{"role": "user", "content": "Hello"},
|
||||||
|
}
|
||||||
|
got := normalizeClaudeMessages(msgs)
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(got))
|
||||||
|
}
|
||||||
|
m := got[0].(map[string]any)
|
||||||
|
if m["content"] != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %v", m["content"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeMessagesArrayContent(t *testing.T) {
|
||||||
|
msgs := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "line1"},
|
||||||
|
map[string]any{"type": "text", "text": "line2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got := normalizeClaudeMessages(msgs)
|
||||||
|
m := got[0].(map[string]any)
|
||||||
|
if m["content"] != "line1\nline2" {
|
||||||
|
t.Fatalf("expected joined text, got %q", m["content"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
|
||||||
|
msgs := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "tool_result", "content": "tool output"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got := normalizeClaudeMessages(msgs)
|
||||||
|
m := got[0].(map[string]any)
|
||||||
|
content, _ := m["content"].(string)
|
||||||
|
if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") {
|
||||||
|
t.Fatalf("expected serialized tool result marker, got %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) {
|
||||||
|
msgs := []any{"not a map", 42}
|
||||||
|
got := normalizeClaudeMessages(msgs)
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("expected 0 messages for non-map items, got %d", len(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeMessagesEmpty(t *testing.T) {
|
||||||
|
got := normalizeClaudeMessages(nil)
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("expected 0, got %d", len(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) {
|
||||||
|
msgs := []any{
|
||||||
|
map[string]any{"role": "assistant", "content": "response"},
|
||||||
|
}
|
||||||
|
got := normalizeClaudeMessages(msgs)
|
||||||
|
m := got[0].(map[string]any)
|
||||||
|
if m["role"] != "assistant" {
|
||||||
|
t.Fatalf("expected 'assistant', got %q", m["role"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) {
|
||||||
|
msgs := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "Hello"},
|
||||||
|
map[string]any{"type": "image", "source": "data:..."},
|
||||||
|
map[string]any{"type": "text", "text": "World"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got := normalizeClaudeMessages(msgs)
|
||||||
|
m := got[0].(map[string]any)
|
||||||
|
if m["content"] != "Hello\nWorld" {
|
||||||
|
t.Fatalf("expected only text parts joined, got %q", m["content"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── buildClaudeToolPrompt ───────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
||||||
|
tools := []any{
|
||||||
|
map[string]any{
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search the web",
|
||||||
|
"input_schema": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"query": map[string]any{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
prompt := buildClaudeToolPrompt(tools)
|
||||||
|
if prompt == "" {
|
||||||
|
t.Fatal("expected non-empty prompt")
|
||||||
|
}
|
||||||
|
// Should contain tool name and description
|
||||||
|
if !containsStr(prompt, "search") {
|
||||||
|
t.Fatalf("expected 'search' in prompt")
|
||||||
|
}
|
||||||
|
if !containsStr(prompt, "Search the web") {
|
||||||
|
t.Fatalf("expected description in prompt")
|
||||||
|
}
|
||||||
|
if !containsStr(prompt, "tool_calls") {
|
||||||
|
t.Fatalf("expected tool_calls instruction in prompt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeToolPromptMultipleTools(t *testing.T) {
|
||||||
|
tools := []any{
|
||||||
|
map[string]any{"name": "tool1", "description": "desc1"},
|
||||||
|
map[string]any{"name": "tool2", "description": "desc2"},
|
||||||
|
}
|
||||||
|
prompt := buildClaudeToolPrompt(tools)
|
||||||
|
if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") {
|
||||||
|
t.Fatalf("expected both tools in prompt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) {
|
||||||
|
tools := []any{"not a map"}
|
||||||
|
prompt := buildClaudeToolPrompt(tools)
|
||||||
|
if prompt == "" {
|
||||||
|
t.Fatal("expected non-empty prompt even with invalid tools")
|
||||||
|
}
|
||||||
|
// Should still contain the intro and instruction
|
||||||
|
if !containsStr(prompt, "You are Claude") {
|
||||||
|
t.Fatalf("expected intro in prompt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── hasSystemMessage ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestHasSystemMessageTrue(t *testing.T) {
|
||||||
|
msgs := []any{
|
||||||
|
map[string]any{"role": "system", "content": "You are a helper"},
|
||||||
|
map[string]any{"role": "user", "content": "Hi"},
|
||||||
|
}
|
||||||
|
if !hasSystemMessage(msgs) {
|
||||||
|
t.Fatal("expected true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasSystemMessageFalse(t *testing.T) {
|
||||||
|
msgs := []any{
|
||||||
|
map[string]any{"role": "user", "content": "Hi"},
|
||||||
|
map[string]any{"role": "assistant", "content": "Hello"},
|
||||||
|
}
|
||||||
|
if hasSystemMessage(msgs) {
|
||||||
|
t.Fatal("expected false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasSystemMessageEmpty(t *testing.T) {
|
||||||
|
if hasSystemMessage(nil) {
|
||||||
|
t.Fatal("expected false for nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasSystemMessageNonMap(t *testing.T) {
|
||||||
|
msgs := []any{"not a map"}
|
||||||
|
if hasSystemMessage(msgs) {
|
||||||
|
t.Fatal("expected false for non-map")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── extractClaudeToolNames ──────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestExtractClaudeToolNamesSingle(t *testing.T) {
|
||||||
|
tools := []any{
|
||||||
|
map[string]any{"name": "search"},
|
||||||
|
}
|
||||||
|
names := extractClaudeToolNames(tools)
|
||||||
|
if len(names) != 1 || names[0] != "search" {
|
||||||
|
t.Fatalf("expected [search], got %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractClaudeToolNamesMultiple(t *testing.T) {
|
||||||
|
tools := []any{
|
||||||
|
map[string]any{"name": "search"},
|
||||||
|
map[string]any{"name": "calculate"},
|
||||||
|
}
|
||||||
|
names := extractClaudeToolNames(tools)
|
||||||
|
if len(names) != 2 {
|
||||||
|
t.Fatalf("expected 2 names, got %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) {
|
||||||
|
tools := []any{
|
||||||
|
map[string]any{"name": ""},
|
||||||
|
map[string]any{"name": "valid"},
|
||||||
|
}
|
||||||
|
names := extractClaudeToolNames(tools)
|
||||||
|
if len(names) != 1 || names[0] != "valid" {
|
||||||
|
t.Fatalf("expected [valid], got %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) {
|
||||||
|
tools := []any{"not a map", 42}
|
||||||
|
names := extractClaudeToolNames(tools)
|
||||||
|
if len(names) != 0 {
|
||||||
|
t.Fatalf("expected 0, got %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractClaudeToolNamesNil(t *testing.T) {
|
||||||
|
names := extractClaudeToolNames(nil)
|
||||||
|
if len(names) != 0 {
|
||||||
|
t.Fatalf("expected 0, got %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── toMessageMaps ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestToMessageMapsNormal(t *testing.T) {
|
||||||
|
input := []any{
|
||||||
|
map[string]any{"role": "user", "content": "Hello"},
|
||||||
|
}
|
||||||
|
got := toMessageMaps(input)
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected 1, got %d", len(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToMessageMapsNonSlice(t *testing.T) {
|
||||||
|
got := toMessageMaps("not a slice")
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("expected nil, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToMessageMapsSkipsNonMap(t *testing.T) {
|
||||||
|
input := []any{"string", map[string]any{"role": "user"}, 42}
|
||||||
|
got := toMessageMaps(input)
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("expected 1 map, got %d", len(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToMessageMapsNil(t *testing.T) {
|
||||||
|
got := toMessageMaps(nil)
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("expected nil, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── extractMessageContent ──────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestExtractMessageContentString(t *testing.T) {
|
||||||
|
if got := extractMessageContent("hello"); got != "hello" {
|
||||||
|
t.Fatalf("expected 'hello', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractMessageContentArray(t *testing.T) {
|
||||||
|
input := []any{"part1", "part2"}
|
||||||
|
got := extractMessageContent(input)
|
||||||
|
if got != "part1\npart2" {
|
||||||
|
t.Fatalf("expected joined, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractMessageContentOther(t *testing.T) {
|
||||||
|
got := extractMessageContent(42)
|
||||||
|
if got != "42" {
|
||||||
|
t.Fatalf("expected '42', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractMessageContentNil(t *testing.T) {
|
||||||
|
got := extractMessageContent(nil)
|
||||||
|
if got != "<nil>" {
|
||||||
|
t.Fatalf("expected '<nil>', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── cloneMap ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestCloneMapBasic(t *testing.T) {
|
||||||
|
original := map[string]any{"a": 1, "b": "hello"}
|
||||||
|
clone := cloneMap(original)
|
||||||
|
original["a"] = 999
|
||||||
|
if clone["a"] != 1 {
|
||||||
|
t.Fatalf("expected 1, got %v", clone["a"])
|
||||||
|
}
|
||||||
|
if clone["b"] != "hello" {
|
||||||
|
t.Fatalf("expected 'hello', got %v", clone["b"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneMapEmpty(t *testing.T) {
|
||||||
|
clone := cloneMap(map[string]any{})
|
||||||
|
if len(clone) != 0 {
|
||||||
|
t.Fatalf("expected empty, got %v", clone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneMapNested(t *testing.T) {
|
||||||
|
// cloneMap is shallow, so nested maps share references
|
||||||
|
inner := map[string]any{"key": "value"}
|
||||||
|
original := map[string]any{"nested": inner}
|
||||||
|
clone := cloneMap(original)
|
||||||
|
// Shallow clone means inner is shared
|
||||||
|
inner["key"] = "modified"
|
||||||
|
cloneNested := clone["nested"].(map[string]any)
|
||||||
|
if cloneNested["key"] != "modified" {
|
||||||
|
t.Fatal("expected shallow clone to share nested references")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper
|
||||||
|
func containsStr(s, sub string) bool {
|
||||||
|
return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func findSubstring(s, sub string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(sub); i++ {
|
||||||
|
if s[i:i+len(sub)] == sub {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
143
internal/adapter/claude/handler_utils.go
Normal file
143
internal/adapter/claude/handler_utils.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeClaudeMessages(messages []any) []any {
|
||||||
|
out := make([]any, 0, len(messages))
|
||||||
|
for _, m := range messages {
|
||||||
|
msg, ok := m.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copied := cloneMap(msg)
|
||||||
|
switch content := msg["content"].(type) {
|
||||||
|
case []any:
|
||||||
|
parts := make([]string, 0, len(content))
|
||||||
|
for _, block := range content {
|
||||||
|
b, ok := block.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
typeStr, _ := b["type"].(string)
|
||||||
|
if typeStr == "text" {
|
||||||
|
if t, ok := b["text"].(string); ok {
|
||||||
|
parts = append(parts, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if typeStr == "tool_result" {
|
||||||
|
parts = append(parts, formatClaudeToolResultForPrompt(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
copied["content"] = strings.Join(parts, "\n")
|
||||||
|
}
|
||||||
|
out = append(out, copied)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildClaudeToolPrompt(tools []any) string {
|
||||||
|
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
||||||
|
for _, t := range tools {
|
||||||
|
m, ok := t.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name, _ := m["name"].(string)
|
||||||
|
desc, _ := m["description"].(string)
|
||||||
|
schema, _ := json.Marshal(m["input_schema"])
|
||||||
|
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||||
|
}
|
||||||
|
parts = append(parts,
|
||||||
|
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
|
||||||
|
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
||||||
|
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||||
|
)
|
||||||
|
return strings.Join(parts, "\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatClaudeToolResultForPrompt(block map[string]any) string {
|
||||||
|
if block == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||||
|
if toolCallID == "" {
|
||||||
|
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
|
||||||
|
}
|
||||||
|
if toolCallID == "" {
|
||||||
|
toolCallID = "unknown"
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||||
|
if name == "" {
|
||||||
|
name = "unknown"
|
||||||
|
}
|
||||||
|
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
|
||||||
|
if content == "" {
|
||||||
|
content = "null"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSystemMessage(messages []any) bool {
|
||||||
|
for _, m := range messages {
|
||||||
|
msg, ok := m.(map[string]any)
|
||||||
|
if ok && msg["role"] == "system" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractClaudeToolNames(tools []any) []string {
|
||||||
|
out := make([]string, 0, len(tools))
|
||||||
|
for _, t := range tools {
|
||||||
|
m, ok := t.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if name, ok := m["name"].(string); ok && name != "" {
|
||||||
|
out = append(out, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func toMessageMaps(v any) []map[string]any {
|
||||||
|
arr, ok := v.([]any)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]map[string]any, 0, len(arr))
|
||||||
|
for _, item := range arr {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
out = append(out, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractMessageContent(v any) string {
|
||||||
|
switch x := v.(type) {
|
||||||
|
case string:
|
||||||
|
return x
|
||||||
|
case []any:
|
||||||
|
parts := make([]string, 0, len(x))
|
||||||
|
for _, it := range x {
|
||||||
|
parts = append(parts, fmt.Sprintf("%v", it))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "\n")
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%v", x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneMap(in map[string]any) map[string]any {
|
||||||
|
out := make(map[string]any, len(in))
|
||||||
|
for k, v := range in {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
44
internal/adapter/claude/route_alias_test.go
Normal file
44
internal/adapter/claude/route_alias_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type routeAliasAuthStub struct{}
|
||||||
|
|
||||||
|
func (routeAliasAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||||
|
return nil, auth.ErrUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
func (routeAliasAuthStub) Release(_ *auth.RequestAuth) {}
|
||||||
|
|
||||||
|
func TestClaudeRouteAliasesDoNot404(t *testing.T) {
|
||||||
|
h := &Handler{
|
||||||
|
Auth: routeAliasAuthStub{},
|
||||||
|
}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
paths := []string{
|
||||||
|
"/anthropic/v1/messages",
|
||||||
|
"/v1/messages",
|
||||||
|
"/messages",
|
||||||
|
"/anthropic/v1/messages/count_tokens",
|
||||||
|
"/v1/messages/count_tokens",
|
||||||
|
"/messages/count_tokens",
|
||||||
|
}
|
||||||
|
for _, path := range paths {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, path, nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code == http.StatusNotFound {
|
||||||
|
t.Fatalf("expected route %s to be registered, got 404", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
113
internal/adapter/claude/standard_request.go
Normal file
113
internal/adapter/claude/standard_request.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeNormalizedRequest struct {
|
||||||
|
Standard util.StandardRequest
|
||||||
|
NormalizedMessages []any
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNormalizedRequest, error) {
|
||||||
|
model, _ := req["model"].(string)
|
||||||
|
messagesRaw, _ := req["messages"].([]any)
|
||||||
|
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||||
|
return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
|
||||||
|
}
|
||||||
|
if _, ok := req["max_tokens"]; !ok {
|
||||||
|
req["max_tokens"] = 8192
|
||||||
|
}
|
||||||
|
normalizedMessages := normalizeClaudeMessages(messagesRaw)
|
||||||
|
payload := cloneMap(req)
|
||||||
|
payload["messages"] = normalizedMessages
|
||||||
|
toolsRequested, _ := req["tools"].([]any)
|
||||||
|
payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested)
|
||||||
|
|
||||||
|
dsPayload := convertClaudeToDeepSeek(payload, store)
|
||||||
|
dsModel, _ := dsPayload["model"].(string)
|
||||||
|
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||||
|
if !ok {
|
||||||
|
thinkingEnabled = false
|
||||||
|
searchEnabled = false
|
||||||
|
}
|
||||||
|
finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||||
|
toolNames := extractClaudeToolNames(toolsRequested)
|
||||||
|
|
||||||
|
return claudeNormalizedRequest{
|
||||||
|
Standard: util.StandardRequest{
|
||||||
|
Surface: "anthropic_messages",
|
||||||
|
RequestedModel: strings.TrimSpace(model),
|
||||||
|
ResolvedModel: dsModel,
|
||||||
|
ResponseModel: strings.TrimSpace(model),
|
||||||
|
Messages: payload["messages"].([]any),
|
||||||
|
FinalPrompt: finalPrompt,
|
||||||
|
ToolNames: toolNames,
|
||||||
|
Stream: util.ToBool(req["stream"]),
|
||||||
|
Thinking: thinkingEnabled,
|
||||||
|
Search: searchEnabled,
|
||||||
|
},
|
||||||
|
NormalizedMessages: normalizedMessages,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any {
|
||||||
|
if len(tools) == 0 {
|
||||||
|
return normalizedMessages
|
||||||
|
}
|
||||||
|
toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools))
|
||||||
|
if toolPrompt == "" {
|
||||||
|
return normalizedMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer top-level Anthropic-style system prompt when available.
|
||||||
|
if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" {
|
||||||
|
payload["system"] = mergeSystemPrompt(systemText, toolPrompt)
|
||||||
|
return normalizedMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := cloneAnySlice(normalizedMessages)
|
||||||
|
for i := range messages {
|
||||||
|
msg, ok := messages[i].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
role, _ := msg["role"].(string)
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(role), "system") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copied := cloneMap(msg)
|
||||||
|
copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt)
|
||||||
|
messages[i] = copied
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeSystemPrompt(base, extra string) string {
|
||||||
|
base = strings.TrimSpace(base)
|
||||||
|
extra = strings.TrimSpace(extra)
|
||||||
|
switch {
|
||||||
|
case base == "":
|
||||||
|
return extra
|
||||||
|
case extra == "":
|
||||||
|
return base
|
||||||
|
default:
|
||||||
|
return base + "\n\n" + extra
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAnySlice(in []any) []any {
|
||||||
|
if len(in) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]any, len(in))
|
||||||
|
copy(out, in)
|
||||||
|
return out
|
||||||
|
}
|
||||||
92
internal/adapter/claude/standard_request_test.go
Normal file
92
internal/adapter/claude/standard_request_test.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeClaudeRequest(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||||
|
store := config.LoadStore()
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "claude-opus-4-6",
|
||||||
|
"messages": []any{
|
||||||
|
map[string]any{"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
"stream": true,
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{"name": "search", "description": "Search"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
norm, err := normalizeClaudeRequest(store, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize failed: %v", err)
|
||||||
|
}
|
||||||
|
if norm.Standard.ResolvedModel == "" {
|
||||||
|
t.Fatalf("expected resolved model")
|
||||||
|
}
|
||||||
|
if !norm.Standard.Stream {
|
||||||
|
t.Fatalf("expected stream=true")
|
||||||
|
}
|
||||||
|
if len(norm.Standard.ToolNames) == 0 {
|
||||||
|
t.Fatalf("expected tool names")
|
||||||
|
}
|
||||||
|
if norm.Standard.FinalPrompt == "" {
|
||||||
|
t.Fatalf("expected non-empty final prompt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||||
|
store := config.LoadStore()
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": []any{
|
||||||
|
map[string]any{"role": "system", "content": "baseline rule"},
|
||||||
|
map[string]any{"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{"name": "search", "description": "Search"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
norm, err := normalizeClaudeRequest(store, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||||
|
t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt)
|
||||||
|
}
|
||||||
|
if !containsStr(norm.Standard.FinalPrompt, "baseline rule") {
|
||||||
|
t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) {
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||||
|
store := config.LoadStore()
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"system": "top-level system",
|
||||||
|
"messages": []any{
|
||||||
|
map[string]any{"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{"name": "search", "description": "Search"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
norm, err := normalizeClaudeRequest(store, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !containsStr(norm.Standard.FinalPrompt, "top-level system") {
|
||||||
|
t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt)
|
||||||
|
}
|
||||||
|
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||||
|
t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
146
internal/adapter/claude/stream_runtime_core.go
Normal file
146
internal/adapter/claude/stream_runtime_core.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"ds2api/internal/sse"
|
||||||
|
streamengine "ds2api/internal/stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeStreamRuntime struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
rc *http.ResponseController
|
||||||
|
canFlush bool
|
||||||
|
|
||||||
|
model string
|
||||||
|
toolNames []string
|
||||||
|
messages []any
|
||||||
|
|
||||||
|
thinkingEnabled bool
|
||||||
|
searchEnabled bool
|
||||||
|
bufferToolContent bool
|
||||||
|
|
||||||
|
messageID string
|
||||||
|
thinking strings.Builder
|
||||||
|
text strings.Builder
|
||||||
|
|
||||||
|
nextBlockIndex int
|
||||||
|
thinkingBlockOpen bool
|
||||||
|
thinkingBlockIndex int
|
||||||
|
textBlockOpen bool
|
||||||
|
textBlockIndex int
|
||||||
|
ended bool
|
||||||
|
upstreamErr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClaudeStreamRuntime(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
rc *http.ResponseController,
|
||||||
|
canFlush bool,
|
||||||
|
model string,
|
||||||
|
messages []any,
|
||||||
|
thinkingEnabled bool,
|
||||||
|
searchEnabled bool,
|
||||||
|
toolNames []string,
|
||||||
|
) *claudeStreamRuntime {
|
||||||
|
return &claudeStreamRuntime{
|
||||||
|
w: w,
|
||||||
|
rc: rc,
|
||||||
|
canFlush: canFlush,
|
||||||
|
model: model,
|
||||||
|
messages: messages,
|
||||||
|
thinkingEnabled: thinkingEnabled,
|
||||||
|
searchEnabled: searchEnabled,
|
||||||
|
bufferToolContent: len(toolNames) > 0,
|
||||||
|
toolNames: toolNames,
|
||||||
|
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||||
|
thinkingBlockIndex: -1,
|
||||||
|
textBlockIndex: -1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||||
|
if !parsed.Parsed {
|
||||||
|
return streamengine.ParsedDecision{}
|
||||||
|
}
|
||||||
|
if parsed.ErrorMessage != "" {
|
||||||
|
s.upstreamErr = parsed.ErrorMessage
|
||||||
|
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||||
|
}
|
||||||
|
if parsed.Stop {
|
||||||
|
return streamengine.ParsedDecision{Stop: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentSeen := false
|
||||||
|
for _, p := range parsed.Parts {
|
||||||
|
if p.Text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentSeen = true
|
||||||
|
|
||||||
|
if p.Type == "thinking" {
|
||||||
|
if !s.thinkingEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.thinking.WriteString(p.Text)
|
||||||
|
s.closeTextBlock()
|
||||||
|
if !s.thinkingBlockOpen {
|
||||||
|
s.thinkingBlockIndex = s.nextBlockIndex
|
||||||
|
s.nextBlockIndex++
|
||||||
|
s.send("content_block_start", map[string]any{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": s.thinkingBlockIndex,
|
||||||
|
"content_block": map[string]any{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s.thinkingBlockOpen = true
|
||||||
|
}
|
||||||
|
s.send("content_block_delta", map[string]any{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": s.thinkingBlockIndex,
|
||||||
|
"delta": map[string]any{
|
||||||
|
"type": "thinking_delta",
|
||||||
|
"thinking": p.Text,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
s.text.WriteString(p.Text)
|
||||||
|
if s.bufferToolContent {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.closeThinkingBlock()
|
||||||
|
if !s.textBlockOpen {
|
||||||
|
s.textBlockIndex = s.nextBlockIndex
|
||||||
|
s.nextBlockIndex++
|
||||||
|
s.send("content_block_start", map[string]any{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": s.textBlockIndex,
|
||||||
|
"content_block": map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s.textBlockOpen = true
|
||||||
|
}
|
||||||
|
s.send("content_block_delta", map[string]any{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": s.textBlockIndex,
|
||||||
|
"delta": map[string]any{
|
||||||
|
"type": "text_delta",
|
||||||
|
"text": p.Text,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||||
|
}
|
||||||
59
internal/adapter/claude/stream_runtime_emit.go
Normal file
59
internal/adapter/claude/stream_runtime_emit.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *claudeStreamRuntime) send(event string, v any) {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
_, _ = s.w.Write([]byte("event: "))
|
||||||
|
_, _ = s.w.Write([]byte(event))
|
||||||
|
_, _ = s.w.Write([]byte("\n"))
|
||||||
|
_, _ = s.w.Write([]byte("data: "))
|
||||||
|
_, _ = s.w.Write(b)
|
||||||
|
_, _ = s.w.Write([]byte("\n\n"))
|
||||||
|
if s.canFlush {
|
||||||
|
_ = s.rc.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeStreamRuntime) sendError(message string) {
|
||||||
|
msg := strings.TrimSpace(message)
|
||||||
|
if msg == "" {
|
||||||
|
msg = "upstream stream error"
|
||||||
|
}
|
||||||
|
s.send("error", map[string]any{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]any{
|
||||||
|
"type": "api_error",
|
||||||
|
"message": msg,
|
||||||
|
"code": "internal_error",
|
||||||
|
"param": nil,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeStreamRuntime) sendPing() {
|
||||||
|
s.send("ping", map[string]any{"type": "ping"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeStreamRuntime) sendMessageStart() {
|
||||||
|
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
|
||||||
|
s.send("message_start", map[string]any{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": map[string]any{
|
||||||
|
"id": s.messageID,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": s.model,
|
||||||
|
"content": []any{},
|
||||||
|
"stop_reason": nil,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
119
internal/adapter/claude/stream_runtime_finalize.go
Normal file
119
internal/adapter/claude/stream_runtime_finalize.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
streamengine "ds2api/internal/stream"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *claudeStreamRuntime) closeThinkingBlock() {
|
||||||
|
if !s.thinkingBlockOpen {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.send("content_block_stop", map[string]any{
|
||||||
|
"type": "content_block_stop",
|
||||||
|
"index": s.thinkingBlockIndex,
|
||||||
|
})
|
||||||
|
s.thinkingBlockOpen = false
|
||||||
|
s.thinkingBlockIndex = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeStreamRuntime) closeTextBlock() {
|
||||||
|
if !s.textBlockOpen {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.send("content_block_stop", map[string]any{
|
||||||
|
"type": "content_block_stop",
|
||||||
|
"index": s.textBlockIndex,
|
||||||
|
})
|
||||||
|
s.textBlockOpen = false
|
||||||
|
s.textBlockIndex = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||||
|
if s.ended {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.ended = true
|
||||||
|
|
||||||
|
s.closeThinkingBlock()
|
||||||
|
s.closeTextBlock()
|
||||||
|
|
||||||
|
finalThinking := s.thinking.String()
|
||||||
|
finalText := s.text.String()
|
||||||
|
|
||||||
|
if s.bufferToolContent {
|
||||||
|
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||||
|
if len(detected) > 0 {
|
||||||
|
stopReason = "tool_use"
|
||||||
|
for i, tc := range detected {
|
||||||
|
idx := s.nextBlockIndex + i
|
||||||
|
s.send("content_block_start", map[string]any{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": idx,
|
||||||
|
"content_block": map[string]any{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
||||||
|
"name": tc.Name,
|
||||||
|
"input": tc.Input,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s.send("content_block_stop", map[string]any{
|
||||||
|
"type": "content_block_stop",
|
||||||
|
"index": idx,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
s.nextBlockIndex += len(detected)
|
||||||
|
} else if finalText != "" {
|
||||||
|
idx := s.nextBlockIndex
|
||||||
|
s.nextBlockIndex++
|
||||||
|
s.send("content_block_start", map[string]any{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": idx,
|
||||||
|
"content_block": map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s.send("content_block_delta", map[string]any{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": idx,
|
||||||
|
"delta": map[string]any{
|
||||||
|
"type": "text_delta",
|
||||||
|
"text": finalText,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s.send("content_block_stop", map[string]any{
|
||||||
|
"type": "content_block_stop",
|
||||||
|
"index": idx,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||||
|
s.send("message_delta", map[string]any{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": map[string]any{
|
||||||
|
"stop_reason": stopReason,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
},
|
||||||
|
"usage": map[string]any{
|
||||||
|
"output_tokens": outputTokens,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s.send("message_stop", map[string]any{"type": "message_stop"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
|
||||||
|
if string(reason) == "upstream_error" {
|
||||||
|
s.sendError(s.upstreamErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if scannerErr != nil {
|
||||||
|
s.sendError(scannerErr.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.finalize("end_turn")
|
||||||
|
}
|
||||||
100
internal/adapter/claude/stream_status_test.go
Normal file
100
internal/adapter/claude/stream_status_test.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
chimw "github.com/go-chi/chi/v5/middleware"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type streamStatusClaudeAuthStub struct{}
|
||||||
|
|
||||||
|
func (streamStatusClaudeAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||||
|
return &auth.RequestAuth{
|
||||||
|
UseConfigToken: false,
|
||||||
|
DeepSeekToken: "direct-token",
|
||||||
|
CallerID: "caller:test",
|
||||||
|
TriedAccounts: map[string]bool{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (streamStatusClaudeAuthStub) Release(_ *auth.RequestAuth) {}
|
||||||
|
|
||||||
|
type streamStatusClaudeDSStub struct{}
|
||||||
|
|
||||||
|
func (streamStatusClaudeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
return "session-id", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (streamStatusClaudeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
return "pow", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (streamStatusClaudeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||||
|
body := "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n" + "data: [DONE]\n"
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: ioNopCloser{strings.NewReader(body)},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ioNopCloser struct {
|
||||||
|
*strings.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ioNopCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
type streamStatusClaudeStoreStub struct{}
|
||||||
|
|
||||||
|
func (streamStatusClaudeStoreStub) ClaudeMapping() map[string]string {
|
||||||
|
return map[string]string{
|
||||||
|
"fast": "deepseek-chat",
|
||||||
|
"slow": "deepseek-reasoner",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||||
|
next.ServeHTTP(ww, r)
|
||||||
|
*statuses = append(*statuses, ww.Status())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) {
|
||||||
|
statuses := make([]int, 0, 1)
|
||||||
|
h := &Handler{
|
||||||
|
Store: streamStatusClaudeStoreStub{},
|
||||||
|
Auth: streamStatusClaudeAuthStub{},
|
||||||
|
DS: streamStatusClaudeDSStub{},
|
||||||
|
}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(captureClaudeStatusMiddleware(&statuses))
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody))
|
||||||
|
req.Header.Set("Authorization", "Bearer direct-token")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if len(statuses) != 1 {
|
||||||
|
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||||
|
}
|
||||||
|
if statuses[0] != http.StatusOK {
|
||||||
|
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
153
internal/adapter/gemini/convert_messages.go
Normal file
153
internal/adapter/gemini/convert_messages.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func geminiMessagesFromRequest(req map[string]any) []any {
|
||||||
|
out := make([]any, 0, 8)
|
||||||
|
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"role": "system",
|
||||||
|
"content": sys,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
contents, _ := req["contents"].([]any)
|
||||||
|
for _, item := range contents {
|
||||||
|
content, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
role := mapGeminiRole(content["role"])
|
||||||
|
if role == "" {
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
parts, _ := content["parts"].([]any)
|
||||||
|
if len(parts) == 0 {
|
||||||
|
if text := strings.TrimSpace(asString(content["text"])); text != "" {
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"role": role,
|
||||||
|
"content": text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
textParts := make([]string, 0, len(parts))
|
||||||
|
flushText := func() {
|
||||||
|
if len(textParts) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"role": role,
|
||||||
|
"content": strings.Join(textParts, "\n"),
|
||||||
|
})
|
||||||
|
textParts = textParts[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rawPart := range parts {
|
||||||
|
part, ok := rawPart.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||||
|
textParts = append(textParts, text)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if fnCall, ok := part["functionCall"].(map[string]any); ok {
|
||||||
|
flushText()
|
||||||
|
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
|
||||||
|
callID := strings.TrimSpace(asString(fnCall["id"]))
|
||||||
|
if callID == "" {
|
||||||
|
callID = "call_gemini"
|
||||||
|
}
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": []any{
|
||||||
|
map[string]any{
|
||||||
|
"id": callID,
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": name,
|
||||||
|
"arguments": stringifyJSON(fnCall["args"]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if fnResp, ok := part["functionResponse"].(map[string]any); ok {
|
||||||
|
flushText()
|
||||||
|
name := strings.TrimSpace(asString(fnResp["name"]))
|
||||||
|
callID := strings.TrimSpace(asString(fnResp["id"]))
|
||||||
|
if callID == "" {
|
||||||
|
callID = strings.TrimSpace(asString(fnResp["callId"]))
|
||||||
|
}
|
||||||
|
if callID == "" {
|
||||||
|
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
|
||||||
|
}
|
||||||
|
if callID == "" {
|
||||||
|
callID = "call_gemini"
|
||||||
|
}
|
||||||
|
content := fnResp["response"]
|
||||||
|
if content == nil {
|
||||||
|
content = fnResp["output"]
|
||||||
|
}
|
||||||
|
if content == nil {
|
||||||
|
content = ""
|
||||||
|
}
|
||||||
|
msg := map[string]any{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": callID,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
if name != "" {
|
||||||
|
msg["name"] = name
|
||||||
|
}
|
||||||
|
out = append(out, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushText()
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeGeminiSystemInstruction(raw any) string {
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case map[string]any:
|
||||||
|
if parts, ok := v["parts"].([]any); ok {
|
||||||
|
texts := make([]string, 0, len(parts))
|
||||||
|
for _, item := range parts {
|
||||||
|
part, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||||
|
texts = append(texts, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(texts, "\n")
|
||||||
|
}
|
||||||
|
if text := strings.TrimSpace(asString(v["text"])); text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapGeminiRole(v any) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(asString(v))) {
|
||||||
|
case "user":
|
||||||
|
return "user"
|
||||||
|
case "model", "assistant":
|
||||||
|
return "assistant"
|
||||||
|
case "system":
|
||||||
|
return "system"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
54
internal/adapter/gemini/convert_passthrough.go
Normal file
54
internal/adapter/gemini/convert_passthrough.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func collectGeminiPassThrough(req map[string]any) map[string]any {
|
||||||
|
cfg, _ := req["generationConfig"].(map[string]any)
|
||||||
|
if len(cfg) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := map[string]any{}
|
||||||
|
if v, ok := cfg["temperature"]; ok {
|
||||||
|
out["temperature"] = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["topP"]; ok {
|
||||||
|
out["top_p"] = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["maxOutputTokens"]; ok {
|
||||||
|
out["max_tokens"] = v
|
||||||
|
}
|
||||||
|
if v, ok := cfg["stopSequences"]; ok {
|
||||||
|
out["stop"] = v
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func asString(v any) string {
|
||||||
|
s, _ := v.(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringifyJSON(v any) string {
|
||||||
|
switch x := v.(type) {
|
||||||
|
case nil:
|
||||||
|
return "{}"
|
||||||
|
case string:
|
||||||
|
s := strings.TrimSpace(x)
|
||||||
|
if s == "" {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
default:
|
||||||
|
b, err := json.Marshal(x)
|
||||||
|
if err != nil || len(b) == 0 {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
46
internal/adapter/gemini/convert_request.go
Normal file
46
internal/adapter/gemini/convert_request.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/adapter/openai"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
|
||||||
|
requestedModel := strings.TrimSpace(routeModel)
|
||||||
|
if requestedModel == "" {
|
||||||
|
return util.StandardRequest{}, fmt.Errorf("model is required in request path")
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedModel, ok := config.ResolveModel(store, requestedModel)
|
||||||
|
if !ok {
|
||||||
|
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
|
||||||
|
}
|
||||||
|
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||||
|
|
||||||
|
messagesRaw := geminiMessagesFromRequest(req)
|
||||||
|
if len(messagesRaw) == 0 {
|
||||||
|
return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsRaw := convertGeminiTools(req["tools"])
|
||||||
|
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
|
||||||
|
passThrough := collectGeminiPassThrough(req)
|
||||||
|
|
||||||
|
return util.StandardRequest{
|
||||||
|
Surface: "google_gemini",
|
||||||
|
RequestedModel: requestedModel,
|
||||||
|
ResolvedModel: resolvedModel,
|
||||||
|
ResponseModel: requestedModel,
|
||||||
|
Messages: messagesRaw,
|
||||||
|
FinalPrompt: finalPrompt,
|
||||||
|
ToolNames: toolNames,
|
||||||
|
Stream: stream,
|
||||||
|
Thinking: thinkingEnabled,
|
||||||
|
Search: searchEnabled,
|
||||||
|
PassThrough: passThrough,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
71
internal/adapter/gemini/convert_tools.go
Normal file
71
internal/adapter/gemini/convert_tools.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func convertGeminiTools(raw any) []any {
|
||||||
|
tools, _ := raw.([]any)
|
||||||
|
if len(tools) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]any, 0, len(tools))
|
||||||
|
for _, item := range tools {
|
||||||
|
tool, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
|
||||||
|
for _, declRaw := range fnDecls {
|
||||||
|
decl, ok := declRaw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(asString(decl["name"]))
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
function := map[string]any{
|
||||||
|
"name": name,
|
||||||
|
}
|
||||||
|
if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
|
||||||
|
function["description"] = desc
|
||||||
|
}
|
||||||
|
if params, ok := decl["parameters"].(map[string]any); ok {
|
||||||
|
function["parameters"] = params
|
||||||
|
}
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": function,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI-style passthrough fallback.
|
||||||
|
if _, ok := tool["function"].(map[string]any); ok {
|
||||||
|
out = append(out, tool)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loose fallback for flattened function schema objects.
|
||||||
|
name := strings.TrimSpace(asString(tool["name"]))
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fn := map[string]any{"name": name}
|
||||||
|
if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
|
||||||
|
fn["description"] = desc
|
||||||
|
}
|
||||||
|
if params, ok := tool["parameters"].(map[string]any); ok {
|
||||||
|
fn["parameters"] = params
|
||||||
|
}
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": fn,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
29
internal/adapter/gemini/deps.go
Normal file
29
internal/adapter/gemini/deps.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthResolver interface {
|
||||||
|
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||||
|
Release(a *auth.RequestAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeepSeekCaller interface {
|
||||||
|
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
|
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
|
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigReader interface {
|
||||||
|
ModelAliases() map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||||
|
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||||
|
var _ ConfigReader = (*config.Store)(nil)
|
||||||
28
internal/adapter/gemini/handler_errors.go
Normal file
28
internal/adapter/gemini/handler_errors.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
func writeGeminiError(w http.ResponseWriter, status int, message string) {
|
||||||
|
errorStatus := "INVALID_ARGUMENT"
|
||||||
|
switch status {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
errorStatus = "UNAUTHENTICATED"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
errorStatus = "PERMISSION_DENIED"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
errorStatus = "RESOURCE_EXHAUSTED"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
errorStatus = "NOT_FOUND"
|
||||||
|
default:
|
||||||
|
if status >= 500 {
|
||||||
|
errorStatus = "INTERNAL"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
writeJSON(w, status, map[string]any{
|
||||||
|
"error": map[string]any{
|
||||||
|
"code": status,
|
||||||
|
"message": message,
|
||||||
|
"status": errorStatus,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
135
internal/adapter/gemini/handler_generate.go
Normal file
135
internal/adapter/gemini/handler_generate.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/sse"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
|
||||||
|
a, err := h.Auth.Determine(r)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusUnauthorized
|
||||||
|
detail := err.Error()
|
||||||
|
if err == auth.ErrNoAccount {
|
||||||
|
status = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
writeGeminiError(w, status, detail)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer h.Auth.Release(a)
|
||||||
|
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||||
|
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
|
||||||
|
if err != nil {
|
||||||
|
writeGeminiError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||||
|
if err != nil {
|
||||||
|
if a.UseConfigToken {
|
||||||
|
writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||||
|
} else {
|
||||||
|
writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||||
|
if err != nil {
|
||||||
|
writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload := stdReq.CompletionPayload(sessionID)
|
||||||
|
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||||
|
if err != nil {
|
||||||
|
writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream {
|
||||||
|
h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||||
|
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||||
|
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
|
||||||
|
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
|
||||||
|
return map[string]any{
|
||||||
|
"candidates": []map[string]any{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"content": map[string]any{
|
||||||
|
"role": "model",
|
||||||
|
"parts": parts,
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"modelVersion": model,
|
||||||
|
"usageMetadata": usage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||||
|
promptTokens := util.EstimateTokens(finalPrompt)
|
||||||
|
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||||
|
completionTokens := util.EstimateTokens(finalText)
|
||||||
|
return map[string]any{
|
||||||
|
"promptTokenCount": promptTokens,
|
||||||
|
"candidatesTokenCount": reasoningTokens + completionTokens,
|
||||||
|
"totalTokenCount": promptTokens + reasoningTokens + completionTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
|
||||||
|
detected := util.ParseToolCalls(finalText, toolNames)
|
||||||
|
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||||
|
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||||
|
}
|
||||||
|
if len(detected) > 0 {
|
||||||
|
parts := make([]map[string]any, 0, len(detected))
|
||||||
|
for _, tc := range detected {
|
||||||
|
parts = append(parts, map[string]any{
|
||||||
|
"functionCall": map[string]any{
|
||||||
|
"name": tc.Name,
|
||||||
|
"args": tc.Input,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
|
text := finalText
|
||||||
|
if strings.TrimSpace(text) == "" {
|
||||||
|
text = finalThinking
|
||||||
|
}
|
||||||
|
return []map[string]any{{"text": text}}
|
||||||
|
}
|
||||||
32
internal/adapter/gemini/handler_routes.go
Normal file
32
internal/adapter/gemini/handler_routes.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var writeJSON = util.WriteJSON
|
||||||
|
|
||||||
|
type Handler struct {
|
||||||
|
Store ConfigReader
|
||||||
|
Auth AuthResolver
|
||||||
|
DS DeepSeekCaller
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||||
|
r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
|
||||||
|
r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||||
|
r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
|
||||||
|
r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.handleGenerateContent(w, r, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.handleGenerateContent(w, r, true)
|
||||||
|
}
|
||||||
181
internal/adapter/gemini/handler_stream_runtime.go
Normal file
181
internal/adapter/gemini/handler_stream_runtime.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
"ds2api/internal/sse"
|
||||||
|
streamengine "ds2api/internal/stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
rc := http.NewResponseController(w)
|
||||||
|
_, canFlush := w.(http.Flusher)
|
||||||
|
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||||
|
|
||||||
|
initialType := "text"
|
||||||
|
if thinkingEnabled {
|
||||||
|
initialType = "thinking"
|
||||||
|
}
|
||||||
|
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||||
|
Context: r.Context(),
|
||||||
|
Body: resp.Body,
|
||||||
|
ThinkingEnabled: thinkingEnabled,
|
||||||
|
InitialType: initialType,
|
||||||
|
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||||
|
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||||
|
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||||
|
}, streamengine.ConsumeHooks{
|
||||||
|
OnParsed: runtime.onParsed,
|
||||||
|
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||||
|
runtime.finalize()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type geminiStreamRuntime struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
rc *http.ResponseController
|
||||||
|
canFlush bool
|
||||||
|
|
||||||
|
model string
|
||||||
|
finalPrompt string
|
||||||
|
|
||||||
|
thinkingEnabled bool
|
||||||
|
searchEnabled bool
|
||||||
|
bufferContent bool
|
||||||
|
toolNames []string
|
||||||
|
|
||||||
|
thinking strings.Builder
|
||||||
|
text strings.Builder
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGeminiStreamRuntime(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
rc *http.ResponseController,
|
||||||
|
canFlush bool,
|
||||||
|
model string,
|
||||||
|
finalPrompt string,
|
||||||
|
thinkingEnabled bool,
|
||||||
|
searchEnabled bool,
|
||||||
|
toolNames []string,
|
||||||
|
) *geminiStreamRuntime {
|
||||||
|
return &geminiStreamRuntime{
|
||||||
|
w: w,
|
||||||
|
rc: rc,
|
||||||
|
canFlush: canFlush,
|
||||||
|
model: model,
|
||||||
|
finalPrompt: finalPrompt,
|
||||||
|
thinkingEnabled: thinkingEnabled,
|
||||||
|
searchEnabled: searchEnabled,
|
||||||
|
bufferContent: len(toolNames) > 0,
|
||||||
|
toolNames: toolNames,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
|
||||||
|
b, _ := json.Marshal(payload)
|
||||||
|
_, _ = s.w.Write([]byte("data: "))
|
||||||
|
_, _ = s.w.Write(b)
|
||||||
|
_, _ = s.w.Write([]byte("\n\n"))
|
||||||
|
if s.canFlush {
|
||||||
|
_ = s.rc.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||||
|
if !parsed.Parsed {
|
||||||
|
return streamengine.ParsedDecision{}
|
||||||
|
}
|
||||||
|
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||||
|
return streamengine.ParsedDecision{Stop: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentSeen := false
|
||||||
|
for _, p := range parsed.Parts {
|
||||||
|
if p.Text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentSeen = true
|
||||||
|
if p.Type == "thinking" {
|
||||||
|
if s.thinkingEnabled {
|
||||||
|
s.thinking.WriteString(p.Text)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.text.WriteString(p.Text)
|
||||||
|
if s.bufferContent {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.sendChunk(map[string]any{
|
||||||
|
"candidates": []map[string]any{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"content": map[string]any{
|
||||||
|
"role": "model",
|
||||||
|
"parts": []map[string]any{{"text": p.Text}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"modelVersion": s.model,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *geminiStreamRuntime) finalize() {
|
||||||
|
finalThinking := s.thinking.String()
|
||||||
|
finalText := s.text.String()
|
||||||
|
|
||||||
|
if s.bufferContent {
|
||||||
|
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
|
||||||
|
s.sendChunk(map[string]any{
|
||||||
|
"candidates": []map[string]any{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"content": map[string]any{
|
||||||
|
"role": "model",
|
||||||
|
"parts": parts,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"modelVersion": s.model,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendChunk(map[string]any{
|
||||||
|
"candidates": []map[string]any{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"content": map[string]any{
|
||||||
|
"role": "model",
|
||||||
|
"parts": []map[string]any{
|
||||||
|
{"text": ""},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"modelVersion": s.model,
|
||||||
|
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
|
||||||
|
})
|
||||||
|
}
|
||||||
216
internal/adapter/gemini/handler_test.go
Normal file
216
internal/adapter/gemini/handler_test.go
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testGeminiConfig struct{}
|
||||||
|
|
||||||
|
func (testGeminiConfig) ModelAliases() map[string]string { return nil }
|
||||||
|
|
||||||
|
type testGeminiAuth struct {
|
||||||
|
a *auth.RequestAuth
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m testGeminiAuth) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
if m.a != nil {
|
||||||
|
return m.a, nil
|
||||||
|
}
|
||||||
|
return &auth.RequestAuth{
|
||||||
|
UseConfigToken: false,
|
||||||
|
DeepSeekToken: "direct-token",
|
||||||
|
CallerID: "caller:test",
|
||||||
|
TriedAccounts: map[string]bool{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (testGeminiAuth) Release(_ *auth.RequestAuth) {}
|
||||||
|
|
||||||
|
type testGeminiDS struct {
|
||||||
|
resp *http.Response
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m testGeminiDS) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
return "session-id", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m testGeminiDS) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
return "pow", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
return m.resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeGeminiUpstreamResponse(lines ...string) *http.Response {
|
||||||
|
body := strings.Join(lines, "\n")
|
||||||
|
if !strings.HasSuffix(body, "\n") {
|
||||||
|
body += "\n"
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiRoutesRegistered(t *testing.T) {
|
||||||
|
h := &Handler{
|
||||||
|
Store: testGeminiConfig{},
|
||||||
|
Auth: testGeminiAuth{err: auth.ErrUnauthorized},
|
||||||
|
}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
paths := []string{
|
||||||
|
"/v1beta/models/gemini-2.5-pro:generateContent",
|
||||||
|
"/v1beta/models/gemini-2.5-pro:streamGenerateContent",
|
||||||
|
"/v1/models/gemini-2.5-pro:generateContent",
|
||||||
|
"/v1/models/gemini-2.5-pro:streamGenerateContent",
|
||||||
|
}
|
||||||
|
for _, path := range paths {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code == http.StatusNotFound {
|
||||||
|
t.Fatalf("expected route %s to be registered, got 404", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||||
|
upstream := makeGeminiUpstreamResponse(
|
||||||
|
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
h := &Handler{
|
||||||
|
Store: testGeminiConfig{},
|
||||||
|
Auth: testGeminiAuth{},
|
||||||
|
DS: testGeminiDS{resp: upstream},
|
||||||
|
}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
body := `{
|
||||||
|
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
|
||||||
|
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||||
|
}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||||
|
req.Header.Set("Authorization", "Bearer direct-token")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var out map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||||
|
t.Fatalf("decode response failed: %v", err)
|
||||||
|
}
|
||||||
|
candidates, _ := out["candidates"].([]any)
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
t.Fatalf("expected non-empty candidates: %#v", out)
|
||||||
|
}
|
||||||
|
c0, _ := candidates[0].(map[string]any)
|
||||||
|
content, _ := c0["content"].(map[string]any)
|
||||||
|
parts, _ := content["parts"].([]any)
|
||||||
|
if len(parts) == 0 {
|
||||||
|
t.Fatalf("expected non-empty parts: %#v", content)
|
||||||
|
}
|
||||||
|
part0, _ := parts[0].(map[string]any)
|
||||||
|
functionCall, _ := part0["functionCall"].(map[string]any)
|
||||||
|
if functionCall["name"] != "eval_javascript" {
|
||||||
|
t.Fatalf("expected functionCall name eval_javascript, got %#v", functionCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||||
|
upstream := makeGeminiUpstreamResponse(
|
||||||
|
`data: {"p":"response/content","v":"hello "}`,
|
||||||
|
`data: {"p":"response/content","v":"world"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
h := &Handler{
|
||||||
|
Store: testGeminiConfig{},
|
||||||
|
Auth: testGeminiAuth{},
|
||||||
|
DS: testGeminiDS{resp: upstream},
|
||||||
|
}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:streamGenerateContent?alt=sse", strings.NewReader(body))
|
||||||
|
req.Header.Set("Authorization", "Bearer direct-token")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), "data: ") {
|
||||||
|
t.Fatalf("expected SSE data frames, got body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), `"finishReason":"STOP"`) {
|
||||||
|
t.Fatalf("expected stream finish frame, got body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
frames := extractGeminiSSEFrames(t, rec.Body.String())
|
||||||
|
if len(frames) == 0 {
|
||||||
|
t.Fatalf("expected non-empty sse frames, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
last := frames[len(frames)-1]
|
||||||
|
candidates, _ := last["candidates"].([]any)
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
t.Fatalf("expected finish frame candidates, got %#v", last)
|
||||||
|
}
|
||||||
|
c0, _ := candidates[0].(map[string]any)
|
||||||
|
content, _ := c0["content"].(map[string]any)
|
||||||
|
if content == nil {
|
||||||
|
t.Fatalf("expected non-null content in finish frame, got %#v", c0)
|
||||||
|
}
|
||||||
|
parts, _ := content["parts"].([]any)
|
||||||
|
if len(parts) == 0 {
|
||||||
|
t.Fatalf("expected non-empty parts in finish frame content, got %#v", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractGeminiSSEFrames(t *testing.T, body string) []map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||||
|
out := make([]map[string]any, 0, 4)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||||
|
if raw == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var frame map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(raw), &frame); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, frame)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
270
internal/adapter/openai/chat_stream_runtime.go
Normal file
270
internal/adapter/openai/chat_stream_runtime.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
openaifmt "ds2api/internal/format/openai"
|
||||||
|
"ds2api/internal/sse"
|
||||||
|
streamengine "ds2api/internal/stream"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type chatStreamRuntime struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
rc *http.ResponseController
|
||||||
|
canFlush bool
|
||||||
|
|
||||||
|
completionID string
|
||||||
|
created int64
|
||||||
|
model string
|
||||||
|
finalPrompt string
|
||||||
|
toolNames []string
|
||||||
|
|
||||||
|
thinkingEnabled bool
|
||||||
|
searchEnabled bool
|
||||||
|
|
||||||
|
firstChunkSent bool
|
||||||
|
bufferToolContent bool
|
||||||
|
emitEarlyToolDeltas bool
|
||||||
|
toolCallsEmitted bool
|
||||||
|
toolCallsDoneEmitted bool
|
||||||
|
|
||||||
|
toolSieve toolStreamSieveState
|
||||||
|
streamToolCallIDs map[int]string
|
||||||
|
streamToolNames map[int]string
|
||||||
|
thinking strings.Builder
|
||||||
|
text strings.Builder
|
||||||
|
}
|
||||||
|
|
||||||
|
func newChatStreamRuntime(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
rc *http.ResponseController,
|
||||||
|
canFlush bool,
|
||||||
|
completionID string,
|
||||||
|
created int64,
|
||||||
|
model string,
|
||||||
|
finalPrompt string,
|
||||||
|
thinkingEnabled bool,
|
||||||
|
searchEnabled bool,
|
||||||
|
toolNames []string,
|
||||||
|
bufferToolContent bool,
|
||||||
|
emitEarlyToolDeltas bool,
|
||||||
|
) *chatStreamRuntime {
|
||||||
|
return &chatStreamRuntime{
|
||||||
|
w: w,
|
||||||
|
rc: rc,
|
||||||
|
canFlush: canFlush,
|
||||||
|
completionID: completionID,
|
||||||
|
created: created,
|
||||||
|
model: model,
|
||||||
|
finalPrompt: finalPrompt,
|
||||||
|
toolNames: toolNames,
|
||||||
|
thinkingEnabled: thinkingEnabled,
|
||||||
|
searchEnabled: searchEnabled,
|
||||||
|
bufferToolContent: bufferToolContent,
|
||||||
|
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||||
|
streamToolCallIDs: map[int]string{},
|
||||||
|
streamToolNames: map[int]string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *chatStreamRuntime) sendKeepAlive() {
|
||||||
|
if !s.canFlush {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = s.w.Write([]byte(": keep-alive\n\n"))
|
||||||
|
_ = s.rc.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *chatStreamRuntime) sendChunk(v any) {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
_, _ = s.w.Write([]byte("data: "))
|
||||||
|
_, _ = s.w.Write(b)
|
||||||
|
_, _ = s.w.Write([]byte("\n\n"))
|
||||||
|
if s.canFlush {
|
||||||
|
_ = s.rc.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *chatStreamRuntime) sendDone() {
|
||||||
|
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
if s.canFlush {
|
||||||
|
_ = s.rc.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||||
|
finalThinking := s.thinking.String()
|
||||||
|
finalText := s.text.String()
|
||||||
|
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||||
|
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
delta := map[string]any{
|
||||||
|
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs),
|
||||||
|
}
|
||||||
|
if !s.firstChunkSent {
|
||||||
|
delta["role"] = "assistant"
|
||||||
|
s.firstChunkSent = true
|
||||||
|
}
|
||||||
|
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||||
|
s.completionID,
|
||||||
|
s.created,
|
||||||
|
s.model,
|
||||||
|
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
|
||||||
|
nil,
|
||||||
|
))
|
||||||
|
s.toolCallsEmitted = true
|
||||||
|
s.toolCallsDoneEmitted = true
|
||||||
|
} else if s.bufferToolContent {
|
||||||
|
for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) {
|
||||||
|
if len(evt.ToolCalls) > 0 {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
s.toolCallsEmitted = true
|
||||||
|
s.toolCallsDoneEmitted = true
|
||||||
|
tcDelta := map[string]any{
|
||||||
|
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||||
|
}
|
||||||
|
if !s.firstChunkSent {
|
||||||
|
tcDelta["role"] = "assistant"
|
||||||
|
s.firstChunkSent = true
|
||||||
|
}
|
||||||
|
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||||
|
s.completionID,
|
||||||
|
s.created,
|
||||||
|
s.model,
|
||||||
|
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
|
||||||
|
nil,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
if evt.Content == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delta := map[string]any{
|
||||||
|
"content": evt.Content,
|
||||||
|
}
|
||||||
|
if !s.firstChunkSent {
|
||||||
|
delta["role"] = "assistant"
|
||||||
|
s.firstChunkSent = true
|
||||||
|
}
|
||||||
|
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||||
|
s.completionID,
|
||||||
|
s.created,
|
||||||
|
s.model,
|
||||||
|
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
|
||||||
|
nil,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(detected) > 0 || s.toolCallsEmitted {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||||
|
s.completionID,
|
||||||
|
s.created,
|
||||||
|
s.model,
|
||||||
|
[]map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)},
|
||||||
|
openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText),
|
||||||
|
))
|
||||||
|
s.sendDone()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||||
|
if !parsed.Parsed {
|
||||||
|
return streamengine.ParsedDecision{}
|
||||||
|
}
|
||||||
|
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||||
|
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")}
|
||||||
|
}
|
||||||
|
if parsed.Stop {
|
||||||
|
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested}
|
||||||
|
}
|
||||||
|
|
||||||
|
newChoices := make([]map[string]any, 0, len(parsed.Parts))
|
||||||
|
contentSeen := false
|
||||||
|
for _, p := range parsed.Parts {
|
||||||
|
if s.searchEnabled && sse.IsCitation(p.Text) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p.Text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentSeen = true
|
||||||
|
delta := map[string]any{}
|
||||||
|
if !s.firstChunkSent {
|
||||||
|
delta["role"] = "assistant"
|
||||||
|
s.firstChunkSent = true
|
||||||
|
}
|
||||||
|
if p.Type == "thinking" {
|
||||||
|
if s.thinkingEnabled {
|
||||||
|
s.thinking.WriteString(p.Text)
|
||||||
|
delta["reasoning_content"] = p.Text
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
s.text.WriteString(p.Text)
|
||||||
|
if !s.bufferToolContent {
|
||||||
|
delta["content"] = p.Text
|
||||||
|
} else {
|
||||||
|
events := processToolSieveChunk(&s.toolSieve, p.Text, s.toolNames)
|
||||||
|
for _, evt := range events {
|
||||||
|
if len(evt.ToolCallDeltas) > 0 {
|
||||||
|
if !s.emitEarlyToolDeltas {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames)
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
formatted := formatIncrementalStreamToolCallDeltas(filtered, s.streamToolCallIDs)
|
||||||
|
if len(formatted) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tcDelta := map[string]any{
|
||||||
|
"tool_calls": formatted,
|
||||||
|
}
|
||||||
|
s.toolCallsEmitted = true
|
||||||
|
if !s.firstChunkSent {
|
||||||
|
tcDelta["role"] = "assistant"
|
||||||
|
s.firstChunkSent = true
|
||||||
|
}
|
||||||
|
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(evt.ToolCalls) > 0 {
|
||||||
|
s.toolCallsEmitted = true
|
||||||
|
s.toolCallsDoneEmitted = true
|
||||||
|
tcDelta := map[string]any{
|
||||||
|
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||||
|
}
|
||||||
|
if !s.firstChunkSent {
|
||||||
|
tcDelta["role"] = "assistant"
|
||||||
|
s.firstChunkSent = true
|
||||||
|
}
|
||||||
|
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if evt.Content != "" {
|
||||||
|
contentDelta := map[string]any{
|
||||||
|
"content": evt.Content,
|
||||||
|
}
|
||||||
|
if !s.firstChunkSent {
|
||||||
|
contentDelta["role"] = "assistant"
|
||||||
|
s.firstChunkSent = true
|
||||||
|
}
|
||||||
|
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, contentDelta))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(delta) > 0 {
|
||||||
|
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, delta))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newChoices) > 0 {
|
||||||
|
s.sendChunk(openaifmt.BuildChatStreamChunk(s.completionID, s.created, s.model, newChoices, nil))
|
||||||
|
}
|
||||||
|
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||||
|
}
|
||||||
35
internal/adapter/openai/deps.go
Normal file
35
internal/adapter/openai/deps.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthResolver interface {
|
||||||
|
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||||
|
DetermineCaller(req *http.Request) (*auth.RequestAuth, error)
|
||||||
|
Release(a *auth.RequestAuth)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeepSeekCaller interface {
|
||||||
|
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
|
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
|
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigReader interface {
|
||||||
|
ModelAliases() map[string]string
|
||||||
|
CompatWideInputStrictOutput() bool
|
||||||
|
ToolcallMode() string
|
||||||
|
ToolcallEarlyEmitConfidence() string
|
||||||
|
ResponsesStoreTTLSeconds() int
|
||||||
|
EmbeddingsProvider() string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||||
|
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||||
|
var _ ConfigReader = (*config.Store)(nil)
|
||||||
70
internal/adapter/openai/deps_injection_test.go
Normal file
70
internal/adapter/openai/deps_injection_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
type mockOpenAIConfig struct {
|
||||||
|
aliases map[string]string
|
||||||
|
wideInput bool
|
||||||
|
toolMode string
|
||||||
|
earlyEmit string
|
||||||
|
responsesTTL int
|
||||||
|
embedProv string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockOpenAIConfig) ModelAliases() map[string]string { return m.aliases }
|
||||||
|
func (m mockOpenAIConfig) CompatWideInputStrictOutput() bool {
|
||||||
|
return m.wideInput
|
||||||
|
}
|
||||||
|
func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMode }
|
||||||
|
func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit }
|
||||||
|
func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL }
|
||||||
|
func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv }
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||||
|
cfg := mockOpenAIConfig{
|
||||||
|
aliases: map[string]string{
|
||||||
|
"my-model": "deepseek-chat-search",
|
||||||
|
},
|
||||||
|
wideInput: true,
|
||||||
|
}
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "my-model",
|
||||||
|
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
|
||||||
|
}
|
||||||
|
out, err := normalizeOpenAIChatRequest(cfg, req, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeOpenAIChatRequest error: %v", err)
|
||||||
|
}
|
||||||
|
if out.ResolvedModel != "deepseek-chat-search" {
|
||||||
|
t.Fatalf("resolved model mismatch: got=%q", out.ResolvedModel)
|
||||||
|
}
|
||||||
|
if !out.Search || out.Thinking {
|
||||||
|
t.Fatalf("unexpected model flags: thinking=%v search=%v", out.Thinking, out.Search)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.T) {
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"input": "hi",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||||
|
aliases: map[string]string{},
|
||||||
|
wideInput: false,
|
||||||
|
}, req, "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when wide input is disabled and only input is provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||||
|
aliases: map[string]string{},
|
||||||
|
wideInput: true,
|
||||||
|
}, req, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error when wide input is enabled: %v", err)
|
||||||
|
}
|
||||||
|
if out.Surface != "openai_responses" {
|
||||||
|
t.Fatalf("unexpected surface: %q", out.Surface)
|
||||||
|
}
|
||||||
|
}
|
||||||
138
internal/adapter/openai/embeddings_handler.go
Normal file
138
internal/adapter/openai/embeddings_handler.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
|
a, err := h.Auth.Determine(r)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusUnauthorized
|
||||||
|
detail := err.Error()
|
||||||
|
if err == auth.ErrNoAccount {
|
||||||
|
status = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
writeOpenAIError(w, status, detail)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer h.Auth.Release(a)
|
||||||
|
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model, _ := req["model"].(string)
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if model == "" {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := config.ResolveModel(h.Store, model); !ok {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs := extractEmbeddingInputs(req["input"])
|
||||||
|
if len(inputs) == 0 {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, "Request must include non-empty 'input'.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := ""
|
||||||
|
if h.Store != nil {
|
||||||
|
provider = strings.ToLower(strings.TrimSpace(h.Store.EmbeddingsProvider()))
|
||||||
|
}
|
||||||
|
if provider == "" {
|
||||||
|
writeOpenAIError(w, http.StatusNotImplemented, "Embeddings provider is not configured. Set embeddings.provider in config.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch provider {
|
||||||
|
case "mock", "deterministic", "builtin":
|
||||||
|
// supported local deterministic provider
|
||||||
|
default:
|
||||||
|
writeOpenAIError(w, http.StatusNotImplemented, fmt.Sprintf("Embeddings provider '%s' is not supported.", provider))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]map[string]any, 0, len(inputs))
|
||||||
|
totalTokens := 0
|
||||||
|
for i, input := range inputs {
|
||||||
|
totalTokens += util.EstimateTokens(input)
|
||||||
|
data = append(data, map[string]any{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": i,
|
||||||
|
"embedding": deterministicEmbedding(input),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
"model": model,
|
||||||
|
"usage": map[string]any{
|
||||||
|
"prompt_tokens": totalTokens,
|
||||||
|
"total_tokens": totalTokens,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractEmbeddingInputs(raw any) []string {
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
s := strings.TrimSpace(v)
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{s}
|
||||||
|
case []any:
|
||||||
|
out := make([]string, 0, len(v))
|
||||||
|
for _, item := range v {
|
||||||
|
switch iv := item.(type) {
|
||||||
|
case string:
|
||||||
|
s := strings.TrimSpace(iv)
|
||||||
|
if s != "" {
|
||||||
|
out = append(out, s)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
// Token array input support: convert to stable string form.
|
||||||
|
out = append(out, fmt.Sprintf("%v", iv))
|
||||||
|
default:
|
||||||
|
s := strings.TrimSpace(fmt.Sprintf("%v", iv))
|
||||||
|
if s != "" {
|
||||||
|
out = append(out, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func deterministicEmbedding(input string) []float64 {
|
||||||
|
// Keep response shape stable without external dependencies.
|
||||||
|
const dims = 64
|
||||||
|
out := make([]float64, dims)
|
||||||
|
seed := sha256.Sum256([]byte(input))
|
||||||
|
buf := seed[:]
|
||||||
|
for i := 0; i < dims; i++ {
|
||||||
|
if len(buf) < 4 {
|
||||||
|
next := sha256.Sum256(buf)
|
||||||
|
buf = next[:]
|
||||||
|
}
|
||||||
|
v := binary.BigEndian.Uint32(buf[:4])
|
||||||
|
buf = buf[4:]
|
||||||
|
// map [0, 2^32) -> [-1, 1]
|
||||||
|
out[i] = (float64(v)/2147483647.5 - 1.0)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
96
internal/adapter/openai/embeddings_route_test.go
Normal file
96
internal/adapter/openai/embeddings_route_test.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/account"
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", cfgJSON)
|
||||||
|
store := config.LoadStore()
|
||||||
|
pool := account.NewPool(store)
|
||||||
|
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||||
|
return "unused", nil
|
||||||
|
})
|
||||||
|
return store, resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsRouteContract(t *testing.T) {
|
||||||
|
store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`)
|
||||||
|
h := &Handler{Store: store, Auth: resolver}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
t.Run("unauthorized", func(t *testing.T) {
|
||||||
|
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ok", func(t *testing.T) {
|
||||||
|
body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||||
|
req.Header.Set("Authorization", "Bearer test-token")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
var out map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||||
|
t.Fatalf("decode response failed: %v", err)
|
||||||
|
}
|
||||||
|
if out["object"] != "list" {
|
||||||
|
t.Fatalf("unexpected object: %#v", out["object"])
|
||||||
|
}
|
||||||
|
data, _ := out["data"].([]any)
|
||||||
|
if len(data) != 2 {
|
||||||
|
t.Fatalf("expected 2 embeddings, got %d", len(data))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsRouteProviderMissing(t *testing.T) {
|
||||||
|
store, resolver := newResolverWithConfigJSON(t, `{}`)
|
||||||
|
h := &Handler{Store: store, Auth: resolver}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||||
|
req.Header.Set("Authorization", "Bearer test-token")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusNotImplemented {
|
||||||
|
t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
var out map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||||
|
t.Fatalf("decode response failed: %v", err)
|
||||||
|
}
|
||||||
|
errObj, _ := out["error"].(map[string]any)
|
||||||
|
if _, ok := errObj["code"]; !ok {
|
||||||
|
t.Fatalf("expected error.code in response: %#v", out)
|
||||||
|
}
|
||||||
|
if _, ok := errObj["param"]; !ok {
|
||||||
|
t.Fatalf("expected error.param in response: %#v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
34
internal/adapter/openai/error_shape_test.go
Normal file
34
internal/adapter/openai/error_shape_test.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
writeOpenAIError(rec, http.StatusBadRequest, "invalid input")
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
errObj, _ := body["error"].(map[string]any)
|
||||||
|
if errObj["message"] != "invalid input" {
|
||||||
|
t.Fatalf("unexpected message: %v", errObj["message"])
|
||||||
|
}
|
||||||
|
if errObj["type"] != "invalid_request_error" {
|
||||||
|
t.Fatalf("unexpected type: %v", errObj["type"])
|
||||||
|
}
|
||||||
|
if errObj["code"] != "invalid_request" {
|
||||||
|
t.Fatalf("unexpected code: %v", errObj["code"])
|
||||||
|
}
|
||||||
|
if _, ok := errObj["param"]; !ok {
|
||||||
|
t.Fatal("expected param field")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,487 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
|
|
||||||
"ds2api/internal/auth"
|
|
||||||
"ds2api/internal/config"
|
|
||||||
"ds2api/internal/deepseek"
|
|
||||||
"ds2api/internal/sse"
|
|
||||||
"ds2api/internal/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
// writeJSON is a package-internal alias kept to avoid mass-renaming across
|
|
||||||
// every call-site in this file. It delegates to the shared util version.
|
|
||||||
var writeJSON = util.WriteJSON
|
|
||||||
|
|
||||||
type Handler struct {
|
|
||||||
Store *config.Store
|
|
||||||
Auth *auth.Resolver
|
|
||||||
DS *deepseek.Client
|
|
||||||
|
|
||||||
leaseMu sync.Mutex
|
|
||||||
streamLeases map[string]streamLease
|
|
||||||
}
|
|
||||||
|
|
||||||
type streamLease struct {
|
|
||||||
Auth *auth.RequestAuth
|
|
||||||
ExpiresAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
|
||||||
r.Get("/v1/models", h.ListModels)
|
|
||||||
r.Post("/v1/chat/completions", h.ChatCompletions)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if isVercelStreamReleaseRequest(r) {
|
|
||||||
h.handleVercelStreamRelease(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if isVercelStreamPrepareRequest(r) {
|
|
||||||
h.handleVercelStreamPrepare(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
a, err := h.Auth.Determine(r)
|
|
||||||
if err != nil {
|
|
||||||
status := http.StatusUnauthorized
|
|
||||||
detail := err.Error()
|
|
||||||
if err == auth.ErrNoAccount {
|
|
||||||
status = http.StatusTooManyRequests
|
|
||||||
}
|
|
||||||
writeOpenAIError(w, status, detail)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer h.Auth.Release(a)
|
|
||||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
|
||||||
|
|
||||||
var req map[string]any
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
model, _ := req["model"].(string)
|
|
||||||
messagesRaw, _ := req["messages"].([]any)
|
|
||||||
if model == "" || len(messagesRaw) == 0 {
|
|
||||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
|
|
||||||
if !ok {
|
|
||||||
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := normalizeMessages(messagesRaw)
|
|
||||||
toolNames := []string{}
|
|
||||||
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
|
|
||||||
messages, toolNames = injectToolPrompt(messages, tools)
|
|
||||||
}
|
|
||||||
finalPrompt := util.MessagesPrepare(messages)
|
|
||||||
|
|
||||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
|
||||||
if err != nil {
|
|
||||||
if a.UseConfigToken {
|
|
||||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
|
||||||
} else {
|
|
||||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
|
||||||
if err != nil {
|
|
||||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
payload := map[string]any{
|
|
||||||
"chat_session_id": sessionID,
|
|
||||||
"parent_message_id": nil,
|
|
||||||
"prompt": finalPrompt,
|
|
||||||
"ref_file_ids": []any{},
|
|
||||||
"thinking_enabled": thinkingEnabled,
|
|
||||||
"search_enabled": searchEnabled,
|
|
||||||
}
|
|
||||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
|
||||||
if err != nil {
|
|
||||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if util.ToBool(req["stream"]) {
|
|
||||||
h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = ctx
|
|
||||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
|
||||||
|
|
||||||
finalThinking := result.Thinking
|
|
||||||
finalText := result.Text
|
|
||||||
detected := util.ParseToolCalls(finalText, toolNames)
|
|
||||||
finishReason := "stop"
|
|
||||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
|
||||||
if thinkingEnabled && finalThinking != "" {
|
|
||||||
messageObj["reasoning_content"] = finalThinking
|
|
||||||
}
|
|
||||||
if len(detected) > 0 {
|
|
||||||
finishReason = "tool_calls"
|
|
||||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
|
||||||
messageObj["content"] = nil
|
|
||||||
}
|
|
||||||
promptTokens := util.EstimateTokens(finalPrompt)
|
|
||||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
|
||||||
completionTokens := util.EstimateTokens(finalText)
|
|
||||||
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
|
||||||
"id": completionID,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": time.Now().Unix(),
|
|
||||||
"model": model,
|
|
||||||
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
|
|
||||||
"usage": map[string]any{
|
|
||||||
"prompt_tokens": promptTokens,
|
|
||||||
"completion_tokens": reasoningTokens + completionTokens,
|
|
||||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
|
||||||
"completion_tokens_details": map[string]any{
|
|
||||||
"reasoning_tokens": reasoningTokens,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
rc := http.NewResponseController(w)
|
|
||||||
canFlush := rc.Flush() == nil
|
|
||||||
if !canFlush {
|
|
||||||
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
|
|
||||||
}
|
|
||||||
|
|
||||||
created := time.Now().Unix()
|
|
||||||
firstChunkSent := false
|
|
||||||
bufferToolContent := len(toolNames) > 0
|
|
||||||
var toolSieve toolStreamSieveState
|
|
||||||
toolCallsEmitted := false
|
|
||||||
initialType := "text"
|
|
||||||
if thinkingEnabled {
|
|
||||||
initialType = "thinking"
|
|
||||||
}
|
|
||||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
|
||||||
thinking := strings.Builder{}
|
|
||||||
text := strings.Builder{}
|
|
||||||
lastContent := time.Now()
|
|
||||||
hasContent := false
|
|
||||||
keepaliveTicker := time.NewTicker(time.Duration(deepseek.KeepAliveTimeout) * time.Second)
|
|
||||||
defer keepaliveTicker.Stop()
|
|
||||||
keepaliveCountWithoutContent := 0
|
|
||||||
|
|
||||||
sendChunk := func(v any) {
|
|
||||||
b, _ := json.Marshal(v)
|
|
||||||
_, _ = w.Write([]byte("data: "))
|
|
||||||
_, _ = w.Write(b)
|
|
||||||
_, _ = w.Write([]byte("\n\n"))
|
|
||||||
if canFlush {
|
|
||||||
_ = rc.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sendDone := func() {
|
|
||||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
|
||||||
if canFlush {
|
|
||||||
_ = rc.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
finalize := func(finishReason string) {
|
|
||||||
finalThinking := thinking.String()
|
|
||||||
finalText := text.String()
|
|
||||||
detected := util.ParseToolCalls(finalText, toolNames)
|
|
||||||
if len(detected) > 0 && !toolCallsEmitted {
|
|
||||||
finishReason = "tool_calls"
|
|
||||||
delta := map[string]any{
|
|
||||||
"tool_calls": util.FormatOpenAIStreamToolCalls(detected),
|
|
||||||
}
|
|
||||||
if !firstChunkSent {
|
|
||||||
delta["role"] = "assistant"
|
|
||||||
firstChunkSent = true
|
|
||||||
}
|
|
||||||
sendChunk(map[string]any{
|
|
||||||
"id": completionID,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": []map[string]any{{"delta": delta, "index": 0}},
|
|
||||||
})
|
|
||||||
} else if bufferToolContent {
|
|
||||||
for _, evt := range flushToolSieve(&toolSieve, toolNames) {
|
|
||||||
if evt.Content == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
delta := map[string]any{
|
|
||||||
"content": evt.Content,
|
|
||||||
}
|
|
||||||
if !firstChunkSent {
|
|
||||||
delta["role"] = "assistant"
|
|
||||||
firstChunkSent = true
|
|
||||||
}
|
|
||||||
sendChunk(map[string]any{
|
|
||||||
"id": completionID,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": []map[string]any{{"delta": delta, "index": 0}},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(detected) > 0 || toolCallsEmitted {
|
|
||||||
finishReason = "tool_calls"
|
|
||||||
}
|
|
||||||
promptTokens := util.EstimateTokens(finalPrompt)
|
|
||||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
|
||||||
completionTokens := util.EstimateTokens(finalText)
|
|
||||||
sendChunk(map[string]any{
|
|
||||||
"id": completionID,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": []map[string]any{{"delta": map[string]any{}, "index": 0, "finish_reason": finishReason}},
|
|
||||||
"usage": map[string]any{
|
|
||||||
"prompt_tokens": promptTokens,
|
|
||||||
"completion_tokens": reasoningTokens + completionTokens,
|
|
||||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
|
||||||
"completion_tokens_details": map[string]any{
|
|
||||||
"reasoning_tokens": reasoningTokens,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
sendDone()
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-r.Context().Done():
|
|
||||||
return
|
|
||||||
case <-keepaliveTicker.C:
|
|
||||||
if !hasContent {
|
|
||||||
keepaliveCountWithoutContent++
|
|
||||||
if keepaliveCountWithoutContent >= deepseek.MaxKeepaliveCount {
|
|
||||||
finalize("stop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hasContent && time.Since(lastContent) > time.Duration(deepseek.StreamIdleTimeout)*time.Second {
|
|
||||||
finalize("stop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if canFlush {
|
|
||||||
_, _ = w.Write([]byte(": keep-alive\n\n"))
|
|
||||||
_ = rc.Flush()
|
|
||||||
}
|
|
||||||
case parsed, ok := <-parsedLines:
|
|
||||||
if !ok {
|
|
||||||
// Ensure scanner completion is observed only after all queued
|
|
||||||
// SSE lines are drained, avoiding early finalize races.
|
|
||||||
_ = <-done
|
|
||||||
finalize("stop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !parsed.Parsed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
|
||||||
finalize("content_filter")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if parsed.Stop {
|
|
||||||
finalize("stop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
newChoices := make([]map[string]any, 0, len(parsed.Parts))
|
|
||||||
for _, p := range parsed.Parts {
|
|
||||||
if searchEnabled && sse.IsCitation(p.Text) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if p.Text == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
hasContent = true
|
|
||||||
lastContent = time.Now()
|
|
||||||
keepaliveCountWithoutContent = 0
|
|
||||||
delta := map[string]any{}
|
|
||||||
if !firstChunkSent {
|
|
||||||
delta["role"] = "assistant"
|
|
||||||
firstChunkSent = true
|
|
||||||
}
|
|
||||||
if p.Type == "thinking" {
|
|
||||||
if thinkingEnabled {
|
|
||||||
thinking.WriteString(p.Text)
|
|
||||||
delta["reasoning_content"] = p.Text
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
text.WriteString(p.Text)
|
|
||||||
if !bufferToolContent {
|
|
||||||
delta["content"] = p.Text
|
|
||||||
} else {
|
|
||||||
events := processToolSieveChunk(&toolSieve, p.Text, toolNames)
|
|
||||||
if len(events) == 0 {
|
|
||||||
// Keep thinking delta only frame.
|
|
||||||
}
|
|
||||||
for _, evt := range events {
|
|
||||||
if len(evt.ToolCalls) > 0 {
|
|
||||||
toolCallsEmitted = true
|
|
||||||
tcDelta := map[string]any{
|
|
||||||
"tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls),
|
|
||||||
}
|
|
||||||
if !firstChunkSent {
|
|
||||||
tcDelta["role"] = "assistant"
|
|
||||||
firstChunkSent = true
|
|
||||||
}
|
|
||||||
newChoices = append(newChoices, map[string]any{
|
|
||||||
"delta": tcDelta,
|
|
||||||
"index": 0,
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if evt.Content != "" {
|
|
||||||
contentDelta := map[string]any{
|
|
||||||
"content": evt.Content,
|
|
||||||
}
|
|
||||||
if !firstChunkSent {
|
|
||||||
contentDelta["role"] = "assistant"
|
|
||||||
firstChunkSent = true
|
|
||||||
}
|
|
||||||
newChoices = append(newChoices, map[string]any{
|
|
||||||
"delta": contentDelta,
|
|
||||||
"index": 0,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(delta) > 0 {
|
|
||||||
newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(newChoices) > 0 {
|
|
||||||
sendChunk(map[string]any{
|
|
||||||
"id": completionID,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": newChoices,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeMessages(raw []any) []map[string]any {
|
|
||||||
out := make([]map[string]any, 0, len(raw))
|
|
||||||
for _, item := range raw {
|
|
||||||
m, ok := item.(map[string]any)
|
|
||||||
if ok {
|
|
||||||
out = append(out, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
|
||||||
toolSchemas := make([]string, 0, len(tools))
|
|
||||||
names := make([]string, 0, len(tools))
|
|
||||||
for _, t := range tools {
|
|
||||||
tool, ok := t.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fn, _ := tool["function"].(map[string]any)
|
|
||||||
if len(fn) == 0 {
|
|
||||||
fn = tool
|
|
||||||
}
|
|
||||||
name, _ := fn["name"].(string)
|
|
||||||
desc, _ := fn["description"].(string)
|
|
||||||
schema, _ := fn["parameters"].(map[string]any)
|
|
||||||
if name == "" {
|
|
||||||
name = "unknown"
|
|
||||||
}
|
|
||||||
names = append(names, name)
|
|
||||||
if desc == "" {
|
|
||||||
desc = "No description available"
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(schema)
|
|
||||||
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
|
||||||
}
|
|
||||||
if len(toolSchemas) == 0 {
|
|
||||||
return messages, names
|
|
||||||
}
|
|
||||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }"
|
|
||||||
|
|
||||||
for i := range messages {
|
|
||||||
if messages[i]["role"] == "system" {
|
|
||||||
old, _ := messages[i]["content"].(string)
|
|
||||||
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
|
||||||
return messages, names
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
|
||||||
return messages, names
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
|
||||||
writeJSON(w, status, map[string]any{
|
|
||||||
"error": map[string]any{
|
|
||||||
"message": message,
|
|
||||||
"type": openAIErrorType(status),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func openAIErrorType(status int) string {
|
|
||||||
switch status {
|
|
||||||
case http.StatusBadRequest:
|
|
||||||
return "invalid_request_error"
|
|
||||||
case http.StatusUnauthorized:
|
|
||||||
return "authentication_error"
|
|
||||||
case http.StatusForbidden:
|
|
||||||
return "permission_error"
|
|
||||||
case http.StatusTooManyRequests:
|
|
||||||
return "rate_limit_error"
|
|
||||||
case http.StatusServiceUnavailable:
|
|
||||||
return "service_unavailable_error"
|
|
||||||
default:
|
|
||||||
if status >= 500 {
|
|
||||||
return "api_error"
|
|
||||||
}
|
|
||||||
return "invalid_request_error"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
156
internal/adapter/openai/handler_chat.go
Normal file
156
internal/adapter/openai/handler_chat.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
openaifmt "ds2api/internal/format/openai"
|
||||||
|
"ds2api/internal/sse"
|
||||||
|
streamengine "ds2api/internal/stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if isVercelStreamReleaseRequest(r) {
|
||||||
|
h.handleVercelStreamRelease(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if isVercelStreamPrepareRequest(r) {
|
||||||
|
h.handleVercelStreamPrepare(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := h.Auth.Determine(r)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusUnauthorized
|
||||||
|
detail := err.Error()
|
||||||
|
if err == auth.ErrNoAccount {
|
||||||
|
status = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
writeOpenAIError(w, status, detail)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer h.Auth.Release(a)
|
||||||
|
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||||
|
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||||
|
if err != nil {
|
||||||
|
if a.UseConfigToken {
|
||||||
|
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||||
|
} else {
|
||||||
|
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload := stdReq.CompletionPayload(sessionID)
|
||||||
|
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if stdReq.Stream {
|
||||||
|
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = ctx
|
||||||
|
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||||
|
|
||||||
|
finalThinking := result.Thinking
|
||||||
|
finalText := result.Text
|
||||||
|
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||||
|
writeJSON(w, http.StatusOK, respBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
rc := http.NewResponseController(w)
|
||||||
|
_, canFlush := w.(http.Flusher)
|
||||||
|
if !canFlush {
|
||||||
|
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
|
||||||
|
}
|
||||||
|
|
||||||
|
created := time.Now().Unix()
|
||||||
|
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||||
|
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||||
|
initialType := "text"
|
||||||
|
if thinkingEnabled {
|
||||||
|
initialType = "thinking"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamRuntime := newChatStreamRuntime(
|
||||||
|
w,
|
||||||
|
rc,
|
||||||
|
canFlush,
|
||||||
|
completionID,
|
||||||
|
created,
|
||||||
|
model,
|
||||||
|
finalPrompt,
|
||||||
|
thinkingEnabled,
|
||||||
|
searchEnabled,
|
||||||
|
toolNames,
|
||||||
|
bufferToolContent,
|
||||||
|
emitEarlyToolDeltas,
|
||||||
|
)
|
||||||
|
|
||||||
|
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||||
|
Context: r.Context(),
|
||||||
|
Body: resp.Body,
|
||||||
|
ThinkingEnabled: thinkingEnabled,
|
||||||
|
InitialType: initialType,
|
||||||
|
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||||
|
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||||
|
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||||
|
}, streamengine.ConsumeHooks{
|
||||||
|
OnKeepAlive: func() {
|
||||||
|
streamRuntime.sendKeepAlive()
|
||||||
|
},
|
||||||
|
OnParsed: streamRuntime.onParsed,
|
||||||
|
OnFinalize: func(reason streamengine.StopReason, _ error) {
|
||||||
|
if string(reason) == "content_filter" {
|
||||||
|
streamRuntime.finalize("content_filter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
streamRuntime.finalize("stop")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
63
internal/adapter/openai/handler_errors.go
Normal file
63
internal/adapter/openai/handler_errors.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||||
|
writeOpenAIErrorWithCode(w, status, message, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) {
|
||||||
|
if code == "" {
|
||||||
|
code = openAIErrorCode(status)
|
||||||
|
}
|
||||||
|
writeJSON(w, status, map[string]any{
|
||||||
|
"error": map[string]any{
|
||||||
|
"message": message,
|
||||||
|
"type": openAIErrorType(status),
|
||||||
|
"code": code,
|
||||||
|
"param": nil,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIErrorType(status int) string {
|
||||||
|
switch status {
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
return "invalid_request_error"
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return "authentication_error"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
return "permission_error"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
return "rate_limit_error"
|
||||||
|
case http.StatusServiceUnavailable:
|
||||||
|
return "service_unavailable_error"
|
||||||
|
default:
|
||||||
|
if status >= 500 {
|
||||||
|
return "api_error"
|
||||||
|
}
|
||||||
|
return "invalid_request_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIErrorCode(status int) string {
|
||||||
|
switch status {
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
return "invalid_request"
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return "authentication_failed"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
return "forbidden"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
return "rate_limit_exceeded"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
return "not_found"
|
||||||
|
case http.StatusServiceUnavailable:
|
||||||
|
return "service_unavailable"
|
||||||
|
default:
|
||||||
|
if status >= 500 {
|
||||||
|
return "internal_error"
|
||||||
|
}
|
||||||
|
return "invalid_request"
|
||||||
|
}
|
||||||
|
}
|
||||||
57
internal/adapter/openai/handler_routes.go
Normal file
57
internal/adapter/openai/handler_routes.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeJSON is a package-internal alias kept to avoid mass-renaming across
|
||||||
|
// every call-site in this package.
|
||||||
|
var writeJSON = util.WriteJSON
|
||||||
|
|
||||||
|
type Handler struct {
|
||||||
|
Store ConfigReader
|
||||||
|
Auth AuthResolver
|
||||||
|
DS DeepSeekCaller
|
||||||
|
|
||||||
|
leaseMu sync.Mutex
|
||||||
|
streamLeases map[string]streamLease
|
||||||
|
responsesMu sync.Mutex
|
||||||
|
responses *responseStore
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamLease struct {
|
||||||
|
Auth *auth.RequestAuth
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||||
|
r.Get("/v1/models", h.ListModels)
|
||||||
|
r.Get("/v1/models/{model_id}", h.GetModel)
|
||||||
|
r.Post("/v1/chat/completions", h.ChatCompletions)
|
||||||
|
r.Post("/v1/responses", h.Responses)
|
||||||
|
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
|
||||||
|
r.Post("/v1/embeddings", h.Embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
|
||||||
|
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
|
||||||
|
model, ok := config.OpenAIModelByID(h.Store, modelID)
|
||||||
|
if !ok {
|
||||||
|
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, model)
|
||||||
|
}
|
||||||
171
internal/adapter/openai/handler_toolcall_format.go
Normal file
171
internal/adapter/openai/handler_toolcall_format.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolChoicePolicy) ([]map[string]any, []string) {
|
||||||
|
if policy.IsNone() {
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
toolSchemas := make([]string, 0, len(tools))
|
||||||
|
names := make([]string, 0, len(tools))
|
||||||
|
isAllowed := func(name string) bool {
|
||||||
|
if strings.TrimSpace(name) == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(policy.Allowed) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, ok := policy.Allowed[name]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range tools {
|
||||||
|
tool, ok := t.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fn, _ := tool["function"].(map[string]any)
|
||||||
|
if len(fn) == 0 {
|
||||||
|
fn = tool
|
||||||
|
}
|
||||||
|
name, _ := fn["name"].(string)
|
||||||
|
desc, _ := fn["description"].(string)
|
||||||
|
schema, _ := fn["parameters"].(map[string]any)
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if !isAllowed(name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
names = append(names, name)
|
||||||
|
if desc == "" {
|
||||||
|
desc = "No description available"
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(schema)
|
||||||
|
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
||||||
|
}
|
||||||
|
if len(toolSchemas) == 0 {
|
||||||
|
return messages, names
|
||||||
|
}
|
||||||
|
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
||||||
|
if policy.Mode == util.ToolChoiceRequired {
|
||||||
|
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
|
||||||
|
}
|
||||||
|
if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
|
||||||
|
toolPrompt += "\n5) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
||||||
|
toolPrompt += "\n6) Do not call any other tool."
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range messages {
|
||||||
|
if messages[i]["role"] == "system" {
|
||||||
|
old, _ := messages[i]["content"].(string)
|
||||||
|
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
||||||
|
return messages, names
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
||||||
|
return messages, names
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
|
||||||
|
if len(deltas) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]map[string]any, 0, len(deltas))
|
||||||
|
for _, d := range deltas {
|
||||||
|
if d.Name == "" && d.Arguments == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID, ok := ids[d.Index]
|
||||||
|
if !ok || callID == "" {
|
||||||
|
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||||
|
ids[d.Index] = callID
|
||||||
|
}
|
||||||
|
item := map[string]any{
|
||||||
|
"index": d.Index,
|
||||||
|
"id": callID,
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
fn := map[string]any{}
|
||||||
|
if d.Name != "" {
|
||||||
|
fn["name"] = d.Name
|
||||||
|
}
|
||||||
|
if d.Arguments != "" {
|
||||||
|
fn["arguments"] = d.Arguments
|
||||||
|
}
|
||||||
|
if len(fn) > 0 {
|
||||||
|
item["function"] = fn
|
||||||
|
}
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta {
|
||||||
|
if len(deltas) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
allowed := namesToSet(allowedNames)
|
||||||
|
if len(allowed) == 0 {
|
||||||
|
for _, d := range deltas {
|
||||||
|
if d.Name != "" {
|
||||||
|
seenNames[d.Index] = "__blocked__"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]toolCallDelta, 0, len(deltas))
|
||||||
|
for _, d := range deltas {
|
||||||
|
if d.Name != "" {
|
||||||
|
if _, ok := allowed[d.Name]; !ok {
|
||||||
|
seenNames[d.Index] = "__blocked__"
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenNames[d.Index] = d.Name
|
||||||
|
out = append(out, d)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(seenNames[d.Index])
|
||||||
|
if name == "" || name == "__blocked__" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, d)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||||
|
if len(calls) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]map[string]any, 0, len(calls))
|
||||||
|
for i, c := range calls {
|
||||||
|
callID := ""
|
||||||
|
if ids != nil {
|
||||||
|
callID = strings.TrimSpace(ids[i])
|
||||||
|
}
|
||||||
|
if callID == "" {
|
||||||
|
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||||
|
if ids != nil {
|
||||||
|
ids[i] = callID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args, _ := json.Marshal(c.Input)
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"index": i,
|
||||||
|
"id": callID,
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": c.Name,
|
||||||
|
"arguments": string(args),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
25
internal/adapter/openai/handler_toolcall_policy.go
Normal file
25
internal/adapter/openai/handler_toolcall_policy.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
|
||||||
|
for k, v := range collectOpenAIChatPassThrough(req) {
|
||||||
|
payload[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) toolcallFeatureMatchEnabled() bool {
|
||||||
|
if h == nil || h.Store == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
|
||||||
|
return mode == "" || mode == "feature_match"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
|
||||||
|
if h == nil || h.Store == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
|
||||||
|
return level == "" || level == "high"
|
||||||
|
}
|
||||||
@@ -100,6 +100,26 @@ func streamFinishReason(frames []map[string]any) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func streamToolCallArgumentChunks(frames []map[string]any) []string {
|
||||||
|
out := make([]string, 0, 4)
|
||||||
|
for _, frame := range frames {
|
||||||
|
choices, _ := frame["choices"].([]any)
|
||||||
|
for _, item := range choices {
|
||||||
|
choice, _ := item.(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
toolCalls, _ := delta["tool_calls"].([]any)
|
||||||
|
for _, tc := range toolCalls {
|
||||||
|
tcm, _ := tc.(map[string]any)
|
||||||
|
fn, _ := tcm["function"].(map[string]any)
|
||||||
|
if args, ok := fn["arguments"].(string); ok && args != "" {
|
||||||
|
out = append(out, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
@@ -108,7 +128,7 @@ func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
|||||||
)
|
)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"})
|
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"})
|
||||||
if rec.Code != http.StatusOK {
|
if rec.Code != http.StatusOK {
|
||||||
t.Fatalf("unexpected status: %d", rec.Code)
|
t.Fatalf("unexpected status: %d", rec.Code)
|
||||||
}
|
}
|
||||||
@@ -141,7 +161,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
|||||||
)
|
)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"})
|
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"})
|
||||||
if rec.Code != http.StatusOK {
|
if rec.Code != http.StatusOK {
|
||||||
t.Fatalf("unexpected status: %d", rec.Code)
|
t.Fatalf("unexpected status: %d", rec.Code)
|
||||||
}
|
}
|
||||||
@@ -161,7 +181,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||||
@@ -169,7 +189,38 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
|||||||
)
|
)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"})
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
choices, _ := out["choices"].([]any)
|
||||||
|
choice, _ := choices[0].(map[string]any)
|
||||||
|
if choice["finish_reason"] != "stop" {
|
||||||
|
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||||
|
}
|
||||||
|
msg, _ := choice["message"].(map[string]any)
|
||||||
|
if _, ok := msg["tool_calls"]; ok {
|
||||||
|
t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"])
|
||||||
|
}
|
||||||
|
content, _ := msg["content"].(string)
|
||||||
|
if !strings.Contains(content, `"tool_calls"`) {
|
||||||
|
t.Fatalf("expected unknown tool json to pass through as text, got %#v", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||||
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||||
|
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"})
|
||||||
if rec.Code != http.StatusOK {
|
if rec.Code != http.StatusOK {
|
||||||
t.Fatalf("unexpected status: %d", rec.Code)
|
t.Fatalf("unexpected status: %d", rec.Code)
|
||||||
}
|
}
|
||||||
@@ -181,12 +232,41 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
|||||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||||
}
|
}
|
||||||
msg, _ := choice["message"].(map[string]any)
|
msg, _ := choice["message"].(map[string]any)
|
||||||
if msg["content"] != nil {
|
|
||||||
t.Fatalf("expected content nil, got %#v", msg["content"])
|
|
||||||
}
|
|
||||||
toolCalls, _ := msg["tool_calls"].([]any)
|
toolCalls, _ := msg["tool_calls"].([]any)
|
||||||
if len(toolCalls) != 1 {
|
if len(toolCalls) == 0 {
|
||||||
t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"])
|
t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||||
|
}
|
||||||
|
if msg["content"] != nil {
|
||||||
|
t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"})
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d", rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
choices, _ := out["choices"].([]any)
|
||||||
|
choice, _ := choices[0].(map[string]any)
|
||||||
|
if choice["finish_reason"] != "stop" {
|
||||||
|
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||||
|
}
|
||||||
|
msg, _ := choice["message"].(map[string]any)
|
||||||
|
if _, ok := msg["tool_calls"]; ok {
|
||||||
|
t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"])
|
||||||
|
}
|
||||||
|
content, _ := msg["content"].(string)
|
||||||
|
if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) {
|
||||||
|
t.Fatalf("expected fenced tool example to pass through as text, got %q", content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -295,7 +375,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
|
func TestHandleStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||||
@@ -310,29 +390,14 @@ func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
|
|||||||
if !done {
|
if !done {
|
||||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
if !streamHasToolCallsDelta(frames) {
|
if streamHasToolCallsDelta(frames) {
|
||||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
foundToolIndex := false
|
if !streamHasRawToolJSONContent(frames) {
|
||||||
for _, frame := range frames {
|
t.Fatalf("expected raw tool_calls json to remain in content for unknown schema name: %s", rec.Body.String())
|
||||||
choices, _ := frame["choices"].([]any)
|
|
||||||
for _, item := range choices {
|
|
||||||
choice, _ := item.(map[string]any)
|
|
||||||
delta, _ := choice["delta"].(map[string]any)
|
|
||||||
toolCalls, _ := delta["tool_calls"].([]any)
|
|
||||||
for _, tc := range toolCalls {
|
|
||||||
tcm, _ := tc.(map[string]any)
|
|
||||||
if _, ok := tcm["index"].(float64); ok {
|
|
||||||
foundToolIndex = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if !foundToolIndex {
|
if streamFinishReason(frames) != "stop" {
|
||||||
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
|
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||||
}
|
|
||||||
if streamHasRawToolJSONContent(frames) {
|
|
||||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,9 +442,9 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
|||||||
func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"前置正文A。"}`,
|
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||||
`data: {"p":"response/content","v":"后置正文B。"}`,
|
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||||
`data: [DONE]`,
|
`data: [DONE]`,
|
||||||
)
|
)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -392,10 +457,7 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
|||||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
if !streamHasToolCallsDelta(frames) {
|
if !streamHasToolCallsDelta(frames) {
|
||||||
t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String())
|
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||||
}
|
|
||||||
if streamHasRawToolJSONContent(frames) {
|
|
||||||
t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String())
|
|
||||||
}
|
}
|
||||||
content := strings.Builder{}
|
content := strings.Builder{}
|
||||||
for _, frame := range frames {
|
for _, frame := range frames {
|
||||||
@@ -409,9 +471,95 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
got := content.String()
|
got := content.String()
|
||||||
if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") {
|
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||||
}
|
}
|
||||||
|
if strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
||||||
|
t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got)
|
||||||
|
}
|
||||||
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
|
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||||
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||||
|
|
||||||
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
|
if !done {
|
||||||
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !streamHasToolCallsDelta(frames) {
|
||||||
|
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
content := strings.Builder{}
|
||||||
|
for _, frame := range frames {
|
||||||
|
choices, _ := frame["choices"].([]any)
|
||||||
|
for _, item := range choices {
|
||||||
|
choice, _ := item.(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
if c, ok := delta["content"].(string); ok {
|
||||||
|
content.WriteString(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
got := content.String()
|
||||||
|
if !strings.Contains(got, "我将调用工具。") {
|
||||||
|
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||||
|
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
||||||
|
}
|
||||||
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||||
|
|
||||||
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
|
if !done {
|
||||||
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !streamHasToolCallsDelta(frames) {
|
||||||
|
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
content := strings.Builder{}
|
||||||
|
for _, frame := range frames {
|
||||||
|
choices, _ := frame["choices"].([]any)
|
||||||
|
for _, item := range choices {
|
||||||
|
choice, _ := item.(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
if c, ok := delta["content"].(string); ok {
|
||||||
|
content.WriteString(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
got := content.String()
|
||||||
|
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||||
|
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||||
|
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
||||||
|
}
|
||||||
if streamFinishReason(frames) != "tool_calls" {
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
@@ -495,16 +643,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
got := strings.ToLower(content.String())
|
got := content.String()
|
||||||
if strings.Contains(got, "tool_calls") {
|
if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") {
|
||||||
t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String())
|
|
||||||
}
|
|
||||||
if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") {
|
|
||||||
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
|
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||||
|
t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) {
|
func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||||
@@ -533,7 +681,112 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") {
|
if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") {
|
||||||
t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String())
|
t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||||
|
`data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||||
|
|
||||||
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
|
if !done {
|
||||||
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !streamHasToolCallsDelta(frames) {
|
||||||
|
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if streamHasRawToolJSONContent(frames) {
|
||||||
|
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
argChunks := streamToolCallArgumentChunks(frames)
|
||||||
|
if len(argChunks) < 2 {
|
||||||
|
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String())
|
||||||
|
}
|
||||||
|
joined := strings.Join(argChunks, "")
|
||||||
|
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||||
|
t.Fatalf("unexpected merged arguments stream: %q", joined)
|
||||||
|
}
|
||||||
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
|
||||||
|
`data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||||
|
|
||||||
|
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||||
|
if !done {
|
||||||
|
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !streamHasToolCallsDelta(frames) {
|
||||||
|
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
foundSearch := false
|
||||||
|
foundEval := false
|
||||||
|
foundIndex1 := false
|
||||||
|
toolCallsDeltaLens := make([]int, 0, 2)
|
||||||
|
for _, frame := range frames {
|
||||||
|
choices, _ := frame["choices"].([]any)
|
||||||
|
for _, item := range choices {
|
||||||
|
choice, _ := item.(map[string]any)
|
||||||
|
delta, _ := choice["delta"].(map[string]any)
|
||||||
|
rawToolCalls, hasToolCalls := delta["tool_calls"]
|
||||||
|
if !hasToolCalls {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toolCalls, _ := rawToolCalls.([]any)
|
||||||
|
toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls))
|
||||||
|
for _, tc := range toolCalls {
|
||||||
|
tcm, _ := tc.(map[string]any)
|
||||||
|
if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
|
||||||
|
foundIndex1 = true
|
||||||
|
}
|
||||||
|
fn, _ := tcm["function"].(map[string]any)
|
||||||
|
name, _ := fn["name"].(string)
|
||||||
|
switch name {
|
||||||
|
case "search_web":
|
||||||
|
foundSearch = true
|
||||||
|
case "eval_javascript":
|
||||||
|
foundEval = true
|
||||||
|
case "search_webeval_javascript":
|
||||||
|
t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
|
||||||
|
}
|
||||||
|
if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
|
||||||
|
t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundSearch || !foundEval {
|
||||||
|
t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
|
||||||
|
}
|
||||||
|
if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 {
|
||||||
|
t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String())
|
||||||
|
}
|
||||||
|
if !foundIndex1 {
|
||||||
|
t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if streamFinishReason(frames) != "tool_calls" {
|
||||||
|
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
270
internal/adapter/openai/message_normalize.go
Normal file
270
internal/adapter/openai/message_normalize.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(raw))
|
||||||
|
for _, item := range raw {
|
||||||
|
msg, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||||
|
switch role {
|
||||||
|
case "assistant":
|
||||||
|
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||||
|
toolCalls := formatAssistantToolCallsForPrompt(msg, traceID)
|
||||||
|
combined := joinNonEmpty(content, toolCalls)
|
||||||
|
if combined == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": combined,
|
||||||
|
})
|
||||||
|
case "tool", "function":
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": formatToolResultForPrompt(msg),
|
||||||
|
})
|
||||||
|
case "user", "system":
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"role": role,
|
||||||
|
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||||
|
if content == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if role == "" {
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"role": role,
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string {
|
||||||
|
entries := make([]string, 0)
|
||||||
|
if calls, ok := msg["tool_calls"].([]any); ok {
|
||||||
|
for i, item := range calls {
|
||||||
|
call, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id := strings.TrimSpace(asString(call["id"]))
|
||||||
|
if id == "" {
|
||||||
|
id = fmt.Sprintf("call_%d", i+1)
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(asString(call["name"]))
|
||||||
|
args := ""
|
||||||
|
|
||||||
|
if fn, ok := call["function"].(map[string]any); ok {
|
||||||
|
if name == "" {
|
||||||
|
name = strings.TrimSpace(asString(fn["name"]))
|
||||||
|
}
|
||||||
|
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = "unknown"
|
||||||
|
}
|
||||||
|
if args == "" {
|
||||||
|
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
|
||||||
|
}
|
||||||
|
if args == "" {
|
||||||
|
args = normalizeOpenAIArgumentsForPrompt(call["input"])
|
||||||
|
}
|
||||||
|
if args == "" {
|
||||||
|
args = "{}"
|
||||||
|
}
|
||||||
|
maybeWarnSuspiciousToolHistory(traceID, id, name, args)
|
||||||
|
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if legacy, ok := msg["function_call"].(map[string]any); ok {
|
||||||
|
name := strings.TrimSpace(asString(legacy["name"]))
|
||||||
|
if name == "" {
|
||||||
|
name = "unknown"
|
||||||
|
}
|
||||||
|
args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"])
|
||||||
|
if args == "" {
|
||||||
|
args = "{}"
|
||||||
|
}
|
||||||
|
maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args)
|
||||||
|
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(entries, "\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatToolResultForPrompt(msg map[string]any) string {
|
||||||
|
toolCallID := strings.TrimSpace(asString(msg["tool_call_id"]))
|
||||||
|
if toolCallID == "" {
|
||||||
|
toolCallID = strings.TrimSpace(asString(msg["id"]))
|
||||||
|
}
|
||||||
|
if toolCallID == "" {
|
||||||
|
toolCallID = "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.TrimSpace(asString(msg["name"]))
|
||||||
|
if name == "" {
|
||||||
|
name = "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||||
|
if content == "" {
|
||||||
|
content = "null"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeOpenAIContentForPrompt(v any) string {
|
||||||
|
switch x := v.(type) {
|
||||||
|
case string:
|
||||||
|
return x
|
||||||
|
case []any:
|
||||||
|
parts := make([]string, 0, len(x))
|
||||||
|
for _, item := range x {
|
||||||
|
m, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||||
|
if t != "text" && t != "output_text" && t != "input_text" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if text := asString(m["text"]); text != "" {
|
||||||
|
parts = append(parts, text)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if text := asString(m["content"]); text != "" {
|
||||||
|
parts = append(parts, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "\n")
|
||||||
|
default:
|
||||||
|
return marshalToPromptString(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
||||||
|
switch x := v.(type) {
|
||||||
|
case string:
|
||||||
|
return normalizeToolArgumentString(x)
|
||||||
|
default:
|
||||||
|
return marshalToPromptString(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToolArgumentString(raw string) string {
|
||||||
|
trimmed := strings.TrimSpace(raw)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if !looksLikeConcatenatedJSON(trimmed) {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||||||
|
values := make([]any, 0, 2)
|
||||||
|
for {
|
||||||
|
var v any
|
||||||
|
if err := dec.Decode(&v); err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
values = append(values, v)
|
||||||
|
}
|
||||||
|
if len(values) < 2 {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
last := values[len(values)-1]
|
||||||
|
b, err := json.Marshal(last)
|
||||||
|
if err != nil || len(b) == 0 {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalToPromptString(v any) string {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func asString(v any) string {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinNonEmpty(parts ...string) string {
|
||||||
|
nonEmpty := make([]string, 0, len(parts))
|
||||||
|
for _, p := range parts {
|
||||||
|
if strings.TrimSpace(p) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nonEmpty = append(nonEmpty, p)
|
||||||
|
}
|
||||||
|
return strings.Join(nonEmpty, "\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) {
|
||||||
|
if !looksLikeConcatenatedJSON(args) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
traceID = strings.TrimSpace(traceID)
|
||||||
|
if traceID == "" {
|
||||||
|
traceID = "unknown"
|
||||||
|
}
|
||||||
|
config.Logger.Warn(
|
||||||
|
"[openai] suspicious tool call history payload detected",
|
||||||
|
"trace_id", traceID,
|
||||||
|
"tool_call_id", strings.TrimSpace(callID),
|
||||||
|
"name", strings.TrimSpace(name),
|
||||||
|
"arguments_preview", previewToolArgs(args, 160),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeConcatenatedJSON(raw string) bool {
|
||||||
|
trimmed := strings.TrimSpace(raw)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||||||
|
var first any
|
||||||
|
if err := dec.Decode(&first); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var second any
|
||||||
|
return dec.Decode(&second) == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func previewToolArgs(raw string, max int) string {
|
||||||
|
trimmed := strings.TrimSpace(raw)
|
||||||
|
if max <= 0 || len(trimmed) <= max {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
return trimmed[:max]
|
||||||
|
}
|
||||||
198
internal/adapter/openai/message_normalize_test.go
Normal file
198
internal/adapter/openai/message_normalize_test.go
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{"role": "system", "content": "You are helpful"},
|
||||||
|
map[string]any{"role": "user", "content": "查北京天气"},
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": nil,
|
||||||
|
"tool_calls": []any{
|
||||||
|
map[string]any{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"city\":\"beijing\"}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"name": "get_weather",
|
||||||
|
"content": "{\"temp\":18}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 4 {
|
||||||
|
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
assistantContent, _ := normalized[2]["content"].(string)
|
||||||
|
if !strings.Contains(assistantContent, "[TOOL_CALL_HISTORY]") ||
|
||||||
|
!strings.Contains(assistantContent, "tool_call_id: call_1") ||
|
||||||
|
!strings.Contains(assistantContent, "function.name: get_weather") ||
|
||||||
|
!strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") {
|
||||||
|
t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent)
|
||||||
|
}
|
||||||
|
toolContent, _ := normalized[3]["content"].(string)
|
||||||
|
if !strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") || !strings.Contains(toolContent, "name: get_weather") {
|
||||||
|
t.Fatalf("tool result not serialized correctly: %q", toolContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := util.MessagesPrepare(normalized)
|
||||||
|
if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") {
|
||||||
|
t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_2",
|
||||||
|
"name": "get_weather",
|
||||||
|
"content": map[string]any{
|
||||||
|
"temp": 18,
|
||||||
|
"condition": "sunny",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
got, _ := normalized[0]["content"].(string)
|
||||||
|
if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) {
|
||||||
|
t.Fatalf("expected serialized object in tool content, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_3",
|
||||||
|
"name": "read_file",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "input_text", "text": "line-1"},
|
||||||
|
map[string]any{"type": "output_text", "text": "line-2"},
|
||||||
|
map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
got, _ := normalized[0]["content"].(string)
|
||||||
|
if !strings.Contains(got, "line-1\nline-2") {
|
||||||
|
t.Fatalf("expected joined text blocks, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "function",
|
||||||
|
"tool_call_id": "call_4",
|
||||||
|
"name": "legacy_tool",
|
||||||
|
"content": map[string]any{
|
||||||
|
"ok": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 1 {
|
||||||
|
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
if normalized[0]["role"] != "user" {
|
||||||
|
t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"])
|
||||||
|
}
|
||||||
|
got, _ := normalized[0]["content"].(string)
|
||||||
|
if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) {
|
||||||
|
t.Fatalf("unexpected normalized function-role content: %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": []any{
|
||||||
|
map[string]any{
|
||||||
|
"id": "call_search",
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "search_web",
|
||||||
|
"arguments": `{"query":"latest ai news"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"id": "call_eval",
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "eval_javascript",
|
||||||
|
"arguments": `{"code":"1+1"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 1 {
|
||||||
|
t.Fatalf("expected one normalized assistant message, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
content, _ := normalized[0]["content"].(string)
|
||||||
|
if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 {
|
||||||
|
t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content)
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") {
|
||||||
|
t.Fatalf("missing first tool call block, got %q", content)
|
||||||
|
}
|
||||||
|
if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") {
|
||||||
|
t.Fatalf("missing second tool call block, got %q", content)
|
||||||
|
}
|
||||||
|
if strings.Contains(content, "search_webeval_javascript") {
|
||||||
|
t.Fatalf("unexpected merged function name detected: %q", content)
|
||||||
|
}
|
||||||
|
if strings.Contains(content, `}{"`) {
|
||||||
|
t.Fatalf("unexpected concatenated function arguments detected: %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) {
|
||||||
|
raw := []any{
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": []any{
|
||||||
|
map[string]any{
|
||||||
|
"id": "call_1",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "search_web",
|
||||||
|
"arguments": `{}{"query":"测试工具调用"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||||
|
if len(normalized) != 1 {
|
||||||
|
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||||
|
}
|
||||||
|
content, _ := normalized[0]["content"].(string)
|
||||||
|
if !strings.Contains(content, `function.arguments: {"query":"测试工具调用"}`) {
|
||||||
|
t.Fatalf("expected repaired arguments in tool history, got %q", content)
|
||||||
|
}
|
||||||
|
if strings.Contains(content, `{}{"query":"测试工具调用"}`) {
|
||||||
|
t.Fatalf("expected concatenated JSON to be repaired, got %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
46
internal/adapter/openai/models_route_test.go
Normal file
46
internal/adapter/openai/models_route_test.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetModelRouteDirectAndAlias(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
t.Run("direct", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/models/deepseek-chat", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("alias", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/models/gpt-4.1", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200 for alias, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelRouteNotFound(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/models/not-exists", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
26
internal/adapter/openai/prompt_build.go
Normal file
26
internal/adapter/openai/prompt_build.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||||
|
return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy())
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy) (string, []string) {
|
||||||
|
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
|
||||||
|
toolNames := []string{}
|
||||||
|
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
|
||||||
|
messages, toolNames = injectToolPrompt(messages, tools, toolPolicy)
|
||||||
|
}
|
||||||
|
return deepseek.MessagesPrepare(messages), toolNames
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so
|
||||||
|
// other protocol adapters (for example Gemini) can reuse the same tool/history
|
||||||
|
// normalization logic and remain behavior-compatible with chat/completions.
|
||||||
|
func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||||
|
return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID)
|
||||||
|
}
|
||||||
83
internal/adapter/openai/prompt_build_test.go
Normal file
83
internal/adapter/openai/prompt_build_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *testing.T) {
|
||||||
|
messages := []any{
|
||||||
|
map[string]any{"role": "user", "content": "查北京天气"},
|
||||||
|
map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": []any{
|
||||||
|
map[string]any{
|
||||||
|
"id": "call_1",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"city\":\"beijing\"}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_1",
|
||||||
|
"name": "get_weather",
|
||||||
|
"content": map[string]any{"temp": 18, "condition": "sunny"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tools := []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather",
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "")
|
||||||
|
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
|
||||||
|
t.Fatalf("unexpected tool names: %#v", toolNames)
|
||||||
|
}
|
||||||
|
if !strings.Contains(finalPrompt, "tool_call_id: call_1") ||
|
||||||
|
!strings.Contains(finalPrompt, "function.name: get_weather") ||
|
||||||
|
!strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") ||
|
||||||
|
!strings.Contains(finalPrompt, `"condition":"sunny"`) {
|
||||||
|
t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) {
|
||||||
|
messages := []any{
|
||||||
|
map[string]any{"role": "system", "content": "You are helpful"},
|
||||||
|
map[string]any{"role": "user", "content": "请调用工具"},
|
||||||
|
}
|
||||||
|
tools := []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "search",
|
||||||
|
"description": "search docs",
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "")
|
||||||
|
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
|
||||||
|
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
|
||||||
|
}
|
||||||
|
if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") {
|
||||||
|
t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt)
|
||||||
|
}
|
||||||
|
if !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") {
|
||||||
|
t.Fatalf("vercel prepare finalPrompt missing history marker instruction: %q", finalPrompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
109
internal/adapter/openai/response_store.go
Normal file
109
internal/adapter/openai/response_store.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type storedResponse struct {
|
||||||
|
Owner string
|
||||||
|
Value map[string]any
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
ttl time.Duration
|
||||||
|
items map[string]storedResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResponseStore(ttl time.Duration) *responseStore {
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = 15 * time.Minute
|
||||||
|
}
|
||||||
|
return &responseStore{
|
||||||
|
ttl: ttl,
|
||||||
|
items: make(map[string]storedResponse),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseStoreKey(owner, id string) string {
|
||||||
|
return owner + "\x00" + id
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseStoreOwner(a *auth.RequestAuth) string {
|
||||||
|
if a == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.CallerID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responseStore) put(owner, id string, value map[string]any) {
|
||||||
|
if s == nil || owner == "" || id == "" || value == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.sweepLocked(now)
|
||||||
|
s.items[responseStoreKey(owner, id)] = storedResponse{
|
||||||
|
Owner: owner,
|
||||||
|
Value: cloneAnyMap(value),
|
||||||
|
ExpiresAt: now.Add(s.ttl),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responseStore) get(owner, id string) (map[string]any, bool) {
|
||||||
|
if s == nil || owner == "" || id == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.sweepLocked(now)
|
||||||
|
item, ok := s.items[responseStoreKey(owner, id)]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if item.Owner != owner {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return cloneAnyMap(item.Value), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responseStore) sweepLocked(now time.Time) {
|
||||||
|
for k, v := range s.items {
|
||||||
|
if now.After(v.ExpiresAt) {
|
||||||
|
delete(s.items, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAnyMap(in map[string]any) map[string]any {
|
||||||
|
if in == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]any, len(in))
|
||||||
|
for k, v := range in {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) getResponseStore() *responseStore {
|
||||||
|
if h == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h.responsesMu.Lock()
|
||||||
|
defer h.responsesMu.Unlock()
|
||||||
|
if h.responses == nil {
|
||||||
|
ttl := 15 * time.Minute
|
||||||
|
if h.Store != nil {
|
||||||
|
ttl = time.Duration(h.Store.ResponsesStoreTTLSeconds()) * time.Second
|
||||||
|
}
|
||||||
|
h.responses = newResponseStore(ttl)
|
||||||
|
}
|
||||||
|
return h.responses
|
||||||
|
}
|
||||||
197
internal/adapter/openai/responses_embeddings_test.go
Normal file
197
internal/adapter/openai/responses_embeddings_test.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeResponsesInputAsMessagesString(t *testing.T) {
|
||||||
|
msgs := normalizeResponsesInputAsMessages("hello")
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
t.Fatalf("expected one message, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
m, _ := msgs[0].(map[string]any)
|
||||||
|
if m["role"] != "user" || m["content"] != "hello" {
|
||||||
|
t.Fatalf("unexpected message: %#v", m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) {
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "gpt-4.1",
|
||||||
|
"input": "ping",
|
||||||
|
"instructions": "system text",
|
||||||
|
}
|
||||||
|
msgs := responsesMessagesFromRequest(req)
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("expected two messages, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
sys, _ := msgs[0].(map[string]any)
|
||||||
|
if sys["role"] != "system" {
|
||||||
|
t.Fatalf("unexpected first message: %#v", sys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesInputAsMessagesObjectRoleContentBlocks(t *testing.T) {
|
||||||
|
msgs := normalizeResponsesInputAsMessages(map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "input_text", "text": "line-1"},
|
||||||
|
map[string]any{"type": "input_text", "text": "line-2"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
t.Fatalf("expected one message, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
m, _ := msgs[0].(map[string]any)
|
||||||
|
if m["role"] != "user" {
|
||||||
|
t.Fatalf("unexpected role: %#v", m)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(normalizeOpenAIContentForPrompt(m["content"])) != "line-1\nline-2" {
|
||||||
|
t.Fatalf("unexpected content: %#v", m["content"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) {
|
||||||
|
msgs := normalizeResponsesInputAsMessages([]any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": "call_123",
|
||||||
|
"output": map[string]any{"ok": true},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
t.Fatalf("expected one message, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
m, _ := msgs[0].(map[string]any)
|
||||||
|
if m["role"] != "tool" {
|
||||||
|
t.Fatalf("expected tool role, got %#v", m)
|
||||||
|
}
|
||||||
|
if m["tool_call_id"] != "call_123" {
|
||||||
|
t.Fatalf("expected tool_call_id propagated, got %#v", m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesInputAsMessagesBackfillsToolResultNameFromCallID(t *testing.T) {
|
||||||
|
msgs := normalizeResponsesInputAsMessages([]any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "call_999",
|
||||||
|
"name": "search",
|
||||||
|
"arguments": `{"q":"golang"}`,
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": "call_999",
|
||||||
|
"output": map[string]any{"ok": true},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("expected two messages, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
toolMsg, _ := msgs[1].(map[string]any)
|
||||||
|
if toolMsg["role"] != "tool" {
|
||||||
|
t.Fatalf("expected tool role, got %#v", toolMsg)
|
||||||
|
}
|
||||||
|
if toolMsg["name"] != "search" {
|
||||||
|
t.Fatalf("expected tool name backfilled from call_id, got %#v", toolMsg["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
||||||
|
msgs := normalizeResponsesInputAsMessages([]any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "call_456",
|
||||||
|
"name": "search",
|
||||||
|
"arguments": `{"q":"golang"}`,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
t.Fatalf("expected one message, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
m, _ := msgs[0].(map[string]any)
|
||||||
|
if m["role"] != "assistant" {
|
||||||
|
t.Fatalf("expected assistant role, got %#v", m["role"])
|
||||||
|
}
|
||||||
|
toolCalls, _ := m["tool_calls"].([]any)
|
||||||
|
if len(toolCalls) != 1 {
|
||||||
|
t.Fatalf("expected one tool_call, got %#v", m["tool_calls"])
|
||||||
|
}
|
||||||
|
call, _ := toolCalls[0].(map[string]any)
|
||||||
|
if call["id"] != "call_456" {
|
||||||
|
t.Fatalf("expected call id preserved, got %#v", call)
|
||||||
|
}
|
||||||
|
if call["type"] != "function" {
|
||||||
|
t.Fatalf("expected function type, got %#v", call)
|
||||||
|
}
|
||||||
|
fn, _ := call["function"].(map[string]any)
|
||||||
|
if fn["name"] != "search" {
|
||||||
|
t.Fatalf("expected call name preserved, got %#v", call)
|
||||||
|
}
|
||||||
|
if fn["arguments"] != `{"q":"golang"}` {
|
||||||
|
t.Fatalf("expected call arguments preserved, got %#v", call)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) {
|
||||||
|
msgs := normalizeResponsesInputAsMessages([]any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "call_456",
|
||||||
|
"name": "search",
|
||||||
|
"arguments": `{}{"q":"golang"}`,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
t.Fatalf("expected one message, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
m, _ := msgs[0].(map[string]any)
|
||||||
|
toolCalls, _ := m["tool_calls"].([]any)
|
||||||
|
call, _ := toolCalls[0].(map[string]any)
|
||||||
|
fn, _ := call["function"].(map[string]any)
|
||||||
|
if fn["arguments"] != `{"q":"golang"}` {
|
||||||
|
t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractEmbeddingInputs(t *testing.T) {
|
||||||
|
got := extractEmbeddingInputs([]any{"a", "b"})
|
||||||
|
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
|
||||||
|
t.Fatalf("unexpected inputs: %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeterministicEmbeddingStable(t *testing.T) {
|
||||||
|
a := deterministicEmbedding("hello")
|
||||||
|
b := deterministicEmbedding("hello")
|
||||||
|
if len(a) != 64 || len(b) != 64 {
|
||||||
|
t.Fatalf("expected 64 dims, got %d and %d", len(a), len(b))
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
t.Fatalf("expected stable embedding at %d: %v != %v", i, a[i], b[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseStorePutGet(t *testing.T) {
|
||||||
|
st := newResponseStore(100 * time.Millisecond)
|
||||||
|
st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"})
|
||||||
|
got, ok := st.get("owner_1", "resp_1")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected stored response")
|
||||||
|
}
|
||||||
|
if got["id"] != "resp_1" {
|
||||||
|
t.Fatalf("unexpected response payload: %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseStoreTenantIsolation(t *testing.T) {
|
||||||
|
st := newResponseStore(100 * time.Millisecond)
|
||||||
|
st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"})
|
||||||
|
if _, ok := st.get("owner_b", "resp_1"); ok {
|
||||||
|
t.Fatal("expected owner_b to be isolated from owner_a response")
|
||||||
|
}
|
||||||
|
}
|
||||||
221
internal/adapter/openai/responses_handler.go
Normal file
221
internal/adapter/openai/responses_handler.go
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
openaifmt "ds2api/internal/format/openai"
|
||||||
|
"ds2api/internal/sse"
|
||||||
|
streamengine "ds2api/internal/stream"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
|
||||||
|
a, err := h.Auth.DetermineCaller(r)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusUnauthorized, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := strings.TrimSpace(chi.URLParam(r, "response_id"))
|
||||||
|
if id == "" {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, "response_id is required.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
owner := responseStoreOwner(a)
|
||||||
|
if owner == "" {
|
||||||
|
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
st := h.getResponseStore()
|
||||||
|
item, ok := st.get(owner, id)
|
||||||
|
if !ok {
|
||||||
|
writeOpenAIError(w, http.StatusNotFound, "Response not found.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||||
|
a, err := h.Auth.Determine(r)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusUnauthorized
|
||||||
|
detail := err.Error()
|
||||||
|
if err == auth.ErrNoAccount {
|
||||||
|
status = http.StatusTooManyRequests
|
||||||
|
}
|
||||||
|
writeOpenAIError(w, status, detail)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer h.Auth.Release(a)
|
||||||
|
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||||
|
owner := responseStoreOwner(a)
|
||||||
|
if owner == "" {
|
||||||
|
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
traceID := requestTraceID(r)
|
||||||
|
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, traceID)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||||
|
if err != nil {
|
||||||
|
if a.UseConfigToken {
|
||||||
|
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||||
|
} else {
|
||||||
|
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload := stdReq.CompletionPayload(sessionID)
|
||||||
|
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||||
|
if stdReq.Stream {
|
||||||
|
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||||
|
textParsed := util.ParseToolCallsDetailed(result.Text, toolNames)
|
||||||
|
thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames)
|
||||||
|
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||||
|
logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking")
|
||||||
|
|
||||||
|
callCount := len(textParsed.Calls)
|
||||||
|
if callCount == 0 {
|
||||||
|
callCount = len(thinkingParsed.Calls)
|
||||||
|
}
|
||||||
|
if toolChoice.IsRequired() && callCount == 0 {
|
||||||
|
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
|
||||||
|
h.getResponseStore().put(owner, responseID, responseObj)
|
||||||
|
writeJSON(w, http.StatusOK, responseObj)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
rc := http.NewResponseController(w)
|
||||||
|
_, canFlush := w.(http.Flusher)
|
||||||
|
|
||||||
|
initialType := "text"
|
||||||
|
if thinkingEnabled {
|
||||||
|
initialType = "thinking"
|
||||||
|
}
|
||||||
|
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||||
|
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||||
|
|
||||||
|
streamRuntime := newResponsesStreamRuntime(
|
||||||
|
w,
|
||||||
|
rc,
|
||||||
|
canFlush,
|
||||||
|
responseID,
|
||||||
|
model,
|
||||||
|
finalPrompt,
|
||||||
|
thinkingEnabled,
|
||||||
|
searchEnabled,
|
||||||
|
toolNames,
|
||||||
|
bufferToolContent,
|
||||||
|
emitEarlyToolDeltas,
|
||||||
|
toolChoice,
|
||||||
|
traceID,
|
||||||
|
func(obj map[string]any) {
|
||||||
|
h.getResponseStore().put(owner, responseID, obj)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
streamRuntime.sendCreated()
|
||||||
|
|
||||||
|
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||||
|
Context: r.Context(),
|
||||||
|
Body: resp.Body,
|
||||||
|
ThinkingEnabled: thinkingEnabled,
|
||||||
|
InitialType: initialType,
|
||||||
|
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||||
|
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||||
|
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||||
|
}, streamengine.ConsumeHooks{
|
||||||
|
OnParsed: streamRuntime.onParsed,
|
||||||
|
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||||
|
streamRuntime.finalize()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolicy, parsed util.ToolCallParseResult, channel string) {
|
||||||
|
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||||
|
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.Logger.Warn(
|
||||||
|
"[responses] rejected tool calls by policy",
|
||||||
|
"trace_id", strings.TrimSpace(traceID),
|
||||||
|
"channel", channel,
|
||||||
|
"tool_choice_mode", policy.Mode,
|
||||||
|
"rejected_tool_names", strings.Join(rejected, ","),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func filteredRejectedToolNamesForLog(names []string) []string {
|
||||||
|
if len(names) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(names))
|
||||||
|
for _, name := range names {
|
||||||
|
trimmed := strings.TrimSpace(name)
|
||||||
|
switch strings.ToLower(trimmed) {
|
||||||
|
case "", "tool_name":
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
out = append(out, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
203
internal/adapter/openai/responses_input_items.go
Normal file
203
internal/adapter/openai/responses_input_items.go
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||||
|
return normalizeResponsesInputItemWithState(m, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[string]string) map[string]any {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||||
|
if role != "" {
|
||||||
|
content := m["content"]
|
||||||
|
if content == nil {
|
||||||
|
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||||
|
content = txt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if content == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"role": role,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||||
|
switch itemType {
|
||||||
|
case "message", "input_message":
|
||||||
|
content := m["content"]
|
||||||
|
if content == nil {
|
||||||
|
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||||
|
content = txt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if content == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||||
|
if role == "" {
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"role": role,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
case "function_call_output", "tool_result":
|
||||||
|
content := m["output"]
|
||||||
|
if content == nil {
|
||||||
|
content = m["content"]
|
||||||
|
}
|
||||||
|
if content == nil {
|
||||||
|
content = ""
|
||||||
|
}
|
||||||
|
out := map[string]any{
|
||||||
|
"role": "tool",
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||||
|
out["tool_call_id"] = callID
|
||||||
|
} else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
|
||||||
|
out["tool_call_id"] = callID
|
||||||
|
}
|
||||||
|
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||||
|
out["name"] = name
|
||||||
|
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
|
||||||
|
out["name"] = name
|
||||||
|
} else if callID := strings.TrimSpace(asString(out["tool_call_id"])); callID != "" {
|
||||||
|
if inferred := strings.TrimSpace(callNameByID[callID]); inferred != "" {
|
||||||
|
out["name"] = inferred
|
||||||
|
} else {
|
||||||
|
config.Logger.Warn(
|
||||||
|
"[responses] unable to backfill tool result name from call_id",
|
||||||
|
"call_id", callID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case "function_call", "tool_call":
|
||||||
|
name := strings.TrimSpace(asString(m["name"]))
|
||||||
|
var fn map[string]any
|
||||||
|
if rawFn, ok := m["function"].(map[string]any); ok {
|
||||||
|
fn = rawFn
|
||||||
|
if name == "" {
|
||||||
|
name = strings.TrimSpace(asString(fn["name"]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var argsRaw any
|
||||||
|
if v, ok := m["arguments"]; ok {
|
||||||
|
argsRaw = v
|
||||||
|
} else if v, ok := m["input"]; ok {
|
||||||
|
argsRaw = v
|
||||||
|
}
|
||||||
|
if argsRaw == nil && fn != nil {
|
||||||
|
if v, ok := fn["arguments"]; ok {
|
||||||
|
argsRaw = v
|
||||||
|
} else if v, ok := fn["input"]; ok {
|
||||||
|
argsRaw = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
functionPayload := map[string]any{
|
||||||
|
"name": name,
|
||||||
|
"arguments": stringifyToolCallArguments(argsRaw),
|
||||||
|
}
|
||||||
|
call := map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": functionPayload,
|
||||||
|
}
|
||||||
|
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||||
|
call["id"] = callID
|
||||||
|
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
|
||||||
|
call["id"] = callID
|
||||||
|
}
|
||||||
|
if callID := strings.TrimSpace(asString(call["id"])); callID != "" && callNameByID != nil {
|
||||||
|
callNameByID[callID] = name
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": []any{call},
|
||||||
|
}
|
||||||
|
case "input_text":
|
||||||
|
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||||
|
return map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": txt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||||
|
return map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": txt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if content, ok := m["content"]; ok {
|
||||||
|
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||||
|
return map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesFallbackPart(m map[string]any) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
|
||||||
|
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||||
|
return txt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||||
|
return txt
|
||||||
|
}
|
||||||
|
if content, ok := m["content"]; ok {
|
||||||
|
if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(fmt.Sprintf("%v", m))
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringifyToolCallArguments(v any) string {
|
||||||
|
switch x := v.(type) {
|
||||||
|
case nil:
|
||||||
|
return "{}"
|
||||||
|
case string:
|
||||||
|
s := strings.TrimSpace(x)
|
||||||
|
if s == "" {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
s = normalizeToolArgumentString(s)
|
||||||
|
if s == "" {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
default:
|
||||||
|
b, err := json.Marshal(x)
|
||||||
|
if err != nil || len(b) == 0 {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
94
internal/adapter/openai/responses_input_normalize.go
Normal file
94
internal/adapter/openai/responses_input_normalize.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func responsesMessagesFromRequest(req map[string]any) []any {
|
||||||
|
if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
|
||||||
|
return prependInstructionMessage(msgs, req["instructions"])
|
||||||
|
}
|
||||||
|
if rawInput, ok := req["input"]; ok {
|
||||||
|
if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
|
||||||
|
return prependInstructionMessage(msgs, req["instructions"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func prependInstructionMessage(messages []any, instructions any) []any {
|
||||||
|
sys, _ := instructions.(string)
|
||||||
|
sys = strings.TrimSpace(sys)
|
||||||
|
if sys == "" {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
out := make([]any, 0, len(messages)+1)
|
||||||
|
out = append(out, map[string]any{"role": "system", "content": sys})
|
||||||
|
out = append(out, messages...)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesInputAsMessages(input any) []any {
|
||||||
|
switch v := input.(type) {
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(v) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []any{map[string]any{"role": "user", "content": v}}
|
||||||
|
case []any:
|
||||||
|
return normalizeResponsesInputArray(v)
|
||||||
|
case map[string]any:
|
||||||
|
if msg := normalizeResponsesInputItem(v); msg != nil {
|
||||||
|
return []any{msg}
|
||||||
|
}
|
||||||
|
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
|
||||||
|
return []any{map[string]any{"role": "user", "content": txt}}
|
||||||
|
}
|
||||||
|
if content, ok := v["content"]; ok {
|
||||||
|
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||||
|
return []any{map[string]any{"role": "user", "content": content}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesInputArray(items []any) []any {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]any, 0, len(items))
|
||||||
|
callNameByID := map[string]string{}
|
||||||
|
fallbackParts := make([]string, 0, len(items))
|
||||||
|
flushFallback := func() {
|
||||||
|
if len(fallbackParts) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
|
||||||
|
fallbackParts = fallbackParts[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range items {
|
||||||
|
switch x := item.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if msg := normalizeResponsesInputItemWithState(x, callNameByID); msg != nil {
|
||||||
|
flushFallback()
|
||||||
|
out = append(out, msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s := normalizeResponsesFallbackPart(x); s != "" {
|
||||||
|
fallbackParts = append(fallbackParts, s)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
|
||||||
|
fallbackParts = append(fallbackParts, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flushFallback()
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
176
internal/adapter/openai/responses_route_test.go
Normal file
176
internal/adapter/openai/responses_route_test.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/account"
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
|
||||||
|
store := config.LoadStore()
|
||||||
|
pool := account.NewPool(store)
|
||||||
|
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||||
|
return "unused", nil
|
||||||
|
})
|
||||||
|
return store, resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||||
|
"keys":["managed-key"],
|
||||||
|
"accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}]
|
||||||
|
}`)
|
||||||
|
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||||
|
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0")
|
||||||
|
store := config.LoadStore()
|
||||||
|
pool := account.NewPool(store)
|
||||||
|
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||||
|
return "unused", nil
|
||||||
|
})
|
||||||
|
return store, resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth {
|
||||||
|
t.Helper()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
a, err := resolver.Determine(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("determine auth failed: %v", err)
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) {
|
||||||
|
store, resolver := newDirectTokenResolver(t)
|
||||||
|
h := &Handler{Store: store, Auth: resolver}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
ownerA := responseStoreOwner(authForToken(t, resolver, "token-a"))
|
||||||
|
h.getResponseStore().put(ownerA, "resp_test", map[string]any{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unauthorized", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cross-tenant-not-found", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer token-b")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("same-tenant-ok", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer token-a")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
var body map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||||
|
t.Fatalf("decode body failed: %v", err)
|
||||||
|
}
|
||||||
|
if body["id"] != "resp_test" {
|
||||||
|
t.Fatalf("unexpected body: %#v", body)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesRouteValidationContract(t *testing.T) {
|
||||||
|
store, resolver := newDirectTokenResolver(t)
|
||||||
|
h := &Handler{Store: store, Auth: resolver}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
}{
|
||||||
|
{name: "missing_model", body: `{"input":"hello"}`},
|
||||||
|
{name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body))
|
||||||
|
req.Header.Set("Authorization", "Bearer token-a")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
var out map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||||
|
t.Fatalf("decode response failed: %v", err)
|
||||||
|
}
|
||||||
|
errObj, _ := out["error"].(map[string]any)
|
||||||
|
if _, ok := errObj["code"]; !ok {
|
||||||
|
t.Fatalf("expected error.code: %#v", out)
|
||||||
|
}
|
||||||
|
if _, ok := errObj["param"]; !ok {
|
||||||
|
t.Fatalf("expected error.param: %#v", out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) {
|
||||||
|
store, resolver := newManagedKeyResolver(t)
|
||||||
|
h := &Handler{Store: store, Auth: resolver}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||||
|
ownerReq.Header.Set("Authorization", "Bearer managed-key")
|
||||||
|
ownerAuth, err := resolver.DetermineCaller(ownerReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("determine caller failed: %v", err)
|
||||||
|
}
|
||||||
|
owner := responseStoreOwner(ownerAuth)
|
||||||
|
h.getResponseStore().put(owner, "resp_test", map[string]any{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
})
|
||||||
|
|
||||||
|
occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
occupyReq.Header.Set("Authorization", "Bearer managed-key")
|
||||||
|
occupied, err := resolver.Determine(occupyReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected first acquire to succeed: %v", err)
|
||||||
|
}
|
||||||
|
defer resolver.Release(occupied)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer managed-key")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
225
internal/adapter/openai/responses_stream_runtime_core.go
Normal file
225
internal/adapter/openai/responses_stream_runtime_core.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
openaifmt "ds2api/internal/format/openai"
|
||||||
|
"ds2api/internal/sse"
|
||||||
|
streamengine "ds2api/internal/stream"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type responsesStreamRuntime struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
rc *http.ResponseController
|
||||||
|
canFlush bool
|
||||||
|
|
||||||
|
responseID string
|
||||||
|
model string
|
||||||
|
finalPrompt string
|
||||||
|
toolNames []string
|
||||||
|
traceID string
|
||||||
|
toolChoice util.ToolChoicePolicy
|
||||||
|
|
||||||
|
thinkingEnabled bool
|
||||||
|
searchEnabled bool
|
||||||
|
|
||||||
|
bufferToolContent bool
|
||||||
|
emitEarlyToolDeltas bool
|
||||||
|
toolCallsEmitted bool
|
||||||
|
toolCallsDoneEmitted bool
|
||||||
|
|
||||||
|
sieve toolStreamSieveState
|
||||||
|
thinkingSieve toolStreamSieveState
|
||||||
|
thinking strings.Builder
|
||||||
|
text strings.Builder
|
||||||
|
visibleText strings.Builder
|
||||||
|
streamToolCallIDs map[int]string
|
||||||
|
functionItemIDs map[int]string
|
||||||
|
functionOutputIDs map[int]int
|
||||||
|
functionArgs map[int]string
|
||||||
|
functionDone map[int]bool
|
||||||
|
functionAdded map[int]bool
|
||||||
|
functionNames map[int]string
|
||||||
|
messageItemID string
|
||||||
|
messageOutputID int
|
||||||
|
nextOutputID int
|
||||||
|
messageAdded bool
|
||||||
|
messagePartAdded bool
|
||||||
|
sequence int
|
||||||
|
failed bool
|
||||||
|
|
||||||
|
persistResponse func(obj map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResponsesStreamRuntime(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
rc *http.ResponseController,
|
||||||
|
canFlush bool,
|
||||||
|
responseID string,
|
||||||
|
model string,
|
||||||
|
finalPrompt string,
|
||||||
|
thinkingEnabled bool,
|
||||||
|
searchEnabled bool,
|
||||||
|
toolNames []string,
|
||||||
|
bufferToolContent bool,
|
||||||
|
emitEarlyToolDeltas bool,
|
||||||
|
toolChoice util.ToolChoicePolicy,
|
||||||
|
traceID string,
|
||||||
|
persistResponse func(obj map[string]any),
|
||||||
|
) *responsesStreamRuntime {
|
||||||
|
return &responsesStreamRuntime{
|
||||||
|
w: w,
|
||||||
|
rc: rc,
|
||||||
|
canFlush: canFlush,
|
||||||
|
responseID: responseID,
|
||||||
|
model: model,
|
||||||
|
finalPrompt: finalPrompt,
|
||||||
|
thinkingEnabled: thinkingEnabled,
|
||||||
|
searchEnabled: searchEnabled,
|
||||||
|
toolNames: toolNames,
|
||||||
|
bufferToolContent: bufferToolContent,
|
||||||
|
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||||
|
streamToolCallIDs: map[int]string{},
|
||||||
|
functionItemIDs: map[int]string{},
|
||||||
|
functionOutputIDs: map[int]int{},
|
||||||
|
functionArgs: map[int]string{},
|
||||||
|
functionDone: map[int]bool{},
|
||||||
|
functionAdded: map[int]bool{},
|
||||||
|
functionNames: map[int]string{},
|
||||||
|
messageOutputID: -1,
|
||||||
|
toolChoice: toolChoice,
|
||||||
|
traceID: traceID,
|
||||||
|
persistResponse: persistResponse,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) finalize() {
|
||||||
|
finalThinking := s.thinking.String()
|
||||||
|
finalText := s.text.String()
|
||||||
|
|
||||||
|
if s.bufferToolContent {
|
||||||
|
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||||
|
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
||||||
|
}
|
||||||
|
|
||||||
|
textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames)
|
||||||
|
thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames)
|
||||||
|
detected := textParsed.Calls
|
||||||
|
if len(detected) == 0 {
|
||||||
|
detected = thinkingParsed.Calls
|
||||||
|
}
|
||||||
|
s.logToolPolicyRejections(textParsed, thinkingParsed)
|
||||||
|
|
||||||
|
if len(detected) > 0 {
|
||||||
|
s.toolCallsEmitted = true
|
||||||
|
if !s.toolCallsDoneEmitted {
|
||||||
|
s.emitFunctionCallDoneEvents(detected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.closeMessageItem()
|
||||||
|
|
||||||
|
if s.toolChoice.IsRequired() && len(detected) == 0 {
|
||||||
|
s.failed = true
|
||||||
|
message := "tool_choice requires at least one valid tool call."
|
||||||
|
failedResp := map[string]any{
|
||||||
|
"id": s.responseID,
|
||||||
|
"type": "response",
|
||||||
|
"object": "response",
|
||||||
|
"model": s.model,
|
||||||
|
"status": "failed",
|
||||||
|
"output": []any{},
|
||||||
|
"output_text": "",
|
||||||
|
"error": map[string]any{
|
||||||
|
"message": message,
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "tool_choice_violation",
|
||||||
|
"param": nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if s.persistResponse != nil {
|
||||||
|
s.persistResponse(failedResp)
|
||||||
|
}
|
||||||
|
s.sendEvent("response.failed", openaifmt.BuildResponsesFailedPayload(s.responseID, s.model, message, "tool_choice_violation"))
|
||||||
|
s.sendDone()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.closeIncompleteFunctionItems()
|
||||||
|
|
||||||
|
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
|
||||||
|
if s.persistResponse != nil {
|
||||||
|
s.persistResponse(obj)
|
||||||
|
}
|
||||||
|
s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj))
|
||||||
|
s.sendDone()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) {
|
||||||
|
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
||||||
|
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||||
|
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.Logger.Warn(
|
||||||
|
"[responses] rejected tool calls by policy",
|
||||||
|
"trace_id", strings.TrimSpace(s.traceID),
|
||||||
|
"channel", channel,
|
||||||
|
"tool_choice_mode", s.toolChoice.Mode,
|
||||||
|
"rejected_tool_names", strings.Join(rejected, ","),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
logRejected(textParsed, "text")
|
||||||
|
logRejected(thinkingParsed, "thinking")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
||||||
|
for _, done := range s.functionDone {
|
||||||
|
if done {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||||
|
if !parsed.Parsed {
|
||||||
|
return streamengine.ParsedDecision{}
|
||||||
|
}
|
||||||
|
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||||
|
return streamengine.ParsedDecision{Stop: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentSeen := false
|
||||||
|
for _, p := range parsed.Parts {
|
||||||
|
if p.Text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentSeen = true
|
||||||
|
if p.Type == "thinking" {
|
||||||
|
if !s.thinkingEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.thinking.WriteString(p.Text)
|
||||||
|
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||||
|
if s.bufferToolContent {
|
||||||
|
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
s.text.WriteString(p.Text)
|
||||||
|
if !s.bufferToolContent {
|
||||||
|
s.emitTextDelta(p.Text)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||||
|
}
|
||||||
61
internal/adapter/openai/responses_stream_runtime_events.go
Normal file
61
internal/adapter/openai/responses_stream_runtime_events.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
openaifmt "ds2api/internal/format/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) nextSequence() int {
|
||||||
|
s.sequence++
|
||||||
|
return s.sequence
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
|
||||||
|
if payload == nil {
|
||||||
|
payload = map[string]any{}
|
||||||
|
}
|
||||||
|
if _, ok := payload["sequence_number"]; !ok {
|
||||||
|
payload["sequence_number"] = s.nextSequence()
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(payload)
|
||||||
|
_, _ = s.w.Write([]byte("event: " + event + "\n"))
|
||||||
|
_, _ = s.w.Write([]byte("data: "))
|
||||||
|
_, _ = s.w.Write(b)
|
||||||
|
_, _ = s.w.Write([]byte("\n\n"))
|
||||||
|
if s.canFlush {
|
||||||
|
_ = s.rc.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) sendCreated() {
|
||||||
|
s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) sendDone() {
|
||||||
|
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
if s.canFlush {
|
||||||
|
_ = s.rc.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||||
|
for _, evt := range events {
|
||||||
|
if emitContent && evt.Content != "" {
|
||||||
|
s.emitTextDelta(evt.Content)
|
||||||
|
}
|
||||||
|
if len(evt.ToolCallDeltas) > 0 {
|
||||||
|
if !s.emitEarlyToolDeltas {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames)
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.emitFunctionCallDeltaEvents(filtered)
|
||||||
|
}
|
||||||
|
if len(evt.ToolCalls) > 0 {
|
||||||
|
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
235
internal/adapter/openai/responses_stream_runtime_toolcalls.go
Normal file
235
internal/adapter/openai/responses_stream_runtime_toolcalls.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
openaifmt "ds2api/internal/format/openai"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) allocateOutputIndex() int {
|
||||||
|
idx := s.nextOutputID
|
||||||
|
s.nextOutputID++
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) ensureMessageItemID() string {
|
||||||
|
if strings.TrimSpace(s.messageItemID) != "" {
|
||||||
|
return s.messageItemID
|
||||||
|
}
|
||||||
|
s.messageItemID = "msg_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||||
|
return s.messageItemID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) ensureMessageOutputIndex() int {
|
||||||
|
if s.messageOutputID >= 0 {
|
||||||
|
return s.messageOutputID
|
||||||
|
}
|
||||||
|
s.messageOutputID = s.allocateOutputIndex()
|
||||||
|
return s.messageOutputID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) ensureMessageItemAdded() {
|
||||||
|
if s.messageAdded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
itemID := s.ensureMessageItemID()
|
||||||
|
item := map[string]any{
|
||||||
|
"id": itemID,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "in_progress",
|
||||||
|
}
|
||||||
|
s.sendEvent(
|
||||||
|
"response.output_item.added",
|
||||||
|
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.ensureMessageOutputIndex(), item),
|
||||||
|
)
|
||||||
|
s.messageAdded = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) ensureMessageContentPartAdded() {
|
||||||
|
if s.messagePartAdded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.ensureMessageItemAdded()
|
||||||
|
s.sendEvent(
|
||||||
|
"response.content_part.added",
|
||||||
|
openaifmt.BuildResponsesContentPartAddedPayload(
|
||||||
|
s.responseID,
|
||||||
|
s.ensureMessageItemID(),
|
||||||
|
s.ensureMessageOutputIndex(),
|
||||||
|
0,
|
||||||
|
map[string]any{"type": "output_text", "text": ""},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
s.messagePartAdded = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) emitTextDelta(content string) {
|
||||||
|
if strings.TrimSpace(content) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.ensureMessageContentPartAdded()
|
||||||
|
s.visibleText.WriteString(content)
|
||||||
|
s.sendEvent(
|
||||||
|
"response.output_text.delta",
|
||||||
|
openaifmt.BuildResponsesTextDeltaPayload(
|
||||||
|
s.responseID,
|
||||||
|
s.ensureMessageItemID(),
|
||||||
|
s.ensureMessageOutputIndex(),
|
||||||
|
0,
|
||||||
|
content,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) closeMessageItem() {
|
||||||
|
if !s.messageAdded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
itemID := s.ensureMessageItemID()
|
||||||
|
outputIndex := s.ensureMessageOutputIndex()
|
||||||
|
text := s.visibleText.String()
|
||||||
|
if s.messagePartAdded {
|
||||||
|
s.sendEvent(
|
||||||
|
"response.content_part.done",
|
||||||
|
openaifmt.BuildResponsesContentPartDonePayload(
|
||||||
|
s.responseID,
|
||||||
|
itemID,
|
||||||
|
outputIndex,
|
||||||
|
0,
|
||||||
|
map[string]any{"type": "output_text", "text": text},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
s.messagePartAdded = false
|
||||||
|
}
|
||||||
|
item := map[string]any{
|
||||||
|
"id": itemID,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": []map[string]any{
|
||||||
|
{
|
||||||
|
"type": "output_text",
|
||||||
|
"text": text,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.sendEvent(
|
||||||
|
"response.output_item.done",
|
||||||
|
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) ensureFunctionItemID(callIndex int) string {
|
||||||
|
if id, ok := s.functionItemIDs[callIndex]; ok && strings.TrimSpace(id) != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||||
|
s.functionItemIDs[callIndex] = id
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) ensureToolCallID(callIndex int) string {
|
||||||
|
if id, ok := s.streamToolCallIDs[callIndex]; ok && strings.TrimSpace(id) != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||||
|
s.streamToolCallIDs[callIndex] = id
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) ensureFunctionOutputIndex(callIndex int) int {
|
||||||
|
if idx, ok := s.functionOutputIDs[callIndex]; ok {
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
idx := s.allocateOutputIndex()
|
||||||
|
s.functionOutputIDs[callIndex] = idx
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name string) {
|
||||||
|
if strings.TrimSpace(name) != "" {
|
||||||
|
s.functionNames[callIndex] = strings.TrimSpace(name)
|
||||||
|
}
|
||||||
|
if s.functionAdded[callIndex] {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fnName := strings.TrimSpace(s.functionNames[callIndex])
|
||||||
|
if fnName == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
outputIndex := s.ensureFunctionOutputIndex(callIndex)
|
||||||
|
itemID := s.ensureFunctionItemID(callIndex)
|
||||||
|
callID := s.ensureToolCallID(callIndex)
|
||||||
|
item := map[string]any{
|
||||||
|
"id": itemID,
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": callID,
|
||||||
|
"name": fnName,
|
||||||
|
"arguments": "",
|
||||||
|
"status": "in_progress",
|
||||||
|
}
|
||||||
|
s.sendEvent(
|
||||||
|
"response.output_item.added",
|
||||||
|
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, outputIndex, item),
|
||||||
|
)
|
||||||
|
s.functionAdded[callIndex] = true
|
||||||
|
s.toolCallsEmitted = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
|
||||||
|
for _, d := range deltas {
|
||||||
|
s.ensureFunctionItemAdded(d.Index, d.Name)
|
||||||
|
if strings.TrimSpace(d.Arguments) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.functionArgs[d.Index] += d.Arguments
|
||||||
|
outputIndex := s.ensureFunctionOutputIndex(d.Index)
|
||||||
|
itemID := s.ensureFunctionItemID(d.Index)
|
||||||
|
callID := s.ensureToolCallID(d.Index)
|
||||||
|
s.sendEvent(
|
||||||
|
"response.function_call_arguments.delta",
|
||||||
|
openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
|
||||||
|
for idx, tc := range calls {
|
||||||
|
if strings.TrimSpace(tc.Name) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.ensureFunctionItemAdded(idx, tc.Name)
|
||||||
|
if s.functionDone[idx] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
outputIndex := s.ensureFunctionOutputIndex(idx)
|
||||||
|
itemID := s.ensureFunctionItemID(idx)
|
||||||
|
callID := s.ensureToolCallID(idx)
|
||||||
|
argsBytes, _ := json.Marshal(tc.Input)
|
||||||
|
args := string(argsBytes)
|
||||||
|
s.functionArgs[idx] = args
|
||||||
|
s.sendEvent(
|
||||||
|
"response.function_call_arguments.done",
|
||||||
|
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, args),
|
||||||
|
)
|
||||||
|
item := map[string]any{
|
||||||
|
"id": itemID,
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": callID,
|
||||||
|
"name": tc.Name,
|
||||||
|
"arguments": args,
|
||||||
|
"status": "completed",
|
||||||
|
}
|
||||||
|
s.sendEvent(
|
||||||
|
"response.output_item.done",
|
||||||
|
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||||
|
)
|
||||||
|
s.functionDone[idx] = true
|
||||||
|
s.toolCallsDoneEmitted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,156 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
openaifmt "ds2api/internal/format/openai"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) closeIncompleteFunctionItems() {
|
||||||
|
if len(s.functionAdded) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
indices := make([]int, 0, len(s.functionAdded))
|
||||||
|
for idx, added := range s.functionAdded {
|
||||||
|
if !added || s.functionDone[idx] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
indices = append(indices, idx)
|
||||||
|
}
|
||||||
|
if len(indices) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sort.Ints(indices)
|
||||||
|
for _, idx := range indices {
|
||||||
|
name := strings.TrimSpace(s.functionNames[idx])
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
args := strings.TrimSpace(s.functionArgs[idx])
|
||||||
|
if args == "" {
|
||||||
|
args = "{}"
|
||||||
|
}
|
||||||
|
outputIndex := s.ensureFunctionOutputIndex(idx)
|
||||||
|
itemID := s.ensureFunctionItemID(idx)
|
||||||
|
callID := s.ensureToolCallID(idx)
|
||||||
|
s.sendEvent(
|
||||||
|
"response.function_call_arguments.done",
|
||||||
|
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, name, args),
|
||||||
|
)
|
||||||
|
item := map[string]any{
|
||||||
|
"id": itemID,
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": callID,
|
||||||
|
"name": name,
|
||||||
|
"arguments": args,
|
||||||
|
"status": "completed",
|
||||||
|
}
|
||||||
|
s.sendEvent(
|
||||||
|
"response.output_item.done",
|
||||||
|
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||||
|
)
|
||||||
|
s.functionDone[idx] = true
|
||||||
|
s.toolCallsDoneEmitted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, finalText string, calls []util.ParsedToolCall) map[string]any {
|
||||||
|
type indexedItem struct {
|
||||||
|
index int
|
||||||
|
item map[string]any
|
||||||
|
}
|
||||||
|
indexed := make([]indexedItem, 0, len(calls)+1)
|
||||||
|
|
||||||
|
if s.messageAdded {
|
||||||
|
text := s.visibleText.String()
|
||||||
|
indexed = append(indexed, indexedItem{
|
||||||
|
index: s.ensureMessageOutputIndex(),
|
||||||
|
item: map[string]any{
|
||||||
|
"id": s.ensureMessageItemID(),
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": []map[string]any{
|
||||||
|
{
|
||||||
|
"type": "output_text",
|
||||||
|
"text": text,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else if len(calls) == 0 {
|
||||||
|
content := make([]map[string]any, 0, 2)
|
||||||
|
if strings.TrimSpace(finalThinking) != "" {
|
||||||
|
content = append(content, map[string]any{
|
||||||
|
"type": "reasoning",
|
||||||
|
"text": finalThinking,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(finalText) != "" {
|
||||||
|
content = append(content, map[string]any{
|
||||||
|
"type": "output_text",
|
||||||
|
"text": finalText,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(content) > 0 {
|
||||||
|
indexed = append(indexed, indexedItem{
|
||||||
|
index: s.ensureMessageOutputIndex(),
|
||||||
|
item: map[string]any{
|
||||||
|
"id": s.ensureMessageItemID(),
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, tc := range calls {
|
||||||
|
if strings.TrimSpace(tc.Name) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
argsBytes, _ := json.Marshal(tc.Input)
|
||||||
|
indexed = append(indexed, indexedItem{
|
||||||
|
index: s.ensureFunctionOutputIndex(idx),
|
||||||
|
item: map[string]any{
|
||||||
|
"id": s.ensureFunctionItemID(idx),
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": s.ensureToolCallID(idx),
|
||||||
|
"name": tc.Name,
|
||||||
|
"arguments": string(argsBytes),
|
||||||
|
"status": "completed",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.SliceStable(indexed, func(i, j int) bool {
|
||||||
|
return indexed[i].index < indexed[j].index
|
||||||
|
})
|
||||||
|
output := make([]any, 0, len(indexed))
|
||||||
|
for _, it := range indexed {
|
||||||
|
output = append(output, it.item)
|
||||||
|
}
|
||||||
|
|
||||||
|
outputText := s.visibleText.String()
|
||||||
|
if strings.TrimSpace(outputText) == "" && len(calls) == 0 {
|
||||||
|
if strings.TrimSpace(finalText) != "" {
|
||||||
|
outputText = finalText
|
||||||
|
} else if strings.TrimSpace(finalThinking) != "" {
|
||||||
|
outputText = finalThinking
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return openaifmt.BuildResponseObjectFromItems(
|
||||||
|
s.responseID,
|
||||||
|
s.model,
|
||||||
|
s.finalPrompt,
|
||||||
|
finalThinking,
|
||||||
|
finalText,
|
||||||
|
output,
|
||||||
|
outputText,
|
||||||
|
)
|
||||||
|
}
|
||||||
611
internal/adapter/openai/responses_stream_test.go
Normal file
611
internal/adapter/openai/responses_stream_test.go
Normal file
@@ -0,0 +1,611 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`
|
||||||
|
streamBody := sseLine(rawToolJSON) + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
|
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
responseObj, _ := completed["response"].(map[string]any)
|
||||||
|
outputText, _ := responseObj["output_text"].(string)
|
||||||
|
if outputText != "" {
|
||||||
|
t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText)
|
||||||
|
}
|
||||||
|
output, _ := responseObj["output"].([]any)
|
||||||
|
if len(output) == 0 {
|
||||||
|
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
|
||||||
|
}
|
||||||
|
hasFunctionCall := false
|
||||||
|
hasLegacyWrapper := false
|
||||||
|
for _, item := range output {
|
||||||
|
m, _ := item.(map[string]any)
|
||||||
|
if m == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m["type"] == "function_call" {
|
||||||
|
hasFunctionCall = true
|
||||||
|
}
|
||||||
|
if m["type"] == "tool_calls" {
|
||||||
|
hasLegacyWrapper = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasFunctionCall {
|
||||||
|
t.Fatalf("expected function_call item, got %#v", responseObj["output"])
|
||||||
|
}
|
||||||
|
if hasLegacyWrapper {
|
||||||
|
t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"])
|
||||||
|
}
|
||||||
|
if strings.Contains(outputText, `"tool_calls"`) {
|
||||||
|
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "event: response.output_item.added") {
|
||||||
|
t.Fatalf("expected response.output_item.added event, body=%s", body)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "event: response.output_item.done") {
|
||||||
|
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||||
|
t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||||
|
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") {
|
||||||
|
t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
addedPayloads := extractAllSSEEventPayloads(body, "response.output_item.added")
|
||||||
|
hasFunctionCallAdded := false
|
||||||
|
for _, payload := range addedPayloads {
|
||||||
|
item, _ := payload["item"].(map[string]any)
|
||||||
|
if item == nil || asString(item["type"]) != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hasFunctionCallAdded = true
|
||||||
|
if asString(item["arguments"]) != "" {
|
||||||
|
t.Fatalf("expected in-progress function_call.arguments to start empty string, got %#v", item["arguments"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasFunctionCallAdded {
|
||||||
|
t.Fatalf("expected function_call output_item.added payload, body=%s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
|
||||||
|
}
|
||||||
|
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
|
||||||
|
if doneCallID == "" {
|
||||||
|
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
|
||||||
|
}
|
||||||
|
completed, ok := extractSSEEventPayload(body, "response.completed")
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected response.completed payload, body=%s", body)
|
||||||
|
}
|
||||||
|
responseObj, _ := completed["response"].(map[string]any)
|
||||||
|
output, _ := responseObj["output"].([]any)
|
||||||
|
var completedCallID string
|
||||||
|
for _, item := range output {
|
||||||
|
m, _ := item.(map[string]any)
|
||||||
|
if m == nil || m["type"] != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
completedCallID = strings.TrimSpace(asString(m["call_id"]))
|
||||||
|
if completedCallID != "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if completedCallID == "" {
|
||||||
|
t.Fatalf("expected function_call.call_id in completed output, output=%#v", output)
|
||||||
|
}
|
||||||
|
if completedCallID != doneCallID {
|
||||||
|
t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/thinking_content",
|
||||||
|
"v": "thought",
|
||||||
|
})
|
||||||
|
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||||
|
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, "event: response.reasoning_text.delta") || strings.Contains(body, "event: response.reasoning_text.done") {
|
||||||
|
t.Fatalf("did not expect response.reasoning_text.* compatibility events, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||||
|
sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||||
|
"data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||||
|
if len(donePayloads) != 2 {
|
||||||
|
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||||
|
}
|
||||||
|
seenNames := map[string]string{}
|
||||||
|
for _, payload := range donePayloads {
|
||||||
|
name := strings.TrimSpace(asString(payload["name"]))
|
||||||
|
callID := strings.TrimSpace(asString(payload["call_id"]))
|
||||||
|
if name != "search_web" && name != "eval_javascript" {
|
||||||
|
t.Fatalf("unexpected tool name in done payload: %#v", payload)
|
||||||
|
}
|
||||||
|
if callID == "" {
|
||||||
|
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||||
|
}
|
||||||
|
seenNames[name] = callID
|
||||||
|
}
|
||||||
|
if seenNames["search_web"] == seenNames["eval_javascript"] {
|
||||||
|
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine("hello") + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||||
|
body := rec.Body.String()
|
||||||
|
|
||||||
|
deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta")
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected response.output_text.delta payload, body=%s", body)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(asString(deltaPayload["item_id"])) == "" {
|
||||||
|
t.Fatalf("expected non-empty item_id in output_text.delta, payload=%#v", deltaPayload)
|
||||||
|
}
|
||||||
|
if _, ok := deltaPayload["output_index"]; !ok {
|
||||||
|
t.Fatalf("expected output_index in output_text.delta, payload=%#v", deltaPayload)
|
||||||
|
}
|
||||||
|
if _, ok := deltaPayload["content_index"]; !ok {
|
||||||
|
t.Fatalf("expected content_index in output_text.delta, payload=%#v", deltaPayload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(path, value string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": path,
|
||||||
|
"v": value,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine("response/thinking_content", "thinking...") +
|
||||||
|
sseLine("response/content", "先读取文件。") +
|
||||||
|
sseLine("response/content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||||
|
"data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||||
|
|
||||||
|
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||||
|
if len(addedPayloads) < 2 {
|
||||||
|
t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
indexes := map[int]struct{}{}
|
||||||
|
typeByIndex := map[int]string{}
|
||||||
|
addedIDs := map[string]string{}
|
||||||
|
for _, payload := range addedPayloads {
|
||||||
|
item, _ := payload["item"].(map[string]any)
|
||||||
|
itemType := strings.TrimSpace(asString(item["type"]))
|
||||||
|
outputIndex := int(asFloat(payload["output_index"]))
|
||||||
|
if _, exists := indexes[outputIndex]; exists {
|
||||||
|
t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload)
|
||||||
|
}
|
||||||
|
indexes[outputIndex] = struct{}{}
|
||||||
|
typeByIndex[outputIndex] = itemType
|
||||||
|
addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"]))
|
||||||
|
}
|
||||||
|
|
||||||
|
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected response.completed payload, body=%s", rec.Body.String())
|
||||||
|
}
|
||||||
|
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||||
|
output, _ := responseObj["output"].([]any)
|
||||||
|
found := map[string]bool{}
|
||||||
|
for _, item := range output {
|
||||||
|
m, _ := item.(map[string]any)
|
||||||
|
itemType := strings.TrimSpace(asString(m["type"]))
|
||||||
|
itemID := strings.TrimSpace(asString(m["id"]))
|
||||||
|
if itemType == "" || itemID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID {
|
||||||
|
found[itemType] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found["message"] || !found["function_call"] {
|
||||||
|
t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "")
|
||||||
|
body := rec.Body.String()
|
||||||
|
if strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||||
|
t.Fatalf("did not expect function_call events for tool_choice=none, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it
|
||||||
|
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||||
|
t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||||
|
t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "event: response.output_item.done") {
|
||||||
|
t.Fatalf("expected runtime to close function output item, body=%s", body)
|
||||||
|
}
|
||||||
|
if !strings.Contains(body, "event: response.completed") {
|
||||||
|
t.Fatalf("expected response.completed event, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine("plain text only") + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
policy := util.ToolChoicePolicy{
|
||||||
|
Mode: util.ToolChoiceRequired,
|
||||||
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
|
}
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "event: response.failed") {
|
||||||
|
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, "event: response.completed") {
|
||||||
|
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
policy := util.ToolChoicePolicy{
|
||||||
|
Mode: util.ToolChoiceRequired,
|
||||||
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
if !strings.Contains(body, "event: response.failed") {
|
||||||
|
t.Fatalf("expected response.failed event, body=%s", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, "event: response.completed") {
|
||||||
|
t.Fatalf("did not expect response.completed, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
sseLine := func(v string) string {
|
||||||
|
b, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": v,
|
||||||
|
})
|
||||||
|
return "data: " + string(b) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n"
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||||
|
body := rec.Body.String()
|
||||||
|
if strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||||
|
t.Fatalf("did not expect function_call events for unknown tool, body=%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
|
||||||
|
`data: [DONE]` + "\n",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
policy := util.ToolChoicePolicy{
|
||||||
|
Mode: util.ToolChoiceRequired,
|
||||||
|
Allowed: map[string]struct{}{"read_file": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "")
|
||||||
|
if rec.Code != http.StatusUnprocessableEntity {
|
||||||
|
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
errObj, _ := out["error"].(map[string]any)
|
||||||
|
if asString(errObj["code"]) != "tool_choice_violation" {
|
||||||
|
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
|
||||||
|
`data: [DONE]` + "\n",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||||
|
|
||||||
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "")
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200 for tool_choice=none passthrough text, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
output, _ := out["output"].([]any)
|
||||||
|
for _, item := range output {
|
||||||
|
m, _ := item.(map[string]any)
|
||||||
|
if m != nil && m["type"] == "function_call" {
|
||||||
|
t.Fatalf("did not expect function_call output item for tool_choice=none, got %#v", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||||
|
matched := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if strings.HasPrefix(line, "event: ") {
|
||||||
|
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||||
|
matched = evt == targetEvent
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !matched || !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||||
|
if raw == "" || raw == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return payload, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||||
|
matched := false
|
||||||
|
out := make([]map[string]any, 0, 2)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if strings.HasPrefix(line, "event: ") {
|
||||||
|
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||||
|
matched = evt == targetEvent
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !matched || !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||||
|
if raw == "" || raw == "[DONE]" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, payload)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func asFloat(v any) float64 {
|
||||||
|
switch x := v.(type) {
|
||||||
|
case float64:
|
||||||
|
return x
|
||||||
|
case float32:
|
||||||
|
return float64(x)
|
||||||
|
case int:
|
||||||
|
return float64(x)
|
||||||
|
case int64:
|
||||||
|
return float64(x)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
326
internal/adapter/openai/standard_request.go
Normal file
326
internal/adapter/openai/standard_request.go
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
|
||||||
|
model, _ := req["model"].(string)
|
||||||
|
messagesRaw, _ := req["messages"].([]any)
|
||||||
|
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||||
|
return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
|
||||||
|
}
|
||||||
|
resolvedModel, ok := config.ResolveModel(store, model)
|
||||||
|
if !ok {
|
||||||
|
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
|
||||||
|
}
|
||||||
|
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||||
|
responseModel := strings.TrimSpace(model)
|
||||||
|
if responseModel == "" {
|
||||||
|
responseModel = resolvedModel
|
||||||
|
}
|
||||||
|
toolPolicy := util.DefaultToolChoicePolicy()
|
||||||
|
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||||
|
passThrough := collectOpenAIChatPassThrough(req)
|
||||||
|
|
||||||
|
return util.StandardRequest{
|
||||||
|
Surface: "openai_chat",
|
||||||
|
RequestedModel: strings.TrimSpace(model),
|
||||||
|
ResolvedModel: resolvedModel,
|
||||||
|
ResponseModel: responseModel,
|
||||||
|
Messages: messagesRaw,
|
||||||
|
FinalPrompt: finalPrompt,
|
||||||
|
ToolNames: toolNames,
|
||||||
|
ToolChoice: toolPolicy,
|
||||||
|
Stream: util.ToBool(req["stream"]),
|
||||||
|
Thinking: thinkingEnabled,
|
||||||
|
Search: searchEnabled,
|
||||||
|
PassThrough: passThrough,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
|
||||||
|
model, _ := req["model"].(string)
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if model == "" {
|
||||||
|
return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.")
|
||||||
|
}
|
||||||
|
resolvedModel, ok := config.ResolveModel(store, model)
|
||||||
|
if !ok {
|
||||||
|
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
|
||||||
|
}
|
||||||
|
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||||
|
|
||||||
|
// Keep width-control as an explicit policy hook even if current default is true.
|
||||||
|
allowWideInput := true
|
||||||
|
if store != nil {
|
||||||
|
allowWideInput = store.CompatWideInputStrictOutput()
|
||||||
|
}
|
||||||
|
var messagesRaw []any
|
||||||
|
if allowWideInput {
|
||||||
|
messagesRaw = responsesMessagesFromRequest(req)
|
||||||
|
} else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
|
||||||
|
messagesRaw = msgs
|
||||||
|
}
|
||||||
|
if len(messagesRaw) == 0 {
|
||||||
|
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
|
||||||
|
}
|
||||||
|
toolPolicy, err := parseToolChoicePolicy(req["tool_choice"], req["tools"])
|
||||||
|
if err != nil {
|
||||||
|
return util.StandardRequest{}, err
|
||||||
|
}
|
||||||
|
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||||
|
if toolPolicy.IsNone() {
|
||||||
|
toolNames = nil
|
||||||
|
toolPolicy.Allowed = nil
|
||||||
|
} else {
|
||||||
|
toolPolicy.Allowed = namesToSet(toolNames)
|
||||||
|
}
|
||||||
|
passThrough := collectOpenAIChatPassThrough(req)
|
||||||
|
|
||||||
|
return util.StandardRequest{
|
||||||
|
Surface: "openai_responses",
|
||||||
|
RequestedModel: model,
|
||||||
|
ResolvedModel: resolvedModel,
|
||||||
|
ResponseModel: model,
|
||||||
|
Messages: messagesRaw,
|
||||||
|
FinalPrompt: finalPrompt,
|
||||||
|
ToolNames: toolNames,
|
||||||
|
ToolChoice: toolPolicy,
|
||||||
|
Stream: util.ToBool(req["stream"]),
|
||||||
|
Thinking: thinkingEnabled,
|
||||||
|
Search: searchEnabled,
|
||||||
|
PassThrough: passThrough,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
|
||||||
|
out := map[string]any{}
|
||||||
|
for _, k := range []string{
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"presence_penalty",
|
||||||
|
"frequency_penalty",
|
||||||
|
"stop",
|
||||||
|
} {
|
||||||
|
if v, ok := req[k]; ok {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseToolChoicePolicy(toolChoiceRaw any, toolsRaw any) (util.ToolChoicePolicy, error) {
|
||||||
|
policy := util.DefaultToolChoicePolicy()
|
||||||
|
declaredNames := extractDeclaredToolNames(toolsRaw)
|
||||||
|
declaredSet := namesToSet(declaredNames)
|
||||||
|
if len(declaredNames) > 0 {
|
||||||
|
policy.Allowed = declaredSet
|
||||||
|
}
|
||||||
|
|
||||||
|
if toolChoiceRaw == nil {
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := toolChoiceRaw.(type) {
|
||||||
|
case string:
|
||||||
|
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||||
|
case "", "auto":
|
||||||
|
policy.Mode = util.ToolChoiceAuto
|
||||||
|
case "none":
|
||||||
|
policy.Mode = util.ToolChoiceNone
|
||||||
|
policy.Allowed = nil
|
||||||
|
case "required":
|
||||||
|
policy.Mode = util.ToolChoiceRequired
|
||||||
|
default:
|
||||||
|
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice: %q", v)
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
allowedOverride, hasAllowedOverride, err := parseAllowedToolNames(v["allowed_tools"])
|
||||||
|
if err != nil {
|
||||||
|
return util.ToolChoicePolicy{}, err
|
||||||
|
}
|
||||||
|
if hasAllowedOverride {
|
||||||
|
filtered := make([]string, 0, len(allowedOverride))
|
||||||
|
for _, name := range allowedOverride {
|
||||||
|
if _, ok := declaredSet[name]; !ok {
|
||||||
|
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice.allowed_tools contains undeclared tool %q", name)
|
||||||
|
}
|
||||||
|
filtered = append(filtered, name)
|
||||||
|
}
|
||||||
|
policy.Allowed = namesToSet(filtered)
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := strings.ToLower(strings.TrimSpace(asString(v["type"])))
|
||||||
|
switch typ {
|
||||||
|
case "", "auto":
|
||||||
|
if hasFunctionSelector(v) {
|
||||||
|
name, err := parseForcedToolName(v)
|
||||||
|
if err != nil {
|
||||||
|
return util.ToolChoicePolicy{}, err
|
||||||
|
}
|
||||||
|
policy.Mode = util.ToolChoiceForced
|
||||||
|
policy.ForcedName = name
|
||||||
|
policy.Allowed = namesToSet([]string{name})
|
||||||
|
} else {
|
||||||
|
policy.Mode = util.ToolChoiceAuto
|
||||||
|
}
|
||||||
|
case "none":
|
||||||
|
policy.Mode = util.ToolChoiceNone
|
||||||
|
policy.Allowed = nil
|
||||||
|
case "required":
|
||||||
|
policy.Mode = util.ToolChoiceRequired
|
||||||
|
case "function":
|
||||||
|
name, err := parseForcedToolName(v)
|
||||||
|
if err != nil {
|
||||||
|
return util.ToolChoicePolicy{}, err
|
||||||
|
}
|
||||||
|
policy.Mode = util.ToolChoiceForced
|
||||||
|
policy.ForcedName = name
|
||||||
|
policy.Allowed = namesToSet([]string{name})
|
||||||
|
default:
|
||||||
|
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice.type: %q", typ)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice must be a string or object")
|
||||||
|
}
|
||||||
|
|
||||||
|
if policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced {
|
||||||
|
if len(declaredNames) == 0 {
|
||||||
|
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice=%s requires non-empty tools.", policy.Mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if policy.Mode == util.ToolChoiceForced {
|
||||||
|
if _, ok := declaredSet[policy.ForcedName]; !ok {
|
||||||
|
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice forced function %q is not declared in tools", policy.ForcedName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(policy.Allowed) == 0 && (policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced) {
|
||||||
|
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice policy resolved to empty allowed tool set")
|
||||||
|
}
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseForcedToolName(v map[string]any) (string, error) {
|
||||||
|
if name := strings.TrimSpace(asString(v["name"])); name != "" {
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
if fn, ok := v["function"].(map[string]any); ok {
|
||||||
|
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("tool_choice function requires name")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAllowedToolNames(raw any) ([]string, bool, error) {
|
||||||
|
if raw == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
collectName := func(v any) string {
|
||||||
|
if name := strings.TrimSpace(asString(v)); name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
if m, ok := v.(map[string]any); ok {
|
||||||
|
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
if fn, ok := m["function"].(map[string]any); ok {
|
||||||
|
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
names := []string{}
|
||||||
|
switch x := raw.(type) {
|
||||||
|
case []any:
|
||||||
|
for _, item := range x {
|
||||||
|
name := collectName(item)
|
||||||
|
if name == "" {
|
||||||
|
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains invalid item")
|
||||||
|
}
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
for _, item := range x {
|
||||||
|
name := strings.TrimSpace(item)
|
||||||
|
if name == "" {
|
||||||
|
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains empty name")
|
||||||
|
}
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, true, fmt.Errorf("tool_choice.allowed_tools must be an array")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(names) == 0 {
|
||||||
|
return nil, true, fmt.Errorf("tool_choice.allowed_tools must not be empty")
|
||||||
|
}
|
||||||
|
return names, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasFunctionSelector(v map[string]any) bool {
|
||||||
|
if strings.TrimSpace(asString(v["name"])) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if fn, ok := v["function"].(map[string]any); ok {
|
||||||
|
return strings.TrimSpace(asString(fn["name"])) != ""
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractDeclaredToolNames(toolsRaw any) []string {
|
||||||
|
tools, ok := toolsRaw.([]any)
|
||||||
|
if !ok || len(tools) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(tools))
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
for _, t := range tools {
|
||||||
|
tool, ok := t.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fn, _ := tool["function"].(map[string]any)
|
||||||
|
if len(fn) == 0 {
|
||||||
|
fn = tool
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(asString(fn["name"]))
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[name] = struct{}{}
|
||||||
|
out = append(out, name)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func namesToSet(names []string) map[string]struct{} {
|
||||||
|
if len(names) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]struct{}, len(names))
|
||||||
|
for _, name := range names {
|
||||||
|
trimmed := strings.TrimSpace(name)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
180
internal/adapter/openai/standard_request_test.go
Normal file
180
internal/adapter/openai/standard_request_test.go
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||||
|
return config.LoadStore()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIChatRequest(t *testing.T) {
|
||||||
|
store := newEmptyStoreForNormalizeTest(t)
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "gpt-5-codex",
|
||||||
|
"messages": []any{
|
||||||
|
map[string]any{"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
"temperature": 0.3,
|
||||||
|
"stream": true,
|
||||||
|
}
|
||||||
|
n, err := normalizeOpenAIChatRequest(store, req, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize failed: %v", err)
|
||||||
|
}
|
||||||
|
if n.ResolvedModel != "deepseek-reasoner" {
|
||||||
|
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
|
||||||
|
}
|
||||||
|
if !n.Stream {
|
||||||
|
t.Fatalf("expected stream=true")
|
||||||
|
}
|
||||||
|
if _, ok := n.PassThrough["temperature"]; !ok {
|
||||||
|
t.Fatalf("expected temperature passthrough")
|
||||||
|
}
|
||||||
|
if n.FinalPrompt == "" {
|
||||||
|
t.Fatalf("expected non-empty final prompt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
|
||||||
|
store := newEmptyStoreForNormalizeTest(t)
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"input": "ping",
|
||||||
|
"instructions": "system",
|
||||||
|
}
|
||||||
|
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize failed: %v", err)
|
||||||
|
}
|
||||||
|
if n.ResolvedModel != "deepseek-chat" {
|
||||||
|
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
|
||||||
|
}
|
||||||
|
if len(n.Messages) != 2 {
|
||||||
|
t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIResponsesRequestToolChoiceRequired(t *testing.T) {
|
||||||
|
store := newEmptyStoreForNormalizeTest(t)
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"input": "ping",
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "search",
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tool_choice": "required",
|
||||||
|
}
|
||||||
|
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize failed: %v", err)
|
||||||
|
}
|
||||||
|
if n.ToolChoice.Mode != util.ToolChoiceRequired {
|
||||||
|
t.Fatalf("expected tool choice mode required, got %q", n.ToolChoice.Mode)
|
||||||
|
}
|
||||||
|
if len(n.ToolNames) != 1 || n.ToolNames[0] != "search" {
|
||||||
|
t.Fatalf("unexpected tool names: %#v", n.ToolNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedFunction(t *testing.T) {
|
||||||
|
store := newEmptyStoreForNormalizeTest(t)
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"input": "ping",
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "search",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "read_file",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tool_choice": map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"name": "read_file",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize failed: %v", err)
|
||||||
|
}
|
||||||
|
if n.ToolChoice.Mode != util.ToolChoiceForced {
|
||||||
|
t.Fatalf("expected tool choice mode forced, got %q", n.ToolChoice.Mode)
|
||||||
|
}
|
||||||
|
if n.ToolChoice.ForcedName != "read_file" {
|
||||||
|
t.Fatalf("expected forced tool name read_file, got %q", n.ToolChoice.ForcedName)
|
||||||
|
}
|
||||||
|
if len(n.ToolNames) != 1 || n.ToolNames[0] != "read_file" {
|
||||||
|
t.Fatalf("expected filtered tool names [read_file], got %#v", n.ToolNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testing.T) {
|
||||||
|
store := newEmptyStoreForNormalizeTest(t)
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"input": "ping",
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "search",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tool_choice": map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"name": "read_file",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, err := normalizeOpenAIResponsesRequest(store, req, ""); err == nil {
|
||||||
|
t.Fatalf("expected forced undeclared tool to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T) {
|
||||||
|
store := newEmptyStoreForNormalizeTest(t)
|
||||||
|
req := map[string]any{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"input": "ping",
|
||||||
|
"tools": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": "search",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tool_choice": "none",
|
||||||
|
}
|
||||||
|
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalize failed: %v", err)
|
||||||
|
}
|
||||||
|
if n.ToolChoice.Mode != util.ToolChoiceNone {
|
||||||
|
t.Fatalf("expected tool choice mode none, got %q", n.ToolChoice.Mode)
|
||||||
|
}
|
||||||
|
if len(n.ToolNames) != 0 {
|
||||||
|
t.Fatalf("expected no tool names when tool_choice=none, got %#v", n.ToolNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
185
internal/adapter/openai/stream_status_test.go
Normal file
185
internal/adapter/openai/stream_status_test.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
chimw "github.com/go-chi/chi/v5/middleware"
|
||||||
|
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type streamStatusAuthStub struct{}
|
||||||
|
|
||||||
|
func (streamStatusAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||||
|
return &auth.RequestAuth{
|
||||||
|
UseConfigToken: false,
|
||||||
|
DeepSeekToken: "direct-token",
|
||||||
|
CallerID: "caller:test",
|
||||||
|
TriedAccounts: map[string]bool{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (streamStatusAuthStub) DetermineCaller(_ *http.Request) (*auth.RequestAuth, error) {
|
||||||
|
return &auth.RequestAuth{
|
||||||
|
UseConfigToken: false,
|
||||||
|
DeepSeekToken: "direct-token",
|
||||||
|
CallerID: "caller:test",
|
||||||
|
TriedAccounts: map[string]bool{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (streamStatusAuthStub) Release(_ *auth.RequestAuth) {}
|
||||||
|
|
||||||
|
type streamStatusDSStub struct {
|
||||||
|
resp *http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m streamStatusDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
return "session-id", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m streamStatusDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||||
|
return "pow", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||||
|
return m.resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
|
||||||
|
body := strings.Join(lines, "\n")
|
||||||
|
if !strings.HasSuffix(body, "\n") {
|
||||||
|
body += "\n"
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(strings.NewReader(body)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func captureStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||||
|
next.ServeHTTP(ww, r)
|
||||||
|
*statuses = append(*statuses, ww.Status())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsStreamStatusCapturedAs200(t *testing.T) {
|
||||||
|
statuses := make([]int, 0, 1)
|
||||||
|
h := &Handler{
|
||||||
|
Store: mockOpenAIConfig{wideInput: true},
|
||||||
|
Auth: streamStatusAuthStub{},
|
||||||
|
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
|
||||||
|
}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(captureStatusMiddleware(&statuses))
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||||
|
req.Header.Set("Authorization", "Bearer direct-token")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if len(statuses) != 1 {
|
||||||
|
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||||
|
}
|
||||||
|
if statuses[0] != http.StatusOK {
|
||||||
|
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesStreamStatusCapturedAs200(t *testing.T) {
|
||||||
|
statuses := make([]int, 0, 1)
|
||||||
|
h := &Handler{
|
||||||
|
Store: mockOpenAIConfig{wideInput: true},
|
||||||
|
Auth: streamStatusAuthStub{},
|
||||||
|
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
|
||||||
|
}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(captureStatusMiddleware(&statuses))
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
reqBody := `{"model":"deepseek-chat","input":"hi","stream":true}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||||
|
req.Header.Set("Authorization", "Bearer direct-token")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if len(statuses) != 1 {
|
||||||
|
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||||
|
}
|
||||||
|
if statuses[0] != http.StatusOK {
|
||||||
|
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||||
|
statuses := make([]int, 0, 1)
|
||||||
|
content, _ := json.Marshal(map[string]any{
|
||||||
|
"p": "response/content",
|
||||||
|
"v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||||
|
})
|
||||||
|
h := &Handler{
|
||||||
|
Store: mockOpenAIConfig{wideInput: true},
|
||||||
|
Auth: streamStatusAuthStub{},
|
||||||
|
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")},
|
||||||
|
}
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(captureStatusMiddleware(&statuses))
|
||||||
|
RegisterRoutes(r, h)
|
||||||
|
|
||||||
|
reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||||
|
req.Header.Set("Authorization", "Bearer direct-token")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if len(statuses) != 1 || statuses[0] != http.StatusOK {
|
||||||
|
t.Fatalf("expected captured status 200, got %#v", statuses)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||||
|
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||||
|
}
|
||||||
|
outputText, _ := out["output_text"].(string)
|
||||||
|
if outputText != "" {
|
||||||
|
t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText)
|
||||||
|
}
|
||||||
|
output, _ := out["output"].([]any)
|
||||||
|
hasFunctionCall := false
|
||||||
|
for _, item := range output {
|
||||||
|
m, _ := item.(map[string]any)
|
||||||
|
if m != nil && m["type"] == "function_call" {
|
||||||
|
hasFunctionCall = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasFunctionCall {
|
||||||
|
t.Fatalf("expected function_call output item, got %#v", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,17 +6,6 @@ import (
|
|||||||
"ds2api/internal/util"
|
"ds2api/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type toolStreamSieveState struct {
|
|
||||||
pending strings.Builder
|
|
||||||
capture strings.Builder
|
|
||||||
capturing bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type toolStreamEvent struct {
|
|
||||||
Content string
|
|
||||||
ToolCalls []util.ParsedToolCall
|
|
||||||
}
|
|
||||||
|
|
||||||
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
||||||
if state == nil {
|
if state == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -32,13 +21,27 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
|||||||
state.capture.WriteString(state.pending.String())
|
state.capture.WriteString(state.pending.String())
|
||||||
state.pending.Reset()
|
state.pending.Reset()
|
||||||
}
|
}
|
||||||
prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
||||||
|
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
||||||
|
}
|
||||||
|
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||||
if !ready {
|
if !ready {
|
||||||
|
if state.capture.Len() > toolSieveCaptureLimit {
|
||||||
|
content := state.capture.String()
|
||||||
|
state.capture.Reset()
|
||||||
|
state.capturing = false
|
||||||
|
state.resetIncrementalToolState()
|
||||||
|
state.noteText(content)
|
||||||
|
events = append(events, toolStreamEvent{Content: content})
|
||||||
|
continue
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
state.capture.Reset()
|
state.capture.Reset()
|
||||||
state.capturing = false
|
state.capturing = false
|
||||||
|
state.resetIncrementalToolState()
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
|
state.noteText(prefix)
|
||||||
events = append(events, toolStreamEvent{Content: prefix})
|
events = append(events, toolStreamEvent{Content: prefix})
|
||||||
}
|
}
|
||||||
if len(calls) > 0 {
|
if len(calls) > 0 {
|
||||||
@@ -58,11 +61,13 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
|||||||
if start >= 0 {
|
if start >= 0 {
|
||||||
prefix := pending[:start]
|
prefix := pending[:start]
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
|
state.noteText(prefix)
|
||||||
events = append(events, toolStreamEvent{Content: prefix})
|
events = append(events, toolStreamEvent{Content: prefix})
|
||||||
}
|
}
|
||||||
state.pending.Reset()
|
state.pending.Reset()
|
||||||
state.capture.WriteString(pending[start:])
|
state.capture.WriteString(pending[start:])
|
||||||
state.capturing = true
|
state.capturing = true
|
||||||
|
state.resetIncrementalToolState()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,6 +77,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
|||||||
}
|
}
|
||||||
state.pending.Reset()
|
state.pending.Reset()
|
||||||
state.pending.WriteString(hold)
|
state.pending.WriteString(hold)
|
||||||
|
state.noteText(safe)
|
||||||
events = append(events, toolStreamEvent{Content: safe})
|
events = append(events, toolStreamEvent{Content: safe})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,25 +90,34 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
|||||||
}
|
}
|
||||||
events := processToolSieveChunk(state, "", toolNames)
|
events := processToolSieveChunk(state, "", toolNames)
|
||||||
if state.capturing {
|
if state.capturing {
|
||||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||||
if ready {
|
if ready {
|
||||||
if consumedPrefix != "" {
|
if consumedPrefix != "" {
|
||||||
|
state.noteText(consumedPrefix)
|
||||||
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
||||||
}
|
}
|
||||||
if len(consumedCalls) > 0 {
|
if len(consumedCalls) > 0 {
|
||||||
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
||||||
}
|
}
|
||||||
if consumedSuffix != "" {
|
if consumedSuffix != "" {
|
||||||
|
state.noteText(consumedSuffix)
|
||||||
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Incomplete captured tool JSON at stream end: suppress raw capture.
|
content := state.capture.String()
|
||||||
|
if content != "" {
|
||||||
|
state.noteText(content)
|
||||||
|
events = append(events, toolStreamEvent{Content: content})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
state.capture.Reset()
|
state.capture.Reset()
|
||||||
state.capturing = false
|
state.capturing = false
|
||||||
|
state.resetIncrementalToolState()
|
||||||
}
|
}
|
||||||
if state.pending.Len() > 0 {
|
if state.pending.Len() > 0 {
|
||||||
events = append(events, toolStreamEvent{Content: state.pending.String()})
|
content := state.pending.String()
|
||||||
|
state.noteText(content)
|
||||||
|
events = append(events, toolStreamEvent{Content: content})
|
||||||
state.pending.Reset()
|
state.pending.Reset()
|
||||||
}
|
}
|
||||||
return events
|
return events
|
||||||
@@ -144,17 +159,26 @@ func findToolSegmentStart(s string) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
lower := strings.ToLower(s)
|
lower := strings.ToLower(s)
|
||||||
keyIdx := strings.Index(lower, "tool_calls")
|
offset := 0
|
||||||
if keyIdx < 0 {
|
for {
|
||||||
return -1
|
keyRel := strings.Index(lower[offset:], "tool_calls")
|
||||||
|
if keyRel < 0 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
keyIdx := offset + keyRel
|
||||||
|
start := strings.LastIndex(s[:keyIdx], "{")
|
||||||
|
if start < 0 {
|
||||||
|
start = keyIdx
|
||||||
|
}
|
||||||
|
if !insideCodeFence(s[:start]) {
|
||||||
|
return start
|
||||||
|
}
|
||||||
|
offset = keyIdx + len("tool_calls")
|
||||||
}
|
}
|
||||||
if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 {
|
|
||||||
return start
|
|
||||||
}
|
|
||||||
return keyIdx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||||
|
captured := state.capture.String()
|
||||||
if captured == "" {
|
if captured == "" {
|
||||||
return "", nil, "", false
|
return "", nil, "", false
|
||||||
}
|
}
|
||||||
@@ -171,53 +195,14 @@ func consumeToolCapture(captured string, toolNames []string) (prefix string, cal
|
|||||||
if !ok {
|
if !ok {
|
||||||
return "", nil, "", false
|
return "", nil, "", false
|
||||||
}
|
}
|
||||||
parsed := util.ParseToolCalls(obj, toolNames)
|
prefixPart := captured[:start]
|
||||||
|
suffixPart := captured[end:]
|
||||||
|
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||||
|
return captured, nil, "", true
|
||||||
|
}
|
||||||
|
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
|
||||||
if len(parsed) == 0 {
|
if len(parsed) == 0 {
|
||||||
// `tool_calls` key exists but strict JSON parse failed.
|
return captured, nil, "", true
|
||||||
// Drop the captured object body to avoid leaking raw tool JSON.
|
|
||||||
return captured[:start], nil, captured[end:], true
|
|
||||||
}
|
}
|
||||||
return captured[:start], parsed, captured[end:], true
|
return prefixPart, parsed, suffixPart, true
|
||||||
}
|
|
||||||
|
|
||||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
|
||||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
|
||||||
return "", 0, false
|
|
||||||
}
|
|
||||||
depth := 0
|
|
||||||
quote := byte(0)
|
|
||||||
escaped := false
|
|
||||||
for i := start; i < len(text); i++ {
|
|
||||||
ch := text[i]
|
|
||||||
if quote != 0 {
|
|
||||||
if escaped {
|
|
||||||
escaped = false
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ch == '\\' {
|
|
||||||
escaped = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ch == quote {
|
|
||||||
quote = 0
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ch == '"' || ch == '\'' {
|
|
||||||
quote = ch
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ch == '{' {
|
|
||||||
depth++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ch == '}' {
|
|
||||||
depth--
|
|
||||||
if depth == 0 {
|
|
||||||
end := i + 1
|
|
||||||
return text[start:end], end, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", 0, false
|
|
||||||
}
|
}
|
||||||
291
internal/adapter/openai/tool_sieve_incremental.go
Normal file
291
internal/adapter/openai/tool_sieve_incremental.go
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||||
|
if state.disableDeltas {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
captured := state.capture.String()
|
||||||
|
if captured == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(captured)
|
||||||
|
keyIdx := strings.Index(lower, "tool_calls")
|
||||||
|
if keyIdx < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||||
|
if start < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
||||||
|
if hasMultiple {
|
||||||
|
state.disableDeltas = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !certainSingle {
|
||||||
|
// In uncertain phases (e.g. first call arrived but array not closed yet),
|
||||||
|
// avoid speculative deltas and wait for final parsed tool_calls payload.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
deltas := make([]toolCallDelta, 0, 2)
|
||||||
|
if state.toolName == "" {
|
||||||
|
name, ok := extractToolCallName(captured, callStart)
|
||||||
|
if !ok || name == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.toolName = name
|
||||||
|
}
|
||||||
|
if state.toolArgsStart < 0 {
|
||||||
|
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
|
||||||
|
if ok {
|
||||||
|
state.toolArgsString = stringMode
|
||||||
|
if stringMode {
|
||||||
|
state.toolArgsStart = argsStart + 1
|
||||||
|
} else {
|
||||||
|
state.toolArgsStart = argsStart
|
||||||
|
}
|
||||||
|
state.toolArgsSent = state.toolArgsStart
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !state.toolNameSent {
|
||||||
|
if state.toolArgsStart < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.toolNameSent = true
|
||||||
|
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
|
||||||
|
}
|
||||||
|
if state.toolArgsStart < 0 || state.toolArgsDone {
|
||||||
|
return deltas
|
||||||
|
}
|
||||||
|
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
|
||||||
|
if !ok {
|
||||||
|
return deltas
|
||||||
|
}
|
||||||
|
if end > state.toolArgsSent {
|
||||||
|
deltas = append(deltas, toolCallDelta{
|
||||||
|
Index: 0,
|
||||||
|
Arguments: captured[state.toolArgsSent:end],
|
||||||
|
})
|
||||||
|
state.toolArgsSent = end
|
||||||
|
}
|
||||||
|
if complete {
|
||||||
|
state.toolArgsDone = true
|
||||||
|
}
|
||||||
|
return deltas
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
|
||||||
|
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||||
|
if !ok {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
i := skipSpaces(text, arrStart+1)
|
||||||
|
if i >= len(text) || text[i] != '{' {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
count := 0
|
||||||
|
depth := 0
|
||||||
|
quote := byte(0)
|
||||||
|
escaped := false
|
||||||
|
for ; i < len(text); i++ {
|
||||||
|
ch := text[i]
|
||||||
|
if quote != 0 {
|
||||||
|
if escaped {
|
||||||
|
escaped = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '\\' {
|
||||||
|
escaped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == quote {
|
||||||
|
quote = 0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '"' || ch == '\'' {
|
||||||
|
quote = ch
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '{' {
|
||||||
|
if depth == 0 {
|
||||||
|
count++
|
||||||
|
if count > 1 {
|
||||||
|
return false, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
depth++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '}' {
|
||||||
|
if depth > 0 {
|
||||||
|
depth--
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == ',' && depth == 0 {
|
||||||
|
// top-level separator means at least one more tool call exists
|
||||||
|
// (or is expected). Treat as multi-call and stop incremental deltas.
|
||||||
|
return false, true
|
||||||
|
}
|
||||||
|
if ch == ']' && depth == 0 {
|
||||||
|
return count == 1, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// array not closed yet: still uncertain whether more calls will appear
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
||||||
|
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||||
|
if !ok {
|
||||||
|
return -1, false
|
||||||
|
}
|
||||||
|
i := skipSpaces(text, arrStart+1)
|
||||||
|
if i >= len(text) || text[i] != '{' {
|
||||||
|
return -1, false
|
||||||
|
}
|
||||||
|
return i, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
|
||||||
|
i := keyIdx + len("tool_calls")
|
||||||
|
for i < len(text) && text[i] != ':' {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i >= len(text) {
|
||||||
|
return -1, false
|
||||||
|
}
|
||||||
|
i = skipSpaces(text, i+1)
|
||||||
|
if i >= len(text) || text[i] != '[' {
|
||||||
|
return -1, false
|
||||||
|
}
|
||||||
|
return i, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractToolCallName(text string, callStart int) (string, bool) {
|
||||||
|
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
|
||||||
|
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||||
|
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||||
|
if !fnOK {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
|
||||||
|
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name, _, ok := parseJSONStringLiteral(text, valueStart)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return name, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
|
||||||
|
keys := []string{"input", "arguments", "args", "parameters", "params"}
|
||||||
|
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
|
||||||
|
if !ok {
|
||||||
|
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||||
|
if !fnOK {
|
||||||
|
return -1, false, false
|
||||||
|
}
|
||||||
|
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
|
||||||
|
if !ok {
|
||||||
|
return -1, false, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if valueStart >= len(text) {
|
||||||
|
return -1, false, false
|
||||||
|
}
|
||||||
|
ch := text[valueStart]
|
||||||
|
if ch == '{' || ch == '[' {
|
||||||
|
return valueStart, false, true
|
||||||
|
}
|
||||||
|
if ch == '"' {
|
||||||
|
return valueStart, true, true
|
||||||
|
}
|
||||||
|
return -1, false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
|
||||||
|
if start < 0 || start > len(text) {
|
||||||
|
return 0, false, false
|
||||||
|
}
|
||||||
|
if stringMode {
|
||||||
|
escaped := false
|
||||||
|
for i := start; i < len(text); i++ {
|
||||||
|
ch := text[i]
|
||||||
|
if escaped {
|
||||||
|
escaped = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '\\' {
|
||||||
|
escaped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '"' {
|
||||||
|
return i, true, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(text), false, true
|
||||||
|
}
|
||||||
|
if start >= len(text) {
|
||||||
|
return start, false, false
|
||||||
|
}
|
||||||
|
if text[start] != '{' && text[start] != '[' {
|
||||||
|
return 0, false, false
|
||||||
|
}
|
||||||
|
depth := 0
|
||||||
|
quote := byte(0)
|
||||||
|
escaped := false
|
||||||
|
for i := start; i < len(text); i++ {
|
||||||
|
ch := text[i]
|
||||||
|
if quote != 0 {
|
||||||
|
if escaped {
|
||||||
|
escaped = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '\\' {
|
||||||
|
escaped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == quote {
|
||||||
|
quote = 0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '"' || ch == '\'' {
|
||||||
|
quote = ch
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '{' || ch == '[' {
|
||||||
|
depth++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '}' || ch == ']' {
|
||||||
|
depth--
|
||||||
|
if depth == 0 {
|
||||||
|
return i + 1, true, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(text), false, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func findFunctionObjectStart(text string, callStart int) (int, bool) {
|
||||||
|
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
|
||||||
|
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
|
||||||
|
return -1, false
|
||||||
|
}
|
||||||
|
return valueStart, true
|
||||||
|
}
|
||||||
152
internal/adapter/openai/tool_sieve_jsonscan.go
Normal file
152
internal/adapter/openai/tool_sieve_jsonscan.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||||
|
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||||
|
return "", 0, false
|
||||||
|
}
|
||||||
|
depth := 0
|
||||||
|
quote := byte(0)
|
||||||
|
escaped := false
|
||||||
|
for i := start; i < len(text); i++ {
|
||||||
|
ch := text[i]
|
||||||
|
if quote != 0 {
|
||||||
|
if escaped {
|
||||||
|
escaped = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '\\' {
|
||||||
|
escaped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == quote {
|
||||||
|
quote = 0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '"' || ch == '\'' {
|
||||||
|
quote = ch
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '{' {
|
||||||
|
depth++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '}' {
|
||||||
|
depth--
|
||||||
|
if depth == 0 {
|
||||||
|
end := i + 1
|
||||||
|
return text[start:end], end, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
|
||||||
|
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
depth := 0
|
||||||
|
quote := byte(0)
|
||||||
|
escaped := false
|
||||||
|
for i := objStart; i < len(text); i++ {
|
||||||
|
ch := text[i]
|
||||||
|
if quote != 0 {
|
||||||
|
if escaped {
|
||||||
|
escaped = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '\\' {
|
||||||
|
escaped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == quote {
|
||||||
|
quote = 0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '"' || ch == '\'' {
|
||||||
|
if depth == 1 {
|
||||||
|
key, end, ok := parseJSONStringLiteral(text, i)
|
||||||
|
if !ok {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
j := skipSpaces(text, end)
|
||||||
|
if j >= len(text) || text[j] != ':' {
|
||||||
|
i = end - 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
j = skipSpaces(text, j+1)
|
||||||
|
if j >= len(text) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if containsKey(keys, key) {
|
||||||
|
return j, true
|
||||||
|
}
|
||||||
|
i = j - 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
quote = ch
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '{' {
|
||||||
|
depth++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '}' {
|
||||||
|
depth--
|
||||||
|
if depth == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
|
||||||
|
if start < 0 || start >= len(text) || text[start] != '"' {
|
||||||
|
return "", 0, false
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
escaped := false
|
||||||
|
for i := start + 1; i < len(text); i++ {
|
||||||
|
ch := text[i]
|
||||||
|
if escaped {
|
||||||
|
b.WriteByte(ch)
|
||||||
|
escaped = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '\\' {
|
||||||
|
escaped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '"' {
|
||||||
|
return b.String(), i + 1, true
|
||||||
|
}
|
||||||
|
b.WriteByte(ch)
|
||||||
|
}
|
||||||
|
return "", 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsKey(keys []string, value string) bool {
|
||||||
|
for _, k := range keys {
|
||||||
|
if k == value {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipSpaces(text string, i int) int {
|
||||||
|
for i < len(text) {
|
||||||
|
switch text[i] {
|
||||||
|
case ' ', '\t', '\n', '\r':
|
||||||
|
i++
|
||||||
|
default:
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
75
internal/adapter/openai/tool_sieve_state.go
Normal file
75
internal/adapter/openai/tool_sieve_state.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type toolStreamSieveState struct {
|
||||||
|
pending strings.Builder
|
||||||
|
capture strings.Builder
|
||||||
|
capturing bool
|
||||||
|
recentTextTail string
|
||||||
|
disableDeltas bool
|
||||||
|
toolNameSent bool
|
||||||
|
toolName string
|
||||||
|
toolArgsStart int
|
||||||
|
toolArgsSent int
|
||||||
|
toolArgsString bool
|
||||||
|
toolArgsDone bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type toolStreamEvent struct {
|
||||||
|
Content string
|
||||||
|
ToolCalls []util.ParsedToolCall
|
||||||
|
ToolCallDeltas []toolCallDelta
|
||||||
|
}
|
||||||
|
|
||||||
|
type toolCallDelta struct {
|
||||||
|
Index int
|
||||||
|
Name string
|
||||||
|
Arguments string
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolSieveCaptureLimit = 8 * 1024
|
||||||
|
const toolSieveContextTailLimit = 256
|
||||||
|
|
||||||
|
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||||
|
s.disableDeltas = false
|
||||||
|
s.toolNameSent = false
|
||||||
|
s.toolName = ""
|
||||||
|
s.toolArgsStart = -1
|
||||||
|
s.toolArgsSent = -1
|
||||||
|
s.toolArgsString = false
|
||||||
|
s.toolArgsDone = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *toolStreamSieveState) noteText(content string) {
|
||||||
|
if strings.TrimSpace(content) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendTail(prev, next string, max int) string {
|
||||||
|
if max <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
combined := prev + next
|
||||||
|
if len(combined) <= max {
|
||||||
|
return combined
|
||||||
|
}
|
||||||
|
return combined[len(combined)-max:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksLikeToolExampleContext(text string) bool {
|
||||||
|
return insideCodeFence(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func insideCodeFence(text string) bool {
|
||||||
|
if text == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Count(text, "```")%2 == 1
|
||||||
|
}
|
||||||
21
internal/adapter/openai/trace.go
Normal file
21
internal/adapter/openai/trace.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func requestTraceID(r *http.Request) string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if q := strings.TrimSpace(r.URL.Query().Get("__trace_id")); q != "" {
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
if h := strings.TrimSpace(r.Header.Get("X-Ds2-Test-Trace")); h != "" {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(middleware.GetReqID(r.Context()))
|
||||||
|
}
|
||||||
47
internal/adapter/openai/trace_test.go
Normal file
47
internal/adapter/openai/trace_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func traceIDViaMiddleware(req *http.Request) string {
|
||||||
|
if req == nil {
|
||||||
|
return requestTraceID(nil)
|
||||||
|
}
|
||||||
|
var got string
|
||||||
|
h := middleware.RequestID(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||||
|
got = requestTraceID(r)
|
||||||
|
}))
|
||||||
|
h.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
return got
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestTraceIDPriority(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions?__trace_id=query-trace", nil)
|
||||||
|
req.Header.Set("X-Ds2-Test-Trace", "header-trace")
|
||||||
|
got := traceIDViaMiddleware(req)
|
||||||
|
if got != "query-trace" {
|
||||||
|
t.Fatalf("expected query trace id to win, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestTraceIDHeaderFallback(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
|
||||||
|
req.Header.Set("X-Ds2-Test-Trace", "header-trace")
|
||||||
|
got := traceIDViaMiddleware(req)
|
||||||
|
if got != "header-trace" {
|
||||||
|
t.Fatalf("expected header trace id to win when query missing, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestTraceIDReqIDFallback(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
|
||||||
|
got := traceIDViaMiddleware(req)
|
||||||
|
if got == "" {
|
||||||
|
t.Fatal("expected middleware request id fallback to be non-empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -56,24 +56,16 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
|||||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model, _ := req["model"].(string)
|
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||||
messagesRaw, _ := req["messages"].([]any)
|
if err != nil {
|
||||||
if model == "" || len(messagesRaw) == 0 {
|
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
|
if !stdReq.Stream {
|
||||||
if !ok {
|
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||||
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := normalizeMessages(messagesRaw)
|
|
||||||
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
|
|
||||||
messages, _ = injectToolPrompt(messages, tools)
|
|
||||||
}
|
|
||||||
finalPrompt := util.MessagesPrepare(messages)
|
|
||||||
|
|
||||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if a.UseConfigToken {
|
if a.UseConfigToken {
|
||||||
@@ -93,14 +85,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
payload := map[string]any{
|
payload := stdReq.CompletionPayload(sessionID)
|
||||||
"chat_session_id": sessionID,
|
|
||||||
"parent_message_id": nil,
|
|
||||||
"prompt": finalPrompt,
|
|
||||||
"ref_file_ids": []any{},
|
|
||||||
"thinking_enabled": thinkingEnabled,
|
|
||||||
"search_enabled": searchEnabled,
|
|
||||||
}
|
|
||||||
leaseID := h.holdStreamLease(a)
|
leaseID := h.holdStreamLease(a)
|
||||||
if leaseID == "" {
|
if leaseID == "" {
|
||||||
writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease")
|
writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease")
|
||||||
@@ -108,15 +93,18 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
|||||||
}
|
}
|
||||||
leased = true
|
leased = true
|
||||||
writeJSON(w, http.StatusOK, map[string]any{
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
"session_id": sessionID,
|
"session_id": sessionID,
|
||||||
"lease_id": leaseID,
|
"lease_id": leaseID,
|
||||||
"model": model,
|
"model": stdReq.ResponseModel,
|
||||||
"final_prompt": finalPrompt,
|
"final_prompt": stdReq.FinalPrompt,
|
||||||
"thinking_enabled": thinkingEnabled,
|
"thinking_enabled": stdReq.Thinking,
|
||||||
"search_enabled": searchEnabled,
|
"search_enabled": stdReq.Search,
|
||||||
"deepseek_token": a.DeepSeekToken,
|
"tool_names": stdReq.ToolNames,
|
||||||
"pow_header": powHeader,
|
"toolcall_feature_match": h.toolcallFeatureMatchEnabled(),
|
||||||
"payload": payload,
|
"toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(),
|
||||||
|
"deepseek_token": a.DeepSeekToken,
|
||||||
|
"pow_header": powHeader,
|
||||||
|
"payload": payload,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
46
internal/admin/deps.go
Normal file
46
internal/admin/deps.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"ds2api/internal/account"
|
||||||
|
"ds2api/internal/auth"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
"ds2api/internal/deepseek"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConfigStore interface {
|
||||||
|
Snapshot() config.Config
|
||||||
|
Keys() []string
|
||||||
|
Accounts() []config.Account
|
||||||
|
FindAccount(identifier string) (config.Account, bool)
|
||||||
|
UpdateAccountToken(identifier, token string) error
|
||||||
|
Update(mutator func(*config.Config) error) error
|
||||||
|
ExportJSONAndBase64() (string, string, error)
|
||||||
|
IsEnvBacked() bool
|
||||||
|
SetVercelSync(hash string, ts int64) error
|
||||||
|
AdminPasswordHash() string
|
||||||
|
AdminJWTExpireHours() int
|
||||||
|
AdminJWTValidAfterUnix() int64
|
||||||
|
RuntimeAccountMaxInflight() int
|
||||||
|
RuntimeAccountMaxQueue(defaultSize int) int
|
||||||
|
RuntimeGlobalMaxInflight(defaultSize int) int
|
||||||
|
}
|
||||||
|
|
||||||
|
type PoolController interface {
|
||||||
|
Reset()
|
||||||
|
Status() map[string]any
|
||||||
|
ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeepSeekCaller interface {
|
||||||
|
Login(ctx context.Context, acc config.Account) (string, error)
|
||||||
|
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
|
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||||
|
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ConfigStore = (*config.Store)(nil)
|
||||||
|
var _ PoolController = (*account.Pool)(nil)
|
||||||
|
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||||
@@ -2,16 +2,12 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
"ds2api/internal/account"
|
|
||||||
"ds2api/internal/config"
|
|
||||||
"ds2api/internal/deepseek"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
Store *config.Store
|
Store ConfigStore
|
||||||
Pool *account.Pool
|
Pool PoolController
|
||||||
DS *deepseek.Client
|
DS DeepSeekCaller
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||||
@@ -22,6 +18,11 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
|||||||
pr.Get("/vercel/config", h.getVercelConfig)
|
pr.Get("/vercel/config", h.getVercelConfig)
|
||||||
pr.Get("/config", h.getConfig)
|
pr.Get("/config", h.getConfig)
|
||||||
pr.Post("/config", h.updateConfig)
|
pr.Post("/config", h.updateConfig)
|
||||||
|
pr.Get("/settings", h.getSettings)
|
||||||
|
pr.Put("/settings", h.updateSettings)
|
||||||
|
pr.Post("/settings/password", h.updateSettingsPassword)
|
||||||
|
pr.Post("/config/import", h.configImport)
|
||||||
|
pr.Get("/config/export", h.configExport)
|
||||||
pr.Post("/keys", h.addKey)
|
pr.Post("/keys", h.addKey)
|
||||||
pr.Delete("/keys/{key}", h.deleteKey)
|
pr.Delete("/keys/{key}", h.deleteKey)
|
||||||
pr.Get("/accounts", h.listAccounts)
|
pr.Get("/accounts", h.listAccounts)
|
||||||
@@ -35,5 +36,7 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
|||||||
pr.Post("/vercel/sync", h.syncVercel)
|
pr.Post("/vercel/sync", h.syncVercel)
|
||||||
pr.Get("/vercel/status", h.vercelStatus)
|
pr.Get("/vercel/status", h.vercelStatus)
|
||||||
pr.Get("/export", h.exportConfig)
|
pr.Get("/export", h.exportConfig)
|
||||||
|
pr.Get("/dev/captures", h.getDevCaptures)
|
||||||
|
pr.Delete("/dev/captures", h.clearDevCaptures)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
114
internal/admin/handler_accounts_crud.go
Normal file
114
internal/admin/handler_accounts_crud.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
|
page := intFromQuery(r, "page", 1)
|
||||||
|
pageSize := intFromQuery(r, "page_size", 10)
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize < 1 {
|
||||||
|
pageSize = 1
|
||||||
|
}
|
||||||
|
if pageSize > 100 {
|
||||||
|
pageSize = 100
|
||||||
|
}
|
||||||
|
accounts := h.Store.Snapshot().Accounts
|
||||||
|
total := len(accounts)
|
||||||
|
reverseAccounts(accounts)
|
||||||
|
totalPages := 1
|
||||||
|
if total > 0 {
|
||||||
|
totalPages = (total + pageSize - 1) / pageSize
|
||||||
|
}
|
||||||
|
start := (page - 1) * pageSize
|
||||||
|
if start > total {
|
||||||
|
start = total
|
||||||
|
}
|
||||||
|
end := start + pageSize
|
||||||
|
if end > total {
|
||||||
|
end = total
|
||||||
|
}
|
||||||
|
items := make([]map[string]any, 0, end-start)
|
||||||
|
for _, acc := range accounts[start:end] {
|
||||||
|
token := strings.TrimSpace(acc.Token)
|
||||||
|
preview := ""
|
||||||
|
if token != "" {
|
||||||
|
if len(token) > 20 {
|
||||||
|
preview = token[:20] + "..."
|
||||||
|
} else {
|
||||||
|
preview = token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
items = append(items, map[string]any{
|
||||||
|
"identifier": acc.Identifier(),
|
||||||
|
"email": acc.Email,
|
||||||
|
"mobile": acc.Mobile,
|
||||||
|
"has_password": acc.Password != "",
|
||||||
|
"has_token": token != "",
|
||||||
|
"token_preview": preview,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req map[string]any
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
acc := toAccount(req)
|
||||||
|
if acc.Identifier() == "" {
|
||||||
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := h.Store.Update(func(c *config.Config) error {
|
||||||
|
for _, a := range c.Accounts {
|
||||||
|
if acc.Email != "" && a.Email == acc.Email {
|
||||||
|
return fmt.Errorf("邮箱已存在")
|
||||||
|
}
|
||||||
|
if acc.Mobile != "" && a.Mobile == acc.Mobile {
|
||||||
|
return fmt.Errorf("手机号已存在")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Accounts = append(c.Accounts, acc)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.Pool.Reset()
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
identifier := chi.URLParam(r, "identifier")
|
||||||
|
err := h.Store.Update(func(c *config.Config) error {
|
||||||
|
idx := -1
|
||||||
|
for i, a := range c.Accounts {
|
||||||
|
if accountMatchesIdentifier(a, identifier) {
|
||||||
|
idx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx < 0 {
|
||||||
|
return fmt.Errorf("账号不存在")
|
||||||
|
}
|
||||||
|
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.Pool.Reset()
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||||
|
}
|
||||||
138
internal/admin/handler_accounts_identifier_test.go
Normal file
138
internal/admin/handler_accounts_identifier_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"ds2api/internal/account"
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAdminTestHandler(t *testing.T, raw string) *Handler {
|
||||||
|
t.Helper()
|
||||||
|
t.Setenv("DS2API_CONFIG_JSON", raw)
|
||||||
|
t.Setenv("CONFIG_JSON", "")
|
||||||
|
store := config.LoadStore()
|
||||||
|
return &Handler{
|
||||||
|
Store: store,
|
||||||
|
Pool: account.NewPool(store),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"accounts":[{"token":"token-only-account"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/accounts?page=1&page_size=10", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.listAccounts(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||||
|
t.Fatalf("decode response failed: %v", err)
|
||||||
|
}
|
||||||
|
items, _ := payload["items"].([]any)
|
||||||
|
if len(items) != 1 {
|
||||||
|
t.Fatalf("expected 1 item, got %d", len(items))
|
||||||
|
}
|
||||||
|
first, _ := items[0].(map[string]any)
|
||||||
|
identifier, _ := first["identifier"].(string)
|
||||||
|
if identifier == "" {
|
||||||
|
t.Fatalf("expected non-empty identifier: %#v", first)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(identifier, "token:") {
|
||||||
|
t.Fatalf("expected token synthetic identifier, got %q", identifier)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAccountSupportsTokenOnlyIdentifier(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"accounts":[{"token":"token-only-account"}]
|
||||||
|
}`)
|
||||||
|
accounts := h.Store.Accounts()
|
||||||
|
if len(accounts) != 1 {
|
||||||
|
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||||
|
}
|
||||||
|
id := accounts[0].Identifier()
|
||||||
|
if id == "" {
|
||||||
|
t.Fatal("expected token-only synthetic identifier")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape(id), nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.Store.Accounts()); got != 0 {
|
||||||
|
t.Fatalf("expected account removed, remaining=%d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAccountSupportsMobileAlias(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"accounts":[{"email":"u@example.com","mobile":"13800138000","password":"pwd"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/13800138000", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
if got := len(h.Store.Accounts()); got != 0 {
|
||||||
|
t.Fatalf("expected account removed, remaining=%d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||||
|
h := newAdminTestHandler(t, `{
|
||||||
|
"accounts":[
|
||||||
|
{"email":"u@example.com","mobile":"13800138000","password":"pwd"},
|
||||||
|
{"token":"token-only-account"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
accByMobile, ok := findAccountByIdentifier(h.Store, "13800138000")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected find by mobile")
|
||||||
|
}
|
||||||
|
if accByMobile.Email != "u@example.com" {
|
||||||
|
t.Fatalf("unexpected account by mobile: %#v", accByMobile)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenOnlyID := ""
|
||||||
|
for _, acc := range h.Store.Accounts() {
|
||||||
|
if strings.TrimSpace(acc.Email) == "" && strings.TrimSpace(acc.Mobile) == "" {
|
||||||
|
tokenOnlyID = acc.Identifier()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tokenOnlyID == "" {
|
||||||
|
t.Fatal("expected token-only account identifier")
|
||||||
|
}
|
||||||
|
accByTokenOnly, ok := findAccountByIdentifier(h.Store, tokenOnlyID)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected find by token-only id=%q", tokenOnlyID)
|
||||||
|
}
|
||||||
|
if accByTokenOnly.Token != "token-only-account" {
|
||||||
|
t.Fatalf("unexpected token-only account: %#v", accByTokenOnly)
|
||||||
|
}
|
||||||
|
}
|
||||||
7
internal/admin/handler_accounts_queue.go
Normal file
7
internal/admin/handler_accounts_queue.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
writeJSON(w, http.StatusOK, h.Pool.Status())
|
||||||
|
}
|
||||||
@@ -11,121 +11,20 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
|
|
||||||
authn "ds2api/internal/auth"
|
authn "ds2api/internal/auth"
|
||||||
"ds2api/internal/config"
|
"ds2api/internal/config"
|
||||||
"ds2api/internal/sse"
|
"ds2api/internal/sse"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
|
||||||
page := intFromQuery(r, "page", 1)
|
|
||||||
pageSize := intFromQuery(r, "page_size", 10)
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
if pageSize < 1 {
|
|
||||||
pageSize = 1
|
|
||||||
}
|
|
||||||
if pageSize > 100 {
|
|
||||||
pageSize = 100
|
|
||||||
}
|
|
||||||
accounts := h.Store.Snapshot().Accounts
|
|
||||||
total := len(accounts)
|
|
||||||
reverseAccounts(accounts)
|
|
||||||
totalPages := 1
|
|
||||||
if total > 0 {
|
|
||||||
totalPages = (total + pageSize - 1) / pageSize
|
|
||||||
}
|
|
||||||
start := (page - 1) * pageSize
|
|
||||||
if start > total {
|
|
||||||
start = total
|
|
||||||
}
|
|
||||||
end := start + pageSize
|
|
||||||
if end > total {
|
|
||||||
end = total
|
|
||||||
}
|
|
||||||
items := make([]map[string]any, 0, end-start)
|
|
||||||
for _, acc := range accounts[start:end] {
|
|
||||||
token := strings.TrimSpace(acc.Token)
|
|
||||||
preview := ""
|
|
||||||
if token != "" {
|
|
||||||
if len(token) > 20 {
|
|
||||||
preview = token[:20] + "..."
|
|
||||||
} else {
|
|
||||||
preview = token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
items = append(items, map[string]any{"email": acc.Email, "mobile": acc.Mobile, "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview})
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var req map[string]any
|
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
|
||||||
acc := toAccount(req)
|
|
||||||
if acc.Identifier() == "" {
|
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err := h.Store.Update(func(c *config.Config) error {
|
|
||||||
for _, a := range c.Accounts {
|
|
||||||
if acc.Email != "" && a.Email == acc.Email {
|
|
||||||
return fmt.Errorf("邮箱已存在")
|
|
||||||
}
|
|
||||||
if acc.Mobile != "" && a.Mobile == acc.Mobile {
|
|
||||||
return fmt.Errorf("手机号已存在")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Accounts = append(c.Accounts, acc)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.Pool.Reset()
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
|
||||||
identifier := chi.URLParam(r, "identifier")
|
|
||||||
err := h.Store.Update(func(c *config.Config) error {
|
|
||||||
idx := -1
|
|
||||||
for i, a := range c.Accounts {
|
|
||||||
if a.Email == identifier || a.Mobile == identifier {
|
|
||||||
idx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if idx < 0 {
|
|
||||||
return fmt.Errorf("账号不存在")
|
|
||||||
}
|
|
||||||
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.Pool.Reset()
|
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) {
|
|
||||||
writeJSON(w, http.StatusOK, h.Pool.Status())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
identifier, _ := req["identifier"].(string)
|
identifier, _ := req["identifier"].(string)
|
||||||
if strings.TrimSpace(identifier) == "" {
|
if strings.TrimSpace(identifier) == "" {
|
||||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(email 或 mobile)"})
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acc, ok := h.Store.FindAccount(identifier)
|
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
||||||
if !ok {
|
if !ok {
|
||||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||||
return
|
return
|
||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
func (h *Handler) requireAdmin(next http.Handler) http.Handler {
|
func (h *Handler) requireAdmin(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := authn.VerifyAdminRequest(r); err != nil {
|
if err := authn.VerifyAdminRequestWithStore(r, h.Store); err != nil {
|
||||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()})
|
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -25,18 +25,18 @@ func (h *Handler) login(w http.ResponseWriter, r *http.Request) {
|
|||||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
adminKey, _ := req["admin_key"].(string)
|
adminKey, _ := req["admin_key"].(string)
|
||||||
expireHours := intFrom(req["expire_hours"])
|
expireHours := intFrom(req["expire_hours"])
|
||||||
if expireHours <= 0 {
|
if !authn.VerifyAdminCredential(adminKey, h.Store) {
|
||||||
expireHours = 24
|
|
||||||
}
|
|
||||||
if adminKey != authn.AdminKey() {
|
|
||||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"})
|
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err := authn.CreateJWT(expireHours)
|
token, err := authn.CreateJWTWithStore(expireHours, h.Store)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if expireHours <= 0 {
|
||||||
|
expireHours = h.Store.AdminJWTExpireHours()
|
||||||
|
}
|
||||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600})
|
writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ func (h *Handler) verify(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
token := strings.TrimSpace(header[7:])
|
token := strings.TrimSpace(header[7:])
|
||||||
payload, err := authn.VerifyJWT(token)
|
payload, err := authn.VerifyJWTWithStore(token, h.Store)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()})
|
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()})
|
||||||
return
|
return
|
||||||
|
|||||||
182
internal/admin/handler_config_import.go
Normal file
182
internal/admin/handler_config_import.go
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode")))
|
||||||
|
if mode == "" {
|
||||||
|
mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode")))
|
||||||
|
}
|
||||||
|
if mode == "" {
|
||||||
|
mode = "merge"
|
||||||
|
}
|
||||||
|
if mode != "merge" && mode != "replace" {
|
||||||
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := req
|
||||||
|
if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 {
|
||||||
|
payload = raw
|
||||||
|
}
|
||||||
|
rawJSON, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var incoming config.Config
|
||||||
|
if err := json.Unmarshal(rawJSON, &incoming); err != nil {
|
||||||
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
importedKeys, importedAccounts := 0, 0
|
||||||
|
err = h.Store.Update(func(c *config.Config) error {
|
||||||
|
next := c.Clone()
|
||||||
|
if mode == "replace" {
|
||||||
|
next = incoming.Clone()
|
||||||
|
next.VercelSyncHash = c.VercelSyncHash
|
||||||
|
next.VercelSyncTime = c.VercelSyncTime
|
||||||
|
importedKeys = len(next.Keys)
|
||||||
|
importedAccounts = len(next.Accounts)
|
||||||
|
} else {
|
||||||
|
existingKeys := map[string]struct{}{}
|
||||||
|
for _, k := range next.Keys {
|
||||||
|
existingKeys[k] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, k := range incoming.Keys {
|
||||||
|
key := strings.TrimSpace(k)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := existingKeys[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingKeys[key] = struct{}{}
|
||||||
|
next.Keys = append(next.Keys, key)
|
||||||
|
importedKeys++
|
||||||
|
}
|
||||||
|
|
||||||
|
existingAccounts := map[string]struct{}{}
|
||||||
|
for _, acc := range next.Accounts {
|
||||||
|
existingAccounts[acc.Identifier()] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, acc := range incoming.Accounts {
|
||||||
|
id := acc.Identifier()
|
||||||
|
if id == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := existingAccounts[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingAccounts[id] = struct{}{}
|
||||||
|
next.Accounts = append(next.Accounts, acc)
|
||||||
|
importedAccounts++
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(incoming.ClaudeMapping) > 0 {
|
||||||
|
if next.ClaudeMapping == nil {
|
||||||
|
next.ClaudeMapping = map[string]string{}
|
||||||
|
}
|
||||||
|
for k, v := range incoming.ClaudeMapping {
|
||||||
|
next.ClaudeMapping[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(incoming.ClaudeModelMap) > 0 {
|
||||||
|
if next.ClaudeModelMap == nil {
|
||||||
|
next.ClaudeModelMap = map[string]string{}
|
||||||
|
}
|
||||||
|
for k, v := range incoming.ClaudeModelMap {
|
||||||
|
next.ClaudeModelMap[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(incoming.ModelAliases) > 0 {
|
||||||
|
if next.ModelAliases == nil {
|
||||||
|
next.ModelAliases = map[string]string{}
|
||||||
|
}
|
||||||
|
for k, v := range incoming.ModelAliases {
|
||||||
|
next.ModelAliases[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(incoming.Toolcall.Mode) != "" {
|
||||||
|
next.Toolcall.Mode = incoming.Toolcall.Mode
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" {
|
||||||
|
next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence
|
||||||
|
}
|
||||||
|
if incoming.Responses.StoreTTLSeconds > 0 {
|
||||||
|
next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(incoming.Embeddings.Provider) != "" {
|
||||||
|
next.Embeddings.Provider = incoming.Embeddings.Provider
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(incoming.Admin.PasswordHash) != "" {
|
||||||
|
next.Admin.PasswordHash = incoming.Admin.PasswordHash
|
||||||
|
}
|
||||||
|
if incoming.Admin.JWTExpireHours > 0 {
|
||||||
|
next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours
|
||||||
|
}
|
||||||
|
if incoming.Admin.JWTValidAfterUnix > 0 {
|
||||||
|
next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix
|
||||||
|
}
|
||||||
|
if incoming.Runtime.AccountMaxInflight > 0 {
|
||||||
|
next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight
|
||||||
|
}
|
||||||
|
if incoming.Runtime.AccountMaxQueue > 0 {
|
||||||
|
next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue
|
||||||
|
}
|
||||||
|
if incoming.Runtime.GlobalMaxInflight > 0 {
|
||||||
|
next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizeSettingsConfig(&next)
|
||||||
|
if err := validateSettingsConfig(next); err != nil {
|
||||||
|
return newRequestError(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = next
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if detail, ok := requestErrorDetail(err); ok {
|
||||||
|
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Pool.Reset()
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"success": true,
|
||||||
|
"mode": mode,
|
||||||
|
"imported_keys": importedKeys,
|
||||||
|
"imported_accounts": importedAccounts,
|
||||||
|
"message": "config imported",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) computeSyncHash() string {
|
||||||
|
snap := h.Store.Snapshot().Clone()
|
||||||
|
snap.VercelSyncHash = ""
|
||||||
|
snap.VercelSyncTime = 0
|
||||||
|
b, _ := json.Marshal(snap)
|
||||||
|
sum := md5.Sum(b)
|
||||||
|
return fmt.Sprintf("%x", sum)
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user