mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-03 16:05:26 +08:00
Compare commits
266 Commits
v2.0.0_Bet
...
v2.4.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
586d31e556 | ||
|
|
b0a09dfab0 | ||
|
|
58f753d0c0 | ||
|
|
2e0586d060 | ||
|
|
1676c8e4f2 | ||
|
|
add13366d2 | ||
|
|
d5a23191f2 | ||
|
|
d2d4e39983 | ||
|
|
6e0dca3b30 | ||
|
|
b108a7915a | ||
|
|
2caabd8ce6 | ||
|
|
c4a73e871a | ||
|
|
6802a3d53e | ||
|
|
e828006cb0 | ||
|
|
a6499cbece | ||
|
|
a504905626 | ||
|
|
59bf78d2c4 | ||
|
|
25b3292497 | ||
|
|
6cf4f0528c | ||
|
|
d8f8dcb704 | ||
|
|
455489ffeb | ||
|
|
5031ae0e6f | ||
|
|
3fccec0e22 | ||
|
|
00d38f1187 | ||
|
|
fe0f3d2c17 | ||
|
|
f67cbfad35 | ||
|
|
11f66db87d | ||
|
|
9afc533153 | ||
|
|
6a39543288 | ||
|
|
7131b06e26 | ||
|
|
8fa1f998aa | ||
|
|
f8936887d0 | ||
|
|
db89744055 | ||
|
|
65312fc573 | ||
|
|
661d753fd3 | ||
|
|
7ca3f141c6 | ||
|
|
d530d25793 | ||
|
|
990cdcf02d | ||
|
|
648bb74587 | ||
|
|
9e5baed061 | ||
|
|
4884773639 | ||
|
|
6758514c61 | ||
|
|
01f33c409f | ||
|
|
55f11e655a | ||
|
|
2275e931f9 | ||
|
|
40594a44db | ||
|
|
67787d9c99 | ||
|
|
7061094964 | ||
|
|
492c603300 | ||
|
|
7e473dffc9 | ||
|
|
43a6e6712f | ||
|
|
ce1b76c90f | ||
|
|
1e7e0b2ae3 | ||
|
|
fd158e5ae2 | ||
|
|
95c96f7744 | ||
|
|
e7f59fac80 | ||
|
|
1bf059396f | ||
|
|
696b403173 | ||
|
|
f4db2732b0 | ||
|
|
ee88a74dcf | ||
|
|
ca08bb66b9 | ||
|
|
708fcb5beb | ||
|
|
7a65d1eaa2 | ||
|
|
6de2457743 | ||
|
|
ce44e260bf | ||
|
|
09f6537ffc | ||
|
|
ab8f494fdb | ||
|
|
b56a211da9 | ||
|
|
fcce5308cb | ||
|
|
d27b19cc53 | ||
|
|
b8ff678f24 | ||
|
|
b24ef1282d | ||
|
|
65e0de3c82 | ||
|
|
0c2743a48c | ||
|
|
dc73e8a6da | ||
|
|
b8495eeeb3 | ||
|
|
b3eae22cef | ||
|
|
7af0098d1b | ||
|
|
17405be300 | ||
|
|
5bc03e5de6 | ||
|
|
5a5f93148d | ||
|
|
32dc5b6099 | ||
|
|
7936d4675f | ||
|
|
808eafa7c6 | ||
|
|
bcb8ed6df2 | ||
|
|
8ec5dcc0cc | ||
|
|
88a79f212d | ||
|
|
b1f8d6192f | ||
|
|
acfb3b225d | ||
|
|
99a6164000 | ||
|
|
e49d9d33e2 | ||
|
|
184a3d1e4e | ||
|
|
c4ec14f49a | ||
|
|
fb5fc0e885 | ||
|
|
20b603666d | ||
|
|
4d549b7102 | ||
|
|
33b0d1d144 | ||
|
|
41c0f7ce28 | ||
|
|
efb484ba4f | ||
|
|
145501d4a5 | ||
|
|
2d5103997b | ||
|
|
52e7e7aae8 | ||
|
|
5b5a4000d7 | ||
|
|
2bbf603148 | ||
|
|
d14b8a0664 | ||
|
|
f16e0b579e | ||
|
|
43cbc4aac0 | ||
|
|
cf569f4749 | ||
|
|
c9c59f2490 | ||
|
|
16216cc2ca | ||
|
|
de50fd3954 | ||
|
|
7648d5f192 | ||
|
|
d35e5eab25 | ||
|
|
90610a52ce | ||
|
|
f6296d506f | ||
|
|
dfea092583 | ||
|
|
af7dc134bb | ||
|
|
2657d37f76 | ||
|
|
7318d1f4a8 | ||
|
|
f2674487c7 | ||
|
|
71cdcb43e8 | ||
|
|
9c46c3a874 | ||
|
|
12d5f136d5 | ||
|
|
00c37d8d2f | ||
|
|
0f1985af4a | ||
|
|
fa8affe1b7 | ||
|
|
c59a0b7799 | ||
|
|
bd72b91f27 | ||
|
|
9240f85246 | ||
|
|
ea4bd1e483 | ||
|
|
9e0de62707 | ||
|
|
128de290db | ||
|
|
286d266723 | ||
|
|
8aad1005b2 | ||
|
|
11b2f24fc2 | ||
|
|
d1f08cbb89 | ||
|
|
60e9d707d4 | ||
|
|
9b93badb57 | ||
|
|
892213071a | ||
|
|
5484d6e59d | ||
|
|
0ce3fd22a7 | ||
|
|
25e40cc3a6 | ||
|
|
af68d21095 | ||
|
|
1fafd25e86 | ||
|
|
5f8f28a943 | ||
|
|
94cf1bfcc7 | ||
|
|
13562cf521 | ||
|
|
d27e700c4f | ||
|
|
d6bce5af93 | ||
|
|
75969e710d | ||
|
|
6c39c8e191 | ||
|
|
0e261ff0a0 | ||
|
|
fab326eca1 | ||
|
|
c033eceee7 | ||
|
|
a10e03ebe0 | ||
|
|
a6aa4a1839 | ||
|
|
1c749b6803 | ||
|
|
c329bf26b6 | ||
|
|
3ae5b57ebe | ||
|
|
0bf5d5440c | ||
|
|
d731a1fd4f | ||
|
|
93e9fb531d | ||
|
|
6daeb2553d | ||
|
|
321b8a89ee | ||
|
|
d84875e466 | ||
|
|
ea8c9a28a9 | ||
|
|
a302fb3c25 | ||
|
|
958bd124cc | ||
|
|
b89e154e43 | ||
|
|
01924f4a69 | ||
|
|
3725694bdf | ||
|
|
21b12f583a | ||
|
|
d97b86e0ee | ||
|
|
0869ea56cd | ||
|
|
4768440627 | ||
|
|
9f91da403f | ||
|
|
89e5ad24b9 | ||
|
|
3f106ac112 | ||
|
|
f6f6a651fd | ||
|
|
37b867c7ad | ||
|
|
25ea28a277 | ||
|
|
0ac49ab32b | ||
|
|
70c59eb71d | ||
|
|
f60a3ea501 | ||
|
|
3f09d60cdc | ||
|
|
d3b5493d2e | ||
|
|
255feb2e65 | ||
|
|
4b73315df0 | ||
|
|
a086e0cfa1 | ||
|
|
f3bc022a36 | ||
|
|
b7cb7ef0c1 | ||
|
|
267420a46a | ||
|
|
3c66ab958a | ||
|
|
cf2f79b6f4 | ||
|
|
ab6e817c8e | ||
|
|
9ae4630a3b | ||
|
|
d1b8537cfb | ||
|
|
d32b4481da | ||
|
|
52a04ac575 | ||
|
|
0d3d535c08 | ||
|
|
224462018a | ||
|
|
35e89230fd | ||
|
|
9a57af6092 | ||
|
|
2e1bd8a481 | ||
|
|
1e678ecc1a | ||
|
|
6b3523a66d | ||
|
|
d4017b87c1 | ||
|
|
d3b60edb6f | ||
|
|
6baf687ecf | ||
|
|
7da012a4d8 | ||
|
|
6c318f1910 | ||
|
|
a9403c5392 | ||
|
|
ae7dce0b32 | ||
|
|
312728c8b6 | ||
|
|
acf39f2823 | ||
|
|
8de87fb9e0 | ||
|
|
6c48429b90 | ||
|
|
cc6af8fd28 | ||
|
|
5d3989a9a7 | ||
|
|
920767f486 | ||
|
|
7a4e994f3a | ||
|
|
13b1ec46ee | ||
|
|
e2cb07f08c | ||
|
|
541816f2ab | ||
|
|
dec9d03fc5 | ||
|
|
2781951ce7 | ||
|
|
1d2a6bf281 | ||
|
|
db49a3ec02 | ||
|
|
c509066943 | ||
|
|
0283846543 | ||
|
|
210d9f5793 | ||
|
|
dd6af0788e | ||
|
|
7307a5cc9a | ||
|
|
3239ef3c3e | ||
|
|
d21aedac83 | ||
|
|
df9aea194c | ||
|
|
2dcc230852 | ||
|
|
51c543631b | ||
|
|
895423852f | ||
|
|
eb253a9d3a | ||
|
|
3a75b75ae0 | ||
|
|
27ecb4b69b | ||
|
|
962700f525 | ||
|
|
e143d13ff6 | ||
|
|
2f853d7364 | ||
|
|
36099a4ada | ||
|
|
0348fa8a22 | ||
|
|
73bdb55cee | ||
|
|
7fc10573ab | ||
|
|
ce74b124d2 | ||
|
|
f2b10992cc | ||
|
|
deec72416e | ||
|
|
7beeea5779 | ||
|
|
19289c9008 | ||
|
|
89e93a1674 | ||
|
|
f62fa22338 | ||
|
|
2acf58590a | ||
|
|
46a56d0389 | ||
|
|
cfd57288d7 | ||
|
|
1049a723d8 | ||
|
|
3f3198c959 | ||
|
|
6b8f7f8821 | ||
|
|
ac9a1ae742 | ||
|
|
bd4c2bacbc | ||
|
|
6cfc7051c4 | ||
|
|
22a2a97a76 |
@@ -10,7 +10,9 @@ __pycache__
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
dist/*
|
||||
!dist/docker-input/
|
||||
!dist/docker-input/*.tar.gz
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
|
||||
95
.env.example
95
.env.example
@@ -1,90 +1,15 @@
|
||||
# DS2API environment template (Go runtime)
|
||||
# Copy this file to .env and adjust values.
|
||||
# Updated: 2026-02
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Runtime
|
||||
# ---------------------------------------------------------------
|
||||
# HTTP listen port (default: 5001)
|
||||
# DS2API runtime
|
||||
PORT=5001
|
||||
|
||||
# Log level: DEBUG | INFO | WARN | ERROR
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Max concurrent inflight requests per account in managed-key mode.
|
||||
# Default: 2
|
||||
# Recommended client concurrency is calculated dynamically as:
|
||||
# account_count * DS2API_ACCOUNT_MAX_INFLIGHT
|
||||
# So by default it is account_count * 2.
|
||||
# Requests beyond inflight slots enter a waiting queue first.
|
||||
# Default queue size equals recommended concurrency, so 429 starts after:
|
||||
# account_count * DS2API_ACCOUNT_MAX_INFLIGHT * 2
|
||||
# Alias: DS2API_ACCOUNT_CONCURRENCY
|
||||
# DS2API_ACCOUNT_MAX_INFLIGHT=2
|
||||
# Admin authentication
|
||||
DS2API_ADMIN_KEY=change-me
|
||||
|
||||
# Optional waiting queue size override for managed-key mode.
|
||||
# Default: recommended_concurrency (same as account_count * inflight_limit)
|
||||
# Alias: DS2API_ACCOUNT_QUEUE_SIZE
|
||||
# DS2API_ACCOUNT_MAX_QUEUE=10
|
||||
# Config loading (choose one)
|
||||
# 1) file-based config
|
||||
DS2API_CONFIG_PATH=/app/config.json
|
||||
# 2) inline JSON or Base64 JSON
|
||||
# DS2API_CONFIG_JSON=
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Admin auth
|
||||
# ---------------------------------------------------------------
|
||||
# Admin key for /admin login and protected admin APIs.
|
||||
# Default is "admin" when unset, but setting it explicitly is recommended.
|
||||
DS2API_ADMIN_KEY=admin
|
||||
|
||||
# Optional JWT signing secret for admin token.
|
||||
# Defaults to DS2API_ADMIN_KEY when unset.
|
||||
# DS2API_JWT_SECRET=change-me
|
||||
|
||||
# Optional admin JWT validity in hours (default: 24)
|
||||
# DS2API_JWT_EXPIRE_HOURS=24
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Config source (choose one)
|
||||
# ---------------------------------------------------------------
|
||||
# Option A: config file path (local/dev recommended)
|
||||
# DS2API_CONFIG_PATH=config.json
|
||||
|
||||
# Option B: JSON string
|
||||
# DS2API_CONFIG_JSON={"keys":["your-api-key"],"accounts":[{"email":"user@example.com","password":"xxx","token":""}]}
|
||||
|
||||
# Option C: Base64 encoded JSON (recommended for Vercel env var)
|
||||
# DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0=
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Paths (optional)
|
||||
# ---------------------------------------------------------------
|
||||
# WASM file used for PoW solving
|
||||
# DS2API_WASM_PATH=sha3_wasm_bg.7b9ca65ddd.wasm
|
||||
|
||||
# Built admin static assets directory
|
||||
# DS2API_STATIC_ADMIN_DIR=static/admin
|
||||
|
||||
# Auto-build WebUI on startup when static/admin is missing.
|
||||
# Default: enabled on local/Docker, disabled on Vercel.
|
||||
# DS2API_AUTO_BUILD_WEBUI=true
|
||||
|
||||
# Internal auth secret used by the Vercel hybrid streaming path
|
||||
# (Go prepare endpoint <-> Node stream function).
|
||||
# Optional: falls back to DS2API_ADMIN_KEY when unset.
|
||||
# DS2API_VERCEL_INTERNAL_SECRET=change-me
|
||||
|
||||
# Stream lease TTL seconds for Vercel hybrid streaming.
|
||||
# During this window, the managed account stays occupied until Node calls release.
|
||||
# Default: 900 (15 minutes)
|
||||
# DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS=900
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Vercel sync integration (optional)
|
||||
# ---------------------------------------------------------------
|
||||
# VERCEL_TOKEN=your-vercel-token
|
||||
# VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
|
||||
# VERCEL_TEAM_ID=team_xxxxxxxxxxxx
|
||||
|
||||
# Optional: Vercel deployment protection bypass secret.
|
||||
# If deployment protection is enabled, DS2API will use this value as
|
||||
# x-vercel-protection-bypass for internal Node->Go calls on Vercel.
|
||||
# You can also use VERCEL_AUTOMATION_BYPASS_SECRET directly.
|
||||
# DS2API_VERCEL_PROTECTION_BYPASS=your-bypass-secret
|
||||
# Optional: static admin assets path
|
||||
# DS2API_STATIC_ADMIN_DIR=/app/static/admin
|
||||
|
||||
44
.github/PULL_REQUEST_TEMPLATE.md
vendored
44
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,24 +1,20 @@
|
||||
#### 💻 变更类型 | Change Type
|
||||
|
||||
<!-- For change type, change [ ] to [x]. -->
|
||||
|
||||
- [ ] ✨ feat
|
||||
- [ ] 🐛 fix
|
||||
- [ ] ♻️ refactor
|
||||
- [ ] 💄 style
|
||||
- [ ] 👷 build
|
||||
- [ ] ⚡️ perf
|
||||
- [ ] 📝 docs
|
||||
- [ ] 🔨 chore
|
||||
|
||||
#### 🔀 变更说明 | Description of Change
|
||||
|
||||
<!-- Thank you for your Pull Request. Please provide a description above. -->
|
||||
|
||||
#### 📝 补充信息 | Additional Information
|
||||
|
||||
<!-- Add any other context about the Pull Request here. -->
|
||||
|
||||
---
|
||||
|
||||
> 💡 **提示**:如果修改了 `webui/` 目录下的文件,PR 合并后 CI 会自动构建并提交产物,无需手动构建。
|
||||
#### 💻 变更类型 | Change Type
|
||||
|
||||
<!-- For change type, change [ ] to [x]. -->
|
||||
|
||||
- [ ] ✨ feat
|
||||
- [ ] 🐛 fix
|
||||
- [ ] ♻️ refactor
|
||||
- [ ] 💄 style
|
||||
- [ ] 👷 build
|
||||
- [ ] ⚡️ perf
|
||||
- [ ] 📝 docs
|
||||
- [ ] 🔨 chore
|
||||
|
||||
#### 🔀 变更说明 | Description of Change
|
||||
|
||||
<!-- Thank you for your Pull Request. Please provide a description above. -->
|
||||
|
||||
#### 📝 补充信息 | Additional Information
|
||||
|
||||
<!-- Add any other context about the Pull Request here. -->
|
||||
40
.github/workflows/quality-gates.yml
vendored
Normal file
40
.github/workflows/quality-gates.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
name: Quality Gates
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
quality-gates:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.x"
|
||||
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "24"
|
||||
cache: "npm"
|
||||
cache-dependency-path: webui/package-lock.json
|
||||
|
||||
- name: Refactor Line Gate
|
||||
run: ./tests/scripts/check-refactor-line-gate.sh
|
||||
|
||||
- name: Unit Gates (Go + Node)
|
||||
run: ./tests/scripts/run-unit-all.sh
|
||||
|
||||
- name: WebUI Build Gate
|
||||
run: |
|
||||
npm ci --prefix webui
|
||||
npm run build --prefix webui
|
||||
136
.github/workflows/release-artifacts.yml
vendored
136
.github/workflows/release-artifacts.yml
vendored
@@ -4,13 +4,22 @@ on:
|
||||
release:
|
||||
types:
|
||||
- published
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Release tag to build/publish (e.g. v2.1.6)"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
build-and-upload:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name || github.event.inputs.release_tag }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -23,10 +32,16 @@ jobs:
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
node-version: "24"
|
||||
cache: "npm"
|
||||
cache-dependency-path: webui/package-lock.json
|
||||
|
||||
- name: Release Blocking Gates
|
||||
run: |
|
||||
./tests/scripts/check-stage6-manual-smoke.sh
|
||||
./tests/scripts/check-refactor-line-gate.sh
|
||||
./tests/scripts/run-unit-all.sh
|
||||
|
||||
- name: Build WebUI
|
||||
run: |
|
||||
npm ci --prefix webui
|
||||
@@ -35,7 +50,11 @@ jobs:
|
||||
- name: Build Multi-Platform Archives
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG="${{ github.event.release.tag_name }}"
|
||||
TAG="${RELEASE_TAG}"
|
||||
BUILD_VERSION="${TAG}"
|
||||
if [ -z "${BUILD_VERSION}" ] && [ -f VERSION ]; then
|
||||
BUILD_VERSION="$(cat VERSION | tr -d '[:space:]')"
|
||||
fi
|
||||
mkdir -p dist
|
||||
|
||||
targets=(
|
||||
@@ -58,7 +77,7 @@ jobs:
|
||||
|
||||
mkdir -p "${STAGE}/static"
|
||||
CGO_ENABLED=0 GOOS="${GOOS}" GOARCH="${GOARCH}" \
|
||||
go build -trimpath -ldflags="-s -w" -o "${STAGE}/${BIN}" ./cmd/ds2api
|
||||
go build -trimpath -ldflags="-s -w -X ds2api/internal/version.BuildVersion=${BUILD_VERSION}" -o "${STAGE}/${BIN}" ./cmd/ds2api
|
||||
|
||||
cp config.example.json .env.example sha3_wasm_bg.7b9ca65ddd.wasm LICENSE README.MD README.en.md "${STAGE}/"
|
||||
cp -R static/admin "${STAGE}/static/admin"
|
||||
@@ -72,12 +91,117 @@ jobs:
|
||||
rm -rf "${STAGE}"
|
||||
done
|
||||
|
||||
- name: Prepare Docker release inputs
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG="${RELEASE_TAG}"
|
||||
mkdir -p dist/docker-input
|
||||
cp "dist/ds2api_${TAG}_linux_amd64.tar.gz" "dist/docker-input/linux_amd64.tar.gz"
|
||||
cp "dist/ds2api_${TAG}_linux_arm64.tar.gz" "dist/docker-input/linux_arm64.tar.gz"
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Wait for GHCR endpoint
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for i in {1..6}; do
|
||||
code="$(curl -sS -o /dev/null -w '%{http_code}' --max-time 15 https://ghcr.io/v2/ || true)"
|
||||
if [ "${code}" = "200" ] || [ "${code}" = "401" ] || [ "${code}" = "405" ]; then
|
||||
exit 0
|
||||
fi
|
||||
sleep "$((i * 10))"
|
||||
done
|
||||
echo "GHCR endpoint is unreachable after multiple retries (last status: ${code:-unknown})." >&2
|
||||
exit 1
|
||||
|
||||
- name: Log in to GHCR (with retry)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for i in {1..6}; do
|
||||
if echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin; then
|
||||
exit 0
|
||||
fi
|
||||
sleep "$((i * 10))"
|
||||
done
|
||||
echo "Failed to login to GHCR after multiple retries." >&2
|
||||
exit 1
|
||||
|
||||
- name: Extract Docker metadata
|
||||
id: meta_release
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
ghcr.io/${{ github.repository }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.RELEASE_TAG }}
|
||||
type=raw,value=latest
|
||||
|
||||
- name: Build and Push Docker Image
|
||||
uses: docker/build-push-action@v6
|
||||
env:
|
||||
DOCKER_BUILD_RECORD_UPLOAD: "false"
|
||||
DOCKER_BUILD_SUMMARY: "false"
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
target: runtime-from-dist
|
||||
push: true
|
||||
platforms: linux/amd64,linux/arm64
|
||||
tags: ${{ steps.meta_release.outputs.tags }}
|
||||
labels: ${{ steps.meta_release.outputs.labels }}
|
||||
|
||||
- name: Export Docker image archives for release assets
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG="${RELEASE_TAG}"
|
||||
|
||||
docker buildx build \
|
||||
--platform linux/amd64 \
|
||||
--target runtime-from-dist \
|
||||
--output type=docker,dest="dist/ds2api_${TAG}_docker_linux_amd64.tar" \
|
||||
.
|
||||
|
||||
docker buildx build \
|
||||
--platform linux/arm64 \
|
||||
--target runtime-from-dist \
|
||||
--output type=docker,dest="dist/ds2api_${TAG}_docker_linux_arm64.tar" \
|
||||
.
|
||||
|
||||
gzip -f "dist/ds2api_${TAG}_docker_linux_amd64.tar"
|
||||
gzip -f "dist/ds2api_${TAG}_docker_linux_arm64.tar"
|
||||
|
||||
- name: Generate checksums
|
||||
run: |
|
||||
set -euo pipefail
|
||||
(cd dist && sha256sum *.tar.gz *.zip > sha256sums.txt)
|
||||
|
||||
- name: Validate release tag
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG="${RELEASE_TAG}"
|
||||
if [ -z "${TAG}" ]; then
|
||||
echo "release tag is empty; set release_tag when using workflow_dispatch." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Upload Release Assets
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: |
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG="${RELEASE_TAG}"
|
||||
FILES=(
|
||||
dist/*.tar.gz
|
||||
dist/*.zip
|
||||
dist/sha256sums.txt
|
||||
)
|
||||
|
||||
if gh release view "${TAG}" >/dev/null 2>&1; then
|
||||
gh release upload "${TAG}" "${FILES[@]}" --clobber
|
||||
else
|
||||
gh release create "${TAG}" "${FILES[@]}" --title "${TAG}" --notes ""
|
||||
fi
|
||||
|
||||
129
.github/workflows/release-dockerhub.yml
vendored
Normal file
129
.github/workflows/release-dockerhub.yml
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
name: Release to Docker Hub
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version_type:
|
||||
description: '版本类型'
|
||||
required: true
|
||||
default: 'patch'
|
||||
type: choice
|
||||
options:
|
||||
- patch
|
||||
- minor
|
||||
- major
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Get current version
|
||||
id: get_version
|
||||
run: |
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
TAG_VERSION=${LATEST_TAG#v}
|
||||
|
||||
if [ -f VERSION ]; then
|
||||
FILE_VERSION=$(cat VERSION | tr -d '[:space:]')
|
||||
else
|
||||
FILE_VERSION="0.0.0"
|
||||
fi
|
||||
|
||||
function version_gt() { test "$(printf '%s\n' "$@" | sort -V | head -n 1)" != "$1"; }
|
||||
|
||||
if version_gt "$FILE_VERSION" "$TAG_VERSION"; then
|
||||
VERSION="$FILE_VERSION"
|
||||
else
|
||||
VERSION="$TAG_VERSION"
|
||||
fi
|
||||
|
||||
echo "Current version: $VERSION"
|
||||
echo "current_version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Calculate next version
|
||||
id: next_version
|
||||
env:
|
||||
VERSION_TYPE: ${{ github.event.inputs.version_type }}
|
||||
run: |
|
||||
VERSION="${{ steps.get_version.outputs.current_version }}"
|
||||
BASE_VERSION=$(echo "$VERSION" | sed 's/-.*$//')
|
||||
|
||||
IFS='.' read -r -a version_parts <<< "$BASE_VERSION"
|
||||
MAJOR="${version_parts[0]:-0}"
|
||||
MINOR="${version_parts[1]:-0}"
|
||||
PATCH="${version_parts[2]:-0}"
|
||||
|
||||
case "$VERSION_TYPE" in
|
||||
major)
|
||||
NEW_VERSION="$((MAJOR + 1)).0.0"
|
||||
;;
|
||||
minor)
|
||||
NEW_VERSION="${MAJOR}.$((MINOR + 1)).0"
|
||||
;;
|
||||
*)
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.$((PATCH + 1))"
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "New version: $NEW_VERSION"
|
||||
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
echo "new_tag=v$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
echo "${{ steps.next_version.outputs.new_version }}" > VERSION
|
||||
|
||||
- name: Commit VERSION and create tag
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git add VERSION
|
||||
if ! git diff --cached --quiet; then
|
||||
git commit -m "chore: bump version to ${{ steps.next_version.outputs.new_tag }} [skip ci]"
|
||||
fi
|
||||
|
||||
NEW_TAG="${{ steps.next_version.outputs.new_tag }}"
|
||||
git tag -a "$NEW_TAG" -m "Release $NEW_TAG"
|
||||
git push origin HEAD:main "$NEW_TAG"
|
||||
|
||||
# Docker 构建并推送到 Docker Hub
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/ds2api:${{ steps.next_version.outputs.new_tag }}
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/ds2api:${{ steps.next_version.outputs.new_version }}
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/ds2api:latest
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
build-args: |
|
||||
BUILD_VERSION=${{ steps.next_version.outputs.new_tag }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
130
.github/workflows/release.yml
vendored
Normal file
130
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,130 @@
|
||||
name: Release to Aliyun CR
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version_type:
|
||||
description: '版本类型'
|
||||
required: true
|
||||
default: 'patch'
|
||||
type: choice
|
||||
options:
|
||||
- patch
|
||||
- minor
|
||||
- major
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Get current version
|
||||
id: get_version
|
||||
run: |
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
TAG_VERSION=${LATEST_TAG#v}
|
||||
|
||||
if [ -f VERSION ]; then
|
||||
FILE_VERSION=$(cat VERSION | tr -d '[:space:]')
|
||||
else
|
||||
FILE_VERSION="0.0.0"
|
||||
fi
|
||||
|
||||
function version_gt() { test "$(printf '%s\n' "$@" | sort -V | head -n 1)" != "$1"; }
|
||||
|
||||
if version_gt "$FILE_VERSION" "$TAG_VERSION"; then
|
||||
VERSION="$FILE_VERSION"
|
||||
else
|
||||
VERSION="$TAG_VERSION"
|
||||
fi
|
||||
|
||||
echo "Current version: $VERSION"
|
||||
echo "current_version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Calculate next version
|
||||
id: next_version
|
||||
env:
|
||||
VERSION_TYPE: ${{ github.event.inputs.version_type }}
|
||||
run: |
|
||||
VERSION="${{ steps.get_version.outputs.current_version }}"
|
||||
BASE_VERSION=$(echo "$VERSION" | sed 's/-.*$//')
|
||||
|
||||
IFS='.' read -r -a version_parts <<< "$BASE_VERSION"
|
||||
MAJOR="${version_parts[0]:-0}"
|
||||
MINOR="${version_parts[1]:-0}"
|
||||
PATCH="${version_parts[2]:-0}"
|
||||
|
||||
case "$VERSION_TYPE" in
|
||||
major)
|
||||
NEW_VERSION="$((MAJOR + 1)).0.0"
|
||||
;;
|
||||
minor)
|
||||
NEW_VERSION="${MAJOR}.$((MINOR + 1)).0"
|
||||
;;
|
||||
*)
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.$((PATCH + 1))"
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "New version: $NEW_VERSION"
|
||||
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
echo "new_tag=v$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
echo "${{ steps.next_version.outputs.new_version }}" > VERSION
|
||||
|
||||
- name: Commit VERSION and create tag
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git add VERSION
|
||||
if ! git diff --cached --quiet; then
|
||||
git commit -m "chore: bump version to ${{ steps.next_version.outputs.new_tag }} [skip ci]"
|
||||
fi
|
||||
|
||||
NEW_TAG="${{ steps.next_version.outputs.new_tag }}"
|
||||
git tag -a "$NEW_TAG" -m "Release $NEW_TAG"
|
||||
git push origin HEAD:main "$NEW_TAG"
|
||||
|
||||
# Docker 构建并推送到阿里云
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Aliyun Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ secrets.ALIYUN_REGISTRY }}
|
||||
username: ${{ secrets.ALIYUN_REGISTRY_USER }}
|
||||
password: ${{ secrets.ALIYUN_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_tag }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_version }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:latest
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
build-args: |
|
||||
BUILD_VERSION=${{ steps.next_version.outputs.new_tag }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
43
.gitignore
vendored
43
.gitignore
vendored
@@ -2,37 +2,6 @@
|
||||
config.json
|
||||
.env
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.venv
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
@@ -44,7 +13,6 @@ env/
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
uvicorn.log
|
||||
artifacts/
|
||||
|
||||
# Vercel
|
||||
@@ -56,8 +24,6 @@ webui/node_modules/
|
||||
webui/dist/
|
||||
.npm
|
||||
.pnpm-store/
|
||||
# 保留 webui/package-lock.json 用于 CI 缓存
|
||||
# package-lock.json # 如果有根目录的可以忽略
|
||||
yarn.lock
|
||||
pnpm-lock.yaml
|
||||
|
||||
@@ -81,9 +47,14 @@ ds2api-tests
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.tox/
|
||||
*.coverprofile
|
||||
coverage*.out
|
||||
cover/
|
||||
|
||||
# Misc
|
||||
*.pyc
|
||||
*.pyo
|
||||
.git/
|
||||
Thumbs.db
|
||||
|
||||
# Claude Code
|
||||
.claude/
|
||||
CLAUDE.local.md
|
||||
|
||||
336
API.en.md
336
API.en.md
@@ -9,11 +9,13 @@ This document describes the actual behavior of the current Go codebase.
|
||||
## Table of Contents
|
||||
|
||||
- [Basics](#basics)
|
||||
- [Configuration Best Practice](#configuration-best-practice)
|
||||
- [Authentication](#authentication)
|
||||
- [Route Index](#route-index)
|
||||
- [Health Endpoints](#health-endpoints)
|
||||
- [OpenAI-Compatible API](#openai-compatible-api)
|
||||
- [Claude-Compatible API](#claude-compatible-api)
|
||||
- [Gemini-Compatible API](#gemini-compatible-api)
|
||||
- [Admin API](#admin-api)
|
||||
- [Error Payloads](#error-payloads)
|
||||
- [cURL Examples](#curl-examples)
|
||||
@@ -27,13 +29,35 @@ This document describes the actual behavior of the current Go codebase.
|
||||
| Base URL | `http://localhost:5001` or your deployment domain |
|
||||
| Default Content-Type | `application/json` |
|
||||
| Health probes | `GET /healthz`, `GET /readyz` |
|
||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`) |
|
||||
| CORS | Enabled (`Access-Control-Allow-Origin: *`, allows `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
---
|
||||
|
||||
## Configuration Best Practice
|
||||
|
||||
Use `config.json` as the single source of truth:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json (keys/accounts)
|
||||
```
|
||||
|
||||
Use it per deployment mode:
|
||||
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate Base64 from `config.json`, then set `DS2API_CONFIG_JSON`
|
||||
|
||||
```bash
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
```
|
||||
|
||||
For Vercel one-click bootstrap, you can set only `DS2API_ADMIN_KEY` first, then import config at `/admin` and sync env vars from the "Vercel Sync" page.
|
||||
|
||||
---
|
||||
|
||||
## Authentication
|
||||
|
||||
### Business Endpoints (`/v1/*`, `/anthropic/*`)
|
||||
### Business Endpoints (`/v1/*`, `/anthropic/*`, `/v1beta/models/*`)
|
||||
|
||||
Two header formats accepted:
|
||||
|
||||
@@ -66,15 +90,32 @@ Two header formats accepted:
|
||||
| GET | `/healthz` | None | Liveness probe |
|
||||
| GET | `/readyz` | None | Readiness probe |
|
||||
| GET | `/v1/models` | None | OpenAI model list |
|
||||
| GET | `/v1/models/{id}` | None | OpenAI single-model query (alias accepted) |
|
||||
| POST | `/v1/chat/completions` | Business | OpenAI chat completions |
|
||||
| POST | `/v1/responses` | Business | OpenAI Responses API (stream/non-stream) |
|
||||
| GET | `/v1/responses/{response_id}` | Business | Query stored response (in-memory TTL) |
|
||||
| POST | `/v1/embeddings` | Business | OpenAI Embeddings API |
|
||||
| GET | `/anthropic/v1/models` | None | Claude model list |
|
||||
| POST | `/anthropic/v1/messages` | Business | Claude messages |
|
||||
| POST | `/anthropic/v1/messages/count_tokens` | Business | Claude token counting |
|
||||
| POST | `/v1/messages` | Business | Claude shortcut path |
|
||||
| POST | `/messages` | Business | Claude shortcut path |
|
||||
| POST | `/v1/messages/count_tokens` | Business | Claude token counting shortcut |
|
||||
| POST | `/messages/count_tokens` | Business | Claude token counting shortcut |
|
||||
| POST | `/v1beta/models/{model}:generateContent` | Business | Gemini non-stream |
|
||||
| POST | `/v1beta/models/{model}:streamGenerateContent` | Business | Gemini stream |
|
||||
| POST | `/v1/models/{model}:generateContent` | Business | Gemini non-stream compat path |
|
||||
| POST | `/v1/models/{model}:streamGenerateContent` | Business | Gemini stream compat path |
|
||||
| POST | `/admin/login` | None | Admin login |
|
||||
| GET | `/admin/verify` | JWT | Verify admin JWT |
|
||||
| GET | `/admin/vercel/config` | Admin | Read preconfigured Vercel creds |
|
||||
| GET | `/admin/config` | Admin | Read sanitized config |
|
||||
| POST | `/admin/config` | Admin | Update config |
|
||||
| GET | `/admin/settings` | Admin | Read runtime settings |
|
||||
| PUT | `/admin/settings` | Admin | Update runtime settings (hot reload) |
|
||||
| POST | `/admin/settings/password` | Admin | Update admin password and invalidate old JWTs |
|
||||
| POST | `/admin/config/import` | Admin | Import config (merge/replace) |
|
||||
| GET | `/admin/config/export` | Admin | Export full config (`config`/`json`/`base64`) |
|
||||
| POST | `/admin/keys` | Admin | Add API key |
|
||||
| DELETE | `/admin/keys/{key}` | Admin | Delete API key |
|
||||
| GET | `/admin/accounts` | Admin | Paginated account list |
|
||||
@@ -88,6 +129,8 @@ Two header formats accepted:
|
||||
| POST | `/admin/vercel/sync` | Admin | Sync config to Vercel |
|
||||
| GET | `/admin/vercel/status` | Admin | Vercel sync status |
|
||||
| GET | `/admin/export` | Admin | Export config JSON/Base64 |
|
||||
| GET | `/admin/dev/captures` | Admin | Read local packet-capture entries |
|
||||
| DELETE | `/admin/dev/captures` | Admin | Clear local packet-capture entries |
|
||||
|
||||
---
|
||||
|
||||
@@ -127,6 +170,15 @@ No auth required. Returns supported models.
|
||||
}
|
||||
```
|
||||
|
||||
### Model Alias Resolution
|
||||
|
||||
For `chat` / `responses` / `embeddings`, DS2API follows a wide-input/strict-output policy:
|
||||
|
||||
1. Match DeepSeek native model IDs first.
|
||||
2. Then match exact keys in `model_aliases`.
|
||||
3. If still unmatched, fall back by known family heuristics (`o*`, `gpt-*`, `claude-*`, etc.).
|
||||
4. If still unmatched, return `invalid_request_error`.
|
||||
|
||||
### `POST /v1/chat/completions`
|
||||
|
||||
**Headers**:
|
||||
@@ -140,7 +192,7 @@ Content-Type: application/json
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` |
|
||||
| `model` | string | ✅ | DeepSeek native models + common aliases (`gpt-4o`, `gpt-5-codex`, `o3`, `claude-sonnet-4-5`, etc.) |
|
||||
| `messages` | array | ✅ | OpenAI-style messages |
|
||||
| `stream` | boolean | ❌ | Default `false` |
|
||||
| `tools` | array | ❌ | Function calling schema |
|
||||
@@ -230,12 +282,90 @@ When `tools` is present, DS2API performs anti-leak handling:
|
||||
}
|
||||
```
|
||||
|
||||
**Stream**: DS2API buffers text first. If tool call detected → only structured `delta.tool_calls` (each with `index`); otherwise emits buffered text at once.
|
||||
**Stream**: Once high-confidence toolcall features are matched, DS2API emits `delta.tool_calls` immediately (without waiting for full JSON closure), then keeps sending argument deltas; confirmed raw tool JSON is never forwarded as `delta.content`.
|
||||
|
||||
---
|
||||
|
||||
### `GET /v1/models/{id}`
|
||||
|
||||
No auth required. Alias values are accepted as path params (for example `gpt-4o`), and the returned object is the mapped DeepSeek model.
|
||||
|
||||
### `POST /v1/responses`
|
||||
|
||||
OpenAI Responses-style endpoint, accepting either `input` or `messages`.
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | Supports native models + alias mapping |
|
||||
| `input` | string/array/object | ❌ | One of `input` or `messages` is required |
|
||||
| `messages` | array | ❌ | One of `input` or `messages` is required |
|
||||
| `instructions` | string | ❌ | Prepended as a system message |
|
||||
| `stream` | boolean | ❌ | Default `false` |
|
||||
| `tools` | array | ❌ | Same tool detection/translation policy as chat |
|
||||
| `tool_choice` | string/object | ❌ | Supports `auto`/`none`/`required` and forced function selection (`{"type":"function","name":"..."}`) |
|
||||
|
||||
**Non-stream**: Returns a standard `response` object with an ID like `resp_xxx`, and stores it in in-memory TTL cache.
|
||||
If `tool_choice=required` and no valid tool call is produced, DS2API returns HTTP `422` (`error.code=tool_choice_violation`).
|
||||
|
||||
**Stream (SSE)**: minimal event sequence:
|
||||
|
||||
```text
|
||||
event: response.created
|
||||
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
|
||||
|
||||
event: response.output_item.added
|
||||
data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||
|
||||
event: response.content_part.added
|
||||
data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...}
|
||||
|
||||
event: response.output_text.delta
|
||||
data: {"type":"response.output_text.delta","response_id":"resp_xxx","item_id":"msg_xxx","output_index":0,"content_index":0,"delta":"..."}
|
||||
|
||||
event: response.function_call_arguments.delta
|
||||
data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."}
|
||||
|
||||
event: response.function_call_arguments.done
|
||||
data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"}
|
||||
|
||||
event: response.content_part.done
|
||||
data: {"type":"response.content_part.done","response_id":"resp_xxx",...}
|
||||
|
||||
event: response.output_item.done
|
||||
data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||
|
||||
event: response.completed
|
||||
data: {"type":"response.completed","response":{...}}
|
||||
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
If `tool_choice=required` is violated in stream mode, DS2API emits `response.failed` then `[DONE]` (no `response.completed`).
|
||||
Unknown tool names (outside declared `tools`) are rejected and will not be emitted as valid tool calls.
|
||||
|
||||
### `GET /v1/responses/{response_id}`
|
||||
|
||||
Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read).
|
||||
|
||||
> Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`).
|
||||
|
||||
### `POST /v1/embeddings`
|
||||
|
||||
Business auth required. Returns OpenAI-compatible embeddings shape.
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | Supports native models + alias mapping |
|
||||
| `input` | string/array | ✅ | Supports string, string array, token array |
|
||||
|
||||
> Requires `embeddings.provider`. Current supported values: `mock` / `deterministic` / `builtin`. If missing/unsupported, returns standard error shape with HTTP 501.
|
||||
|
||||
---
|
||||
|
||||
## Claude-Compatible API
|
||||
|
||||
Besides `/anthropic/v1/*`, DS2API also supports shortcut paths: `/v1/messages`, `/messages`, `/v1/messages/count_tokens`, `/messages/count_tokens`.
|
||||
|
||||
### `GET /anthropic/v1/models`
|
||||
|
||||
No auth required.
|
||||
@@ -249,7 +379,10 @@ No auth required.
|
||||
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
|
||||
]
|
||||
],
|
||||
"first_id": "claude-opus-4-6",
|
||||
"last_id": "claude-instant-1.0",
|
||||
"has_more": false
|
||||
}
|
||||
```
|
||||
|
||||
@@ -265,13 +398,15 @@ Content-Type: application/json
|
||||
anthropic-version: 2023-06-01
|
||||
```
|
||||
|
||||
> `anthropic-version` is optional; DS2API auto-fills `2023-06-01` when absent.
|
||||
|
||||
**Request body**:
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | For example `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5` (compatible with `claude-3-5-haiku-latest`), plus historical Claude model IDs |
|
||||
| `messages` | array | ✅ | Claude-style messages |
|
||||
| `max_tokens` | number | ❌ | Not strictly enforced by upstream bridge |
|
||||
| `max_tokens` | number | ❌ | Auto-filled to `8192` when omitted; not strictly enforced by upstream bridge |
|
||||
| `stream` | boolean | ❌ | Default `false` |
|
||||
| `system` | string | ❌ | Optional system prompt |
|
||||
| `tools` | array | ❌ | Claude tool schema |
|
||||
@@ -354,6 +489,37 @@ data: {"type":"message_stop"}
|
||||
|
||||
---
|
||||
|
||||
## Gemini-Compatible API
|
||||
|
||||
Supported paths:
|
||||
|
||||
- `/v1beta/models/{model}:generateContent`
|
||||
- `/v1beta/models/{model}:streamGenerateContent`
|
||||
- `/v1/models/{model}:generateContent` (compat path)
|
||||
- `/v1/models/{model}:streamGenerateContent` (compat path)
|
||||
|
||||
Authentication is the same as other business routes (`Authorization: Bearer <token>` or `x-api-key`).
|
||||
|
||||
### `POST /v1beta/models/{model}:generateContent`
|
||||
|
||||
Request body accepts Gemini-style `contents` / `tools`. Model names can use aliases and are mapped to DeepSeek models.
|
||||
|
||||
Response uses Gemini-compatible fields, including:
|
||||
|
||||
- `candidates[].content.parts[].text`
|
||||
- `candidates[].content.parts[].functionCall` (when tool call is produced)
|
||||
- `usageMetadata` (`promptTokenCount` / `candidatesTokenCount` / `totalTokenCount`)
|
||||
|
||||
### `POST /v1beta/models/{model}:streamGenerateContent`
|
||||
|
||||
Returns SSE (`text/event-stream`), each chunk as `data: <json>`:
|
||||
|
||||
- regular text: incremental text chunks
|
||||
- `tools` mode: buffered and emitted as `functionCall` at finalize phase
|
||||
- final chunk: includes `finishReason: "STOP"` and `usageMetadata`
|
||||
|
||||
---
|
||||
|
||||
## Admin API
|
||||
|
||||
### `POST /admin/login`
|
||||
@@ -416,6 +582,7 @@ Returns sanitized config.
|
||||
"keys": ["k1", "k2"],
|
||||
"accounts": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
"email": "user@example.com",
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
@@ -449,6 +616,53 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
||||
}
|
||||
```
|
||||
|
||||
### `GET /admin/settings`
|
||||
|
||||
Reads runtime settings and status, including:
|
||||
|
||||
- `admin` (JWT expiry, default-password warning, etc.)
|
||||
- `runtime` (`account_max_inflight`, `account_max_queue`, `global_max_inflight`)
|
||||
- `toolcall` / `responses` / `embeddings`
|
||||
- `auto_delete` (`sessions`)
|
||||
- `claude_mapping` / `model_aliases`
|
||||
- `env_backed`, `needs_vercel_sync`
|
||||
|
||||
### `PUT /admin/settings`
|
||||
|
||||
Hot-updates runtime settings. Supported fields:
|
||||
|
||||
- `admin.jwt_expire_hours`
|
||||
- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight`
|
||||
- `toolcall.mode` / `toolcall.early_emit_confidence`
|
||||
- `responses.store_ttl_seconds`
|
||||
- `embeddings.provider`
|
||||
- `auto_delete.sessions`
|
||||
- `claude_mapping`
|
||||
- `model_aliases`
|
||||
|
||||
### `POST /admin/settings/password`
|
||||
|
||||
Updates admin password and invalidates existing JWTs.
|
||||
|
||||
Request example:
|
||||
|
||||
```json
|
||||
{"new_password":"your-new-password"}
|
||||
```
|
||||
|
||||
### `POST /admin/config/import`
|
||||
|
||||
Imports full config with:
|
||||
|
||||
- `mode=merge` (default)
|
||||
- `mode=replace`
|
||||
|
||||
The request can send config directly, or wrapped as `{"config": {...}, "mode":"merge"}`.
|
||||
|
||||
### `GET /admin/config/export`
|
||||
|
||||
Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
|
||||
### `POST /admin/keys`
|
||||
|
||||
```json
|
||||
@@ -476,6 +690,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
"email": "user@example.com",
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
@@ -500,7 +715,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
||||
|
||||
### `DELETE /admin/accounts/{identifier}`
|
||||
|
||||
`identifier` is email or mobile.
|
||||
`identifier` can be email, mobile, or the synthetic id for token-only accounts (`token:<hash>`).
|
||||
|
||||
**Response**: `{"success": true, "total_accounts": 5}`
|
||||
|
||||
@@ -530,7 +745,7 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
||||
|
||||
| Field | Required | Notes |
|
||||
| --- | --- | --- |
|
||||
| `identifier` | ✅ | email or mobile |
|
||||
| `identifier` | ✅ | email / mobile / token-only synthetic id |
|
||||
| `model` | ❌ | default `deepseek-chat` |
|
||||
| `message` | ❌ | if empty, only session creation is tested |
|
||||
|
||||
@@ -655,17 +870,53 @@ Or manual deploy required:
|
||||
}
|
||||
```
|
||||
|
||||
### `GET /admin/dev/captures`
|
||||
|
||||
Reads local packet-capture status and recent entries (Admin auth required):
|
||||
|
||||
- `enabled`
|
||||
- `limit`
|
||||
- `max_body_bytes`
|
||||
- `items`
|
||||
|
||||
### `DELETE /admin/dev/captures`
|
||||
|
||||
Clears packet-capture entries:
|
||||
|
||||
```json
|
||||
{"success":true,"detail":"capture logs cleared"}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Error Payloads
|
||||
|
||||
Error formats vary by module:
|
||||
Compatible routes (`/v1/*`, `/anthropic/*`) use the same error envelope:
|
||||
|
||||
| Module | Format |
|
||||
| --- | --- |
|
||||
| OpenAI routes | `{"error": {"message": "...", "type": "..."}}` |
|
||||
| Claude routes | `{"error": {"type": "...", "message": "..."}}` |
|
||||
| Admin routes | `{"detail": "..."}` |
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "...",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_request",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Admin routes keep `{"detail":"..."}`.
|
||||
|
||||
Gemini routes use Google-style errors:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "invalid json",
|
||||
"status": "INVALID_ARGUMENT"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Clients should handle HTTP status code plus `error` / `detail` fields.
|
||||
|
||||
@@ -707,6 +958,31 @@ curl http://localhost:5001/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI Responses (Stream)
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/v1/responses \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5-codex",
|
||||
"input": "Write a hello world in golang",
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI Embeddings
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/v1/embeddings \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"input": ["first text", "second text"]
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI with Search
|
||||
|
||||
```bash
|
||||
@@ -748,6 +1024,38 @@ curl http://localhost:5001/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### Gemini Non-Stream
|
||||
|
||||
```bash
|
||||
curl "http://localhost:5001/v1beta/models/gemini-2.5-pro:generateContent" \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "Introduce Go in three sentences"}]
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### Gemini Stream
|
||||
|
||||
```bash
|
||||
curl "http://localhost:5001/v1beta/models/gemini-2.5-flash:streamGenerateContent" \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "Write a short summary"}]
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### Claude Non-Stream
|
||||
|
||||
```bash
|
||||
|
||||
341
API.md
341
API.md
@@ -9,11 +9,13 @@
|
||||
## 目录
|
||||
|
||||
- [基础信息](#基础信息)
|
||||
- [配置最佳实践](#配置最佳实践)
|
||||
- [鉴权规则](#鉴权规则)
|
||||
- [路由总览](#路由总览)
|
||||
- [健康检查](#健康检查)
|
||||
- [OpenAI 兼容接口](#openai-兼容接口)
|
||||
- [Claude 兼容接口](#claude-兼容接口)
|
||||
- [Gemini 兼容接口](#gemini-兼容接口)
|
||||
- [Admin 接口](#admin-接口)
|
||||
- [错误响应格式](#错误响应格式)
|
||||
- [cURL 示例](#curl-示例)
|
||||
@@ -27,13 +29,35 @@
|
||||
| Base URL | `http://localhost:5001` 或你的部署域名 |
|
||||
| 默认 Content-Type | `application/json` |
|
||||
| 健康检查 | `GET /healthz`、`GET /readyz` |
|
||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`) |
|
||||
| CORS | 已启用(`Access-Control-Allow-Origin: *`,允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Vercel-Protection-Bypass`) |
|
||||
|
||||
---
|
||||
|
||||
## 配置最佳实践
|
||||
|
||||
推荐把 `config.json` 作为唯一配置源:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json(keys/accounts)
|
||||
```
|
||||
|
||||
按部署方式使用:
|
||||
|
||||
- 本地运行:直接读取 `config.json`
|
||||
- Docker / Vercel:从 `config.json` 生成 Base64,填入 `DS2API_CONFIG_JSON`
|
||||
|
||||
```bash
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
```
|
||||
|
||||
Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导入配置,再通过 “Vercel 同步” 写回环境变量。
|
||||
|
||||
---
|
||||
|
||||
## 鉴权规则
|
||||
|
||||
### 业务接口(`/v1/*`、`/anthropic/*`)
|
||||
### 业务接口(`/v1/*`、`/anthropic/*`、`/v1beta/models/*`)
|
||||
|
||||
支持两种传参方式:
|
||||
|
||||
@@ -66,15 +90,32 @@
|
||||
| GET | `/healthz` | 无 | 存活探针 |
|
||||
| GET | `/readyz` | 无 | 就绪探针 |
|
||||
| GET | `/v1/models` | 无 | OpenAI 模型列表 |
|
||||
| GET | `/v1/models/{id}` | 无 | OpenAI 单模型查询(支持 alias 入参) |
|
||||
| POST | `/v1/chat/completions` | 业务 | OpenAI 对话补全 |
|
||||
| POST | `/v1/responses` | 业务 | OpenAI Responses 接口(流式/非流式) |
|
||||
| GET | `/v1/responses/{response_id}` | 业务 | 查询已生成 response(内存 TTL) |
|
||||
| POST | `/v1/embeddings` | 业务 | OpenAI Embeddings 接口 |
|
||||
| GET | `/anthropic/v1/models` | 无 | Claude 模型列表 |
|
||||
| POST | `/anthropic/v1/messages` | 业务 | Claude 消息接口 |
|
||||
| POST | `/anthropic/v1/messages/count_tokens` | 业务 | Claude token 计数 |
|
||||
| POST | `/v1/messages` | 业务 | Claude 消息快捷路径 |
|
||||
| POST | `/messages` | 业务 | Claude 消息快捷路径 |
|
||||
| POST | `/v1/messages/count_tokens` | 业务 | Claude token 计数快捷路径 |
|
||||
| POST | `/messages/count_tokens` | 业务 | Claude token 计数快捷路径 |
|
||||
| POST | `/v1beta/models/{model}:generateContent` | 业务 | Gemini 非流式 |
|
||||
| POST | `/v1beta/models/{model}:streamGenerateContent` | 业务 | Gemini 流式 |
|
||||
| POST | `/v1/models/{model}:generateContent` | 业务 | Gemini 非流式兼容路径 |
|
||||
| POST | `/v1/models/{model}:streamGenerateContent` | 业务 | Gemini 流式兼容路径 |
|
||||
| POST | `/admin/login` | 无 | 管理登录 |
|
||||
| GET | `/admin/verify` | JWT | 校验管理 JWT |
|
||||
| GET | `/admin/vercel/config` | Admin | 读取 Vercel 预配置 |
|
||||
| GET | `/admin/config` | Admin | 读取配置(脱敏) |
|
||||
| POST | `/admin/config` | Admin | 更新配置 |
|
||||
| GET | `/admin/settings` | Admin | 读取运行时设置 |
|
||||
| PUT | `/admin/settings` | Admin | 更新运行时设置(热更新) |
|
||||
| POST | `/admin/settings/password` | Admin | 更新 Admin 密码并使旧 JWT 失效 |
|
||||
| POST | `/admin/config/import` | Admin | 导入配置(merge/replace) |
|
||||
| GET | `/admin/config/export` | Admin | 导出完整配置(含 `config`/`json`/`base64`) |
|
||||
| POST | `/admin/keys` | Admin | 添加 API key |
|
||||
| DELETE | `/admin/keys/{key}` | Admin | 删除 API key |
|
||||
| GET | `/admin/accounts` | Admin | 分页账号列表 |
|
||||
@@ -88,6 +129,8 @@
|
||||
| POST | `/admin/vercel/sync` | Admin | 同步配置到 Vercel |
|
||||
| GET | `/admin/vercel/status` | Admin | Vercel 同步状态 |
|
||||
| GET | `/admin/export` | Admin | 导出配置 JSON/Base64 |
|
||||
| GET | `/admin/dev/captures` | Admin | 查看本地抓包记录 |
|
||||
| DELETE | `/admin/dev/captures` | Admin | 清空本地抓包记录 |
|
||||
|
||||
---
|
||||
|
||||
@@ -127,6 +170,15 @@
|
||||
}
|
||||
```
|
||||
|
||||
### 模型 alias 解析策略
|
||||
|
||||
对 `chat` / `responses` / `embeddings` 的 `model` 字段采用“宽进严出”:
|
||||
|
||||
1. 先匹配 DeepSeek 原生模型。
|
||||
2. 再匹配 `model_aliases` 精确映射。
|
||||
3. 未命中时按模型家族规则回退(如 `o*`、`gpt-*`、`claude-*`)。
|
||||
4. 仍未命中则返回 `invalid_request_error`。
|
||||
|
||||
### `POST /v1/chat/completions`
|
||||
|
||||
**请求头**:
|
||||
@@ -140,7 +192,7 @@ Content-Type: application/json
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | `deepseek-chat` / `deepseek-reasoner` / `deepseek-chat-search` / `deepseek-reasoner-search` |
|
||||
| `model` | string | ✅ | 支持 DeepSeek 原生模型 + 常见 alias(如 `gpt-4o`、`gpt-5-codex`、`o3`、`claude-sonnet-4-5`) |
|
||||
| `messages` | array | ✅ | OpenAI 风格消息数组 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `tools` | array | ❌ | Function Calling 定义 |
|
||||
@@ -230,12 +282,95 @@ data: [DONE]
|
||||
}
|
||||
```
|
||||
|
||||
**流式**:先缓冲正文片段。识别到工具调用 → 仅输出结构化 `delta.tool_calls`(每个 tool call 带 `index`);否则一次性输出普通文本。
|
||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
||||
|
||||
补充说明:
|
||||
|
||||
- **非代码块上下文**下,工具 JSON 即使与普通文本混合,也会按特征识别并产出可执行 tool call(前后普通文本仍可透传)。
|
||||
- Markdown fenced code block(例如 ```json ... ```)中的 `tool_calls` 仅视为示例文本,不会被执行。
|
||||
|
||||
---
|
||||
|
||||
### `GET /v1/models/{id}`
|
||||
|
||||
无需鉴权。入参支持 alias(例如 `gpt-4o`),返回的是映射后的 DeepSeek 模型对象。
|
||||
|
||||
### `POST /v1/responses`
|
||||
|
||||
OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | 支持原生模型 + alias 自动映射 |
|
||||
| `input` | string/array/object | ❌ | 与 `messages` 二选一 |
|
||||
| `messages` | array | ❌ | 与 `input` 二选一 |
|
||||
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略(含代码块示例豁免) |
|
||||
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) |
|
||||
|
||||
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
||||
当 `tool_choice=required` 且未产出有效工具调用时,返回 HTTP `422`(`error.code=tool_choice_violation`)。
|
||||
|
||||
**流式响应(SSE)**:最小事件序列如下。
|
||||
|
||||
```text
|
||||
event: response.created
|
||||
data: {"type":"response.created","id":"resp_xxx","status":"in_progress",...}
|
||||
|
||||
event: response.output_item.added
|
||||
data: {"type":"response.output_item.added","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||
|
||||
event: response.content_part.added
|
||||
data: {"type":"response.content_part.added","response_id":"resp_xxx","part":{"type":"output_text",...},...}
|
||||
|
||||
event: response.output_text.delta
|
||||
data: {"type":"response.output_text.delta","response_id":"resp_xxx","item_id":"msg_xxx","output_index":0,"content_index":0,"delta":"..."}
|
||||
|
||||
event: response.function_call_arguments.delta
|
||||
data: {"type":"response.function_call_arguments.delta","response_id":"resp_xxx","call_id":"call_xxx","delta":"..."}
|
||||
|
||||
event: response.function_call_arguments.done
|
||||
data: {"type":"response.function_call_arguments.done","response_id":"resp_xxx","call_id":"call_xxx","name":"tool","arguments":"{...}"}
|
||||
|
||||
event: response.content_part.done
|
||||
data: {"type":"response.content_part.done","response_id":"resp_xxx",...}
|
||||
|
||||
event: response.output_item.done
|
||||
data: {"type":"response.output_item.done","response_id":"resp_xxx","item":{"type":"message|function_call",...},...}
|
||||
|
||||
event: response.completed
|
||||
data: {"type":"response.completed","response":{...}}
|
||||
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
流式场景下若 `tool_choice=required` 违规,会返回 `response.failed` 后结束(不再发送 `response.completed`)。
|
||||
未在 `tools` 声明中的工具名会被严格拒绝,不会作为有效 tool call 下发。
|
||||
|
||||
### `GET /v1/responses/{response_id}`
|
||||
|
||||
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。
|
||||
|
||||
> 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。
|
||||
|
||||
### `POST /v1/embeddings`
|
||||
|
||||
需要业务鉴权。返回 OpenAI Embeddings 兼容结构。
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | 支持原生模型 + alias 自动映射 |
|
||||
| `input` | string/array | ✅ | 支持字符串、字符串数组、token 数组 |
|
||||
|
||||
> 需配置 `embeddings.provider`。当前支持:`mock` / `deterministic` / `builtin`。未配置或不支持时返回标准错误结构(HTTP 501)。
|
||||
|
||||
---
|
||||
|
||||
## Claude 兼容接口
|
||||
|
||||
除标准路径 `/anthropic/v1/*` 外,还支持快捷路径 `/v1/messages`、`/messages`、`/v1/messages/count_tokens`、`/messages/count_tokens`。
|
||||
|
||||
### `GET /anthropic/v1/models`
|
||||
|
||||
无需鉴权。
|
||||
@@ -249,7 +384,10 @@ data: [DONE]
|
||||
{"id": "claude-sonnet-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||
{"id": "claude-haiku-4-5", "object": "model", "created": 1715635200, "owned_by": "anthropic"},
|
||||
{"id": "claude-opus-4-6", "object": "model", "created": 1715635200, "owned_by": "anthropic"}
|
||||
]
|
||||
],
|
||||
"first_id": "claude-opus-4-6",
|
||||
"last_id": "claude-instant-1.0",
|
||||
"has_more": false
|
||||
}
|
||||
```
|
||||
|
||||
@@ -265,13 +403,15 @@ Content-Type: application/json
|
||||
anthropic-version: 2023-06-01
|
||||
```
|
||||
|
||||
> `anthropic-version` 可省略,服务端会自动补为 `2023-06-01`。
|
||||
|
||||
**请求体**:
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
| --- | --- | --- | --- |
|
||||
| `model` | string | ✅ | 例如 `claude-sonnet-4-5` / `claude-opus-4-6` / `claude-haiku-4-5`(兼容 `claude-3-5-haiku-latest`),并支持历史 Claude 模型 ID |
|
||||
| `messages` | array | ✅ | Claude 风格消息数组 |
|
||||
| `max_tokens` | number | ❌ | 当前实现不会硬性截断上游输出 |
|
||||
| `max_tokens` | number | ❌ | 缺省自动补 `8192`;当前实现不会硬性截断上游输出 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `system` | string | ❌ | 可选系统提示 |
|
||||
| `tools` | array | ❌ | Claude tool 定义 |
|
||||
@@ -354,6 +494,37 @@ data: {"type":"message_stop"}
|
||||
|
||||
---
|
||||
|
||||
## Gemini 兼容接口
|
||||
|
||||
支持路径:
|
||||
|
||||
- `/v1beta/models/{model}:generateContent`
|
||||
- `/v1beta/models/{model}:streamGenerateContent`
|
||||
- `/v1/models/{model}:generateContent`(兼容路径)
|
||||
- `/v1/models/{model}:streamGenerateContent`(兼容路径)
|
||||
|
||||
鉴权方式同业务接口(`Authorization: Bearer <token>` 或 `x-api-key`)。
|
||||
|
||||
### `POST /v1beta/models/{model}:generateContent`
|
||||
|
||||
请求体兼容 Gemini `contents` / `tools` 字段,模型名可用 alias 自动映射到 DeepSeek 模型。
|
||||
|
||||
响应为 Gemini 兼容结构,核心字段包括:
|
||||
|
||||
- `candidates[].content.parts[].text`
|
||||
- `candidates[].content.parts[].functionCall`(工具调用时)
|
||||
- `usageMetadata`(`promptTokenCount` / `candidatesTokenCount` / `totalTokenCount`)
|
||||
|
||||
### `POST /v1beta/models/{model}:streamGenerateContent`
|
||||
|
||||
返回 SSE(`text/event-stream`),每个 chunk 为一条 `data: <json>`:
|
||||
|
||||
- 常规文本:持续返回增量文本 chunk
|
||||
- `tools` 场景:会缓冲并在结束时输出 `functionCall` 结构
|
||||
- 结束 chunk:包含 `finishReason: "STOP"` 与 `usageMetadata`
|
||||
|
||||
---
|
||||
|
||||
## Admin 接口
|
||||
|
||||
### `POST /admin/login`
|
||||
@@ -416,6 +587,7 @@ data: {"type":"message_stop"}
|
||||
"keys": ["k1", "k2"],
|
||||
"accounts": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
"email": "user@example.com",
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
@@ -449,6 +621,53 @@ data: {"type":"message_stop"}
|
||||
}
|
||||
```
|
||||
|
||||
### `GET /admin/settings`
|
||||
|
||||
读取运行时设置与状态,返回:
|
||||
|
||||
- `admin`(JWT 过期、默认密码告警等)
|
||||
- `runtime`(`account_max_inflight`、`account_max_queue`、`global_max_inflight`)
|
||||
- `toolcall` / `responses` / `embeddings`
|
||||
- `auto_delete`(`sessions`)
|
||||
- `claude_mapping` / `model_aliases`
|
||||
- `env_backed`、`needs_vercel_sync`
|
||||
|
||||
### `PUT /admin/settings`
|
||||
|
||||
热更新运行时设置。支持更新:
|
||||
|
||||
- `admin.jwt_expire_hours`
|
||||
- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight`
|
||||
- `toolcall.mode` / `toolcall.early_emit_confidence`
|
||||
- `responses.store_ttl_seconds`
|
||||
- `embeddings.provider`
|
||||
- `auto_delete.sessions`
|
||||
- `claude_mapping`
|
||||
- `model_aliases`
|
||||
|
||||
### `POST /admin/settings/password`
|
||||
|
||||
更新管理密码并使旧 JWT 失效。
|
||||
|
||||
请求示例:
|
||||
|
||||
```json
|
||||
{"new_password":"your-new-password"}
|
||||
```
|
||||
|
||||
### `POST /admin/config/import`
|
||||
|
||||
导入完整配置,支持:
|
||||
|
||||
- `mode=merge`(默认)
|
||||
- `mode=replace`
|
||||
|
||||
请求可直接传配置对象,或使用 `{"config": {...}, "mode":"merge"}` 包裹格式。
|
||||
|
||||
### `GET /admin/config/export`
|
||||
|
||||
导出完整配置,返回 `config`、`json`、`base64` 三种格式。
|
||||
|
||||
### `POST /admin/keys`
|
||||
|
||||
```json
|
||||
@@ -476,6 +695,7 @@ data: {"type":"message_stop"}
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
"email": "user@example.com",
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
@@ -500,7 +720,7 @@ data: {"type":"message_stop"}
|
||||
|
||||
### `DELETE /admin/accounts/{identifier}`
|
||||
|
||||
`identifier` 为 email 或 mobile。
|
||||
`identifier` 可为 email、mobile,或 token-only 账号的合成标识(`token:<hash>`)。
|
||||
|
||||
**响应**:`{"success": true, "total_accounts": 5}`
|
||||
|
||||
@@ -530,7 +750,7 @@ data: {"type":"message_stop"}
|
||||
|
||||
| 字段 | 必填 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| `identifier` | ✅ | email 或 mobile |
|
||||
| `identifier` | ✅ | email / mobile / token-only 合成标识 |
|
||||
| `model` | ❌ | 默认 `deepseek-chat` |
|
||||
| `message` | ❌ | 空字符串时仅测试会话创建 |
|
||||
|
||||
@@ -655,17 +875,53 @@ data: {"type":"message_stop"}
|
||||
}
|
||||
```
|
||||
|
||||
### `GET /admin/dev/captures`
|
||||
|
||||
查看本地抓包状态与最近记录(需 Admin 鉴权):
|
||||
|
||||
- `enabled`
|
||||
- `limit`
|
||||
- `max_body_bytes`
|
||||
- `items`
|
||||
|
||||
### `DELETE /admin/dev/captures`
|
||||
|
||||
清空抓包记录,返回:
|
||||
|
||||
```json
|
||||
{"success":true,"detail":"capture logs cleared"}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 错误响应格式
|
||||
|
||||
不同模块的错误格式略有差异:
|
||||
兼容路由(`/v1/*`、`/anthropic/*`)统一使用以下结构:
|
||||
|
||||
| 模块 | 格式 |
|
||||
| --- | --- |
|
||||
| OpenAI 接口 | `{"error": {"message": "...", "type": "..."}}` |
|
||||
| Claude 接口 | `{"error": {"type": "...", "message": "..."}}` |
|
||||
| Admin 接口 | `{"detail": "..."}` |
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"message": "...",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_request",
|
||||
"param": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Admin 接口保持 `{"detail":"..."}`。
|
||||
|
||||
Gemini 路由使用 Google 风格错误结构:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": {
|
||||
"code": 400,
|
||||
"message": "invalid json",
|
||||
"status": "INVALID_ARGUMENT"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
建议客户端处理逻辑:检查 HTTP 状态码 + 解析 `error` 或 `detail` 字段。
|
||||
|
||||
@@ -707,6 +963,31 @@ curl http://localhost:5001/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI Responses(流式)
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/v1/responses \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5-codex",
|
||||
"input": "写一个 golang 的 hello world",
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI Embeddings
|
||||
|
||||
```bash
|
||||
curl http://localhost:5001/v1/embeddings \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"input": ["第一段文本", "第二段文本"]
|
||||
}'
|
||||
```
|
||||
|
||||
### OpenAI 带搜索
|
||||
|
||||
```bash
|
||||
@@ -748,6 +1029,38 @@ curl http://localhost:5001/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### Gemini 非流式
|
||||
|
||||
```bash
|
||||
curl "http://localhost:5001/v1beta/models/gemini-2.5-pro:generateContent" \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "用三句话介绍 Go 语言"}]
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### Gemini 流式
|
||||
|
||||
```bash
|
||||
curl "http://localhost:5001/v1beta/models/gemini-2.5-flash:streamGenerateContent" \
|
||||
-H "Authorization: Bearer your-api-key" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "写一个简短摘要"}]
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### Claude 非流式
|
||||
|
||||
```bash
|
||||
|
||||
@@ -82,11 +82,11 @@ Manually build WebUI to `static/admin/`:
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Go unit tests
|
||||
go test ./...
|
||||
# Go + Node unit tests (recommended)
|
||||
./tests/scripts/run-unit-all.sh
|
||||
|
||||
# End-to-end live tests (real accounts)
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
@@ -99,24 +99,34 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go entry
|
||||
│ ├── chat-stream.js # Vercel Node.js stream relay
|
||||
│ └── helpers/ # Node.js helper modules
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # Account pool and concurrency queue
|
||||
│ ├── adapter/
|
||||
│ │ ├── openai/ # OpenAI adapter
|
||||
│ │ └── claude/ # Claude adapter
|
||||
│ │ ├── claude/ # Claude adapter
|
||||
│ │ └── gemini/ # Gemini adapter
|
||||
│ ├── admin/ # Admin API handlers
|
||||
│ ├── auth/ # Auth and JWT
|
||||
│ ├── claudeconv/ # Claude message conversion
|
||||
│ ├── compat/ # Compatibility helpers
|
||||
│ ├── config/ # Config loading and hot-reload
|
||||
│ ├── deepseek/ # DeepSeek client, PoW WASM
|
||||
│ ├── js/ # Node runtime stream/compat logic
|
||||
│ ├── devcapture/ # Dev packet capture
|
||||
│ ├── format/ # Output formatting
|
||||
│ ├── prompt/ # Prompt building
|
||||
│ ├── server/ # HTTP routing (chi router)
|
||||
│ ├── sse/ # SSE parsing utilities
|
||||
│ ├── stream/ # Unified stream consumption engine
|
||||
│ ├── testsuite/ # Testsuite core logic
|
||||
│ ├── util/ # Common utilities
|
||||
│ └── webui/ # WebUI static hosting
|
||||
├── webui/ # React WebUI source
|
||||
│ └── src/
|
||||
│ ├── components/ # Components
|
||||
│ ├── app/ # Routing, auth, config state
|
||||
│ ├── features/ # Feature modules
|
||||
│ ├── components/ # Shared components
|
||||
│ └── locales/ # Language packs
|
||||
├── scripts/ # Build and test scripts
|
||||
├── static/admin/ # WebUI build output (not committed)
|
||||
|
||||
@@ -82,11 +82,11 @@ docker-compose -f docker-compose.dev.yml up
|
||||
## 运行测试
|
||||
|
||||
```bash
|
||||
# Go 单元测试
|
||||
go test ./...
|
||||
# Go + Node 单元测试(推荐)
|
||||
./tests/scripts/run-unit-all.sh
|
||||
|
||||
# 端到端全链路测试(真实账号)
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
@@ -99,24 +99,34 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go 入口
|
||||
│ ├── chat-stream.js # Vercel Node.js 流式转发
|
||||
│ └── helpers/ # Node.js 辅助模块
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # 账号池与并发队列
|
||||
│ ├── adapter/
|
||||
│ │ ├── openai/ # OpenAI 兼容适配器
|
||||
│ │ └── claude/ # Claude 兼容适配器
|
||||
│ │ ├── claude/ # Claude 兼容适配器
|
||||
│ │ └── gemini/ # Gemini 兼容适配器
|
||||
│ ├── admin/ # Admin API handlers
|
||||
│ ├── auth/ # 鉴权与 JWT
|
||||
│ ├── claudeconv/ # Claude 消息格式转换
|
||||
│ ├── compat/ # 兼容性辅助
|
||||
│ ├── config/ # 配置加载与热更新
|
||||
│ ├── deepseek/ # DeepSeek 客户端、PoW WASM
|
||||
│ ├── js/ # Node 运行时流式/兼容逻辑
|
||||
│ ├── devcapture/ # 开发抓包
|
||||
│ ├── format/ # 输出格式化
|
||||
│ ├── prompt/ # Prompt 构建
|
||||
│ ├── server/ # HTTP 路由(chi router)
|
||||
│ ├── sse/ # SSE 解析工具
|
||||
│ ├── stream/ # 统一流式消费引擎
|
||||
│ ├── testsuite/ # 测试集核心逻辑
|
||||
│ ├── util/ # 通用工具
|
||||
│ └── webui/ # WebUI 静态托管
|
||||
├── webui/ # React WebUI 源码
|
||||
│ └── src/
|
||||
│ ├── components/ # 组件
|
||||
│ ├── app/ # 路由、鉴权、配置状态
|
||||
│ ├── features/ # 业务功能模块
|
||||
│ ├── components/ # 通用组件
|
||||
│ └── locales/ # 语言包
|
||||
├── scripts/ # 构建与测试脚本
|
||||
├── static/admin/ # WebUI 构建产物(不提交)
|
||||
|
||||
106
DEPLOY.en.md
106
DEPLOY.en.md
@@ -33,6 +33,17 @@ Config source (choose one):
|
||||
- **File**: `config.json` (recommended for local/Docker)
|
||||
- **Environment variable**: `DS2API_CONFIG_JSON` (recommended for Vercel; supports raw JSON or Base64)
|
||||
|
||||
Unified recommendation (best practice):
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
```
|
||||
|
||||
Use `config.json` as the single source of truth:
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate `DS2API_CONFIG_JSON` (Base64) from `config.json` and inject it
|
||||
|
||||
---
|
||||
|
||||
## 1. Local Run
|
||||
@@ -99,11 +110,11 @@ go build -o ds2api ./cmd/ds2api
|
||||
### 2.1 Basic Steps
|
||||
|
||||
```bash
|
||||
# Copy and edit environment
|
||||
# Copy env template
|
||||
cp .env.example .env
|
||||
# Edit .env, at minimum set:
|
||||
|
||||
# Edit .env and set at least:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]}
|
||||
|
||||
# Start
|
||||
docker-compose up -d
|
||||
@@ -120,11 +131,12 @@ docker-compose up -d --build
|
||||
|
||||
### 2.3 Docker Architecture
|
||||
|
||||
The `Dockerfile` uses a three-stage build:
|
||||
The `Dockerfile` now provides two image paths:
|
||||
|
||||
1. **WebUI build stage**: `node:20` image, runs `npm ci && npm run build`
|
||||
2. **Go build stage**: `golang:1.24` image, compiles the binary
|
||||
3. **Runtime stage**: `debian:bookworm-slim` minimal image
|
||||
1. **Default local/dev path (`runtime-from-source`)**: a three-stage build (WebUI build + Go build + runtime).
|
||||
2. **Release path (`runtime-from-dist`)**: CI first creates `dist/ds2api_<tag>_linux_<arch>.tar.gz`, then Docker directly reuses the binary and `static/admin` assets from those release archives, without running `npm build`/`go build` again.
|
||||
|
||||
The release path keeps Docker images aligned with release archives and reduces duplicate build work.
|
||||
|
||||
Container entry command: `/usr/local/bin/ds2api`, default exposed port: `5001`.
|
||||
|
||||
@@ -145,7 +157,7 @@ Docker Compose includes a built-in health check:
|
||||
|
||||
```yaml
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"]
|
||||
test: ["CMD", "/usr/local/bin/busybox", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
@@ -159,6 +171,19 @@ If container logs look normal but the admin panel is unreachable, check these fi
|
||||
1. **Port alignment**: when `PORT` is not `5001`, use the same port in your URL (for example `http://localhost:8080/admin`).
|
||||
2. **WebUI assets in dev compose**: `docker-compose.dev.yml` runs `go run` in a dev image and does not auto-install Node.js inside the container; if `static/admin` is missing in your repo, `/admin` will return 404. Build once on host: `./scripts/build-webui.sh`.
|
||||
|
||||
### 2.7 Zeabur One-Click (Dockerfile)
|
||||
|
||||
This repo includes a `zeabur.yaml` template for one-click deployment on Zeabur:
|
||||
|
||||
[](https://zeabur.com/templates/L4CFHP)
|
||||
|
||||
Notes:
|
||||
|
||||
- **Port**: DS2API listens on `5001` by default; the template sets `PORT=5001`.
|
||||
- **Persistent config**: the template mounts `/data` and sets `DS2API_CONFIG_PATH=/data/config.json`. After importing config in Admin UI, it will be written and persisted to this path.
|
||||
- **Build version**: Zeabur / regular `docker build` does not require `BUILD_VERSION` by default. The image prefers that build arg when provided, and automatically falls back to the repo-root `VERSION` file when it is absent.
|
||||
- **First login**: after deployment, open `/admin` and login with `DS2API_ADMIN_KEY` shown in Zeabur env/template instructions (recommended: rotate to a strong secret after first login).
|
||||
|
||||
---
|
||||
|
||||
## 3. Vercel Deployment
|
||||
@@ -167,15 +192,49 @@ If container logs look normal but the admin panel is unreachable, check these fi
|
||||
|
||||
1. **Fork** the repo to your GitHub account
|
||||
2. **Import** the project on Vercel
|
||||
3. **Set environment variables** (at minimum):
|
||||
3. **Set environment variables** (minimum required: one variable):
|
||||
|
||||
| Variable | Description |
|
||||
| --- | --- |
|
||||
| `DS2API_ADMIN_KEY` | Admin key (required) |
|
||||
| `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (required) |
|
||||
| `DS2API_CONFIG_JSON` | Config content, raw JSON or Base64 (optional, recommended) |
|
||||
|
||||
4. **Deploy**
|
||||
|
||||
### 3.1.1 Recommended Input (avoid `DS2API_CONFIG_JSON` mistakes)
|
||||
|
||||
If you prefer faster one-click bootstrap, you can leave `DS2API_CONFIG_JSON` empty first, then open `/admin` after deployment, import config, and sync it back to Vercel env vars from the "Vercel Sync" page.
|
||||
|
||||
Recommended: in repo root, copy the template first and fill your real accounts:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
```
|
||||
|
||||
Do not hand-edit large JSON directly in Vercel. Generate Base64 locally and paste it:
|
||||
|
||||
```bash
|
||||
# Run in repo root
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
echo "$DS2API_CONFIG_JSON"
|
||||
```
|
||||
|
||||
If you choose to preconfigure before first deploy, set these vars in Vercel Project Settings -> Environment Variables:
|
||||
|
||||
```text
|
||||
DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||
DS2API_CONFIG_JSON=<the single-line Base64 output above>
|
||||
```
|
||||
|
||||
Optional but recommended (for WebUI one-click Vercel sync):
|
||||
|
||||
```text
|
||||
VERCEL_TOKEN=your-vercel-token
|
||||
VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
|
||||
VERCEL_TEAM_ID=team_xxxxxxxxxxxx # optional for personal accounts
|
||||
```
|
||||
|
||||
### 3.2 Optional Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
@@ -184,6 +243,8 @@ If container logs look normal but the admin panel is unreachable, check these fi
|
||||
| `DS2API_ACCOUNT_CONCURRENCY` | Alias (legacy compat) | — |
|
||||
| `DS2API_ACCOUNT_MAX_QUEUE` | Waiting queue limit | `recommended_concurrency` |
|
||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
|
||||
| `DS2API_GLOBAL_MAX_INFLIGHT` | Global inflight limit | `recommended_concurrency` |
|
||||
| `DS2API_MAX_INFLIGHT` | Alias (legacy compat) | — |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | Hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL | `900` |
|
||||
| `VERCEL_TOKEN` | Vercel sync token | — |
|
||||
@@ -290,6 +351,7 @@ Built-in GitHub Actions workflow: `.github/workflows/release-artifacts.yml`
|
||||
|
||||
- **Trigger**: only on Release `published` (no build on normal push)
|
||||
- **Outputs**: multi-platform binary archives + `sha256sums.txt`
|
||||
- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`)
|
||||
|
||||
| Platform | Architecture | Format |
|
||||
| --- | --- | --- |
|
||||
@@ -301,7 +363,7 @@ Each archive includes:
|
||||
|
||||
- `ds2api` executable (`ds2api.exe` on Windows)
|
||||
- `static/admin/` (built WebUI assets)
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm`
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm` (optional; binary has embedded fallback)
|
||||
- `config.example.json`, `.env.example`
|
||||
- `README.MD`, `README.en.md`, `LICENSE`
|
||||
|
||||
@@ -310,8 +372,8 @@ Each archive includes:
|
||||
```bash
|
||||
# 1. Download the archive for your platform
|
||||
# 2. Extract
|
||||
tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz
|
||||
cd ds2api_v1.7.0_linux_amd64
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
|
||||
# 3. Configure
|
||||
cp config.example.json config.json
|
||||
@@ -323,10 +385,20 @@ cp config.example.json config.json
|
||||
|
||||
### Maintainer Release Flow
|
||||
|
||||
1. Create and publish a GitHub Release (with tag, e.g. `v1.7.0`)
|
||||
1. Create and publish a GitHub Release (with tag, for example `vX.Y.Z`)
|
||||
2. Wait for the `Release Artifacts` workflow to complete
|
||||
3. Download the matching archive from Release Assets
|
||||
|
||||
### Pull from GHCR (Optional)
|
||||
|
||||
```bash
|
||||
# latest
|
||||
docker pull ghcr.io/cjackhwang/ds2api:latest
|
||||
|
||||
# specific version (example)
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v2.1.2
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Reverse Proxy (Nginx)
|
||||
@@ -380,7 +452,9 @@ server {
|
||||
```bash
|
||||
# Copy compiled binary and related files to target directory
|
||||
sudo mkdir -p /opt/ds2api
|
||||
sudo cp ds2api config.json sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp ds2api config.json /opt/ds2api/
|
||||
# Optional: if you want to use an external WASM file (override embedded one)
|
||||
# sudo cp sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp -r static/admin /opt/ds2api/static/admin
|
||||
```
|
||||
|
||||
@@ -469,7 +543,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
|
||||
Run the full live testsuite before release (real account tests):
|
||||
|
||||
```bash
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
With custom flags:
|
||||
|
||||
106
DEPLOY.md
106
DEPLOY.md
@@ -33,6 +33,17 @@
|
||||
- **文件方式**:`config.json`(推荐本地/Docker 使用)
|
||||
- **环境变量方式**:`DS2API_CONFIG_JSON`(推荐 Vercel 使用,支持 JSON 字符串或 Base64 编码)
|
||||
|
||||
统一建议(最优实践):
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
```
|
||||
|
||||
建议把 `config.json` 作为唯一配置源:
|
||||
- 本地运行:直接读 `config.json`
|
||||
- Docker / Vercel:从 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量
|
||||
|
||||
---
|
||||
|
||||
## 一、本地运行
|
||||
@@ -99,11 +110,11 @@ go build -o ds2api ./cmd/ds2api
|
||||
### 2.1 基本步骤
|
||||
|
||||
```bash
|
||||
# 复制并编辑环境变量
|
||||
# 复制环境变量模板
|
||||
cp .env.example .env
|
||||
# 编辑 .env,至少设置:
|
||||
|
||||
# 编辑 .env(请改成你的强密码),至少设置:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
# DS2API_CONFIG_JSON={"keys":[...],"accounts":[...]}
|
||||
|
||||
# 启动
|
||||
docker-compose up -d
|
||||
@@ -120,11 +131,12 @@ docker-compose up -d --build
|
||||
|
||||
### 2.3 Docker 架构说明
|
||||
|
||||
`Dockerfile` 使用三阶段构建:
|
||||
`Dockerfile` 提供两条构建路径:
|
||||
|
||||
1. **WebUI 构建阶段**:`node:20` 镜像,执行 `npm ci && npm run build`
|
||||
2. **Go 构建阶段**:`golang:1.24` 镜像,编译二进制文件
|
||||
3. **运行阶段**:`debian:bookworm-slim` 精简镜像
|
||||
1. **本地/开发默认路径(`runtime-from-source`)**:三阶段构建(WebUI 构建 + Go 构建 + 运行阶段)。
|
||||
2. **Release 路径(`runtime-from-dist`)**:CI 先生成 `dist/ds2api_<tag>_linux_<arch>.tar.gz`,再由 Docker 直接复用该发布包内的二进制和 `static/admin` 产物组装运行镜像,不再重复执行 `npm build`/`go build`。
|
||||
|
||||
Release 路径可确保 Docker 镜像与 release 压缩包使用同一套产物,减少重复构建带来的差异。
|
||||
|
||||
容器内启动命令:`/usr/local/bin/ds2api`,默认暴露端口 `5001`。
|
||||
|
||||
@@ -145,7 +157,7 @@ Docker Compose 已配置内置健康检查:
|
||||
|
||||
```yaml
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"]
|
||||
test: ["CMD", "/usr/local/bin/busybox", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
@@ -159,6 +171,19 @@ healthcheck:
|
||||
1. **端口是否一致**:`PORT` 改成非 `5001` 时,访问地址也要改成对应端口(如 `http://localhost:8080/admin`)。
|
||||
2. **开发 compose 的 WebUI 静态文件**:`docker-compose.dev.yml` 使用 `go run` 开发镜像,不会在容器内自动安装 Node.js;若仓库里没有 `static/admin`,`/admin` 会返回 404。可先在宿主机构建一次:`./scripts/build-webui.sh`。
|
||||
|
||||
### 2.7 Zeabur 一键部署(Dockerfile)
|
||||
|
||||
仓库提供 `zeabur.yaml` 模板,可在 Zeabur 上一键部署:
|
||||
|
||||
[](https://zeabur.com/templates/L4CFHP)
|
||||
|
||||
部署要点:
|
||||
|
||||
- **端口**:服务默认监听 `5001`,模板会固定设置 `PORT=5001`。
|
||||
- **配置持久化**:模板挂载卷 `/data`,并设置 `DS2API_CONFIG_PATH=/data/config.json`;在管理台导入配置后,会写入并持久化到该路径。
|
||||
- **构建版本号**:Zeabur / 普通 `docker build` 默认不需要传 `BUILD_VERSION`;镜像会优先使用该构建参数,未提供时自动回退到仓库根目录的 `VERSION` 文件。
|
||||
- **首次登录**:部署完成后访问 `/admin`,使用 Zeabur 环境变量/模板指引中的 `DS2API_ADMIN_KEY` 登录(建议首次登录后自行更换为强密码)。
|
||||
|
||||
---
|
||||
|
||||
## 三、Vercel 部署
|
||||
@@ -167,15 +192,49 @@ healthcheck:
|
||||
|
||||
1. **Fork 仓库**到你的 GitHub 账号
|
||||
2. **在 Vercel 上导入项目**
|
||||
3. **配置环境变量**(至少设置以下两项):
|
||||
3. **配置环境变量**(最少只需设置以下一项):
|
||||
|
||||
| 变量 | 说明 |
|
||||
| --- | --- |
|
||||
| `DS2API_ADMIN_KEY` | 管理密钥(必填) |
|
||||
| `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(必填) |
|
||||
| `DS2API_CONFIG_JSON` | 配置内容,JSON 字符串或 Base64 编码(可选,建议) |
|
||||
|
||||
4. **部署**
|
||||
|
||||
### 3.1.1 推荐填写方式(避免 `DS2API_CONFIG_JSON` 填错)
|
||||
|
||||
如果你想先完成一键部署,也可以先不填 `DS2API_CONFIG_JSON`,部署后进入 `/admin` 导入配置,再在「Vercel 同步」里写回环境变量。
|
||||
|
||||
建议先在仓库目录复制示例配置,再按实际账号填写:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
```
|
||||
|
||||
不要在 Vercel 面板里手写复杂 JSON,建议本地生成 Base64 后粘贴:
|
||||
|
||||
```bash
|
||||
# 在仓库根目录执行
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
echo "$DS2API_CONFIG_JSON"
|
||||
```
|
||||
|
||||
如果你选择在部署前就预置配置,请在 Vercel Project Settings -> Environment Variables 配置:
|
||||
|
||||
```text
|
||||
DS2API_ADMIN_KEY=请替换为强密码
|
||||
DS2API_CONFIG_JSON=上一步生成的一整行 Base64
|
||||
```
|
||||
|
||||
可选但推荐(用于 WebUI 一键同步 Vercel 配置):
|
||||
|
||||
```text
|
||||
VERCEL_TOKEN=你的 Vercel Token
|
||||
VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
|
||||
VERCEL_TEAM_ID=team_xxxxxxxxxxxx # 个人账号可留空
|
||||
```
|
||||
|
||||
### 3.2 可选环境变量
|
||||
|
||||
| 变量 | 说明 | 默认值 |
|
||||
@@ -184,6 +243,8 @@ healthcheck:
|
||||
| `DS2API_ACCOUNT_CONCURRENCY` | 同上(兼容别名) | — |
|
||||
| `DS2API_ACCOUNT_MAX_QUEUE` | 等待队列上限 | `recommended_concurrency` |
|
||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容别名) | — |
|
||||
| `DS2API_GLOBAL_MAX_INFLIGHT` | 全局并发上限 | `recommended_concurrency` |
|
||||
| `DS2API_MAX_INFLIGHT` | 同上(兼容别名) | — |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | 混合流式内部鉴权 | 回退用 `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease TTL | `900` |
|
||||
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
||||
@@ -290,6 +351,7 @@ No Output Directory named "public" found after the Build completed.
|
||||
|
||||
- **触发条件**:仅在 Release `published` 时触发(普通 push 不会构建)
|
||||
- **构建产物**:多平台二进制压缩包 + `sha256sums.txt`
|
||||
- **容器镜像发布**:仅发布到 GHCR(`ghcr.io/cjackhwang/ds2api`)
|
||||
|
||||
| 平台 | 架构 | 文件格式 |
|
||||
| --- | --- | --- |
|
||||
@@ -301,7 +363,7 @@ No Output Directory named "public" found after the Build completed.
|
||||
|
||||
- `ds2api` 可执行文件(Windows 为 `ds2api.exe`)
|
||||
- `static/admin/`(WebUI 构建产物)
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm`
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm`(可选;程序内置 embed fallback)
|
||||
- `config.example.json`、`.env.example`
|
||||
- `README.MD`、`README.en.md`、`LICENSE`
|
||||
|
||||
@@ -310,8 +372,8 @@ No Output Directory named "public" found after the Build completed.
|
||||
```bash
|
||||
# 1. 下载对应平台的压缩包
|
||||
# 2. 解压
|
||||
tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz
|
||||
cd ds2api_v1.7.0_linux_amd64
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
|
||||
# 3. 配置
|
||||
cp config.example.json config.json
|
||||
@@ -323,10 +385,20 @@ cp config.example.json config.json
|
||||
|
||||
### 维护者发布步骤
|
||||
|
||||
1. 在 GitHub 创建并发布 Release(带 tag,如 `v1.7.0`)
|
||||
1. 在 GitHub 创建并发布 Release(带 tag,如 `vX.Y.Z`)
|
||||
2. 等待 Actions 工作流 `Release Artifacts` 完成
|
||||
3. 在 Release 的 Assets 下载对应平台压缩包
|
||||
|
||||
### 拉取 GHCR 镜像(可选)
|
||||
|
||||
```bash
|
||||
# latest
|
||||
docker pull ghcr.io/cjackhwang/ds2api:latest
|
||||
|
||||
# 指定版本(示例)
|
||||
docker pull ghcr.io/cjackhwang/ds2api:v2.1.2
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 五、反向代理(Nginx)
|
||||
@@ -380,7 +452,9 @@ server {
|
||||
```bash
|
||||
# 将编译好的二进制文件和相关文件复制到目标目录
|
||||
sudo mkdir -p /opt/ds2api
|
||||
sudo cp ds2api config.json sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp ds2api config.json /opt/ds2api/
|
||||
# 可选:若你希望使用外置 WASM 文件(覆盖内置版本)
|
||||
# sudo cp sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp -r static/admin /opt/ds2api/static/admin
|
||||
```
|
||||
|
||||
@@ -469,7 +543,7 @@ curl http://127.0.0.1:5001/v1/chat/completions \
|
||||
建议在发布前执行完整的端到端测试集(使用真实账号):
|
||||
|
||||
```bash
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
可自定义参数:
|
||||
|
||||
52
Dockerfile
52
Dockerfile
@@ -8,17 +8,59 @@ RUN npm run build
|
||||
|
||||
FROM golang:1.24 AS go-builder
|
||||
WORKDIR /app
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
ARG BUILD_VERSION
|
||||
COPY go.mod go.sum* ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o /out/ds2api ./cmd/ds2api
|
||||
RUN set -eux; \
|
||||
GOOS="${TARGETOS:-$(go env GOOS)}"; \
|
||||
GOARCH="${TARGETARCH:-$(go env GOARCH)}"; \
|
||||
BUILD_VERSION_RESOLVED="${BUILD_VERSION:-}"; \
|
||||
if [ -z "${BUILD_VERSION_RESOLVED}" ] && [ -f VERSION ]; then BUILD_VERSION_RESOLVED="$(cat VERSION | tr -d "[:space:]")"; fi; \
|
||||
CGO_ENABLED=0 GOOS="${GOOS}" GOARCH="${GOARCH}" go build -ldflags="-s -w -X ds2api/internal/version.BuildVersion=${BUILD_VERSION_RESOLVED}" -o /out/ds2api ./cmd/ds2api
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
FROM busybox:1.36.1-musl AS busybox-tools
|
||||
|
||||
FROM debian:bookworm-slim AS runtime-base
|
||||
WORKDIR /app
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates wget && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=busybox-tools /bin/busybox /usr/local/bin/busybox
|
||||
EXPOSE 5001
|
||||
CMD ["/usr/local/bin/ds2api"]
|
||||
|
||||
FROM runtime-base AS runtime-from-source
|
||||
COPY --from=go-builder /out/ds2api /usr/local/bin/ds2api
|
||||
COPY --from=go-builder /app/sha3_wasm_bg.7b9ca65ddd.wasm /app/sha3_wasm_bg.7b9ca65ddd.wasm
|
||||
COPY --from=go-builder /app/config.example.json /app/config.example.json
|
||||
COPY --from=webui-builder /app/static/admin /app/static/admin
|
||||
EXPOSE 5001
|
||||
CMD ["/usr/local/bin/ds2api"]
|
||||
|
||||
FROM busybox-tools AS dist-extract
|
||||
ARG TARGETARCH
|
||||
COPY dist/docker-input/linux_amd64.tar.gz /tmp/ds2api_linux_amd64.tar.gz
|
||||
COPY dist/docker-input/linux_arm64.tar.gz /tmp/ds2api_linux_arm64.tar.gz
|
||||
RUN set -eux; \
|
||||
case "${TARGETARCH}" in \
|
||||
amd64) ARCHIVE="/tmp/ds2api_linux_amd64.tar.gz" ;; \
|
||||
arm64) ARCHIVE="/tmp/ds2api_linux_arm64.tar.gz" ;; \
|
||||
*) echo "unsupported TARGETARCH: ${TARGETARCH}" >&2; exit 1 ;; \
|
||||
esac; \
|
||||
tar -xzf "${ARCHIVE}" -C /tmp; \
|
||||
PKG_DIR="$(find /tmp -maxdepth 1 -type d -name "ds2api_*_linux_${TARGETARCH}" | head -n1)"; \
|
||||
test -n "${PKG_DIR}"; \
|
||||
mkdir -p /out/static; \
|
||||
cp "${PKG_DIR}/ds2api" /out/ds2api; \
|
||||
cp "${PKG_DIR}/sha3_wasm_bg.7b9ca65ddd.wasm" /out/sha3_wasm_bg.7b9ca65ddd.wasm; \
|
||||
cp "${PKG_DIR}/config.example.json" /out/config.example.json; \
|
||||
cp -R "${PKG_DIR}/static/admin" /out/static/admin
|
||||
|
||||
FROM runtime-base AS runtime-from-dist
|
||||
COPY --from=dist-extract /out/ds2api /usr/local/bin/ds2api
|
||||
COPY --from=dist-extract /out/sha3_wasm_bg.7b9ca65ddd.wasm /app/sha3_wasm_bg.7b9ca65ddd.wasm
|
||||
COPY --from=dist-extract /out/config.example.json /app/config.example.json
|
||||
COPY --from=dist-extract /out/static/admin /app/static/admin
|
||||
|
||||
FROM runtime-from-source AS final
|
||||
|
||||
256
README.MD
256
README.MD
@@ -1,20 +1,26 @@
|
||||
<p align="center">
|
||||
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
|
||||
</p>
|
||||
|
||||
# DS2API
|
||||
|
||||
[](LICENSE)
|
||||

|
||||

|
||||
[](version.txt)
|
||||
[](https://github.com/CJackHwang/ds2api/releases)
|
||||
[](DEPLOY.md)
|
||||
[](https://zeabur.com/templates/L4CFHP)
|
||||
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||
|
||||
语言 / Language: [中文](README.MD) | [English](README.en.md)
|
||||
|
||||
将 DeepSeek Web 对话能力转换为 OpenAI 与 Claude 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。
|
||||
将 DeepSeek Web 对话能力转换为 OpenAI、Claude 与 Gemini 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。
|
||||
|
||||
## 架构概览
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
Client["🖥️ 客户端\n(OpenAI / Claude 兼容)"]
|
||||
Client["🖥️ 客户端\n(OpenAI / Claude / Gemini 兼容)"]
|
||||
|
||||
subgraph DS2API["DS2API 服务"]
|
||||
direction TB
|
||||
@@ -24,6 +30,7 @@ flowchart LR
|
||||
subgraph Adapters["适配器层"]
|
||||
OA["OpenAI 适配器\n/v1/*"]
|
||||
CA["Claude 适配器\n/anthropic/*"]
|
||||
GA["Gemini 适配器\n/v1beta/models/*"]
|
||||
end
|
||||
|
||||
subgraph Support["支撑模块"]
|
||||
@@ -38,11 +45,11 @@ flowchart LR
|
||||
DS["☁️ DeepSeek API"]
|
||||
|
||||
Client -- "请求" --> CORS --> Auth
|
||||
Auth --> OA & CA
|
||||
OA & CA -- "调用" --> DS
|
||||
Auth --> OA & CA & GA
|
||||
OA & CA & GA -- "调用" --> DS
|
||||
Auth --> Admin
|
||||
OA & CA -. "轮询选账号" .-> Pool
|
||||
OA & CA -. "计算 PoW" .-> PoW
|
||||
OA & CA & GA -. "轮询选账号" .-> Pool
|
||||
OA & CA & GA -. "计算 PoW" .-> PoW
|
||||
DS -- "响应" --> Client
|
||||
```
|
||||
|
||||
@@ -54,16 +61,29 @@ flowchart LR
|
||||
|
||||
| 能力 | 说明 |
|
||||
| --- | --- |
|
||||
| OpenAI 兼容 | `GET /v1/models`、`POST /v1/chat/completions`(流式/非流式) |
|
||||
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens` |
|
||||
| OpenAI 兼容 | `GET /v1/models`、`GET /v1/models/{id}`、`POST /v1/chat/completions`、`POST /v1/responses`、`GET /v1/responses/{response_id}`、`POST /v1/embeddings` |
|
||||
| Claude 兼容 | `GET /anthropic/v1/models`、`POST /anthropic/v1/messages`、`POST /anthropic/v1/messages/count_tokens`(及快捷路径 `/v1/messages`、`/messages`) |
|
||||
| Gemini 兼容 | `POST /v1beta/models/{model}:generateContent`、`POST /v1beta/models/{model}:streamGenerateContent`(及 `/v1/models/{model}:*` 路径) |
|
||||
| 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 |
|
||||
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
|
||||
| DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 |
|
||||
| Tool Calling | 防泄漏处理:自动缓冲、识别、结构化输出 |
|
||||
| Admin API | 配置管理、账号测试 / 批量测试、导入导出、Vercel 同步 |
|
||||
| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 |
|
||||
| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、导入导出、Vercel 同步 |
|
||||
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
|
||||
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
|
||||
|
||||
## 平台兼容矩阵
|
||||
|
||||
| 级别 | 平台 | 当前状态 |
|
||||
| --- | --- | --- |
|
||||
| P0 | Codex CLI/SDK(`wire_api=chat` / `wire_api=responses`) | ✅ |
|
||||
| P0 | OpenAI SDK(JS/Python,chat + responses) | ✅ |
|
||||
| P0 | Vercel AI SDK(openai-compatible) | ✅ |
|
||||
| P0 | Anthropic SDK(messages) | ✅ |
|
||||
| P0 | Google Gemini SDK(generateContent) | ✅ |
|
||||
| P1 | LangChain / LlamaIndex / OpenWebUI(OpenAI 兼容接入) | ✅ |
|
||||
| P2 | MCP 独立桥接层 | 规划中 |
|
||||
|
||||
## 模型支持
|
||||
|
||||
### OpenAI 接口
|
||||
@@ -86,8 +106,33 @@ flowchart LR
|
||||
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
||||
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
||||
|
||||
|
||||
#### Claude Code 接入避坑(实测)
|
||||
|
||||
- `ANTHROPIC_BASE_URL` 推荐直接指向 DS2API 根地址(例如 `http://127.0.0.1:5001`),Claude Code 会请求 `/v1/messages?beta=true`。
|
||||
- `ANTHROPIC_API_KEY` 需要与 `config.json` 中 `keys` 一致;建议同时保留常规 key 与 `sk-ant-*` 形态 key,兼容不同客户端校验习惯。
|
||||
- 若系统设置了代理,建议对 DS2API 地址配置 `NO_PROXY=127.0.0.1,localhost,<你的主机IP>`,避免本地回环请求被代理拦截。
|
||||
- 如遇“工具调用输出成文本、未执行”问题,请升级到包含 Claude 工具调用多格式解析(JSON/XML/ANTML/invoke)的版本。
|
||||
|
||||
### Gemini 接口
|
||||
|
||||
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 通用第一步(所有部署方式)
|
||||
|
||||
把 `config.json` 作为唯一配置源(推荐做法):
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
```
|
||||
|
||||
后续部署建议:
|
||||
- 本地运行:直接读取 `config.json`
|
||||
- Docker / Vercel:由 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量
|
||||
|
||||
### 方式一:本地运行
|
||||
|
||||
**前置要求**:Go 1.24+,Node.js 20+(仅在需要构建 WebUI 时)
|
||||
@@ -112,26 +157,49 @@ go run ./cmd/ds2api
|
||||
### 方式二:Docker 运行
|
||||
|
||||
```bash
|
||||
# 1. 配置环境变量
|
||||
# 1. 准备环境变量文件
|
||||
cp .env.example .env
|
||||
# 编辑 .env
|
||||
|
||||
# 2. 启动
|
||||
# 2. 编辑 .env(至少设置 DS2API_ADMIN_KEY)
|
||||
# DS2API_ADMIN_KEY=请替换为强密码
|
||||
|
||||
# 3. 启动
|
||||
docker-compose up -d
|
||||
|
||||
# 3. 查看日志
|
||||
# 4. 查看日志
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
更新镜像:`docker-compose up -d --build`
|
||||
|
||||
#### Zeabur 一键部署(Dockerfile)
|
||||
|
||||
1. 点击上方 “Deploy on Zeabur” 按钮,一键部署。
|
||||
2. 部署完成后访问 `/admin`,使用 Zeabur 环境变量/模板指引中的 `DS2API_ADMIN_KEY` 登录。
|
||||
3. 在管理台导入/编辑配置(会写入并持久化到 `/data/config.json`)。
|
||||
|
||||
说明:Zeabur 使用仓库内 `Dockerfile` 直接构建时,不需要额外传入 `BUILD_VERSION`;镜像会优先读取该构建参数,未提供时自动回退到仓库根目录的 `VERSION` 文件。
|
||||
|
||||
### 方式三:Vercel 部署
|
||||
|
||||
1. Fork 仓库到自己的 GitHub
|
||||
2. 在 Vercel 上导入项目
|
||||
3. 配置环境变量(至少设置 `DS2API_ADMIN_KEY` 和 `DS2API_CONFIG_JSON`)
|
||||
3. 配置环境变量(最少设置 `DS2API_ADMIN_KEY`;推荐同时设置 `DS2API_CONFIG_JSON`)
|
||||
4. 部署
|
||||
|
||||
建议先在仓库目录复制模板并填写:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
```
|
||||
|
||||
推荐:先本地把 `config.json` 转成 Base64,再粘贴到 `DS2API_CONFIG_JSON`,避免 JSON 格式错误:
|
||||
|
||||
```bash
|
||||
base64 < config.json | tr -d '\n'
|
||||
```
|
||||
|
||||
> **流式说明**:`/v1/chat/completions` 在 Vercel 上默认走 `api/chat-stream.js`(Node Runtime)以保证实时 SSE。鉴权、账号选择、会话/PoW 准备仍由 Go 内部 prepare 接口完成;流式响应(含 `tools`)在 Node 侧执行与 Go 对齐的输出组装与防泄漏处理。
|
||||
|
||||
详细部署说明请参阅 [部署指南](DEPLOY.md)。
|
||||
@@ -142,8 +210,8 @@ docker-compose logs -f
|
||||
|
||||
```bash
|
||||
# 下载对应平台的压缩包后
|
||||
tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz
|
||||
cd ds2api_v1.7.0_linux_amd64
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
cp config.example.json config.json
|
||||
# 编辑 config.json
|
||||
./ds2api
|
||||
@@ -164,6 +232,7 @@ cp opencode.json.example opencode.json
|
||||
3. 在项目目录启动 OpenCode CLI(按你的安装方式运行 `opencode`)。
|
||||
|
||||
> 建议优先使用 OpenAI 兼容路径(`/v1/*`),即示例里的 `@ai-sdk/openai-compatible` provider。
|
||||
> 若客户端支持 `wire_api`,可分别测试 `responses` 与 `chat`,DS2API 两条链路都兼容。
|
||||
|
||||
## 配置说明
|
||||
|
||||
@@ -175,26 +244,61 @@ cp opencode.json.example opencode.json
|
||||
"accounts": [
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
},
|
||||
{
|
||||
"mobile": "12345678901",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
}
|
||||
],
|
||||
"claude_model_mapping": {
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true
|
||||
},
|
||||
"toolcall": {
|
||||
"mode": "feature_match",
|
||||
"early_emit_confidence": "high"
|
||||
},
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
"admin": {
|
||||
"jwt_expire_hours": 24
|
||||
},
|
||||
"runtime": {
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
},
|
||||
"auto_delete": {
|
||||
"sessions": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer <key>` 鉴权
|
||||
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录
|
||||
- `token`:留空则首次请求时自动登录获取;也可预填已有 token
|
||||
- `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型
|
||||
- `token`:配置文件中即使填写也会在加载时被清空(不会从 `config.json` 读取 token);实际 token 仅在运行时内存中维护并自动刷新
|
||||
- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射
|
||||
- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出)
|
||||
- `toolcall`:固定采用特征匹配 + 高置信早发策略
|
||||
- `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL
|
||||
- `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`)
|
||||
- `claude_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型(兼容读取 `claude_model_mapping`)
|
||||
- `admin`:管理后台设置(JWT 过期时间、密码哈希等),可通过 Admin Settings API 热更新
|
||||
- `runtime`:运行时参数(并发限制、队列大小),可通过 Admin Settings API 热更新;`account_max_queue=0`/`global_max_inflight=0` 表示按推荐值自动计算
|
||||
- `auto_delete.sessions`:是否在请求结束后自动清理 DeepSeek 会话(默认 `false`,可在 Settings 热更新)
|
||||
|
||||
### 环境变量
|
||||
|
||||
@@ -214,8 +318,13 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_ACCOUNT_CONCURRENCY` | 同上(兼容旧名) | — |
|
||||
| `DS2API_ACCOUNT_MAX_QUEUE` | 等待队列上限 | `recommended_concurrency` |
|
||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | 同上(兼容旧名) | — |
|
||||
| `DS2API_GLOBAL_MAX_INFLIGHT` | 全局最大 in-flight 请求数 | `recommended_concurrency` |
|
||||
| `DS2API_MAX_INFLIGHT` | 同上(兼容旧名) | — |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel 混合流式内部鉴权密钥 | 回退用 `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | 流式 lease 过期秒数 | `900` |
|
||||
| `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | 本地抓包保留条数(超出自动淘汰) | `5` |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | 单条响应体最大记录字节数 | `2097152` |
|
||||
| `VERCEL_TOKEN` | Vercel 同步 token | — |
|
||||
| `VERCEL_PROJECT_ID` | Vercel 项目 ID | — |
|
||||
| `VERCEL_TEAM_ID` | Vercel 团队 ID | — |
|
||||
@@ -223,7 +332,7 @@ cp opencode.json.example opencode.json
|
||||
|
||||
## 鉴权模式
|
||||
|
||||
调用业务接口(`/v1/*`、`/anthropic/*`)时支持两种模式:
|
||||
调用业务接口(`/v1/*`、`/anthropic/*`、Gemini 路由)时支持两种模式:
|
||||
|
||||
| 模式 | 说明 |
|
||||
| --- | --- |
|
||||
@@ -249,10 +358,34 @@ cp opencode.json.example opencode.json
|
||||
|
||||
当请求中带 `tools` 时,DS2API 会做防泄漏处理:
|
||||
|
||||
1. `stream=true` 时先**缓冲**正文片段
|
||||
2. 若识别到工具调用 → 仅输出结构化 `tool_calls`,不透传原始 JSON 文本
|
||||
3. 若最终不是工具调用 → 一次性输出普通文本
|
||||
4. 解析器支持混合文本、fenced JSON、`function.arguments` 字符串等格式
|
||||
1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发)
|
||||
2. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`)
|
||||
3. 未在 `tools` 声明中的工具名会被严格拒绝,不会下发为有效 tool call
|
||||
4. `responses` 支持并执行 `tool_choice`(`auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed`
|
||||
5. 仅在通过策略校验后才会发出有效工具调用事件,避免错误工具名进入客户端执行链
|
||||
|
||||
## 本地开发抓包工具
|
||||
|
||||
用于定位「responses 思考流/工具调用」等问题。开启后会自动记录最近 N 条 DeepSeek 对话上游请求体与响应体(默认 5 条,超出自动淘汰)。
|
||||
|
||||
启用示例:
|
||||
|
||||
```bash
|
||||
DS2API_DEV_PACKET_CAPTURE=true \
|
||||
DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
查询/清空(需 Admin JWT):
|
||||
|
||||
- `GET /admin/dev/captures`:查看抓包列表(最新在前)
|
||||
- `DELETE /admin/dev/captures`:清空抓包
|
||||
|
||||
返回字段包含:
|
||||
|
||||
- `request_body`:发送给 DeepSeek 的完整请求体
|
||||
- `response_body`:上游返回的原始流式内容拼接文本
|
||||
- `response_truncated`:是否触发单条大小截断
|
||||
|
||||
## 项目结构
|
||||
|
||||
@@ -264,30 +397,42 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go 入口
|
||||
│ ├── chat-stream.js # Vercel Node.js 流式转发
|
||||
│ └── helpers/ # Node.js 辅助模块
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # 账号池与并发队列
|
||||
│ ├── adapter/
|
||||
│ │ ├── openai/ # OpenAI 兼容适配器(含 Tool Call 解析、Vercel 流式 prepare/release)
|
||||
│ │ └── claude/ # Claude 兼容适配器
|
||||
│ ├── admin/ # Admin API handlers
|
||||
│ │ ├── claude/ # Claude 兼容适配器
|
||||
│ │ └── gemini/ # Gemini 兼容适配器(generateContent / streamGenerateContent)
|
||||
│ ├── admin/ # Admin API handlers(含 Settings 热更新)
|
||||
│ ├── auth/ # 鉴权与 JWT
|
||||
│ ├── claudeconv/ # Claude 消息格式转换
|
||||
│ ├── compat/ # 兼容性辅助
|
||||
│ ├── config/ # 配置加载与热更新
|
||||
│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM
|
||||
│ ├── js/ # Node 运行时流式处理与兼容逻辑
|
||||
│ ├── devcapture/ # 开发抓包模块
|
||||
│ ├── format/ # 输出格式化
|
||||
│ ├── prompt/ # Prompt 构建
|
||||
│ ├── server/ # HTTP 路由与中间件(chi router)
|
||||
│ ├── sse/ # SSE 解析工具
|
||||
│ ├── stream/ # 统一流式消费引擎
|
||||
│ ├── util/ # 通用工具函数
|
||||
│ └── webui/ # WebUI 静态文件托管与自动构建
|
||||
├── webui/ # React WebUI 源码(Vite + Tailwind)
|
||||
│ └── src/
|
||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||
│ ├── app/ # 路由、鉴权、配置状态管理
|
||||
│ ├── features/ # 业务功能模块(account/settings/vercel/apiTester)
|
||||
│ ├── components/ # 登录/落地页等通用组件
|
||||
│ └── locales/ # 中英文语言包(zh.json / en.json)
|
||||
├── scripts/
|
||||
│ ├── build-webui.sh # WebUI 手动构建脚本
|
||||
│ └── testsuite/ # 测试集运行脚本
|
||||
│ └── build-webui.sh # WebUI 手动构建脚本
|
||||
├── tests/
|
||||
│ ├── compat/ # 兼容性测试夹具与期望输出
|
||||
│ └── scripts/ # 统一测试脚本入口(unit/e2e)
|
||||
├── static/admin/ # WebUI 构建产物(不提交到 Git)
|
||||
├── .github/
|
||||
│ ├── workflows/ # GitHub Actions(Release 自动构建)
|
||||
│ ├── workflows/ # GitHub Actions(质量门禁 + Release 自动构建)
|
||||
│ ├── ISSUE_TEMPLATE/ # Issue 模板
|
||||
│ └── PULL_REQUEST_TEMPLATE.md
|
||||
├── config.example.json # 配置文件示例
|
||||
@@ -296,8 +441,7 @@ ds2api/
|
||||
├── docker-compose.yml # 生产环境 Docker Compose
|
||||
├── docker-compose.dev.yml # 开发环境 Docker Compose
|
||||
├── vercel.json # Vercel 路由与构建配置
|
||||
├── go.mod / go.sum # Go 模块依赖
|
||||
└── version.txt # 版本号
|
||||
└── go.mod / go.sum # Go 模块依赖
|
||||
```
|
||||
|
||||
## 文档索引
|
||||
@@ -312,11 +456,11 @@ ds2api/
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
# 单元测试
|
||||
go test ./...
|
||||
# 单元测试(Go + Node)
|
||||
./tests/scripts/run-unit-all.sh
|
||||
|
||||
# 一键端到端全链路测试(真实账号,生成完整请求/响应日志)
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
|
||||
# 或自定义参数
|
||||
go run ./cmd/ds2api-tests \
|
||||
@@ -327,13 +471,39 @@ go run ./cmd/ds2api-tests \
|
||||
--retries 2
|
||||
```
|
||||
|
||||
```bash
|
||||
# 发布前阻断门禁
|
||||
./tests/scripts/check-stage6-manual-smoke.sh
|
||||
./tests/scripts/check-refactor-line-gate.sh
|
||||
./tests/scripts/run-unit-all.sh
|
||||
npm ci --prefix webui && npm run build --prefix webui
|
||||
```
|
||||
|
||||
## 测试
|
||||
|
||||
详细测试指南请参阅 [TESTING.md](TESTING.md)。
|
||||
|
||||
### 快速测试命令
|
||||
|
||||
```bash
|
||||
# 运行所有单元测试
|
||||
go test ./...
|
||||
|
||||
# 运行 tool calls 相关测试(调试工具调用问题)
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||
|
||||
# 运行端到端测试
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
## Release 自动构建(GitHub Actions)
|
||||
|
||||
工作流文件:`.github/workflows/release-artifacts.yml`
|
||||
|
||||
- **触发条件**:仅在 GitHub Release `published` 时触发(普通 push 不会触发)
|
||||
- **构建产物**:多平台二进制包(`linux/amd64`、`linux/arm64`、`darwin/amd64`、`darwin/arm64`、`windows/amd64`)+ `sha256sums.txt`
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件、配置示例、README、LICENSE
|
||||
- **容器镜像发布**:仅推送到 GHCR(`ghcr.io/cjackhwang/ds2api`)
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件(同时支持内置 fallback)、配置示例、README、LICENSE
|
||||
|
||||
## 免责声明
|
||||
|
||||
|
||||
240
README.en.md
240
README.en.md
@@ -1,20 +1,26 @@
|
||||
<p align="center">
|
||||
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
|
||||
</p>
|
||||
|
||||
# DS2API
|
||||
|
||||
[](LICENSE)
|
||||

|
||||

|
||||
[](version.txt)
|
||||
[](https://github.com/CJackHwang/ds2api/releases)
|
||||
[](DEPLOY.en.md)
|
||||
[](https://zeabur.com/templates/L4CFHP)
|
||||
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||
|
||||
Language: [中文](README.MD) | [English](README.en.md)
|
||||
|
||||
DS2API converts DeepSeek Web chat capability into OpenAI-compatible and Claude-compatible APIs. The backend is a **pure Go implementation**, with a React WebUI admin panel (source in `webui/`, build output auto-generated to `static/admin` during deployment).
|
||||
DS2API converts DeepSeek Web chat capability into OpenAI-compatible, Claude-compatible, and Gemini-compatible APIs. The backend is a **pure Go implementation**, with a React WebUI admin panel (source in `webui/`, build output auto-generated to `static/admin` during deployment).
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
Client["🖥️ Clients\n(OpenAI / Claude compat)"]
|
||||
Client["🖥️ Clients\n(OpenAI / Claude / Gemini compat)"]
|
||||
|
||||
subgraph DS2API["DS2API Service"]
|
||||
direction TB
|
||||
@@ -24,6 +30,7 @@ flowchart LR
|
||||
subgraph Adapters["Adapter Layer"]
|
||||
OA["OpenAI Adapter\n/v1/*"]
|
||||
CA["Claude Adapter\n/anthropic/*"]
|
||||
GA["Gemini Adapter\n/v1beta/models/*"]
|
||||
end
|
||||
|
||||
subgraph Support["Support Modules"]
|
||||
@@ -38,11 +45,11 @@ flowchart LR
|
||||
DS["☁️ DeepSeek API"]
|
||||
|
||||
Client -- "Request" --> CORS --> Auth
|
||||
Auth --> OA & CA
|
||||
OA & CA -- "Call" --> DS
|
||||
Auth --> OA & CA & GA
|
||||
OA & CA & GA -- "Call" --> DS
|
||||
Auth --> Admin
|
||||
OA & CA -. "Rotate accounts" .-> Pool
|
||||
OA & CA -. "Compute PoW" .-> PoW
|
||||
OA & CA & GA -. "Rotate accounts" .-> Pool
|
||||
OA & CA & GA -. "Compute PoW" .-> PoW
|
||||
DS -- "Response" --> Client
|
||||
```
|
||||
|
||||
@@ -54,16 +61,29 @@ flowchart LR
|
||||
|
||||
| Capability | Details |
|
||||
| --- | --- |
|
||||
| OpenAI compatible | `GET /v1/models`, `POST /v1/chat/completions` (stream/non-stream) |
|
||||
| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` |
|
||||
| OpenAI compatible | `GET /v1/models`, `GET /v1/models/{id}`, `POST /v1/chat/completions`, `POST /v1/responses`, `GET /v1/responses/{response_id}`, `POST /v1/embeddings` |
|
||||
| Claude compatible | `GET /anthropic/v1/models`, `POST /anthropic/v1/messages`, `POST /anthropic/v1/messages/count_tokens` (plus shortcut paths `/v1/messages`, `/messages`) |
|
||||
| Gemini compatible | `POST /v1beta/models/{model}:generateContent`, `POST /v1beta/models/{model}:streamGenerateContent` (plus `/v1/models/{model}:*` paths) |
|
||||
| Multi-account rotation | Auto token refresh, email/mobile dual login |
|
||||
| Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency |
|
||||
| DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency |
|
||||
| Tool Calling | Anti-leak handling: auto buffer, detect, structured output |
|
||||
| Admin API | Config management, account testing/batch test, import/export, Vercel sync |
|
||||
| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output |
|
||||
| Admin API | Config management, runtime settings hot-reload, account testing/batch test, import/export, Vercel sync |
|
||||
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) |
|
||||
| Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) |
|
||||
|
||||
## Platform Compatibility Matrix
|
||||
|
||||
| Tier | Platform | Status |
|
||||
| --- | --- | --- |
|
||||
| P0 | Codex CLI/SDK (`wire_api=chat` / `wire_api=responses`) | ✅ |
|
||||
| P0 | OpenAI SDK (JS/Python, chat + responses) | ✅ |
|
||||
| P0 | Vercel AI SDK (openai-compatible) | ✅ |
|
||||
| P0 | Anthropic SDK (messages) | ✅ |
|
||||
| P0 | Google Gemini SDK (generateContent) | ✅ |
|
||||
| P1 | LangChain / LlamaIndex / OpenWebUI (OpenAI-compatible integration) | ✅ |
|
||||
| P2 | MCP standalone bridge | Planned |
|
||||
|
||||
## Model Support
|
||||
|
||||
### OpenAI Endpoint
|
||||
@@ -86,8 +106,33 @@ flowchart LR
|
||||
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
||||
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
||||
|
||||
|
||||
#### Claude Code integration pitfalls (validated)
|
||||
|
||||
- Set `ANTHROPIC_BASE_URL` to the DS2API root URL (for example `http://127.0.0.1:5001`). Claude Code sends requests to `/v1/messages?beta=true`.
|
||||
- `ANTHROPIC_API_KEY` must match an entry in `keys` from `config.json`. Keeping both a regular key and an `sk-ant-*` style key improves client compatibility.
|
||||
- If your environment has proxy variables, set `NO_PROXY=127.0.0.1,localhost,<your_host_ip>` for DS2API to avoid proxy interception of local traffic.
|
||||
- If tool calls are rendered as plain text and not executed, upgrade to a build that includes multi-format Claude tool-call parsing (JSON/XML/ANTML/invoke).
|
||||
|
||||
### Gemini Endpoint
|
||||
|
||||
The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output).
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Universal First Step (all deployment modes)
|
||||
|
||||
Use `config.json` as the single source of truth (recommended):
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
```
|
||||
|
||||
Recommended per deployment mode:
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON`
|
||||
|
||||
### Option 1: Local Run
|
||||
|
||||
**Prerequisites**: Go 1.24+, Node.js 20+ (only if building WebUI locally)
|
||||
@@ -112,26 +157,49 @@ Default URL: `http://localhost:5001`
|
||||
### Option 2: Docker
|
||||
|
||||
```bash
|
||||
# 1. Configure environment
|
||||
# 1. Prepare env file
|
||||
cp .env.example .env
|
||||
# Edit .env
|
||||
|
||||
# 2. Start
|
||||
# 2. Edit .env (at least set DS2API_ADMIN_KEY)
|
||||
# DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||
|
||||
# 3. Start
|
||||
docker-compose up -d
|
||||
|
||||
# 3. View logs
|
||||
# 4. View logs
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
Rebuild after updates: `docker-compose up -d --build`
|
||||
|
||||
#### Zeabur One-Click (Dockerfile)
|
||||
|
||||
1. Click the “Deploy on Zeabur” button above to deploy.
|
||||
2. After deployment, open `/admin` and login with `DS2API_ADMIN_KEY` shown in Zeabur env/template instructions.
|
||||
3. Import / edit config in Admin UI (it will be written and persisted to `/data/config.json`).
|
||||
|
||||
Note: when Zeabur builds directly from the repo `Dockerfile`, you do not need to pass `BUILD_VERSION`. The image prefers that build arg when provided, and automatically falls back to the repo-root `VERSION` file when it is absent.
|
||||
|
||||
### Option 3: Vercel
|
||||
|
||||
1. Fork this repo to your GitHub account
|
||||
2. Import the project on Vercel
|
||||
3. Set environment variables (minimum: `DS2API_ADMIN_KEY` and `DS2API_CONFIG_JSON`)
|
||||
3. Set environment variables (minimum: `DS2API_ADMIN_KEY`; recommended to also set `DS2API_CONFIG_JSON`)
|
||||
4. Deploy
|
||||
|
||||
Recommended first step in repo root:
|
||||
|
||||
```bash
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
```
|
||||
|
||||
Recommended: convert `config.json` to Base64 locally, then paste into `DS2API_CONFIG_JSON` to avoid JSON formatting mistakes:
|
||||
|
||||
```bash
|
||||
base64 < config.json | tr -d '\n'
|
||||
```
|
||||
|
||||
> **Streaming note**: `/v1/chat/completions` on Vercel is routed to `api/chat-stream.js` (Node Runtime) for real-time SSE. Auth, account selection, and session/PoW preparation are still handled by the Go internal prepare endpoint; streaming output (including `tools`) is assembled on Node with Go-aligned anti-leak handling.
|
||||
|
||||
For detailed deployment instructions, see the [Deployment Guide](DEPLOY.en.md).
|
||||
@@ -142,8 +210,8 @@ GitHub Actions automatically builds multi-platform archives on each Release:
|
||||
|
||||
```bash
|
||||
# After downloading the archive for your platform
|
||||
tar -xzf ds2api_v1.7.0_linux_amd64.tar.gz
|
||||
cd ds2api_v1.7.0_linux_amd64
|
||||
tar -xzf ds2api_<tag>_linux_amd64.tar.gz
|
||||
cd ds2api_<tag>_linux_amd64
|
||||
cp config.example.json config.json
|
||||
# Edit config.json
|
||||
./ds2api
|
||||
@@ -164,6 +232,7 @@ cp opencode.json.example opencode.json
|
||||
3. Start OpenCode CLI in the project directory (run `opencode` using your installed method).
|
||||
|
||||
> Recommended: use the OpenAI-compatible path (`/v1/*`) via `@ai-sdk/openai-compatible` as shown in the example.
|
||||
> If your client supports `wire_api`, test both `responses` and `chat`; DS2API supports both paths.
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -175,26 +244,61 @@ cp opencode.json.example opencode.json
|
||||
"accounts": [
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
},
|
||||
{
|
||||
"mobile": "12345678901",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
}
|
||||
],
|
||||
"claude_model_mapping": {
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true
|
||||
},
|
||||
"toolcall": {
|
||||
"mode": "feature_match",
|
||||
"early_emit_confidence": "high"
|
||||
},
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
"admin": {
|
||||
"jwt_expire_hours": 24
|
||||
},
|
||||
"runtime": {
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
},
|
||||
"auto_delete": {
|
||||
"sessions": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- `keys`: API access keys; clients authenticate via `Authorization: Bearer <key>`
|
||||
- `accounts`: DeepSeek account list, supports `email` or `mobile` login
|
||||
- `token`: Leave empty for auto-login on first request; or pre-fill an existing token
|
||||
- `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models
|
||||
- `token`: Even if set in `config.json`, it is cleared during load (DS2API does not read persisted tokens from config); runtime tokens are maintained/refreshed in memory only
|
||||
- `model_aliases`: Map common model names (GPT/Codex/Claude) to DeepSeek models
|
||||
- `compat.wide_input_strict_output`: Keep `true` (current default policy)
|
||||
- `toolcall`: Fixed to feature matching + high-confidence early emit
|
||||
- `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}`
|
||||
- `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in)
|
||||
- `claude_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models (still compatible with `claude_model_mapping`)
|
||||
- `admin`: Admin panel settings (JWT expiry, password hash, etc.), hot-reloadable via Admin Settings API
|
||||
- `runtime`: Runtime parameters (concurrency limits, queue sizes), hot-reloadable via Admin Settings API; `account_max_queue=0`/`global_max_inflight=0` means auto-calculate from recommended values
|
||||
- `auto_delete.sessions`: Whether to auto-delete DeepSeek sessions after request completion (default `false`, hot-reloadable via Settings)
|
||||
|
||||
### Environment Variables
|
||||
|
||||
@@ -214,8 +318,13 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_ACCOUNT_CONCURRENCY` | Alias (legacy compat) | — |
|
||||
| `DS2API_ACCOUNT_MAX_QUEUE` | Waiting queue limit | `recommended_concurrency` |
|
||||
| `DS2API_ACCOUNT_QUEUE_SIZE` | Alias (legacy compat) | — |
|
||||
| `DS2API_GLOBAL_MAX_INFLIGHT` | Global max in-flight requests | `recommended_concurrency` |
|
||||
| `DS2API_MAX_INFLIGHT` | Alias (legacy compat) | — |
|
||||
| `DS2API_VERCEL_INTERNAL_SECRET` | Vercel hybrid streaming internal auth | Falls back to `DS2API_ADMIN_KEY` |
|
||||
| `DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS` | Stream lease TTL seconds | `900` |
|
||||
| `DS2API_DEV_PACKET_CAPTURE` | Local dev packet capture switch (record recent request/response bodies) | Enabled by default on non-Vercel local runtime |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | Number of captured sessions to retain (auto-evict overflow) | `5` |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | Max recorded bytes per captured response body | `2097152` |
|
||||
| `VERCEL_TOKEN` | Vercel sync token | — |
|
||||
| `VERCEL_PROJECT_ID` | Vercel project ID | — |
|
||||
| `VERCEL_TEAM_ID` | Vercel team ID | — |
|
||||
@@ -223,7 +332,7 @@ cp opencode.json.example opencode.json
|
||||
|
||||
## Authentication Modes
|
||||
|
||||
For business endpoints (`/v1/*`, `/anthropic/*`), DS2API supports two modes:
|
||||
For business endpoints (`/v1/*`, `/anthropic/*`, Gemini routes), DS2API supports two modes:
|
||||
|
||||
| Mode | Description |
|
||||
| --- | --- |
|
||||
@@ -249,10 +358,35 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
||||
|
||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||
|
||||
1. With `stream=true`, DS2API **buffers** text deltas first
|
||||
2. If a tool call is detected → only structured `tool_calls` are emitted, raw JSON is not leaked
|
||||
3. If no tool call → buffered text is emitted at once
|
||||
4. Parser supports mixed text, fenced JSON, and `function.arguments` payloads
|
||||
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||
- In non-code-block context, tool JSON may still be recognized even when mixed with normal prose; surrounding prose can remain as text output.
|
||||
2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
||||
3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls
|
||||
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
||||
5. Valid tool call events are only emitted after passing policy validation, preventing invalid tool names from entering the client execution chain
|
||||
|
||||
## Local Dev Packet Capture
|
||||
|
||||
This is for debugging issues such as Responses reasoning streaming and tool-call handoff. When enabled, DS2API stores the latest N DeepSeek conversation payload pairs (request body + upstream response body), defaulting to 5 entries with auto-eviction.
|
||||
|
||||
Enable example:
|
||||
|
||||
```bash
|
||||
DS2API_DEV_PACKET_CAPTURE=true \
|
||||
DS2API_DEV_PACKET_CAPTURE_LIMIT=5 \
|
||||
go run ./cmd/ds2api
|
||||
```
|
||||
|
||||
Inspect/clear (Admin JWT required):
|
||||
|
||||
- `GET /admin/dev/captures`: list captured items (newest first)
|
||||
- `DELETE /admin/dev/captures`: clear captured items
|
||||
|
||||
Response fields include:
|
||||
|
||||
- `request_body`: full payload sent to DeepSeek
|
||||
- `response_body`: concatenated raw upstream stream body text
|
||||
- `response_truncated`: whether body-size truncation happened
|
||||
|
||||
## Project Structure
|
||||
|
||||
@@ -264,30 +398,42 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go entry
|
||||
│ ├── chat-stream.js # Vercel Node.js stream relay
|
||||
│ └── helpers/ # Node.js helper modules
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # Account pool and concurrency queue
|
||||
│ ├── adapter/
|
||||
│ │ ├── openai/ # OpenAI adapter (incl. tool call parsing, Vercel stream prepare/release)
|
||||
│ │ └── claude/ # Claude adapter
|
||||
│ ├── admin/ # Admin API handlers
|
||||
│ │ ├── claude/ # Claude adapter
|
||||
│ │ └── gemini/ # Gemini adapter (generateContent / streamGenerateContent)
|
||||
│ ├── admin/ # Admin API handlers (incl. Settings hot-reload)
|
||||
│ ├── auth/ # Auth and JWT
|
||||
│ ├── claudeconv/ # Claude message format conversion
|
||||
│ ├── compat/ # Compatibility helpers
|
||||
│ ├── config/ # Config loading and hot-reload
|
||||
│ ├── deepseek/ # DeepSeek API client, PoW WASM
|
||||
│ ├── js/ # Node runtime stream/compat logic
|
||||
│ ├── devcapture/ # Dev packet capture module
|
||||
│ ├── format/ # Output formatting
|
||||
│ ├── prompt/ # Prompt construction
|
||||
│ ├── server/ # HTTP routing and middleware (chi router)
|
||||
│ ├── sse/ # SSE parsing utilities
|
||||
│ ├── stream/ # Unified stream consumption engine
|
||||
│ ├── util/ # Common utilities
|
||||
│ └── webui/ # WebUI static file serving and auto-build
|
||||
├── webui/ # React WebUI source (Vite + Tailwind)
|
||||
│ └── src/
|
||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||
│ ├── app/ # Routing, auth, config state
|
||||
│ ├── features/ # Feature modules (account/settings/vercel/apiTester)
|
||||
│ ├── components/ # Shared UI pieces (login/landing, etc.)
|
||||
│ └── locales/ # Language packs (zh.json / en.json)
|
||||
├── scripts/
|
||||
│ ├── build-webui.sh # Manual WebUI build script
|
||||
│ └── testsuite/ # Testsuite runner scripts
|
||||
│ └── build-webui.sh # Manual WebUI build script
|
||||
├── tests/
|
||||
│ ├── compat/ # Compatibility fixtures and expected outputs
|
||||
│ └── scripts/ # Unified test script entrypoints (unit/e2e)
|
||||
├── static/admin/ # WebUI build output (not committed to Git)
|
||||
├── .github/
|
||||
│ ├── workflows/ # GitHub Actions (Release artifact automation)
|
||||
│ ├── workflows/ # GitHub Actions (quality gates + release automation)
|
||||
│ ├── ISSUE_TEMPLATE/ # Issue templates
|
||||
│ └── PULL_REQUEST_TEMPLATE.md
|
||||
├── config.example.json # Config file template
|
||||
@@ -296,8 +442,7 @@ ds2api/
|
||||
├── docker-compose.yml # Production Docker Compose
|
||||
├── docker-compose.dev.yml # Development Docker Compose
|
||||
├── vercel.json # Vercel routing and build config
|
||||
├── go.mod / go.sum # Go module dependencies
|
||||
└── version.txt # Version number
|
||||
└── go.mod / go.sum # Go module dependencies
|
||||
```
|
||||
|
||||
## Documentation Index
|
||||
@@ -312,11 +457,11 @@ ds2api/
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Unit tests
|
||||
go test ./...
|
||||
# Unit tests (Go + Node)
|
||||
./tests/scripts/run-unit-all.sh
|
||||
|
||||
# One-command live end-to-end tests (real accounts, full request/response logs)
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
|
||||
# Or with custom flags
|
||||
go run ./cmd/ds2api-tests \
|
||||
@@ -327,13 +472,22 @@ go run ./cmd/ds2api-tests \
|
||||
--retries 2
|
||||
```
|
||||
|
||||
```bash
|
||||
# Release-blocking gates
|
||||
./tests/scripts/check-stage6-manual-smoke.sh
|
||||
./tests/scripts/check-refactor-line-gate.sh
|
||||
./tests/scripts/run-unit-all.sh
|
||||
npm ci --prefix webui && npm run build --prefix webui
|
||||
```
|
||||
|
||||
## Release Artifact Automation (GitHub Actions)
|
||||
|
||||
Workflow: `.github/workflows/release-artifacts.yml`
|
||||
|
||||
- **Trigger**: only on GitHub Release `published` (normal pushes do not trigger builds)
|
||||
- **Outputs**: multi-platform archives (`linux/amd64`, `linux/arm64`, `darwin/amd64`, `darwin/arm64`, `windows/amd64`) + `sha256sums.txt`
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file, config template, README, LICENSE
|
||||
- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`)
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file (with embedded fallback support), config template, README, LICENSE
|
||||
|
||||
## Disclaimer
|
||||
|
||||
|
||||
74
TESTING.md
74
TESTING.md
@@ -8,8 +8,10 @@ DS2API 提供两个层级的测试:
|
||||
|
||||
| 层级 | 命令 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| 单元测试 | `go test ./...` | 不需要真实账号 |
|
||||
| 端到端测试 | `./scripts/testsuite/run-live.sh` | 使用真实账号执行全链路测试 |
|
||||
| 单元测试(Go) | `./tests/scripts/run-unit-go.sh` | 不需要真实账号 |
|
||||
| 单元测试(Node) | `./tests/scripts/run-unit-node.sh` | 不需要真实账号 |
|
||||
| 单元测试(全部) | `./tests/scripts/run-unit-all.sh` | 不需要真实账号 |
|
||||
| 端到端测试 | `./tests/scripts/run-live.sh` | 使用真实账号执行全链路测试 |
|
||||
|
||||
端到端测试集会录制完整的请求/响应日志,用于故障排查。
|
||||
|
||||
@@ -20,26 +22,36 @@ DS2API 提供两个层级的测试:
|
||||
### 单元测试 | Unit Tests
|
||||
|
||||
```bash
|
||||
go test ./...
|
||||
./tests/scripts/run-unit-all.sh
|
||||
```
|
||||
|
||||
```bash
|
||||
node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js
|
||||
# 或按语言拆分执行
|
||||
./tests/scripts/run-unit-go.sh
|
||||
./tests/scripts/run-unit-node.sh
|
||||
```
|
||||
|
||||
```bash
|
||||
# 结构与流程门禁
|
||||
./tests/scripts/check-refactor-line-gate.sh
|
||||
./tests/scripts/check-node-split-syntax.sh
|
||||
|
||||
# 发布阻断:阶段 6 手工烟测签字检查(默认读取 plans/stage6-manual-smoke.md)
|
||||
./tests/scripts/check-stage6-manual-smoke.sh
|
||||
```
|
||||
|
||||
### 端到端测试 | End-to-End Tests
|
||||
|
||||
```bash
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
**默认行为**:
|
||||
|
||||
1. **Preflight 检查**:
|
||||
- `go test ./... -count=1`(单元测试)
|
||||
- `node --check api/chat-stream.js`(语法检查)
|
||||
- `node --check api/helpers/stream-tool-sieve.js`(语法检查)
|
||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js`(Node 流式拦截单测)
|
||||
- `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁)
|
||||
- `node --test tests/node/stream-tool-sieve.test.js tests/node/chat-stream.test.js tests/node/js_compat_test.js`
|
||||
- `npm run build --prefix webui`(WebUI 构建检查)
|
||||
|
||||
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
||||
@@ -161,6 +173,50 @@ rg "<trace_id>" artifacts/testsuite/<run_id>/server.log
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### 运行特定模块的单元测试
|
||||
|
||||
```bash
|
||||
# 运行 tool calls 相关测试(推荐用于调试 tool call 解析问题)
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||
|
||||
# 运行单个测试用例
|
||||
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/
|
||||
|
||||
# 运行 format 相关测试
|
||||
go test -v ./internal/format/...
|
||||
|
||||
# 运行 adapter 相关测试
|
||||
go test -v ./internal/adapter/openai/...
|
||||
```
|
||||
|
||||
### 调试 Tool Call 问题 | Debugging Tool Call Issues
|
||||
|
||||
当遇到 DeepSeek 工具调用解析问题时,可以使用以下方法:
|
||||
|
||||
```bash
|
||||
# 1. 运行 tool calls 相关的所有测试
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||
|
||||
# 2. 查看测试输出中的详细调试信息
|
||||
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/ 2>&1
|
||||
|
||||
# 3. 检查具体测试用例的修复效果
|
||||
# 测试用例位于 internal/util/toolcalls_test.go,包含:
|
||||
# - TestParseToolCallsWithDeepSeekHallucination: DeepSeek 典型幻觉输出
|
||||
# - TestRepairLooseJSONWithNestedObjects: 嵌套对象的方括号修复
|
||||
# - TestParseToolCallsWithMixedWindowsPaths: Windows 路径处理
|
||||
```
|
||||
|
||||
### 运行 Node.js 测试
|
||||
|
||||
```bash
|
||||
# 运行 Node 测试
|
||||
node --test tests/node/stream-tool-sieve.test.js
|
||||
|
||||
# 或使用脚本
|
||||
./tests/scripts/run-unit-node.sh
|
||||
```
|
||||
|
||||
### 跑端到端测试(跳过 preflight)
|
||||
|
||||
```bash
|
||||
@@ -179,7 +235,7 @@ go run ./cmd/ds2api-tests \
|
||||
|
||||
```bash
|
||||
# 确保 config.json 存在且包含有效测试账号
|
||||
./scripts/testsuite/run-live.sh
|
||||
./tests/scripts/run-live.sh
|
||||
exit_code=$?
|
||||
if [ $exit_code -ne 0 ]; then
|
||||
echo "Tests failed! Check artifacts for details."
|
||||
|
||||
@@ -1,770 +1,3 @@
|
||||
'use strict';
|
||||
|
||||
const {
|
||||
extractToolNames,
|
||||
createToolSieveState,
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseToolCalls,
|
||||
formatOpenAIStreamToolCalls,
|
||||
} = require('./helpers/stream-tool-sieve');
|
||||
|
||||
const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion';
|
||||
|
||||
const BASE_HEADERS = {
|
||||
Host: 'chat.deepseek.com',
|
||||
'User-Agent': 'DeepSeek/1.6.11 Android/35',
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
'x-client-platform': 'android',
|
||||
'x-client-version': '1.6.11',
|
||||
'x-client-locale': 'zh_CN',
|
||||
'accept-charset': 'UTF-8',
|
||||
};
|
||||
|
||||
const SKIP_PATTERNS = [
|
||||
'quasi_status',
|
||||
'elapsed_secs',
|
||||
'token_usage',
|
||||
'pending_fragment',
|
||||
'conversation_mode',
|
||||
'fragments/-1/status',
|
||||
'fragments/-2/status',
|
||||
'fragments/-3/status',
|
||||
];
|
||||
|
||||
module.exports = async function handler(req, res) {
|
||||
setCorsHeaders(res);
|
||||
if (req.method === 'OPTIONS') {
|
||||
res.statusCode = 204;
|
||||
res.end();
|
||||
return;
|
||||
}
|
||||
if (req.method !== 'POST') {
|
||||
writeOpenAIError(res, 405, 'method not allowed');
|
||||
return;
|
||||
}
|
||||
|
||||
const rawBody = await readRawBody(req);
|
||||
|
||||
// Hard guard: only use Node data path for streaming on Vercel runtime.
|
||||
// Any non-Vercel runtime always falls back to Go for full behavior parity.
|
||||
if (!isVercelRuntime()) {
|
||||
await proxyToGo(req, res, rawBody);
|
||||
return;
|
||||
}
|
||||
|
||||
let payload;
|
||||
try {
|
||||
payload = JSON.parse(rawBody.toString('utf8') || '{}');
|
||||
} catch (_err) {
|
||||
writeOpenAIError(res, 400, 'invalid json');
|
||||
return;
|
||||
}
|
||||
|
||||
// Keep all non-stream behavior on Go side to avoid compatibility regressions.
|
||||
if (!toBool(payload.stream)) {
|
||||
await proxyToGo(req, res, rawBody);
|
||||
return;
|
||||
}
|
||||
|
||||
const prep = await fetchStreamPrepare(req, rawBody);
|
||||
if (!prep.ok) {
|
||||
relayPreparedFailure(res, prep);
|
||||
return;
|
||||
}
|
||||
|
||||
const model = asString(prep.body.model) || asString(payload.model);
|
||||
const sessionID = asString(prep.body.session_id) || `chatcmpl-${Date.now()}`;
|
||||
const leaseID = asString(prep.body.lease_id);
|
||||
const deepseekToken = asString(prep.body.deepseek_token);
|
||||
const powHeader = asString(prep.body.pow_header);
|
||||
const completionPayload = prep.body.payload && typeof prep.body.payload === 'object' ? prep.body.payload : null;
|
||||
const finalPrompt = asString(prep.body.final_prompt);
|
||||
const thinkingEnabled = toBool(prep.body.thinking_enabled);
|
||||
const searchEnabled = toBool(prep.body.search_enabled);
|
||||
const toolNames = extractToolNames(payload.tools);
|
||||
|
||||
if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
|
||||
writeOpenAIError(res, 500, 'invalid vercel prepare response');
|
||||
return;
|
||||
}
|
||||
const releaseLease = createLeaseReleaser(req, leaseID);
|
||||
try {
|
||||
const completionRes = await fetch(DEEPSEEK_COMPLETION_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
...BASE_HEADERS,
|
||||
authorization: `Bearer ${deepseekToken}`,
|
||||
'x-ds-pow-response': powHeader,
|
||||
},
|
||||
body: JSON.stringify(completionPayload),
|
||||
});
|
||||
|
||||
if (!completionRes.ok || !completionRes.body) {
|
||||
const detail = await safeReadText(completionRes);
|
||||
writeOpenAIError(res, 500, detail ? `Failed to get completion: ${detail}` : 'Failed to get completion.');
|
||||
return;
|
||||
}
|
||||
|
||||
res.statusCode = 200;
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
if (typeof res.flushHeaders === 'function') {
|
||||
res.flushHeaders();
|
||||
}
|
||||
|
||||
const created = Math.floor(Date.now() / 1000);
|
||||
let firstChunkSent = false;
|
||||
let currentType = thinkingEnabled ? 'thinking' : 'text';
|
||||
let thinkingText = '';
|
||||
let outputText = '';
|
||||
const toolSieveEnabled = toolNames.length > 0;
|
||||
const toolSieveState = createToolSieveState();
|
||||
let toolCallsEmitted = false;
|
||||
const decoder = new TextDecoder();
|
||||
const reader = completionRes.body.getReader();
|
||||
let buffered = '';
|
||||
let ended = false;
|
||||
|
||||
const sendFrame = (obj) => {
|
||||
res.write(`data: ${JSON.stringify(obj)}\n\n`);
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
};
|
||||
|
||||
const sendDeltaFrame = (delta) => {
|
||||
const payloadDelta = { ...delta };
|
||||
if (!firstChunkSent) {
|
||||
payloadDelta.role = 'assistant';
|
||||
firstChunkSent = true;
|
||||
}
|
||||
sendFrame({
|
||||
id: sessionID,
|
||||
object: 'chat.completion.chunk',
|
||||
created,
|
||||
model,
|
||||
choices: [{ delta: payloadDelta, index: 0 }],
|
||||
});
|
||||
};
|
||||
|
||||
const finish = async (reason) => {
|
||||
if (ended) {
|
||||
return;
|
||||
}
|
||||
ended = true;
|
||||
const detected = parseToolCalls(outputText, toolNames);
|
||||
if (detected.length > 0 && !toolCallsEmitted) {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) });
|
||||
} else if (toolSieveEnabled) {
|
||||
const tailEvents = flushToolSieve(toolSieveState, toolNames);
|
||||
for (const evt of tailEvents) {
|
||||
if (evt.text) {
|
||||
sendDeltaFrame({ content: evt.text });
|
||||
}
|
||||
}
|
||||
}
|
||||
if (detected.length > 0 || toolCallsEmitted) {
|
||||
reason = 'tool_calls';
|
||||
}
|
||||
sendFrame({
|
||||
id: sessionID,
|
||||
object: 'chat.completion.chunk',
|
||||
created,
|
||||
model,
|
||||
choices: [{ delta: {}, index: 0, finish_reason: reason }],
|
||||
usage: buildUsage(finalPrompt, thinkingText, outputText),
|
||||
});
|
||||
res.write('data: [DONE]\n\n');
|
||||
await releaseLease();
|
||||
res.end();
|
||||
};
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
buffered += decoder.decode(value, { stream: true });
|
||||
const lines = buffered.split('\n');
|
||||
buffered = lines.pop() || '';
|
||||
|
||||
for (const rawLine of lines) {
|
||||
const line = rawLine.trim();
|
||||
if (!line.startsWith('data:')) {
|
||||
continue;
|
||||
}
|
||||
const dataStr = line.slice(5).trim();
|
||||
if (!dataStr) {
|
||||
continue;
|
||||
}
|
||||
if (dataStr === '[DONE]') {
|
||||
await finish('stop');
|
||||
return;
|
||||
}
|
||||
let chunk;
|
||||
try {
|
||||
chunk = JSON.parse(dataStr);
|
||||
} catch (_err) {
|
||||
continue;
|
||||
}
|
||||
if (chunk.error || chunk.code === 'content_filter') {
|
||||
await finish('content_filter');
|
||||
return;
|
||||
}
|
||||
const parsed = parseChunkForContent(chunk, thinkingEnabled, currentType);
|
||||
currentType = parsed.newType;
|
||||
if (parsed.finished) {
|
||||
await finish('stop');
|
||||
return;
|
||||
}
|
||||
|
||||
for (const p of parsed.parts) {
|
||||
if (!p.text) {
|
||||
continue;
|
||||
}
|
||||
if (searchEnabled && isCitation(p.text)) {
|
||||
continue;
|
||||
}
|
||||
if (p.type === 'thinking') {
|
||||
if (thinkingEnabled) {
|
||||
thinkingText += p.text;
|
||||
sendDeltaFrame({ reasoning_content: p.text });
|
||||
}
|
||||
} else {
|
||||
outputText += p.text;
|
||||
if (!toolSieveEnabled) {
|
||||
sendDeltaFrame({ content: p.text });
|
||||
continue;
|
||||
}
|
||||
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
||||
for (const evt of events) {
|
||||
if (evt.type === 'tool_calls') {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) });
|
||||
continue;
|
||||
}
|
||||
if (evt.text) {
|
||||
sendDeltaFrame({ content: evt.text });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
await finish('stop');
|
||||
} catch (_err) {
|
||||
await finish('stop');
|
||||
}
|
||||
} finally {
|
||||
await releaseLease();
|
||||
}
|
||||
};
|
||||
|
||||
function setCorsHeaders(res) {
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
|
||||
res.setHeader(
|
||||
'Access-Control-Allow-Headers',
|
||||
'Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass',
|
||||
);
|
||||
}
|
||||
|
||||
function header(req, key) {
|
||||
if (!req || !req.headers) {
|
||||
return '';
|
||||
}
|
||||
return asString(req.headers[key.toLowerCase()]);
|
||||
}
|
||||
|
||||
async function readRawBody(req) {
|
||||
if (Buffer.isBuffer(req.body)) {
|
||||
return req.body;
|
||||
}
|
||||
if (typeof req.body === 'string') {
|
||||
return Buffer.from(req.body);
|
||||
}
|
||||
if (req.body && typeof req.body === 'object') {
|
||||
return Buffer.from(JSON.stringify(req.body));
|
||||
}
|
||||
const chunks = [];
|
||||
for await (const chunk of req) {
|
||||
chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk));
|
||||
}
|
||||
return Buffer.concat(chunks);
|
||||
}
|
||||
|
||||
async function fetchStreamPrepare(req, rawBody) {
|
||||
const url = buildInternalGoURL(req);
|
||||
url.searchParams.set('__stream_prepare', '1');
|
||||
|
||||
const upstream = await fetch(url.toString(), {
|
||||
method: 'POST',
|
||||
headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }),
|
||||
body: rawBody,
|
||||
});
|
||||
|
||||
const text = await upstream.text();
|
||||
let body = {};
|
||||
try {
|
||||
body = JSON.parse(text || '{}');
|
||||
} catch (_err) {
|
||||
body = {};
|
||||
}
|
||||
|
||||
return {
|
||||
ok: upstream.ok,
|
||||
status: upstream.status,
|
||||
contentType: upstream.headers.get('content-type') || 'application/json',
|
||||
text,
|
||||
body,
|
||||
};
|
||||
}
|
||||
|
||||
function relayPreparedFailure(res, prep) {
|
||||
if (prep.status === 401 && looksLikeVercelAuthPage(prep.text)) {
|
||||
writeOpenAIError(
|
||||
res,
|
||||
401,
|
||||
'Vercel Deployment Protection blocked internal prepare request. Disable protection for this deployment or set VERCEL_AUTOMATION_BYPASS_SECRET.',
|
||||
);
|
||||
return;
|
||||
}
|
||||
res.statusCode = prep.status || 500;
|
||||
res.setHeader('Content-Type', prep.contentType || 'application/json');
|
||||
if (prep.text) {
|
||||
res.end(prep.text);
|
||||
return;
|
||||
}
|
||||
writeOpenAIError(res, prep.status || 500, 'vercel prepare failed');
|
||||
}
|
||||
|
||||
async function safeReadText(resp) {
|
||||
if (!resp) {
|
||||
return '';
|
||||
}
|
||||
try {
|
||||
const text = await resp.text();
|
||||
return text.trim();
|
||||
} catch (_err) {
|
||||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
function internalSecret() {
|
||||
return asString(process.env.DS2API_VERCEL_INTERNAL_SECRET) || asString(process.env.DS2API_ADMIN_KEY) || 'admin';
|
||||
}
|
||||
|
||||
function buildInternalGoURL(req) {
|
||||
const proto = asString(header(req, 'x-forwarded-proto')) || 'https';
|
||||
const host = asString(header(req, 'host'));
|
||||
const url = new URL(`${proto}://${host}${req.url || '/v1/chat/completions'}`);
|
||||
url.searchParams.set('__go', '1');
|
||||
const protectionBypass = resolveProtectionBypass(req);
|
||||
if (protectionBypass) {
|
||||
url.searchParams.set('x-vercel-protection-bypass', protectionBypass);
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
function buildInternalGoHeaders(req, opts = {}) {
|
||||
const headers = {
|
||||
authorization: asString(header(req, 'authorization')),
|
||||
'x-api-key': asString(header(req, 'x-api-key')),
|
||||
'x-ds2-target-account': asString(header(req, 'x-ds2-target-account')),
|
||||
'x-vercel-protection-bypass': resolveProtectionBypass(req),
|
||||
};
|
||||
if (opts.withInternalToken) {
|
||||
headers['x-ds2-internal-token'] = internalSecret();
|
||||
}
|
||||
if (opts.withContentType) {
|
||||
headers['content-type'] = asString(header(req, 'content-type')) || 'application/json';
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
function createLeaseReleaser(req, leaseID) {
|
||||
let released = false;
|
||||
return async () => {
|
||||
if (released || !leaseID) {
|
||||
return;
|
||||
}
|
||||
released = true;
|
||||
try {
|
||||
await releaseStreamLease(req, leaseID);
|
||||
} catch (_err) {
|
||||
// Ignore release errors. Lease TTL cleanup on Go side still prevents permanent leaks.
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async function releaseStreamLease(req, leaseID) {
|
||||
const url = buildInternalGoURL(req);
|
||||
url.searchParams.set('__stream_release', '1');
|
||||
const body = Buffer.from(JSON.stringify({ lease_id: leaseID }));
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeout = setTimeout(() => controller.abort(), 1500);
|
||||
try {
|
||||
await fetch(url.toString(), {
|
||||
method: 'POST',
|
||||
headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }),
|
||||
body,
|
||||
signal: controller.signal,
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
}
|
||||
|
||||
function resolveProtectionBypass(req) {
|
||||
const fromHeader = asString(header(req, 'x-vercel-protection-bypass'));
|
||||
if (fromHeader) {
|
||||
return fromHeader;
|
||||
}
|
||||
return asString(process.env.VERCEL_AUTOMATION_BYPASS_SECRET) || asString(process.env.DS2API_VERCEL_PROTECTION_BYPASS);
|
||||
}
|
||||
|
||||
function looksLikeVercelAuthPage(text) {
|
||||
const body = asString(text).toLowerCase();
|
||||
if (!body) {
|
||||
return false;
|
||||
}
|
||||
return body.includes('authentication required') && body.includes('vercel');
|
||||
}
|
||||
|
||||
function parseChunkForContent(chunk, thinkingEnabled, currentType) {
|
||||
if (!chunk || typeof chunk !== 'object' || !Object.prototype.hasOwnProperty.call(chunk, 'v')) {
|
||||
return { parts: [], finished: false, newType: currentType };
|
||||
}
|
||||
const pathValue = asString(chunk.p);
|
||||
if (shouldSkipPath(pathValue)) {
|
||||
return { parts: [], finished: false, newType: currentType };
|
||||
}
|
||||
if (pathValue === 'response/status' && asString(chunk.v) === 'FINISHED') {
|
||||
return { parts: [], finished: true, newType: currentType };
|
||||
}
|
||||
|
||||
let newType = currentType;
|
||||
const parts = [];
|
||||
|
||||
if (pathValue === 'response/fragments' && asString(chunk.o).toUpperCase() === 'APPEND' && Array.isArray(chunk.v)) {
|
||||
for (const frag of chunk.v) {
|
||||
if (!frag || typeof frag !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const fragType = asString(frag.type).toUpperCase();
|
||||
const content = asString(frag.content);
|
||||
if (!content) {
|
||||
continue;
|
||||
}
|
||||
if (fragType === 'THINK' || fragType === 'THINKING') {
|
||||
newType = 'thinking';
|
||||
parts.push({ text: content, type: 'thinking' });
|
||||
} else if (fragType === 'RESPONSE') {
|
||||
newType = 'text';
|
||||
parts.push({ text: content, type: 'text' });
|
||||
} else {
|
||||
parts.push({ text: content, type: 'text' });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (pathValue === 'response' && Array.isArray(chunk.v)) {
|
||||
for (const item of chunk.v) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
continue;
|
||||
}
|
||||
if (item.p === 'fragments' && item.o === 'APPEND' && Array.isArray(item.v)) {
|
||||
for (const frag of item.v) {
|
||||
const fragType = asString(frag && frag.type).toUpperCase();
|
||||
if (fragType === 'THINK' || fragType === 'THINKING') {
|
||||
newType = 'thinking';
|
||||
} else if (fragType === 'RESPONSE') {
|
||||
newType = 'text';
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let partType = 'text';
|
||||
if (pathValue === 'response/thinking_content') {
|
||||
partType = 'thinking';
|
||||
} else if (pathValue === 'response/content') {
|
||||
partType = 'text';
|
||||
} else if (pathValue.includes('response/fragments') && pathValue.includes('/content')) {
|
||||
partType = newType;
|
||||
} else if (!pathValue && thinkingEnabled) {
|
||||
partType = newType;
|
||||
}
|
||||
|
||||
const val = chunk.v;
|
||||
if (typeof val === 'string') {
|
||||
if (val === 'FINISHED' && (!pathValue || pathValue === 'status')) {
|
||||
return { parts: [], finished: true, newType };
|
||||
}
|
||||
if (val) {
|
||||
parts.push({ text: val, type: partType });
|
||||
}
|
||||
return { parts, finished: false, newType };
|
||||
}
|
||||
|
||||
if (Array.isArray(val)) {
|
||||
const extracted = extractContentRecursive(val, partType);
|
||||
if (extracted.finished) {
|
||||
return { parts: [], finished: true, newType };
|
||||
}
|
||||
parts.push(...extracted.parts);
|
||||
return { parts, finished: false, newType };
|
||||
}
|
||||
|
||||
if (val && typeof val === 'object') {
|
||||
const resp = val.response && typeof val.response === 'object' ? val.response : val;
|
||||
if (Array.isArray(resp.fragments)) {
|
||||
for (const frag of resp.fragments) {
|
||||
if (!frag || typeof frag !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const content = asString(frag.content);
|
||||
if (!content) {
|
||||
continue;
|
||||
}
|
||||
const t = asString(frag.type).toUpperCase();
|
||||
if (t === 'THINK' || t === 'THINKING') {
|
||||
newType = 'thinking';
|
||||
parts.push({ text: content, type: 'thinking' });
|
||||
} else if (t === 'RESPONSE') {
|
||||
newType = 'text';
|
||||
parts.push({ text: content, type: 'text' });
|
||||
} else {
|
||||
parts.push({ text: content, type: partType });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return { parts, finished: false, newType };
|
||||
}
|
||||
|
||||
function extractContentRecursive(items, defaultType) {
|
||||
const parts = [];
|
||||
for (const it of items) {
|
||||
if (!it || typeof it !== 'object') {
|
||||
continue;
|
||||
}
|
||||
if (!Object.prototype.hasOwnProperty.call(it, 'v')) {
|
||||
continue;
|
||||
}
|
||||
const itemPath = asString(it.p);
|
||||
const itemV = it.v;
|
||||
if (itemPath === 'status' && asString(itemV) === 'FINISHED') {
|
||||
return { parts: [], finished: true };
|
||||
}
|
||||
if (shouldSkipPath(itemPath)) {
|
||||
continue;
|
||||
}
|
||||
const content = asString(it.content);
|
||||
if (content) {
|
||||
const typeName = asString(it.type).toUpperCase();
|
||||
if (typeName === 'THINK' || typeName === 'THINKING') {
|
||||
parts.push({ text: content, type: 'thinking' });
|
||||
} else if (typeName === 'RESPONSE') {
|
||||
parts.push({ text: content, type: 'text' });
|
||||
} else {
|
||||
parts.push({ text: content, type: defaultType });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let partType = defaultType;
|
||||
if (itemPath.includes('thinking')) {
|
||||
partType = 'thinking';
|
||||
} else if (itemPath.includes('content') || itemPath === 'response' || itemPath === 'fragments') {
|
||||
partType = 'text';
|
||||
}
|
||||
|
||||
if (typeof itemV === 'string') {
|
||||
if (itemV && itemV !== 'FINISHED') {
|
||||
parts.push({ text: itemV, type: partType });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!Array.isArray(itemV)) {
|
||||
continue;
|
||||
}
|
||||
for (const inner of itemV) {
|
||||
if (typeof inner === 'string') {
|
||||
if (inner) {
|
||||
parts.push({ text: inner, type: partType });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!inner || typeof inner !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const ct = asString(inner.content);
|
||||
if (!ct) {
|
||||
continue;
|
||||
}
|
||||
const typeName = asString(inner.type).toUpperCase();
|
||||
if (typeName === 'THINK' || typeName === 'THINKING') {
|
||||
parts.push({ text: ct, type: 'thinking' });
|
||||
} else if (typeName === 'RESPONSE') {
|
||||
parts.push({ text: ct, type: 'text' });
|
||||
} else {
|
||||
parts.push({ text: ct, type: partType });
|
||||
}
|
||||
}
|
||||
}
|
||||
return { parts, finished: false };
|
||||
}
|
||||
|
||||
function shouldSkipPath(pathValue) {
|
||||
if (pathValue === 'response/search_status') {
|
||||
return true;
|
||||
}
|
||||
for (const p of SKIP_PATTERNS) {
|
||||
if (pathValue.includes(p)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function isCitation(text) {
|
||||
return asString(text).trim().startsWith('[citation:');
|
||||
}
|
||||
|
||||
function buildUsage(prompt, thinking, output) {
|
||||
const promptTokens = estimateTokens(prompt);
|
||||
const reasoningTokens = estimateTokens(thinking);
|
||||
const completionTokens = estimateTokens(output);
|
||||
return {
|
||||
prompt_tokens: promptTokens,
|
||||
completion_tokens: reasoningTokens + completionTokens,
|
||||
total_tokens: promptTokens + reasoningTokens + completionTokens,
|
||||
completion_tokens_details: {
|
||||
reasoning_tokens: reasoningTokens,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function estimateTokens(text) {
|
||||
const t = asString(text);
|
||||
if (!t) {
|
||||
return 0;
|
||||
}
|
||||
const n = Math.floor(Array.from(t).length / 4);
|
||||
return n < 1 ? 1 : n;
|
||||
}
|
||||
|
||||
async function proxyToGo(req, res, rawBody) {
|
||||
const url = buildInternalGoURL(req);
|
||||
|
||||
const upstream = await fetch(url.toString(), {
|
||||
method: 'POST',
|
||||
headers: buildInternalGoHeaders(req, { withContentType: true }),
|
||||
body: rawBody,
|
||||
});
|
||||
|
||||
res.statusCode = upstream.status;
|
||||
upstream.headers.forEach((value, key) => {
|
||||
if (key.toLowerCase() === 'content-length') {
|
||||
return;
|
||||
}
|
||||
res.setHeader(key, value);
|
||||
});
|
||||
|
||||
if (!upstream.body || typeof upstream.body.getReader !== 'function') {
|
||||
const bytes = Buffer.from(await upstream.arrayBuffer());
|
||||
res.end(bytes);
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = upstream.body.getReader();
|
||||
try {
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
if (value && value.length > 0) {
|
||||
res.write(Buffer.from(value));
|
||||
if (typeof res.flush === 'function') {
|
||||
res.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
res.end();
|
||||
} catch (_err) {
|
||||
if (!res.writableEnded) {
|
||||
res.end();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function writeOpenAIError(res, status, message) {
|
||||
res.statusCode = status;
|
||||
res.setHeader('Content-Type', 'application/json');
|
||||
res.end(
|
||||
JSON.stringify({
|
||||
error: {
|
||||
message,
|
||||
type: openAIErrorType(status),
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
function openAIErrorType(status) {
|
||||
switch (status) {
|
||||
case 400:
|
||||
return 'invalid_request_error';
|
||||
case 401:
|
||||
return 'authentication_error';
|
||||
case 403:
|
||||
return 'permission_error';
|
||||
case 429:
|
||||
return 'rate_limit_error';
|
||||
case 503:
|
||||
return 'service_unavailable_error';
|
||||
default:
|
||||
return status >= 500 ? 'api_error' : 'invalid_request_error';
|
||||
}
|
||||
}
|
||||
|
||||
function toBool(v) {
|
||||
return v === true;
|
||||
}
|
||||
|
||||
function isVercelRuntime() {
|
||||
return asString(process.env.VERCEL) !== '' || asString(process.env.NOW_REGION) !== '';
|
||||
}
|
||||
|
||||
function asString(v) {
|
||||
if (typeof v === 'string') {
|
||||
return v.trim();
|
||||
}
|
||||
if (Array.isArray(v)) {
|
||||
return asString(v[0]);
|
||||
}
|
||||
if (v == null) {
|
||||
return '';
|
||||
}
|
||||
return String(v).trim();
|
||||
}
|
||||
|
||||
module.exports.__test = {
|
||||
parseChunkForContent,
|
||||
extractContentRecursive,
|
||||
shouldSkipPath,
|
||||
asString,
|
||||
};
|
||||
module.exports = require('../internal/js/chat-stream/index.js');
|
||||
|
||||
@@ -1,477 +0,0 @@
|
||||
'use strict';
|
||||
|
||||
const crypto = require('crypto');
|
||||
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
||||
|
||||
function extractToolNames(tools) {
|
||||
if (!Array.isArray(tools) || tools.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const t of tools) {
|
||||
if (!t || typeof t !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const fn = t.function && typeof t.function === 'object' ? t.function : t;
|
||||
const name = toStringSafe(fn.name);
|
||||
// Keep parity with Go injectToolPrompt: object tools without name still
|
||||
// enter tool mode via fallback name "unknown".
|
||||
out.push(name || 'unknown');
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function createToolSieveState() {
|
||||
return {
|
||||
pending: '',
|
||||
capture: '',
|
||||
capturing: false,
|
||||
};
|
||||
}
|
||||
|
||||
function processToolSieveChunk(state, chunk, toolNames) {
|
||||
if (!state) {
|
||||
return [];
|
||||
}
|
||||
if (chunk) {
|
||||
state.pending += chunk;
|
||||
}
|
||||
const events = [];
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
if (state.capturing) {
|
||||
if (state.pending) {
|
||||
state.capture += state.pending;
|
||||
state.pending = '';
|
||||
}
|
||||
const consumed = consumeToolCapture(state.capture, toolNames);
|
||||
if (!consumed.ready) {
|
||||
break;
|
||||
}
|
||||
state.capture = '';
|
||||
state.capturing = false;
|
||||
if (consumed.prefix) {
|
||||
events.push({ type: 'text', text: consumed.prefix });
|
||||
}
|
||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
||||
events.push({ type: 'tool_calls', calls: consumed.calls });
|
||||
}
|
||||
if (consumed.suffix) {
|
||||
state.pending += consumed.suffix;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!state.pending) {
|
||||
break;
|
||||
}
|
||||
|
||||
const start = findToolSegmentStart(state.pending);
|
||||
if (start >= 0) {
|
||||
const prefix = state.pending.slice(0, start);
|
||||
if (prefix) {
|
||||
events.push({ type: 'text', text: prefix });
|
||||
}
|
||||
state.capture = state.pending.slice(start);
|
||||
state.pending = '';
|
||||
state.capturing = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
const [safe, hold] = splitSafeContentForToolDetection(state.pending);
|
||||
if (!safe) {
|
||||
break;
|
||||
}
|
||||
state.pending = hold;
|
||||
events.push({ type: 'text', text: safe });
|
||||
}
|
||||
return events;
|
||||
}
|
||||
|
||||
function flushToolSieve(state, toolNames) {
|
||||
if (!state) {
|
||||
return [];
|
||||
}
|
||||
const events = processToolSieveChunk(state, '', toolNames);
|
||||
if (state.capturing) {
|
||||
const consumed = consumeToolCapture(state.capture, toolNames);
|
||||
if (consumed.ready) {
|
||||
if (consumed.prefix) {
|
||||
events.push({ type: 'text', text: consumed.prefix });
|
||||
}
|
||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
||||
events.push({ type: 'tool_calls', calls: consumed.calls });
|
||||
}
|
||||
if (consumed.suffix) {
|
||||
events.push({ type: 'text', text: consumed.suffix });
|
||||
}
|
||||
} else if (state.capture) {
|
||||
// Incomplete captured tool JSON at stream end: suppress raw capture.
|
||||
}
|
||||
state.capture = '';
|
||||
state.capturing = false;
|
||||
}
|
||||
if (state.pending) {
|
||||
events.push({ type: 'text', text: state.pending });
|
||||
state.pending = '';
|
||||
}
|
||||
return events;
|
||||
}
|
||||
|
||||
function splitSafeContentForToolDetection(s) {
|
||||
const text = s || '';
|
||||
if (!text) {
|
||||
return ['', ''];
|
||||
}
|
||||
const suspiciousStart = findSuspiciousPrefixStart(text);
|
||||
if (suspiciousStart < 0) {
|
||||
return [text, ''];
|
||||
}
|
||||
if (suspiciousStart > 0) {
|
||||
return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)];
|
||||
}
|
||||
// If suspicious content starts at the beginning, keep holding until we can
|
||||
// either parse a full tool JSON block or reach stream flush.
|
||||
return ['', text];
|
||||
}
|
||||
|
||||
function findSuspiciousPrefixStart(s) {
|
||||
let start = -1;
|
||||
for (const needle of ['{', '[', '```']) {
|
||||
const idx = s.lastIndexOf(needle);
|
||||
if (idx > start) {
|
||||
start = idx;
|
||||
}
|
||||
}
|
||||
return start;
|
||||
}
|
||||
|
||||
function findToolSegmentStart(s) {
|
||||
if (!s) {
|
||||
return -1;
|
||||
}
|
||||
const lower = s.toLowerCase();
|
||||
const keyIdx = lower.indexOf('tool_calls');
|
||||
if (keyIdx < 0) {
|
||||
return -1;
|
||||
}
|
||||
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
||||
return start >= 0 ? start : keyIdx;
|
||||
}
|
||||
|
||||
function consumeToolCapture(captured, toolNames) {
|
||||
if (!captured) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const lower = captured.toLowerCase();
|
||||
const keyIdx = lower.indexOf('tool_calls');
|
||||
if (keyIdx < 0) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
||||
if (start < 0) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const obj = extractJSONObjectFrom(captured, start);
|
||||
if (!obj.ok) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames);
|
||||
if (parsed.length === 0) {
|
||||
// `tool_calls` key exists but strict JSON parse failed.
|
||||
// Drop the captured object body to avoid leaking raw tool JSON.
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured.slice(0, start),
|
||||
calls: [],
|
||||
suffix: captured.slice(obj.end),
|
||||
};
|
||||
}
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured.slice(0, start),
|
||||
calls: parsed,
|
||||
suffix: captured.slice(obj.end),
|
||||
};
|
||||
}
|
||||
|
||||
function extractJSONObjectFrom(text, start) {
|
||||
if (!text || start < 0 || start >= text.length || text[start] !== '{') {
|
||||
return { ok: false, end: 0 };
|
||||
}
|
||||
let depth = 0;
|
||||
let quote = '';
|
||||
let escaped = false;
|
||||
for (let i = start; i < text.length; i += 1) {
|
||||
const ch = text[i];
|
||||
if (quote) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if (ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === quote) {
|
||||
quote = '';
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '"' || ch === "'") {
|
||||
quote = ch;
|
||||
continue;
|
||||
}
|
||||
if (ch === '{') {
|
||||
depth += 1;
|
||||
continue;
|
||||
}
|
||||
if (ch === '}') {
|
||||
depth -= 1;
|
||||
if (depth === 0) {
|
||||
return { ok: true, end: i + 1 };
|
||||
}
|
||||
}
|
||||
}
|
||||
return { ok: false, end: 0 };
|
||||
}
|
||||
|
||||
function parseToolCalls(text, toolNames) {
|
||||
if (!toStringSafe(text)) {
|
||||
return [];
|
||||
}
|
||||
const candidates = buildToolCallCandidates(text);
|
||||
let parsed = [];
|
||||
for (const c of candidates) {
|
||||
parsed = parseToolCallsPayload(c);
|
||||
if (parsed.length > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const allowed = new Set((toolNames || []).filter(Boolean));
|
||||
const out = [];
|
||||
for (const tc of parsed) {
|
||||
if (!tc || !tc.name) {
|
||||
continue;
|
||||
}
|
||||
if (allowed.size > 0 && !allowed.has(tc.name)) {
|
||||
continue;
|
||||
}
|
||||
out.push({ name: tc.name, input: tc.input || {} });
|
||||
}
|
||||
if (out.length === 0 && parsed.length > 0) {
|
||||
for (const tc of parsed) {
|
||||
if (!tc || !tc.name) {
|
||||
continue;
|
||||
}
|
||||
out.push({ name: tc.name, input: tc.input || {} });
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function buildToolCallCandidates(text) {
|
||||
const trimmed = toStringSafe(text);
|
||||
const candidates = [trimmed];
|
||||
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
|
||||
for (const block of fenced) {
|
||||
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
||||
if (m && m[1]) {
|
||||
candidates.push(toStringSafe(m[1]));
|
||||
}
|
||||
}
|
||||
for (const candidate of extractToolCallObjects(trimmed)) {
|
||||
candidates.push(toStringSafe(candidate));
|
||||
}
|
||||
const first = trimmed.indexOf('{');
|
||||
const last = trimmed.lastIndexOf('}');
|
||||
if (first >= 0 && last > first) {
|
||||
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
|
||||
}
|
||||
const m = trimmed.match(TOOL_CALL_PATTERN);
|
||||
if (m && m[1]) {
|
||||
candidates.push(`{"tool_calls":[${m[1]}]}`);
|
||||
}
|
||||
return [...new Set(candidates.filter(Boolean))];
|
||||
}
|
||||
|
||||
function extractToolCallObjects(text) {
|
||||
const raw = toStringSafe(text);
|
||||
if (!raw) {
|
||||
return [];
|
||||
}
|
||||
const lower = raw.toLowerCase();
|
||||
const out = [];
|
||||
let offset = 0;
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
let idx = lower.indexOf('tool_calls', offset);
|
||||
if (idx < 0) {
|
||||
break;
|
||||
}
|
||||
let start = raw.slice(0, idx).lastIndexOf('{');
|
||||
while (start >= 0) {
|
||||
const obj = extractJSONObjectFrom(raw, start);
|
||||
if (obj.ok) {
|
||||
out.push(raw.slice(start, obj.end).trim());
|
||||
offset = obj.end;
|
||||
idx = -1;
|
||||
break;
|
||||
}
|
||||
start = raw.slice(0, start).lastIndexOf('{');
|
||||
}
|
||||
if (idx >= 0) {
|
||||
offset = idx + 'tool_calls'.length;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseToolCallsPayload(payload) {
|
||||
let decoded;
|
||||
try {
|
||||
decoded = JSON.parse(payload);
|
||||
} catch (_err) {
|
||||
return [];
|
||||
}
|
||||
if (Array.isArray(decoded)) {
|
||||
return parseToolCallList(decoded);
|
||||
}
|
||||
if (!decoded || typeof decoded !== 'object') {
|
||||
return [];
|
||||
}
|
||||
if (decoded.tool_calls) {
|
||||
return parseToolCallList(decoded.tool_calls);
|
||||
}
|
||||
const one = parseToolCallItem(decoded);
|
||||
return one ? [one] : [];
|
||||
}
|
||||
|
||||
function parseToolCallList(v) {
|
||||
if (!Array.isArray(v)) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const item of v) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const one = parseToolCallItem(item);
|
||||
if (one) {
|
||||
out.push(one);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseToolCallItem(m) {
|
||||
let name = toStringSafe(m.name);
|
||||
let inputRaw = m.input;
|
||||
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
||||
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
||||
if (fn) {
|
||||
if (!name) {
|
||||
name = toStringSafe(fn.name);
|
||||
}
|
||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
|
||||
inputRaw = fn.arguments;
|
||||
hasInput = true;
|
||||
}
|
||||
}
|
||||
if (!hasInput) {
|
||||
for (const k of ['arguments', 'args', 'parameters', 'params']) {
|
||||
if (Object.prototype.hasOwnProperty.call(m, k)) {
|
||||
inputRaw = m[k];
|
||||
hasInput = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!name) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
name,
|
||||
input: parseToolCallInput(inputRaw),
|
||||
};
|
||||
}
|
||||
|
||||
function parseToolCallInput(v) {
|
||||
if (v == null) {
|
||||
return {};
|
||||
}
|
||||
if (typeof v === 'string') {
|
||||
const raw = toStringSafe(v);
|
||||
if (!raw) {
|
||||
return {};
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(raw);
|
||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||
return parsed;
|
||||
}
|
||||
return { _raw: raw };
|
||||
} catch (_err) {
|
||||
return { _raw: raw };
|
||||
}
|
||||
}
|
||||
if (typeof v === 'object' && !Array.isArray(v)) {
|
||||
return v;
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(JSON.stringify(v));
|
||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||
return parsed;
|
||||
}
|
||||
} catch (_err) {
|
||||
return {};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
function formatOpenAIStreamToolCalls(calls) {
|
||||
if (!Array.isArray(calls) || calls.length === 0) {
|
||||
return [];
|
||||
}
|
||||
return calls.map((c, idx) => ({
|
||||
index: idx,
|
||||
id: `call_${newCallID()}`,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: c.name,
|
||||
arguments: JSON.stringify(c.input || {}),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function newCallID() {
|
||||
if (typeof crypto.randomUUID === 'function') {
|
||||
return crypto.randomUUID().replace(/-/g, '');
|
||||
}
|
||||
return `${Date.now()}${Math.floor(Math.random() * 1e9)}`;
|
||||
}
|
||||
|
||||
function toStringSafe(v) {
|
||||
if (typeof v === 'string') {
|
||||
return v.trim();
|
||||
}
|
||||
if (Array.isArray(v)) {
|
||||
return toStringSafe(v[0]);
|
||||
}
|
||||
if (v == null) {
|
||||
return '';
|
||||
}
|
||||
return String(v).trim();
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
extractToolNames,
|
||||
createToolSieveState,
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseToolCalls,
|
||||
formatOpenAIStreamToolCalls,
|
||||
};
|
||||
@@ -1,130 +0,0 @@
|
||||
'use strict';
|
||||
|
||||
const test = require('node:test');
|
||||
const assert = require('node:assert/strict');
|
||||
|
||||
const {
|
||||
extractToolNames,
|
||||
createToolSieveState,
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseToolCalls,
|
||||
} = require('./stream-tool-sieve');
|
||||
|
||||
function runSieve(chunks, toolNames) {
|
||||
const state = createToolSieveState();
|
||||
const events = [];
|
||||
for (const chunk of chunks) {
|
||||
events.push(...processToolSieveChunk(state, chunk, toolNames));
|
||||
}
|
||||
events.push(...flushToolSieve(state, toolNames));
|
||||
return events;
|
||||
}
|
||||
|
||||
function collectText(events) {
|
||||
return events
|
||||
.filter((evt) => evt.type === 'text' && evt.text)
|
||||
.map((evt) => evt.text)
|
||||
.join('');
|
||||
}
|
||||
|
||||
test('extractToolNames keeps tool mode enabled with unknown fallback', () => {
|
||||
const names = extractToolNames([
|
||||
{ function: { description: 'no name tool' } },
|
||||
{ function: { name: ' read_file ' } },
|
||||
{},
|
||||
]);
|
||||
assert.deepEqual(names, ['unknown', 'read_file', 'unknown']);
|
||||
});
|
||||
|
||||
test('parseToolCalls keeps non-object argument strings as _raw (Go parity)', () => {
|
||||
const payload = JSON.stringify({
|
||||
tool_calls: [
|
||||
{ name: 'read_file', input: '123' },
|
||||
{ name: 'list_dir', input: '[1,2,3]' },
|
||||
],
|
||||
});
|
||||
const calls = parseToolCalls(payload, ['read_file', 'list_dir']);
|
||||
assert.deepEqual(calls, [
|
||||
{ name: 'read_file', input: { _raw: '123' } },
|
||||
{ name: 'list_dir', input: { _raw: '[1,2,3]' } },
|
||||
]);
|
||||
});
|
||||
|
||||
test('parseToolCalls still intercepts unknown schema names to avoid leaks', () => {
|
||||
const payload = JSON.stringify({
|
||||
tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }],
|
||||
});
|
||||
const calls = parseToolCalls(payload, ['search']);
|
||||
assert.equal(calls.length, 1);
|
||||
assert.equal(calls[0].name, 'not_in_schema');
|
||||
});
|
||||
|
||||
test('parseToolCalls supports fenced json and function.arguments string payload', () => {
|
||||
const text = [
|
||||
'I will call a tool now.',
|
||||
'```json',
|
||||
'{"tool_calls":[{"function":{"name":"read_file","arguments":"{\\"path\\":\\"README.md\\"}"}}]}',
|
||||
'```',
|
||||
].join('\n');
|
||||
const calls = parseToolCalls(text, ['read_file']);
|
||||
assert.equal(calls.length, 1);
|
||||
assert.equal(calls[0].name, 'read_file');
|
||||
assert.deepEqual(calls[0].input, { path: 'README.md' });
|
||||
});
|
||||
|
||||
test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => {
|
||||
const events = runSieve(
|
||||
[
|
||||
'{"',
|
||||
'tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}',
|
||||
'后置正文C。',
|
||||
],
|
||||
['read_file'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0);
|
||||
assert.equal(hasToolCall, true);
|
||||
assert.equal(leakedText.includes('{'), false);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
assert.equal(leakedText.includes('后置正文C。'), true);
|
||||
});
|
||||
|
||||
test('sieve drops invalid tool json body while preserving surrounding text', () => {
|
||||
const events = runSieve(
|
||||
[
|
||||
'前置正文D。',
|
||||
"{'tool_calls':[{'name':'read_file','input':{'path':'README.MD'}}]}",
|
||||
'后置正文E。',
|
||||
],
|
||||
['read_file'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls');
|
||||
assert.equal(hasToolCall, false);
|
||||
assert.equal(leakedText.includes('前置正文D。'), true);
|
||||
assert.equal(leakedText.includes('后置正文E。'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
});
|
||||
|
||||
test('sieve suppresses incomplete captured tool json on stream finalize', () => {
|
||||
const events = runSieve(
|
||||
['前置正文F。', '{"tool_calls":[{"name":"read_file"'],
|
||||
['read_file'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
assert.equal(leakedText.includes('前置正文F。'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
assert.equal(leakedText.includes('{'), false);
|
||||
});
|
||||
|
||||
test('sieve keeps plain text intact in tool mode when no tool call appears', () => {
|
||||
const events = runSieve(
|
||||
['你好,', '这是普通文本回复。', '请继续。'],
|
||||
['read_file'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls');
|
||||
assert.equal(hasToolCall, false);
|
||||
assert.equal(leakedText, '你好,这是普通文本回复。请继续。');
|
||||
});
|
||||
@@ -2,6 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -28,10 +30,21 @@ func main() {
|
||||
Addr: "0.0.0.0:" + port,
|
||||
Handler: app.Router,
|
||||
}
|
||||
localURL := fmt.Sprintf("http://127.0.0.1:%s", port)
|
||||
lanIP := detectLANIPv4()
|
||||
lanURL := ""
|
||||
if lanIP != "" {
|
||||
lanURL = fmt.Sprintf("http://%s:%s", lanIP, port)
|
||||
}
|
||||
|
||||
// Start server in a goroutine so we can listen for shutdown signals.
|
||||
go func() {
|
||||
config.Logger.Info("starting ds2api", "port", port)
|
||||
if lanURL != "" {
|
||||
config.Logger.Info("starting ds2api", "bind", srv.Addr, "port", port, "local_url", localURL, "lan_url", lanURL, "lan_ip", lanIP)
|
||||
} else {
|
||||
config.Logger.Info("starting ds2api", "bind", srv.Addr, "port", port, "local_url", localURL)
|
||||
config.Logger.Warn("lan ip not detected; check active network interfaces")
|
||||
}
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
config.Logger.Error("server stopped unexpectedly", "error", err)
|
||||
os.Exit(1)
|
||||
@@ -54,3 +67,36 @@ func main() {
|
||||
}
|
||||
config.Logger.Info("server gracefully stopped")
|
||||
}
|
||||
|
||||
func detectLANIPv4() string {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
ip = ip.To4()
|
||||
if ip == nil || !ip.IsPrivate() {
|
||||
continue
|
||||
}
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -9,20 +9,50 @@
|
||||
{
|
||||
"_comment": "邮箱登录方式",
|
||||
"email": "example1@example.com",
|
||||
"password": "your-password-1",
|
||||
"token": ""
|
||||
"password": "your-password-1"
|
||||
},
|
||||
{
|
||||
"_comment": "邮箱登录方式 - 账号2",
|
||||
"email": "example2@example.com",
|
||||
"password": "your-password-2",
|
||||
"token": ""
|
||||
"password": "your-password-2"
|
||||
},
|
||||
{
|
||||
"_comment": "手机号登录方式(中国大陆)",
|
||||
"mobile": "12345678901",
|
||||
"password": "your-password-3",
|
||||
"token": ""
|
||||
"password": "your-password-3"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
"gpt-4o": "deepseek-chat",
|
||||
"gpt-5-codex": "deepseek-reasoner",
|
||||
"o3": "deepseek-reasoner"
|
||||
},
|
||||
"compat": {
|
||||
"wide_input_strict_output": true
|
||||
},
|
||||
"toolcall": {
|
||||
"mode": "feature_match",
|
||||
"early_emit_confidence": "high"
|
||||
},
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
"admin": {
|
||||
"jwt_expire_hours": 24
|
||||
},
|
||||
"runtime": {
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
},
|
||||
"auto_delete": {
|
||||
"sessions": false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
services:
|
||||
ds2api:
|
||||
build: .
|
||||
image: ds2api:latest
|
||||
container_name: ds2api
|
||||
ports:
|
||||
- "${PORT:-5001}:${PORT:-5001}"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- HOST=0.0.0.0
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-qO-", "http://localhost:${PORT:-5001}/healthz"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
image: ghcr.io/cjackhwang/ds2api:latest
|
||||
container_name: ds2api
|
||||
restart: always
|
||||
ports:
|
||||
- "6011:5001"
|
||||
volumes:
|
||||
- ./config.json:/app/config.json # 配置文件
|
||||
- ./.env:/app/.env # 环境变量
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- LOG_LEVEL=INFO
|
||||
- DS2API_ADMIN_KEY=${DS2API_ADMIN_KEY:-ds2api}
|
||||
|
||||
41
docs/toolcall-semantics.md
Normal file
41
docs/toolcall-semantics.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Tool call parsing semantics (Go canonical spec)
|
||||
|
||||
This document defines the cross-runtime contract for `ParseToolCallsDetailed` / `parseToolCallsDetailed`.
|
||||
|
||||
## Output contract
|
||||
|
||||
- `calls`: accepted tool calls with normalized tool names.
|
||||
- `sawToolCallSyntax`: true when tool-call-like syntax is detected (`tool_calls`, `<tool_call>`, `<function_call>`, `<invoke>`) or a valid call is parsed.
|
||||
- `rejectedByPolicy`: true when parser extracted call syntax but all calls are rejected by allow-list policy.
|
||||
- `rejectedToolNames`: de-duplicated rejected tool names in first-seen order.
|
||||
|
||||
## Parse pipeline
|
||||
|
||||
1. Strip fenced code blocks for non-standalone parsing.
|
||||
2. Build candidates from:
|
||||
- full text,
|
||||
- fenced JSON snippets,
|
||||
- extracted JSON objects around `tool_calls`,
|
||||
- first `{` to last `}` object slice.
|
||||
3. Parse each candidate in order:
|
||||
- JSON payload parser (`tool_calls`, list, single call object),
|
||||
- XML/Markup parser (`<tool_call>`, `<function_call>`, `<invoke>`; supports attributes + nested fields),
|
||||
- Text KV fallback parser (`function.name: <name>` ... `function.arguments: {json}`).
|
||||
4. Stop at first candidate that yields at least one call.
|
||||
|
||||
## Name normalization policy
|
||||
|
||||
When matching parsed names against configured tools:
|
||||
|
||||
1. exact match,
|
||||
2. case-insensitive match,
|
||||
3. namespace tail match (`a.b.c` => `c`),
|
||||
4. loose alnum match (remove non `[a-z0-9]`, compare).
|
||||
|
||||
## Standalone mode
|
||||
|
||||
Standalone mode (`ParseStandaloneToolCallsDetailed`) parses the whole input directly (no candidate slicing), while still applying:
|
||||
|
||||
- example-context guard,
|
||||
- JSON then markup fallback,
|
||||
- the same allow-list normalization policy.
|
||||
@@ -1,302 +0,0 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
type Pool struct {
|
||||
store *config.Store
|
||||
mu sync.Mutex
|
||||
queue []string
|
||||
inUse map[string]int
|
||||
waiters []chan struct{}
|
||||
maxInflightPerAccount int
|
||||
recommendedConcurrency int
|
||||
maxQueueSize int
|
||||
}
|
||||
|
||||
func NewPool(store *config.Store) *Pool {
|
||||
p := &Pool{
|
||||
store: store,
|
||||
inUse: map[string]int{},
|
||||
maxInflightPerAccount: maxInflightFromEnv(),
|
||||
}
|
||||
p.Reset()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pool) Reset() {
|
||||
accounts := p.store.Accounts()
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
iHas := accounts[i].Token != ""
|
||||
jHas := accounts[j].Token != ""
|
||||
if iHas == jHas {
|
||||
return i < j
|
||||
}
|
||||
return iHas
|
||||
})
|
||||
ids := make([]string, 0, len(accounts))
|
||||
for _, a := range accounts {
|
||||
id := a.Identifier()
|
||||
if id != "" {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
|
||||
queueLimit := maxQueueFromEnv(recommended)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.drainWaitersLocked()
|
||||
p.queue = ids
|
||||
p.inUse = map[string]int{}
|
||||
p.recommendedConcurrency = recommended
|
||||
p.maxQueueSize = queueLimit
|
||||
config.Logger.Info(
|
||||
"[init_account_queue] initialized",
|
||||
"total", len(ids),
|
||||
"max_inflight_per_account", p.maxInflightPerAccount,
|
||||
"recommended_concurrency", p.recommendedConcurrency,
|
||||
"max_queue_size", p.maxQueueSize,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.acquireLocked(target, normalizeExclude(exclude))
|
||||
}
|
||||
|
||||
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
exclude = normalizeExclude(exclude)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if acc, ok := p.acquireLocked(target, exclude); ok {
|
||||
p.mu.Unlock()
|
||||
return acc, true
|
||||
}
|
||||
if !p.canQueueLocked(target, exclude) {
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
}
|
||||
waiter := make(chan struct{})
|
||||
p.waiters = append(p.waiters, waiter)
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
p.mu.Lock()
|
||||
p.removeWaiterLocked(waiter)
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
case <-waiter:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if target != "" {
|
||||
if exclude[target] || p.inUse[target] >= p.maxInflightPerAccount {
|
||||
return config.Account{}, false
|
||||
}
|
||||
acc, ok := p.store.FindAccount(target)
|
||||
if !ok {
|
||||
return config.Account{}, false
|
||||
}
|
||||
p.inUse[target]++
|
||||
p.bumpQueue(target)
|
||||
return acc, true
|
||||
}
|
||||
|
||||
if acc, ok := p.tryAcquire(exclude, true); ok {
|
||||
return acc, true
|
||||
}
|
||||
if acc, ok := p.tryAcquire(exclude, false); ok {
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
|
||||
for i := 0; i < len(p.queue); i++ {
|
||||
id := p.queue[i]
|
||||
if exclude[id] || p.inUse[id] >= p.maxInflightPerAccount {
|
||||
continue
|
||||
}
|
||||
acc, ok := p.store.FindAccount(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if requireToken && acc.Token == "" {
|
||||
continue
|
||||
}
|
||||
p.inUse[id]++
|
||||
p.bumpQueue(id)
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
func (p *Pool) bumpQueue(accountID string) {
|
||||
for i, id := range p.queue {
|
||||
if id != accountID {
|
||||
continue
|
||||
}
|
||||
p.queue = append(p.queue[:i], p.queue[i+1:]...)
|
||||
p.queue = append(p.queue, accountID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) Release(accountID string) {
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
count := p.inUse[accountID]
|
||||
if count <= 0 {
|
||||
return
|
||||
}
|
||||
if count == 1 {
|
||||
delete(p.inUse, accountID)
|
||||
p.notifyWaiterLocked()
|
||||
return
|
||||
}
|
||||
p.inUse[accountID] = count - 1
|
||||
p.notifyWaiterLocked()
|
||||
}
|
||||
|
||||
func (p *Pool) Status() map[string]any {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
available := make([]string, 0, len(p.queue))
|
||||
inUseAccounts := make([]string, 0, len(p.inUse))
|
||||
inUseSlots := 0
|
||||
for _, id := range p.queue {
|
||||
if p.inUse[id] < p.maxInflightPerAccount {
|
||||
available = append(available, id)
|
||||
}
|
||||
}
|
||||
for id, count := range p.inUse {
|
||||
if count > 0 {
|
||||
inUseAccounts = append(inUseAccounts, id)
|
||||
inUseSlots += count
|
||||
}
|
||||
}
|
||||
sort.Strings(inUseAccounts)
|
||||
return map[string]any{
|
||||
"available": len(available),
|
||||
"in_use": inUseSlots,
|
||||
"total": len(p.store.Accounts()),
|
||||
"available_accounts": available,
|
||||
"in_use_accounts": inUseAccounts,
|
||||
"max_inflight_per_account": p.maxInflightPerAccount,
|
||||
"recommended_concurrency": p.recommendedConcurrency,
|
||||
"waiting": len(p.waiters),
|
||||
"max_queue_size": p.maxQueueSize,
|
||||
}
|
||||
}
|
||||
|
||||
func maxInflightFromEnv() int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
|
||||
if accountCount <= 0 {
|
||||
return 0
|
||||
}
|
||||
if maxInflightPerAccount <= 0 {
|
||||
maxInflightPerAccount = 2
|
||||
}
|
||||
return accountCount * maxInflightPerAccount
|
||||
}
|
||||
|
||||
func normalizeExclude(exclude map[string]bool) map[string]bool {
|
||||
if exclude == nil {
|
||||
return map[string]bool{}
|
||||
}
|
||||
return exclude
|
||||
}
|
||||
|
||||
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
|
||||
if target != "" {
|
||||
if exclude[target] {
|
||||
return false
|
||||
}
|
||||
if _, ok := p.store.FindAccount(target); !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if p.maxQueueSize <= 0 {
|
||||
return false
|
||||
}
|
||||
return len(p.waiters) < p.maxQueueSize
|
||||
}
|
||||
|
||||
func (p *Pool) notifyWaiterLocked() {
|
||||
if len(p.waiters) == 0 {
|
||||
return
|
||||
}
|
||||
waiter := p.waiters[0]
|
||||
p.waiters = p.waiters[1:]
|
||||
close(waiter)
|
||||
}
|
||||
|
||||
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
|
||||
for i, w := range p.waiters {
|
||||
if w != waiter {
|
||||
continue
|
||||
}
|
||||
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Pool) drainWaitersLocked() {
|
||||
for _, waiter := range p.waiters {
|
||||
close(waiter)
|
||||
}
|
||||
p.waiters = nil
|
||||
}
|
||||
|
||||
func maxQueueFromEnv(defaultSize int) int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n >= 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
108
internal/account/pool_acquire.go
Normal file
108
internal/account/pool_acquire.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.acquireLocked(target, normalizeExclude(exclude))
|
||||
}
|
||||
|
||||
func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
exclude = normalizeExclude(exclude)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
if acc, ok := p.acquireLocked(target, exclude); ok {
|
||||
p.mu.Unlock()
|
||||
return acc, true
|
||||
}
|
||||
if !p.canQueueLocked(target, exclude) {
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
}
|
||||
waiter := make(chan struct{})
|
||||
p.waiters = append(p.waiters, waiter)
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
p.mu.Lock()
|
||||
p.removeWaiterLocked(waiter)
|
||||
p.mu.Unlock()
|
||||
return config.Account{}, false
|
||||
case <-waiter:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if target != "" {
|
||||
if exclude[target] || !p.canAcquireIDLocked(target) {
|
||||
return config.Account{}, false
|
||||
}
|
||||
acc, ok := p.store.FindAccount(target)
|
||||
if !ok {
|
||||
return config.Account{}, false
|
||||
}
|
||||
p.inUse[target]++
|
||||
p.bumpQueue(target)
|
||||
return acc, true
|
||||
}
|
||||
|
||||
if acc, ok := p.tryAcquire(exclude, true); ok {
|
||||
return acc, true
|
||||
}
|
||||
if acc, ok := p.tryAcquire(exclude, false); ok {
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
|
||||
for i := 0; i < len(p.queue); i++ {
|
||||
id := p.queue[i]
|
||||
if exclude[id] || !p.canAcquireIDLocked(id) {
|
||||
continue
|
||||
}
|
||||
acc, ok := p.store.FindAccount(id)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if requireToken && acc.Token == "" {
|
||||
continue
|
||||
}
|
||||
p.inUse[id]++
|
||||
p.bumpQueue(id)
|
||||
return acc, true
|
||||
}
|
||||
return config.Account{}, false
|
||||
}
|
||||
|
||||
func (p *Pool) bumpQueue(accountID string) {
|
||||
for i, id := range p.queue {
|
||||
if id != accountID {
|
||||
continue
|
||||
}
|
||||
p.queue = append(p.queue[:i], p.queue[i+1:]...)
|
||||
p.queue = append(p.queue, accountID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeExclude(exclude map[string]bool) map[string]bool {
|
||||
if exclude == nil {
|
||||
return map[string]bool{}
|
||||
}
|
||||
return exclude
|
||||
}
|
||||
132
internal/account/pool_core.go
Normal file
132
internal/account/pool_core.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
type Pool struct {
|
||||
store *config.Store
|
||||
mu sync.Mutex
|
||||
queue []string
|
||||
inUse map[string]int
|
||||
waiters []chan struct{}
|
||||
maxInflightPerAccount int
|
||||
recommendedConcurrency int
|
||||
maxQueueSize int
|
||||
globalMaxInflight int
|
||||
}
|
||||
|
||||
func NewPool(store *config.Store) *Pool {
|
||||
maxPer := 2
|
||||
if store != nil {
|
||||
maxPer = store.RuntimeAccountMaxInflight()
|
||||
}
|
||||
p := &Pool{
|
||||
store: store,
|
||||
inUse: map[string]int{},
|
||||
maxInflightPerAccount: maxPer,
|
||||
}
|
||||
p.Reset()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Pool) Reset() {
|
||||
accounts := p.store.Accounts()
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
iHas := accounts[i].Token != ""
|
||||
jHas := accounts[j].Token != ""
|
||||
if iHas == jHas {
|
||||
return i < j
|
||||
}
|
||||
return iHas
|
||||
})
|
||||
ids := make([]string, 0, len(accounts))
|
||||
for _, a := range accounts {
|
||||
id := a.Identifier()
|
||||
if id != "" {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
if p.store != nil {
|
||||
p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight()
|
||||
} else {
|
||||
p.maxInflightPerAccount = maxInflightFromEnv()
|
||||
}
|
||||
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
|
||||
queueLimit := maxQueueFromEnv(recommended)
|
||||
globalLimit := recommended
|
||||
if p.store != nil {
|
||||
queueLimit = p.store.RuntimeAccountMaxQueue(recommended)
|
||||
globalLimit = p.store.RuntimeGlobalMaxInflight(recommended)
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.drainWaitersLocked()
|
||||
p.queue = ids
|
||||
p.inUse = map[string]int{}
|
||||
p.recommendedConcurrency = recommended
|
||||
p.maxQueueSize = queueLimit
|
||||
p.globalMaxInflight = globalLimit
|
||||
config.Logger.Info(
|
||||
"[init_account_queue] initialized",
|
||||
"total", len(ids),
|
||||
"max_inflight_per_account", p.maxInflightPerAccount,
|
||||
"global_max_inflight", p.globalMaxInflight,
|
||||
"recommended_concurrency", p.recommendedConcurrency,
|
||||
"max_queue_size", p.maxQueueSize,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Pool) Release(accountID string) {
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
count := p.inUse[accountID]
|
||||
if count <= 0 {
|
||||
return
|
||||
}
|
||||
if count == 1 {
|
||||
delete(p.inUse, accountID)
|
||||
p.notifyWaiterLocked()
|
||||
return
|
||||
}
|
||||
p.inUse[accountID] = count - 1
|
||||
p.notifyWaiterLocked()
|
||||
}
|
||||
|
||||
func (p *Pool) Status() map[string]any {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
available := make([]string, 0, len(p.queue))
|
||||
inUseAccounts := make([]string, 0, len(p.inUse))
|
||||
inUseSlots := 0
|
||||
for _, id := range p.queue {
|
||||
if p.inUse[id] < p.maxInflightPerAccount {
|
||||
available = append(available, id)
|
||||
}
|
||||
}
|
||||
for id, count := range p.inUse {
|
||||
if count > 0 {
|
||||
inUseAccounts = append(inUseAccounts, id)
|
||||
inUseSlots += count
|
||||
}
|
||||
}
|
||||
sort.Strings(inUseAccounts)
|
||||
return map[string]any{
|
||||
"available": len(available),
|
||||
"in_use": inUseSlots,
|
||||
"total": len(p.store.Accounts()),
|
||||
"available_accounts": available,
|
||||
"in_use_accounts": inUseAccounts,
|
||||
"max_inflight_per_account": p.maxInflightPerAccount,
|
||||
"global_max_inflight": p.globalMaxInflight,
|
||||
"recommended_concurrency": p.recommendedConcurrency,
|
||||
"waiting": len(p.waiters),
|
||||
"max_queue_size": p.maxQueueSize,
|
||||
}
|
||||
}
|
||||
249
internal/account/pool_edge_test.go
Normal file
249
internal/account/pool_edge_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── Pool edge cases ─────────────────────────────────────────────────
|
||||
|
||||
func TestPoolEmptyNoAccounts(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "2")
|
||||
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
|
||||
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
|
||||
pool := NewPool(config.LoadStore())
|
||||
if _, ok := pool.Acquire("", nil); ok {
|
||||
t.Fatal("expected acquire to fail with no accounts")
|
||||
}
|
||||
status := pool.Status()
|
||||
if total, ok := status["total"].(int); !ok || total != 0 {
|
||||
t.Fatalf("unexpected total: %#v", status["total"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolReleaseNonExistentAccount(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
pool.Release("nonexistent@example.com") // should not panic
|
||||
}
|
||||
|
||||
func TestPoolReleaseAlreadyReleased(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
acc, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected acquire success")
|
||||
}
|
||||
pool.Release(acc.Identifier())
|
||||
pool.Release(acc.Identifier()) // double release should not panic
|
||||
}
|
||||
|
||||
func TestPoolAcquireTargetNotFound(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
if _, ok := pool.Acquire("nonexistent@example.com", nil); ok {
|
||||
t.Fatal("expected acquire to fail for non-existent target")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolAcquireWithExclusionList(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
acc, ok := pool.Acquire("", map[string]bool{"acc1@example.com": true})
|
||||
if !ok {
|
||||
t.Fatal("expected acquire success with exclusion")
|
||||
}
|
||||
if acc.Identifier() != "acc2@example.com" {
|
||||
t.Fatalf("expected acc2 when acc1 excluded, got %q", acc.Identifier())
|
||||
}
|
||||
pool.Release(acc.Identifier())
|
||||
}
|
||||
|
||||
func TestPoolAcquireAllExcluded(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
if _, ok := pool.Acquire("", map[string]bool{
|
||||
"acc1@example.com": true,
|
||||
"acc2@example.com": true,
|
||||
}); ok {
|
||||
t.Fatal("expected acquire to fail when all accounts excluded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolStatusFields(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
status := pool.Status()
|
||||
|
||||
// Check all expected fields are present
|
||||
for _, key := range []string{"total", "available", "max_inflight_per_account", "recommended_concurrency", "available_accounts", "in_use_accounts", "waiting", "max_queue_size"} {
|
||||
if _, ok := status[key]; !ok {
|
||||
t.Fatalf("missing status field: %s", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolStatusAccountDetails(t *testing.T) {
|
||||
pool := newPoolForTest(t, "2")
|
||||
acc, _ := pool.Acquire("acc1@example.com", nil)
|
||||
|
||||
status := pool.Status()
|
||||
inUseAccounts, ok := status["in_use_accounts"].([]string)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected in_use_accounts type: %T", status["in_use_accounts"])
|
||||
}
|
||||
found := false
|
||||
for _, id := range inUseAccounts {
|
||||
if id == "acc1@example.com" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected acc1 in in_use_accounts, got %v", inUseAccounts)
|
||||
}
|
||||
if status["in_use"] != 1 {
|
||||
t.Fatalf("expected 1 in_use, got %v", status["in_use"])
|
||||
}
|
||||
|
||||
pool.Release(acc.Identifier())
|
||||
}
|
||||
|
||||
func TestPoolAcquireWaitContextCancelled(t *testing.T) {
|
||||
pool := newSingleAccountPoolForTest(t, "1")
|
||||
// Exhaust the pool
|
||||
first, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected first acquire to succeed")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
var waitOK bool
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, waitOK = pool.AcquireWait(ctx, "", nil)
|
||||
}()
|
||||
|
||||
// Wait until queued
|
||||
waitForWaitingCount(t, pool, 1)
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
|
||||
wg.Wait()
|
||||
if waitOK {
|
||||
t.Fatal("expected acquire to fail after context cancellation")
|
||||
}
|
||||
|
||||
pool.Release(first.Identifier())
|
||||
}
|
||||
|
||||
func TestPoolAcquireWaitTargetAccount(t *testing.T) {
|
||||
pool := newPoolForTest(t, "1")
|
||||
// Exhaust acc1
|
||||
acc1, ok := pool.Acquire("acc1@example.com", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected acquire acc1 success")
|
||||
}
|
||||
|
||||
// Acquire acc2 directly (should succeed since acc2 is free)
|
||||
ctx := context.Background()
|
||||
acc2, ok := pool.AcquireWait(ctx, "acc2@example.com", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected acquire acc2 success via AcquireWait")
|
||||
}
|
||||
if acc2.Identifier() != "acc2@example.com" {
|
||||
t.Fatalf("expected acc2, got %q", acc2.Identifier())
|
||||
}
|
||||
|
||||
pool.Release(acc1.Identifier())
|
||||
pool.Release(acc2.Identifier())
|
||||
}
|
||||
|
||||
func TestPoolMaxQueueSizeOverride(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "5")
|
||||
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
|
||||
pool := NewPool(config.LoadStore())
|
||||
status := pool.Status()
|
||||
if got, ok := status["max_queue_size"].(int); !ok || got != 5 {
|
||||
t.Fatalf("expected max_queue_size=5, got %#v", status["max_queue_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolQueueSizeAliasEnv(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_ACCOUNT_CONCURRENCY", "")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "")
|
||||
t.Setenv("DS2API_ACCOUNT_QUEUE_SIZE", "7")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"acc1@example.com","token":"t1"}]}`)
|
||||
pool := NewPool(config.LoadStore())
|
||||
status := pool.Status()
|
||||
if got, ok := status["max_queue_size"].(int); !ok || got != 7 {
|
||||
t.Fatalf("expected max_queue_size=7, got %#v", status["max_queue_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolMultipleAcquireReleaseCycles(t *testing.T) {
|
||||
pool := newSingleAccountPoolForTest(t, "1")
|
||||
for i := 0; i < 10; i++ {
|
||||
acc, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatalf("acquire failed at cycle %d", i)
|
||||
}
|
||||
pool.Release(acc.Identifier())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolConcurrentAcquireWait(t *testing.T) {
|
||||
pool := newSingleAccountPoolForTest(t, "1")
|
||||
first, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatal("expected first acquire success")
|
||||
}
|
||||
|
||||
const waiters = 3
|
||||
results := make(chan bool, waiters)
|
||||
|
||||
for i := 0; i < waiters; i++ {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, ok := pool.AcquireWait(ctx, "", nil)
|
||||
results <- ok
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all to be queued (only 1 can queue)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Release and allow queued requests to proceed
|
||||
pool.Release(first.Identifier())
|
||||
|
||||
successCount := 0
|
||||
timeoutCount := 0
|
||||
for i := 0; i < waiters; i++ {
|
||||
select {
|
||||
case ok := <-results:
|
||||
if ok {
|
||||
successCount++
|
||||
// Release for next waiter
|
||||
pool.Release("acc1@example.com")
|
||||
} else {
|
||||
timeoutCount++
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timed out waiting for results")
|
||||
}
|
||||
}
|
||||
|
||||
// At least 1 should succeed; 2 may fail due to queue limit
|
||||
if successCount < 1 {
|
||||
t.Fatalf("expected at least 1 success, got success=%d timeout=%d", successCount, timeoutCount)
|
||||
}
|
||||
}
|
||||
91
internal/account/pool_limits.go
Normal file
91
internal/account/pool_limits.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) {
|
||||
if maxInflightPerAccount <= 0 {
|
||||
maxInflightPerAccount = 1
|
||||
}
|
||||
if maxQueueSize < 0 {
|
||||
maxQueueSize = 0
|
||||
}
|
||||
if globalMaxInflight <= 0 {
|
||||
globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts())
|
||||
if globalMaxInflight <= 0 {
|
||||
globalMaxInflight = maxInflightPerAccount
|
||||
}
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.maxInflightPerAccount = maxInflightPerAccount
|
||||
p.maxQueueSize = maxQueueSize
|
||||
p.globalMaxInflight = globalMaxInflight
|
||||
p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount)
|
||||
p.notifyWaiterLocked()
|
||||
}
|
||||
|
||||
func maxInflightFromEnv() int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
|
||||
if accountCount <= 0 {
|
||||
return 0
|
||||
}
|
||||
if maxInflightPerAccount <= 0 {
|
||||
maxInflightPerAccount = 2
|
||||
}
|
||||
return accountCount * maxInflightPerAccount
|
||||
}
|
||||
|
||||
func maxQueueFromEnv(defaultSize int) int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n >= 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
func (p *Pool) canAcquireIDLocked(accountID string) bool {
|
||||
if accountID == "" {
|
||||
return false
|
||||
}
|
||||
if p.inUse[accountID] >= p.maxInflightPerAccount {
|
||||
return false
|
||||
}
|
||||
if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Pool) currentInUseLocked() int {
|
||||
total := 0
|
||||
for _, n := range p.inUse {
|
||||
total += n
|
||||
}
|
||||
return total
|
||||
}
|
||||
@@ -194,7 +194,7 @@ func TestPoolAccountConcurrencyAliasEnv(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
|
||||
func TestPoolDropsLegacyTokenOnlyAccountOnLoad(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
@@ -203,19 +203,15 @@ func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
|
||||
|
||||
pool := NewPool(config.LoadStore())
|
||||
status := pool.Status()
|
||||
if got, ok := status["total"].(int); !ok || got != 1 {
|
||||
if got, ok := status["total"].(int); !ok || got != 0 {
|
||||
t.Fatalf("unexpected total in pool status: %#v", status["total"])
|
||||
}
|
||||
if got, ok := status["available"].(int); !ok || got != 1 {
|
||||
if got, ok := status["available"].(int); !ok || got != 0 {
|
||||
t.Fatalf("unexpected available in pool status: %#v", status["available"])
|
||||
}
|
||||
|
||||
acc, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatalf("expected acquire success for token-only account")
|
||||
}
|
||||
if acc.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token on acquired account: %q", acc.Token)
|
||||
if _, ok := pool.Acquire("", nil); ok {
|
||||
t.Fatalf("expected acquire to fail for token-only account")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
43
internal/account/pool_waiters.go
Normal file
43
internal/account/pool_waiters.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package account
|
||||
|
||||
func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
|
||||
if target != "" {
|
||||
if exclude[target] {
|
||||
return false
|
||||
}
|
||||
if _, ok := p.store.FindAccount(target); !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if p.maxQueueSize <= 0 {
|
||||
return false
|
||||
}
|
||||
return len(p.waiters) < p.maxQueueSize
|
||||
}
|
||||
|
||||
func (p *Pool) notifyWaiterLocked() {
|
||||
if len(p.waiters) == 0 {
|
||||
return
|
||||
}
|
||||
waiter := p.waiters[0]
|
||||
p.waiters = p.waiters[1:]
|
||||
close(waiter)
|
||||
}
|
||||
|
||||
func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
|
||||
for i, w := range p.waiters {
|
||||
if w != waiter {
|
||||
continue
|
||||
}
|
||||
p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Pool) drainWaitersLocked() {
|
||||
for _, waiter := range p.waiters {
|
||||
close(waiter)
|
||||
}
|
||||
p.waiters = nil
|
||||
}
|
||||
11
internal/adapter/claude/convert.go
Normal file
11
internal/adapter/claude/convert.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/claudeconv"
|
||||
)
|
||||
|
||||
const defaultClaudeModel = "claude-sonnet-4-5"
|
||||
|
||||
func convertClaudeToDeepSeek(claudeReq map[string]any, store ConfigReader) map[string]any {
|
||||
return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, defaultClaudeModel)
|
||||
}
|
||||
29
internal/adapter/claude/deps.go
Normal file
29
internal/adapter/claude/deps.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type AuthResolver interface {
|
||||
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||
Release(a *auth.RequestAuth)
|
||||
}
|
||||
|
||||
type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
}
|
||||
|
||||
type ConfigReader interface {
|
||||
ClaudeMapping() map[string]string
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
33
internal/adapter/claude/deps_injection_test.go
Normal file
33
internal/adapter/claude/deps_injection_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package claude
|
||||
|
||||
import "testing"
|
||||
|
||||
type mockClaudeConfig struct {
|
||||
m map[string]string
|
||||
}
|
||||
|
||||
func (m mockClaudeConfig) ClaudeMapping() map[string]string { return m.m }
|
||||
|
||||
func TestNormalizeClaudeRequestUsesConfigInterfaceMapping(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
out, err := normalizeClaudeRequest(mockClaudeConfig{
|
||||
m: map[string]string{
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner-search",
|
||||
},
|
||||
}, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeClaudeRequest error: %v", err)
|
||||
}
|
||||
if out.Standard.ResolvedModel != "deepseek-reasoner-search" {
|
||||
t.Fatalf("resolved model mismatch: got=%q", out.Standard.ResolvedModel)
|
||||
}
|
||||
if !out.Standard.Thinking || !out.Standard.Search {
|
||||
t.Fatalf("unexpected flags: thinking=%v search=%v", out.Standard.Thinking, out.Standard.Search)
|
||||
}
|
||||
}
|
||||
34
internal/adapter/claude/error_shape_test.go
Normal file
34
internal/adapter/claude/error_shape_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
writeClaudeError(rec, http.StatusUnauthorized, "bad token")
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode body: %v", err)
|
||||
}
|
||||
errObj, _ := body["error"].(map[string]any)
|
||||
if errObj["message"] != "bad token" {
|
||||
t.Fatalf("unexpected message: %v", errObj["message"])
|
||||
}
|
||||
if errObj["type"] != "invalid_request_error" {
|
||||
t.Fatalf("unexpected type: %v", errObj["type"])
|
||||
}
|
||||
if errObj["code"] != "authentication_failed" {
|
||||
t.Fatalf("unexpected code: %v", errObj["code"])
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatal("expected param field")
|
||||
}
|
||||
}
|
||||
@@ -1,603 +0,0 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
DS *deepseek.Client
|
||||
}
|
||||
|
||||
var (
|
||||
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
|
||||
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
|
||||
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
|
||||
)
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/anthropic/v1/models", h.ListModels)
|
||||
r.Post("/anthropic/v1/messages", h.Messages)
|
||||
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
||||
}
|
||||
|
||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}})
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}})
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}})
|
||||
return
|
||||
}
|
||||
|
||||
normalized := normalizeClaudeMessages(messagesRaw)
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalized
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
if len(toolsRequested) > 0 && !hasSystemMessage(normalized) {
|
||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...)
|
||||
}
|
||||
|
||||
dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||
if !ok {
|
||||
thinkingEnabled = false
|
||||
searchEnabled = false
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}})
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}})
|
||||
return
|
||||
}
|
||||
requestPayload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}})
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
|
||||
return
|
||||
}
|
||||
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
if util.ToBool(req["stream"]) {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
fullText := result.Text
|
||||
fullThinking := result.Thinking
|
||||
detected := util.ParseToolCalls(fullText, toolNames)
|
||||
content := make([]map[string]any, 0, 4)
|
||||
if fullThinking != "" {
|
||||
content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking})
|
||||
}
|
||||
stopReason := "end_turn"
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
content = append(content, map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if fullText == "" {
|
||||
fullText = "抱歉,没有生成有效的响应内容。"
|
||||
}
|
||||
content = append(content, map[string]any{"type": "text", "text": fullText})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"id": fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content,
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)),
|
||||
"output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"})
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messages, _ := req["messages"].([]any)
|
||||
if model == "" || len(messages) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."})
|
||||
return
|
||||
}
|
||||
inputTokens := 0
|
||||
if sys, ok := req["system"].(string); ok {
|
||||
inputTokens += util.EstimateTokens(sys)
|
||||
}
|
||||
for _, item := range messages {
|
||||
msg, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inputTokens += 2
|
||||
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
||||
}
|
||||
if tools, ok := req["tools"].([]any); ok {
|
||||
for _, t := range tools {
|
||||
b, _ := json.Marshal(t)
|
||||
inputTokens += util.EstimateTokens(string(b))
|
||||
}
|
||||
}
|
||||
if inputTokens < 1 {
|
||||
inputTokens = 1
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
canFlush := rc.Flush() == nil
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
send := func(event string, v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = w.Write([]byte("event: "))
|
||||
_, _ = w.Write([]byte(event))
|
||||
_, _ = w.Write([]byte("\n"))
|
||||
_, _ = w.Write([]byte("data: "))
|
||||
_, _ = w.Write(b)
|
||||
_, _ = w.Write([]byte("\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
sendError := func(message string) {
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
msg = "upstream stream error"
|
||||
}
|
||||
send("error", map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "api_error",
|
||||
"message": msg,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
|
||||
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages))
|
||||
send("message_start", map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
})
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
hasContent := false
|
||||
lastContent := time.Now()
|
||||
keepaliveCount := 0
|
||||
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
|
||||
nextBlockIndex := 0
|
||||
thinkingBlockOpen := false
|
||||
thinkingBlockIndex := -1
|
||||
textBlockOpen := false
|
||||
textBlockIndex := -1
|
||||
ended := false
|
||||
|
||||
closeThinkingBlock := func() {
|
||||
if !thinkingBlockOpen {
|
||||
return
|
||||
}
|
||||
send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": thinkingBlockIndex,
|
||||
})
|
||||
thinkingBlockOpen = false
|
||||
thinkingBlockIndex = -1
|
||||
}
|
||||
closeTextBlock := func() {
|
||||
if !textBlockOpen {
|
||||
return
|
||||
}
|
||||
send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": textBlockIndex,
|
||||
})
|
||||
textBlockOpen = false
|
||||
textBlockIndex = -1
|
||||
}
|
||||
|
||||
finalize := func(stopReason string) {
|
||||
if ended {
|
||||
return
|
||||
}
|
||||
ended = true
|
||||
|
||||
closeThinkingBlock()
|
||||
closeTextBlock()
|
||||
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
|
||||
if bufferToolContent {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
idx := nextBlockIndex + i
|
||||
send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
},
|
||||
})
|
||||
send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
nextBlockIndex += len(detected)
|
||||
} else if finalText != "" {
|
||||
idx := nextBlockIndex
|
||||
nextBlockIndex++
|
||||
send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": idx,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": finalText,
|
||||
},
|
||||
})
|
||||
send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
send("message_stop", map[string]any{"type": "message_stop"})
|
||||
}
|
||||
|
||||
pingTicker := time.NewTicker(claudeStreamPingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-pingTicker.C:
|
||||
if !hasContent {
|
||||
keepaliveCount++
|
||||
if keepaliveCount >= claudeStreamMaxKeepaliveCnt {
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
}
|
||||
if hasContent && time.Since(lastContent) > claudeStreamIdleTimeout {
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
send("ping", map[string]any{"type": "ping"})
|
||||
case parsed, ok := <-parsedLines:
|
||||
if !ok {
|
||||
if err := <-done; err != nil {
|
||||
sendError(err.Error())
|
||||
return
|
||||
}
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
sendError(parsed.ErrorMessage)
|
||||
return
|
||||
}
|
||||
if parsed.Stop {
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
|
||||
hasContent = true
|
||||
lastContent = time.Now()
|
||||
keepaliveCount = 0
|
||||
|
||||
if p.Type == "thinking" {
|
||||
if !thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
thinking.WriteString(p.Text)
|
||||
closeTextBlock()
|
||||
if !thinkingBlockOpen {
|
||||
thinkingBlockIndex = nextBlockIndex
|
||||
nextBlockIndex++
|
||||
send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": thinkingBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
},
|
||||
})
|
||||
thinkingBlockOpen = true
|
||||
}
|
||||
send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": thinkingBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "thinking_delta",
|
||||
"thinking": p.Text,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
text.WriteString(p.Text)
|
||||
if bufferToolContent {
|
||||
continue
|
||||
}
|
||||
closeThinkingBlock()
|
||||
if !textBlockOpen {
|
||||
textBlockIndex = nextBlockIndex
|
||||
nextBlockIndex++
|
||||
send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": textBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
textBlockOpen = true
|
||||
}
|
||||
send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": textBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": p.Text,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeClaudeMessages(messages []any) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
switch content := msg["content"].(type) {
|
||||
case []any:
|
||||
parts := make([]string, 0, len(content))
|
||||
for _, block := range content {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr, _ := b["type"].(string)
|
||||
if typeStr == "text" {
|
||||
if t, ok := b["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
}
|
||||
}
|
||||
if typeStr == "tool_result" {
|
||||
parts = append(parts, fmt.Sprintf("%v", b["content"]))
|
||||
}
|
||||
}
|
||||
copied["content"] = strings.Join(parts, "\n")
|
||||
}
|
||||
out = append(out, copied)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeToolPrompt(tools []any) string {
|
||||
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := m["name"].(string)
|
||||
desc, _ := m["description"].(string)
|
||||
schema, _ := json.Marshal(m["input_schema"])
|
||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
}
|
||||
parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}")
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if ok && msg["role"] == "system" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractClaudeToolNames(tools []any) []string {
|
||||
out := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if name, ok := m["name"].(string); ok && name != "" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toMessageMaps(v any) []map[string]any {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractMessageContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, it := range x {
|
||||
parts = append(parts, fmt.Sprintf("%v", it))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return fmt.Sprintf("%v", x)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
25
internal/adapter/claude/handler_errors.go
Normal file
25
internal/adapter/claude/handler_errors.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package claude
|
||||
|
||||
import "net/http"
|
||||
|
||||
func writeClaudeError(w http.ResponseWriter, status int, message string) {
|
||||
code := "invalid_request"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
code = "authentication_failed"
|
||||
case http.StatusTooManyRequests:
|
||||
code = "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
code = "not_found"
|
||||
case http.StatusInternalServerError:
|
||||
code = "internal_error"
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "invalid_request_error",
|
||||
"message": message,
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
134
internal/adapter/claude/handler_messages.go
Normal file
134
internal/adapter/claude/handler_messages.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
claudefmt "ds2api/internal/format/claude"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeClaudeError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
stdReq := norm.Standard
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
|
||||
return
|
||||
}
|
||||
requestPayload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
if stdReq.Stream {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, stdReq.Thinking, true)
|
||||
respBody := claudefmt.BuildMessageResponse(
|
||||
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
stdReq.ResponseModel,
|
||||
norm.NormalizedMessages,
|
||||
result.Thinking,
|
||||
result.Text,
|
||||
stdReq.ToolNames,
|
||||
)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
|
||||
streamRuntime := newClaudeStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
model,
|
||||
messages,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
)
|
||||
streamRuntime.sendMessageStart()
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: claudeStreamPingInterval,
|
||||
IdleTimeout: claudeStreamIdleTimeout,
|
||||
MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendPing()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: streamRuntime.onFinalize,
|
||||
})
|
||||
}
|
||||
41
internal/adapter/claude/handler_routes.go
Normal file
41
internal/adapter/claude/handler_routes.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
}
|
||||
|
||||
var (
|
||||
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
|
||||
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
|
||||
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
|
||||
)
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/anthropic/v1/models", h.ListModels)
|
||||
r.Post("/anthropic/v1/messages", h.Messages)
|
||||
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/v1/messages", h.Messages)
|
||||
r.Post("/messages", h.Messages)
|
||||
r.Post("/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/messages/count_tokens", h.CountTokens)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
||||
}
|
||||
@@ -183,6 +183,66 @@ func TestHandleClaudeStreamRealtimeToolSafety(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeToolDetectionFromThinkingFallback(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
`data: {"p":"response/thinking_content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" && contentBlock["name"] == "search" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block from thinking fallback, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeSkipsThinkingFallbackWhenFinalTextExists(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
`data: {"p":"response/thinking_content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"normal answer"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
t.Fatalf("unexpected tool_use block when final text exists, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
foundEndTurn := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "end_turn" {
|
||||
foundEndTurn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundEndTurn {
|
||||
t.Fatalf("expected stop_reason=end_turn, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeUpstreamErrorEvent(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
@@ -255,3 +315,88 @@ func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload string
|
||||
}{
|
||||
{name: "xml_tool_call", payload: `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`},
|
||||
{name: "xml_json_tool_call", payload: `<tool_call>{"tool":"Bash","params":{"command":"pwd"}}</tool_call>`},
|
||||
{name: "nested_tool_tag_style", payload: `<tool_call><tool name="Bash"><command>pwd</command></tool></tool_call>`},
|
||||
{name: "function_tag_style", payload: `<function_call>Bash</function_call><function parameter name="command">pwd</function parameter>`},
|
||||
{name: "antml_argument_style", payload: `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument></antml:function_call></antml:function_calls>`},
|
||||
{name: "antml_function_attr_parameters", payload: `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`},
|
||||
{name: "invoke_parameter_style", payload: `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></function_calls>`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"`+strings.ReplaceAll(tc.payload, `"`, `\"`)+`"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block for format %s, body=%s", tc.name, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Here is an example:\\n```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"Bash\\\",\\\"input\\\":{\\\"command\\\":\\\"pwd\\\"}}]}\"}",
|
||||
"data: {\"p\":\"response/content\",\"v\":\"\\n```\\nDo not execute it.\"}",
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundToolUse {
|
||||
t.Fatalf("expected no tool_use for fenced example, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundToolStop := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "tool_use" {
|
||||
foundToolStop = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundToolStop {
|
||||
t.Fatalf("expected stop_reason to remain content-only, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestHandleClaudeStreamRealtimePromotesUnclosedFencedToolExample(t *testing.T) {
|
||||
TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t)
|
||||
}
|
||||
|
||||
51
internal/adapter/claude/handler_tokens.go
Normal file
51
internal/adapter/claude/handler_tokens.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messages, _ := req["messages"].([]any)
|
||||
if model == "" || len(messages) == 0 {
|
||||
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
return
|
||||
}
|
||||
inputTokens := 0
|
||||
if sys, ok := req["system"].(string); ok {
|
||||
inputTokens += util.EstimateTokens(sys)
|
||||
}
|
||||
for _, item := range messages {
|
||||
msg, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inputTokens += 2
|
||||
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
||||
}
|
||||
if tools, ok := req["tools"].([]any); ok {
|
||||
for _, t := range tools {
|
||||
b, _ := json.Marshal(t)
|
||||
inputTokens += util.EstimateTokens(string(b))
|
||||
}
|
||||
}
|
||||
if inputTokens < 1 {
|
||||
inputTokens = 1
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
||||
}
|
||||
522
internal/adapter/claude/handler_util_test.go
Normal file
522
internal/adapter/claude/handler_util_test.go
Normal file
@@ -0,0 +1,522 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ─── normalizeClaudeMessages ─────────────────────────────────────────
|
||||
|
||||
func TestNormalizeClaudeMessagesSimpleString(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "user", "content": "Hello"},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %v", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesArrayContent(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "line1"},
|
||||
map[string]any{"type": "text", "text": "line2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "line1\nline2" {
|
||||
t.Fatalf("expected joined text, got %q", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "tool_result", "content": "tool output"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role preserved, got %#v", m["role"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if content != "tool output" {
|
||||
t.Fatalf("expected raw tool output content preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_1",
|
||||
"name": "search_web",
|
||||
"input": map[string]any{"query": "latest"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized tool-call message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant role, got %#v", m["role"])
|
||||
}
|
||||
tc, _ := m["tool_calls"].([]any)
|
||||
if len(tc) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", m["tool_calls"])
|
||||
}
|
||||
call, _ := tc[0].(map[string]any)
|
||||
if call["id"] != "call_1" {
|
||||
t.Fatalf("expected call id preserved, got %#v", call)
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, "search_web") || !containsStr(content, `"arguments":"{\"query\":\"latest\"}"`) {
|
||||
t.Fatalf("expected assistant content to include serialized tool call for prompt roundtrip, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesDoesNotPromoteUserToolUse(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_unsafe",
|
||||
"name": "dangerous_tool",
|
||||
"input": map[string]any{"value": "x"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "user" {
|
||||
t.Fatalf("expected user role preserved, got %#v", m["role"])
|
||||
}
|
||||
if _, ok := m["tool_calls"]; ok {
|
||||
t.Fatalf("expected no tool_calls promotion for user message, got %#v", m["tool_calls"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, `"type":"tool_use"`) || !containsStr(content, "dangerous_tool") {
|
||||
t.Fatalf("expected raw tool_use block preserved in user content, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) {
|
||||
msgs := []any{"not a map", 42}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected 0 messages for non-map items, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesEmpty(t *testing.T) {
|
||||
got := normalizeClaudeMessages(nil)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected 0, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "assistant", "content": "response"},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected 'assistant', got %q", m["role"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "Hello"},
|
||||
map[string]any{"type": "image", "source": map[string]any{"type": "base64", "data": strings.Repeat("A", 2048)}},
|
||||
map[string]any{"type": "text", "text": "World"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, "Hello") || !containsStr(content, "World") || !containsStr(content, `"type":"image"`) {
|
||||
t.Fatalf("expected text plus non-text block marker preserved, got %q", content)
|
||||
}
|
||||
if !containsStr(content, omittedBinaryMarker) {
|
||||
t.Fatalf("expected binary payload omitted marker, got %q", content)
|
||||
}
|
||||
if containsStr(content, strings.Repeat("A", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolResultNonTextPayloadStringified(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_image_1",
|
||||
"name": "vision_tool",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "image analysis"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{"type": "base64", "media_type": "image/png", "data": strings.Repeat("B", 2048)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", m["role"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, `"type":"tool_result"`) || !containsStr(content, `"type":"image"`) {
|
||||
t.Fatalf("expected non-text tool_result payload to be JSON stringified, got %q", content)
|
||||
}
|
||||
if !containsStr(content, omittedBinaryMarker) {
|
||||
t.Fatalf("expected binary data to be sanitized with omitted marker, got %q", content)
|
||||
}
|
||||
if containsStr(content, strings.Repeat("B", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── buildClaudeToolPrompt ───────────────────────────────────────────
|
||||
|
||||
func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if prompt == "" {
|
||||
t.Fatal("expected non-empty prompt")
|
||||
}
|
||||
// Should contain tool name and description
|
||||
if !containsStr(prompt, "search") {
|
||||
t.Fatalf("expected 'search' in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "Search the web") {
|
||||
t.Fatalf("expected description in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "tool_use") {
|
||||
t.Fatalf("expected tool_use instruction in prompt")
|
||||
}
|
||||
if containsStr(prompt, "TOOL_CALL_HISTORY") || containsStr(prompt, "TOOL_RESULT_HISTORY") {
|
||||
t.Fatalf("expected legacy tool history markers removed from prompt")
|
||||
}
|
||||
if !containsStr(prompt, "Do not print tool-call JSON in text") {
|
||||
t.Fatalf("expected prompt to keep no tool-call-json instruction")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeToolPromptMultipleTools(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": "tool1", "description": "desc1"},
|
||||
map[string]any{"name": "tool2", "description": "desc2"},
|
||||
}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") {
|
||||
t.Fatalf("expected both tools in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeToolPromptSupportsOpenAIStyleFunctionTool(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
"description": "Search via function tool",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"q": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if !containsStr(prompt, "Tool: search") {
|
||||
t.Fatalf("expected OpenAI-style function tool name in prompt, got: %q", prompt)
|
||||
}
|
||||
if !containsStr(prompt, "Search via function tool") {
|
||||
t.Fatalf("expected OpenAI-style function tool description in prompt, got: %q", prompt)
|
||||
}
|
||||
if !containsStr(prompt, "\"q\"") {
|
||||
t.Fatalf("expected parameters schema serialized in prompt, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) {
|
||||
tools := []any{"not a map"}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if prompt == "" {
|
||||
t.Fatal("expected non-empty prompt even with invalid tools")
|
||||
}
|
||||
// Should still contain the intro and instruction
|
||||
if !containsStr(prompt, "You are Claude") {
|
||||
t.Fatalf("expected intro in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── hasSystemMessage ────────────────────────────────────────────────
|
||||
|
||||
func TestHasSystemMessageTrue(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "system", "content": "You are a helper"},
|
||||
map[string]any{"role": "user", "content": "Hi"},
|
||||
}
|
||||
if !hasSystemMessage(msgs) {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSystemMessageFalse(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "user", "content": "Hi"},
|
||||
map[string]any{"role": "assistant", "content": "Hello"},
|
||||
}
|
||||
if hasSystemMessage(msgs) {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSystemMessageEmpty(t *testing.T) {
|
||||
if hasSystemMessage(nil) {
|
||||
t.Fatal("expected false for nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSystemMessageNonMap(t *testing.T) {
|
||||
msgs := []any{"not a map"}
|
||||
if hasSystemMessage(msgs) {
|
||||
t.Fatal("expected false for non-map")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── extractClaudeToolNames ──────────────────────────────────────────
|
||||
|
||||
func TestExtractClaudeToolNamesSingle(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": "search"},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 1 || names[0] != "search" {
|
||||
t.Fatalf("expected [search], got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesMultiple(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": "search"},
|
||||
map[string]any{"name": "calculate"},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 2 {
|
||||
t.Fatalf("expected 2 names, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": ""},
|
||||
map[string]any{"name": "valid"},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 1 || names[0] != "valid" {
|
||||
t.Fatalf("expected [valid], got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) {
|
||||
tools := []any{"not a map", 42}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 0 {
|
||||
t.Fatalf("expected 0, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesNil(t *testing.T) {
|
||||
names := extractClaudeToolNames(nil)
|
||||
if len(names) != 0 {
|
||||
t.Fatalf("expected 0, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesSupportsOpenAIStyleFunctionTool(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 1 || names[0] != "search" {
|
||||
t.Fatalf("expected [search], got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── toMessageMaps ───────────────────────────────────────────────────
|
||||
|
||||
func TestToMessageMapsNormal(t *testing.T) {
|
||||
input := []any{
|
||||
map[string]any{"role": "user", "content": "Hello"},
|
||||
}
|
||||
got := toMessageMaps(input)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessageMapsNonSlice(t *testing.T) {
|
||||
got := toMessageMaps("not a slice")
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessageMapsSkipsNonMap(t *testing.T) {
|
||||
input := []any{"string", map[string]any{"role": "user"}, 42}
|
||||
got := toMessageMaps(input)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 map, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessageMapsNil(t *testing.T) {
|
||||
got := toMessageMaps(nil)
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── extractMessageContent ──────────────────────────────────────────
|
||||
|
||||
func TestExtractMessageContentString(t *testing.T) {
|
||||
if got := extractMessageContent("hello"); got != "hello" {
|
||||
t.Fatalf("expected 'hello', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMessageContentArray(t *testing.T) {
|
||||
input := []any{"part1", "part2"}
|
||||
got := extractMessageContent(input)
|
||||
if got != "part1\npart2" {
|
||||
t.Fatalf("expected joined, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMessageContentOther(t *testing.T) {
|
||||
got := extractMessageContent(42)
|
||||
if got != "42" {
|
||||
t.Fatalf("expected '42', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMessageContentNil(t *testing.T) {
|
||||
got := extractMessageContent(nil)
|
||||
if got != "<nil>" {
|
||||
t.Fatalf("expected '<nil>', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── cloneMap ────────────────────────────────────────────────────────
|
||||
|
||||
func TestCloneMapBasic(t *testing.T) {
|
||||
original := map[string]any{"a": 1, "b": "hello"}
|
||||
clone := cloneMap(original)
|
||||
original["a"] = 999
|
||||
if clone["a"] != 1 {
|
||||
t.Fatalf("expected 1, got %v", clone["a"])
|
||||
}
|
||||
if clone["b"] != "hello" {
|
||||
t.Fatalf("expected 'hello', got %v", clone["b"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneMapEmpty(t *testing.T) {
|
||||
clone := cloneMap(map[string]any{})
|
||||
if len(clone) != 0 {
|
||||
t.Fatalf("expected empty, got %v", clone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneMapNested(t *testing.T) {
|
||||
// cloneMap is shallow, so nested maps share references
|
||||
inner := map[string]any{"key": "value"}
|
||||
original := map[string]any{"nested": inner}
|
||||
clone := cloneMap(original)
|
||||
// Shallow clone means inner is shared
|
||||
inner["key"] = "modified"
|
||||
cloneNested := clone["nested"].(map[string]any)
|
||||
if cloneNested["key"] != "modified" {
|
||||
t.Fatal("expected shallow clone to share nested references")
|
||||
}
|
||||
}
|
||||
|
||||
// helper
|
||||
func containsStr(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub))
|
||||
}
|
||||
|
||||
func findSubstring(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
291
internal/adapter/claude/handler_utils.go
Normal file
291
internal/adapter/claude/handler_utils.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func normalizeClaudeMessages(messages []any) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", msg["role"])))
|
||||
switch content := msg["content"].(type) {
|
||||
case []any:
|
||||
textParts := make([]string, 0, len(content))
|
||||
flushText := func() {
|
||||
if len(textParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": strings.Join(textParts, "\n"),
|
||||
})
|
||||
textParts = textParts[:0]
|
||||
}
|
||||
for _, block := range content {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", b["type"])))
|
||||
switch typeStr {
|
||||
case "text":
|
||||
if t, ok := b["text"].(string); ok {
|
||||
textParts = append(textParts, t)
|
||||
}
|
||||
case "tool_use":
|
||||
if role == "assistant" {
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolUseToAssistant(b); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
case "tool_result":
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolResultToToolMessage(b); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
default:
|
||||
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
flushText()
|
||||
default:
|
||||
copied := cloneMap(msg)
|
||||
out = append(out, copied)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeToolPrompt(tools []any) string {
|
||||
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, desc, schemaObj := extractClaudeToolMeta(m)
|
||||
schema, _ := json.Marshal(schemaObj)
|
||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
}
|
||||
parts = append(parts,
|
||||
"When you need a tool, respond with Claude-native tool use (tool_use) using the provided tool schema. Do not print tool-call JSON in text.",
|
||||
"Tool roundtrip context is included directly in the conversation messages (assistant tool_use/tool_calls and tool results).",
|
||||
"After receiving a valid tool result, continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||
)
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
func formatClaudeToolResultForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "tool_result",
|
||||
"content": block["content"],
|
||||
}
|
||||
if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])); toolCallID != "" {
|
||||
payload["tool_call_id"] = toolCallID
|
||||
} else if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])); toolCallID != "" {
|
||||
payload["tool_call_id"] = toolCallID
|
||||
}
|
||||
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
|
||||
payload["name"] = name
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", payload))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
callID := strings.TrimSpace(fmt.Sprintf("%v", block["id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_claude"
|
||||
}
|
||||
arguments := block["input"]
|
||||
if arguments == nil {
|
||||
arguments = map[string]any{}
|
||||
}
|
||||
argsJSON, err := json.Marshal(arguments)
|
||||
if err != nil || len(argsJSON) == 0 {
|
||||
argsJSON = []byte("{}")
|
||||
}
|
||||
toolCalls := []any{
|
||||
map[string]any{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": string(argsJSON),
|
||||
},
|
||||
},
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"content": marshalCompactJSON(toolCalls),
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "call_claude"
|
||||
}
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": toolCallID,
|
||||
"content": normalizeClaudeToolResultContent(block["content"]),
|
||||
}
|
||||
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeClaudeToolResultContent(content any) any {
|
||||
if text, ok := content.(string); ok {
|
||||
return text
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "tool_result",
|
||||
"content": content,
|
||||
}
|
||||
b, err := json.Marshal(sanitizeClaudeBlockForPrompt(payload))
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", content))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func formatClaudeBlockRaw(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
b, err := json.Marshal(block)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", block))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if ok && msg["role"] == "system" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractClaudeToolNames(tools []any) []string {
|
||||
out := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _, _ := extractClaudeToolMeta(m)
|
||||
if name != "" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractClaudeToolMeta(m map[string]any) (string, string, any) {
|
||||
name, _ := m["name"].(string)
|
||||
desc, _ := m["description"].(string)
|
||||
schemaObj := m["input_schema"]
|
||||
if schemaObj == nil {
|
||||
schemaObj = m["parameters"]
|
||||
}
|
||||
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name, _ = fn["name"].(string)
|
||||
}
|
||||
if strings.TrimSpace(desc) == "" {
|
||||
desc, _ = fn["description"].(string)
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["input_schema"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["parameters"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj
|
||||
}
|
||||
|
||||
func toMessageMaps(v any) []map[string]any {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractMessageContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, it := range x {
|
||||
parts = append(parts, fmt.Sprintf("%v", it))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return fmt.Sprintf("%v", x)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
105
internal/adapter/claude/handler_utils_sanitize.go
Normal file
105
internal/adapter/claude/handler_utils_sanitize.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
maxClaudeRawPromptChars = 1024
|
||||
omittedBinaryMarker = "[omitted_binary_payload]"
|
||||
)
|
||||
|
||||
func formatClaudeUnknownBlockForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
safe := sanitizeClaudeBlockForPrompt(block)
|
||||
raw := strings.TrimSpace(formatClaudeBlockRaw(safe))
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if len(raw) > maxClaudeRawPromptChars {
|
||||
return raw[:maxClaudeRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func sanitizeClaudeBlockForPrompt(block map[string]any) map[string]any {
|
||||
out := cloneMap(block)
|
||||
for k, v := range out {
|
||||
if looksLikeBinaryFieldName(k) {
|
||||
out[k] = omittedBinaryMarker
|
||||
continue
|
||||
}
|
||||
switch inner := v.(type) {
|
||||
case map[string]any:
|
||||
out[k] = sanitizeClaudeBlockForPrompt(inner)
|
||||
case []any:
|
||||
out[k] = sanitizeClaudeArrayForPrompt(inner)
|
||||
case string:
|
||||
out[k] = sanitizeClaudeStringForPrompt(k, inner)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeClaudeArrayForPrompt(items []any) []any {
|
||||
out := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
switch v := item.(type) {
|
||||
case map[string]any:
|
||||
out = append(out, sanitizeClaudeBlockForPrompt(v))
|
||||
case []any:
|
||||
out = append(out, sanitizeClaudeArrayForPrompt(v))
|
||||
default:
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeClaudeStringForPrompt(key, value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if looksLikeBinaryFieldName(key) || looksLikeBase64Payload(trimmed) {
|
||||
return omittedBinaryMarker
|
||||
}
|
||||
if len(trimmed) > maxClaudeRawPromptChars {
|
||||
return trimmed[:maxClaudeRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func looksLikeBinaryFieldName(name string) bool {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
return n == "data" || n == "bytes" || n == "base64" || n == "inline_data" || n == "inlinedata"
|
||||
}
|
||||
|
||||
func looksLikeBase64Payload(v string) bool {
|
||||
if len(v) < 512 {
|
||||
return false
|
||||
}
|
||||
compact := strings.TrimRight(v, "=")
|
||||
if compact == "" {
|
||||
return false
|
||||
}
|
||||
for _, ch := range compact {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func marshalCompactJSON(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
44
internal/adapter/claude/route_alias_test.go
Normal file
44
internal/adapter/claude/route_alias_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type routeAliasAuthStub struct{}
|
||||
|
||||
func (routeAliasAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return nil, auth.ErrUnauthorized
|
||||
}
|
||||
|
||||
func (routeAliasAuthStub) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
func TestClaudeRouteAliasesDoNot404(t *testing.T) {
|
||||
h := &Handler{
|
||||
Auth: routeAliasAuthStub{},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
paths := []string{
|
||||
"/anthropic/v1/messages",
|
||||
"/v1/messages",
|
||||
"/messages",
|
||||
"/anthropic/v1/messages/count_tokens",
|
||||
"/v1/messages/count_tokens",
|
||||
"/messages/count_tokens",
|
||||
}
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest(http.MethodPost, path, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusNotFound {
|
||||
t.Fatalf("expected route %s to be registered, got 404", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
113
internal/adapter/claude/standard_request.go
Normal file
113
internal/adapter/claude/standard_request.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type claudeNormalizedRequest struct {
|
||||
Standard util.StandardRequest
|
||||
NormalizedMessages []any
|
||||
}
|
||||
|
||||
func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNormalizedRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
|
||||
}
|
||||
if _, ok := req["max_tokens"]; !ok {
|
||||
req["max_tokens"] = 8192
|
||||
}
|
||||
normalizedMessages := normalizeClaudeMessages(messagesRaw)
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalizedMessages
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested)
|
||||
|
||||
dsPayload := convertClaudeToDeepSeek(payload, store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||
if !ok {
|
||||
thinkingEnabled = false
|
||||
searchEnabled = false
|
||||
}
|
||||
finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
|
||||
return claudeNormalizedRequest{
|
||||
Standard: util.StandardRequest{
|
||||
Surface: "anthropic_messages",
|
||||
RequestedModel: strings.TrimSpace(model),
|
||||
ResolvedModel: dsModel,
|
||||
ResponseModel: strings.TrimSpace(model),
|
||||
Messages: payload["messages"].([]any),
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
},
|
||||
NormalizedMessages: normalizedMessages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any {
|
||||
if len(tools) == 0 {
|
||||
return normalizedMessages
|
||||
}
|
||||
toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools))
|
||||
if toolPrompt == "" {
|
||||
return normalizedMessages
|
||||
}
|
||||
|
||||
// Prefer top-level Anthropic-style system prompt when available.
|
||||
if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" {
|
||||
payload["system"] = mergeSystemPrompt(systemText, toolPrompt)
|
||||
return normalizedMessages
|
||||
}
|
||||
|
||||
messages := cloneAnySlice(normalizedMessages)
|
||||
for i := range messages {
|
||||
msg, ok := messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msg["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "system") {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt)
|
||||
messages[i] = copied
|
||||
return messages
|
||||
}
|
||||
|
||||
return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...)
|
||||
}
|
||||
|
||||
func mergeSystemPrompt(base, extra string) string {
|
||||
base = strings.TrimSpace(base)
|
||||
extra = strings.TrimSpace(extra)
|
||||
switch {
|
||||
case base == "":
|
||||
return extra
|
||||
case extra == "":
|
||||
return base
|
||||
default:
|
||||
return base + "\n\n" + extra
|
||||
}
|
||||
}
|
||||
|
||||
func cloneAnySlice(in []any) []any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, len(in))
|
||||
copy(out, in)
|
||||
return out
|
||||
}
|
||||
92
internal/adapter/claude/standard_request_test.go
Normal file
92
internal/adapter/claude/standard_request_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func TestNormalizeClaudeRequest(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"stream": true,
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if norm.Standard.ResolvedModel == "" {
|
||||
t.Fatalf("expected resolved model")
|
||||
}
|
||||
if !norm.Standard.Stream {
|
||||
t.Fatalf("expected stream=true")
|
||||
}
|
||||
if len(norm.Standard.ToolNames) == 0 {
|
||||
t.Fatalf("expected tool names")
|
||||
}
|
||||
if norm.Standard.FinalPrompt == "" {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "system", "content": "baseline rule"},
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||
t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
if !containsStr(norm.Standard.FinalPrompt, "baseline rule") {
|
||||
t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"system": "top-level system",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
if !containsStr(norm.Standard.FinalPrompt, "top-level system") {
|
||||
t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||
t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
}
|
||||
163
internal/adapter/claude/stream_runtime_core.go
Normal file
163
internal/adapter/claude/stream_runtime_core.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type claudeStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
model string
|
||||
toolNames []string
|
||||
messages []any
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
bufferToolContent bool
|
||||
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
|
||||
nextBlockIndex int
|
||||
thinkingBlockOpen bool
|
||||
thinkingBlockIndex int
|
||||
textBlockOpen bool
|
||||
textBlockIndex int
|
||||
ended bool
|
||||
upstreamErr string
|
||||
}
|
||||
|
||||
func newClaudeStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
model string,
|
||||
messages []any,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
) *claudeStreamRuntime {
|
||||
return &claudeStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
model: model,
|
||||
messages: messages,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferToolContent: len(toolNames) > 0,
|
||||
toolNames: toolNames,
|
||||
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
thinkingBlockIndex: -1,
|
||||
textBlockIndex: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
s.upstreamErr = parsed.ErrorMessage
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||
}
|
||||
if parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
|
||||
if p.Type == "thinking" {
|
||||
if !s.thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.closeTextBlock()
|
||||
if !s.thinkingBlockOpen {
|
||||
s.thinkingBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
},
|
||||
})
|
||||
s.thinkingBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "thinking_delta",
|
||||
"thinking": p.Text,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferToolContent {
|
||||
if hasUnclosedCodeFence(s.text.String()) {
|
||||
continue
|
||||
}
|
||||
detected := util.ParseToolCalls(s.text.String(), s.toolNames)
|
||||
if len(detected) > 0 {
|
||||
s.finalize("tool_use")
|
||||
return streamengine.ParsedDecision{
|
||||
ContentSeen: true,
|
||||
Stop: true,
|
||||
StopReason: streamengine.StopReason("tool_use_detected"),
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
s.closeThinkingBlock()
|
||||
if !s.textBlockOpen {
|
||||
s.textBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.textBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.textBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.textBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": p.Text,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func hasUnclosedCodeFence(text string) bool {
|
||||
return strings.Count(text, "```")%2 == 1
|
||||
}
|
||||
59
internal/adapter/claude/stream_runtime_emit.go
Normal file
59
internal/adapter/claude/stream_runtime_emit.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (s *claudeStreamRuntime) send(event string, v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = s.w.Write([]byte("event: "))
|
||||
_, _ = s.w.Write([]byte(event))
|
||||
_, _ = s.w.Write([]byte("\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendError(message string) {
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
msg = "upstream stream error"
|
||||
}
|
||||
s.send("error", map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "api_error",
|
||||
"message": msg,
|
||||
"code": "internal_error",
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendPing() {
|
||||
s.send("ping", map[string]any{"type": "ping"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendMessageStart() {
|
||||
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
|
||||
s.send("message_start", map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": s.messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": s.model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
})
|
||||
}
|
||||
122
internal/adapter/claude/stream_runtime_finalize.go
Normal file
122
internal/adapter/claude/stream_runtime_finalize.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (s *claudeStreamRuntime) closeThinkingBlock() {
|
||||
if !s.thinkingBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.thinkingBlockIndex,
|
||||
})
|
||||
s.thinkingBlockOpen = false
|
||||
s.thinkingBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) closeTextBlock() {
|
||||
if !s.textBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.textBlockIndex,
|
||||
})
|
||||
s.textBlockOpen = false
|
||||
s.textBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
if s.ended {
|
||||
return
|
||||
}
|
||||
s.ended = true
|
||||
|
||||
s.closeThinkingBlock()
|
||||
s.closeTextBlock()
|
||||
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
|
||||
if s.bufferToolContent {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) == 0 && finalText == "" && finalThinking != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, s.toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
idx := s.nextBlockIndex + i
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
},
|
||||
})
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
s.nextBlockIndex += len(detected)
|
||||
} else if finalText != "" {
|
||||
idx := s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": idx,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": finalText,
|
||||
},
|
||||
})
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
s.send("message_stop", map[string]any{"type": "message_stop"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
|
||||
if string(reason) == "upstream_error" {
|
||||
s.sendError(s.upstreamErr)
|
||||
return
|
||||
}
|
||||
if scannerErr != nil {
|
||||
s.sendError(scannerErr.Error())
|
||||
return
|
||||
}
|
||||
s.finalize("end_turn")
|
||||
}
|
||||
100
internal/adapter/claude/stream_status_test.go
Normal file
100
internal/adapter/claude/stream_status_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimw "github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type streamStatusClaudeAuthStub struct{}
|
||||
|
||||
func (streamStatusClaudeAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: "direct-token",
|
||||
CallerID: "caller:test",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (streamStatusClaudeAuthStub) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
type streamStatusClaudeDSStub struct{}
|
||||
|
||||
func (streamStatusClaudeDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "session-id", nil
|
||||
}
|
||||
|
||||
func (streamStatusClaudeDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (streamStatusClaudeDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
body := "data: {\"p\":\"response/content\",\"v\":\"hello\"}\n" + "data: [DONE]\n"
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: ioNopCloser{strings.NewReader(body)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ioNopCloser struct {
|
||||
*strings.Reader
|
||||
}
|
||||
|
||||
func (ioNopCloser) Close() error { return nil }
|
||||
|
||||
type streamStatusClaudeStoreStub struct{}
|
||||
|
||||
func (streamStatusClaudeStoreStub) ClaudeMapping() map[string]string {
|
||||
return map[string]string{
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner",
|
||||
}
|
||||
}
|
||||
|
||||
func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
next.ServeHTTP(ww, r)
|
||||
*statuses = append(*statuses, ww.Status())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: streamStatusClaudeStoreStub{},
|
||||
Auth: streamStatusClaudeAuthStub{},
|
||||
DS: streamStatusClaudeDSStub{},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureClaudeStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||
}
|
||||
if statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||
}
|
||||
}
|
||||
244
internal/adapter/gemini/convert_messages.go
Normal file
244
internal/adapter/gemini/convert_messages.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package gemini
|
||||
|
||||
import "strings"
|
||||
|
||||
const maxGeminiRawPromptChars = 1024
|
||||
|
||||
func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
out := make([]any, 0, 8)
|
||||
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": "system",
|
||||
"content": sys,
|
||||
})
|
||||
}
|
||||
|
||||
contents, _ := req["contents"].([]any)
|
||||
for _, item := range contents {
|
||||
content, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := mapGeminiRole(content["role"])
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
parts, _ := content["parts"].([]any)
|
||||
if len(parts) == 0 {
|
||||
if text := strings.TrimSpace(asString(content["text"])); text != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
textParts := make([]string, 0, len(parts))
|
||||
flushText := func() {
|
||||
if len(textParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": strings.Join(textParts, "\n"),
|
||||
})
|
||||
textParts = textParts[:0]
|
||||
}
|
||||
|
||||
for _, rawPart := range parts {
|
||||
part, ok := rawPart.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||
textParts = append(textParts, text)
|
||||
continue
|
||||
}
|
||||
|
||||
if fnCall, ok := part["functionCall"].(map[string]any); ok {
|
||||
flushText()
|
||||
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
|
||||
callID := strings.TrimSpace(asString(fnCall["id"]))
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyJSON(fnCall["args"]),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if fnResp, ok := part["functionResponse"].(map[string]any); ok {
|
||||
flushText()
|
||||
name := strings.TrimSpace(asString(fnResp["name"]))
|
||||
callID := strings.TrimSpace(asString(fnResp["id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(asString(fnResp["callId"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
}
|
||||
content := fnResp["response"]
|
||||
if content == nil {
|
||||
content = fnResp["output"]
|
||||
}
|
||||
if content == nil {
|
||||
content = ""
|
||||
}
|
||||
msg := map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": callID,
|
||||
"content": content,
|
||||
}
|
||||
if name != "" {
|
||||
msg["name"] = name
|
||||
}
|
||||
out = append(out, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(formatGeminiUnknownPartForPrompt(part)); raw != "" && raw != "null" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
}
|
||||
flushText()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeGeminiSystemInstruction(raw any) string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if parts, ok := v["parts"].([]any); ok {
|
||||
texts := make([]string, 0, len(parts))
|
||||
for _, item := range parts {
|
||||
part, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if text := strings.TrimSpace(asString(part["text"])); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(texts, "\n")
|
||||
}
|
||||
if text := strings.TrimSpace(asString(v["text"])); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func mapGeminiRole(v any) string {
|
||||
switch strings.ToLower(strings.TrimSpace(asString(v))) {
|
||||
case "user":
|
||||
return "user"
|
||||
case "model", "assistant":
|
||||
return "assistant"
|
||||
case "system":
|
||||
return "system"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func formatGeminiUnknownPartForPrompt(part map[string]any) string {
|
||||
safe := sanitizeGeminiPartForPrompt(part)
|
||||
raw := strings.TrimSpace(stringifyJSON(safe))
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if len(raw) > maxGeminiRawPromptChars {
|
||||
return raw[:maxGeminiRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func sanitizeGeminiPartForPrompt(part map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(part))
|
||||
for k, v := range part {
|
||||
if looksLikeGeminiBinaryField(k) {
|
||||
out[k] = "[omitted_binary_payload]"
|
||||
continue
|
||||
}
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
out[k] = sanitizeGeminiPartForPrompt(x)
|
||||
case []any:
|
||||
out[k] = sanitizeGeminiArrayForPrompt(x)
|
||||
case string:
|
||||
out[k] = sanitizeGeminiStringForPrompt(k, x)
|
||||
default:
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeGeminiArrayForPrompt(items []any) []any {
|
||||
out := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
out = append(out, sanitizeGeminiPartForPrompt(x))
|
||||
case []any:
|
||||
out = append(out, sanitizeGeminiArrayForPrompt(x))
|
||||
default:
|
||||
out = append(out, x)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeGeminiStringForPrompt(key, value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if looksLikeGeminiBinaryField(key) || looksLikeGeminiBase64(trimmed) {
|
||||
return "[omitted_binary_payload]"
|
||||
}
|
||||
if len(trimmed) > maxGeminiRawPromptChars {
|
||||
return trimmed[:maxGeminiRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func looksLikeGeminiBinaryField(name string) bool {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
return n == "data" || n == "bytes" || n == "inlinedata" || n == "inline_data" || n == "base64"
|
||||
}
|
||||
|
||||
func looksLikeGeminiBase64(v string) bool {
|
||||
if len(v) < 512 {
|
||||
return false
|
||||
}
|
||||
compact := strings.TrimRight(v, "=")
|
||||
if compact == "" {
|
||||
return false
|
||||
}
|
||||
for _, ch := range compact {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
84
internal/adapter/gemini/convert_messages_test.go
Normal file
84
internal/adapter/gemini/convert_messages_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGeminiMessagesFromRequestPreservesFunctionRoundtrip(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "model",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"id": "call_g1",
|
||||
"name": "search_web",
|
||||
"args": map[string]any{"query": "ai"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionResponse": map[string]any{
|
||||
"id": "call_g1",
|
||||
"name": "search_web",
|
||||
"response": "ok",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := geminiMessagesFromRequest(req)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected two normalized messages, got %#v", got)
|
||||
}
|
||||
assistant, _ := got[0].(map[string]any)
|
||||
if assistant["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant first, got %#v", assistant)
|
||||
}
|
||||
tc, _ := assistant["tool_calls"].([]any)
|
||||
if len(tc) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", assistant["tool_calls"])
|
||||
}
|
||||
toolMsg, _ := got[1].(map[string]any)
|
||||
if toolMsg["role"] != "tool" || toolMsg["tool_call_id"] != "call_g1" {
|
||||
t.Fatalf("expected tool message with call id, got %#v", toolMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{"text": "hello"},
|
||||
map[string]any{"inlineData": map[string]any{"mimeType": "image/png", "data": strings.Repeat("A", 2048)}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := geminiMessagesFromRequest(req)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %#v", got)
|
||||
}
|
||||
msg, _ := got[0].(map[string]any)
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "hello") || !strings.Contains(content, "inlineData") {
|
||||
t.Fatalf("expected unknown part preserved as raw json text, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "[omitted_binary_payload]") {
|
||||
t.Fatalf("expected inlineData payload to be redacted, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, strings.Repeat("A", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be embedded, got %q", content)
|
||||
}
|
||||
}
|
||||
54
internal/adapter/gemini/convert_passthrough.go
Normal file
54
internal/adapter/gemini/convert_passthrough.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func collectGeminiPassThrough(req map[string]any) map[string]any {
|
||||
cfg, _ := req["generationConfig"].(map[string]any)
|
||||
if len(cfg) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := map[string]any{}
|
||||
if v, ok := cfg["temperature"]; ok {
|
||||
out["temperature"] = v
|
||||
}
|
||||
if v, ok := cfg["topP"]; ok {
|
||||
out["top_p"] = v
|
||||
}
|
||||
if v, ok := cfg["maxOutputTokens"]; ok {
|
||||
out["max_tokens"] = v
|
||||
}
|
||||
if v, ok := cfg["stopSequences"]; ok {
|
||||
out["stop"] = v
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func stringifyJSON(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
46
internal/adapter/gemini/convert_request.go
Normal file
46
internal/adapter/gemini/convert_request.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/adapter/openai"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
|
||||
requestedModel := strings.TrimSpace(routeModel)
|
||||
if requestedModel == "" {
|
||||
return util.StandardRequest{}, fmt.Errorf("model is required in request path")
|
||||
}
|
||||
|
||||
resolvedModel, ok := config.ResolveModel(store, requestedModel)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
|
||||
messagesRaw := geminiMessagesFromRequest(req)
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
|
||||
}
|
||||
|
||||
toolsRaw := convertGeminiTools(req["tools"])
|
||||
finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
|
||||
passThrough := collectGeminiPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "google_gemini",
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: requestedModel,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: stream,
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
71
internal/adapter/gemini/convert_tools.go
Normal file
71
internal/adapter/gemini/convert_tools.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package gemini
|
||||
|
||||
import "strings"
|
||||
|
||||
func convertGeminiTools(raw any) []any {
|
||||
tools, _ := raw.([]any)
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(tools))
|
||||
for _, item := range tools {
|
||||
tool, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
|
||||
for _, declRaw := range fnDecls {
|
||||
decl, ok := declRaw.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(asString(decl["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
function := map[string]any{
|
||||
"name": name,
|
||||
}
|
||||
if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
|
||||
function["description"] = desc
|
||||
}
|
||||
if params, ok := decl["parameters"].(map[string]any); ok {
|
||||
function["parameters"] = params
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"type": "function",
|
||||
"function": function,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// OpenAI-style passthrough fallback.
|
||||
if _, ok := tool["function"].(map[string]any); ok {
|
||||
out = append(out, tool)
|
||||
continue
|
||||
}
|
||||
|
||||
// Loose fallback for flattened function schema objects.
|
||||
name := strings.TrimSpace(asString(tool["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
fn := map[string]any{"name": name}
|
||||
if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
|
||||
fn["description"] = desc
|
||||
}
|
||||
if params, ok := tool["parameters"].(map[string]any); ok {
|
||||
fn["parameters"] = params
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"type": "function",
|
||||
"function": fn,
|
||||
})
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
29
internal/adapter/gemini/deps.go
Normal file
29
internal/adapter/gemini/deps.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type AuthResolver interface {
|
||||
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||
Release(a *auth.RequestAuth)
|
||||
}
|
||||
|
||||
type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
}
|
||||
|
||||
type ConfigReader interface {
|
||||
ModelAliases() map[string]string
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
28
internal/adapter/gemini/handler_errors.go
Normal file
28
internal/adapter/gemini/handler_errors.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package gemini
|
||||
|
||||
import "net/http"
|
||||
|
||||
func writeGeminiError(w http.ResponseWriter, status int, message string) {
|
||||
errorStatus := "INVALID_ARGUMENT"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errorStatus = "UNAUTHENTICATED"
|
||||
case http.StatusForbidden:
|
||||
errorStatus = "PERMISSION_DENIED"
|
||||
case http.StatusTooManyRequests:
|
||||
errorStatus = "RESOURCE_EXHAUSTED"
|
||||
case http.StatusNotFound:
|
||||
errorStatus = "NOT_FOUND"
|
||||
default:
|
||||
if status >= 500 {
|
||||
errorStatus = "INTERNAL"
|
||||
}
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": errorStatus,
|
||||
},
|
||||
})
|
||||
}
|
||||
135
internal/adapter/gemini/handler_generate.go
Normal file
135
internal/adapter/gemini/handler_generate.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeGeminiError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
|
||||
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||
stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
|
||||
if stream {
|
||||
h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
|
||||
}
|
||||
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": model,
|
||||
"usageMetadata": usage,
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"promptTokenCount": promptTokens,
|
||||
"candidatesTokenCount": reasoningTokens + completionTokens,
|
||||
"totalTokenCount": promptTokens + reasoningTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
parts := make([]map[string]any, 0, len(detected))
|
||||
for _, tc := range detected {
|
||||
parts = append(parts, map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": tc.Name,
|
||||
"args": tc.Input,
|
||||
},
|
||||
})
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
text := finalText
|
||||
if strings.TrimSpace(text) == "" {
|
||||
text = finalThinking
|
||||
}
|
||||
return []map[string]any{{"text": text}}
|
||||
}
|
||||
32
internal/adapter/gemini/handler_routes.go
Normal file
32
internal/adapter/gemini/handler_routes.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
|
||||
r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||
r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
|
||||
r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
|
||||
}
|
||||
|
||||
func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleGenerateContent(w, r, false)
|
||||
}
|
||||
|
||||
func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleGenerateContent(w, r, true)
|
||||
}
|
||||
181
internal/adapter/gemini/handler_stream_runtime.go
Normal file
181
internal/adapter/gemini/handler_stream_runtime.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnParsed: runtime.onParsed,
|
||||
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||
runtime.finalize()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
type geminiStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
model string
|
||||
finalPrompt string
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
bufferContent bool
|
||||
toolNames []string
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
func newGeminiStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
) *geminiStreamRuntime {
|
||||
return &geminiStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferContent: len(toolNames) > 0,
|
||||
toolNames: toolNames,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
if p.Type == "thinking" {
|
||||
if s.thinkingEnabled {
|
||||
s.thinking.WriteString(p.Text)
|
||||
}
|
||||
continue
|
||||
}
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferContent {
|
||||
continue
|
||||
}
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": []map[string]any{{"text": p.Text}},
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func (s *geminiStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
|
||||
if s.bufferContent {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": parts,
|
||||
},
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
})
|
||||
}
|
||||
|
||||
s.sendChunk(map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"content": map[string]any{
|
||||
"role": "model",
|
||||
"parts": []map[string]any{
|
||||
{"text": ""},
|
||||
},
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
|
||||
})
|
||||
}
|
||||
252
internal/adapter/gemini/handler_test.go
Normal file
252
internal/adapter/gemini/handler_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type testGeminiConfig struct{}
|
||||
|
||||
func (testGeminiConfig) ModelAliases() map[string]string { return nil }
|
||||
|
||||
type testGeminiAuth struct {
|
||||
a *auth.RequestAuth
|
||||
err error
|
||||
}
|
||||
|
||||
func (m testGeminiAuth) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
if m.a != nil {
|
||||
return m.a, nil
|
||||
}
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: "direct-token",
|
||||
CallerID: "caller:test",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (testGeminiAuth) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
type testGeminiDS struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (m testGeminiDS) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "session-id", nil
|
||||
}
|
||||
|
||||
func (m testGeminiDS) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.resp, nil
|
||||
}
|
||||
|
||||
func makeGeminiUpstreamResponse(lines ...string) *http.Response {
|
||||
body := strings.Join(lines, "\n")
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
body += "\n"
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiRoutesRegistered(t *testing.T) {
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
Auth: testGeminiAuth{err: auth.ErrUnauthorized},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
paths := []string{
|
||||
"/v1beta/models/gemini-2.5-pro:generateContent",
|
||||
"/v1beta/models/gemini-2.5-pro:streamGenerateContent",
|
||||
"/v1/models/gemini-2.5-pro:generateContent",
|
||||
"/v1/models/gemini-2.5-pro:streamGenerateContent",
|
||||
}
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusNotFound {
|
||||
t.Fatalf("expected route %s to be registered, got 404", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
Auth: testGeminiAuth{},
|
||||
DS: testGeminiDS{resp: upstream},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{
|
||||
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
|
||||
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
candidates, _ := out["candidates"].([]any)
|
||||
if len(candidates) == 0 {
|
||||
t.Fatalf("expected non-empty candidates: %#v", out)
|
||||
}
|
||||
c0, _ := candidates[0].(map[string]any)
|
||||
content, _ := c0["content"].(map[string]any)
|
||||
parts, _ := content["parts"].([]any)
|
||||
if len(parts) == 0 {
|
||||
t.Fatalf("expected non-empty parts: %#v", content)
|
||||
}
|
||||
part0, _ := parts[0].(map[string]any)
|
||||
functionCall, _ := part0["functionCall"].(map[string]any)
|
||||
if functionCall["name"] != "eval_javascript" {
|
||||
t.Fatalf("expected functionCall name eval_javascript, got %#v", functionCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{Store: testGeminiConfig{}, Auth: testGeminiAuth{}, DS: testGeminiDS{resp: upstream}}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{
|
||||
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
|
||||
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
candidates, _ := out["candidates"].([]any)
|
||||
c0, _ := candidates[0].(map[string]any)
|
||||
content, _ := c0["content"].(map[string]any)
|
||||
parts, _ := content["parts"].([]any)
|
||||
part0, _ := parts[0].(map[string]any)
|
||||
functionCall, _ := part0["functionCall"].(map[string]any)
|
||||
if functionCall["name"] != "eval_javascript" {
|
||||
t.Fatalf("expected functionCall name eval_javascript for mixed snippet, got %#v", functionCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"hello "}`,
|
||||
`data: {"p":"response/content","v":"world"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{
|
||||
Store: testGeminiConfig{},
|
||||
Auth: testGeminiAuth{},
|
||||
DS: testGeminiDS{resp: upstream},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:streamGenerateContent?alt=sse", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "data: ") {
|
||||
t.Fatalf("expected SSE data frames, got body=%s", rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), `"finishReason":"STOP"`) {
|
||||
t.Fatalf("expected stream finish frame, got body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
frames := extractGeminiSSEFrames(t, rec.Body.String())
|
||||
if len(frames) == 0 {
|
||||
t.Fatalf("expected non-empty sse frames, body=%s", rec.Body.String())
|
||||
}
|
||||
last := frames[len(frames)-1]
|
||||
candidates, _ := last["candidates"].([]any)
|
||||
if len(candidates) == 0 {
|
||||
t.Fatalf("expected finish frame candidates, got %#v", last)
|
||||
}
|
||||
c0, _ := candidates[0].(map[string]any)
|
||||
content, _ := c0["content"].(map[string]any)
|
||||
if content == nil {
|
||||
t.Fatalf("expected non-null content in finish frame, got %#v", c0)
|
||||
}
|
||||
parts, _ := content["parts"].([]any)
|
||||
if len(parts) == 0 {
|
||||
t.Fatalf("expected non-empty parts in finish frame content, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
func extractGeminiSSEFrames(t *testing.T, body string) []map[string]any {
|
||||
t.Helper()
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
out := make([]map[string]any, 0, 4)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
var frame map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &frame); err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, frame)
|
||||
}
|
||||
return out
|
||||
}
|
||||
278
internal/adapter/openai/chat_stream_runtime.go
Normal file
278
internal/adapter/openai/chat_stream_runtime.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type chatStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
completionID string
|
||||
created int64
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
|
||||
firstChunkSent bool
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
toolCallsDoneEmitted bool
|
||||
|
||||
toolSieve toolStreamSieveState
|
||||
streamToolCallIDs map[int]string
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
func newChatStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
completionID string,
|
||||
created int64,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
) *chatStreamRuntime {
|
||||
return &chatStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
completionID: completionID,
|
||||
created: created,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
toolNames: toolNames,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamToolNames: map[int]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) sendKeepAlive() {
|
||||
if !s.canFlush {
|
||||
return
|
||||
}
|
||||
_, _ = s.w.Write([]byte(": keep-alive\n\n"))
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) sendChunk(v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) sendDone() {
|
||||
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := sanitizeLeakedToolHistory(s.text.String())
|
||||
detected := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
} else if s.bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) {
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
|
||||
nil,
|
||||
))
|
||||
}
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
}
|
||||
cleaned := sanitizeLeakedToolHistory(evt.Content)
|
||||
if cleaned == "" {
|
||||
continue
|
||||
}
|
||||
delta := map[string]any{
|
||||
"content": cleaned,
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
if len(detected.Calls) > 0 || s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)},
|
||||
openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText),
|
||||
))
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")}
|
||||
}
|
||||
if parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested}
|
||||
}
|
||||
|
||||
newChoices := make([]map[string]any, 0, len(parsed.Parts))
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
delta := map[string]any{}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
if p.Type == "thinking" {
|
||||
if s.thinkingEnabled {
|
||||
s.thinking.WriteString(p.Text)
|
||||
delta["reasoning_content"] = p.Text
|
||||
}
|
||||
} else {
|
||||
s.text.WriteString(p.Text)
|
||||
if !s.bufferToolContent {
|
||||
delta["content"] = p.Text
|
||||
} else {
|
||||
events := processToolSieveChunk(&s.toolSieve, p.Text, s.toolNames)
|
||||
for _, evt := range events {
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames)
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
formatted := formatIncrementalStreamToolCallDeltas(filtered, s.streamToolCallIDs)
|
||||
if len(formatted) == 0 {
|
||||
continue
|
||||
}
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatted,
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
|
||||
continue
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
cleaned := sanitizeLeakedToolHistory(evt.Content)
|
||||
if cleaned == "" {
|
||||
continue
|
||||
}
|
||||
contentDelta := map[string]any{
|
||||
"content": cleaned,
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
contentDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, contentDelta))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(delta) > 0 {
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, delta))
|
||||
}
|
||||
}
|
||||
|
||||
if len(newChoices) > 0 {
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(s.completionID, s.created, s.model, newChoices, nil))
|
||||
}
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
37
internal/adapter/openai/deps.go
Normal file
37
internal/adapter/openai/deps.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type AuthResolver interface {
|
||||
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||
DetermineCaller(req *http.Request) (*auth.RequestAuth, error)
|
||||
Release(a *auth.RequestAuth)
|
||||
}
|
||||
|
||||
type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
DeleteAllSessionsForToken(ctx context.Context, token string) error
|
||||
}
|
||||
|
||||
type ConfigReader interface {
|
||||
ModelAliases() map[string]string
|
||||
CompatWideInputStrictOutput() bool
|
||||
ToolcallMode() string
|
||||
ToolcallEarlyEmitConfidence() string
|
||||
ResponsesStoreTTLSeconds() int
|
||||
EmbeddingsProvider() string
|
||||
AutoDeleteSessions() bool
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
71
internal/adapter/openai/deps_injection_test.go
Normal file
71
internal/adapter/openai/deps_injection_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
type mockOpenAIConfig struct {
|
||||
aliases map[string]string
|
||||
wideInput bool
|
||||
toolMode string
|
||||
earlyEmit string
|
||||
responsesTTL int
|
||||
embedProv string
|
||||
}
|
||||
|
||||
func (m mockOpenAIConfig) ModelAliases() map[string]string { return m.aliases }
|
||||
func (m mockOpenAIConfig) CompatWideInputStrictOutput() bool {
|
||||
return m.wideInput
|
||||
}
|
||||
func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMode }
|
||||
func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit }
|
||||
func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL }
|
||||
func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv }
|
||||
func (m mockOpenAIConfig) AutoDeleteSessions() bool { return false }
|
||||
|
||||
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||
cfg := mockOpenAIConfig{
|
||||
aliases: map[string]string{
|
||||
"my-model": "deepseek-chat-search",
|
||||
},
|
||||
wideInput: true,
|
||||
}
|
||||
req := map[string]any{
|
||||
"model": "my-model",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
|
||||
}
|
||||
out, err := normalizeOpenAIChatRequest(cfg, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeOpenAIChatRequest error: %v", err)
|
||||
}
|
||||
if out.ResolvedModel != "deepseek-chat-search" {
|
||||
t.Fatalf("resolved model mismatch: got=%q", out.ResolvedModel)
|
||||
}
|
||||
if !out.Search || out.Thinking {
|
||||
t.Fatalf("unexpected model flags: thinking=%v search=%v", out.Thinking, out.Search)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"input": "hi",
|
||||
}
|
||||
|
||||
_, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||
aliases: map[string]string{},
|
||||
wideInput: false,
|
||||
}, req, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when wide input is disabled and only input is provided")
|
||||
}
|
||||
|
||||
out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||
aliases: map[string]string{},
|
||||
wideInput: true,
|
||||
}, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error when wide input is enabled: %v", err)
|
||||
}
|
||||
if out.Surface != "openai_responses" {
|
||||
t.Fatalf("unexpected surface: %q", out.Surface)
|
||||
}
|
||||
}
|
||||
138
internal/adapter/openai/embeddings_handler.go
Normal file
138
internal/adapter/openai/embeddings_handler.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.")
|
||||
return
|
||||
}
|
||||
if _, ok := config.ResolveModel(h.Store, model); !ok {
|
||||
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
return
|
||||
}
|
||||
|
||||
inputs := extractEmbeddingInputs(req["input"])
|
||||
if len(inputs) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include non-empty 'input'.")
|
||||
return
|
||||
}
|
||||
|
||||
provider := ""
|
||||
if h.Store != nil {
|
||||
provider = strings.ToLower(strings.TrimSpace(h.Store.EmbeddingsProvider()))
|
||||
}
|
||||
if provider == "" {
|
||||
writeOpenAIError(w, http.StatusNotImplemented, "Embeddings provider is not configured. Set embeddings.provider in config.")
|
||||
return
|
||||
}
|
||||
switch provider {
|
||||
case "mock", "deterministic", "builtin":
|
||||
// supported local deterministic provider
|
||||
default:
|
||||
writeOpenAIError(w, http.StatusNotImplemented, fmt.Sprintf("Embeddings provider '%s' is not supported.", provider))
|
||||
return
|
||||
}
|
||||
|
||||
data := make([]map[string]any, 0, len(inputs))
|
||||
totalTokens := 0
|
||||
for i, input := range inputs {
|
||||
totalTokens += util.EstimateTokens(input)
|
||||
data = append(data, map[string]any{
|
||||
"object": "embedding",
|
||||
"index": i,
|
||||
"embedding": deterministicEmbedding(input),
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": model,
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": totalTokens,
|
||||
"total_tokens": totalTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func extractEmbeddingInputs(raw any) []string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(v)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{s}
|
||||
case []any:
|
||||
out := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
switch iv := item.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(iv)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
case []any:
|
||||
// Token array input support: convert to stable string form.
|
||||
out = append(out, fmt.Sprintf("%v", iv))
|
||||
default:
|
||||
s := strings.TrimSpace(fmt.Sprintf("%v", iv))
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func deterministicEmbedding(input string) []float64 {
|
||||
// Keep response shape stable without external dependencies.
|
||||
const dims = 64
|
||||
out := make([]float64, dims)
|
||||
seed := sha256.Sum256([]byte(input))
|
||||
buf := seed[:]
|
||||
for i := 0; i < dims; i++ {
|
||||
if len(buf) < 4 {
|
||||
next := sha256.Sum256(buf)
|
||||
buf = next[:]
|
||||
}
|
||||
v := binary.BigEndian.Uint32(buf[:4])
|
||||
buf = buf[4:]
|
||||
// map [0, 2^32) -> [-1, 1]
|
||||
out[i] = (float64(v)/2147483647.5 - 1.0)
|
||||
}
|
||||
return out
|
||||
}
|
||||
96
internal/adapter/openai/embeddings_route_test.go
Normal file
96
internal/adapter/openai/embeddings_route_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", cfgJSON)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "unused", nil
|
||||
})
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func TestEmbeddingsRouteContract(t *testing.T) {
|
||||
store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
t.Run("unauthorized", func(t *testing.T) {
|
||||
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if out["object"] != "list" {
|
||||
t.Fatalf("unexpected object: %#v", out["object"])
|
||||
}
|
||||
data, _ := out["data"].([]any)
|
||||
if len(data) != 2 {
|
||||
t.Fatalf("expected 2 embeddings, got %d", len(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddingsRouteProviderMissing(t *testing.T) {
|
||||
store, resolver := newResolverWithConfigJSON(t, `{}`)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if _, ok := errObj["code"]; !ok {
|
||||
t.Fatalf("expected error.code in response: %#v", out)
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatalf("expected error.param in response: %#v", out)
|
||||
}
|
||||
}
|
||||
34
internal/adapter/openai/error_shape_test.go
Normal file
34
internal/adapter/openai/error_shape_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
writeOpenAIError(rec, http.StatusBadRequest, "invalid input")
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode body: %v", err)
|
||||
}
|
||||
errObj, _ := body["error"].(map[string]any)
|
||||
if errObj["message"] != "invalid input" {
|
||||
t.Fatalf("unexpected message: %v", errObj["message"])
|
||||
}
|
||||
if errObj["type"] != "invalid_request_error" {
|
||||
t.Fatalf("unexpected type: %v", errObj["type"])
|
||||
}
|
||||
if errObj["code"] != "invalid_request" {
|
||||
t.Fatalf("unexpected code: %v", errObj["code"])
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatal("expected param field")
|
||||
}
|
||||
}
|
||||
@@ -1,487 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias kept to avoid mass-renaming across
|
||||
// every call-site in this file. It delegates to the shared util version.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
DS *deepseek.Client
|
||||
|
||||
leaseMu sync.Mutex
|
||||
streamLeases map[string]streamLease
|
||||
}
|
||||
|
||||
type streamLease struct {
|
||||
Auth *auth.RequestAuth
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/v1/models", h.ListModels)
|
||||
r.Post("/v1/chat/completions", h.ChatCompletions)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
|
||||
}
|
||||
|
||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
if isVercelStreamReleaseRequest(r) {
|
||||
h.handleVercelStreamRelease(w, r)
|
||||
return
|
||||
}
|
||||
if isVercelStreamPrepareRequest(r) {
|
||||
h.handleVercelStreamPrepare(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
return
|
||||
}
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
return
|
||||
}
|
||||
|
||||
messages := normalizeMessages(messagesRaw)
|
||||
toolNames := []string{}
|
||||
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools)
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(messages)
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if util.ToBool(req["stream"]) {
|
||||
h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
_ = ctx
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := result.Text
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if thinkingEnabled && finalThinking != "" {
|
||||
messageObj["reasoning_content"] = finalThinking
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
canFlush := rc.Flush() == nil
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
|
||||
created := time.Now().Unix()
|
||||
firstChunkSent := false
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
var toolSieve toolStreamSieveState
|
||||
toolCallsEmitted := false
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
lastContent := time.Now()
|
||||
hasContent := false
|
||||
keepaliveTicker := time.NewTicker(time.Duration(deepseek.KeepAliveTimeout) * time.Second)
|
||||
defer keepaliveTicker.Stop()
|
||||
keepaliveCountWithoutContent := 0
|
||||
|
||||
sendChunk := func(v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = w.Write([]byte("data: "))
|
||||
_, _ = w.Write(b)
|
||||
_, _ = w.Write([]byte("\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
sendDone := func() {
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
finalize := func(finishReason string) {
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) > 0 && !toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": util.FormatOpenAIStreamToolCalls(detected),
|
||||
}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": delta, "index": 0}},
|
||||
})
|
||||
} else if bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&toolSieve, toolNames) {
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
}
|
||||
delta := map[string]any{
|
||||
"content": evt.Content,
|
||||
}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": delta, "index": 0}},
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(detected) > 0 || toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": map[string]any{}, "index": 0, "finish_reason": finishReason}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
sendDone()
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-keepaliveTicker.C:
|
||||
if !hasContent {
|
||||
keepaliveCountWithoutContent++
|
||||
if keepaliveCountWithoutContent >= deepseek.MaxKeepaliveCount {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
}
|
||||
if hasContent && time.Since(lastContent) > time.Duration(deepseek.StreamIdleTimeout)*time.Second {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
if canFlush {
|
||||
_, _ = w.Write([]byte(": keep-alive\n\n"))
|
||||
_ = rc.Flush()
|
||||
}
|
||||
case parsed, ok := <-parsedLines:
|
||||
if !ok {
|
||||
// Ensure scanner completion is observed only after all queued
|
||||
// SSE lines are drained, avoiding early finalize races.
|
||||
_ = <-done
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||
finalize("content_filter")
|
||||
return
|
||||
}
|
||||
if parsed.Stop {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
newChoices := make([]map[string]any, 0, len(parsed.Parts))
|
||||
for _, p := range parsed.Parts {
|
||||
if searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
hasContent = true
|
||||
lastContent = time.Now()
|
||||
keepaliveCountWithoutContent = 0
|
||||
delta := map[string]any{}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
if p.Type == "thinking" {
|
||||
if thinkingEnabled {
|
||||
thinking.WriteString(p.Text)
|
||||
delta["reasoning_content"] = p.Text
|
||||
}
|
||||
} else {
|
||||
text.WriteString(p.Text)
|
||||
if !bufferToolContent {
|
||||
delta["content"] = p.Text
|
||||
} else {
|
||||
events := processToolSieveChunk(&toolSieve, p.Text, toolNames)
|
||||
if len(events) == 0 {
|
||||
// Keep thinking delta only frame.
|
||||
}
|
||||
for _, evt := range events {
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls),
|
||||
}
|
||||
if !firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, map[string]any{
|
||||
"delta": tcDelta,
|
||||
"index": 0,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
contentDelta := map[string]any{
|
||||
"content": evt.Content,
|
||||
}
|
||||
if !firstChunkSent {
|
||||
contentDelta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, map[string]any{
|
||||
"delta": contentDelta,
|
||||
"index": 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(delta) > 0 {
|
||||
newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0})
|
||||
}
|
||||
}
|
||||
if len(newChoices) > 0 {
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": newChoices,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMessages(raw []any) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
m, ok := item.(map[string]any)
|
||||
if ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tool["function"].(map[string]any)
|
||||
if len(fn) == 0 {
|
||||
fn = tool
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
schema, _ := fn["parameters"].(map[string]any)
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
names = append(names, name)
|
||||
if desc == "" {
|
||||
desc = "No description available"
|
||||
}
|
||||
b, _ := json.Marshal(schema)
|
||||
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
||||
}
|
||||
if len(toolSchemas) == 0 {
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }"
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
old, _ := messages[i]["content"].(string)
|
||||
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
||||
return messages, names
|
||||
}
|
||||
}
|
||||
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
||||
return messages, names
|
||||
}
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": openAIErrorType(status),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func openAIErrorType(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
return "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
return "permission_error"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_error"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "service_unavailable_error"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "api_error"
|
||||
}
|
||||
return "invalid_request_error"
|
||||
}
|
||||
}
|
||||
174
internal/adapter/openai/handler_chat.go
Normal file
174
internal/adapter/openai/handler_chat.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
if isVercelStreamReleaseRequest(r) {
|
||||
h.handleVercelStreamRelease(w, r)
|
||||
return
|
||||
}
|
||||
if isVercelStreamPrepareRequest(r) {
|
||||
h.handleVercelStreamPrepare(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
// 自动删除会话(同步)
|
||||
// 必须在 Release 之前同步删除,否则:
|
||||
// 1. 异步删除时账号已被 Release
|
||||
// 2. 新请求可能获取到同一账号并开始使用
|
||||
// 3. 异步删除仍在进行,会截断新请求正在使用的会话
|
||||
if h.Store.AutoDeleteSessions() && a.DeepSeekToken != "" {
|
||||
deleteCtx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
err := h.DS.DeleteAllSessionsForToken(deleteCtx, a.DeepSeekToken)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[auto_delete_sessions] failed", "account", a.AccountID, "error", err)
|
||||
} else {
|
||||
config.Logger.Debug("[auto_delete_sessions] success", "account", a.AccountID)
|
||||
}
|
||||
}
|
||||
h.Auth.Release(a)
|
||||
}()
|
||||
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if stdReq.Stream {
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
_ = ctx
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := sanitizeLeakedToolHistory(result.Text)
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
|
||||
created := time.Now().Unix()
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
emitEarlyToolDeltas := h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence()
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
|
||||
streamRuntime := newChatStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
finalPrompt,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
)
|
||||
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendKeepAlive()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: func(reason streamengine.StopReason, _ error) {
|
||||
if string(reason) == "content_filter" {
|
||||
streamRuntime.finalize("content_filter")
|
||||
return
|
||||
}
|
||||
streamRuntime.finalize("stop")
|
||||
},
|
||||
})
|
||||
}
|
||||
63
internal/adapter/openai/handler_errors.go
Normal file
63
internal/adapter/openai/handler_errors.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package openai
|
||||
|
||||
import "net/http"
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeOpenAIErrorWithCode(w, status, message, "")
|
||||
}
|
||||
|
||||
func writeOpenAIErrorWithCode(w http.ResponseWriter, status int, message, code string) {
|
||||
if code == "" {
|
||||
code = openAIErrorCode(status)
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": openAIErrorType(status),
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func openAIErrorType(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request_error"
|
||||
case http.StatusUnauthorized:
|
||||
return "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
return "permission_error"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_error"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "service_unavailable_error"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "api_error"
|
||||
}
|
||||
return "invalid_request_error"
|
||||
}
|
||||
}
|
||||
|
||||
func openAIErrorCode(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request"
|
||||
case http.StatusUnauthorized:
|
||||
return "authentication_failed"
|
||||
case http.StatusForbidden:
|
||||
return "forbidden"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
return "not_found"
|
||||
case http.StatusServiceUnavailable:
|
||||
return "service_unavailable"
|
||||
default:
|
||||
if status >= 500 {
|
||||
return "internal_error"
|
||||
}
|
||||
return "invalid_request"
|
||||
}
|
||||
}
|
||||
57
internal/adapter/openai/handler_routes.go
Normal file
57
internal/adapter/openai/handler_routes.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias kept to avoid mass-renaming across
|
||||
// every call-site in this package.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
|
||||
leaseMu sync.Mutex
|
||||
streamLeases map[string]streamLease
|
||||
responsesMu sync.Mutex
|
||||
responses *responseStore
|
||||
}
|
||||
|
||||
type streamLease struct {
|
||||
Auth *auth.RequestAuth
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/v1/models", h.ListModels)
|
||||
r.Get("/v1/models/{model_id}", h.GetModel)
|
||||
r.Post("/v1/chat/completions", h.ChatCompletions)
|
||||
r.Post("/v1/responses", h.Responses)
|
||||
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
|
||||
r.Post("/v1/embeddings", h.Embeddings)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
|
||||
}
|
||||
|
||||
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
|
||||
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
|
||||
model, ok := config.OpenAIModelByID(h.Store, modelID)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, model)
|
||||
}
|
||||
171
internal/adapter/openai/handler_toolcall_format.go
Normal file
171
internal/adapter/openai/handler_toolcall_format.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolChoicePolicy) ([]map[string]any, []string) {
|
||||
if policy.IsNone() {
|
||||
return messages, nil
|
||||
}
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
isAllowed := func(name string) bool {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return false
|
||||
}
|
||||
if len(policy.Allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := policy.Allowed[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tool["function"].(map[string]any)
|
||||
if len(fn) == 0 {
|
||||
fn = tool
|
||||
}
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
schema, _ := fn["parameters"].(map[string]any)
|
||||
name = strings.TrimSpace(name)
|
||||
if !isAllowed(name) {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
if desc == "" {
|
||||
desc = "No description available"
|
||||
}
|
||||
b, _ := json.Marshal(schema)
|
||||
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
||||
}
|
||||
if len(toolSchemas) == 0 {
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON object format:\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\n【EXAMPLE】\nUser: Please check the weather in Beijing and Shanghai, and update my todo list.\nAssistant:\n{\"tool_calls\": [\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Beijing\"}},\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Shanghai\"}},\n {\"name\": \"update_todo\", \"input\": {\"todos\": [{\"content\": \"Buy milk\"}, {\"content\": \"Write report\"}]}}\n]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON object above. Do NOT include any extra text.\n2) Do NOT wrap tool-call JSON in markdown/code fences (for example, do not use triple backticks).\n3) After receiving a tool result, you MUST use it to produce the final answer.\n4) Only call another tool when the previous result is missing required data or returned an error.\n5) JSON SYNTAX STRICTLY REQUIRED: All property names MUST be enclosed in double quotes (e.g., \"name\", not name).\n6) ARRAY FORMAT: If providing a list of items, you MUST enclose them in square brackets `[]` (e.g., \"todos\": [{\"item\": \"a\"}, {\"item\": \"b\"}]). DO NOT output comma-separated objects without brackets."
|
||||
if policy.Mode == util.ToolChoiceRequired {
|
||||
toolPrompt += "\n7) For this response, you MUST call at least one tool from the allowed list."
|
||||
}
|
||||
if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
|
||||
toolPrompt += "\n7) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
||||
toolPrompt += "\n8) Do not call any other tool."
|
||||
}
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
old, _ := messages[i]["content"].(string)
|
||||
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
||||
return messages, names
|
||||
}
|
||||
}
|
||||
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
||||
return messages, names
|
||||
}
|
||||
|
||||
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name == "" && d.Arguments == "" {
|
||||
continue
|
||||
}
|
||||
callID, ok := ids[d.Index]
|
||||
if !ok || callID == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
ids[d.Index] = callID
|
||||
}
|
||||
item := map[string]any{
|
||||
"index": d.Index,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
}
|
||||
fn := map[string]any{}
|
||||
if d.Name != "" {
|
||||
fn["name"] = d.Name
|
||||
}
|
||||
if d.Arguments != "" {
|
||||
fn["arguments"] = d.Arguments
|
||||
}
|
||||
if len(fn) > 0 {
|
||||
item["function"] = fn
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
allowed := namesToSet(allowedNames)
|
||||
if len(allowed) == 0 {
|
||||
for _, d := range deltas {
|
||||
if d.Name != "" {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
out := make([]toolCallDelta, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name != "" {
|
||||
if _, ok := allowed[d.Name]; !ok {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
continue
|
||||
}
|
||||
seenNames[d.Index] = d.Name
|
||||
out = append(out, d)
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(seenNames[d.Index])
|
||||
if name == "" || name == "__blocked__" {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for i, c := range calls {
|
||||
callID := ""
|
||||
if ids != nil {
|
||||
callID = strings.TrimSpace(ids[i])
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if ids != nil {
|
||||
ids[i] = callID
|
||||
}
|
||||
}
|
||||
args, _ := json.Marshal(c.Input)
|
||||
out = append(out, map[string]any{
|
||||
"index": i,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": c.Name,
|
||||
"arguments": string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
19
internal/adapter/openai/handler_toolcall_policy.go
Normal file
19
internal/adapter/openai/handler_toolcall_policy.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func (h *Handler) toolcallFeatureMatchEnabled() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
|
||||
return mode == "" || mode == "feature_match"
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
|
||||
return level == "" || level == "high"
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package openai
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -100,6 +101,26 @@ func streamFinishReason(frames []map[string]any) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func streamToolCallArgumentChunks(frames []map[string]any) []string {
|
||||
out := make([]string, 0, 4)
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
if args, ok := fn["arguments"].(string); ok && args != "" {
|
||||
out = append(out, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -108,7 +129,7 @@ func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
@@ -141,7 +162,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, false, []string{"search"})
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
@@ -161,7 +182,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
@@ -169,7 +190,38 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected unknown tool json to pass through as text, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
@@ -181,13 +233,49 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil, got %#v", msg["content"])
|
||||
}
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"])
|
||||
t.Fatalf("expected one tool_call field for embedded example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected raw tool_calls json stripped from content, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] == "tool_calls" {
|
||||
t.Fatalf("expected fenced example to remain content-only, got finish_reason=%#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 0 {
|
||||
t.Fatalf("expected no tool_call field for fenced example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected fenced example content preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
@@ -235,6 +323,36 @@ func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallLargeArgumentsStillIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
large := strings.Repeat("a", 9000)
|
||||
payload := fmt.Sprintf(`{"tool_calls":[{"name":"search","input":{"q":"%s"}}]}`, large)
|
||||
splitAt := len(payload) / 2
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[:splitAt]),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[splitAt:]),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid3-large", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -295,7 +413,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
@@ -310,29 +428,40 @@ func TestHandleStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
foundToolIndex := false
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if _, ok := tcm["index"].(float64); ok {
|
||||
foundToolIndex = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundToolIndex {
|
||||
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\"}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid5b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name (no args): %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,9 +506,9 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"前置正文A。"}`,
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"后置正文B。"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -392,10 +521,7 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String())
|
||||
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -409,15 +535,175 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") {
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||
}
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamFencedToolCallSnippetPromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例,不要执行。"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected raw fenced tool_calls snippet stripped from content, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "```json") || strings.Contains(got, "\n```\n") {
|
||||
t.Fatalf("expected consumed fenced tool payload to not leave empty code fence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
func TestHandleStreamStandaloneToolCallAfterClosedFenceKeepsFence(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "先给一个代码示例:\n```text\nhello\n```\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7g", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for standalone payload, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "```") {
|
||||
t.Fatalf("expected closed fence before standalone tool json to be preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
spaces := strings.Repeat(" ", 200)
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -438,9 +724,6 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
@@ -453,9 +736,6 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if strings.Contains(got, "{") {
|
||||
t.Fatalf("unexpected suspicious prefix leak in content: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "后置正文C。") {
|
||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||
}
|
||||
@@ -495,16 +775,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
got := strings.ToLower(content.String())
|
||||
if strings.Contains(got, "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String())
|
||||
}
|
||||
if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") {
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") {
|
||||
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) {
|
||||
func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
@@ -533,7 +813,112 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") {
|
||||
t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String())
|
||||
if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") {
|
||||
t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||
`data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
argChunks := streamToolCallArgumentChunks(frames)
|
||||
if len(argChunks) == 0 {
|
||||
t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String())
|
||||
}
|
||||
joined := strings.Join(argChunks, "")
|
||||
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||
t.Fatalf("unexpected merged arguments stream: %q", joined)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
|
||||
`data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundSearch := false
|
||||
foundEval := false
|
||||
foundIndex1 := false
|
||||
toolCallsDeltaLens := make([]int, 0, 2)
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
rawToolCalls, hasToolCalls := delta["tool_calls"]
|
||||
if !hasToolCalls {
|
||||
continue
|
||||
}
|
||||
toolCalls, _ := rawToolCalls.([]any)
|
||||
toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
|
||||
foundIndex1 = true
|
||||
}
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
name, _ := fn["name"].(string)
|
||||
switch name {
|
||||
case "search_web":
|
||||
foundSearch = true
|
||||
case "eval_javascript":
|
||||
foundEval = true
|
||||
case "search_webeval_javascript":
|
||||
t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
|
||||
}
|
||||
if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundSearch || !foundEval {
|
||||
t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
|
||||
}
|
||||
if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 {
|
||||
t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String())
|
||||
}
|
||||
if !foundIndex1 {
|
||||
t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
148
internal/adapter/openai/message_normalize.go
Normal file
148
internal/adapter/openai/message_normalize.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/prompt"
|
||||
)
|
||||
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||
_ = traceID
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
msg, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
switch role {
|
||||
case "assistant":
|
||||
content := buildAssistantContentForPrompt(msg)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
})
|
||||
case "tool", "function":
|
||||
content := buildToolContentForPrompt(msg)
|
||||
out = append(out, map[string]any{
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
})
|
||||
case "user", "system", "developer":
|
||||
out = append(out, map[string]any{
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
||||
})
|
||||
default:
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildAssistantContentForPrompt(msg map[string]any) string {
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
toolCalls := normalizeAssistantToolCallsForPrompt(msg["tool_calls"])
|
||||
if toolCalls == "" {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return toolCalls
|
||||
}
|
||||
return strings.TrimSpace(content + "\n" + toolCalls)
|
||||
}
|
||||
|
||||
func normalizeAssistantToolCallsForPrompt(v any) string {
|
||||
calls, ok := v.([]any)
|
||||
if !ok || len(calls) == 0 {
|
||||
return ""
|
||||
}
|
||||
b, err := json.Marshal(calls)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", calls))
|
||||
}
|
||||
return strings.TrimSpace(string(b))
|
||||
}
|
||||
|
||||
func buildToolContentForPrompt(msg map[string]any) string {
|
||||
payload := map[string]any{
|
||||
"content": msg["content"],
|
||||
}
|
||||
if id := strings.TrimSpace(asString(msg["tool_call_id"])); id != "" {
|
||||
payload["tool_call_id"] = id
|
||||
}
|
||||
if id := strings.TrimSpace(asString(msg["id"])); id != "" {
|
||||
payload["id"] = id
|
||||
}
|
||||
if name := strings.TrimSpace(asString(msg["name"])); name != "" {
|
||||
payload["name"] = name
|
||||
}
|
||||
content := normalizeOpenAIContentForPrompt(payload)
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return `{"content":"null"}`
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func normalizeOpenAIContentForPrompt(v any) string {
|
||||
return prompt.NormalizeContent(v)
|
||||
}
|
||||
|
||||
func normalizeToolArgumentString(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if looksLikeConcatenatedJSON(trimmed) {
|
||||
// Keep original payload to avoid silent argument rewrites.
|
||||
return raw
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func normalizeOpenAIRoleForPrompt(role string) string {
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if role == "developer" {
|
||||
return "system"
|
||||
}
|
||||
return role
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func looksLikeConcatenatedJSON(raw string) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
|
||||
return true
|
||||
}
|
||||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||||
var first any
|
||||
if err := dec.Decode(&first); err != nil {
|
||||
return false
|
||||
}
|
||||
var second any
|
||||
return dec.Decode(&second) == nil
|
||||
}
|
||||
293
internal/adapter/openai/message_normalize_test.go
Normal file
293
internal/adapter/openai/message_normalize_test.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{"role": "system", "content": "You are helpful"},
|
||||
map[string]any{"role": "user", "content": "查北京天气"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": nil,
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"beijing\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"name": "get_weather",
|
||||
"content": "{\"temp\":18}",
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 4 {
|
||||
t.Fatalf("expected 4 normalized messages with assistant tool_call history preserved, got %d", len(normalized))
|
||||
}
|
||||
toolContent, _ := normalized[3]["content"].(string)
|
||||
if !strings.Contains(toolContent, `\"temp\":18`) {
|
||||
t.Fatalf("tool result should be transparently forwarded, got %q", toolContent)
|
||||
}
|
||||
if strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("tool history marker should not be injected: %q", toolContent)
|
||||
}
|
||||
|
||||
prompt := util.MessagesPrepare(normalized)
|
||||
if strings.Contains(prompt, "[TOOL_CALL_HISTORY]") || strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("expected no synthetic history markers in prompt: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_2",
|
||||
"name": "get_weather",
|
||||
"content": map[string]any{
|
||||
"temp": 18,
|
||||
"condition": "sunny",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) {
|
||||
t.Fatalf("expected serialized object in tool content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_3",
|
||||
"name": "read_file",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_text", "text": "line-1"},
|
||||
map[string]any{"type": "output_text", "text": "line-2"},
|
||||
map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"line-1"`) || !strings.Contains(got, `"line-2"`) || !strings.Contains(got, `"name":"read_file"`) {
|
||||
t.Fatalf("expected tool envelope to preserve content blocks and metadata, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "function",
|
||||
"tool_call_id": "call_4",
|
||||
"name": "legacy_tool",
|
||||
"content": map[string]any{
|
||||
"ok": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
if normalized[0]["role"] != "tool" {
|
||||
t.Fatalf("expected function role normalized as tool, got %#v", normalized[0]["role"])
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"name":"legacy_tool"`) || !strings.Contains(got, `"ok":true`) {
|
||||
t.Fatalf("unexpected normalized function-role content: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_EmptyToolContentPreservedAsNull(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_5",
|
||||
"name": "noop_tool",
|
||||
"content": "",
|
||||
},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "done",
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 2 {
|
||||
t.Fatalf("expected tool completion turn to be preserved, got %#v", normalized)
|
||||
}
|
||||
if normalized[0]["role"] != "tool" {
|
||||
t.Fatalf("expected tool role preserved, got %#v", normalized[0]["role"])
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"content":""`) || !strings.Contains(got, `"name":"noop_tool"`) || !strings.Contains(got, `"tool_call_id":"call_5"`) {
|
||||
t.Fatalf("expected tool metadata preserved in content envelope, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_search",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search_web",
|
||||
"arguments": `{"query":"latest ai news"}`,
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"id": "call_eval",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "eval_javascript",
|
||||
"arguments": `{"code":"1+1"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected assistant tool_call-only message to be preserved, got %#v", normalized)
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"name":"search_web"`) || !strings.Contains(got, `"name":"eval_javascript"`) {
|
||||
t.Fatalf("expected tool_calls payload preserved in assistant content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_1",
|
||||
"function": map[string]any{
|
||||
"name": "search_web",
|
||||
"arguments": `{}{"query":"测试工具调用"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected assistant tool_call-only content to be preserved, got %#v", normalized)
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `{}{\"query\":\"测试工具调用\"}`) {
|
||||
t.Fatalf("expected concatenated arguments preserved verbatim, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDropped(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_missing_name",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"arguments": `{"path":"README.MD"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected assistant tool_calls history to be preserved even when name missing, got %#v", normalized)
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "call_missing_name") {
|
||||
t.Fatalf("expected raw tool_call payload preserved, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLiteral(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": nil,
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_screenshot",
|
||||
"function": map[string]any{
|
||||
"name": "send_file_to_user",
|
||||
"arguments": `{"file_path":"/tmp/a.png"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected nil-content assistant tool_call-only message to be preserved, got %#v", normalized)
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "send_file_to_user") {
|
||||
t.Fatalf("expected tool call payload preserved, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_DeveloperRoleMapsToSystem(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{"role": "developer", "content": "必须先走工具调用"},
|
||||
map[string]any{"role": "user", "content": "你好"},
|
||||
}
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 2 {
|
||||
t.Fatalf("expected 2 normalized messages, got %d", len(normalized))
|
||||
}
|
||||
if normalized[0]["role"] != "system" {
|
||||
t.Fatalf("expected developer role converted to system, got %#v", normalized[0]["role"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextEmpty(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "", "content": "工具说明文本"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if content != "工具说明文本" {
|
||||
t.Fatalf("expected content fallback text preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
46
internal/adapter/openai/models_route_test.go
Normal file
46
internal/adapter/openai/models_route_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestGetModelRouteDirectAndAlias(t *testing.T) {
|
||||
h := &Handler{}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
t.Run("direct", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models/deepseek-chat", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("alias", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models/gpt-4.1", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for alias, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetModelRouteNotFound(t *testing.T) {
|
||||
h := &Handler{}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models/not-exists", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
26
internal/adapter/openai/prompt_build.go
Normal file
26
internal/adapter/openai/prompt_build.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||
return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy())
|
||||
}
|
||||
|
||||
func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID)
|
||||
toolNames := []string{}
|
||||
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools, toolPolicy)
|
||||
}
|
||||
return deepseek.MessagesPrepare(messages), toolNames
|
||||
}
|
||||
|
||||
// BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so
|
||||
// other protocol adapters (for example Gemini) can reuse the same tool/history
|
||||
// normalization logic and remain behavior-compatible with chat/completions.
|
||||
func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) {
|
||||
return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID)
|
||||
}
|
||||
86
internal/adapter/openai/prompt_build_test.go
Normal file
86
internal/adapter/openai/prompt_build_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *testing.T) {
|
||||
messages := []any{
|
||||
map[string]any{"role": "user", "content": "查北京天气"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_1",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"beijing\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"name": "get_weather",
|
||||
"content": map[string]any{"temp": 18, "condition": "sunny"},
|
||||
},
|
||||
}
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "")
|
||||
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
|
||||
t.Fatalf("unexpected tool names: %#v", toolNames)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, `"condition":"sunny"`) {
|
||||
t.Fatalf("handler finalPrompt should preserve tool output content: %q", finalPrompt)
|
||||
}
|
||||
if strings.Contains(finalPrompt, "[TOOL_CALL_HISTORY]") || strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("handler finalPrompt should not include synthetic history markers: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) {
|
||||
messages := []any{
|
||||
map[string]any{"role": "system", "content": "You are helpful"},
|
||||
map[string]any{"role": "user", "content": "请调用工具"},
|
||||
}
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
"description": "search docs",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "")
|
||||
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "Do NOT wrap tool-call JSON in markdown/code fences") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing no-fence instruction: %q", finalPrompt)
|
||||
}
|
||||
if strings.Contains(finalPrompt, "```json") {
|
||||
t.Fatalf("vercel prepare finalPrompt should not require fenced json tool calls: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
109
internal/adapter/openai/response_store.go
Normal file
109
internal/adapter/openai/response_store.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type storedResponse struct {
|
||||
Owner string
|
||||
Value map[string]any
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type responseStore struct {
|
||||
mu sync.Mutex
|
||||
ttl time.Duration
|
||||
items map[string]storedResponse
|
||||
}
|
||||
|
||||
func newResponseStore(ttl time.Duration) *responseStore {
|
||||
if ttl <= 0 {
|
||||
ttl = 15 * time.Minute
|
||||
}
|
||||
return &responseStore{
|
||||
ttl: ttl,
|
||||
items: make(map[string]storedResponse),
|
||||
}
|
||||
}
|
||||
|
||||
func responseStoreKey(owner, id string) string {
|
||||
return owner + "\x00" + id
|
||||
}
|
||||
|
||||
func responseStoreOwner(a *auth.RequestAuth) string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
return a.CallerID
|
||||
}
|
||||
|
||||
func (s *responseStore) put(owner, id string, value map[string]any) {
|
||||
if s == nil || owner == "" || id == "" || value == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sweepLocked(now)
|
||||
s.items[responseStoreKey(owner, id)] = storedResponse{
|
||||
Owner: owner,
|
||||
Value: cloneAnyMap(value),
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responseStore) get(owner, id string) (map[string]any, bool) {
|
||||
if s == nil || owner == "" || id == "" {
|
||||
return nil, false
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sweepLocked(now)
|
||||
item, ok := s.items[responseStoreKey(owner, id)]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if item.Owner != owner {
|
||||
return nil, false
|
||||
}
|
||||
return cloneAnyMap(item.Value), true
|
||||
}
|
||||
|
||||
func (s *responseStore) sweepLocked(now time.Time) {
|
||||
for k, v := range s.items {
|
||||
if now.After(v.ExpiresAt) {
|
||||
delete(s.items, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cloneAnyMap(in map[string]any) map[string]any {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) getResponseStore() *responseStore {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
h.responsesMu.Lock()
|
||||
defer h.responsesMu.Unlock()
|
||||
if h.responses == nil {
|
||||
ttl := 15 * time.Minute
|
||||
if h.Store != nil {
|
||||
ttl = time.Duration(h.Store.ResponsesStoreTTLSeconds()) * time.Second
|
||||
}
|
||||
h.responses = newResponseStore(ttl)
|
||||
}
|
||||
return h.responses
|
||||
}
|
||||
197
internal/adapter/openai/responses_embeddings_test.go
Normal file
197
internal/adapter/openai/responses_embeddings_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesString(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages("hello")
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "user" || m["content"] != "hello" {
|
||||
t.Fatalf("unexpected message: %#v", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "gpt-4.1",
|
||||
"input": "ping",
|
||||
"instructions": "system text",
|
||||
}
|
||||
msgs := responsesMessagesFromRequest(req)
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected two messages, got %d", len(msgs))
|
||||
}
|
||||
sys, _ := msgs[0].(map[string]any)
|
||||
if sys["role"] != "system" {
|
||||
t.Fatalf("unexpected first message: %#v", sys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesObjectRoleContentBlocks(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages(map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_text", "text": "line-1"},
|
||||
map[string]any{"type": "input_text", "text": "line-2"},
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "user" {
|
||||
t.Fatalf("unexpected role: %#v", m)
|
||||
}
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(m["content"])) != "line-1\nline-2" {
|
||||
t.Fatalf("unexpected content: %#v", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": "call_123",
|
||||
"output": map[string]any{"ok": true},
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", m)
|
||||
}
|
||||
if m["tool_call_id"] != "call_123" {
|
||||
t.Fatalf("expected tool_call_id propagated, got %#v", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesBackfillsToolResultNameFromCallID(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": "call_999",
|
||||
"name": "search",
|
||||
"arguments": `{"q":"golang"}`,
|
||||
},
|
||||
map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": "call_999",
|
||||
"output": map[string]any{"ok": true},
|
||||
},
|
||||
})
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected two messages, got %d", len(msgs))
|
||||
}
|
||||
toolMsg, _ := msgs[1].(map[string]any)
|
||||
if toolMsg["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", toolMsg)
|
||||
}
|
||||
if toolMsg["name"] != "search" {
|
||||
t.Fatalf("expected tool name backfilled from call_id, got %#v", toolMsg["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": "call_456",
|
||||
"name": "search",
|
||||
"arguments": `{"q":"golang"}`,
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant role, got %#v", m["role"])
|
||||
}
|
||||
toolCalls, _ := m["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool_call, got %#v", m["tool_calls"])
|
||||
}
|
||||
call, _ := toolCalls[0].(map[string]any)
|
||||
if call["id"] != "call_456" {
|
||||
t.Fatalf("expected call id preserved, got %#v", call)
|
||||
}
|
||||
if call["type"] != "function" {
|
||||
t.Fatalf("expected function type, got %#v", call)
|
||||
}
|
||||
fn, _ := call["function"].(map[string]any)
|
||||
if fn["name"] != "search" {
|
||||
t.Fatalf("expected call name preserved, got %#v", call)
|
||||
}
|
||||
if fn["arguments"] != `{"q":"golang"}` {
|
||||
t.Fatalf("expected call arguments preserved, got %#v", call)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItemPreservesConcatenatedArguments(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": "call_456",
|
||||
"name": "search",
|
||||
"arguments": `{}{"q":"golang"}`,
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
toolCalls, _ := m["tool_calls"].([]any)
|
||||
call, _ := toolCalls[0].(map[string]any)
|
||||
fn, _ := call["function"].(map[string]any)
|
||||
if fn["arguments"] != `{}{"q":"golang"}` {
|
||||
t.Fatalf("expected original concatenated call arguments preserved, got %#v", fn["arguments"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmbeddingInputs(t *testing.T) {
|
||||
got := extractEmbeddingInputs([]any{"a", "b"})
|
||||
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
|
||||
t.Fatalf("unexpected inputs: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeterministicEmbeddingStable(t *testing.T) {
|
||||
a := deterministicEmbedding("hello")
|
||||
b := deterministicEmbedding("hello")
|
||||
if len(a) != 64 || len(b) != 64 {
|
||||
t.Fatalf("expected 64 dims, got %d and %d", len(a), len(b))
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
t.Fatalf("expected stable embedding at %d: %v != %v", i, a[i], b[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseStorePutGet(t *testing.T) {
|
||||
st := newResponseStore(100 * time.Millisecond)
|
||||
st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"})
|
||||
got, ok := st.get("owner_1", "resp_1")
|
||||
if !ok {
|
||||
t.Fatal("expected stored response")
|
||||
}
|
||||
if got["id"] != "resp_1" {
|
||||
t.Fatalf("unexpected response payload: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseStoreTenantIsolation(t *testing.T) {
|
||||
st := newResponseStore(100 * time.Millisecond)
|
||||
st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"})
|
||||
if _, ok := st.get("owner_b", "resp_1"); ok {
|
||||
t.Fatal("expected owner_b to be isolated from owner_a response")
|
||||
}
|
||||
}
|
||||
217
internal/adapter/openai/responses_handler.go
Normal file
217
internal/adapter/openai/responses_handler.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.DetermineCaller(r)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
id := strings.TrimSpace(chi.URLParam(r, "response_id"))
|
||||
if id == "" {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "response_id is required.")
|
||||
return
|
||||
}
|
||||
owner := responseStoreOwner(a)
|
||||
if owner == "" {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
st := h.getResponseStore()
|
||||
item, ok := st.get(owner, id)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusNotFound, "Response not found.")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, item)
|
||||
}
|
||||
|
||||
func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
owner := responseStoreOwner(a)
|
||||
if owner == "" {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
traceID := requestTraceID(r)
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req, traceID)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if stdReq.Stream {
|
||||
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
sanitizedText := sanitizeLeakedToolHistory(result.Text)
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(sanitizedText, toolNames)
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||
|
||||
callCount := len(textParsed.Calls)
|
||||
if toolChoice.IsRequired() && callCount == 0 {
|
||||
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||
return
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, sanitizedText, toolNames)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
emitEarlyToolDeltas := h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence()
|
||||
|
||||
streamRuntime := newResponsesStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
responseID,
|
||||
model,
|
||||
finalPrompt,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
toolChoice,
|
||||
traceID,
|
||||
func(obj map[string]any) {
|
||||
h.getResponseStore().put(owner, responseID, obj)
|
||||
},
|
||||
)
|
||||
streamRuntime.sendCreated()
|
||||
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||
streamRuntime.finalize()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolicy, parsed util.ToolCallParseResult, channel string) {
|
||||
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
||||
return
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[responses] rejected tool calls by policy",
|
||||
"trace_id", strings.TrimSpace(traceID),
|
||||
"channel", channel,
|
||||
"tool_choice_mode", policy.Mode,
|
||||
"rejected_tool_names", strings.Join(rejected, ","),
|
||||
)
|
||||
}
|
||||
|
||||
func filteredRejectedToolNamesForLog(names []string) []string {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
switch strings.ToLower(trimmed) {
|
||||
case "", "tool_name":
|
||||
continue
|
||||
default:
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
236
internal/adapter/openai/responses_input_items.go
Normal file
236
internal/adapter/openai/responses_input_items.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
return normalizeResponsesInputItemWithState(m, nil)
|
||||
}
|
||||
|
||||
func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[string]string) map[string]any {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role != "" {
|
||||
if role == "assistant" {
|
||||
out := map[string]any{
|
||||
"role": "assistant",
|
||||
}
|
||||
if toolCalls, ok := m["tool_calls"].([]any); ok && len(toolCalls) > 0 {
|
||||
out["tool_calls"] = toolCalls
|
||||
}
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content != nil {
|
||||
out["content"] = content
|
||||
}
|
||||
if _, hasToolCalls := out["tool_calls"]; hasToolCalls || out["content"] != nil {
|
||||
return out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
out := map[string]any{
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
}
|
||||
if role == "tool" || role == "function" {
|
||||
if callID := strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
switch itemType {
|
||||
case "message", "input_message":
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
return map[string]any{
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
}
|
||||
case "function_call_output", "tool_result":
|
||||
content := m["output"]
|
||||
if content == nil {
|
||||
content = m["content"]
|
||||
}
|
||||
if content == nil {
|
||||
content = ""
|
||||
}
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
out["name"] = name
|
||||
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
|
||||
out["name"] = name
|
||||
} else if callID := strings.TrimSpace(asString(out["tool_call_id"])); callID != "" {
|
||||
if inferred := strings.TrimSpace(callNameByID[callID]); inferred != "" {
|
||||
out["name"] = inferred
|
||||
} else {
|
||||
config.Logger.Warn(
|
||||
"[responses] unable to backfill tool result name from call_id",
|
||||
"call_id", callID,
|
||||
)
|
||||
}
|
||||
}
|
||||
return out
|
||||
case "function_call", "tool_call":
|
||||
name := strings.TrimSpace(asString(m["name"]))
|
||||
var fn map[string]any
|
||||
if rawFn, ok := m["function"].(map[string]any); ok {
|
||||
fn = rawFn
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asString(fn["name"]))
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var argsRaw any
|
||||
if v, ok := m["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := m["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
if argsRaw == nil && fn != nil {
|
||||
if v, ok := fn["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := fn["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
}
|
||||
|
||||
functionPayload := map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyToolCallArguments(argsRaw),
|
||||
}
|
||||
call := map[string]any{
|
||||
"type": "function",
|
||||
"function": functionPayload,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(call["id"])); callID != "" && callNameByID != nil {
|
||||
callNameByID[callID] = name
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{call},
|
||||
}
|
||||
case "input_text":
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesFallbackPart(m map[string]any) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
}
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", m))
|
||||
}
|
||||
|
||||
func stringifyToolCallArguments(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
s = normalizeToolArgumentString(s)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
94
internal/adapter/openai/responses_input_normalize.go
Normal file
94
internal/adapter/openai/responses_input_normalize.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func responsesMessagesFromRequest(req map[string]any) []any {
|
||||
if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
|
||||
return prependInstructionMessage(msgs, req["instructions"])
|
||||
}
|
||||
if rawInput, ok := req["input"]; ok {
|
||||
if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
|
||||
return prependInstructionMessage(msgs, req["instructions"])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prependInstructionMessage(messages []any, instructions any) []any {
|
||||
sys, _ := instructions.(string)
|
||||
sys = strings.TrimSpace(sys)
|
||||
if sys == "" {
|
||||
return messages
|
||||
}
|
||||
out := make([]any, 0, len(messages)+1)
|
||||
out = append(out, map[string]any{"role": "system", "content": sys})
|
||||
out = append(out, messages...)
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeResponsesInputAsMessages(input any) []any {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return nil
|
||||
}
|
||||
return []any{map[string]any{"role": "user", "content": v}}
|
||||
case []any:
|
||||
return normalizeResponsesInputArray(v)
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(v); msg != nil {
|
||||
return []any{msg}
|
||||
}
|
||||
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": txt}}
|
||||
}
|
||||
if content, ok := v["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": content}}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesInputArray(items []any) []any {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(items))
|
||||
callNameByID := map[string]string{}
|
||||
fallbackParts := make([]string, 0, len(items))
|
||||
flushFallback := func() {
|
||||
if len(fallbackParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
|
||||
fallbackParts = fallbackParts[:0]
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItemWithState(x, callNameByID); msg != nil {
|
||||
flushFallback()
|
||||
out = append(out, msg)
|
||||
continue
|
||||
}
|
||||
if s := normalizeResponsesFallbackPart(x); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
default:
|
||||
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
flushFallback()
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
176
internal/adapter/openai/responses_route_test.go
Normal file
176
internal/adapter/openai/responses_route_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "unused", nil
|
||||
})
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}]
|
||||
}`)
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0")
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "unused", nil
|
||||
})
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth {
|
||||
t.Helper()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
a, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine auth failed: %v", err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) {
|
||||
store, resolver := newDirectTokenResolver(t)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
ownerA := responseStoreOwner(authForToken(t, resolver, "token-a"))
|
||||
h.getResponseStore().put(ownerA, "resp_test", map[string]any{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
})
|
||||
|
||||
t.Run("unauthorized", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cross-tenant-not-found", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-b")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("same-tenant-ok", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-a")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode body failed: %v", err)
|
||||
}
|
||||
if body["id"] != "resp_test" {
|
||||
t.Fatalf("unexpected body: %#v", body)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponsesRouteValidationContract(t *testing.T) {
|
||||
store, resolver := newDirectTokenResolver(t)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{name: "missing_model", body: `{"input":"hello"}`},
|
||||
{name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body))
|
||||
req.Header.Set("Authorization", "Bearer token-a")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if _, ok := errObj["code"]; !ok {
|
||||
t.Fatalf("expected error.code: %#v", out)
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatalf("expected error.param: %#v", out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) {
|
||||
store, resolver := newManagedKeyResolver(t)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
ownerReq.Header.Set("Authorization", "Bearer managed-key")
|
||||
ownerAuth, err := resolver.DetermineCaller(ownerReq)
|
||||
if err != nil {
|
||||
t.Fatalf("determine caller failed: %v", err)
|
||||
}
|
||||
owner := responseStoreOwner(ownerAuth)
|
||||
h.getResponseStore().put(owner, "resp_test", map[string]any{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
})
|
||||
|
||||
occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
occupyReq.Header.Set("Authorization", "Bearer managed-key")
|
||||
occupied, err := resolver.Determine(occupyReq)
|
||||
if err != nil {
|
||||
t.Fatalf("expected first acquire to succeed: %v", err)
|
||||
}
|
||||
defer resolver.Release(occupied)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
210
internal/adapter/openai/responses_stream_runtime_core.go
Normal file
210
internal/adapter/openai/responses_stream_runtime_core.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type responsesStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
responseID string
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
traceID string
|
||||
toolChoice util.ToolChoicePolicy
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
toolCallsDoneEmitted bool
|
||||
|
||||
sieve toolStreamSieveState
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
visibleText strings.Builder
|
||||
streamToolCallIDs map[int]string
|
||||
functionItemIDs map[int]string
|
||||
functionOutputIDs map[int]int
|
||||
functionArgs map[int]string
|
||||
functionDone map[int]bool
|
||||
functionAdded map[int]bool
|
||||
functionNames map[int]string
|
||||
messageItemID string
|
||||
messageOutputID int
|
||||
nextOutputID int
|
||||
messageAdded bool
|
||||
messagePartAdded bool
|
||||
sequence int
|
||||
failed bool
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
|
||||
func newResponsesStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
responseID string,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
toolChoice util.ToolChoicePolicy,
|
||||
traceID string,
|
||||
persistResponse func(obj map[string]any),
|
||||
) *responsesStreamRuntime {
|
||||
return &responsesStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
responseID: responseID,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
toolNames: toolNames,
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
functionItemIDs: map[int]string{},
|
||||
functionOutputIDs: map[int]int{},
|
||||
functionArgs: map[int]string{},
|
||||
functionDone: map[int]bool{},
|
||||
functionAdded: map[int]bool{},
|
||||
functionNames: map[int]string{},
|
||||
messageOutputID: -1,
|
||||
toolChoice: toolChoice,
|
||||
traceID: traceID,
|
||||
persistResponse: persistResponse,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := sanitizeLeakedToolHistory(s.text.String())
|
||||
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
}
|
||||
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
detected := textParsed.Calls
|
||||
s.logToolPolicyRejections(textParsed)
|
||||
|
||||
if len(detected) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
if !s.toolCallsDoneEmitted {
|
||||
s.emitFunctionCallDoneEvents(detected)
|
||||
}
|
||||
}
|
||||
|
||||
s.closeMessageItem()
|
||||
|
||||
if s.toolChoice.IsRequired() && len(detected) == 0 {
|
||||
s.failed = true
|
||||
message := "tool_choice requires at least one valid tool call."
|
||||
failedResp := map[string]any{
|
||||
"id": s.responseID,
|
||||
"type": "response",
|
||||
"object": "response",
|
||||
"model": s.model,
|
||||
"status": "failed",
|
||||
"output": []any{},
|
||||
"output_text": "",
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
"type": "invalid_request_error",
|
||||
"code": "tool_choice_violation",
|
||||
"param": nil,
|
||||
},
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(failedResp)
|
||||
}
|
||||
s.sendEvent("response.failed", openaifmt.BuildResponsesFailedPayload(s.responseID, s.model, message, "tool_choice_violation"))
|
||||
s.sendDone()
|
||||
return
|
||||
}
|
||||
s.closeIncompleteFunctionItems()
|
||||
|
||||
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj))
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed util.ToolCallParseResult) {
|
||||
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
||||
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
||||
return
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[responses] rejected tool calls by policy",
|
||||
"trace_id", strings.TrimSpace(s.traceID),
|
||||
"channel", channel,
|
||||
"tool_choice_mode", s.toolChoice.Mode,
|
||||
"rejected_tool_names", strings.Join(rejected, ","),
|
||||
)
|
||||
}
|
||||
logRejected(textParsed, "text")
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
if p.Type == "thinking" {
|
||||
if !s.thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
|
||||
cleanedText := sanitizeLeakedToolHistory(p.Text)
|
||||
if cleanedText == "" {
|
||||
continue
|
||||
}
|
||||
s.text.WriteString(cleanedText)
|
||||
if !s.bufferToolContent {
|
||||
s.emitTextDelta(cleanedText)
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, cleanedText, s.toolNames), true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
61
internal/adapter/openai/responses_stream_runtime_events.go
Normal file
61
internal/adapter/openai/responses_stream_runtime_events.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) nextSequence() int {
|
||||
s.sequence++
|
||||
return s.sequence
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
|
||||
if payload == nil {
|
||||
payload = map[string]any{}
|
||||
}
|
||||
if _, ok := payload["sequence_number"]; !ok {
|
||||
payload["sequence_number"] = s.nextSequence()
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("event: " + event + "\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendCreated() {
|
||||
s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model))
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendDone() {
|
||||
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.emitTextDelta(evt.Content)
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames)
|
||||
if len(filtered) == 0 {
|
||||
continue
|
||||
}
|
||||
s.emitFunctionCallDeltaEvents(filtered)
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
245
internal/adapter/openai/responses_stream_runtime_toolcalls.go
Normal file
245
internal/adapter/openai/responses_stream_runtime_toolcalls.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/util"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) allocateOutputIndex() int {
|
||||
idx := s.nextOutputID
|
||||
s.nextOutputID++
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageItemID() string {
|
||||
if strings.TrimSpace(s.messageItemID) != "" {
|
||||
return s.messageItemID
|
||||
}
|
||||
s.messageItemID = "msg_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
return s.messageItemID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageOutputIndex() int {
|
||||
if s.messageOutputID >= 0 {
|
||||
return s.messageOutputID
|
||||
}
|
||||
s.messageOutputID = s.allocateOutputIndex()
|
||||
return s.messageOutputID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageItemAdded() {
|
||||
if s.messageAdded {
|
||||
return
|
||||
}
|
||||
itemID := s.ensureMessageItemID()
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "in_progress",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.added",
|
||||
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.ensureMessageOutputIndex(), item),
|
||||
)
|
||||
s.messageAdded = true
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageContentPartAdded() {
|
||||
if s.messagePartAdded {
|
||||
return
|
||||
}
|
||||
s.ensureMessageItemAdded()
|
||||
s.sendEvent(
|
||||
"response.content_part.added",
|
||||
openaifmt.BuildResponsesContentPartAddedPayload(
|
||||
s.responseID,
|
||||
s.ensureMessageItemID(),
|
||||
s.ensureMessageOutputIndex(),
|
||||
0,
|
||||
map[string]any{"type": "output_text", "text": ""},
|
||||
),
|
||||
)
|
||||
s.messagePartAdded = true
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitTextDelta(content string) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
s.ensureMessageContentPartAdded()
|
||||
s.visibleText.WriteString(content)
|
||||
s.sendEvent(
|
||||
"response.output_text.delta",
|
||||
openaifmt.BuildResponsesTextDeltaPayload(
|
||||
s.responseID,
|
||||
s.ensureMessageItemID(),
|
||||
s.ensureMessageOutputIndex(),
|
||||
0,
|
||||
content,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) closeMessageItem() {
|
||||
if !s.messageAdded {
|
||||
return
|
||||
}
|
||||
itemID := s.ensureMessageItemID()
|
||||
outputIndex := s.ensureMessageOutputIndex()
|
||||
text := s.visibleText.String()
|
||||
if s.messagePartAdded {
|
||||
s.sendEvent(
|
||||
"response.output_text.done",
|
||||
openaifmt.BuildResponsesTextDonePayload(
|
||||
s.responseID,
|
||||
itemID,
|
||||
outputIndex,
|
||||
0,
|
||||
text,
|
||||
),
|
||||
)
|
||||
s.sendEvent(
|
||||
"response.content_part.done",
|
||||
openaifmt.BuildResponsesContentPartDonePayload(
|
||||
s.responseID,
|
||||
itemID,
|
||||
outputIndex,
|
||||
0,
|
||||
map[string]any{"type": "output_text", "text": text},
|
||||
),
|
||||
)
|
||||
s.messagePartAdded = false
|
||||
}
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemID(callIndex int) string {
|
||||
if id, ok := s.functionItemIDs[callIndex]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.functionItemIDs[callIndex] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureToolCallID(callIndex int) string {
|
||||
if id, ok := s.streamToolCallIDs[callIndex]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.streamToolCallIDs[callIndex] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionOutputIndex(callIndex int) int {
|
||||
if idx, ok := s.functionOutputIDs[callIndex]; ok {
|
||||
return idx
|
||||
}
|
||||
idx := s.allocateOutputIndex()
|
||||
s.functionOutputIDs[callIndex] = idx
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name string) {
|
||||
if strings.TrimSpace(name) != "" {
|
||||
s.functionNames[callIndex] = strings.TrimSpace(name)
|
||||
}
|
||||
if s.functionAdded[callIndex] {
|
||||
return
|
||||
}
|
||||
fnName := strings.TrimSpace(s.functionNames[callIndex])
|
||||
if fnName == "" {
|
||||
return
|
||||
}
|
||||
outputIndex := s.ensureFunctionOutputIndex(callIndex)
|
||||
itemID := s.ensureFunctionItemID(callIndex)
|
||||
callID := s.ensureToolCallID(callIndex)
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": fnName,
|
||||
"arguments": "",
|
||||
"status": "in_progress",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.added",
|
||||
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
s.functionAdded[callIndex] = true
|
||||
s.toolCallsEmitted = true
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
|
||||
for _, d := range deltas {
|
||||
s.ensureFunctionItemAdded(d.Index, d.Name)
|
||||
if strings.TrimSpace(d.Arguments) == "" {
|
||||
continue
|
||||
}
|
||||
s.functionArgs[d.Index] += d.Arguments
|
||||
outputIndex := s.ensureFunctionOutputIndex(d.Index)
|
||||
itemID := s.ensureFunctionItemID(d.Index)
|
||||
callID := s.ensureToolCallID(d.Index)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.delta",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
|
||||
for idx, tc := range calls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
s.ensureFunctionItemAdded(idx, tc.Name)
|
||||
if s.functionDone[idx] {
|
||||
continue
|
||||
}
|
||||
outputIndex := s.ensureFunctionOutputIndex(idx)
|
||||
itemID := s.ensureFunctionItemID(idx)
|
||||
callID := s.ensureToolCallID(idx)
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
args := string(argsBytes)
|
||||
s.functionArgs[idx] = args
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.done",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, args),
|
||||
)
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": tc.Name,
|
||||
"arguments": args,
|
||||
"status": "completed",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
s.functionDone[idx] = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) closeIncompleteFunctionItems() {
|
||||
if len(s.functionAdded) == 0 {
|
||||
return
|
||||
}
|
||||
indices := make([]int, 0, len(s.functionAdded))
|
||||
for idx, added := range s.functionAdded {
|
||||
if !added || s.functionDone[idx] {
|
||||
continue
|
||||
}
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
if len(indices) == 0 {
|
||||
return
|
||||
}
|
||||
sort.Ints(indices)
|
||||
for _, idx := range indices {
|
||||
name := strings.TrimSpace(s.functionNames[idx])
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
args := strings.TrimSpace(s.functionArgs[idx])
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
outputIndex := s.ensureFunctionOutputIndex(idx)
|
||||
itemID := s.ensureFunctionItemID(idx)
|
||||
callID := s.ensureToolCallID(idx)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.done",
|
||||
openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, name, args),
|
||||
)
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": name,
|
||||
"arguments": args,
|
||||
"status": "completed",
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
s.functionDone[idx] = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, finalText string, calls []util.ParsedToolCall) map[string]any {
|
||||
type indexedItem struct {
|
||||
index int
|
||||
item map[string]any
|
||||
}
|
||||
indexed := make([]indexedItem, 0, len(calls)+1)
|
||||
|
||||
if s.messageAdded {
|
||||
text := s.visibleText.String()
|
||||
indexed = append(indexed, indexedItem{
|
||||
index: s.ensureMessageOutputIndex(),
|
||||
item: map[string]any{
|
||||
"id": s.ensureMessageItemID(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
} else if len(calls) == 0 {
|
||||
content := make([]map[string]any, 0, 2)
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
content = append(content, map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
})
|
||||
}
|
||||
if strings.TrimSpace(finalText) != "" {
|
||||
content = append(content, map[string]any{
|
||||
"type": "output_text",
|
||||
"text": finalText,
|
||||
})
|
||||
}
|
||||
if len(content) > 0 {
|
||||
indexed = append(indexed, indexedItem{
|
||||
index: s.ensureMessageOutputIndex(),
|
||||
item: map[string]any{
|
||||
"id": s.ensureMessageItemID(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": content,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for idx, tc := range calls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
indexed = append(indexed, indexedItem{
|
||||
index: s.ensureFunctionOutputIndex(idx),
|
||||
item: map[string]any{
|
||||
"id": s.ensureFunctionItemID(idx),
|
||||
"type": "function_call",
|
||||
"call_id": s.ensureToolCallID(idx),
|
||||
"name": tc.Name,
|
||||
"arguments": string(argsBytes),
|
||||
"status": "completed",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
sort.SliceStable(indexed, func(i, j int) bool {
|
||||
return indexed[i].index < indexed[j].index
|
||||
})
|
||||
output := make([]any, 0, len(indexed))
|
||||
for _, it := range indexed {
|
||||
output = append(output, it.item)
|
||||
}
|
||||
|
||||
outputText := s.visibleText.String()
|
||||
if strings.TrimSpace(outputText) == "" && len(calls) == 0 {
|
||||
if strings.TrimSpace(finalText) != "" {
|
||||
outputText = finalText
|
||||
} else if strings.TrimSpace(finalThinking) != "" {
|
||||
outputText = finalThinking
|
||||
}
|
||||
}
|
||||
|
||||
return openaifmt.BuildResponseObjectFromItems(
|
||||
s.responseID,
|
||||
s.model,
|
||||
s.finalPrompt,
|
||||
finalThinking,
|
||||
finalText,
|
||||
output,
|
||||
outputText,
|
||||
)
|
||||
}
|
||||
677
internal/adapter/openai/responses_stream_test.go
Normal file
677
internal/adapter/openai/responses_stream_test.go
Normal file
@@ -0,0 +1,677 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`
|
||||
streamBody := sseLine(rawToolJSON) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
outputText, _ := responseObj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText)
|
||||
}
|
||||
output, _ := responseObj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
|
||||
}
|
||||
hasFunctionCall := false
|
||||
hasLegacyWrapper := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
if m["type"] == "tool_calls" {
|
||||
hasLegacyWrapper = true
|
||||
}
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call item, got %#v", responseObj["output"])
|
||||
}
|
||||
if hasLegacyWrapper {
|
||||
t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"])
|
||||
}
|
||||
if strings.Contains(outputText, `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_item.added") {
|
||||
t.Fatalf("expected response.output_item.added event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") {
|
||||
t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body)
|
||||
}
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(body, "response.output_item.added")
|
||||
hasFunctionCallAdded := false
|
||||
for _, payload := range addedPayloads {
|
||||
item, _ := payload["item"].(map[string]any)
|
||||
if item == nil || asString(item["type"]) != "function_call" {
|
||||
continue
|
||||
}
|
||||
hasFunctionCallAdded = true
|
||||
if asString(item["arguments"]) != "" {
|
||||
t.Fatalf("expected in-progress function_call.arguments to start empty string, got %#v", item["arguments"])
|
||||
}
|
||||
}
|
||||
if !hasFunctionCallAdded {
|
||||
t.Fatalf("expected function_call output_item.added payload, body=%s", body)
|
||||
}
|
||||
|
||||
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
|
||||
if !ok {
|
||||
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
|
||||
}
|
||||
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
|
||||
if doneCallID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
|
||||
}
|
||||
completed, ok := extractSSEEventPayload(body, "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed payload, body=%s", body)
|
||||
}
|
||||
responseObj, _ := completed["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
var completedCallID string
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil || m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
completedCallID = strings.TrimSpace(asString(m["call_id"]))
|
||||
if completedCallID != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if completedCallID == "" {
|
||||
t.Fatalf("expected function_call.call_id in completed output, output=%#v", output)
|
||||
}
|
||||
if completedCallID != doneCallID {
|
||||
t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/thinking_content",
|
||||
"v": "thought",
|
||||
})
|
||||
streamBody := "data: " + string(b) + "\n" + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.reasoning.delta") {
|
||||
t.Fatalf("expected response.reasoning.delta event, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.reasoning_text.delta") || strings.Contains(body, "event: response.reasoning_text.done") {
|
||||
t.Fatalf("did not expect response.reasoning_text.* compatibility events, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
|
||||
sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(donePayloads) != 2 {
|
||||
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
|
||||
}
|
||||
seenNames := map[string]string{}
|
||||
for _, payload := range donePayloads {
|
||||
name := strings.TrimSpace(asString(payload["name"]))
|
||||
callID := strings.TrimSpace(asString(payload["call_id"]))
|
||||
if name != "search_web" && name != "eval_javascript" {
|
||||
t.Fatalf("unexpected tool name in done payload: %#v", payload)
|
||||
}
|
||||
if callID == "" {
|
||||
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
|
||||
}
|
||||
seenNames[name] = callID
|
||||
}
|
||||
if seenNames["search_web"] == seenNames["eval_javascript"] {
|
||||
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("hello") + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_text.done") {
|
||||
t.Fatalf("expected response.output_text.done payload, body=%s", body)
|
||||
}
|
||||
textDoneIdx := strings.Index(body, "event: response.output_text.done")
|
||||
partDoneIdx := strings.Index(body, "event: response.content_part.done")
|
||||
if textDoneIdx < 0 || partDoneIdx < 0 {
|
||||
t.Fatalf("expected output_text.done + content_part.done, body=%s", body)
|
||||
}
|
||||
if textDoneIdx > partDoneIdx {
|
||||
t.Fatalf("expected output_text.done before content_part.done, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("hello") + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
|
||||
deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.output_text.delta payload, body=%s", body)
|
||||
}
|
||||
if strings.TrimSpace(asString(deltaPayload["item_id"])) == "" {
|
||||
t.Fatalf("expected non-empty item_id in output_text.delta, payload=%#v", deltaPayload)
|
||||
}
|
||||
if _, ok := deltaPayload["output_index"]; !ok {
|
||||
t.Fatalf("expected output_index in output_text.delta, payload=%#v", deltaPayload)
|
||||
}
|
||||
if _, ok := deltaPayload["content_index"]; !ok {
|
||||
t.Fatalf("expected content_index in output_text.delta, payload=%#v", deltaPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, value string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": value,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", "thinking...") +
|
||||
sseLine("response/content", "先读取文件。") +
|
||||
sseLine("response/content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||
if len(addedPayloads) < 1 {
|
||||
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
}
|
||||
|
||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed payload, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasMessage := false
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if asString(m["type"]) == "message" {
|
||||
hasMessage = true
|
||||
}
|
||||
if asString(m["type"]) == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
}
|
||||
if !hasMessage {
|
||||
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for tool_choice=none, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
// invalid JSON (NaN) should remain plain text in strict mode.
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_text.delta") {
|
||||
t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("plain text only") + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, value string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": value,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
sseLine("response/content", "plain text only") +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for unknown tool, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "tool_choice_violation" {
|
||||
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
|
||||
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "tool_choice_violation" {
|
||||
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for tool_choice=none passthrough text, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
output, _ := out["output"].([]any)
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
t.Fatalf("did not expect function_call output item for tool_choice=none, got %#v", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||
matched = evt == targetEvent
|
||||
continue
|
||||
}
|
||||
if !matched || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if raw == "" || raw == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return payload, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
out := make([]map[string]any, 0, 2)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||
matched = evt == targetEvent
|
||||
continue
|
||||
}
|
||||
if !matched || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if raw == "" || raw == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, payload)
|
||||
}
|
||||
return out
|
||||
}
|
||||
326
internal/adapter/openai/standard_request.go
Normal file
326
internal/adapter/openai/standard_request.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
|
||||
}
|
||||
resolvedModel, ok := config.ResolveModel(store, model)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
responseModel := strings.TrimSpace(model)
|
||||
if responseModel == "" {
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
toolPolicy := util.DefaultToolChoicePolicy()
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "openai_chat",
|
||||
RequestedModel: strings.TrimSpace(model),
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: responseModel,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, traceID string) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.")
|
||||
}
|
||||
resolvedModel, ok := config.ResolveModel(store, model)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
|
||||
// Keep width-control as an explicit policy hook even if current default is true.
|
||||
allowWideInput := true
|
||||
if store != nil {
|
||||
allowWideInput = store.CompatWideInputStrictOutput()
|
||||
}
|
||||
var messagesRaw []any
|
||||
if allowWideInput {
|
||||
messagesRaw = responsesMessagesFromRequest(req)
|
||||
} else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
|
||||
messagesRaw = msgs
|
||||
}
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
|
||||
}
|
||||
toolPolicy, err := parseToolChoicePolicy(req["tool_choice"], req["tools"])
|
||||
if err != nil {
|
||||
return util.StandardRequest{}, err
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
if toolPolicy.IsNone() {
|
||||
toolNames = nil
|
||||
toolPolicy.Allowed = nil
|
||||
} else {
|
||||
toolPolicy.Allowed = namesToSet(toolNames)
|
||||
}
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "openai_responses",
|
||||
RequestedModel: model,
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: model,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
ToolChoice: toolPolicy,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
|
||||
out := map[string]any{}
|
||||
for _, k := range []string{
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"stop",
|
||||
} {
|
||||
if v, ok := req[k]; ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseToolChoicePolicy(toolChoiceRaw any, toolsRaw any) (util.ToolChoicePolicy, error) {
|
||||
policy := util.DefaultToolChoicePolicy()
|
||||
declaredNames := extractDeclaredToolNames(toolsRaw)
|
||||
declaredSet := namesToSet(declaredNames)
|
||||
if len(declaredNames) > 0 {
|
||||
policy.Allowed = declaredSet
|
||||
}
|
||||
|
||||
if toolChoiceRaw == nil {
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
switch v := toolChoiceRaw.(type) {
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "", "auto":
|
||||
policy.Mode = util.ToolChoiceAuto
|
||||
case "none":
|
||||
policy.Mode = util.ToolChoiceNone
|
||||
policy.Allowed = nil
|
||||
case "required":
|
||||
policy.Mode = util.ToolChoiceRequired
|
||||
default:
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice: %q", v)
|
||||
}
|
||||
case map[string]any:
|
||||
allowedOverride, hasAllowedOverride, err := parseAllowedToolNames(v["allowed_tools"])
|
||||
if err != nil {
|
||||
return util.ToolChoicePolicy{}, err
|
||||
}
|
||||
if hasAllowedOverride {
|
||||
filtered := make([]string, 0, len(allowedOverride))
|
||||
for _, name := range allowedOverride {
|
||||
if _, ok := declaredSet[name]; !ok {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice.allowed_tools contains undeclared tool %q", name)
|
||||
}
|
||||
filtered = append(filtered, name)
|
||||
}
|
||||
policy.Allowed = namesToSet(filtered)
|
||||
}
|
||||
|
||||
typ := strings.ToLower(strings.TrimSpace(asString(v["type"])))
|
||||
switch typ {
|
||||
case "", "auto":
|
||||
if hasFunctionSelector(v) {
|
||||
name, err := parseForcedToolName(v)
|
||||
if err != nil {
|
||||
return util.ToolChoicePolicy{}, err
|
||||
}
|
||||
policy.Mode = util.ToolChoiceForced
|
||||
policy.ForcedName = name
|
||||
policy.Allowed = namesToSet([]string{name})
|
||||
} else {
|
||||
policy.Mode = util.ToolChoiceAuto
|
||||
}
|
||||
case "none":
|
||||
policy.Mode = util.ToolChoiceNone
|
||||
policy.Allowed = nil
|
||||
case "required":
|
||||
policy.Mode = util.ToolChoiceRequired
|
||||
case "function":
|
||||
name, err := parseForcedToolName(v)
|
||||
if err != nil {
|
||||
return util.ToolChoicePolicy{}, err
|
||||
}
|
||||
policy.Mode = util.ToolChoiceForced
|
||||
policy.ForcedName = name
|
||||
policy.Allowed = namesToSet([]string{name})
|
||||
default:
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("Unsupported tool_choice.type: %q", typ)
|
||||
}
|
||||
default:
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice must be a string or object")
|
||||
}
|
||||
|
||||
if policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced {
|
||||
if len(declaredNames) == 0 {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice=%s requires non-empty tools.", policy.Mode)
|
||||
}
|
||||
}
|
||||
if policy.Mode == util.ToolChoiceForced {
|
||||
if _, ok := declaredSet[policy.ForcedName]; !ok {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice forced function %q is not declared in tools", policy.ForcedName)
|
||||
}
|
||||
}
|
||||
if len(policy.Allowed) == 0 && (policy.Mode == util.ToolChoiceRequired || policy.Mode == util.ToolChoiceForced) {
|
||||
return util.ToolChoicePolicy{}, fmt.Errorf("tool_choice policy resolved to empty allowed tool set")
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func parseForcedToolName(v map[string]any) (string, error) {
|
||||
if name := strings.TrimSpace(asString(v["name"])); name != "" {
|
||||
return name, nil
|
||||
}
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
|
||||
return name, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("tool_choice function requires name")
|
||||
}
|
||||
|
||||
func parseAllowedToolNames(raw any) ([]string, bool, error) {
|
||||
if raw == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
collectName := func(v any) string {
|
||||
if name := strings.TrimSpace(asString(v)); name != "" {
|
||||
return name
|
||||
}
|
||||
if m, ok := v.(map[string]any); ok {
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
return name
|
||||
}
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if name := strings.TrimSpace(asString(fn["name"])); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
names := []string{}
|
||||
switch x := raw.(type) {
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
name := collectName(item)
|
||||
if name == "" {
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains invalid item")
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
case []string:
|
||||
for _, item := range x {
|
||||
name := strings.TrimSpace(item)
|
||||
if name == "" {
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools contains empty name")
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
default:
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools must be an array")
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
return nil, true, fmt.Errorf("tool_choice.allowed_tools must not be empty")
|
||||
}
|
||||
return names, true, nil
|
||||
}
|
||||
|
||||
func hasFunctionSelector(v map[string]any) bool {
|
||||
if strings.TrimSpace(asString(v["name"])) != "" {
|
||||
return true
|
||||
}
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
return strings.TrimSpace(asString(fn["name"])) != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractDeclaredToolNames(toolsRaw any) []string {
|
||||
tools, ok := toolsRaw.([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(tools))
|
||||
seen := map[string]struct{}{}
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fn, _ := tool["function"].(map[string]any)
|
||||
if len(fn) == 0 {
|
||||
fn = tool
|
||||
}
|
||||
name := strings.TrimSpace(asString(fn["name"]))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
out = append(out, name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func namesToSet(names []string) map[string]struct{} {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]struct{}, len(names))
|
||||
for _, name := range names {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out[trimmed] = struct{}{}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
180
internal/adapter/openai/standard_request_test.go
Normal file
180
internal/adapter/openai/standard_request_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
return config.LoadStore()
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIChatRequest(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-5-codex",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"temperature": 0.3,
|
||||
"stream": true,
|
||||
}
|
||||
n, err := normalizeOpenAIChatRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ResolvedModel != "deepseek-reasoner" {
|
||||
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
|
||||
}
|
||||
if !n.Stream {
|
||||
t.Fatalf("expected stream=true")
|
||||
}
|
||||
if _, ok := n.PassThrough["temperature"]; !ok {
|
||||
t.Fatalf("expected temperature passthrough")
|
||||
}
|
||||
if n.FinalPrompt == "" {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"instructions": "system",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ResolvedModel != "deepseek-chat" {
|
||||
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
|
||||
}
|
||||
if len(n.Messages) != 2 {
|
||||
t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceRequired(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": "required",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ToolChoice.Mode != util.ToolChoiceRequired {
|
||||
t.Fatalf("expected tool choice mode required, got %q", n.ToolChoice.Mode)
|
||||
}
|
||||
if len(n.ToolNames) != 1 || n.ToolNames[0] != "search" {
|
||||
t.Fatalf("unexpected tool names: %#v", n.ToolNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedFunction(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "read_file",
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"name": "read_file",
|
||||
},
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ToolChoice.Mode != util.ToolChoiceForced {
|
||||
t.Fatalf("expected tool choice mode forced, got %q", n.ToolChoice.Mode)
|
||||
}
|
||||
if n.ToolChoice.ForcedName != "read_file" {
|
||||
t.Fatalf("expected forced tool name read_file, got %q", n.ToolChoice.ForcedName)
|
||||
}
|
||||
if len(n.ToolNames) != 1 || n.ToolNames[0] != "read_file" {
|
||||
t.Fatalf("expected filtered tool names [read_file], got %#v", n.ToolNames)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": map[string]any{
|
||||
"type": "function",
|
||||
"name": "read_file",
|
||||
},
|
||||
}
|
||||
if _, err := normalizeOpenAIResponsesRequest(store, req, ""); err == nil {
|
||||
t.Fatalf("expected forced undeclared tool to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": "none",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ToolChoice.Mode != util.ToolChoiceNone {
|
||||
t.Fatalf("expected tool choice mode none, got %q", n.ToolChoice.Mode)
|
||||
}
|
||||
if len(n.ToolNames) != 0 {
|
||||
t.Fatalf("expected no tool names when tool_choice=none, got %#v", n.ToolNames)
|
||||
}
|
||||
}
|
||||
185
internal/adapter/openai/stream_status_test.go
Normal file
185
internal/adapter/openai/stream_status_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimw "github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type streamStatusAuthStub struct{}
|
||||
|
||||
func (streamStatusAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: "direct-token",
|
||||
CallerID: "caller:test",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (streamStatusAuthStub) DetermineCaller(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: "direct-token",
|
||||
CallerID: "caller:test",
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (streamStatusAuthStub) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
type streamStatusDSStub struct {
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "session-id", nil
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
return "pow", nil
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
return m.resp, nil
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
|
||||
body := strings.Join(lines, "\n")
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
body += "\n"
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func captureStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
next.ServeHTTP(ww, r)
|
||||
*statuses = append(*statuses, ww.Status())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsStreamStatusCapturedAs200(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||
}
|
||||
if statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesStreamStatusCapturedAs200(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, "data: [DONE]")},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","input":"hi","stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||
}
|
||||
if statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
content, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||
})
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 || statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200, got %#v", statuses)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
outputText, _ := out["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one output item, got %#v", output)
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
17
internal/adapter/openai/tool_history_sanitize.go
Normal file
17
internal/adapter/openai/tool_history_sanitize.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var leakedToolHistoryPattern = regexp.MustCompile(`(?is)\[TOOL_CALL_HISTORY\][\s\S]*?\[/TOOL_CALL_HISTORY\]|\[TOOL_RESULT_HISTORY\][\s\S]*?\[/TOOL_RESULT_HISTORY\]`)
|
||||
var emptyJSONFencePattern = regexp.MustCompile("(?is)```json\\s*```")
|
||||
|
||||
func sanitizeLeakedToolHistory(text string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
out := leakedToolHistoryPattern.ReplaceAllString(text, "")
|
||||
out = emptyJSONFencePattern.ReplaceAllString(out, "")
|
||||
return out
|
||||
}
|
||||
106
internal/adapter/openai/tool_history_sanitize_test.go
Normal file
106
internal/adapter/openai/tool_history_sanitize_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeLeakedToolHistoryRemovesMarkerBlocks(t *testing.T) {
|
||||
raw := "前缀\n[TOOL_CALL_HISTORY]\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_CALL_HISTORY]\n后缀"
|
||||
got := sanitizeLeakedToolHistory(raw)
|
||||
if got != "前缀\n\n后缀" {
|
||||
t.Fatalf("unexpected sanitized content: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedToolHistoryPreservesChunkWhitespace(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "trailing space kept",
|
||||
raw: "Hello ",
|
||||
want: "Hello ",
|
||||
},
|
||||
{
|
||||
name: "leading newline kept",
|
||||
raw: "\nworld",
|
||||
want: "\nworld",
|
||||
},
|
||||
{
|
||||
name: "surrounding whitespace around marker is preserved",
|
||||
raw: "A \n[TOOL_RESULT_HISTORY]\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]\n B",
|
||||
want: "A \n\n B",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := sanitizeLeakedToolHistory(tc.raw)
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected sanitize result, want %q got %q", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedToolHistoryRemovesEmptyJSONFence(t *testing.T) {
|
||||
raw := "before\n```json\n```\nafter"
|
||||
got := sanitizeLeakedToolHistory(raw)
|
||||
if got != "before\n\nafter" {
|
||||
t.Fatalf("unexpected sanitized empty json fence: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushToolSieveDropsToolHistoryLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := "[TOOL_CALL_HISTORY]\nstatus: already_called\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_CALL_HISTORY]"
|
||||
evts := processToolSieveChunk(&state, chunk, []string{"exec"})
|
||||
if len(evts) != 0 {
|
||||
t.Fatalf("expected no immediate output before history block is complete, got %+v", evts)
|
||||
}
|
||||
flushed := flushToolSieve(&state, []string{"exec"})
|
||||
if len(flushed) != 0 {
|
||||
t.Fatalf("expected history block to be swallowed, got %+v", flushed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushToolSieveDropsToolResultHistoryLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := "[TOOL_RESULT_HISTORY]\nstatus: already_called\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]"
|
||||
evts := processToolSieveChunk(&state, chunk, []string{"exec"})
|
||||
if len(evts) != 0 {
|
||||
t.Fatalf("expected no immediate output before result history block is complete, got %+v", evts)
|
||||
}
|
||||
flushed := flushToolSieve(&state, []string{"exec"})
|
||||
if len(flushed) != 0 {
|
||||
t.Fatalf("expected result history block to be swallowed, got %+v", flushed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveChunkSplitsResultHistoryBoundary(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
parts := []string{
|
||||
"Hello ",
|
||||
"[TOOL_RESULT_HISTORY]\nstatus: already_called\n",
|
||||
"function.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]",
|
||||
"world",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, p := range parts {
|
||||
events = append(events, processToolSieveChunk(&state, p, []string{"exec"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"exec"})...)
|
||||
|
||||
var text string
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
text += evt.Content
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
t.Fatalf("did not expect parsed tool calls from history leak: %+v", evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
if text != "Hello world" {
|
||||
t.Fatalf("expected clean text output preserving boundary spaces, got %q", text)
|
||||
}
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
Content string
|
||||
ToolCalls []util.ParsedToolCall
|
||||
}
|
||||
|
||||
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
if chunk != "" {
|
||||
state.pending.WriteString(chunk)
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
|
||||
for {
|
||||
if state.capturing {
|
||||
if state.pending.Len() > 0 {
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
||||
if !ready {
|
||||
break
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
if prefix != "" {
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: calls})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pending := state.pending.String()
|
||||
if pending == "" {
|
||||
break
|
||||
}
|
||||
start := findToolSegmentStart(pending)
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.capture.WriteString(pending[start:])
|
||||
state.capturing = true
|
||||
continue
|
||||
}
|
||||
|
||||
safe, hold := splitSafeContentForToolDetection(pending)
|
||||
if safe == "" {
|
||||
break
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.pending.WriteString(hold)
|
||||
events = append(events, toolStreamEvent{Content: safe})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if state.capturing {
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
||||
if ready {
|
||||
if consumedPrefix != "" {
|
||||
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
||||
}
|
||||
if len(consumedCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
||||
}
|
||||
if consumedSuffix != "" {
|
||||
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
||||
}
|
||||
} else {
|
||||
// Incomplete captured tool JSON at stream end: suppress raw capture.
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
events = append(events, toolStreamEvent{Content: state.pending.String()})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func splitSafeContentForToolDetection(s string) (safe, hold string) {
|
||||
if s == "" {
|
||||
return "", ""
|
||||
}
|
||||
suspiciousStart := findSuspiciousPrefixStart(s)
|
||||
if suspiciousStart < 0 {
|
||||
return s, ""
|
||||
}
|
||||
if suspiciousStart > 0 {
|
||||
return s[:suspiciousStart], s[suspiciousStart:]
|
||||
}
|
||||
// If suspicious content starts at position 0, keep holding until we can
|
||||
// parse a complete tool JSON block or reach stream flush.
|
||||
return "", s
|
||||
}
|
||||
|
||||
func findSuspiciousPrefixStart(s string) int {
|
||||
start := -1
|
||||
indices := []int{
|
||||
strings.LastIndex(s, "{"),
|
||||
strings.LastIndex(s, "["),
|
||||
strings.LastIndex(s, "```"),
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if idx > start {
|
||||
start = idx
|
||||
}
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func findToolSegmentStart(s string) int {
|
||||
if s == "" {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 {
|
||||
return start
|
||||
}
|
||||
return keyIdx
|
||||
}
|
||||
|
||||
func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
if captured == "" {
|
||||
return "", nil, "", false
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
return "", nil, "", false
|
||||
}
|
||||
parsed := util.ParseToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
// `tool_calls` key exists but strict JSON parse failed.
|
||||
// Drop the captured object body to avoid leaking raw tool JSON.
|
||||
return captured[:start], nil, captured[end:], true
|
||||
}
|
||||
return captured[:start], parsed, captured[end:], true
|
||||
}
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
end := i + 1
|
||||
return text[start:end], end, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
300
internal/adapter/openai/tool_sieve_core.go
Normal file
300
internal/adapter/openai/tool_sieve_core.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
if chunk != "" {
|
||||
state.pending.WriteString(chunk)
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
if len(state.pendingToolCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
}
|
||||
|
||||
for {
|
||||
if state.capturing {
|
||||
if state.pending.Len() > 0 {
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||
if !ready {
|
||||
break
|
||||
}
|
||||
captured := state.capture.String()
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if len(calls) > 0 {
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
_ = captured
|
||||
state.pendingToolCalls = calls
|
||||
continue
|
||||
}
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pending := state.pending.String()
|
||||
if pending == "" {
|
||||
break
|
||||
}
|
||||
start := findToolSegmentStart(pending)
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.capture.WriteString(pending[start:])
|
||||
state.capturing = true
|
||||
state.resetIncrementalToolState()
|
||||
continue
|
||||
}
|
||||
|
||||
safe, hold := splitSafeContentForToolDetection(pending)
|
||||
if safe == "" {
|
||||
break
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.pending.WriteString(hold)
|
||||
state.noteText(safe)
|
||||
events = append(events, toolStreamEvent{Content: safe})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if len(state.pendingToolCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
}
|
||||
if state.capturing {
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||
if ready {
|
||||
if consumedPrefix != "" {
|
||||
state.noteText(consumedPrefix)
|
||||
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
||||
}
|
||||
if len(consumedCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
||||
}
|
||||
if consumedSuffix != "" {
|
||||
state.noteText(consumedSuffix)
|
||||
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
||||
}
|
||||
} else {
|
||||
content := state.capture.String()
|
||||
if content != "" {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
content := state.pending.String()
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func splitSafeContentForToolDetection(s string) (safe, hold string) {
|
||||
if s == "" {
|
||||
return "", ""
|
||||
}
|
||||
suspiciousStart := findSuspiciousPrefixStart(s)
|
||||
if suspiciousStart < 0 {
|
||||
return s, ""
|
||||
}
|
||||
if suspiciousStart > 0 {
|
||||
return s[:suspiciousStart], s[suspiciousStart:]
|
||||
}
|
||||
// If suspicious content starts at position 0, keep holding until we can
|
||||
// parse a complete tool JSON block or reach stream flush.
|
||||
return "", s
|
||||
}
|
||||
|
||||
func findSuspiciousPrefixStart(s string) int {
|
||||
start := -1
|
||||
indices := []int{
|
||||
strings.LastIndex(s, "{"),
|
||||
strings.LastIndex(s, "["),
|
||||
strings.LastIndex(s, "```"),
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if idx > start {
|
||||
start = idx
|
||||
}
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func findToolSegmentStart(s string) int {
|
||||
if s == "" {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]", "[tool_result_history]"}
|
||||
bestKeyIdx := -1
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||
bestKeyIdx = idx
|
||||
}
|
||||
}
|
||||
if bestKeyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
start := strings.LastIndex(s[:bestKeyIdx], "{")
|
||||
if start < 0 {
|
||||
start = bestKeyIdx
|
||||
}
|
||||
if fenceStart, ok := openFenceStartBefore(s, start); ok {
|
||||
return fenceStart
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return "", nil, "", false
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := -1
|
||||
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]", "[tool_result_history]"}
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {
|
||||
keyIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if keyIdx < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
if blockStart, blockEnd, ok := extractToolHistoryBlock(captured, keyIdx); ok {
|
||||
return captured[:blockStart], nil, captured[blockEnd:], true
|
||||
}
|
||||
start = keyIdx
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
return "", nil, "", false
|
||||
}
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||
if len(parsed.Calls) == 0 {
|
||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||
// Parsed as tool-call payload but rejected by schema/policy:
|
||||
// consume it to avoid leaking raw tool_calls JSON to user content.
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
// If it has obvious keywords but failed to parse even after loose repair,
|
||||
// we still might want to intercept it if it looks like an attempt at tool call.
|
||||
// For now, keep the original logic but rely on loose JSON repair.
|
||||
return captured, nil, "", true
|
||||
}
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed.Calls, suffixPart, true
|
||||
}
|
||||
|
||||
func extractToolHistoryBlock(captured string, keyIdx int) (start int, end int, ok bool) {
|
||||
if keyIdx < 0 || keyIdx >= len(captured) {
|
||||
return 0, 0, false
|
||||
}
|
||||
rest := strings.ToLower(captured[keyIdx:])
|
||||
switch {
|
||||
case strings.HasPrefix(rest, "[tool_call_history]"):
|
||||
closeTag := "[/tool_call_history]"
|
||||
closeIdx := strings.Index(rest, closeTag)
|
||||
if closeIdx < 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
return keyIdx, keyIdx + closeIdx + len(closeTag), true
|
||||
case strings.HasPrefix(rest, "[tool_result_history]"):
|
||||
closeTag := "[/tool_result_history]"
|
||||
closeIdx := strings.Index(rest, closeTag)
|
||||
if closeIdx < 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
return keyIdx, keyIdx + closeIdx + len(closeTag), true
|
||||
default:
|
||||
return 0, 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func trimWrappingJSONFence(prefix, suffix string) (string, string) {
|
||||
trimmedPrefix := strings.TrimRight(prefix, " \t\r\n")
|
||||
fenceIdx := strings.LastIndex(trimmedPrefix, "```")
|
||||
if fenceIdx < 0 {
|
||||
return prefix, suffix
|
||||
}
|
||||
// Only strip when the trailing fence in prefix behaves like an opening fence.
|
||||
// A legitimate closing fence before a standalone tool JSON must be preserved.
|
||||
if strings.Count(trimmedPrefix[:fenceIdx+3], "```")%2 == 0 {
|
||||
return prefix, suffix
|
||||
}
|
||||
fenceHeader := strings.TrimSpace(trimmedPrefix[fenceIdx+3:])
|
||||
if fenceHeader != "" && !strings.EqualFold(fenceHeader, "json") {
|
||||
return prefix, suffix
|
||||
}
|
||||
|
||||
trimmedSuffix := strings.TrimLeft(suffix, " \t\r\n")
|
||||
if !strings.HasPrefix(trimmedSuffix, "```") {
|
||||
return prefix, suffix
|
||||
}
|
||||
consumedLeading := len(suffix) - len(trimmedSuffix)
|
||||
return trimmedPrefix[:fenceIdx], suffix[consumedLeading+3:]
|
||||
}
|
||||
|
||||
func openFenceStartBefore(s string, pos int) (int, bool) {
|
||||
if pos <= 0 || pos > len(s) {
|
||||
return -1, false
|
||||
}
|
||||
segment := s[:pos]
|
||||
lastFence := strings.LastIndex(segment, "```")
|
||||
if lastFence < 0 {
|
||||
return -1, false
|
||||
}
|
||||
if strings.Count(segment, "```")%2 == 1 {
|
||||
return lastFence, true
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user