mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-02 07:25:26 +08:00
Compare commits
180 Commits
v2.2.1
...
v2.5.1_bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c95bf7b667 | ||
|
|
d79565b250 | ||
|
|
dc39de062b | ||
|
|
a7c9dfd7c0 | ||
|
|
822b14ed6b | ||
|
|
af7c7c6770 | ||
|
|
868a60b70b | ||
|
|
30a53b6c43 | ||
|
|
c3c644ff8c | ||
|
|
621599f8ad | ||
|
|
aeb519c211 | ||
|
|
075728cca6 | ||
|
|
883607ac87 | ||
|
|
1d6a8e7008 | ||
|
|
f041ebab93 | ||
|
|
3ab9d44f60 | ||
|
|
4b42fe9086 | ||
|
|
302bcefeb5 | ||
|
|
19b4f879c5 | ||
|
|
56a3ed19e8 | ||
|
|
958f4e39b5 | ||
|
|
6e8f3185d5 | ||
|
|
0925e83b9b | ||
|
|
87c231e736 | ||
|
|
5887821a9d | ||
|
|
7794006513 | ||
|
|
47d4499d47 | ||
|
|
15891ddc25 | ||
|
|
97a81c4191 | ||
|
|
b0a09dfab0 | ||
|
|
58f753d0c0 | ||
|
|
2e0586d060 | ||
|
|
1676c8e4f2 | ||
|
|
add13366d2 | ||
|
|
d5a23191f2 | ||
|
|
d2d4e39983 | ||
|
|
6e0dca3b30 | ||
|
|
b108a7915a | ||
|
|
2caabd8ce6 | ||
|
|
6802a3d53e | ||
|
|
e828006cb0 | ||
|
|
a6499cbece | ||
|
|
a504905626 | ||
|
|
59bf78d2c4 | ||
|
|
6cf4f0528c | ||
|
|
d8f8dcb704 | ||
|
|
455489ffeb | ||
|
|
5031ae0e6f | ||
|
|
3fccec0e22 | ||
|
|
00d38f1187 | ||
|
|
fe0f3d2c17 | ||
|
|
f67cbfad35 | ||
|
|
9afc533153 | ||
|
|
6a39543288 | ||
|
|
8fa1f998aa | ||
|
|
f8936887d0 | ||
|
|
db89744055 | ||
|
|
65312fc573 | ||
|
|
661d753fd3 | ||
|
|
7ca3f141c6 | ||
|
|
d530d25793 | ||
|
|
990cdcf02d | ||
|
|
648bb74587 | ||
|
|
9e5baed061 | ||
|
|
4884773639 | ||
|
|
6758514c61 | ||
|
|
01f33c409f | ||
|
|
55f11e655a | ||
|
|
2275e931f9 | ||
|
|
40594a44db | ||
|
|
67787d9c99 | ||
|
|
7061094964 | ||
|
|
492c603300 | ||
|
|
7e473dffc9 | ||
|
|
43a6e6712f | ||
|
|
ce1b76c90f | ||
|
|
1e7e0b2ae3 | ||
|
|
fd158e5ae2 | ||
|
|
95c96f7744 | ||
|
|
e7f59fac80 | ||
|
|
1bf059396f | ||
|
|
696b403173 | ||
|
|
f4db2732b0 | ||
|
|
ee88a74dcf | ||
|
|
ca08bb66b9 | ||
|
|
708fcb5beb | ||
|
|
7a65d1eaa2 | ||
|
|
6de2457743 | ||
|
|
ce44e260bf | ||
|
|
09f6537ffc | ||
|
|
ab8f494fdb | ||
|
|
b56a211da9 | ||
|
|
fcce5308cb | ||
|
|
d27b19cc53 | ||
|
|
b8ff678f24 | ||
|
|
b24ef1282d | ||
|
|
65e0de3c82 | ||
|
|
0c2743a48c | ||
|
|
dc73e8a6da | ||
|
|
b8495eeeb3 | ||
|
|
b3eae22cef | ||
|
|
7af0098d1b | ||
|
|
17405be300 | ||
|
|
5bc03e5de6 | ||
|
|
5a5f93148d | ||
|
|
32dc5b6099 | ||
|
|
7936d4675f | ||
|
|
808eafa7c6 | ||
|
|
bcb8ed6df2 | ||
|
|
8ec5dcc0cc | ||
|
|
88a79f212d | ||
|
|
b1f8d6192f | ||
|
|
acfb3b225d | ||
|
|
99a6164000 | ||
|
|
e49d9d33e2 | ||
|
|
184a3d1e4e | ||
|
|
c4ec14f49a | ||
|
|
fb5fc0e885 | ||
|
|
20b603666d | ||
|
|
4d549b7102 | ||
|
|
33b0d1d144 | ||
|
|
41c0f7ce28 | ||
|
|
efb484ba4f | ||
|
|
145501d4a5 | ||
|
|
2d5103997b | ||
|
|
52e7e7aae8 | ||
|
|
5b5a4000d7 | ||
|
|
2bbf603148 | ||
|
|
d14b8a0664 | ||
|
|
f16e0b579e | ||
|
|
43cbc4aac0 | ||
|
|
cf569f4749 | ||
|
|
c9c59f2490 | ||
|
|
16216cc2ca | ||
|
|
de50fd3954 | ||
|
|
7648d5f192 | ||
|
|
d35e5eab25 | ||
|
|
90610a52ce | ||
|
|
f6296d506f | ||
|
|
dfea092583 | ||
|
|
af7dc134bb | ||
|
|
2657d37f76 | ||
|
|
7318d1f4a8 | ||
|
|
f2674487c7 | ||
|
|
71cdcb43e8 | ||
|
|
9c46c3a874 | ||
|
|
12d5f136d5 | ||
|
|
00c37d8d2f | ||
|
|
0f1985af4a | ||
|
|
fa8affe1b7 | ||
|
|
c59a0b7799 | ||
|
|
bd72b91f27 | ||
|
|
9240f85246 | ||
|
|
ea4bd1e483 | ||
|
|
9e0de62707 | ||
|
|
128de290db | ||
|
|
286d266723 | ||
|
|
8aad1005b2 | ||
|
|
11b2f24fc2 | ||
|
|
d1f08cbb89 | ||
|
|
60e9d707d4 | ||
|
|
9b93badb57 | ||
|
|
892213071a | ||
|
|
5484d6e59d | ||
|
|
0ce3fd22a7 | ||
|
|
25e40cc3a6 | ||
|
|
af68d21095 | ||
|
|
1fafd25e86 | ||
|
|
5f8f28a943 | ||
|
|
94cf1bfcc7 | ||
|
|
13562cf521 | ||
|
|
d27e700c4f | ||
|
|
d6bce5af93 | ||
|
|
75969e710d | ||
|
|
6c39c8e191 | ||
|
|
0e261ff0a0 | ||
|
|
fab326eca1 | ||
|
|
c033eceee7 | ||
|
|
a10e03ebe0 | ||
|
|
a6aa4a1839 |
100
.env.example
100
.env.example
@@ -1,93 +1,17 @@
|
||||
# DS2API environment template (Go runtime)
|
||||
# Copy this file to .env and adjust values.
|
||||
# Updated: 2026-02
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Runtime
|
||||
# ---------------------------------------------------------------
|
||||
# HTTP listen port (default: 5001)
|
||||
# DS2API runtime
|
||||
PORT=5001
|
||||
|
||||
# Log level: DEBUG | INFO | WARN | ERROR
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Max concurrent inflight requests per account in managed-key mode.
|
||||
# Default: 2
|
||||
# Recommended client concurrency is calculated dynamically as:
|
||||
# account_count * DS2API_ACCOUNT_MAX_INFLIGHT
|
||||
# So by default it is account_count * 2.
|
||||
# Requests beyond inflight slots enter a waiting queue first.
|
||||
# Default queue size equals recommended concurrency, so 429 starts after:
|
||||
# account_count * DS2API_ACCOUNT_MAX_INFLIGHT * 2
|
||||
# Alias: DS2API_ACCOUNT_CONCURRENCY
|
||||
# DS2API_ACCOUNT_MAX_INFLIGHT=2
|
||||
# Admin authentication
|
||||
DS2API_ADMIN_KEY=change-me
|
||||
|
||||
# Optional waiting queue size override for managed-key mode.
|
||||
# Default: recommended_concurrency (same as account_count * inflight_limit)
|
||||
# Alias: DS2API_ACCOUNT_QUEUE_SIZE
|
||||
# DS2API_ACCOUNT_MAX_QUEUE=10
|
||||
# Config loading (choose one)
|
||||
# 1) file-based config
|
||||
DS2API_CONFIG_PATH=/app/config.json
|
||||
# 2) inline JSON or Base64 JSON
|
||||
# DS2API_CONFIG_JSON=
|
||||
# 3) legacy compatibility alias
|
||||
# CONFIG_JSON=
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Admin auth
|
||||
# ---------------------------------------------------------------
|
||||
# Admin key for /admin login and protected admin APIs.
|
||||
# Default is "admin" when unset, but setting it explicitly is recommended.
|
||||
DS2API_ADMIN_KEY=admin
|
||||
|
||||
# Optional JWT signing secret for admin token.
|
||||
# Defaults to DS2API_ADMIN_KEY when unset.
|
||||
# DS2API_JWT_SECRET=change-me
|
||||
|
||||
# Optional admin JWT validity in hours (default: 24)
|
||||
# DS2API_JWT_EXPIRE_HOURS=24
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Config source (choose one)
|
||||
# ---------------------------------------------------------------
|
||||
# Option A: config file path (local/dev recommended)
|
||||
# DS2API_CONFIG_PATH=config.json
|
||||
|
||||
# Option B: JSON string
|
||||
# DS2API_CONFIG_JSON={"keys":["your-api-key"],"accounts":[{"email":"user@example.com","password":"xxx","token":""}]}
|
||||
|
||||
# Option C: Base64 encoded JSON (recommended for Vercel env var)
|
||||
# DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0=
|
||||
#
|
||||
# Generate from local config.json:
|
||||
# DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Paths (optional)
|
||||
# ---------------------------------------------------------------
|
||||
# WASM file used for PoW solving
|
||||
# DS2API_WASM_PATH=sha3_wasm_bg.7b9ca65ddd.wasm
|
||||
|
||||
# Built admin static assets directory
|
||||
# DS2API_STATIC_ADMIN_DIR=static/admin
|
||||
|
||||
# Auto-build WebUI on startup when static/admin is missing.
|
||||
# Default: enabled on local/Docker, disabled on Vercel.
|
||||
# DS2API_AUTO_BUILD_WEBUI=true
|
||||
|
||||
# Internal auth secret used by the Vercel hybrid streaming path
|
||||
# (Go prepare endpoint <-> Node stream function).
|
||||
# Optional: falls back to DS2API_ADMIN_KEY when unset.
|
||||
# DS2API_VERCEL_INTERNAL_SECRET=change-me
|
||||
|
||||
# Stream lease TTL seconds for Vercel hybrid streaming.
|
||||
# During this window, the managed account stays occupied until Node calls release.
|
||||
# Default: 900 (15 minutes)
|
||||
# DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS=900
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Vercel sync integration (optional)
|
||||
# ---------------------------------------------------------------
|
||||
# VERCEL_TOKEN=your-vercel-token
|
||||
# VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
|
||||
# VERCEL_TEAM_ID=team_xxxxxxxxxxxx
|
||||
|
||||
# Optional: Vercel deployment protection bypass secret.
|
||||
# If deployment protection is enabled, DS2API will use this value as
|
||||
# x-vercel-protection-bypass for internal Node->Go calls on Vercel.
|
||||
# You can also use VERCEL_AUTOMATION_BYPASS_SECRET directly.
|
||||
# DS2API_VERCEL_PROTECTION_BYPASS=your-bypass-secret
|
||||
# Optional: static admin assets path
|
||||
# DS2API_STATIC_ADMIN_DIR=/app/static/admin
|
||||
|
||||
2
.github/workflows/quality-gates.yml
vendored
2
.github/workflows/quality-gates.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
node-version: "24"
|
||||
cache: "npm"
|
||||
cache-dependency-path: webui/package-lock.json
|
||||
|
||||
|
||||
10
.github/workflows/release-artifacts.yml
vendored
10
.github/workflows/release-artifacts.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
node-version: "24"
|
||||
cache: "npm"
|
||||
cache-dependency-path: webui/package-lock.json
|
||||
|
||||
@@ -51,6 +51,10 @@ jobs:
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG="${RELEASE_TAG}"
|
||||
BUILD_VERSION="${TAG}"
|
||||
if [ -z "${BUILD_VERSION}" ] && [ -f VERSION ]; then
|
||||
BUILD_VERSION="$(cat VERSION | tr -d '[:space:]')"
|
||||
fi
|
||||
mkdir -p dist
|
||||
|
||||
targets=(
|
||||
@@ -73,9 +77,9 @@ jobs:
|
||||
|
||||
mkdir -p "${STAGE}/static"
|
||||
CGO_ENABLED=0 GOOS="${GOOS}" GOARCH="${GOARCH}" \
|
||||
go build -trimpath -ldflags="-s -w" -o "${STAGE}/${BIN}" ./cmd/ds2api
|
||||
go build -trimpath -ldflags="-s -w -X ds2api/internal/version.BuildVersion=${BUILD_VERSION}" -o "${STAGE}/${BIN}" ./cmd/ds2api
|
||||
|
||||
cp config.example.json .env.example sha3_wasm_bg.7b9ca65ddd.wasm LICENSE README.MD README.en.md "${STAGE}/"
|
||||
cp config.example.json .env.example internal/deepseek/assets/sha3_wasm_bg.7b9ca65ddd.wasm LICENSE README.MD README.en.md "${STAGE}/"
|
||||
cp -R static/admin "${STAGE}/static/admin"
|
||||
|
||||
if [ "${GOOS}" = "windows" ]; then
|
||||
|
||||
2
.github/workflows/release-dockerhub.yml
vendored
2
.github/workflows/release-dockerhub.yml
vendored
@@ -123,5 +123,7 @@ jobs:
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
build-args: |
|
||||
BUILD_VERSION=${{ steps.next_version.outputs.new_tag }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
258
.github/workflows/release.yml
vendored
258
.github/workflows/release.yml
vendored
@@ -1,128 +1,130 @@
|
||||
name: Release to Aliyun CR
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version_type:
|
||||
description: '版本类型'
|
||||
required: true
|
||||
default: 'patch'
|
||||
type: choice
|
||||
options:
|
||||
- patch
|
||||
- minor
|
||||
- major
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Get current version
|
||||
id: get_version
|
||||
run: |
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
TAG_VERSION=${LATEST_TAG#v}
|
||||
|
||||
if [ -f VERSION ]; then
|
||||
FILE_VERSION=$(cat VERSION | tr -d '[:space:]')
|
||||
else
|
||||
FILE_VERSION="0.0.0"
|
||||
fi
|
||||
|
||||
function version_gt() { test "$(printf '%s\n' "$@" | sort -V | head -n 1)" != "$1"; }
|
||||
|
||||
if version_gt "$FILE_VERSION" "$TAG_VERSION"; then
|
||||
VERSION="$FILE_VERSION"
|
||||
else
|
||||
VERSION="$TAG_VERSION"
|
||||
fi
|
||||
|
||||
echo "Current version: $VERSION"
|
||||
echo "current_version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Calculate next version
|
||||
id: next_version
|
||||
env:
|
||||
VERSION_TYPE: ${{ github.event.inputs.version_type }}
|
||||
run: |
|
||||
VERSION="${{ steps.get_version.outputs.current_version }}"
|
||||
BASE_VERSION=$(echo "$VERSION" | sed 's/-.*$//')
|
||||
|
||||
IFS='.' read -r -a version_parts <<< "$BASE_VERSION"
|
||||
MAJOR="${version_parts[0]:-0}"
|
||||
MINOR="${version_parts[1]:-0}"
|
||||
PATCH="${version_parts[2]:-0}"
|
||||
|
||||
case "$VERSION_TYPE" in
|
||||
major)
|
||||
NEW_VERSION="$((MAJOR + 1)).0.0"
|
||||
;;
|
||||
minor)
|
||||
NEW_VERSION="${MAJOR}.$((MINOR + 1)).0"
|
||||
;;
|
||||
*)
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.$((PATCH + 1))"
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "New version: $NEW_VERSION"
|
||||
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
echo "new_tag=v$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
echo "${{ steps.next_version.outputs.new_version }}" > VERSION
|
||||
|
||||
- name: Commit VERSION and create tag
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git add VERSION
|
||||
if ! git diff --cached --quiet; then
|
||||
git commit -m "chore: bump version to ${{ steps.next_version.outputs.new_tag }} [skip ci]"
|
||||
fi
|
||||
|
||||
NEW_TAG="${{ steps.next_version.outputs.new_tag }}"
|
||||
git tag -a "$NEW_TAG" -m "Release $NEW_TAG"
|
||||
git push origin HEAD:main "$NEW_TAG"
|
||||
|
||||
# Docker 构建并推送到阿里云
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Aliyun Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ secrets.ALIYUN_REGISTRY }}
|
||||
username: ${{ secrets.ALIYUN_REGISTRY_USER }}
|
||||
password: ${{ secrets.ALIYUN_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_tag }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_version }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:latest
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
name: Release to Aliyun CR
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version_type:
|
||||
description: '版本类型'
|
||||
required: true
|
||||
default: 'patch'
|
||||
type: choice
|
||||
options:
|
||||
- patch
|
||||
- minor
|
||||
- major
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Get current version
|
||||
id: get_version
|
||||
run: |
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
TAG_VERSION=${LATEST_TAG#v}
|
||||
|
||||
if [ -f VERSION ]; then
|
||||
FILE_VERSION=$(cat VERSION | tr -d '[:space:]')
|
||||
else
|
||||
FILE_VERSION="0.0.0"
|
||||
fi
|
||||
|
||||
function version_gt() { test "$(printf '%s\n' "$@" | sort -V | head -n 1)" != "$1"; }
|
||||
|
||||
if version_gt "$FILE_VERSION" "$TAG_VERSION"; then
|
||||
VERSION="$FILE_VERSION"
|
||||
else
|
||||
VERSION="$TAG_VERSION"
|
||||
fi
|
||||
|
||||
echo "Current version: $VERSION"
|
||||
echo "current_version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Calculate next version
|
||||
id: next_version
|
||||
env:
|
||||
VERSION_TYPE: ${{ github.event.inputs.version_type }}
|
||||
run: |
|
||||
VERSION="${{ steps.get_version.outputs.current_version }}"
|
||||
BASE_VERSION=$(echo "$VERSION" | sed 's/-.*$//')
|
||||
|
||||
IFS='.' read -r -a version_parts <<< "$BASE_VERSION"
|
||||
MAJOR="${version_parts[0]:-0}"
|
||||
MINOR="${version_parts[1]:-0}"
|
||||
PATCH="${version_parts[2]:-0}"
|
||||
|
||||
case "$VERSION_TYPE" in
|
||||
major)
|
||||
NEW_VERSION="$((MAJOR + 1)).0.0"
|
||||
;;
|
||||
minor)
|
||||
NEW_VERSION="${MAJOR}.$((MINOR + 1)).0"
|
||||
;;
|
||||
*)
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.$((PATCH + 1))"
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "New version: $NEW_VERSION"
|
||||
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
echo "new_tag=v$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
echo "${{ steps.next_version.outputs.new_version }}" > VERSION
|
||||
|
||||
- name: Commit VERSION and create tag
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git add VERSION
|
||||
if ! git diff --cached --quiet; then
|
||||
git commit -m "chore: bump version to ${{ steps.next_version.outputs.new_tag }} [skip ci]"
|
||||
fi
|
||||
|
||||
NEW_TAG="${{ steps.next_version.outputs.new_tag }}"
|
||||
git tag -a "$NEW_TAG" -m "Release $NEW_TAG"
|
||||
git push origin HEAD:main "$NEW_TAG"
|
||||
|
||||
# Docker 构建并推送到阿里云
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Aliyun Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ secrets.ALIYUN_REGISTRY }}
|
||||
username: ${{ secrets.ALIYUN_REGISTRY_USER }}
|
||||
password: ${{ secrets.ALIYUN_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_tag }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_version }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:latest
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
build-args: |
|
||||
BUILD_VERSION=${{ steps.next_version.outputs.new_tag }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
92
API.en.md
92
API.en.md
@@ -46,6 +46,7 @@ Use it per deployment mode:
|
||||
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate Base64 from `config.json`, then set `DS2API_CONFIG_JSON`
|
||||
- Compatibility note: `DS2API_CONFIG_JSON` may also contain raw JSON directly; `CONFIG_JSON` is the legacy fallback variable
|
||||
|
||||
```bash
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
@@ -65,6 +66,7 @@ Two header formats accepted:
|
||||
| --- | --- |
|
||||
| Bearer Token | `Authorization: Bearer <token>` |
|
||||
| API Key Header | `x-api-key: <token>` (no `Bearer` prefix) |
|
||||
| Gemini-compatible | `x-goog-api-key: <token>` or `?key=<token>` / `?api_key=<token>` |
|
||||
|
||||
**Auth behavior**:
|
||||
|
||||
@@ -72,6 +74,7 @@ Two header formats accepted:
|
||||
- Token is not in `config.keys` → **Direct token mode**: treated as a DeepSeek token directly
|
||||
|
||||
**Optional header**: `X-Ds2-Target-Account: <email_or_mobile>` — Pin a specific managed account.
|
||||
Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=` as the caller credential source.
|
||||
|
||||
### Admin Endpoints (`/admin/*`)
|
||||
|
||||
@@ -124,13 +127,16 @@ Two header formats accepted:
|
||||
| GET | `/admin/queue/status` | Admin | Account queue status |
|
||||
| POST | `/admin/accounts/test` | Admin | Test one account |
|
||||
| POST | `/admin/accounts/test-all` | Admin | Test all accounts |
|
||||
| POST | `/admin/accounts/sessions/delete-all` | Admin | Delete all sessions for one account |
|
||||
| POST | `/admin/import` | Admin | Batch import keys/accounts |
|
||||
| POST | `/admin/test` | Admin | Test API through service |
|
||||
| POST | `/admin/vercel/sync` | Admin | Sync config to Vercel |
|
||||
| GET | `/admin/vercel/status` | Admin | Vercel sync status |
|
||||
| POST | `/admin/vercel/status` | Admin | Vercel sync status / draft compare |
|
||||
| GET | `/admin/export` | Admin | Export config JSON/Base64 |
|
||||
| GET | `/admin/dev/captures` | Admin | Read local packet-capture entries |
|
||||
| DELETE | `/admin/dev/captures` | Admin | Clear local packet-capture entries |
|
||||
| GET | `/admin/version` | Admin | Check current version and latest Release |
|
||||
|
||||
---
|
||||
|
||||
@@ -580,6 +586,7 @@ Returns sanitized config.
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"env_backed": false,
|
||||
"accounts": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
@@ -599,7 +606,7 @@ Returns sanitized config.
|
||||
|
||||
### `POST /admin/config`
|
||||
|
||||
Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
||||
Only updates `keys`, `accounts`, and `claude_mapping`.
|
||||
|
||||
**Request**:
|
||||
|
||||
@@ -620,23 +627,27 @@ Updatable fields: `keys`, `accounts`, `claude_mapping`.
|
||||
|
||||
Reads runtime settings and status, including:
|
||||
|
||||
- `admin` (JWT expiry, default-password warning, etc.)
|
||||
- `runtime` (`account_max_inflight`, `account_max_queue`, `global_max_inflight`)
|
||||
- `toolcall` / `responses` / `embeddings`
|
||||
- `success`
|
||||
- `admin` (`has_password_hash`, `jwt_expire_hours`, `jwt_valid_after_unix`, `default_password_warning`)
|
||||
- `runtime` (`account_max_inflight`, `account_max_queue`, `global_max_inflight`, `token_refresh_interval_hours`)
|
||||
- `responses` / `embeddings`
|
||||
- `auto_delete` (`sessions`)
|
||||
- `claude_mapping` / `model_aliases`
|
||||
- `env_backed`, `needs_vercel_sync`
|
||||
- `toolcall` policy is fixed to `feature_match + high` and is no longer returned or editable via settings
|
||||
|
||||
### `PUT /admin/settings`
|
||||
|
||||
Hot-updates runtime settings. Supported fields:
|
||||
|
||||
- `admin.jwt_expire_hours`
|
||||
- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight`
|
||||
- `toolcall.mode` / `toolcall.early_emit_confidence`
|
||||
- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight` / `runtime.token_refresh_interval_hours`
|
||||
- `responses.store_ttl_seconds`
|
||||
- `embeddings.provider`
|
||||
- `auto_delete.sessions`
|
||||
- `claude_mapping`
|
||||
- `model_aliases`
|
||||
- `toolcall` policy is fixed and is no longer writable through settings
|
||||
|
||||
### `POST /admin/settings/password`
|
||||
|
||||
@@ -648,6 +659,8 @@ Request example:
|
||||
{"new_password":"your-new-password"}
|
||||
```
|
||||
|
||||
It also accepts `{"password":"your-new-password"}`.
|
||||
|
||||
### `POST /admin/config/import`
|
||||
|
||||
Imports full config with:
|
||||
@@ -656,6 +669,8 @@ Imports full config with:
|
||||
- `mode=replace`
|
||||
|
||||
The request can send config directly, or wrapped as `{"config": {...}, "mode":"merge"}`.
|
||||
Query params `?mode=merge` / `?mode=replace` are also supported.
|
||||
Import accepts `keys`, `accounts`, `claude_mapping` / `claude_model_mapping`, `model_aliases`, `admin`, `runtime`, `responses`, `embeddings`, and `auto_delete`; legacy `toolcall` fields are ignored.
|
||||
|
||||
### `GET /admin/config/export`
|
||||
|
||||
@@ -681,6 +696,7 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
| --- | --- | --- |
|
||||
| `page` | `1` | ≥ 1 |
|
||||
| `page_size` | `10` | 1–100 |
|
||||
| `q` | empty | Filter by identifier / email / mobile |
|
||||
|
||||
**Response**:
|
||||
|
||||
@@ -693,7 +709,8 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
"has_token": true,
|
||||
"token_preview": "abc..."
|
||||
"token_preview": "abc...",
|
||||
"test_status": "ok"
|
||||
}
|
||||
],
|
||||
"total": 25,
|
||||
@@ -703,6 +720,8 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
}
|
||||
```
|
||||
|
||||
Returned items also include `test_status`, usually `ok` or `failed`.
|
||||
|
||||
### `POST /admin/accounts`
|
||||
|
||||
```json
|
||||
@@ -755,10 +774,14 @@ Exports full config in three forms: `config`, `json`, and `base64`.
|
||||
"success": true,
|
||||
"response_time": 1240,
|
||||
"message": "API test successful (session creation only)",
|
||||
"model": "deepseek-chat"
|
||||
"model": "deepseek-chat",
|
||||
"session_count": 0,
|
||||
"config_writable": true
|
||||
}
|
||||
```
|
||||
|
||||
If a `message` is provided, `thinking` may also be included when the upstream response carries reasoning text.
|
||||
|
||||
### `POST /admin/accounts/test-all`
|
||||
|
||||
Optional request field: `model`.
|
||||
@@ -772,6 +795,25 @@ Optional request field: `model`.
|
||||
}
|
||||
```
|
||||
|
||||
The internal concurrency limit is currently fixed at 5.
|
||||
|
||||
### `POST /admin/accounts/sessions/delete-all`
|
||||
|
||||
Deletes all DeepSeek sessions for a specific account. Request example:
|
||||
|
||||
```json
|
||||
{"identifier":"user@example.com"}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{"success": true, "message": "删除成功"}
|
||||
```
|
||||
|
||||
If the account is missing or deletion fails, `success` becomes `false` and `message` contains the error.
|
||||
The current handler returns the Chinese literal `删除成功` on success.
|
||||
|
||||
### `POST /admin/import`
|
||||
|
||||
Batch import keys and accounts.
|
||||
@@ -849,16 +891,25 @@ Or manual deploy required:
|
||||
}
|
||||
```
|
||||
|
||||
Failed account checks are returned in `failed_accounts`, and any saved Vercel credentials are returned in `saved_credentials`.
|
||||
|
||||
### `GET /admin/vercel/status`
|
||||
|
||||
```json
|
||||
{
|
||||
"synced": true,
|
||||
"last_sync_time": 1738400000,
|
||||
"has_synced_before": true
|
||||
"has_synced_before": true,
|
||||
"env_backed": false,
|
||||
"config_hash": "....",
|
||||
"last_synced_hash": "....",
|
||||
"draft_hash": "....",
|
||||
"draft_differs": false
|
||||
}
|
||||
```
|
||||
|
||||
`POST /admin/vercel/status` can also accept `config_override` to compare a draft config against the current synced config.
|
||||
|
||||
### `GET /admin/export`
|
||||
|
||||
```json
|
||||
@@ -868,6 +919,29 @@ Or manual deploy required:
|
||||
}
|
||||
```
|
||||
|
||||
This is the same payload as `GET /admin/config/export`, just with a shorter path.
|
||||
|
||||
### `GET /admin/version`
|
||||
|
||||
Checks the current build version and the latest GitHub Release:
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"current_version": "2.3.5",
|
||||
"current_tag": "v2.3.5",
|
||||
"source": "file:VERSION",
|
||||
"checked_at": "2026-03-29T00:00:00Z",
|
||||
"latest_tag": "v2.3.6",
|
||||
"latest_version": "2.3.6",
|
||||
"release_url": "https://github.com/CJackHwang/ds2api/releases/tag/v2.3.6",
|
||||
"published_at": "2026-03-28T12:00:00Z",
|
||||
"has_update": true
|
||||
}
|
||||
```
|
||||
|
||||
If GitHub API access fails, the response includes `check_error` while still returning HTTP 200.
|
||||
|
||||
### `GET /admin/dev/captures`
|
||||
|
||||
Reads local packet-capture status and recent entries (Admin auth required):
|
||||
|
||||
102
API.md
102
API.md
@@ -46,6 +46,7 @@ cp config.example.json config.json
|
||||
|
||||
- 本地运行:直接读取 `config.json`
|
||||
- Docker / Vercel:从 `config.json` 生成 Base64,填入 `DS2API_CONFIG_JSON`
|
||||
- 兼容写法:`DS2API_CONFIG_JSON` 也可直接填原始 JSON;`CONFIG_JSON` 是旧版兼容回退变量
|
||||
|
||||
```bash
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
@@ -65,6 +66,7 @@ Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导
|
||||
| --- | --- |
|
||||
| Bearer Token | `Authorization: Bearer <token>` |
|
||||
| API Key Header | `x-api-key: <token>`(无 `Bearer` 前缀) |
|
||||
| Gemini 兼容 | `x-goog-api-key: <token>` 或 `?key=<token>` / `?api_key=<token>` |
|
||||
|
||||
**鉴权行为**:
|
||||
|
||||
@@ -72,6 +74,7 @@ Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导
|
||||
- token 不在 `config.keys` 中 → **直通 token 模式**,直接作为 DeepSeek token 使用
|
||||
|
||||
**可选请求头**:`X-Ds2-Target-Account: <email_or_mobile>` — 指定使用某个托管账号。
|
||||
Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=` 作为凭据来源。
|
||||
|
||||
### Admin 接口(`/admin/*`)
|
||||
|
||||
@@ -124,13 +127,16 @@ Vercel 一键部署可先只填 `DS2API_ADMIN_KEY`,部署后在 `/admin` 导
|
||||
| GET | `/admin/queue/status` | Admin | 账号队列状态 |
|
||||
| POST | `/admin/accounts/test` | Admin | 测试单个账号 |
|
||||
| POST | `/admin/accounts/test-all` | Admin | 测试全部账号 |
|
||||
| POST | `/admin/accounts/sessions/delete-all` | Admin | 删除某账号的全部会话 |
|
||||
| POST | `/admin/import` | Admin | 批量导入 keys/accounts |
|
||||
| POST | `/admin/test` | Admin | 测试当前 API 可用性 |
|
||||
| POST | `/admin/vercel/sync` | Admin | 同步配置到 Vercel |
|
||||
| GET | `/admin/vercel/status` | Admin | Vercel 同步状态 |
|
||||
| POST | `/admin/vercel/status` | Admin | Vercel 同步状态 / 草稿对比 |
|
||||
| GET | `/admin/export` | Admin | 导出配置 JSON/Base64 |
|
||||
| GET | `/admin/dev/captures` | Admin | 查看本地抓包记录 |
|
||||
| DELETE | `/admin/dev/captures` | Admin | 清空本地抓包记录 |
|
||||
| GET | `/admin/version` | Admin | 查询当前版本与最新 Release |
|
||||
|
||||
---
|
||||
|
||||
@@ -284,6 +290,12 @@ data: [DONE]
|
||||
|
||||
**流式**:命中高置信特征后立即输出 `delta.tool_calls`(不等待完整 JSON 闭合),并持续发送 arguments 增量;已确认的 toolcall 原始 JSON 不会回流到 `delta.content`。
|
||||
|
||||
补充说明:
|
||||
|
||||
- **非代码块上下文**下,工具负载即使与普通文本混合,也会按特征识别并产出可执行 tool call(前后普通文本仍可透传)。
|
||||
- 解析器以 XML/Markup 为最高优先级,并兼容 JSON、ANTML、text-kv 等格式输入;最终按客户端协议转译为对应 tool call 结构(OpenAI/Claude/Gemini)。
|
||||
- Markdown fenced code block(例如 ```json ... ```)中的 `tool_calls` 仅视为示例文本,不会被执行。
|
||||
|
||||
---
|
||||
|
||||
### `GET /v1/models/{id}`
|
||||
@@ -301,7 +313,7 @@ OpenAI Responses 风格接口,兼容 `input` 或 `messages`。
|
||||
| `messages` | array | ❌ | 与 `input` 二选一 |
|
||||
| `instructions` | string | ❌ | 自动前置为 system 消息 |
|
||||
| `stream` | boolean | ❌ | 默认 `false` |
|
||||
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略 |
|
||||
| `tools` | array | ❌ | 与 chat 同样的工具识别与转译策略(含代码块示例豁免) |
|
||||
| `tool_choice` | string/object | ❌ | 支持 `auto`/`none`/`required` 与强制函数(`{"type":"function","name":"..."}`) |
|
||||
|
||||
**非流式响应**:返回标准 `response` 对象,`id` 形如 `resp_xxx`,并写入内存 TTL 存储。
|
||||
@@ -341,7 +353,8 @@ data: [DONE]
|
||||
```
|
||||
|
||||
流式场景下若 `tool_choice=required` 违规,会返回 `response.failed` 后结束(不再发送 `response.completed`)。
|
||||
未在 `tools` 声明中的工具名会被严格拒绝,不会作为有效 tool call 下发。
|
||||
|
||||
> 当前版本说明:解析层默认“尽量提取结构化 tool call”,未启用基于 `tools` allow-list 的硬拒绝;是否执行仍应由你的工具执行器做白名单校验。
|
||||
|
||||
### `GET /v1/responses/{response_id}`
|
||||
|
||||
@@ -487,6 +500,8 @@ data: {"type":"message_stop"}
|
||||
}
|
||||
```
|
||||
|
||||
返回项还会包含 `test_status`,当前值通常为 `ok` 或 `failed`。
|
||||
|
||||
---
|
||||
|
||||
## Gemini 兼容接口
|
||||
@@ -580,6 +595,7 @@ data: {"type":"message_stop"}
|
||||
```json
|
||||
{
|
||||
"keys": ["k1", "k2"],
|
||||
"env_backed": false,
|
||||
"accounts": [
|
||||
{
|
||||
"identifier": "user@example.com",
|
||||
@@ -599,7 +615,7 @@ data: {"type":"message_stop"}
|
||||
|
||||
### `POST /admin/config`
|
||||
|
||||
可更新 `keys`、`accounts`、`claude_mapping`。
|
||||
只更新 `keys`、`accounts`、`claude_mapping`。
|
||||
|
||||
**请求**:
|
||||
|
||||
@@ -620,23 +636,27 @@ data: {"type":"message_stop"}
|
||||
|
||||
读取运行时设置与状态,返回:
|
||||
|
||||
- `admin`(JWT 过期、默认密码告警等)
|
||||
- `runtime`(`account_max_inflight`、`account_max_queue`、`global_max_inflight`)
|
||||
- `toolcall` / `responses` / `embeddings`
|
||||
- `success`
|
||||
- `admin`(`has_password_hash`、`jwt_expire_hours`、`jwt_valid_after_unix`、`default_password_warning`)
|
||||
- `runtime`(`account_max_inflight`、`account_max_queue`、`global_max_inflight`、`token_refresh_interval_hours`)
|
||||
- `responses` / `embeddings`
|
||||
- `auto_delete`(`sessions`)
|
||||
- `claude_mapping` / `model_aliases`
|
||||
- `env_backed`、`needs_vercel_sync`
|
||||
- `toolcall` 策略已固定为 `feature_match + high`,不再通过 settings 返回或修改
|
||||
|
||||
### `PUT /admin/settings`
|
||||
|
||||
热更新运行时设置。支持更新:
|
||||
|
||||
- `admin.jwt_expire_hours`
|
||||
- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight`
|
||||
- `toolcall.mode` / `toolcall.early_emit_confidence`
|
||||
- `runtime.account_max_inflight` / `runtime.account_max_queue` / `runtime.global_max_inflight` / `runtime.token_refresh_interval_hours`
|
||||
- `responses.store_ttl_seconds`
|
||||
- `embeddings.provider`
|
||||
- `auto_delete.sessions`
|
||||
- `claude_mapping`
|
||||
- `model_aliases`
|
||||
- `toolcall` 策略已固定,不再作为可写入字段
|
||||
|
||||
### `POST /admin/settings/password`
|
||||
|
||||
@@ -648,6 +668,8 @@ data: {"type":"message_stop"}
|
||||
{"new_password":"your-new-password"}
|
||||
```
|
||||
|
||||
也兼容 `{"password":"your-new-password"}`。
|
||||
|
||||
### `POST /admin/config/import`
|
||||
|
||||
导入完整配置,支持:
|
||||
@@ -656,6 +678,8 @@ data: {"type":"message_stop"}
|
||||
- `mode=replace`
|
||||
|
||||
请求可直接传配置对象,或使用 `{"config": {...}, "mode":"merge"}` 包裹格式。
|
||||
也支持在查询参数里传 `?mode=merge` / `?mode=replace`。
|
||||
导入时会接受 `keys`、`accounts`、`claude_mapping` / `claude_model_mapping`、`model_aliases`、`admin`、`runtime`、`responses`、`embeddings`、`auto_delete` 等字段;`toolcall` 相关字段会被忽略。
|
||||
|
||||
### `GET /admin/config/export`
|
||||
|
||||
@@ -681,6 +705,7 @@ data: {"type":"message_stop"}
|
||||
| --- | --- | --- |
|
||||
| `page` | `1` | ≥ 1 |
|
||||
| `page_size` | `10` | 1–100 |
|
||||
| `q` | 空 | 按 identifier / email / mobile 过滤 |
|
||||
|
||||
**响应**:
|
||||
|
||||
@@ -693,7 +718,8 @@ data: {"type":"message_stop"}
|
||||
"mobile": "",
|
||||
"has_password": true,
|
||||
"has_token": true,
|
||||
"token_preview": "abc..."
|
||||
"token_preview": "abc...",
|
||||
"test_status": "ok"
|
||||
}
|
||||
],
|
||||
"total": 25,
|
||||
@@ -755,10 +781,14 @@ data: {"type":"message_stop"}
|
||||
"success": true,
|
||||
"response_time": 1240,
|
||||
"message": "API 测试成功(仅会话创建)",
|
||||
"model": "deepseek-chat"
|
||||
"model": "deepseek-chat",
|
||||
"session_count": 0,
|
||||
"config_writable": true
|
||||
}
|
||||
```
|
||||
|
||||
如果传入 `message`,还会附带 `thinking`(当上游返回思考内容时)。
|
||||
|
||||
### `POST /admin/accounts/test-all`
|
||||
|
||||
可选请求字段:`model`
|
||||
@@ -772,6 +802,24 @@ data: {"type":"message_stop"}
|
||||
}
|
||||
```
|
||||
|
||||
内部并发上限当前固定为 5。
|
||||
|
||||
### `POST /admin/accounts/sessions/delete-all`
|
||||
|
||||
清空指定账号的所有 DeepSeek 会话。请求体示例:
|
||||
|
||||
```json
|
||||
{"identifier":"user@example.com"}
|
||||
```
|
||||
|
||||
响应:
|
||||
|
||||
```json
|
||||
{"success": true, "message": "删除成功"}
|
||||
```
|
||||
|
||||
如果账号不存在或删除失败,`success` 会是 `false`,`message` 会返回错误原因。
|
||||
|
||||
### `POST /admin/import`
|
||||
|
||||
批量导入 keys 与 accounts。
|
||||
@@ -849,16 +897,25 @@ data: {"type":"message_stop"}
|
||||
}
|
||||
```
|
||||
|
||||
失败校验的账号会通过 `failed_accounts` 返回;成功保存到 Vercel 的凭据会通过 `saved_credentials` 返回。
|
||||
|
||||
### `GET /admin/vercel/status`
|
||||
|
||||
```json
|
||||
{
|
||||
"synced": true,
|
||||
"last_sync_time": 1738400000,
|
||||
"has_synced_before": true
|
||||
"has_synced_before": true,
|
||||
"env_backed": false,
|
||||
"config_hash": "....",
|
||||
"last_synced_hash": "....",
|
||||
"draft_hash": "....",
|
||||
"draft_differs": false
|
||||
}
|
||||
```
|
||||
|
||||
`POST /admin/vercel/status` 还可以携带 `config_override`,用于对比“草稿配置”和当前已同步配置。
|
||||
|
||||
### `GET /admin/export`
|
||||
|
||||
```json
|
||||
@@ -868,6 +925,29 @@ data: {"type":"message_stop"}
|
||||
}
|
||||
```
|
||||
|
||||
该接口与 `GET /admin/config/export` 返回相同内容,只是路径更短。
|
||||
|
||||
### `GET /admin/version`
|
||||
|
||||
查询当前构建版本与 GitHub 最新 Release:
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"current_version": "2.3.5",
|
||||
"current_tag": "v2.3.5",
|
||||
"source": "file:VERSION",
|
||||
"checked_at": "2026-03-29T00:00:00Z",
|
||||
"latest_tag": "v2.3.6",
|
||||
"latest_version": "2.3.6",
|
||||
"release_url": "https://github.com/CJackHwang/ds2api/releases/tag/v2.3.6",
|
||||
"published_at": "2026-03-28T12:00:00Z",
|
||||
"has_update": true
|
||||
}
|
||||
```
|
||||
|
||||
如果 GitHub API 不可用,响应里会额外包含 `check_error`,但 HTTP 状态仍为 200。
|
||||
|
||||
### `GET /admin/dev/captures`
|
||||
|
||||
查看本地抓包状态与最近记录(需 Admin 鉴权):
|
||||
|
||||
11
Dockerfile
11
Dockerfile
@@ -10,26 +10,31 @@ FROM golang:1.24 AS go-builder
|
||||
WORKDIR /app
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
ARG BUILD_VERSION
|
||||
COPY go.mod go.sum* ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN set -eux; \
|
||||
GOOS="${TARGETOS:-$(go env GOOS)}"; \
|
||||
GOARCH="${TARGETARCH:-$(go env GOARCH)}"; \
|
||||
CGO_ENABLED=0 GOOS="${GOOS}" GOARCH="${GOARCH}" go build -o /out/ds2api ./cmd/ds2api
|
||||
BUILD_VERSION_RESOLVED="${BUILD_VERSION:-}"; \
|
||||
if [ -z "${BUILD_VERSION_RESOLVED}" ] && [ -f VERSION ]; then BUILD_VERSION_RESOLVED="$(cat VERSION | tr -d "[:space:]")"; fi; \
|
||||
CGO_ENABLED=0 GOOS="${GOOS}" GOARCH="${GOARCH}" go build -ldflags="-s -w -X ds2api/internal/version.BuildVersion=${BUILD_VERSION_RESOLVED}" -o /out/ds2api ./cmd/ds2api
|
||||
|
||||
FROM busybox:1.36.1-musl AS busybox-tools
|
||||
|
||||
FROM debian:bookworm-slim AS runtime-base
|
||||
WORKDIR /app
|
||||
COPY --from=go-builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=busybox-tools /bin/busybox /usr/local/bin/busybox
|
||||
EXPOSE 5001
|
||||
CMD ["/usr/local/bin/ds2api"]
|
||||
|
||||
FROM runtime-base AS runtime-from-source
|
||||
COPY --from=go-builder /out/ds2api /usr/local/bin/ds2api
|
||||
COPY --from=go-builder /app/sha3_wasm_bg.7b9ca65ddd.wasm /app/sha3_wasm_bg.7b9ca65ddd.wasm
|
||||
COPY --from=go-builder /app/internal/deepseek/assets/sha3_wasm_bg.7b9ca65ddd.wasm /app/sha3_wasm_bg.7b9ca65ddd.wasm
|
||||
COPY --from=go-builder /app/config.example.json /app/config.example.json
|
||||
COPY --from=webui-builder /app/static/admin /app/static/admin
|
||||
|
||||
|
||||
123
README.MD
123
README.MD
@@ -8,7 +8,7 @@
|
||||

|
||||

|
||||
[](https://github.com/CJackHwang/ds2api/releases)
|
||||
[](DEPLOY.md)
|
||||
[](docs/DEPLOY.md)
|
||||
[](https://zeabur.com/templates/L4CFHP)
|
||||
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||
|
||||
@@ -16,6 +16,14 @@
|
||||
|
||||
将 DeepSeek Web 对话能力转换为 OpenAI、Claude 与 Gemini 兼容 API。后端为 **Go 全量实现**,前端为 React WebUI 管理台(源码在 `webui/`,部署时自动构建到 `static/admin`)。
|
||||
|
||||
> **重要免责声明**
|
||||
>
|
||||
> 本仓库仅供学习、研究、个人实验和内部验证使用,不提供任何形式的商业授权、适用性保证或结果保证。
|
||||
>
|
||||
> 作者及仓库维护者不对因使用、修改、分发、部署或依赖本项目而产生的任何直接或间接损失、账号封禁、数据丢失、法律风险或第三方索赔负责。
|
||||
>
|
||||
> 请勿将本项目用于违反服务条款、协议、法律法规或平台规则的场景。商业使用前请自行确认 `LICENSE`、相关协议以及你是否获得了作者的书面许可。
|
||||
|
||||
## 架构概览
|
||||
|
||||
```mermaid
|
||||
@@ -68,7 +76,7 @@ flowchart LR
|
||||
| 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 |
|
||||
| DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 |
|
||||
| Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 |
|
||||
| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、导入导出、Vercel 同步 |
|
||||
| Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、会话清理、导入导出、Vercel 同步、版本检查 |
|
||||
| WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) |
|
||||
| 运维探针 | `GET /healthz`(存活)、`GET /readyz`(就绪) |
|
||||
|
||||
@@ -106,6 +114,14 @@ flowchart LR
|
||||
可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。
|
||||
另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。
|
||||
|
||||
|
||||
#### Claude Code 接入避坑(实测)
|
||||
|
||||
- `ANTHROPIC_BASE_URL` 推荐直接指向 DS2API 根地址(例如 `http://127.0.0.1:5001`),Claude Code 会请求 `/v1/messages?beta=true`。
|
||||
- `ANTHROPIC_API_KEY` 需要与 `config.json` 中 `keys` 一致;建议同时保留常规 key 与 `sk-ant-*` 形态 key,兼容不同客户端校验习惯。
|
||||
- 若系统设置了代理,建议对 DS2API 地址配置 `NO_PROXY=127.0.0.1,localhost,<你的主机IP>`,避免本地回环请求被代理拦截。
|
||||
- 如遇“工具调用输出成文本、未执行”问题,请升级到包含 Claude 工具调用多格式解析(JSON/XML/ANTML/invoke)的版本。
|
||||
|
||||
### Gemini 接口
|
||||
|
||||
Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。
|
||||
@@ -124,6 +140,7 @@ cp config.example.json config.json
|
||||
后续部署建议:
|
||||
- 本地运行:直接读取 `config.json`
|
||||
- Docker / Vercel:由 `config.json` 生成 `DS2API_CONFIG_JSON`(Base64)注入环境变量
|
||||
- 兼容写法:`DS2API_CONFIG_JSON` 也可以直接写原始 JSON;`CONFIG_JSON` 是旧版回退变量
|
||||
|
||||
### 方式一:本地运行
|
||||
|
||||
@@ -144,7 +161,7 @@ go run ./cmd/ds2api
|
||||
|
||||
默认监听地址:`http://localhost:5001`
|
||||
|
||||
> **WebUI 自动构建**:本地首次启动时,若 `static/admin` 不存在,会自动尝试执行 `npm install && npm run build`(需要本机有 Node.js)。你也可以手动构建:`./scripts/build-webui.sh`
|
||||
> **WebUI 自动构建**:本地首次启动时,若 `static/admin` 不存在,会自动尝试执行 `npm ci`(仅在缺少依赖时)和 `npm run build -- --outDir static/admin --emptyOutDir`(需要本机有 Node.js)。你也可以手动构建:`./scripts/build-webui.sh`
|
||||
|
||||
### 方式二:Docker 运行
|
||||
|
||||
@@ -152,20 +169,18 @@ go run ./cmd/ds2api
|
||||
# 1. 准备环境变量文件
|
||||
cp .env.example .env
|
||||
|
||||
# 2. 从 config.json 生成 DS2API_CONFIG_JSON(单行 Base64)
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# 3. 编辑 .env,设置:
|
||||
# 2. 编辑 .env(至少设置 DS2API_ADMIN_KEY)
|
||||
# DS2API_ADMIN_KEY=请替换为强密码
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# 4. 启动
|
||||
# 3. 启动
|
||||
docker-compose up -d
|
||||
|
||||
# 5. 查看日志
|
||||
# 4. 查看日志
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
默认 `docker-compose.yml` 会把宿主机 `6011` 映射到容器内的 `5001`。如果你希望直接对外暴露 `5001`,请调整 `ports` 配置。
|
||||
|
||||
更新镜像:`docker-compose up -d --build`
|
||||
|
||||
#### Zeabur 一键部署(Dockerfile)
|
||||
@@ -174,6 +189,8 @@ docker-compose logs -f
|
||||
2. 部署完成后访问 `/admin`,使用 Zeabur 环境变量/模板指引中的 `DS2API_ADMIN_KEY` 登录。
|
||||
3. 在管理台导入/编辑配置(会写入并持久化到 `/data/config.json`)。
|
||||
|
||||
说明:Zeabur 使用仓库内 `Dockerfile` 直接构建时,不需要额外传入 `BUILD_VERSION`;镜像会优先读取该构建参数,未提供时自动回退到仓库根目录的 `VERSION` 文件。
|
||||
|
||||
### 方式三:Vercel 部署
|
||||
|
||||
1. Fork 仓库到自己的 GitHub
|
||||
@@ -196,7 +213,7 @@ base64 < config.json | tr -d '\n'
|
||||
|
||||
> **流式说明**:`/v1/chat/completions` 在 Vercel 上默认走 `api/chat-stream.js`(Node Runtime)以保证实时 SSE。鉴权、账号选择、会话/PoW 准备仍由 Go 内部 prepare 接口完成;流式响应(含 `tools`)在 Node 侧执行与 Go 对齐的输出组装与防泄漏处理。
|
||||
|
||||
详细部署说明请参阅 [部署指南](DEPLOY.md)。
|
||||
详细部署说明请参阅 [部署指南](docs/DEPLOY.md)。
|
||||
|
||||
### 方式四:下载 Release 构建包
|
||||
|
||||
@@ -238,13 +255,11 @@ cp opencode.json.example opencode.json
|
||||
"accounts": [
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
},
|
||||
{
|
||||
"mobile": "12345678901",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
@@ -255,17 +270,13 @@ cp opencode.json.example opencode.json
|
||||
"compat": {
|
||||
"wide_input_strict_output": true
|
||||
},
|
||||
"toolcall": {
|
||||
"mode": "feature_match",
|
||||
"early_emit_confidence": "high"
|
||||
},
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_model_mapping": {
|
||||
"claude_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
@@ -275,22 +286,27 @@ cp opencode.json.example opencode.json
|
||||
"runtime": {
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
"global_max_inflight": 0,
|
||||
"token_refresh_interval_hours": 6
|
||||
},
|
||||
"auto_delete": {
|
||||
"sessions": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer <key>` 鉴权
|
||||
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录
|
||||
- `token`:留空则首次请求时自动登录获取;也可预填已有 token
|
||||
- `token`:配置文件中即使填写也会在加载时被清空(不会从 `config.json` 读取 token);实际 token 仅在运行时内存中维护并自动刷新
|
||||
- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射
|
||||
- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出)
|
||||
- `toolcall`:固定采用特征匹配 + 高置信早发策略
|
||||
- `toolcall`:策略已固定为特征匹配 + 高置信早发,不再作为可配置项
|
||||
- `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL
|
||||
- `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`)
|
||||
- `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型
|
||||
- `claude_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型(兼容读取 `claude_model_mapping`)
|
||||
- `admin`:管理后台设置(JWT 过期时间、密码哈希等),可通过 Admin Settings API 热更新
|
||||
- `runtime`:运行时参数(并发限制、队列大小),可通过 Admin Settings API 热更新
|
||||
- `runtime`:运行时参数(并发限制、队列大小、托管账号 token 刷新间隔),可通过 Admin Settings API 热更新;`account_max_queue=0`/`global_max_inflight=0` 表示按推荐值自动计算,`token_refresh_interval_hours=6` 为默认强制重登间隔
|
||||
- `auto_delete.sessions`:是否在请求结束后自动清理 DeepSeek 会话(默认 `false`,可在 Settings 热更新)
|
||||
|
||||
### 环境变量
|
||||
|
||||
@@ -303,9 +319,13 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_JWT_EXPIRE_HOURS` | Admin JWT 过期小时数 | `24` |
|
||||
| `DS2API_CONFIG_PATH` | 配置文件路径 | `config.json` |
|
||||
| `DS2API_CONFIG_JSON` | 直接注入配置(JSON 或 Base64) | — |
|
||||
| `CONFIG_JSON` | 旧版兼容配置注入 | — |
|
||||
| `DS2API_WASM_PATH` | PoW WASM 文件路径 | 自动查找 |
|
||||
| `DS2API_STATIC_ADMIN_DIR` | 管理台静态文件目录 | `static/admin` |
|
||||
| `DS2API_AUTO_BUILD_WEBUI` | 启动时自动构建 WebUI | 本地开启,Vercel 关闭 |
|
||||
| `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_LIMIT` | 本地抓包保留条数(超出自动淘汰) | `5` |
|
||||
| `DS2API_DEV_PACKET_CAPTURE_MAX_BODY_BYTES` | 单条响应体最大记录字节数 | `2097152` |
|
||||
| `DS2API_ACCOUNT_MAX_INFLIGHT` | 每账号最大并发 in-flight 请求数 | `2` |
|
||||
| `DS2API_ACCOUNT_CONCURRENCY` | 同上(兼容旧名) | — |
|
||||
| `DS2API_ACCOUNT_MAX_QUEUE` | 等待队列上限 | `recommended_concurrency` |
|
||||
@@ -332,6 +352,7 @@ cp opencode.json.example opencode.json
|
||||
| **直通 token 模式** | 传入 token 不在 `config.keys` 中时,直接作为 DeepSeek token 使用 |
|
||||
|
||||
可选请求头 `X-Ds2-Target-Account`:指定使用某个托管账号(值为 email 或 mobile)。
|
||||
Gemini 路由还可以使用 `x-goog-api-key`,或在没有认证头时使用 `?key=` / `?api_key=` 作为调用方凭据。
|
||||
|
||||
## 并发模型
|
||||
|
||||
@@ -348,13 +369,17 @@ cp opencode.json.example opencode.json
|
||||
|
||||
## Tool Call 适配
|
||||
|
||||
当请求中带 `tools` 时,DS2API 会做防泄漏处理:
|
||||
当请求中带 `tools` 时,DS2API 会做防泄漏处理与结构化转译:
|
||||
|
||||
1. 只在**非代码块上下文**启用 toolcall 特征识别(代码块示例不会触发)
|
||||
2. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`)
|
||||
3. 未在 `tools` 声明中的工具名会被严格拒绝,不会下发为有效 tool call
|
||||
1. 只在**非代码块上下文**启用执行型 toolcall 识别(代码块示例默认不触发)
|
||||
2. 解析层以 XML/Markup 为最高优先级,同时兼容 JSON / ANTML / invoke / text-kv,并统一归一到内部工具调用结构
|
||||
3. `responses` 流式严格使用官方 item 生命周期事件(`response.output_item.*`、`response.content_part.*`、`response.function_call_arguments.*`)
|
||||
4. `responses` 支持并执行 `tool_choice`(`auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed`
|
||||
5. 仅在通过策略校验后才会发出有效工具调用事件,避免错误工具名进入客户端执行链
|
||||
5. 客户端请求哪种协议,就按该协议返回工具调用(OpenAI/Claude/Gemini 各自原生结构);模型侧优先约束输出规范 XML,再由兼容层转译
|
||||
|
||||
> 说明:当前版本在 parser 层仍以“尽量解析成功”为优先,未启用基于 allow-list 的工具名硬拒绝。
|
||||
>
|
||||
> 想评估“把工具调用封装成 XML 再输入模型”的方案,可参考:`docs/toolcall-semantics.md`。
|
||||
|
||||
## 本地开发抓包工具
|
||||
|
||||
@@ -389,7 +414,7 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go 入口
|
||||
│ ├── chat-stream.js # Vercel Node.js 流式转发
|
||||
│ └── helpers/ # Node.js 辅助模块
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # 账号池与并发队列
|
||||
│ ├── adapter/
|
||||
@@ -402,6 +427,7 @@ ds2api/
|
||||
│ ├── compat/ # 兼容性辅助
|
||||
│ ├── config/ # 配置加载与热更新
|
||||
│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM
|
||||
│ ├── js/ # Node 运行时流式处理与兼容逻辑
|
||||
│ ├── devcapture/ # 开发抓包模块
|
||||
│ ├── format/ # 输出格式化
|
||||
│ ├── prompt/ # Prompt 构建
|
||||
@@ -412,13 +438,16 @@ ds2api/
|
||||
│ └── webui/ # WebUI 静态文件托管与自动构建
|
||||
├── webui/ # React WebUI 源码(Vite + Tailwind)
|
||||
│ └── src/
|
||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||
│ ├── app/ # 路由、鉴权、配置状态管理
|
||||
│ ├── features/ # 业务功能模块(account/settings/vercel/apiTester)
|
||||
│ ├── components/ # 登录/落地页等通用组件
|
||||
│ └── locales/ # 中英文语言包(zh.json / en.json)
|
||||
├── scripts/
|
||||
│ └── build-webui.sh # WebUI 手动构建脚本
|
||||
├── tests/
|
||||
│ ├── compat/ # 兼容性测试夹具与期望输出
|
||||
│ └── scripts/ # 统一测试脚本入口(unit/e2e)
|
||||
├── docs/ # 部署 / 贡献 / 测试等辅助文档
|
||||
├── static/admin/ # WebUI 构建产物(不提交到 Git)
|
||||
├── .github/
|
||||
│ ├── workflows/ # GitHub Actions(质量门禁 + Release 自动构建)
|
||||
@@ -438,9 +467,9 @@ ds2api/
|
||||
| 文档 | 说明 |
|
||||
| --- | --- |
|
||||
| [API.md](API.md) / [API.en.md](API.en.md) | API 接口文档(含请求/响应示例) |
|
||||
| [DEPLOY.md](DEPLOY.md) / [DEPLOY.en.md](DEPLOY.en.md) | 部署指南(本地/Docker/Vercel/systemd) |
|
||||
| [CONTRIBUTING.md](CONTRIBUTING.md) / [CONTRIBUTING.en.md](CONTRIBUTING.en.md) | 贡献指南 |
|
||||
| [TESTING.md](TESTING.md) | 测试集使用指南 |
|
||||
| [DEPLOY.md](docs/DEPLOY.md) / [DEPLOY.en.md](docs/DEPLOY.en.md) | 部署指南(本地/Docker/Vercel/systemd) |
|
||||
| [CONTRIBUTING.md](docs/CONTRIBUTING.md) / [CONTRIBUTING.en.md](docs/CONTRIBUTING.en.md) | 贡献指南 |
|
||||
| [TESTING.md](docs/TESTING.md) | 测试集使用指南 |
|
||||
|
||||
## 测试
|
||||
|
||||
@@ -468,6 +497,23 @@ go run ./cmd/ds2api-tests \
|
||||
npm ci --prefix webui && npm run build --prefix webui
|
||||
```
|
||||
|
||||
## 测试
|
||||
|
||||
详细测试指南请参阅 [docs/TESTING.md](docs/TESTING.md)。
|
||||
|
||||
### 快速测试命令
|
||||
|
||||
```bash
|
||||
# 运行所有单元测试
|
||||
go test ./...
|
||||
|
||||
# 运行 tool calls 相关测试(调试工具调用问题)
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||
|
||||
# 运行端到端测试
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
## Release 自动构建(GitHub Actions)
|
||||
|
||||
工作流文件:`.github/workflows/release-artifacts.yml`
|
||||
@@ -475,8 +521,11 @@ npm ci --prefix webui && npm run build --prefix webui
|
||||
- **触发条件**:仅在 GitHub Release `published` 时触发(普通 push 不会触发)
|
||||
- **构建产物**:多平台二进制包(`linux/amd64`、`linux/arm64`、`darwin/amd64`、`darwin/arm64`、`windows/amd64`)+ `sha256sums.txt`
|
||||
- **容器镜像发布**:仅推送到 GHCR(`ghcr.io/cjackhwang/ds2api`)
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件、配置示例、README、LICENSE
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件(同时支持内置 fallback)、配置示例、README、LICENSE
|
||||
|
||||
## 免责声明
|
||||
|
||||
本项目基于逆向方式实现,仅供学习与研究使用。稳定性和可用性不作保证,请勿用于违反服务条款或法律法规的场景。
|
||||
本项目基于逆向方式实现,仅供学习、研究、个人实验和内部验证使用,不提供任何商业授权、稳定性保证或可用性保证。
|
||||
作者及仓库维护者不对因使用、修改、分发、部署或依赖本项目而产生的任何直接或间接损失、账号封禁、数据丢失、法律风险或第三方索赔负责。
|
||||
|
||||
请勿将本项目用于违反服务条款、协议、法律法规或平台规则的场景。商业使用前请自行确认 `LICENSE`、相关协议以及你是否获得了作者的书面许可。
|
||||
|
||||
90
README.en.md
90
README.en.md
@@ -8,7 +8,7 @@
|
||||

|
||||

|
||||
[](https://github.com/CJackHwang/ds2api/releases)
|
||||
[](DEPLOY.en.md)
|
||||
[](docs/DEPLOY.en.md)
|
||||
[](https://zeabur.com/templates/L4CFHP)
|
||||
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||
|
||||
@@ -16,6 +16,14 @@ Language: [中文](README.MD) | [English](README.en.md)
|
||||
|
||||
DS2API converts DeepSeek Web chat capability into OpenAI-compatible, Claude-compatible, and Gemini-compatible APIs. The backend is a **pure Go implementation**, with a React WebUI admin panel (source in `webui/`, build output auto-generated to `static/admin` during deployment).
|
||||
|
||||
> **Important Disclaimer**
|
||||
>
|
||||
> This repository is provided for learning, research, personal experimentation, and internal validation only. It does not grant any commercial authorization and comes with no warranty of fitness, stability, or results.
|
||||
>
|
||||
> The author and repository maintainers are not responsible for any direct or indirect loss, account suspension, data loss, legal risk, or third-party claims arising from use, modification, distribution, deployment, or reliance on this project.
|
||||
>
|
||||
> Do not use this project in ways that violate service terms, agreements, laws, or platform rules. Before any commercial use, review the `LICENSE`, the relevant terms, and confirm that you have the author's written permission.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```mermaid
|
||||
@@ -68,7 +76,7 @@ flowchart LR
|
||||
| Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency |
|
||||
| DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency |
|
||||
| Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output |
|
||||
| Admin API | Config management, runtime settings hot-reload, account testing/batch test, import/export, Vercel sync |
|
||||
| Admin API | Config management, runtime settings hot-reload, account testing/batch test, session cleanup, import/export, Vercel sync, version check |
|
||||
| WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) |
|
||||
| Health Probes | `GET /healthz` (liveness), `GET /readyz` (readiness) |
|
||||
|
||||
@@ -106,6 +114,14 @@ flowchart LR
|
||||
Override mapping via `claude_mapping` or `claude_model_mapping` in config.
|
||||
In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility.
|
||||
|
||||
|
||||
#### Claude Code integration pitfalls (validated)
|
||||
|
||||
- Set `ANTHROPIC_BASE_URL` to the DS2API root URL (for example `http://127.0.0.1:5001`). Claude Code sends requests to `/v1/messages?beta=true`.
|
||||
- `ANTHROPIC_API_KEY` must match an entry in `keys` from `config.json`. Keeping both a regular key and an `sk-ant-*` style key improves client compatibility.
|
||||
- If your environment has proxy variables, set `NO_PROXY=127.0.0.1,localhost,<your_host_ip>` for DS2API to avoid proxy interception of local traffic.
|
||||
- If tool calls are rendered as plain text and not executed, upgrade to a build that includes multi-format Claude tool-call parsing (JSON/XML/ANTML/invoke).
|
||||
|
||||
### Gemini Endpoint
|
||||
|
||||
The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output).
|
||||
@@ -124,6 +140,7 @@ cp config.example.json config.json
|
||||
Recommended per deployment mode:
|
||||
- Local run: read `config.json` directly
|
||||
- Docker / Vercel: generate Base64 from `config.json` and inject as `DS2API_CONFIG_JSON`
|
||||
- Compatibility note: `DS2API_CONFIG_JSON` may also contain raw JSON directly; `CONFIG_JSON` is the legacy fallback variable
|
||||
|
||||
### Option 1: Local Run
|
||||
|
||||
@@ -144,7 +161,7 @@ go run ./cmd/ds2api
|
||||
|
||||
Default URL: `http://localhost:5001`
|
||||
|
||||
> **WebUI auto-build**: On first local startup, if `static/admin` is missing, DS2API will auto-run `npm install && npm run build` (requires Node.js). You can also build manually: `./scripts/build-webui.sh`
|
||||
> **WebUI auto-build**: On first local startup, if `static/admin` is missing, DS2API will auto-run `npm ci` (only when dependencies are missing) and `npm run build -- --outDir static/admin --emptyOutDir` (requires Node.js). You can also build manually: `./scripts/build-webui.sh`
|
||||
|
||||
### Option 2: Docker
|
||||
|
||||
@@ -152,20 +169,18 @@ Default URL: `http://localhost:5001`
|
||||
# 1. Prepare env file
|
||||
cp .env.example .env
|
||||
|
||||
# 2. Generate DS2API_CONFIG_JSON from config.json (single-line Base64)
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# 3. Edit .env and set:
|
||||
# 2. Edit .env (at least set DS2API_ADMIN_KEY)
|
||||
# DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# 4. Start
|
||||
# 3. Start
|
||||
docker-compose up -d
|
||||
|
||||
# 5. View logs
|
||||
# 4. View logs
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
The default `docker-compose.yml` maps host port `6011` to container port `5001`. If you want `5001` exposed directly, adjust the `ports` mapping.
|
||||
|
||||
Rebuild after updates: `docker-compose up -d --build`
|
||||
|
||||
#### Zeabur One-Click (Dockerfile)
|
||||
@@ -174,6 +189,8 @@ Rebuild after updates: `docker-compose up -d --build`
|
||||
2. After deployment, open `/admin` and login with `DS2API_ADMIN_KEY` shown in Zeabur env/template instructions.
|
||||
3. Import / edit config in Admin UI (it will be written and persisted to `/data/config.json`).
|
||||
|
||||
Note: when Zeabur builds directly from the repo `Dockerfile`, you do not need to pass `BUILD_VERSION`. The image prefers that build arg when provided, and automatically falls back to the repo-root `VERSION` file when it is absent.
|
||||
|
||||
### Option 3: Vercel
|
||||
|
||||
1. Fork this repo to your GitHub account
|
||||
@@ -196,7 +213,7 @@ base64 < config.json | tr -d '\n'
|
||||
|
||||
> **Streaming note**: `/v1/chat/completions` on Vercel is routed to `api/chat-stream.js` (Node Runtime) for real-time SSE. Auth, account selection, and session/PoW preparation are still handled by the Go internal prepare endpoint; streaming output (including `tools`) is assembled on Node with Go-aligned anti-leak handling.
|
||||
|
||||
For detailed deployment instructions, see the [Deployment Guide](DEPLOY.en.md).
|
||||
For detailed deployment instructions, see the [Deployment Guide](docs/DEPLOY.en.md).
|
||||
|
||||
### Option 4: Download Release Binaries
|
||||
|
||||
@@ -238,13 +255,11 @@ cp opencode.json.example opencode.json
|
||||
"accounts": [
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
},
|
||||
{
|
||||
"mobile": "12345678901",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
@@ -255,17 +270,13 @@ cp opencode.json.example opencode.json
|
||||
"compat": {
|
||||
"wide_input_strict_output": true
|
||||
},
|
||||
"toolcall": {
|
||||
"mode": "feature_match",
|
||||
"early_emit_confidence": "high"
|
||||
},
|
||||
"responses": {
|
||||
"store_ttl_seconds": 900
|
||||
},
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_model_mapping": {
|
||||
"claude_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
@@ -275,22 +286,27 @@ cp opencode.json.example opencode.json
|
||||
"runtime": {
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
"global_max_inflight": 0,
|
||||
"token_refresh_interval_hours": 6
|
||||
},
|
||||
"auto_delete": {
|
||||
"sessions": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- `keys`: API access keys; clients authenticate via `Authorization: Bearer <key>`
|
||||
- `accounts`: DeepSeek account list, supports `email` or `mobile` login
|
||||
- `token`: Leave empty for auto-login on first request; or pre-fill an existing token
|
||||
- `token`: Even if set in `config.json`, it is cleared during load (DS2API does not read persisted tokens from config); runtime tokens are maintained/refreshed in memory only
|
||||
- `model_aliases`: Map common model names (GPT/Codex/Claude) to DeepSeek models
|
||||
- `compat.wide_input_strict_output`: Keep `true` (current default policy)
|
||||
- `toolcall`: Fixed to feature matching + high-confidence early emit
|
||||
- `toolcall`: Fixed to feature matching + high-confidence early emit, no longer configurable
|
||||
- `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}`
|
||||
- `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in)
|
||||
- `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models
|
||||
- `claude_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models (still compatible with `claude_model_mapping`)
|
||||
- `admin`: Admin panel settings (JWT expiry, password hash, etc.), hot-reloadable via Admin Settings API
|
||||
- `runtime`: Runtime parameters (concurrency limits, queue sizes), hot-reloadable via Admin Settings API
|
||||
- `runtime`: Runtime parameters (concurrency limits, queue sizes, managed token refresh interval), hot-reloadable via Admin Settings API; `account_max_queue=0`/`global_max_inflight=0` means auto-calculate from recommended values, `token_refresh_interval_hours=6` is the default forced re-login interval
|
||||
- `auto_delete.sessions`: Whether to auto-delete DeepSeek sessions after request completion (default `false`, hot-reloadable via Settings)
|
||||
|
||||
### Environment Variables
|
||||
|
||||
@@ -303,6 +319,7 @@ cp opencode.json.example opencode.json
|
||||
| `DS2API_JWT_EXPIRE_HOURS` | Admin JWT TTL in hours | `24` |
|
||||
| `DS2API_CONFIG_PATH` | Config file path | `config.json` |
|
||||
| `DS2API_CONFIG_JSON` | Inline config (JSON or Base64) | — |
|
||||
| `CONFIG_JSON` | Legacy compatibility config input | — |
|
||||
| `DS2API_WASM_PATH` | PoW WASM file path | Auto-detect |
|
||||
| `DS2API_STATIC_ADMIN_DIR` | Admin static assets dir | `static/admin` |
|
||||
| `DS2API_AUTO_BUILD_WEBUI` | Auto-build WebUI on startup | Enabled locally, disabled on Vercel |
|
||||
@@ -332,6 +349,7 @@ For business endpoints (`/v1/*`, `/anthropic/*`, Gemini routes), DS2API supports
|
||||
| **Direct token** | If the token is not in `config.keys`, DS2API treats it as a DeepSeek token directly |
|
||||
|
||||
Optional header `X-Ds2-Target-Account`: Pin a specific managed account (value is email or mobile).
|
||||
Gemini routes also accept `x-goog-api-key`, or `?key=` / `?api_key=` when no auth header is present.
|
||||
|
||||
## Concurrency Model
|
||||
|
||||
@@ -351,6 +369,7 @@ Queue limit = DS2API_ACCOUNT_MAX_QUEUE (default = recommended concurrency)
|
||||
When `tools` is present in the request, DS2API performs anti-leak handling:
|
||||
|
||||
1. Toolcall feature matching is enabled only in **non-code-block context** (fenced examples are ignored)
|
||||
- In non-code-block context, tool JSON may still be recognized even when mixed with normal prose; surrounding prose can remain as text output.
|
||||
2. `responses` streaming strictly uses official item lifecycle events (`response.output_item.*`, `response.content_part.*`, `response.function_call_arguments.*`)
|
||||
3. Tool names not declared in the `tools` schema are strictly rejected and will not be emitted as valid tool calls
|
||||
4. `responses` supports and enforces `tool_choice` (`auto`/`none`/`required`/forced function); `required` violations return `422` for non-stream and `response.failed` for stream
|
||||
@@ -389,7 +408,7 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go entry
|
||||
│ ├── chat-stream.js # Vercel Node.js stream relay
|
||||
│ └── helpers/ # Node.js helper modules
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # Account pool and concurrency queue
|
||||
│ ├── adapter/
|
||||
@@ -402,6 +421,7 @@ ds2api/
|
||||
│ ├── compat/ # Compatibility helpers
|
||||
│ ├── config/ # Config loading and hot-reload
|
||||
│ ├── deepseek/ # DeepSeek API client, PoW WASM
|
||||
│ ├── js/ # Node runtime stream/compat logic
|
||||
│ ├── devcapture/ # Dev packet capture module
|
||||
│ ├── format/ # Output formatting
|
||||
│ ├── prompt/ # Prompt construction
|
||||
@@ -412,13 +432,16 @@ ds2api/
|
||||
│ └── webui/ # WebUI static file serving and auto-build
|
||||
├── webui/ # React WebUI source (Vite + Tailwind)
|
||||
│ └── src/
|
||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||
│ ├── app/ # Routing, auth, config state
|
||||
│ ├── features/ # Feature modules (account/settings/vercel/apiTester)
|
||||
│ ├── components/ # Shared UI pieces (login/landing, etc.)
|
||||
│ └── locales/ # Language packs (zh.json / en.json)
|
||||
├── scripts/
|
||||
│ └── build-webui.sh # Manual WebUI build script
|
||||
├── tests/
|
||||
│ ├── compat/ # Compatibility fixtures and expected outputs
|
||||
│ └── scripts/ # Unified test script entrypoints (unit/e2e)
|
||||
├── docs/ # Deployment / contributing / testing docs
|
||||
├── static/admin/ # WebUI build output (not committed to Git)
|
||||
├── .github/
|
||||
│ ├── workflows/ # GitHub Actions (quality gates + release automation)
|
||||
@@ -438,9 +461,9 @@ ds2api/
|
||||
| Document | Description |
|
||||
| --- | --- |
|
||||
| [API.md](API.md) / [API.en.md](API.en.md) | API reference with request/response examples |
|
||||
| [DEPLOY.md](DEPLOY.md) / [DEPLOY.en.md](DEPLOY.en.md) | Deployment guide (local/Docker/Vercel/systemd) |
|
||||
| [CONTRIBUTING.md](CONTRIBUTING.md) / [CONTRIBUTING.en.md](CONTRIBUTING.en.md) | Contributing guide |
|
||||
| [TESTING.md](TESTING.md) | Testsuite guide |
|
||||
| [DEPLOY.md](docs/DEPLOY.md) / [DEPLOY.en.md](docs/DEPLOY.en.md) | Deployment guide (local/Docker/Vercel/systemd) |
|
||||
| [CONTRIBUTING.md](docs/CONTRIBUTING.md) / [CONTRIBUTING.en.md](docs/CONTRIBUTING.en.md) | Contributing guide |
|
||||
| [TESTING.md](docs/TESTING.md) | Testsuite guide |
|
||||
|
||||
## Testing
|
||||
|
||||
@@ -475,8 +498,11 @@ Workflow: `.github/workflows/release-artifacts.yml`
|
||||
- **Trigger**: only on GitHub Release `published` (normal pushes do not trigger builds)
|
||||
- **Outputs**: multi-platform archives (`linux/amd64`, `linux/arm64`, `darwin/amd64`, `darwin/arm64`, `windows/amd64`) + `sha256sums.txt`
|
||||
- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`)
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file, config template, README, LICENSE
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file (with embedded fallback support), config template, README, LICENSE
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This project is built through reverse engineering and is provided for learning and research only. Stability is not guaranteed. Do not use it in scenarios that violate terms of service or laws.
|
||||
This project is built through reverse engineering and is provided for learning, research, personal experimentation, and internal validation only. No commercial authorization is granted, and no warranty of stability, fitness, or results is provided.
|
||||
The author and repository maintainers are not responsible for any direct or indirect loss, account suspension, data loss, legal risk, or third-party claims arising from use, modification, distribution, deployment, or reliance on this project.
|
||||
|
||||
Do not use this project in ways that violate service terms, agreements, laws, or platform rules. Before any commercial use, review the `LICENSE`, the relevant terms, and confirm that you have the author's written permission.
|
||||
|
||||
@@ -9,20 +9,17 @@
|
||||
{
|
||||
"_comment": "邮箱登录方式",
|
||||
"email": "example1@example.com",
|
||||
"password": "your-password-1",
|
||||
"token": ""
|
||||
"password": "your-password-1"
|
||||
},
|
||||
{
|
||||
"_comment": "邮箱登录方式 - 账号2",
|
||||
"email": "example2@example.com",
|
||||
"password": "your-password-2",
|
||||
"token": ""
|
||||
"password": "your-password-2"
|
||||
},
|
||||
{
|
||||
"_comment": "手机号登录方式(中国大陆)",
|
||||
"mobile": "12345678901",
|
||||
"password": "your-password-3",
|
||||
"token": ""
|
||||
"password": "your-password-3"
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
@@ -43,8 +40,19 @@
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_model_mapping": {
|
||||
"claude_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
"admin": {
|
||||
"jwt_expire_hours": 24
|
||||
},
|
||||
"runtime": {
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
},
|
||||
"auto_delete": {
|
||||
"sessions": false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +70,7 @@ docker-compose -f docker-compose.dev.yml up
|
||||
5. Open a Pull Request
|
||||
|
||||
> 💡 If you modify files under `webui/`, no manual build is needed — CI handles it automatically.
|
||||
> If you want to verify the generated `static/admin/` assets locally, you can still run `./scripts/build-webui.sh`.
|
||||
|
||||
## Build WebUI
|
||||
|
||||
@@ -99,7 +100,7 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go entry
|
||||
│ ├── chat-stream.js # Vercel Node.js stream relay
|
||||
│ └── helpers/ # Node.js helper modules
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # Account pool and concurrency queue
|
||||
│ ├── adapter/
|
||||
@@ -112,6 +113,7 @@ ds2api/
|
||||
│ ├── compat/ # Compatibility helpers
|
||||
│ ├── config/ # Config loading and hot-reload
|
||||
│ ├── deepseek/ # DeepSeek client, PoW WASM
|
||||
│ ├── js/ # Node runtime stream/compat logic
|
||||
│ ├── devcapture/ # Dev packet capture
|
||||
│ ├── format/ # Output formatting
|
||||
│ ├── prompt/ # Prompt building
|
||||
@@ -123,9 +125,13 @@ ds2api/
|
||||
│ └── webui/ # WebUI static hosting
|
||||
├── webui/ # React WebUI source
|
||||
│ └── src/
|
||||
│ ├── components/ # Components
|
||||
│ ├── app/ # Routing, auth, config state
|
||||
│ ├── features/ # Feature modules
|
||||
│ ├── components/ # Shared components
|
||||
│ └── locales/ # Language packs
|
||||
├── scripts/ # Build and test scripts
|
||||
├── tests/ # Unit tests, Node tests, and end-to-end tests
|
||||
├── plans/ # Plans, gates, and manual smoke-test records
|
||||
├── static/admin/ # WebUI build output (not committed)
|
||||
├── Dockerfile # Multi-stage build
|
||||
├── docker-compose.yml # Production
|
||||
@@ -70,6 +70,7 @@ docker-compose -f docker-compose.dev.yml up
|
||||
5. 发起 Pull Request
|
||||
|
||||
> 💡 如果修改了 `webui/` 目录下的文件,无需手动构建——CI 会自动处理。
|
||||
> 但如果你本地想验证 `static/admin/` 产物,还是可以手动运行 `./scripts/build-webui.sh`。
|
||||
|
||||
## WebUI 构建
|
||||
|
||||
@@ -99,7 +100,7 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go 入口
|
||||
│ ├── chat-stream.js # Vercel Node.js 流式转发
|
||||
│ └── helpers/ # Node.js 辅助模块
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # 账号池与并发队列
|
||||
│ ├── adapter/
|
||||
@@ -112,6 +113,7 @@ ds2api/
|
||||
│ ├── compat/ # 兼容性辅助
|
||||
│ ├── config/ # 配置加载与热更新
|
||||
│ ├── deepseek/ # DeepSeek 客户端、PoW WASM
|
||||
│ ├── js/ # Node 运行时流式/兼容逻辑
|
||||
│ ├── devcapture/ # 开发抓包
|
||||
│ ├── format/ # 输出格式化
|
||||
│ ├── prompt/ # Prompt 构建
|
||||
@@ -123,9 +125,13 @@ ds2api/
|
||||
│ └── webui/ # WebUI 静态托管
|
||||
├── webui/ # React WebUI 源码
|
||||
│ └── src/
|
||||
│ ├── components/ # 组件
|
||||
│ ├── app/ # 路由、鉴权、配置状态
|
||||
│ ├── features/ # 业务功能模块
|
||||
│ ├── components/ # 通用组件
|
||||
│ └── locales/ # 语言包
|
||||
├── scripts/ # 构建与测试脚本
|
||||
├── tests/ # 单元测试、Node 测试与端到端测试
|
||||
├── plans/ # 计划、门禁和手工烟测记录
|
||||
├── static/admin/ # WebUI 构建产物(不提交)
|
||||
├── Dockerfile # 多阶段构建
|
||||
├── docker-compose.yml # 生产环境
|
||||
@@ -32,6 +32,7 @@ Config source (choose one):
|
||||
|
||||
- **File**: `config.json` (recommended for local/Docker)
|
||||
- **Environment variable**: `DS2API_CONFIG_JSON` (recommended for Vercel; supports raw JSON or Base64)
|
||||
- Compatibility note: `CONFIG_JSON` is the legacy fallback variable; `DS2API_CONFIG_JSON` may also contain raw JSON directly
|
||||
|
||||
Unified recommendation (best practice):
|
||||
|
||||
@@ -69,7 +70,7 @@ Default address: `http://0.0.0.0:5001` (override with `PORT`).
|
||||
|
||||
### 1.2 WebUI Build
|
||||
|
||||
On first local startup, if `static/admin/` is missing, DS2API will automatically attempt to build the WebUI (requires Node.js/npm).
|
||||
On first local startup, if `static/admin/` is missing, DS2API will automatically attempt to build the WebUI (requires Node.js/npm; when dependencies are missing it runs `npm ci` first, then `npm run build -- --outDir static/admin --emptyOutDir`).
|
||||
|
||||
Manual build:
|
||||
|
||||
@@ -113,12 +114,8 @@ go build -o ds2api ./cmd/ds2api
|
||||
# Copy env template
|
||||
cp .env.example .env
|
||||
|
||||
# Generate single-line Base64 from config.json
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# Edit .env and set:
|
||||
# Edit .env and set at least:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# Start
|
||||
docker-compose up -d
|
||||
@@ -127,6 +124,8 @@ docker-compose up -d
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
The default `docker-compose.yml` maps host port `6011` to container port `5001`. If you want `5001` exposed directly, adjust the `ports` mapping.
|
||||
|
||||
### 2.2 Update
|
||||
|
||||
```bash
|
||||
@@ -185,6 +184,7 @@ Notes:
|
||||
|
||||
- **Port**: DS2API listens on `5001` by default; the template sets `PORT=5001`.
|
||||
- **Persistent config**: the template mounts `/data` and sets `DS2API_CONFIG_PATH=/data/config.json`. After importing config in Admin UI, it will be written and persisted to this path.
|
||||
- **Build version**: Zeabur / regular `docker build` does not require `BUILD_VERSION` by default. The image prefers that build arg when provided, and automatically falls back to the repo-root `VERSION` file when it is absent.
|
||||
- **First login**: after deployment, open `/admin` and login with `DS2API_ADMIN_KEY` shown in Zeabur env/template instructions (recommended: rotate to a strong secret after first login).
|
||||
|
||||
---
|
||||
@@ -366,7 +366,7 @@ Each archive includes:
|
||||
|
||||
- `ds2api` executable (`ds2api.exe` on Windows)
|
||||
- `static/admin/` (built WebUI assets)
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm`
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm` (optional; binary has embedded fallback)
|
||||
- `config.example.json`, `.env.example`
|
||||
- `README.MD`, `README.en.md`, `LICENSE`
|
||||
|
||||
@@ -455,7 +455,9 @@ server {
|
||||
```bash
|
||||
# Copy compiled binary and related files to target directory
|
||||
sudo mkdir -p /opt/ds2api
|
||||
sudo cp ds2api config.json sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp ds2api config.json /opt/ds2api/
|
||||
# Optional: if you want to use an external WASM file (override the embedded one, from a release package or build output)
|
||||
# sudo cp /path/to/sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp -r static/admin /opt/ds2api/static/admin
|
||||
```
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
|
||||
- **文件方式**:`config.json`(推荐本地/Docker 使用)
|
||||
- **环境变量方式**:`DS2API_CONFIG_JSON`(推荐 Vercel 使用,支持 JSON 字符串或 Base64 编码)
|
||||
- 兼容写法:`CONFIG_JSON` 是旧版回退变量;`DS2API_CONFIG_JSON` 也可以直接写原始 JSON
|
||||
|
||||
统一建议(最优实践):
|
||||
|
||||
@@ -69,7 +70,7 @@ go run ./cmd/ds2api
|
||||
|
||||
### 1.2 WebUI 构建
|
||||
|
||||
本地首次启动时,若 `static/admin/` 不存在,服务会自动尝试构建 WebUI(需要 Node.js/npm)。
|
||||
本地首次启动时,若 `static/admin/` 不存在,服务会自动尝试构建 WebUI(需要 Node.js/npm;缺依赖时会先执行 `npm ci`,再执行 `npm run build -- --outDir static/admin --emptyOutDir`)。
|
||||
|
||||
你也可以手动构建:
|
||||
|
||||
@@ -113,12 +114,8 @@ go build -o ds2api ./cmd/ds2api
|
||||
# 复制环境变量模板
|
||||
cp .env.example .env
|
||||
|
||||
# 从 config.json 生成单行 Base64
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# 编辑 .env(请改成你的强密码),设置:
|
||||
# 编辑 .env(请改成你的强密码),至少设置:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# 启动
|
||||
docker-compose up -d
|
||||
@@ -127,6 +124,8 @@ docker-compose up -d
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
默认 `docker-compose.yml` 会把宿主机 `6011` 映射到容器内的 `5001`。如果你希望直接对外暴露 `5001`,请调整 `ports` 配置。
|
||||
|
||||
### 2.2 更新
|
||||
|
||||
```bash
|
||||
@@ -185,6 +184,7 @@ healthcheck:
|
||||
|
||||
- **端口**:服务默认监听 `5001`,模板会固定设置 `PORT=5001`。
|
||||
- **配置持久化**:模板挂载卷 `/data`,并设置 `DS2API_CONFIG_PATH=/data/config.json`;在管理台导入配置后,会写入并持久化到该路径。
|
||||
- **构建版本号**:Zeabur / 普通 `docker build` 默认不需要传 `BUILD_VERSION`;镜像会优先使用该构建参数,未提供时自动回退到仓库根目录的 `VERSION` 文件。
|
||||
- **首次登录**:部署完成后访问 `/admin`,使用 Zeabur 环境变量/模板指引中的 `DS2API_ADMIN_KEY` 登录(建议首次登录后自行更换为强密码)。
|
||||
|
||||
---
|
||||
@@ -366,7 +366,7 @@ No Output Directory named "public" found after the Build completed.
|
||||
|
||||
- `ds2api` 可执行文件(Windows 为 `ds2api.exe`)
|
||||
- `static/admin/`(WebUI 构建产物)
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm`
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm`(可选;程序内置 embed fallback)
|
||||
- `config.example.json`、`.env.example`
|
||||
- `README.MD`、`README.en.md`、`LICENSE`
|
||||
|
||||
@@ -455,7 +455,9 @@ server {
|
||||
```bash
|
||||
# 将编译好的二进制文件和相关文件复制到目标目录
|
||||
sudo mkdir -p /opt/ds2api
|
||||
sudo cp ds2api config.json sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp ds2api config.json /opt/ds2api/
|
||||
# 可选:若你希望使用外置 WASM 文件(覆盖内置版本,来自 release 包或构建产物)
|
||||
# sudo cp /path/to/sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp -r static/admin /opt/ds2api/static/admin
|
||||
```
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# DS2API 测试指南
|
||||
|
||||
语言 / Language: [中文 + English](TESTING.md)
|
||||
语言 / Language: 中文 + English(同页)
|
||||
|
||||
## 概述 | Overview
|
||||
|
||||
@@ -14,6 +14,7 @@ DS2API 提供两个层级的测试:
|
||||
| 端到端测试 | `./tests/scripts/run-live.sh` | 使用真实账号执行全链路测试 |
|
||||
|
||||
端到端测试集会录制完整的请求/响应日志,用于故障排查。
|
||||
Node 单元测试脚本会先做 `node --check` 语法门禁,再以 `--test-concurrency=1` 串行执行测试文件,减少模块级共享状态带来的干扰。
|
||||
|
||||
---
|
||||
|
||||
@@ -51,7 +52,7 @@ DS2API 提供两个层级的测试:
|
||||
1. **Preflight 检查**:
|
||||
- `go test ./... -count=1`(单元测试)
|
||||
- `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁)
|
||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`(Node 流式拦截 + compat 单测)
|
||||
- `node --test tests/node/stream-tool-sieve.test.js tests/node/chat-stream.test.js tests/node/js_compat_test.js`
|
||||
- `npm run build --prefix webui`(WebUI 构建检查)
|
||||
|
||||
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
||||
@@ -66,6 +67,8 @@ DS2API 提供两个层级的测试:
|
||||
|
||||
4. **结果收集**:继续执行所有用例(不中断),写入最终汇总
|
||||
|
||||
如果你只想跳过这些 preflight 检查,可以直接运行 `go run ./cmd/ds2api-tests --no-preflight`。
|
||||
|
||||
---
|
||||
|
||||
## CLI 参数 | CLI Flags
|
||||
@@ -173,6 +176,50 @@ rg "<trace_id>" artifacts/testsuite/<run_id>/server.log
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### 运行特定模块的单元测试
|
||||
|
||||
```bash
|
||||
# 运行 tool calls 相关测试(推荐用于调试 tool call 解析问题)
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||
|
||||
# 运行单个测试用例
|
||||
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/
|
||||
|
||||
# 运行 format 相关测试
|
||||
go test -v ./internal/format/...
|
||||
|
||||
# 运行 adapter 相关测试
|
||||
go test -v ./internal/adapter/openai/...
|
||||
```
|
||||
|
||||
### 调试 Tool Call 问题 | Debugging Tool Call Issues
|
||||
|
||||
当遇到 DeepSeek 工具调用解析问题时,可以使用以下方法:
|
||||
|
||||
```bash
|
||||
# 1. 运行 tool calls 相关的所有测试
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||
|
||||
# 2. 查看测试输出中的详细调试信息
|
||||
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/ 2>&1
|
||||
|
||||
# 3. 检查具体测试用例的修复效果
|
||||
# 测试用例位于 internal/util/toolcalls_test.go,包含:
|
||||
# - TestParseToolCallsWithDeepSeekHallucination: DeepSeek 典型幻觉输出
|
||||
# - TestRepairLooseJSONWithNestedObjects: 嵌套对象的方括号修复
|
||||
# - TestParseToolCallsWithMixedWindowsPaths: Windows 路径处理
|
||||
```
|
||||
|
||||
### 运行 Node.js 测试
|
||||
|
||||
```bash
|
||||
# 运行 Node 测试
|
||||
node --test tests/node/stream-tool-sieve.test.js
|
||||
|
||||
# 或使用脚本
|
||||
./tests/scripts/run-unit-node.sh
|
||||
```
|
||||
|
||||
### 跑端到端测试(跳过 preflight)
|
||||
|
||||
```bash
|
||||
72
docs/toolcall-semantics.md
Normal file
72
docs/toolcall-semantics.md
Normal file
@@ -0,0 +1,72 @@
|
||||
# Tool call parsing semantics(Go/Node 统一语义)
|
||||
|
||||
本文档描述当前代码中 `ParseToolCallsDetailed` / `parseToolCallsDetailed` 的**实际行为**,用于对齐 Go 与 Node Runtime。
|
||||
|
||||
## 1) 输出结构(当前实现)
|
||||
|
||||
- `calls`:解析得到的工具调用列表(`name` + `input`)。
|
||||
- `sawToolCallSyntax`:检测到工具调用语法特征时为 `true`(例如 `tool_calls`、`<tool_call>`、`<function_call>`、`<invoke>`、`function.name:`)。
|
||||
- `rejectedByPolicy`:当前实现固定为 `false`(预留字段,尚未启用 allow-list 拒绝)。
|
||||
- `rejectedToolNames`:当前实现固定为空数组(预留字段)。
|
||||
|
||||
> 说明:`filterToolCallsDetailed` 当前仅做结构清洗,不做工具名策略拒绝。
|
||||
|
||||
## 2) 解析管线
|
||||
|
||||
1. **示例保护**:若判定为 fenced code block 示例上下文,则跳过执行型解析。
|
||||
2. **候选片段构建**:从完整文本中构建候选(原文、围绕 `tool_calls` 的 JSON 片段、首尾大括号切片等)。
|
||||
3. **按序尝试解析(命中即停)**:
|
||||
- 对“明显 JSON 工具载荷候选”(以 `{`/`[` 开头且包含 `tool_calls`/`\"function\"`)先走 JSON 解析,避免 JSON 字符串内偶发 XML 片段误命中;
|
||||
- 其余候选优先 XML 解析(`<tool_call>` / `<function_call>` / `<invoke>` / `tool_use` / `antml:function_call` 等);
|
||||
- JSON 解析(`{"tool_calls": [...]}`、列表、单对象);
|
||||
- Markup 解析;
|
||||
- Text-KV 回退(如 `function.name:` + `function.arguments:`)。
|
||||
4. **兜底**:候选全部失败后,再对全文做 XML / Text-KV 回退。
|
||||
|
||||
## 3) XML 能力边界(当前)
|
||||
|
||||
当前已支持输入端的“多 XML/标记风格”解析,包括但不限于:
|
||||
|
||||
- `<tool_call><tool_name>...</tool_name><parameters>...</parameters></tool_call>`
|
||||
- `<function_call>tool</function_call><function parameter name="x">...</function parameter>`
|
||||
- `<invoke name="tool"><parameter name="x">...</parameter></invoke>`
|
||||
- `antml:function_call` / `antml:argument` / `antml:parameters`
|
||||
- `tool_use` 家族标签
|
||||
|
||||
但**输出端仍统一转换为 OpenAI 兼容 JSON 事件/对象**(`message.tool_calls`、`delta.tool_calls`、`response.function_call_arguments.*`)。
|
||||
|
||||
## 4) 关于“是否可以封装成 XML 再喂给模型”
|
||||
|
||||
结论:**可以做,而且当前解析器已经能兼容 XML 作为输入格式之一**,但代码里并没有 `toolcall.prefer_xml_output` 这个开关。现有可调配置只有:
|
||||
|
||||
- `toolcall.mode`:`feature_match` / `off`
|
||||
- `toolcall.early_emit_confidence`:`high` / `low` / `off`
|
||||
|
||||
推荐思路仍然是“输入兼容层 + 输出按客户端协议渲染”:
|
||||
|
||||
1. **Prompt 约束层**:如果你要尝试 XML-first,可以在系统提示词里约束模型输出规范 XML tool block(例如 `<tool_calls><tool_call>...</tool_call></tool_calls>`)。
|
||||
2. **解析兼容层**:继续在 parser 中同时接受 JSON / XML / ANTML / invoke / text-kv。
|
||||
3. **协议归一层**:无论模型输出什么格式,统一落到内部 `ParsedToolCall`。
|
||||
4. **对外渲染层**:根据客户端请求协议渲染(OpenAI / Claude / Gemini 各自格式)。
|
||||
|
||||
这样可以同时获得:
|
||||
|
||||
- 减少模型端 JSON 转义/引号错误;
|
||||
- 不破坏现有 SDK / 客户端生态;
|
||||
- 逐步灰度(按模型、按租户、按请求开关)。
|
||||
|
||||
## 5) 落地建议(低风险迭代)
|
||||
|
||||
- 继续使用现有的 `toolcall.mode=feature_match` 和 `toolcall.early_emit_confidence=high` 作为默认策略。
|
||||
- 如果要试 XML-first,把它放在 prompt 层或上游模板层,不要假设代码里已有专门的 XML 输出开关。
|
||||
- 增加观测指标:
|
||||
- `toolcall_parse_source`(json/xml/markup/textkv);
|
||||
- `toolcall_parse_success_rate`;
|
||||
- `toolcall_malformed_rate`;
|
||||
- `toolcall_repair_rate`。
|
||||
- 先在 `responses` 链路灰度,再扩展 `chat.completions`。
|
||||
|
||||
## 6) 兼容性提醒
|
||||
|
||||
- 上游模型若输出混合文本 + XML,仍可能出现“半结构化”噪声,需要依赖现有 sieve 增量消费策略。
|
||||
- XML 不等于安全:仍需做 tool 名、参数 schema、执行权限的服务端校验。
|
||||
@@ -194,7 +194,7 @@ func TestPoolAccountConcurrencyAliasEnv(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
|
||||
func TestPoolDropsLegacyTokenOnlyAccountOnLoad(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
@@ -203,19 +203,15 @@ func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
|
||||
|
||||
pool := NewPool(config.LoadStore())
|
||||
status := pool.Status()
|
||||
if got, ok := status["total"].(int); !ok || got != 1 {
|
||||
if got, ok := status["total"].(int); !ok || got != 0 {
|
||||
t.Fatalf("unexpected total in pool status: %#v", status["total"])
|
||||
}
|
||||
if got, ok := status["available"].(int); !ok || got != 1 {
|
||||
if got, ok := status["available"].(int); !ok || got != 0 {
|
||||
t.Fatalf("unexpected available in pool status: %#v", status["available"])
|
||||
}
|
||||
|
||||
acc, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatalf("expected acquire success for token-only account")
|
||||
}
|
||||
if acc.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token on acquired account: %q", acc.Token)
|
||||
if _, ok := pool.Acquire("", nil); ok {
|
||||
t.Fatalf("expected acquire to fail for token-only account")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -315,3 +315,122 @@ func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload string
|
||||
}{
|
||||
{name: "xml_tool_call", payload: `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`},
|
||||
{name: "xml_json_tool_call", payload: `<tool_call>{"tool":"Bash","params":{"command":"pwd"}}</tool_call>`},
|
||||
{name: "nested_tool_tag_style", payload: `<tool_call><tool name="Bash"><command>pwd</command></tool></tool_call>`},
|
||||
{name: "function_tag_style", payload: `<function_call>Bash</function_call><function parameter name="command">pwd</function parameter>`},
|
||||
{name: "antml_argument_style", payload: `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument></antml:function_call></antml:function_calls>`},
|
||||
{name: "antml_function_attr_parameters", payload: `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`},
|
||||
{name: "invoke_parameter_style", payload: `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></function_calls>`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"`+strings.ReplaceAll(tc.payload, `"`, `\"`)+`"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block for format %s, body=%s", tc.name, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeDetectsToolUseWithLeadingProse(t *testing.T) {
|
||||
h := &Handler{}
|
||||
payload := "I'll call a tool now.\\n<tool_use><tool_name>write_file</tool_name><parameters>{\\\"path\\\":\\\"/tmp/a.txt\\\",\\\"content\\\":\\\"abc\\\"}</parameters></tool_use>"
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"`+payload+`"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"write_file"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" && contentBlock["name"] == "write_file" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block with leading prose payload, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "tool_use" {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected stop_reason=tool_use, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Here is an example:\\n```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"Bash\\\",\\\"input\\\":{\\\"command\\\":\\\"pwd\\\"}}]}\"}",
|
||||
"data: {\"p\":\"response/content\",\"v\":\"\\n```\\nDo not execute it.\"}",
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundToolUse {
|
||||
t.Fatalf("expected no tool_use for fenced example, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundToolStop := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "tool_use" {
|
||||
foundToolStop = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundToolStop {
|
||||
t.Fatalf("expected stop_reason to remain content-only, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestHandleClaudeStreamRealtimePromotesUnclosedFencedToolExample(t *testing.T) {
|
||||
TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t)
|
||||
}
|
||||
|
||||
@@ -48,10 +48,88 @@ func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role preserved, got %#v", m["role"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") {
|
||||
t.Fatalf("expected serialized tool result marker, got %q", content)
|
||||
if content != "tool output" {
|
||||
t.Fatalf("expected raw tool output content preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_1",
|
||||
"name": "search_web",
|
||||
"input": map[string]any{"query": "latest"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized tool-call message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant role, got %#v", m["role"])
|
||||
}
|
||||
tc, _ := m["tool_calls"].([]any)
|
||||
if len(tc) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", m["tool_calls"])
|
||||
}
|
||||
call, _ := tc[0].(map[string]any)
|
||||
if call["id"] != "call_1" {
|
||||
t.Fatalf("expected call id preserved, got %#v", call)
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, "<tool_calls>") || !containsStr(content, "<tool_name>search_web</tool_name>") {
|
||||
t.Fatalf("expected assistant content to include XML tool call history, got %q", content)
|
||||
}
|
||||
if !containsStr(content, `<parameters>{"query":"latest"}</parameters>`) {
|
||||
t.Fatalf("expected assistant content to include serialized parameters, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesDoesNotPromoteUserToolUse(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_unsafe",
|
||||
"name": "dangerous_tool",
|
||||
"input": map[string]any{"value": "x"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "user" {
|
||||
t.Fatalf("expected user role preserved, got %#v", m["role"])
|
||||
}
|
||||
if _, ok := m["tool_calls"]; ok {
|
||||
t.Fatalf("expected no tool_calls promotion for user message, got %#v", m["tool_calls"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, `"type":"tool_use"`) || !containsStr(content, "dangerous_tool") {
|
||||
t.Fatalf("expected raw tool_use block preserved in user content, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,15 +165,63 @@ func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) {
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "Hello"},
|
||||
map[string]any{"type": "image", "source": "data:..."},
|
||||
map[string]any{"type": "image", "source": map[string]any{"type": "base64", "data": strings.Repeat("A", 2048)}},
|
||||
map[string]any{"type": "text", "text": "World"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "Hello\nWorld" {
|
||||
t.Fatalf("expected only text parts joined, got %q", m["content"])
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, "Hello") || !containsStr(content, "World") || !containsStr(content, `"type":"image"`) {
|
||||
t.Fatalf("expected text plus non-text block marker preserved, got %q", content)
|
||||
}
|
||||
if !containsStr(content, omittedBinaryMarker) {
|
||||
t.Fatalf("expected binary payload omitted marker, got %q", content)
|
||||
}
|
||||
if containsStr(content, strings.Repeat("A", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolResultNonTextPayloadStringified(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_image_1",
|
||||
"name": "vision_tool",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "image analysis"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{"type": "base64", "media_type": "image/png", "data": strings.Repeat("B", 2048)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", m["role"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, `"type":"tool_result"`) || !containsStr(content, `"type":"image"`) {
|
||||
t.Fatalf("expected non-text tool_result payload to be JSON stringified, got %q", content)
|
||||
}
|
||||
if !containsStr(content, omittedBinaryMarker) {
|
||||
t.Fatalf("expected binary data to be sanitized with omitted marker, got %q", content)
|
||||
}
|
||||
if containsStr(content, strings.Repeat("B", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,8 +251,11 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
||||
if !containsStr(prompt, "Search the web") {
|
||||
t.Fatalf("expected description in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "tool_calls") {
|
||||
t.Fatalf("expected tool_calls instruction in prompt")
|
||||
if !containsStr(prompt, "<tool_calls>") {
|
||||
t.Fatalf("expected XML tool_calls format in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "TOOL CALL FORMAT") {
|
||||
t.Fatalf("expected tool call format header in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,12 +301,9 @@ func TestBuildClaudeToolPromptSupportsOpenAIStyleFunctionTool(t *testing.T) {
|
||||
func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) {
|
||||
tools := []any{"not a map"}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if prompt == "" {
|
||||
t.Fatal("expected non-empty prompt even with invalid tools")
|
||||
}
|
||||
// Should still contain the intro and instruction
|
||||
if !containsStr(prompt, "You are Claude") {
|
||||
t.Fatalf("expected intro in prompt")
|
||||
// No valid tools → empty prompt
|
||||
if prompt != "" {
|
||||
t.Fatalf("expected empty prompt for non-map tools, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/prompt"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeClaudeMessages(messages []any) []any {
|
||||
@@ -13,71 +16,195 @@ func normalizeClaudeMessages(messages []any) []any {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
role := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", msg["role"])))
|
||||
switch content := msg["content"].(type) {
|
||||
case []any:
|
||||
parts := make([]string, 0, len(content))
|
||||
textParts := make([]string, 0, len(content))
|
||||
flushText := func() {
|
||||
if len(textParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": strings.Join(textParts, "\n"),
|
||||
})
|
||||
textParts = textParts[:0]
|
||||
}
|
||||
for _, block := range content {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr, _ := b["type"].(string)
|
||||
if typeStr == "text" {
|
||||
typeStr := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", b["type"])))
|
||||
switch typeStr {
|
||||
case "text":
|
||||
if t, ok := b["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
textParts = append(textParts, t)
|
||||
}
|
||||
case "tool_use":
|
||||
if role == "assistant" {
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolUseToAssistant(b); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
case "tool_result":
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolResultToToolMessage(b); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
default:
|
||||
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
}
|
||||
if typeStr == "tool_result" {
|
||||
parts = append(parts, formatClaudeToolResultForPrompt(b))
|
||||
}
|
||||
}
|
||||
copied["content"] = strings.Join(parts, "\n")
|
||||
flushText()
|
||||
default:
|
||||
copied := cloneMap(msg)
|
||||
out = append(out, copied)
|
||||
}
|
||||
out = append(out, copied)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeToolPrompt(tools []any) string {
|
||||
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, desc, schemaObj := extractClaudeToolMeta(m)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
schema, _ := json.Marshal(schemaObj)
|
||||
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
}
|
||||
parts = append(parts,
|
||||
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
|
||||
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
||||
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||
)
|
||||
return strings.Join(parts, "\n\n")
|
||||
if len(toolSchemas) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "You have access to these tools:\n\n" +
|
||||
strings.Join(toolSchemas, "\n\n") + "\n\n" +
|
||||
util.BuildToolCallInstructions(names)
|
||||
}
|
||||
|
||||
func formatClaudeToolResultForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "tool_result",
|
||||
"content": block["content"],
|
||||
}
|
||||
if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])); toolCallID != "" {
|
||||
payload["tool_call_id"] = toolCallID
|
||||
} else if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])); toolCallID != "" {
|
||||
payload["tool_call_id"] = toolCallID
|
||||
}
|
||||
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
|
||||
payload["name"] = name
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", payload))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
callID := strings.TrimSpace(fmt.Sprintf("%v", block["id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_claude"
|
||||
}
|
||||
arguments := block["input"]
|
||||
if arguments == nil {
|
||||
arguments = map[string]any{}
|
||||
}
|
||||
argsJSON, err := json.Marshal(arguments)
|
||||
if err != nil || len(argsJSON) == 0 {
|
||||
argsJSON = []byte("{}")
|
||||
}
|
||||
toolCalls := []any{
|
||||
map[string]any{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": string(argsJSON),
|
||||
},
|
||||
},
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"content": prompt.FormatToolCallsForPrompt(toolCalls),
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "unknown"
|
||||
toolCallID = "call_claude"
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": toolCallID,
|
||||
"content": normalizeClaudeToolResultContent(block["content"]),
|
||||
}
|
||||
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
|
||||
if content == "" {
|
||||
content = "null"
|
||||
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeClaudeToolResultContent(content any) any {
|
||||
if text, ok := content.(string); ok {
|
||||
return text
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "tool_result",
|
||||
"content": content,
|
||||
}
|
||||
b, err := json.Marshal(sanitizeClaudeBlockForPrompt(payload))
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", content))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func formatClaudeBlockRaw(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
b, err := json.Marshal(block)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", block))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
|
||||
105
internal/adapter/claude/handler_utils_sanitize.go
Normal file
105
internal/adapter/claude/handler_utils_sanitize.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
maxClaudeRawPromptChars = 1024
|
||||
omittedBinaryMarker = "[omitted_binary_payload]"
|
||||
)
|
||||
|
||||
func formatClaudeUnknownBlockForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
safe := sanitizeClaudeBlockForPrompt(block)
|
||||
raw := strings.TrimSpace(formatClaudeBlockRaw(safe))
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if len(raw) > maxClaudeRawPromptChars {
|
||||
return raw[:maxClaudeRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func sanitizeClaudeBlockForPrompt(block map[string]any) map[string]any {
|
||||
out := cloneMap(block)
|
||||
for k, v := range out {
|
||||
if looksLikeBinaryFieldName(k) {
|
||||
out[k] = omittedBinaryMarker
|
||||
continue
|
||||
}
|
||||
switch inner := v.(type) {
|
||||
case map[string]any:
|
||||
out[k] = sanitizeClaudeBlockForPrompt(inner)
|
||||
case []any:
|
||||
out[k] = sanitizeClaudeArrayForPrompt(inner)
|
||||
case string:
|
||||
out[k] = sanitizeClaudeStringForPrompt(k, inner)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeClaudeArrayForPrompt(items []any) []any {
|
||||
out := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
switch v := item.(type) {
|
||||
case map[string]any:
|
||||
out = append(out, sanitizeClaudeBlockForPrompt(v))
|
||||
case []any:
|
||||
out = append(out, sanitizeClaudeArrayForPrompt(v))
|
||||
default:
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeClaudeStringForPrompt(key, value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if looksLikeBinaryFieldName(key) || looksLikeBase64Payload(trimmed) {
|
||||
return omittedBinaryMarker
|
||||
}
|
||||
if len(trimmed) > maxClaudeRawPromptChars {
|
||||
return trimmed[:maxClaudeRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func looksLikeBinaryFieldName(name string) bool {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
return n == "data" || n == "bytes" || n == "base64" || n == "inline_data" || n == "inlinedata"
|
||||
}
|
||||
|
||||
func looksLikeBase64Payload(v string) bool {
|
||||
if len(v) < 512 {
|
||||
return false
|
||||
}
|
||||
compact := strings.TrimRight(v, "=")
|
||||
if compact == "" {
|
||||
return false
|
||||
}
|
||||
for _, ch := range compact {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func marshalCompactJSON(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -38,6 +38,9 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
|
||||
}
|
||||
finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
if len(toolNames) == 0 && len(toolsRequested) > 0 {
|
||||
toolNames = []string{"__any_tool__"}
|
||||
}
|
||||
|
||||
return claudeNormalizedRequest{
|
||||
Standard: util.StandardRequest{
|
||||
|
||||
@@ -116,6 +116,9 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferToolContent {
|
||||
if hasUnclosedCodeFence(s.text.String()) {
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
s.closeThinkingBlock()
|
||||
@@ -144,3 +147,7 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func hasUnclosedCodeFence(text string) bool {
|
||||
return strings.Count(text, "```")%2 == 1
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -45,9 +46,9 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
finalText := s.text.String()
|
||||
|
||||
if s.bufferToolContent {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
detected := util.ParseStandaloneToolCalls(finalText, s.toolNames)
|
||||
if len(detected) == 0 && finalText == "" && finalThinking != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, s.toolNames)
|
||||
detected = util.ParseStandaloneToolCalls(finalThinking, s.toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
@@ -60,9 +61,20 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
"input": map[string]any{},
|
||||
},
|
||||
})
|
||||
|
||||
inputBytes, _ := json.Marshal(tc.Input)
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": idx,
|
||||
"delta": map[string]any{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputBytes),
|
||||
},
|
||||
})
|
||||
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
|
||||
@@ -2,6 +2,8 @@ package gemini
|
||||
|
||||
import "strings"
|
||||
|
||||
const maxGeminiRawPromptChars = 1024
|
||||
|
||||
func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
out := make([]any, 0, 8)
|
||||
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
|
||||
@@ -107,6 +109,11 @@ func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
msg["name"] = name
|
||||
}
|
||||
out = append(out, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(formatGeminiUnknownPartForPrompt(part)); raw != "" && raw != "null" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
}
|
||||
flushText()
|
||||
@@ -151,3 +158,87 @@ func mapGeminiRole(v any) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func formatGeminiUnknownPartForPrompt(part map[string]any) string {
|
||||
safe := sanitizeGeminiPartForPrompt(part)
|
||||
raw := strings.TrimSpace(stringifyJSON(safe))
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if len(raw) > maxGeminiRawPromptChars {
|
||||
return raw[:maxGeminiRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func sanitizeGeminiPartForPrompt(part map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(part))
|
||||
for k, v := range part {
|
||||
if looksLikeGeminiBinaryField(k) {
|
||||
out[k] = "[omitted_binary_payload]"
|
||||
continue
|
||||
}
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
out[k] = sanitizeGeminiPartForPrompt(x)
|
||||
case []any:
|
||||
out[k] = sanitizeGeminiArrayForPrompt(x)
|
||||
case string:
|
||||
out[k] = sanitizeGeminiStringForPrompt(k, x)
|
||||
default:
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeGeminiArrayForPrompt(items []any) []any {
|
||||
out := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
out = append(out, sanitizeGeminiPartForPrompt(x))
|
||||
case []any:
|
||||
out = append(out, sanitizeGeminiArrayForPrompt(x))
|
||||
default:
|
||||
out = append(out, x)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeGeminiStringForPrompt(key, value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if looksLikeGeminiBinaryField(key) || looksLikeGeminiBase64(trimmed) {
|
||||
return "[omitted_binary_payload]"
|
||||
}
|
||||
if len(trimmed) > maxGeminiRawPromptChars {
|
||||
return trimmed[:maxGeminiRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func looksLikeGeminiBinaryField(name string) bool {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
return n == "data" || n == "bytes" || n == "inlinedata" || n == "inline_data" || n == "base64"
|
||||
}
|
||||
|
||||
func looksLikeGeminiBase64(v string) bool {
|
||||
if len(v) < 512 {
|
||||
return false
|
||||
}
|
||||
compact := strings.TrimRight(v, "=")
|
||||
if compact == "" {
|
||||
return false
|
||||
}
|
||||
for _, ch := range compact {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
84
internal/adapter/gemini/convert_messages_test.go
Normal file
84
internal/adapter/gemini/convert_messages_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGeminiMessagesFromRequestPreservesFunctionRoundtrip(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "model",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"id": "call_g1",
|
||||
"name": "search_web",
|
||||
"args": map[string]any{"query": "ai"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionResponse": map[string]any{
|
||||
"id": "call_g1",
|
||||
"name": "search_web",
|
||||
"response": "ok",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := geminiMessagesFromRequest(req)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected two normalized messages, got %#v", got)
|
||||
}
|
||||
assistant, _ := got[0].(map[string]any)
|
||||
if assistant["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant first, got %#v", assistant)
|
||||
}
|
||||
tc, _ := assistant["tool_calls"].([]any)
|
||||
if len(tc) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", assistant["tool_calls"])
|
||||
}
|
||||
toolMsg, _ := got[1].(map[string]any)
|
||||
if toolMsg["role"] != "tool" || toolMsg["tool_call_id"] != "call_g1" {
|
||||
t.Fatalf("expected tool message with call id, got %#v", toolMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{"text": "hello"},
|
||||
map[string]any{"inlineData": map[string]any{"mimeType": "image/png", "data": strings.Repeat("A", 2048)}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := geminiMessagesFromRequest(req)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %#v", got)
|
||||
}
|
||||
msg, _ := got[0].(map[string]any)
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "hello") || !strings.Contains(content, "inlineData") {
|
||||
t.Fatalf("expected unknown part preserved as raw json text, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "[omitted_binary_payload]") {
|
||||
t.Fatalf("expected inlineData payload to be redacted, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, strings.Repeat("A", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be embedded, got %q", content)
|
||||
}
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func TestGeminiRoutesRegistered(t *testing.T) {
|
||||
|
||||
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{
|
||||
@@ -143,6 +143,42 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{Store: testGeminiConfig{}, Auth: testGeminiAuth{}, DS: testGeminiDS{resp: upstream}}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{
|
||||
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
|
||||
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
candidates, _ := out["candidates"].([]any)
|
||||
c0, _ := candidates[0].(map[string]any)
|
||||
content, _ := c0["content"].(map[string]any)
|
||||
parts, _ := content["parts"].([]any)
|
||||
part0, _ := parts[0].(map[string]any)
|
||||
functionCall, _ := part0["functionCall"].(map[string]any)
|
||||
if functionCall["name"] != "eval_javascript" {
|
||||
t.Fatalf("expected functionCall name eval_javascript for mixed snippet, got %#v", functionCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"hello "}`,
|
||||
|
||||
@@ -97,12 +97,12 @@ func (s *chatStreamRuntime) sendDone() {
|
||||
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
detected := util.ParseStandaloneToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
||||
finalText := sanitizeLeakedOutput(s.text.String())
|
||||
detected := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
@@ -141,8 +141,12 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
}
|
||||
cleaned := sanitizeLeakedOutput(evt.Content)
|
||||
if cleaned == "" {
|
||||
continue
|
||||
}
|
||||
delta := map[string]any{
|
||||
"content": evt.Content,
|
||||
"content": cleaned,
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
@@ -158,7 +162,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(detected) > 0 || s.toolCallsEmitted {
|
||||
if len(detected.Calls) > 0 || s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
@@ -246,8 +250,12 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
cleaned := sanitizeLeakedOutput(evt.Content)
|
||||
if cleaned == "" {
|
||||
continue
|
||||
}
|
||||
contentDelta := map[string]any{
|
||||
"content": evt.Content,
|
||||
"content": cleaned,
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
contentDelta["role"] = "assistant"
|
||||
|
||||
@@ -19,6 +19,7 @@ type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
DeleteAllSessionsForToken(ctx context.Context, token string) error
|
||||
}
|
||||
|
||||
type ConfigReader interface {
|
||||
@@ -28,6 +29,7 @@ type ConfigReader interface {
|
||||
ToolcallEarlyEmitConfidence() string
|
||||
ResponsesStoreTTLSeconds() int
|
||||
EmbeddingsProvider() string
|
||||
AutoDeleteSessions() bool
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
|
||||
@@ -19,6 +19,7 @@ func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMo
|
||||
func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit }
|
||||
func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL }
|
||||
func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv }
|
||||
func (m mockOpenAIConfig) AutoDeleteSessions() bool { return false }
|
||||
|
||||
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||
cfg := mockOpenAIConfig{
|
||||
|
||||
@@ -35,7 +35,25 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
defer func() {
|
||||
// 自动删除会话(同步)
|
||||
// 必须在 Release 之前同步删除,否则:
|
||||
// 1. 异步删除时账号已被 Release
|
||||
// 2. 新请求可能获取到同一账号并开始使用
|
||||
// 3. 异步删除仍在进行,会截断新请求正在使用的会话
|
||||
if h.Store.AutoDeleteSessions() && a.DeepSeekToken != "" {
|
||||
deleteCtx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
err := h.DS.DeleteAllSessionsForToken(deleteCtx, a.DeepSeekToken)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[auto_delete_sessions] failed", "account", a.AccountID, "error", err)
|
||||
} else {
|
||||
config.Logger.Debug("[auto_delete_sessions] success", "account", a.AccountID)
|
||||
}
|
||||
}
|
||||
h.Auth.Release(a)
|
||||
}()
|
||||
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
var req map[string]any
|
||||
@@ -87,7 +105,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := result.Text
|
||||
finalText := sanitizeLeakedOutput(result.Text)
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
@@ -110,8 +128,8 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
}
|
||||
|
||||
created := time.Now().Unix()
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
emitEarlyToolDeltas := h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence()
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
|
||||
@@ -53,13 +53,13 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolCh
|
||||
if len(toolSchemas) == 0 {
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\n" + buildToolCallInstructions(names)
|
||||
if policy.Mode == util.ToolChoiceRequired {
|
||||
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
|
||||
toolPrompt += "\n7) For this response, you MUST call at least one tool from the allowed list."
|
||||
}
|
||||
if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
|
||||
toolPrompt += "\n5) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
||||
toolPrompt += "\n6) Do not call any other tool."
|
||||
toolPrompt += "\n7) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
||||
toolPrompt += "\n8) Do not call any other tool."
|
||||
}
|
||||
|
||||
for i := range messages {
|
||||
@@ -73,6 +73,11 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolCh
|
||||
return messages, names
|
||||
}
|
||||
|
||||
// buildToolCallInstructions delegates to the shared util implementation.
|
||||
func buildToolCallInstructions(toolNames []string) string {
|
||||
return util.BuildToolCallInstructions(toolNames)
|
||||
}
|
||||
|
||||
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
@@ -111,28 +116,21 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNam
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
allowed := namesToSet(allowedNames)
|
||||
if len(allowed) == 0 {
|
||||
for _, d := range deltas {
|
||||
if d.Name != "" {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
out := make([]toolCallDelta, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name != "" {
|
||||
if _, ok := allowed[d.Name]; !ok {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
continue
|
||||
if seenNames != nil {
|
||||
seenNames[d.Index] = d.Name
|
||||
}
|
||||
seenNames[d.Index] = d.Name
|
||||
out = append(out, d)
|
||||
continue
|
||||
}
|
||||
if seenNames == nil {
|
||||
out = append(out, d)
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(seenNames[d.Index])
|
||||
if name == "" || name == "__blocked__" {
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
|
||||
@@ -1,25 +1,9 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
|
||||
for k, v := range collectOpenAIChatPassThrough(req) {
|
||||
payload[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallFeatureMatchEnabled() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
|
||||
return mode == "" || mode == "feature_match"
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
|
||||
return level == "" || level == "high"
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
func TestHandleNonStreamUnknownToolIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
@@ -198,20 +198,17 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected unknown tool json to pass through as text, got %#v", content)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected tool_calls for unknown schema name, got %#v", msg["tool_calls"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamEmbeddedToolCallExampleRemainsText(t *testing.T) {
|
||||
func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
@@ -229,20 +226,21 @@ func TestHandleNonStreamEmbeddedToolCallExampleRemainsText(t *testing.T) {
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool_call field for embedded example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "下面是示例:") || !strings.Contains(content, "请勿执行。") || !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded example to remain plain text, got %#v", content)
|
||||
if strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected raw tool_calls json stripped from content, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
|
||||
func TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
||||
@@ -258,19 +256,25 @@ func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
if choice["finish_reason"] == "tool_calls" {
|
||||
t.Fatalf("expected fenced example to remain content-only, got finish_reason=%#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"])
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 0 {
|
||||
t.Fatalf("expected no tool_call field for fenced example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected fenced tool example to pass through as text, got %q", content)
|
||||
if !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected fenced example content preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -406,7 +410,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) {
|
||||
func TestHandleStreamUnknownToolEmitsToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
@@ -421,18 +425,18 @@ func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for unknown schema name, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) {
|
||||
func TestHandleStreamUnknownToolNoArgsEmitsToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\"}]}"}`,
|
||||
@@ -447,14 +451,14 @@ func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name (no args): %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -513,8 +517,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -531,11 +535,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded tool json to remain text in strict mode, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -555,8 +556,8 @@ func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -573,11 +574,9 @@ func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,8 +595,8 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T)
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -614,11 +613,90 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T)
|
||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestHandleStreamFencedToolCallSnippetPromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例,不要执行。"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected raw fenced tool_calls snippet stripped from content, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "```json") || strings.Contains(got, "\n```\n") {
|
||||
t.Fatalf("expected consumed fenced tool payload to not leave empty code fence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamStandaloneToolCallAfterClosedFenceKeepsFence(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "先给一个代码示例:\n```text\nhello\n```\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7g", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for standalone payload, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "```") {
|
||||
t.Fatalf("expected closed fence before standalone tool json to be preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -640,8 +718,8 @@ func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -655,14 +733,11 @@ func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") || !strings.Contains(got, "{") {
|
||||
t.Fatalf("expected embedded tool json to remain in text, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(got, "后置正文C。") {
|
||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
54
internal/adapter/openai/leaked_output_sanitize.go
Normal file
54
internal/adapter/openai/leaked_output_sanitize.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var emptyJSONFencePattern = regexp.MustCompile("(?is)```json\\s*```")
|
||||
var leakedToolCallArrayPattern = regexp.MustCompile(`(?is)\[\{\s*"function"\s*:\s*\{[\s\S]*?\}\s*,\s*"id"\s*:\s*"call[^"]*"\s*,\s*"type"\s*:\s*"function"\s*}\]`)
|
||||
var leakedToolResultBlobPattern = regexp.MustCompile(`(?is)<\s*\|\s*tool\s*\|\s*>\s*\{[\s\S]*?"tool_call_id"\s*:\s*"call[^"]*"\s*}`)
|
||||
|
||||
// leakedMetaMarkerPattern matches DeepSeek special tokens in BOTH forms:
|
||||
// - ASCII underscore: <|end_of_sentence|>
|
||||
// - U+2581 variant: <|end▁of▁sentence|> (used in some DeepSeek outputs)
|
||||
var leakedMetaMarkerPattern = regexp.MustCompile(`(?i)<[|\|]\s*(?:assistant|tool|end[_▁]of[_▁]sentence|end[_▁]of[_▁]thinking)\s*[|\|]>`)
|
||||
|
||||
// leakedAgentXMLBlockPatterns catch agent-style XML blocks that leak through
|
||||
// when the sieve fails to capture them. These are applied only to complete
|
||||
// wrapper blocks so standalone "<result>" examples in normal output remain
|
||||
// untouched.
|
||||
var leakedAgentXMLBlockPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?is)<attempt_completion\b[^>]*>(.*?)</attempt_completion>`),
|
||||
regexp.MustCompile(`(?is)<ask_followup_question\b[^>]*>(.*?)</ask_followup_question>`),
|
||||
regexp.MustCompile(`(?is)<new_task\b[^>]*>(.*?)</new_task>`),
|
||||
}
|
||||
|
||||
var leakedAgentResultTagPattern = regexp.MustCompile(`(?is)</?result>`)
|
||||
|
||||
func sanitizeLeakedOutput(text string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
out := emptyJSONFencePattern.ReplaceAllString(text, "")
|
||||
out = leakedToolCallArrayPattern.ReplaceAllString(out, "")
|
||||
out = leakedToolResultBlobPattern.ReplaceAllString(out, "")
|
||||
out = leakedMetaMarkerPattern.ReplaceAllString(out, "")
|
||||
out = sanitizeLeakedAgentXMLBlocks(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeLeakedAgentXMLBlocks(text string) string {
|
||||
out := text
|
||||
for _, pattern := range leakedAgentXMLBlockPatterns {
|
||||
out = pattern.ReplaceAllStringFunc(out, func(match string) string {
|
||||
submatches := pattern.FindStringSubmatch(match)
|
||||
if len(submatches) < 2 {
|
||||
return match
|
||||
}
|
||||
// Preserve the inner text so leaked agent instructions do not erase
|
||||
// the actual answer, but strip the wrapper/result markup itself.
|
||||
return leakedAgentResultTagPattern.ReplaceAllString(submatches[1], "")
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
43
internal/adapter/openai/leaked_output_sanitize_test.go
Normal file
43
internal/adapter/openai/leaked_output_sanitize_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeLeakedOutputRemovesEmptyJSONFence(t *testing.T) {
|
||||
raw := "before\n```json\n```\nafter"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
if got != "before\n\nafter" {
|
||||
t.Fatalf("unexpected sanitized empty json fence: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputRemovesLeakedWireToolCallAndResult(t *testing.T) {
|
||||
raw := "开始\n[{\"function\":{\"arguments\":\"{\\\"command\\\":\\\"java -version\\\"}\",\"name\":\"exec\"},\"id\":\"callb9a321\",\"type\":\"function\"}]< | Tool | >{\"content\":\"openjdk version 21\",\"tool_call_id\":\"callb9a321\"}\n结束"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
if got != "开始\n\n结束" {
|
||||
t.Fatalf("unexpected sanitize result for leaked wire format: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputRemovesStandaloneMetaMarkers(t *testing.T) {
|
||||
raw := "A<| end_of_sentence |><| Assistant |>B<| end_of_thinking |>C<|end▁of▁thinking|>D<|end▁of▁sentence|>E"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
if got != "ABCDE" {
|
||||
t.Fatalf("unexpected sanitize result for meta markers: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputRemovesAgentXMLLeaks(t *testing.T) {
|
||||
raw := "Done.<attempt_completion><result>Some final answer</result></attempt_completion>"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
if got != "Done.Some final answer" {
|
||||
t.Fatalf("unexpected sanitize result for agent XML leak: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedOutputPreservesStandaloneResultTags(t *testing.T) {
|
||||
raw := "Example XML: <result>value</result>"
|
||||
got := sanitizeLeakedOutput(raw)
|
||||
if got != raw {
|
||||
t.Fatalf("unexpected sanitize result for standalone result tag: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,13 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/prompt"
|
||||
)
|
||||
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||
_ = traceID
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
msg, ok := item.(map[string]any)
|
||||
@@ -18,24 +17,23 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
switch role {
|
||||
case "assistant":
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
toolCalls := formatAssistantToolCallsForPrompt(msg, traceID)
|
||||
combined := joinNonEmpty(content, toolCalls)
|
||||
if combined == "" {
|
||||
content := buildAssistantContentForPrompt(msg)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"content": combined,
|
||||
"content": content,
|
||||
})
|
||||
case "tool", "function":
|
||||
content := buildToolContentForPrompt(msg)
|
||||
out = append(out, map[string]any{
|
||||
"role": "user",
|
||||
"content": formatToolResultForPrompt(msg),
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
})
|
||||
case "user", "system":
|
||||
case "user", "system", "developer":
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
||||
})
|
||||
default:
|
||||
@@ -47,7 +45,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
role = "user"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
@@ -55,138 +53,39 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
return out
|
||||
}
|
||||
|
||||
func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string {
|
||||
entries := make([]string, 0)
|
||||
if calls, ok := msg["tool_calls"].([]any); ok {
|
||||
for i, item := range calls {
|
||||
call, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id := strings.TrimSpace(asString(call["id"]))
|
||||
if id == "" {
|
||||
id = fmt.Sprintf("call_%d", i+1)
|
||||
}
|
||||
name := strings.TrimSpace(asString(call["name"]))
|
||||
args := ""
|
||||
|
||||
if fn, ok := call["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asString(fn["name"]))
|
||||
}
|
||||
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
|
||||
}
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
if args == "" {
|
||||
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
|
||||
}
|
||||
if args == "" {
|
||||
args = normalizeOpenAIArgumentsForPrompt(call["input"])
|
||||
}
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
maybeWarnSuspiciousToolHistory(traceID, id, name, args)
|
||||
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args))
|
||||
}
|
||||
func buildAssistantContentForPrompt(msg map[string]any) string {
|
||||
content := strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
|
||||
toolHistory := prompt.FormatToolCallsForPrompt(msg["tool_calls"])
|
||||
switch {
|
||||
case content == "" && toolHistory == "":
|
||||
return ""
|
||||
case content == "":
|
||||
return toolHistory
|
||||
case toolHistory == "":
|
||||
return content
|
||||
default:
|
||||
return content + "\n\n" + toolHistory
|
||||
}
|
||||
|
||||
if legacy, ok := msg["function_call"].(map[string]any); ok {
|
||||
name := strings.TrimSpace(asString(legacy["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"])
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args)
|
||||
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args))
|
||||
}
|
||||
|
||||
return strings.Join(entries, "\n\n")
|
||||
}
|
||||
|
||||
func formatToolResultForPrompt(msg map[string]any) string {
|
||||
toolCallID := strings.TrimSpace(asString(msg["tool_call_id"]))
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(asString(msg["id"]))
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "unknown"
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(asString(msg["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
|
||||
func buildToolContentForPrompt(msg map[string]any) string {
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
if content == "" {
|
||||
content = "null"
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return "null"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||
return content
|
||||
}
|
||||
|
||||
func normalizeOpenAIContentForPrompt(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, item := range x {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
t := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
if t != "text" && t != "output_text" && t != "input_text" {
|
||||
continue
|
||||
}
|
||||
if text := asString(m["text"]); text != "" {
|
||||
parts = append(parts, text)
|
||||
continue
|
||||
}
|
||||
if text := asString(m["content"]); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return marshalToPromptString(v)
|
||||
}
|
||||
return prompt.NormalizeContent(v)
|
||||
}
|
||||
|
||||
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return normalizeToolArgumentString(x)
|
||||
default:
|
||||
return marshalToPromptString(v)
|
||||
func normalizeOpenAIRoleForPrompt(role string) string {
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if role == "developer" {
|
||||
return "system"
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeToolArgumentString(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if looksLikeConcatenatedJSON(trimmed) {
|
||||
// Keep original payload to avoid silent argument rewrites.
|
||||
return raw
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func marshalToPromptString(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
return string(b)
|
||||
return role
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
@@ -195,56 +94,3 @@ func asString(v any) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func joinNonEmpty(parts ...string) string {
|
||||
nonEmpty := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if strings.TrimSpace(p) == "" {
|
||||
continue
|
||||
}
|
||||
nonEmpty = append(nonEmpty, p)
|
||||
}
|
||||
return strings.Join(nonEmpty, "\n\n")
|
||||
}
|
||||
|
||||
func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) {
|
||||
if !looksLikeConcatenatedJSON(args) {
|
||||
return
|
||||
}
|
||||
traceID = strings.TrimSpace(traceID)
|
||||
if traceID == "" {
|
||||
traceID = "unknown"
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[openai] suspicious tool call history payload detected",
|
||||
"trace_id", traceID,
|
||||
"tool_call_id", strings.TrimSpace(callID),
|
||||
"name", strings.TrimSpace(name),
|
||||
"arguments_preview", previewToolArgs(args, 160),
|
||||
)
|
||||
}
|
||||
|
||||
func looksLikeConcatenatedJSON(raw string) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
|
||||
return true
|
||||
}
|
||||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||||
var first any
|
||||
if err := dec.Decode(&first); err != nil {
|
||||
return false
|
||||
}
|
||||
var second any
|
||||
return dec.Decode(&second) == nil
|
||||
}
|
||||
|
||||
func previewToolArgs(raw string, max int) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if max <= 0 || len(trimmed) <= max {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:max]
|
||||
}
|
||||
|
||||
@@ -35,23 +35,22 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 4 {
|
||||
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
|
||||
t.Fatalf("expected 4 normalized messages with assistant tool history preserved, got %d", len(normalized))
|
||||
}
|
||||
assistantContent, _ := normalized[2]["content"].(string)
|
||||
if !strings.Contains(assistantContent, "[TOOL_CALL_HISTORY]") ||
|
||||
!strings.Contains(assistantContent, "tool_call_id: call_1") ||
|
||||
!strings.Contains(assistantContent, "function.name: get_weather") ||
|
||||
!strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") {
|
||||
t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent)
|
||||
if !strings.Contains(assistantContent, "<tool_calls>") {
|
||||
t.Fatalf("assistant tool history should be preserved in XML form, got %q", assistantContent)
|
||||
}
|
||||
toolContent, _ := normalized[3]["content"].(string)
|
||||
if !strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") || !strings.Contains(toolContent, "name: get_weather") {
|
||||
t.Fatalf("tool result not serialized correctly: %q", toolContent)
|
||||
if !strings.Contains(assistantContent, "<tool_name>get_weather</tool_name>") {
|
||||
t.Fatalf("expected tool name in preserved history, got %q", assistantContent)
|
||||
}
|
||||
if !strings.Contains(normalized[3]["content"].(string), `"temp":18`) {
|
||||
t.Fatalf("tool result should be transparently forwarded, got %#v", normalized[3]["content"])
|
||||
}
|
||||
|
||||
prompt := util.MessagesPrepare(normalized)
|
||||
if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt)
|
||||
if !strings.Contains(prompt, "<tool_calls>") {
|
||||
t.Fatalf("expected preserved assistant tool history in prompt: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,8 +90,8 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "line-1\nline-2") {
|
||||
t.Fatalf("expected joined text blocks, got %q", got)
|
||||
if !strings.Contains(got, `line-1`) || !strings.Contains(got, `line-2`) {
|
||||
t.Fatalf("expected tool content blocks preserved, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,15 +111,42 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
if normalized[0]["role"] != "user" {
|
||||
t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"])
|
||||
if normalized[0]["role"] != "tool" {
|
||||
t.Fatalf("expected function role normalized as tool, got %#v", normalized[0]["role"])
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) {
|
||||
if !strings.Contains(got, `"ok":true`) || strings.Contains(got, `"name":"legacy_tool"`) {
|
||||
t.Fatalf("unexpected normalized function-role content: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_EmptyToolContentPreservedAsNull(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_5",
|
||||
"name": "noop_tool",
|
||||
"content": "",
|
||||
},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "done",
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 2 {
|
||||
t.Fatalf("expected tool completion turn to be preserved, got %#v", normalized)
|
||||
}
|
||||
if normalized[0]["role"] != "tool" {
|
||||
t.Fatalf("expected tool role preserved, got %#v", normalized[0]["role"])
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if got != "null" {
|
||||
t.Fatalf("expected empty tool content normalized as null string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
@@ -148,23 +174,14 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized assistant message, got %d", len(normalized))
|
||||
t.Fatalf("expected assistant tool_call-only message preserved, got %#v", normalized)
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 {
|
||||
t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content)
|
||||
if strings.Count(content, "<tool_call>") != 2 {
|
||||
t.Fatalf("expected two preserved tool call blocks, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") {
|
||||
t.Fatalf("missing first tool call block, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") {
|
||||
t.Fatalf("missing second tool call block, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, "search_webeval_javascript") {
|
||||
t.Fatalf("unexpected merged function name detected: %q", content)
|
||||
}
|
||||
if strings.Contains(content, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated function arguments detected: %q", content)
|
||||
if !strings.Contains(content, "<tool_name>search_web</tool_name>") || !strings.Contains(content, "<tool_name>eval_javascript</tool_name>") {
|
||||
t.Fatalf("expected both tool names in preserved history, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,12 +201,98 @@ func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected assistant tool_call-only content preserved, got %#v", normalized)
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(content, `{}{"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected concatenated tool arguments preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDropped(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_missing_name",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"arguments": `{"path":"README.MD"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 0 {
|
||||
t.Fatalf("expected assistant tool_calls without text to be dropped when name is missing, got %#v", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLiteral(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": nil,
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_screenshot",
|
||||
"function": map[string]any{
|
||||
"name": "send_file_to_user",
|
||||
"arguments": `{"file_path":"/tmp/a.png"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected nil-content assistant tool_call-only message preserved, got %#v", normalized)
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if strings.Contains(content, "null") {
|
||||
t.Fatalf("expected no null literal injection, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "<tool_calls>") {
|
||||
t.Fatalf("expected assistant tool history in normalized content, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_DeveloperRoleMapsToSystem(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{"role": "developer", "content": "必须先走工具调用"},
|
||||
map[string]any{"role": "user", "content": "你好"},
|
||||
}
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 2 {
|
||||
t.Fatalf("expected 2 normalized messages, got %d", len(normalized))
|
||||
}
|
||||
if normalized[0]["role"] != "system" {
|
||||
t.Fatalf("expected developer role converted to system, got %#v", normalized[0]["role"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextEmpty(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "", "content": "工具说明文本"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(content, `function.arguments: {}{"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
|
||||
if content != "工具说明文本" {
|
||||
t.Fatalf("expected content fallback text preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,11 +44,14 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
|
||||
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
|
||||
t.Fatalf("unexpected tool names: %#v", toolNames)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "tool_call_id: call_1") ||
|
||||
!strings.Contains(finalPrompt, "function.name: get_weather") ||
|
||||
!strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") ||
|
||||
!strings.Contains(finalPrompt, `"condition":"sunny"`) {
|
||||
t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt)
|
||||
if !strings.Contains(finalPrompt, `"condition":"sunny"`) {
|
||||
t.Fatalf("handler finalPrompt should preserve tool output content: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "<tool_calls>") {
|
||||
t.Fatalf("handler finalPrompt should preserve assistant tool history: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "<tool_name>get_weather</tool_name>") {
|
||||
t.Fatalf("handler finalPrompt should include tool name history: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,13 +74,19 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "")
|
||||
if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") {
|
||||
if !strings.Contains(finalPrompt, "After receiving a tool result, use it directly.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") {
|
||||
if !strings.Contains(finalPrompt, "Only call another tool if the result is insufficient.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing history marker instruction: %q", finalPrompt)
|
||||
if !strings.Contains(finalPrompt, "TOOL CALL FORMAT") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing xml format instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "Do NOT wrap the XML in markdown code fences") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing no-fence xml instruction: %q", finalPrompt)
|
||||
}
|
||||
if strings.Contains(finalPrompt, "```json") {
|
||||
t.Fatalf("vercel prepare finalPrompt should not require fenced tool calls: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +113,8 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(result.Text, toolNames)
|
||||
sanitizedText := sanitizeLeakedOutput(result.Text)
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(sanitizedText, toolNames)
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||
|
||||
callCount := len(textParsed.Calls)
|
||||
@@ -122,7 +123,7 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, sanitizedText, toolNames)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
@@ -145,8 +146,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
emitEarlyToolDeltas := h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence()
|
||||
|
||||
streamRuntime := newResponsesStreamRuntime(
|
||||
w,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/prompt"
|
||||
)
|
||||
|
||||
func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
@@ -19,6 +19,27 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role != "" {
|
||||
if role == "assistant" {
|
||||
out := map[string]any{
|
||||
"role": "assistant",
|
||||
}
|
||||
if toolCalls, ok := m["tool_calls"].([]any); ok && len(toolCalls) > 0 {
|
||||
out["tool_calls"] = toolCalls
|
||||
}
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content != nil {
|
||||
out["content"] = content
|
||||
}
|
||||
if _, hasToolCalls := out["tool_calls"]; hasToolCalls || out["content"] != nil {
|
||||
return out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
@@ -28,10 +49,22 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
out := map[string]any{
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
}
|
||||
if role == "tool" || role == "function" {
|
||||
if callID := strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
@@ -51,7 +84,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
role = "user"
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
}
|
||||
case "function_call_output", "tool_result":
|
||||
@@ -115,7 +148,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
|
||||
functionPayload := map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyToolCallArguments(argsRaw),
|
||||
"arguments": prompt.StringifyToolCallArguments(argsRaw),
|
||||
}
|
||||
call := map[string]any{
|
||||
"type": "function",
|
||||
@@ -178,26 +211,3 @@ func normalizeResponsesFallbackPart(m map[string]any) string {
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", m))
|
||||
}
|
||||
|
||||
func stringifyToolCallArguments(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
s = normalizeToolArgumentString(s)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ type responsesStreamRuntime struct {
|
||||
toolCallsDoneEmitted bool
|
||||
|
||||
sieve toolStreamSieveState
|
||||
thinkingSieve toolStreamSieveState
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
visibleText strings.Builder
|
||||
@@ -98,7 +97,7 @@ func newResponsesStreamRuntime(
|
||||
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
finalText := sanitizeLeakedOutput(s.text.String())
|
||||
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
@@ -169,15 +168,6 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed util.ToolCal
|
||||
logRejected(textParsed, "text")
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
||||
for _, done := range s.functionDone {
|
||||
if done {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
@@ -204,12 +194,16 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
continue
|
||||
}
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if !s.bufferToolContent {
|
||||
s.emitTextDelta(p.Text)
|
||||
cleanedText := sanitizeLeakedOutput(p.Text)
|
||||
if cleanedText == "" {
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
|
||||
s.text.WriteString(cleanedText)
|
||||
if !s.bufferToolContent {
|
||||
s.emitTextDelta(cleanedText)
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, cleanedText, s.toolNames), true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
|
||||
@@ -94,6 +94,16 @@ func (s *responsesStreamRuntime) closeMessageItem() {
|
||||
outputIndex := s.ensureMessageOutputIndex()
|
||||
text := s.visibleText.String()
|
||||
if s.messagePartAdded {
|
||||
s.sendEvent(
|
||||
"response.output_text.done",
|
||||
openaifmt.BuildResponsesTextDonePayload(
|
||||
s.responseID,
|
||||
itemID,
|
||||
outputIndex,
|
||||
0,
|
||||
text,
|
||||
),
|
||||
)
|
||||
s.sendEvent(
|
||||
"response.content_part.done",
|
||||
openaifmt.BuildResponsesContentPartDonePayload(
|
||||
|
||||
@@ -226,6 +226,40 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("hello") + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.output_text.done") {
|
||||
t.Fatalf("expected response.output_text.done payload, body=%s", body)
|
||||
}
|
||||
textDoneIdx := strings.Index(body, "event: response.output_text.done")
|
||||
partDoneIdx := strings.Index(body, "event: response.content_part.done")
|
||||
if textDoneIdx < 0 || partDoneIdx < 0 {
|
||||
t.Fatalf("expected output_text.done + content_part.done, body=%s", body)
|
||||
}
|
||||
if textDoneIdx > partDoneIdx {
|
||||
t.Fatalf("expected output_text.done before content_part.done, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -263,7 +297,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *testing.T) {
|
||||
func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -288,12 +322,8 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||
if len(addedPayloads) != 1 {
|
||||
t.Fatalf("expected only one message output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
}
|
||||
item, _ := addedPayloads[0]["item"].(map[string]any)
|
||||
if asString(item["type"]) != "message" {
|
||||
t.Fatalf("expected only message output item in strict mode, got %#v", item)
|
||||
if len(addedPayloads) < 1 {
|
||||
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
}
|
||||
|
||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
@@ -302,18 +332,29 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
|
||||
}
|
||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasMessage := false
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if asString(m["type"]) == "function_call" {
|
||||
t.Fatalf("did not expect function_call output for mixed prose tool example, output=%#v", output)
|
||||
if asString(m["type"]) == "message" {
|
||||
hasMessage = true
|
||||
}
|
||||
if asString(m["type"]) == "function_call" {
|
||||
hasFunctionCall = true
|
||||
}
|
||||
}
|
||||
if !hasMessage {
|
||||
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
func TestHandleResponsesStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -335,8 +376,8 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for tool_choice=none, body=%s", body)
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected function_call events for tool_choice=none, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,7 +518,7 @@ func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
|
||||
func TestHandleResponsesStreamAllowsUnknownToolName(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -498,8 +539,8 @@ func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for unknown tool, body=%s", body)
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected function_call events for unknown tool, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -556,7 +597,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
@@ -570,16 +611,20 @@ func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T)
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for tool_choice=none passthrough text, got %d body=%s", rec.Code, rec.Body.String())
|
||||
t.Fatalf("expected 200 for tool_choice=none handling, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
output, _ := out["output"].([]any)
|
||||
foundFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
t.Fatalf("did not expect function_call output item for tool_choice=none, got %#v", output)
|
||||
foundFunctionCall = true
|
||||
}
|
||||
}
|
||||
if !foundFunctionCall {
|
||||
t.Fatalf("expected function_call output item for tool_choice=none, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
@@ -634,18 +679,3 @@ func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asFloat(v any) float64 {
|
||||
switch x := v.(type) {
|
||||
case float64:
|
||||
return x
|
||||
case float32:
|
||||
return float64(x)
|
||||
case int:
|
||||
return float64(x)
|
||||
case int64:
|
||||
return float64(x)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
|
||||
}
|
||||
toolPolicy := util.DefaultToolChoicePolicy()
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
toolNames = ensureToolDetectionEnabled(toolNames, req["tools"])
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
@@ -74,10 +75,8 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
|
||||
return util.StandardRequest{}, err
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
|
||||
if toolPolicy.IsNone() {
|
||||
toolNames = nil
|
||||
toolPolicy.Allowed = nil
|
||||
} else {
|
||||
toolNames = ensureToolDetectionEnabled(toolNames, req["tools"])
|
||||
if !toolPolicy.IsNone() {
|
||||
toolPolicy.Allowed = namesToSet(toolNames)
|
||||
}
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
@@ -98,6 +97,20 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ensureToolDetectionEnabled(toolNames []string, toolsRaw any) []string {
|
||||
if len(toolNames) > 0 {
|
||||
return toolNames
|
||||
}
|
||||
tools, _ := toolsRaw.([]any)
|
||||
if len(tools) == 0 {
|
||||
return toolNames
|
||||
}
|
||||
// Keep stream sieve/tool buffering enabled even when client tool schemas
|
||||
// are malformed or lack explicit names; parsed tool payload names are no
|
||||
// longer filtered by this list.
|
||||
return []string{"__any_tool__"}
|
||||
}
|
||||
|
||||
func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
|
||||
out := map[string]any{}
|
||||
for _, k := range []string{
|
||||
|
||||
@@ -152,7 +152,7 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T) {
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceNoneKeepsToolDetectionEnabled(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
@@ -174,7 +174,7 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T
|
||||
if n.ToolChoice.Mode != util.ToolChoiceNone {
|
||||
t.Fatalf("expected tool choice mode none, got %q", n.ToolChoice.Mode)
|
||||
}
|
||||
if len(n.ToolNames) != 0 {
|
||||
t.Fatalf("expected no tool names when tool_choice=none, got %#v", n.ToolNames)
|
||||
if len(n.ToolNames) == 0 {
|
||||
t.Fatalf("expected tool detection sentinel when tool_choice=none, got %#v", n.ToolNames)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,10 @@ func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAut
|
||||
return m.resp, nil
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
|
||||
body := strings.Join(lines, "\n")
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
@@ -167,15 +171,15 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
outputText, _ := out["output_text"].(string)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one output item, got %#v", output)
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected message output item, got %#v", output)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,19 +15,9 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
if len(state.pendingToolCalls) > 0 {
|
||||
pending := state.pending.String()
|
||||
if strings.TrimSpace(pending) != "" {
|
||||
content := state.pendingToolRaw + pending
|
||||
state.pending.Reset()
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
} else {
|
||||
// Wait for either more non-whitespace content (demote to plain text)
|
||||
// or stream flush (promote to executable tool calls).
|
||||
return events
|
||||
}
|
||||
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -45,7 +35,14 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if len(calls) > 0 {
|
||||
state.pendingToolRaw = captured
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
_ = captured
|
||||
state.pendingToolCalls = calls
|
||||
continue
|
||||
}
|
||||
@@ -117,8 +114,14 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
} else {
|
||||
content := state.capture.String()
|
||||
if content != "" {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
// If the captured text looks like an incomplete XML tool call block,
|
||||
// swallow it to prevent leaking raw XML tags to the client.
|
||||
if hasOpenXMLToolTag(content) {
|
||||
// Drop it silently — incomplete tool call.
|
||||
} else {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
}
|
||||
}
|
||||
state.capture.Reset()
|
||||
@@ -127,8 +130,14 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
content := state.pending.String()
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
// Safety: if pending contains XML tool tag fragments (e.g. "tool_calls>"
|
||||
// from a split closing tag), swallow them instead of leaking.
|
||||
if hasOpenXMLToolTag(content) || looksLikeXMLToolTagFragment(content) {
|
||||
// Drop it — likely an incomplete tool call fragment.
|
||||
} else {
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
@@ -162,6 +171,10 @@ func findSuspiciousPrefixStart(s string) int {
|
||||
start = idx
|
||||
}
|
||||
}
|
||||
// Also check for partial XML tool tag at end of string.
|
||||
if xmlIdx := findPartialXMLToolTagStart(s); xmlIdx >= 0 && xmlIdx > start {
|
||||
start = xmlIdx
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
@@ -170,22 +183,44 @@ func findToolSegmentStart(s string) int {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
offset := 0
|
||||
for {
|
||||
keyRel := strings.Index(lower[offset:], "tool_calls")
|
||||
if keyRel < 0 {
|
||||
return -1
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:"}
|
||||
bestKeyIdx := -1
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||
bestKeyIdx = idx
|
||||
}
|
||||
keyIdx := offset + keyRel
|
||||
start := strings.LastIndex(s[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
start = keyIdx
|
||||
}
|
||||
if !insideCodeFence(s[:start]) {
|
||||
return start
|
||||
}
|
||||
offset = keyIdx + len("tool_calls")
|
||||
}
|
||||
// Also detect XML tool call tags.
|
||||
for _, tag := range xmlToolTagsToDetect {
|
||||
idx := strings.Index(lower, tag)
|
||||
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||
bestKeyIdx = idx
|
||||
}
|
||||
}
|
||||
if bestKeyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
// For XML tags, the '<' is itself the segment start.
|
||||
if bestKeyIdx < len(s) && s[bestKeyIdx] == '<' {
|
||||
if fenceStart, ok := openFenceStartBefore(s, bestKeyIdx); ok {
|
||||
return fenceStart
|
||||
}
|
||||
return bestKeyIdx
|
||||
}
|
||||
start := strings.LastIndex(s[:bestKeyIdx], "{")
|
||||
if start < 0 {
|
||||
start = bestKeyIdx
|
||||
}
|
||||
// If the keyword matched inside an XML tag (e.g. "tool_calls" in "<tool_calls>"),
|
||||
// back up past the '<' to capture the full tag.
|
||||
if start > 0 && s[start-1] == '<' {
|
||||
start--
|
||||
}
|
||||
if fenceStart, ok := openFenceStartBefore(s, start); ok {
|
||||
return fenceStart
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
@@ -193,14 +228,32 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
if captured == "" {
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
// Try XML tool call extraction first.
|
||||
if xmlPrefix, xmlCalls, xmlSuffix, xmlReady := consumeXMLToolCapture(captured, toolNames); xmlReady {
|
||||
return xmlPrefix, xmlCalls, xmlSuffix, true
|
||||
}
|
||||
// If XML tags are present but block is incomplete, keep buffering.
|
||||
if hasOpenXMLToolTag(captured) {
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
keyIdx := -1
|
||||
keywords := []string{"tool_calls", "\"function\"", "function.name:"}
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {
|
||||
keyIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if keyIdx < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return "", nil, "", false
|
||||
start = keyIdx
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
@@ -208,14 +261,6 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
}
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
// Strict mode: only standalone tool payloads are executable. If the
|
||||
// payload is wrapped by non-whitespace prose, keep it as plain text.
|
||||
if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||
if len(parsed.Calls) == 0 {
|
||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||
@@ -223,7 +268,11 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
// consume it to avoid leaking raw tool_calls JSON to user content.
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
// If it has obvious keywords but failed to parse even after loose repair,
|
||||
// we still might want to intercept it if it looks like an attempt at tool call.
|
||||
// For now, keep the original logic but rely on loose JSON repair.
|
||||
return captured, nil, "", true
|
||||
}
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed.Calls, suffixPart, true
|
||||
}
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
if state.disableDeltas {
|
||||
return nil
|
||||
}
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return nil
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return nil
|
||||
}
|
||||
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
||||
return nil
|
||||
}
|
||||
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
||||
if hasMultiple {
|
||||
state.disableDeltas = true
|
||||
return nil
|
||||
}
|
||||
if !certainSingle {
|
||||
// In uncertain phases (e.g. first call arrived but array not closed yet),
|
||||
// avoid speculative deltas and wait for final parsed tool_calls payload.
|
||||
return nil
|
||||
}
|
||||
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
deltas := make([]toolCallDelta, 0, 2)
|
||||
if state.toolName == "" {
|
||||
name, ok := extractToolCallName(captured, callStart)
|
||||
if !ok || name == "" {
|
||||
return nil
|
||||
}
|
||||
state.toolName = name
|
||||
}
|
||||
if state.toolArgsStart < 0 {
|
||||
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
|
||||
if ok {
|
||||
state.toolArgsString = stringMode
|
||||
if stringMode {
|
||||
state.toolArgsStart = argsStart + 1
|
||||
} else {
|
||||
state.toolArgsStart = argsStart
|
||||
}
|
||||
state.toolArgsSent = state.toolArgsStart
|
||||
}
|
||||
}
|
||||
if !state.toolNameSent {
|
||||
if state.toolArgsStart < 0 {
|
||||
return nil
|
||||
}
|
||||
state.toolNameSent = true
|
||||
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
|
||||
}
|
||||
if state.toolArgsStart < 0 || state.toolArgsDone {
|
||||
return deltas
|
||||
}
|
||||
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
|
||||
if !ok {
|
||||
return deltas
|
||||
}
|
||||
if end > state.toolArgsSent {
|
||||
deltas = append(deltas, toolCallDelta{
|
||||
Index: 0,
|
||||
Arguments: captured[state.toolArgsSent:end],
|
||||
})
|
||||
state.toolArgsSent = end
|
||||
}
|
||||
if complete {
|
||||
state.toolArgsDone = true
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return false, false
|
||||
}
|
||||
count := 0
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for ; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
if depth == 0 {
|
||||
count++
|
||||
if count > 1 {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
if depth > 0 {
|
||||
depth--
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == ',' && depth == 0 {
|
||||
// top-level separator means at least one more tool call exists
|
||||
// (or is expected). Treat as multi-call and stop incremental deltas.
|
||||
return false, true
|
||||
}
|
||||
if ch == ']' && depth == 0 {
|
||||
return count == 1, false
|
||||
}
|
||||
}
|
||||
// array not closed yet: still uncertain whether more calls will appear
|
||||
return false, false
|
||||
}
|
||||
|
||||
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
|
||||
i := keyIdx + len("tool_calls")
|
||||
for i < len(text) && text[i] != ':' {
|
||||
i++
|
||||
}
|
||||
if i >= len(text) {
|
||||
return -1, false
|
||||
}
|
||||
i = skipSpaces(text, i+1)
|
||||
if i >= len(text) || text[i] != '[' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func extractToolCallName(text string, callStart int) (string, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return "", false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
name, _, ok := parseJSONStringLiteral(text, valueStart)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return name, true
|
||||
}
|
||||
|
||||
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
|
||||
keys := []string{"input", "arguments", "args", "parameters", "params"}
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
|
||||
if !ok {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return -1, false, false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
|
||||
if !ok {
|
||||
return -1, false, false
|
||||
}
|
||||
}
|
||||
if valueStart >= len(text) {
|
||||
return -1, false, false
|
||||
}
|
||||
ch := text[valueStart]
|
||||
if ch == '{' || ch == '[' {
|
||||
return valueStart, false, true
|
||||
}
|
||||
if ch == '"' {
|
||||
return valueStart, true, true
|
||||
}
|
||||
return -1, false, false
|
||||
}
|
||||
|
||||
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
|
||||
if start < 0 || start > len(text) {
|
||||
return 0, false, false
|
||||
}
|
||||
if stringMode {
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return i, true, true
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
if start >= len(text) {
|
||||
return start, false, false
|
||||
}
|
||||
if text[start] != '{' && text[start] != '[' {
|
||||
return 0, false, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' || ch == '[' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' || ch == ']' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i + 1, true, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
|
||||
func findFunctionObjectStart(text string, callStart int) (int, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return valueStart, true
|
||||
}
|
||||
@@ -44,109 +44,41 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
|
||||
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
|
||||
return 0, false
|
||||
func trimWrappingJSONFence(prefix, suffix string) (string, string) {
|
||||
trimmedPrefix := strings.TrimRight(prefix, " \t\r\n")
|
||||
fenceIdx := strings.LastIndex(trimmedPrefix, "```")
|
||||
if fenceIdx < 0 {
|
||||
return prefix, suffix
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := objStart; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
if depth == 1 {
|
||||
key, end, ok := parseJSONStringLiteral(text, i)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
j := skipSpaces(text, end)
|
||||
if j >= len(text) || text[j] != ':' {
|
||||
i = end - 1
|
||||
continue
|
||||
}
|
||||
j = skipSpaces(text, j+1)
|
||||
if j >= len(text) {
|
||||
return 0, false
|
||||
}
|
||||
if containsKey(keys, key) {
|
||||
return j, true
|
||||
}
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Only strip when the trailing fence in prefix behaves like an opening fence.
|
||||
// A legitimate closing fence before a standalone tool JSON must be preserved.
|
||||
if strings.Count(trimmedPrefix[:fenceIdx+3], "```")%2 == 0 {
|
||||
return prefix, suffix
|
||||
}
|
||||
return 0, false
|
||||
fenceHeader := strings.TrimSpace(trimmedPrefix[fenceIdx+3:])
|
||||
if fenceHeader != "" && !strings.EqualFold(fenceHeader, "json") {
|
||||
return prefix, suffix
|
||||
}
|
||||
|
||||
trimmedSuffix := strings.TrimLeft(suffix, " \t\r\n")
|
||||
if !strings.HasPrefix(trimmedSuffix, "```") {
|
||||
return prefix, suffix
|
||||
}
|
||||
consumedLeading := len(suffix) - len(trimmedSuffix)
|
||||
return trimmedPrefix[:fenceIdx], suffix[consumedLeading+3:]
|
||||
}
|
||||
|
||||
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '"' {
|
||||
return "", 0, false
|
||||
func openFenceStartBefore(s string, pos int) (int, bool) {
|
||||
if pos <= 0 || pos > len(s) {
|
||||
return -1, false
|
||||
}
|
||||
var b strings.Builder
|
||||
escaped := false
|
||||
for i := start + 1; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
b.WriteByte(ch)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return b.String(), i + 1, true
|
||||
}
|
||||
b.WriteByte(ch)
|
||||
segment := s[:pos]
|
||||
lastFence := strings.LastIndex(segment, "```")
|
||||
if lastFence < 0 {
|
||||
return -1, false
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func containsKey(keys []string, value string) bool {
|
||||
for _, k := range keys {
|
||||
if k == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func skipSpaces(text string, i int) int {
|
||||
for i < len(text) {
|
||||
switch text[i] {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
i++
|
||||
default:
|
||||
return i
|
||||
}
|
||||
}
|
||||
return i
|
||||
if strings.Count(segment, "```")%2 == 1 {
|
||||
return lastFence, true
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
@@ -63,14 +63,3 @@ func appendTail(prev, next string, max int) string {
|
||||
}
|
||||
return combined[len(combined)-max:]
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
return insideCodeFence(text)
|
||||
}
|
||||
|
||||
func insideCodeFence(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Count(text, "```")%2 == 1
|
||||
}
|
||||
|
||||
147
internal/adapter/openai/tool_sieve_xml.go
Normal file
147
internal/adapter/openai/tool_sieve_xml.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// --- XML tool call support for the streaming sieve ---
|
||||
|
||||
var xmlToolCallClosingTags = []string{"</tool_calls>", "</tool_call>", "</invoke>", "</function_call>", "</function_calls>", "</tool_use>",
|
||||
// Agent-style XML tags (Roo Code, Cline, etc.)
|
||||
"</attempt_completion>", "</ask_followup_question>", "</new_task>", "</result>"}
|
||||
var xmlToolCallOpeningTags = []string{"<tool_calls", "<tool_call", "<invoke", "<function_call", "<function_calls", "<tool_use",
|
||||
// Agent-style XML tags
|
||||
"<attempt_completion", "<ask_followup_question", "<new_task", "<result"}
|
||||
|
||||
// xmlToolCallTagPairs maps each opening tag to its expected closing tag.
|
||||
// Order matters: longer/wrapper tags must be checked first.
|
||||
var xmlToolCallTagPairs = []struct{ open, close string }{
|
||||
{"<tool_calls", "</tool_calls>"},
|
||||
{"<tool_call", "</tool_call>"},
|
||||
{"<function_calls", "</function_calls>"},
|
||||
{"<function_call", "</function_call>"},
|
||||
{"<invoke", "</invoke>"},
|
||||
{"<tool_use", "</tool_use>"},
|
||||
// Agent-style: these are XML "tool call" patterns from coding agents.
|
||||
// They get captured → parsed. If parsing fails, the block is consumed
|
||||
// (swallowed) to prevent raw XML from leaking to the client.
|
||||
{"<attempt_completion", "</attempt_completion>"},
|
||||
{"<ask_followup_question", "</ask_followup_question>"},
|
||||
{"<new_task", "</new_task>"},
|
||||
}
|
||||
|
||||
// xmlToolCallBlockPattern matches a complete XML tool call block (wrapper or standalone).
|
||||
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)(<tool_calls>\s*(?:.*?)\s*</tool_calls>|<tool_call>\s*(?:.*?)\s*</tool_call>|<invoke\b[^>]*>(?:.*?)</invoke>|<function_calls?\b[^>]*>(?:.*?)</function_calls?>|<tool_use>(?:.*?)</tool_use>|<attempt_completion>(?:.*?)</attempt_completion>|<ask_followup_question>(?:.*?)</ask_followup_question>|<new_task>(?:.*?)</new_task>)`)
|
||||
|
||||
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
|
||||
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_call>", "<tool_call\n",
|
||||
"<invoke ", "<invoke>", "<function_call", "<function_calls", "<tool_use>",
|
||||
// Agent-style tags
|
||||
"<attempt_completion>", "<ask_followup_question>", "<new_task>"}
|
||||
|
||||
// consumeXMLToolCapture tries to extract complete XML tool call blocks from captured text.
|
||||
func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
lower := strings.ToLower(captured)
|
||||
// Find the FIRST matching open/close pair, preferring wrapper tags.
|
||||
// Tag pairs are ordered longest-first (e.g. <tool_calls before <tool_call)
|
||||
// so wrapper tags are checked before inner tags.
|
||||
for _, pair := range xmlToolCallTagPairs {
|
||||
openIdx := strings.Index(lower, pair.open)
|
||||
if openIdx < 0 {
|
||||
continue
|
||||
}
|
||||
// Find the LAST occurrence of the specific closing tag to get the outermost block.
|
||||
closeIdx := strings.LastIndex(lower, pair.close)
|
||||
if closeIdx < openIdx {
|
||||
// Opening tag is present but its specific closing tag hasn't arrived.
|
||||
// Return not-ready so we keep buffering — do NOT fall through to
|
||||
// try inner pairs (e.g. <tool_call inside <tool_calls).
|
||||
return "", nil, "", false
|
||||
}
|
||||
closeEnd := closeIdx + len(pair.close)
|
||||
|
||||
xmlBlock := captured[openIdx:closeEnd]
|
||||
prefixPart := captured[:openIdx]
|
||||
suffixPart := captured[closeEnd:]
|
||||
parsed := util.ParseToolCalls(xmlBlock, toolNames)
|
||||
if len(parsed) > 0 {
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
// Looks like XML tool syntax but failed to parse — consume it to avoid leak.
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
// hasOpenXMLToolTag returns true if captured text contains an XML tool opening tag
|
||||
// whose SPECIFIC closing tag has not appeared yet.
|
||||
func hasOpenXMLToolTag(captured string) bool {
|
||||
lower := strings.ToLower(captured)
|
||||
for _, pair := range xmlToolCallTagPairs {
|
||||
if strings.Contains(lower, pair.open) {
|
||||
if !strings.Contains(lower, pair.close) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// findPartialXMLToolTagStart checks if the string ends with a partial XML tool tag
|
||||
// (e.g., "<tool_ca" or "<inv") and returns the position of the '<'.
|
||||
func findPartialXMLToolTagStart(s string) int {
|
||||
lastLT := strings.LastIndex(s, "<")
|
||||
if lastLT < 0 {
|
||||
return -1
|
||||
}
|
||||
tail := s[lastLT:]
|
||||
// If there's a '>' in the tail, the tag is closed — not partial.
|
||||
if strings.Contains(tail, ">") {
|
||||
return -1
|
||||
}
|
||||
lowerTail := strings.ToLower(tail)
|
||||
// Check if the tail is a prefix of any known XML tool tag.
|
||||
for _, tag := range xmlToolCallOpeningTags {
|
||||
tagWithLT := tag
|
||||
if !strings.HasPrefix(tagWithLT, "<") {
|
||||
tagWithLT = "<" + tagWithLT
|
||||
}
|
||||
if strings.HasPrefix(tagWithLT, lowerTail) {
|
||||
return lastLT
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// looksLikeXMLToolTagFragment returns true if s looks like a fragment from a
|
||||
// split XML tool call tag — for example "tool_calls>" or "/tool_call>\n".
|
||||
// These fragments arise when '<' was consumed separately and the tail remains.
|
||||
func looksLikeXMLToolTagFragment(s string) bool {
|
||||
trimmed := strings.TrimSpace(s)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
// Check for closing tag tails like "tool_calls>" or "/tool_calls>"
|
||||
fragments := []string{
|
||||
"tool_calls>", "tool_call>", "/tool_calls>", "/tool_call>",
|
||||
"function_calls>", "function_call>", "/function_calls>", "/function_call>",
|
||||
"invoke>", "/invoke>", "tool_use>", "/tool_use>",
|
||||
"tool_name>", "/tool_name>", "parameters>", "/parameters>",
|
||||
// Agent-style tag fragments
|
||||
"attempt_completion>", "/attempt_completion>",
|
||||
"ask_followup_question>", "/ask_followup_question>",
|
||||
"new_task>", "/new_task>",
|
||||
"result>", "/result>",
|
||||
}
|
||||
for _, f := range fragments {
|
||||
if strings.Contains(lower, f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
319
internal/adapter/openai/tool_sieve_xml_test.go
Normal file
319
internal/adapter/openai/tool_sieve_xml_test.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProcessToolSieveInterceptsXMLToolCallWithoutLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Simulate a model producing XML tool call output chunk by chunk.
|
||||
chunks := []string{
|
||||
"<tool_calls>\n",
|
||||
" <tool_call>\n",
|
||||
" <tool_name>read_file</tool_name>\n",
|
||||
` <parameters>{"path":"README.MD"}</parameters>` + "\n",
|
||||
" </tool_call>\n",
|
||||
"</tool_calls>",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{"read_file"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent string
|
||||
var toolCalls int
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent += evt.Content
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
t.Fatalf("XML tool call content leaked to text: %q", textContent)
|
||||
}
|
||||
if strings.Contains(textContent, "read_file") {
|
||||
t.Fatalf("tool name leaked to text: %q", textContent)
|
||||
}
|
||||
if toolCalls == 0 {
|
||||
t.Fatal("expected tool calls to be extracted, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Model outputs some prose then an XML tool call.
|
||||
chunks := []string{
|
||||
"Let me check the file.\n",
|
||||
"<tool_calls>\n <tool_call>\n <tool_name>read_file</tool_name>\n",
|
||||
` <parameters>{"path":"go.mod"}</parameters>` + "\n </tool_call>\n</tool_calls>",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{"read_file"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent string
|
||||
var toolCalls int
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent += evt.Content
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
// Leading text should be emitted.
|
||||
if !strings.Contains(textContent, "Let me check the file.") {
|
||||
t.Fatalf("expected leading text to be emitted, got %q", textContent)
|
||||
}
|
||||
// The XML itself should NOT leak.
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
t.Fatalf("XML tool call content leaked to text: %q", textContent)
|
||||
}
|
||||
if toolCalls == 0 {
|
||||
t.Fatal("expected tool calls to be extracted, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSievePartialXMLTagHeldBack(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Chunk ends with a partial XML tool tag.
|
||||
events := processToolSieveChunk(&state, "Hello <tool_ca", []string{"read_file"})
|
||||
|
||||
var textContent string
|
||||
for _, evt := range events {
|
||||
textContent += evt.Content
|
||||
}
|
||||
|
||||
// "Hello " should be emitted, but "<tool_ca" should be held back.
|
||||
if strings.Contains(textContent, "<tool_ca") {
|
||||
t.Fatalf("partial XML tag should not be emitted, got %q", textContent)
|
||||
}
|
||||
if !strings.Contains(textContent, "Hello") {
|
||||
t.Fatalf("expected 'Hello' text to be emitted, got %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
want int
|
||||
}{
|
||||
{"tool_calls_tag", "some text <tool_calls>\n", 10},
|
||||
{"tool_call_tag", "prefix <tool_call>\n", 7},
|
||||
{"invoke_tag", "text <invoke name=\"foo\">body</invoke>", 5},
|
||||
{"function_call_tag", "<function_call name=\"foo\">body</function_call>", 0},
|
||||
{"no_xml", "just plain text", -1},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := findToolSegmentStart(tc.input)
|
||||
if got != tc.want {
|
||||
t.Fatalf("findToolSegmentStart(%q) = %d, want %d", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPartialXMLToolTagStart(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
want int
|
||||
}{
|
||||
{"partial_tool_call", "Hello <tool_ca", 6},
|
||||
{"partial_invoke", "Prefix <inv", 7},
|
||||
{"partial_lt_only", "Text <", 5},
|
||||
{"complete_tag", "Text <tool_call>done", -1},
|
||||
{"no_lt", "plain text", -1},
|
||||
{"closed_lt", "a < b > c", -1},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := findPartialXMLToolTagStart(tc.input)
|
||||
if got != tc.want {
|
||||
t.Fatalf("findPartialXMLToolTagStart(%q) = %d, want %d", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasOpenXMLToolTag(t *testing.T) {
|
||||
if !hasOpenXMLToolTag("<tool_call>\n<tool_name>foo</tool_name>") {
|
||||
t.Fatal("should detect open XML tool tag without closing tag")
|
||||
}
|
||||
if hasOpenXMLToolTag("<tool_call>\n<tool_name>foo</tool_name></tool_call>") {
|
||||
t.Fatal("should return false when closing tag is present")
|
||||
}
|
||||
if hasOpenXMLToolTag("plain text without any XML") {
|
||||
t.Fatal("should return false for plain text")
|
||||
}
|
||||
}
|
||||
|
||||
// Test the EXACT scenario the user reports: token-by-token streaming where
|
||||
// <tool_calls> tag arrives in small pieces.
|
||||
func TestProcessToolSieveTokenByTokenXMLNoLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Simulate DeepSeek model generating tokens one at a time.
|
||||
chunks := []string{
|
||||
"<",
|
||||
"tool",
|
||||
"_calls",
|
||||
">\n",
|
||||
" <",
|
||||
"tool",
|
||||
"_call",
|
||||
">\n",
|
||||
" <",
|
||||
"tool",
|
||||
"_name",
|
||||
">",
|
||||
"read",
|
||||
"_file",
|
||||
"</",
|
||||
"tool",
|
||||
"_name",
|
||||
">\n",
|
||||
" <",
|
||||
"parameters",
|
||||
">",
|
||||
`{"path"`,
|
||||
`: "README.MD"`,
|
||||
`}`,
|
||||
"</",
|
||||
"parameters",
|
||||
">\n",
|
||||
" </",
|
||||
"tool",
|
||||
"_call",
|
||||
">\n",
|
||||
"</",
|
||||
"tool",
|
||||
"_calls",
|
||||
">",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{"read_file"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent string
|
||||
var toolCalls int
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent += evt.Content
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
t.Fatalf("XML tool call content leaked to text in token-by-token mode: %q", textContent)
|
||||
}
|
||||
if strings.Contains(textContent, "tool_calls>") {
|
||||
t.Fatalf("closing tag fragment leaked to text: %q", textContent)
|
||||
}
|
||||
if strings.Contains(textContent, "read_file") {
|
||||
t.Fatalf("tool name leaked to text: %q", textContent)
|
||||
}
|
||||
if toolCalls == 0 {
|
||||
t.Fatal("expected tool calls to be extracted, got none")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that flushToolSieve on incomplete XML does NOT leak the raw XML content.
|
||||
func TestFlushToolSieveIncompleteXMLDoesNotLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// XML block starts but stream ends before completion.
|
||||
chunks := []string{
|
||||
"<tool_calls>\n",
|
||||
" <tool_call>\n",
|
||||
" <tool_name>read_file</tool_name>\n",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{"read_file"})...)
|
||||
}
|
||||
// Stream ends abruptly - flush should NOT dump raw XML.
|
||||
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent string
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent += evt.Content
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
t.Fatalf("incomplete XML leaked on flush: %q", textContent)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the opening tag "<tool_calls>\n " is NOT emitted as text content.
|
||||
func TestOpeningXMLTagNotLeakedAsContent(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// First chunk is the opening tag - should be held, not emitted.
|
||||
evts1 := processToolSieveChunk(&state, "<tool_calls>\n ", []string{"read_file"})
|
||||
for _, evt := range evts1 {
|
||||
if strings.Contains(evt.Content, "<tool_calls>") {
|
||||
t.Fatalf("opening tag leaked on first chunk: %q", evt.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Remaining content arrives.
|
||||
evts2 := processToolSieveChunk(&state, "<tool_call>\n <tool_name>read_file</tool_name>\n <parameters>{\"path\":\"README.MD\"}</parameters>\n </tool_call>\n</tool_calls>", []string{"read_file"})
|
||||
evts2 = append(evts2, flushToolSieve(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent string
|
||||
var toolCalls int
|
||||
allEvents := append(evts1, evts2...)
|
||||
for _, evt := range allEvents {
|
||||
if evt.Content != "" {
|
||||
textContent += evt.Content
|
||||
}
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<tool_call") {
|
||||
t.Fatalf("XML content leaked: %q", textContent)
|
||||
}
|
||||
if toolCalls == 0 {
|
||||
t.Fatal("expected tool calls to be extracted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveInterceptsAttemptCompletionLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
// Simulate an agent outputting attempt_completion XML tag
|
||||
// which shouldn't leak to text output, even if it fails to parse as a valid tool.
|
||||
chunks := []string{
|
||||
"Done with task.\n",
|
||||
"<attempt_completion>\n",
|
||||
" <result>Here is the answer</result>\n",
|
||||
"</attempt_completion>",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, c := range chunks {
|
||||
events = append(events, processToolSieveChunk(&state, c, []string{"attempt_completion"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"attempt_completion"})...)
|
||||
|
||||
var textContent string
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
textContent += evt.Content
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.Contains(textContent, "Done with task.\n") {
|
||||
t.Fatalf("expected leading text to be emitted, got %q", textContent)
|
||||
}
|
||||
|
||||
if strings.Contains(textContent, "<attempt_completion>") || strings.Contains(textContent, "result>") {
|
||||
t.Fatalf("agent XML tag content leaked to text: %q", textContent)
|
||||
}
|
||||
}
|
||||
@@ -93,18 +93,16 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
leased = true
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"session_id": sessionID,
|
||||
"lease_id": leaseID,
|
||||
"model": stdReq.ResponseModel,
|
||||
"final_prompt": stdReq.FinalPrompt,
|
||||
"thinking_enabled": stdReq.Thinking,
|
||||
"search_enabled": stdReq.Search,
|
||||
"tool_names": stdReq.ToolNames,
|
||||
"toolcall_feature_match": h.toolcallFeatureMatchEnabled(),
|
||||
"toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(),
|
||||
"deepseek_token": a.DeepSeekToken,
|
||||
"pow_header": powHeader,
|
||||
"payload": payload,
|
||||
"session_id": sessionID,
|
||||
"lease_id": leaseID,
|
||||
"model": stdReq.ResponseModel,
|
||||
"final_prompt": stdReq.FinalPrompt,
|
||||
"thinking_enabled": stdReq.Thinking,
|
||||
"search_enabled": stdReq.Search,
|
||||
"tool_names": stdReq.ToolNames,
|
||||
"deepseek_token": a.DeepSeekToken,
|
||||
"pow_header": powHeader,
|
||||
"payload": payload,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ type ConfigStore interface {
|
||||
FindAccount(identifier string) (config.Account, bool)
|
||||
UpdateAccountToken(identifier, token string) error
|
||||
UpdateAccountTestStatus(identifier, status string) error
|
||||
AccountTestStatus(identifier string) (string, bool)
|
||||
Update(mutator func(*config.Config) error) error
|
||||
ExportJSONAndBase64() (string, string, error)
|
||||
IsEnvBacked() bool
|
||||
@@ -27,6 +28,8 @@ type ConfigStore interface {
|
||||
RuntimeAccountMaxInflight() int
|
||||
RuntimeAccountMaxQueue(defaultSize int) int
|
||||
RuntimeGlobalMaxInflight(defaultSize int) int
|
||||
RuntimeTokenRefreshIntervalHours() int
|
||||
AutoDeleteSessions() bool
|
||||
}
|
||||
|
||||
type PoolController interface {
|
||||
@@ -40,6 +43,8 @@ type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
GetSessionCountForToken(ctx context.Context, token string) (*deepseek.SessionStats, error)
|
||||
DeleteAllSessionsForToken(ctx context.Context, token string) error
|
||||
}
|
||||
|
||||
var _ ConfigStore = (*config.Store)(nil)
|
||||
|
||||
@@ -31,12 +31,15 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Get("/queue/status", h.queueStatus)
|
||||
pr.Post("/accounts/test", h.testSingleAccount)
|
||||
pr.Post("/accounts/test-all", h.testAllAccounts)
|
||||
pr.Post("/accounts/sessions/delete-all", h.deleteAllSessions)
|
||||
pr.Post("/import", h.batchImport)
|
||||
pr.Post("/test", h.testAPI)
|
||||
pr.Post("/vercel/sync", h.syncVercel)
|
||||
pr.Get("/vercel/status", h.vercelStatus)
|
||||
pr.Post("/vercel/status", h.vercelStatus)
|
||||
pr.Get("/export", h.exportConfig)
|
||||
pr.Get("/dev/captures", h.getDevCaptures)
|
||||
pr.Delete("/dev/captures", h.clearDevCaptures)
|
||||
pr.Get("/version", h.getVersion)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
items := make([]map[string]any, 0, end-start)
|
||||
for _, acc := range accounts[start:end] {
|
||||
testStatus, _ := h.Store.AccountTestStatus(acc.Identifier())
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
@@ -70,7 +71,7 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
"has_password": acc.Password != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
"test_status": acc.TestStatus,
|
||||
"test_status": testStatus,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -26,9 +25,9 @@ func newAdminTestHandler(t *testing.T, raw string) *Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) {
|
||||
func TestListAccountsUsesEmailIdentifier(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
"accounts":[{"email":"u@example.com","password":"pwd"}]
|
||||
}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/accounts?page=1&page_size=10", nil)
|
||||
@@ -49,38 +48,8 @@ func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) {
|
||||
}
|
||||
first, _ := items[0].(map[string]any)
|
||||
identifier, _ := first["identifier"].(string)
|
||||
if identifier == "" {
|
||||
t.Fatalf("expected non-empty identifier: %#v", first)
|
||||
}
|
||||
if !strings.HasPrefix(identifier, "token:") {
|
||||
t.Fatalf("expected token synthetic identifier, got %q", identifier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAccountSupportsTokenOnlyIdentifier(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
}`)
|
||||
accounts := h.Store.Accounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||
}
|
||||
id := accounts[0].Identifier()
|
||||
if id == "" {
|
||||
t.Fatal("expected token-only synthetic identifier")
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape(id), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := len(h.Store.Accounts()); got != 0 {
|
||||
t.Fatalf("expected account removed, remaining=%d", got)
|
||||
if identifier != "u@example.com" {
|
||||
t.Fatalf("expected email identifier, got %q", identifier)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,11 +111,10 @@ func TestAddAccountRejectsCanonicalMobileDuplicate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||
func TestFindAccountByIdentifierSupportsMobile(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[
|
||||
{"email":"u@example.com","mobile":"13800138000","password":"pwd"},
|
||||
{"token":"token-only-account"}
|
||||
{"email":"u@example.com","mobile":"13800138000","password":"pwd"}
|
||||
]
|
||||
}`)
|
||||
|
||||
@@ -165,21 +133,4 @@ func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||
t.Fatalf("unexpected account by +86 mobile: %#v", accByMobileWithCountryCode)
|
||||
}
|
||||
|
||||
tokenOnlyID := ""
|
||||
for _, acc := range h.Store.Accounts() {
|
||||
if strings.TrimSpace(acc.Email) == "" && strings.TrimSpace(acc.Mobile) == "" {
|
||||
tokenOnlyID = acc.Identifier()
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenOnlyID == "" {
|
||||
t.Fatal("expected token-only account identifier")
|
||||
}
|
||||
accByTokenOnly, ok := findAccountByIdentifier(h.Store, tokenOnlyID)
|
||||
if !ok {
|
||||
t.Fatalf("expected find by token-only id=%q", tokenOnlyID)
|
||||
}
|
||||
if accByTokenOnly.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token-only account: %#v", accByTokenOnly)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +89,15 @@ func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int,
|
||||
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
||||
start := time.Now()
|
||||
identifier := acc.Identifier()
|
||||
result := map[string]any{"account": identifier, "success": false, "response_time": 0, "message": "", "model": model}
|
||||
result := map[string]any{
|
||||
"account": identifier,
|
||||
"success": false,
|
||||
"response_time": 0,
|
||||
"message": "",
|
||||
"model": model,
|
||||
"session_count": 0,
|
||||
"config_writable": !h.Store.IsEnvBacked(),
|
||||
}
|
||||
defer func() {
|
||||
status := "failed"
|
||||
if ok, _ := result["success"].(bool); ok {
|
||||
@@ -97,15 +105,14 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
|
||||
}
|
||||
_ = h.Store.UpdateAccountTestStatus(identifier, status)
|
||||
}()
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
if token == "" {
|
||||
newToken, err := h.DS.Login(ctx, acc)
|
||||
if err != nil {
|
||||
result["message"] = "登录失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
token = newToken
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
token, err := h.DS.Login(ctx, acc)
|
||||
if err != nil {
|
||||
result["message"] = "登录失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
if err := h.Store.UpdateAccountToken(acc.Identifier(), token); err != nil {
|
||||
result["message"] = "登录成功但写入运行时 token 失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
|
||||
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
|
||||
@@ -117,16 +124,26 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
|
||||
}
|
||||
token = newToken
|
||||
authCtx.DeepSeekToken = token
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
if err := h.Store.UpdateAccountToken(acc.Identifier(), token); err != nil {
|
||||
result["message"] = "刷新 token 成功但写入运行时 token 失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
result["message"] = "创建会话失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// 获取会话数量
|
||||
sessionStats, sessionErr := h.DS.GetSessionCountForToken(ctx, token)
|
||||
if sessionErr == nil && sessionStats != nil {
|
||||
result["session_count"] = sessionStats.FirstPageCount
|
||||
}
|
||||
|
||||
if strings.TrimSpace(message) == "" {
|
||||
result["success"] = true
|
||||
result["message"] = "API 测试成功(仅会话创建)"
|
||||
result["message"] = "Token 刷新成功(登录与会话创建成功)"
|
||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||
return result
|
||||
}
|
||||
@@ -210,3 +227,45 @@ func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAllSessions(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
identifier, _ := req["identifier"].(string)
|
||||
if strings.TrimSpace(identifier) == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
||||
return
|
||||
}
|
||||
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 每次先登录刷新一次 token,避免使用过期 token。
|
||||
token, err := h.DS.Login(r.Context(), acc)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "登录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
|
||||
// 删除所有会话
|
||||
err = h.DS.DeleteAllSessionsForToken(r.Context(), token)
|
||||
if err != nil {
|
||||
// token 可能过期,尝试重新登录并重试一次
|
||||
newToken, loginErr := h.DS.Login(r.Context(), acc)
|
||||
if loginErr != nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
token = newToken
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
if retryErr := h.DS.DeleteAllSessionsForToken(r.Context(), token); retryErr != nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + retryErr.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "删除成功"})
|
||||
}
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type testingDSMock struct {
|
||||
loginCalls int
|
||||
createSessionCalls int
|
||||
getPowCalls int
|
||||
callCompletionCalls int
|
||||
loginCalls int
|
||||
createSessionCalls int
|
||||
getPowCalls int
|
||||
callCompletionCalls int
|
||||
deleteAllSessionsCalls int
|
||||
deleteAllSessionsError error
|
||||
deleteAllSessionsErrorOnce bool
|
||||
}
|
||||
|
||||
func (m *testingDSMock) Login(_ context.Context, _ config.Account) (string, error) {
|
||||
@@ -38,6 +45,22 @@ func (m *testingDSMock) CallCompletion(_ context.Context, _ *auth.RequestAuth, _
|
||||
return nil, errors.New("should not call CallCompletion in this test")
|
||||
}
|
||||
|
||||
func (m *testingDSMock) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||
m.deleteAllSessionsCalls++
|
||||
if m.deleteAllSessionsError != nil {
|
||||
err := m.deleteAllSessionsError
|
||||
if m.deleteAllSessionsErrorOnce {
|
||||
m.deleteAllSessionsError = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testingDSMock) GetSessionCountForToken(_ context.Context, _ string) (*deepseek.SessionStats, error) {
|
||||
return &deepseek.SessionStats{Success: true}, nil
|
||||
}
|
||||
|
||||
func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":""}]}`)
|
||||
store := config.LoadStore()
|
||||
@@ -54,7 +77,7 @@ func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
|
||||
t.Fatalf("expected success=true, got %#v", result)
|
||||
}
|
||||
msg, _ := result["message"].(string)
|
||||
if !strings.Contains(msg, "仅会话创建") {
|
||||
if !strings.Contains(msg, "Token 刷新成功") {
|
||||
t.Fatalf("expected session-only success message, got %q", msg)
|
||||
}
|
||||
if ds.loginCalls != 1 || ds.createSessionCalls != 1 {
|
||||
@@ -70,7 +93,43 @@ func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
|
||||
if updated.Token != "new-token" {
|
||||
t.Fatalf("expected refreshed token to be persisted, got %q", updated.Token)
|
||||
}
|
||||
if updated.TestStatus != "ok" {
|
||||
t.Fatalf("expected test status ok, got %q", updated.TestStatus)
|
||||
testStatus, ok := store.AccountTestStatus("batch@example.com")
|
||||
if !ok || testStatus != "ok" {
|
||||
t.Fatalf("expected runtime test status ok, got %q (ok=%v)", testStatus, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAllSessions_RetryWithReloginOnDeleteFailure(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":"expired-token"}]}`)
|
||||
store := config.LoadStore()
|
||||
ds := &testingDSMock{deleteAllSessionsError: errors.New("token expired"), deleteAllSessionsErrorOnce: true}
|
||||
h := &Handler{Store: store, DS: ds}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/delete-all", bytes.NewBufferString(`{"identifier":"batch@example.com"}`))
|
||||
rec := httptest.NewRecorder()
|
||||
h.deleteAllSessions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", rec.Code)
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if ok, _ := resp["success"].(bool); !ok {
|
||||
t.Fatalf("expected success response, got %#v", resp)
|
||||
}
|
||||
if ds.loginCalls != 2 {
|
||||
t.Fatalf("expected initial login plus relogin, got %d", ds.loginCalls)
|
||||
}
|
||||
if ds.deleteAllSessionsCalls != 2 {
|
||||
t.Fatalf("expected delete called twice, got %d", ds.deleteAllSessionsCalls)
|
||||
}
|
||||
updated, ok := store.FindAccount("batch@example.com")
|
||||
if !ok {
|
||||
t.Fatal("expected account")
|
||||
}
|
||||
if updated.Token != "new-token" {
|
||||
t.Fatalf("expected refreshed token persisted, got %q", updated.Token)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
incoming.ClearAccountTokens()
|
||||
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err = h.Store.Update(func(c *config.Config) error {
|
||||
@@ -119,12 +120,6 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
next.ModelAliases[k] = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(incoming.Toolcall.Mode) != "" {
|
||||
next.Toolcall.Mode = incoming.Toolcall.Mode
|
||||
}
|
||||
if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" {
|
||||
next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence
|
||||
}
|
||||
if incoming.Responses.StoreTTLSeconds > 0 {
|
||||
next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds
|
||||
}
|
||||
@@ -149,6 +144,9 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
if incoming.Runtime.GlobalMaxInflight > 0 {
|
||||
next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight
|
||||
}
|
||||
if incoming.Runtime.TokenRefreshIntervalHours > 0 {
|
||||
next.Runtime.TokenRefreshIntervalHours = incoming.Runtime.TokenRefreshIntervalHours
|
||||
}
|
||||
}
|
||||
|
||||
normalizeSettingsConfig(&next)
|
||||
@@ -180,6 +178,7 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (h *Handler) computeSyncHash() string {
|
||||
snap := h.Store.Snapshot().Clone()
|
||||
snap.ClearAccountTokens()
|
||||
snap.VercelSyncHash = ""
|
||||
snap.VercelSyncTime = 0
|
||||
b, _ := json.Marshal(snap)
|
||||
|
||||
@@ -8,8 +8,9 @@ import (
|
||||
func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
safe := map[string]any{
|
||||
"keys": snap.Keys,
|
||||
"accounts": []map[string]any{},
|
||||
"keys": snap.Keys,
|
||||
"accounts": []map[string]any{},
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"claude_mapping": func() map[string]string {
|
||||
if len(snap.ClaudeMapping) > 0 {
|
||||
return snap.ClaudeMapping
|
||||
|
||||
@@ -50,9 +50,6 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(acc.Password) == "" {
|
||||
acc.Password = prev.Password
|
||||
}
|
||||
if strings.TrimSpace(acc.Token) == "" {
|
||||
acc.Token = prev.Token
|
||||
}
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
accounts = append(accounts, acc)
|
||||
|
||||
@@ -7,15 +7,29 @@ import (
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
|
||||
func boolFrom(v any) bool {
|
||||
if v == nil {
|
||||
return false
|
||||
}
|
||||
switch x := v.(type) {
|
||||
case bool:
|
||||
return x
|
||||
case string:
|
||||
return strings.ToLower(strings.TrimSpace(x)) == "true"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, map[string]string, map[string]string, error) {
|
||||
var (
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
toolcallCfg *config.ToolcallConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
autoDeleteCfg *config.AutoDeleteConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
)
|
||||
|
||||
if raw, ok := req["admin"].(map[string]any); ok {
|
||||
@@ -53,35 +67,19 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
}
|
||||
cfg.GlobalMaxInflight = n
|
||||
}
|
||||
if v, exists := raw["token_refresh_interval_hours"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 720 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.token_refresh_interval_hours must be between 1 and 720")
|
||||
}
|
||||
cfg.TokenRefreshIntervalHours = n
|
||||
}
|
||||
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
}
|
||||
runtimeCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["toolcall"].(map[string]any); ok {
|
||||
cfg := &config.ToolcallConfig{}
|
||||
if v, exists := raw["mode"]; exists {
|
||||
mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
switch mode {
|
||||
case "feature_match", "off":
|
||||
cfg.Mode = mode
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
|
||||
}
|
||||
}
|
||||
if v, exists := raw["early_emit_confidence"]; exists {
|
||||
level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
switch level {
|
||||
case "high", "low", "off":
|
||||
cfg.EarlyEmitConfidence = level
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
||||
}
|
||||
}
|
||||
toolcallCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["responses"].(map[string]any); ok {
|
||||
cfg := &config.ResponsesConfig{}
|
||||
if v, exists := raw["store_ttl_seconds"]; exists {
|
||||
@@ -98,9 +96,6 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
cfg := &config.EmbeddingsConfig{}
|
||||
if v, exists := raw["provider"]; exists {
|
||||
p := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if p == "" {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
|
||||
}
|
||||
cfg.Provider = p
|
||||
}
|
||||
embCfg = cfg
|
||||
@@ -130,5 +125,13 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
}
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
|
||||
if raw, ok := req["auto_delete"].(map[string]any); ok {
|
||||
cfg := &config.AutoDeleteConfig{}
|
||||
if v, exists := raw["sessions"]; exists {
|
||||
cfg.Sessions = boolFrom(v)
|
||||
}
|
||||
autoDeleteCfg = cfg
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, respCfg, embCfg, autoDeleteCfg, claudeMap, aliasMap, nil
|
||||
}
|
||||
|
||||
@@ -21,13 +21,14 @@ func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
|
||||
"default_password_warning": authn.UsingDefaultAdminKey(h.Store),
|
||||
},
|
||||
"runtime": map[string]any{
|
||||
"account_max_inflight": h.Store.RuntimeAccountMaxInflight(),
|
||||
"account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended),
|
||||
"global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
|
||||
"account_max_inflight": h.Store.RuntimeAccountMaxInflight(),
|
||||
"account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended),
|
||||
"global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
|
||||
"token_refresh_interval_hours": h.Store.RuntimeTokenRefreshIntervalHours(),
|
||||
},
|
||||
"toolcall": snap.Toolcall,
|
||||
"responses": snap.Responses,
|
||||
"embeddings": snap.Embeddings,
|
||||
"auto_delete": snap.AutoDelete,
|
||||
"claude_mapping": settingsClaudeMapping(snap),
|
||||
"model_aliases": snap.ModelAliases,
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
|
||||
@@ -14,6 +14,9 @@ func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *confi
|
||||
if incoming.GlobalMaxInflight > 0 {
|
||||
merged.GlobalMaxInflight = incoming.GlobalMaxInflight
|
||||
}
|
||||
if incoming.TokenRefreshIntervalHours > 0 {
|
||||
merged.TokenRefreshIntervalHours = incoming.TokenRefreshIntervalHours
|
||||
}
|
||||
}
|
||||
return validateRuntimeSettings(merged)
|
||||
}
|
||||
|
||||
@@ -28,6 +28,25 @@ func TestGetSettingsDefaultPasswordWarning(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSettingsIncludesTokenRefreshInterval(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
"runtime":{"token_refresh_interval_hours":9}
|
||||
}`)
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/settings", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
h.getSettings(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var body map[string]any
|
||||
_ = json.Unmarshal(rec.Body.Bytes(), &body)
|
||||
runtime, _ := body["runtime"].(map[string]any)
|
||||
if got := intFrom(runtime["token_refresh_interval_hours"]); got != 9 {
|
||||
t.Fatalf("expected token_refresh_interval_hours=9, got %d body=%v", got, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsValidation(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
payload := map[string]any{
|
||||
@@ -44,6 +63,25 @@ func TestUpdateSettingsValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsValidationRejectsTokenRefreshInterval(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
payload := map[string]any{
|
||||
"runtime": map[string]any{
|
||||
"token_refresh_interval_hours": 0,
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateSettings(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.token_refresh_interval_hours")) {
|
||||
t.Fatalf("expected token refresh validation detail, got %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsValidationWithMergedRuntimeSnapshot(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
@@ -126,6 +164,29 @@ func TestUpdateSettingsHotReloadRuntime(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsHotReloadTokenRefreshInterval(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
"runtime":{"token_refresh_interval_hours":6}
|
||||
}`)
|
||||
|
||||
payload := map[string]any{
|
||||
"runtime": map[string]any{
|
||||
"token_refresh_interval_hours": 12,
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateSettings(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := h.Store.RuntimeTokenRefreshIntervalHours(); got != 12 {
|
||||
t.Fatalf("token_refresh_interval_hours=%d want=12", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsPasswordInvalidatesOldJWT(t *testing.T) {
|
||||
hash := authn.HashAdminPassword("old-password")
|
||||
h := newAdminTestHandler(t, `{"admin":{"password_hash":"`+hash+`"}}`)
|
||||
@@ -207,6 +268,30 @@ func TestConfigImportMergeAndReplace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportAppliesTokenRefreshInterval(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
|
||||
replace := map[string]any{
|
||||
"mode": "replace",
|
||||
"config": map[string]any{
|
||||
"keys": []any{"k9"},
|
||||
"runtime": map[string]any{
|
||||
"token_refresh_interval_hours": 11,
|
||||
},
|
||||
},
|
||||
}
|
||||
replaceBytes, _ := json.Marshal(replace)
|
||||
replaceReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=replace", bytes.NewReader(replaceBytes))
|
||||
replaceRec := httptest.NewRecorder()
|
||||
h.configImport(replaceRec, replaceReq)
|
||||
if replaceRec.Code != http.StatusOK {
|
||||
t.Fatalf("replace status=%d body=%s", replaceRec.Code, replaceRec.Body.String())
|
||||
}
|
||||
if got := h.Store.RuntimeTokenRefreshIntervalHours(); got != 11 {
|
||||
t.Fatalf("token_refresh_interval_hours=%d want=11", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportRejectsInvalidRuntimeBounds(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
payload := map[string]any{
|
||||
|
||||
@@ -17,7 +17,7 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
adminCfg, runtimeCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
@@ -45,13 +45,8 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
if runtimeCfg.GlobalMaxInflight > 0 {
|
||||
c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
if toolcallCfg != nil {
|
||||
if strings.TrimSpace(toolcallCfg.Mode) != "" {
|
||||
c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode)
|
||||
}
|
||||
if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" {
|
||||
c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence)
|
||||
if runtimeCfg.TokenRefreshIntervalHours > 0 {
|
||||
c.Runtime.TokenRefreshIntervalHours = runtimeCfg.TokenRefreshIntervalHours
|
||||
}
|
||||
}
|
||||
if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 {
|
||||
@@ -60,6 +55,9 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
|
||||
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
|
||||
}
|
||||
if autoDeleteCfg != nil {
|
||||
c.AutoDelete.Sessions = autoDeleteCfg.Sessions
|
||||
}
|
||||
if claudeMap != nil {
|
||||
c.ClaudeMapping = claudeMap
|
||||
c.ClaudeModelMap = nil
|
||||
|
||||
@@ -3,6 +3,8 @@ package admin
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -11,6 +13,8 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -25,7 +29,7 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
validated, failed := h.validateAccountsForVercelSync(r.Context(), opts.AutoValidate)
|
||||
_, cfgB64, err := h.Store.ExportJSONAndBase64()
|
||||
cfgJSON, cfgB64, err := h.exportSyncConfig(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
@@ -47,7 +51,7 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
savedCreds := h.saveVercelProjectCredentials(r.Context(), client, opts, params, headers, envs)
|
||||
manual, deployURL := triggerVercelDeployment(r.Context(), client, opts.ProjectID, params, headers)
|
||||
_ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix())
|
||||
_ = h.Store.SetVercelSync(syncHashForJSON(cfgJSON), time.Now().Unix())
|
||||
result := map[string]any{"success": true, "validated_accounts": validated}
|
||||
if manual {
|
||||
result["message"] = "配置已同步到 Vercel,请手动触发重新部署"
|
||||
@@ -209,11 +213,71 @@ func triggerVercelDeployment(ctx context.Context, client *http.Client, projectID
|
||||
return false, deployURL
|
||||
}
|
||||
|
||||
func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
func (h *Handler) vercelStatus(w http.ResponseWriter, r *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
current := h.computeSyncHash()
|
||||
synced := snap.VercelSyncHash != "" && snap.VercelSyncHash == current
|
||||
writeJSON(w, http.StatusOK, map[string]any{"synced": synced, "last_sync_time": nilIfZero(snap.VercelSyncTime), "has_synced_before": snap.VercelSyncHash != ""})
|
||||
draftHash := ""
|
||||
draftDiffers := false
|
||||
if r != nil && r.Method == http.MethodPost && r.Body != nil {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
|
||||
if cfgJSON, _, err := h.exportSyncConfig(req); err == nil {
|
||||
draftHash = syncHashForJSON(cfgJSON)
|
||||
draftDiffers = draftHash != "" && draftHash != current
|
||||
}
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"synced": synced,
|
||||
"last_sync_time": nilIfZero(snap.VercelSyncTime),
|
||||
"has_synced_before": snap.VercelSyncHash != "",
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"config_hash": current,
|
||||
"last_synced_hash": snap.VercelSyncHash,
|
||||
"draft_hash": draftHash,
|
||||
"draft_differs": draftDiffers,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) exportSyncConfig(req map[string]any) (string, string, error) {
|
||||
override, ok := req["config_override"]
|
||||
if !ok || override == nil {
|
||||
return h.Store.ExportJSONAndBase64()
|
||||
}
|
||||
raw, err := json.Marshal(override)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
var cfg config.Config
|
||||
if err := json.Unmarshal(raw, &cfg); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
cfg.DropInvalidAccounts()
|
||||
cfg.ClearAccountTokens()
|
||||
cfg.VercelSyncHash = ""
|
||||
cfg.VercelSyncTime = 0
|
||||
b, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return string(b), base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func syncHashForJSON(s string) string {
|
||||
var cfg config.Config
|
||||
if err := json.Unmarshal([]byte(s), &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
cfg.VercelSyncHash = ""
|
||||
cfg.VercelSyncTime = 0
|
||||
cfg.ClearAccountTokens()
|
||||
b, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
sum := md5.Sum(b)
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
|
||||
func vercelRequest(ctx context.Context, client *http.Client, method, endpoint string, params url.Values, headers map[string]string, body any) (map[string]any, int, error) {
|
||||
|
||||
75
internal/admin/handler_version.go
Normal file
75
internal/admin/handler_version.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/version"
|
||||
)
|
||||
|
||||
const latestReleaseAPI = "https://api.github.com/repos/CJackHwang/ds2api/releases/latest"
|
||||
|
||||
type latestReleasePayload struct {
|
||||
TagName string `json:"tag_name"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
}
|
||||
|
||||
func (h *Handler) getVersion(w http.ResponseWriter, _ *http.Request) {
|
||||
current, source := version.Current()
|
||||
resp := map[string]any{
|
||||
"success": true,
|
||||
"current_version": current,
|
||||
"current_tag": version.Tag(current),
|
||||
"source": source,
|
||||
"checked_at": time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, latestReleaseAPI, nil)
|
||||
if err != nil {
|
||||
resp["check_error"] = err.Error()
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("User-Agent", "ds2api-version-check")
|
||||
|
||||
client := &http.Client{Timeout: 4 * time.Second}
|
||||
r, err := client.Do(req)
|
||||
if err != nil {
|
||||
resp["check_error"] = err.Error()
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
if r.StatusCode < 200 || r.StatusCode >= 300 {
|
||||
resp["check_error"] = "github api status: " + r.Status
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
var data latestReleasePayload
|
||||
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
|
||||
resp["check_error"] = err.Error()
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
latest := strings.TrimSpace(data.TagName)
|
||||
if latest == "" {
|
||||
resp["check_error"] = "missing latest tag"
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
latestVersion := strings.TrimPrefix(latest, "v")
|
||||
|
||||
resp["latest_tag"] = latest
|
||||
resp["latest_version"] = latestVersion
|
||||
resp["release_url"] = data.HTMLURL
|
||||
resp["published_at"] = data.PublishedAt
|
||||
resp["has_update"] = version.Compare(current, latestVersion) < 0
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
@@ -65,7 +65,6 @@ func toAccount(m map[string]any) config.Account {
|
||||
Email: email,
|
||||
Mobile: mobile,
|
||||
Password: fieldString(m, "password"),
|
||||
Token: fieldString(m, "token"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -188,8 +188,8 @@ func TestToAccountAllFields(t *testing.T) {
|
||||
if acc.Password != "secret" {
|
||||
t.Fatalf("unexpected password: %q", acc.Password)
|
||||
}
|
||||
if acc.Token != "tok123" {
|
||||
t.Fatalf("unexpected token: %q", acc.Token)
|
||||
if acc.Token != "" {
|
||||
t.Fatalf("expected token to be ignored, got %q", acc.Token)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,8 +12,6 @@ func normalizeSettingsConfig(c *config.Config) {
|
||||
return
|
||||
}
|
||||
c.Admin.PasswordHash = strings.TrimSpace(c.Admin.PasswordHash)
|
||||
c.Toolcall.Mode = strings.ToLower(strings.TrimSpace(c.Toolcall.Mode))
|
||||
c.Toolcall.EarlyEmitConfidence = strings.ToLower(strings.TrimSpace(c.Toolcall.EarlyEmitConfidence))
|
||||
c.Embeddings.Provider = strings.TrimSpace(c.Embeddings.Provider)
|
||||
}
|
||||
|
||||
@@ -27,20 +25,6 @@ func validateSettingsConfig(c config.Config) error {
|
||||
if c.Responses.StoreTTLSeconds != 0 && (c.Responses.StoreTTLSeconds < 30 || c.Responses.StoreTTLSeconds > 86400) {
|
||||
return fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
|
||||
}
|
||||
if mode := strings.TrimSpace(c.Toolcall.Mode); mode != "" {
|
||||
switch mode {
|
||||
case "feature_match", "off":
|
||||
default:
|
||||
return fmt.Errorf("toolcall.mode must be feature_match or off")
|
||||
}
|
||||
}
|
||||
if level := strings.TrimSpace(c.Toolcall.EarlyEmitConfidence); level != "" {
|
||||
switch level {
|
||||
case "high", "low", "off":
|
||||
default:
|
||||
return fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
||||
}
|
||||
}
|
||||
if c.Embeddings.Provider != "" && strings.TrimSpace(c.Embeddings.Provider) == "" {
|
||||
return fmt.Errorf("embeddings.provider cannot be empty")
|
||||
}
|
||||
@@ -57,6 +41,9 @@ func validateRuntimeSettings(runtime config.RuntimeConfig) error {
|
||||
if runtime.GlobalMaxInflight != 0 && (runtime.GlobalMaxInflight < 1 || runtime.GlobalMaxInflight > 200000) {
|
||||
return fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
|
||||
}
|
||||
if runtime.TokenRefreshIntervalHours != 0 && (runtime.TokenRefreshIntervalHours < 1 || runtime.TokenRefreshIntervalHours > 720) {
|
||||
return fmt.Errorf("runtime.token_refresh_interval_hours must be between 1 and 720")
|
||||
}
|
||||
if runtime.AccountMaxInflight > 0 && runtime.GlobalMaxInflight > 0 && runtime.GlobalMaxInflight < runtime.AccountMaxInflight {
|
||||
return fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
}
|
||||
|
||||
109
internal/admin/token_runtime_http_test.go
Normal file
109
internal/admin/token_runtime_http_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newHTTPAdminHarness(t *testing.T, rawConfig string, ds DeepSeekCaller) http.Handler {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", rawConfig)
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
store := config.LoadStore()
|
||||
h := &Handler{
|
||||
Store: store,
|
||||
Pool: account.NewPool(store),
|
||||
DS: ds,
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
return r
|
||||
}
|
||||
|
||||
func adminReq(method, path string, body []byte) *http.Request {
|
||||
req := httptest.NewRequest(method, path, bytes.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer admin")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req
|
||||
}
|
||||
|
||||
func TestConfigImportIgnoresTokenFieldInPayload(t *testing.T) {
|
||||
ds := &testingDSMock{}
|
||||
router := newHTTPAdminHarness(t, `{"accounts":[]}`, ds)
|
||||
|
||||
payload := []byte(`{
|
||||
"mode":"replace",
|
||||
"config":{
|
||||
"accounts":[{"email":"u@example.com","password":"pwd","token":"expired-token"}]
|
||||
}
|
||||
}`)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, adminReq(http.MethodPost, "/config/import", payload))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("import status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
readRec := httptest.NewRecorder()
|
||||
router.ServeHTTP(readRec, adminReq(http.MethodGet, "/config", nil))
|
||||
if readRec.Code != http.StatusOK {
|
||||
t.Fatalf("get config status=%d body=%s", readRec.Code, readRec.Body.String())
|
||||
}
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(readRec.Body.Bytes(), &data); err != nil {
|
||||
t.Fatalf("decode config response: %v", err)
|
||||
}
|
||||
accounts, _ := data["accounts"].([]any)
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected one account, got %d", len(accounts))
|
||||
}
|
||||
accountMap, _ := accounts[0].(map[string]any)
|
||||
if hasToken, _ := accountMap["has_token"].(bool); hasToken {
|
||||
t.Fatalf("expected imported token to be ignored, account=%#v", accountMap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountTestRefreshesRuntimeTokenButExportOmitsToken(t *testing.T) {
|
||||
ds := &testingDSMock{}
|
||||
router := newHTTPAdminHarness(t, `{
|
||||
"accounts":[{"email":"batch@example.com","password":"pwd","token":"stale-token"}]
|
||||
}`, ds)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, adminReq(http.MethodPost, "/accounts/test", []byte(`{"identifier":"batch@example.com"}`)))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("test account status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var testResp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &testResp); err != nil {
|
||||
t.Fatalf("decode test response: %v", err)
|
||||
}
|
||||
if ok, _ := testResp["success"].(bool); !ok {
|
||||
t.Fatalf("expected test success, got %#v", testResp)
|
||||
}
|
||||
if ds.loginCalls < 1 {
|
||||
t.Fatalf("expected login to be called at least once, got %d", ds.loginCalls)
|
||||
}
|
||||
|
||||
exportRec := httptest.NewRecorder()
|
||||
router.ServeHTTP(exportRec, adminReq(http.MethodGet, "/config/export", nil))
|
||||
if exportRec.Code != http.StatusOK {
|
||||
t.Fatalf("export status=%d body=%s", exportRec.Code, exportRec.Body.String())
|
||||
}
|
||||
var exportResp map[string]any
|
||||
if err := json.Unmarshal(exportRec.Body.Bytes(), &exportResp); err != nil {
|
||||
t.Fatalf("decode export response: %v", err)
|
||||
}
|
||||
exportJSON, _ := exportResp["json"].(string)
|
||||
if strings.Contains(exportJSON, `"token"`) {
|
||||
t.Fatalf("expected export json to omit tokens, got %s", exportJSON)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
@@ -37,10 +39,18 @@ type Resolver struct {
|
||||
Store *config.Store
|
||||
Pool *account.Pool
|
||||
Login LoginFunc
|
||||
|
||||
mu sync.Mutex
|
||||
tokenRefreshedAt map[string]time.Time
|
||||
}
|
||||
|
||||
func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Resolver {
|
||||
return &Resolver{Store: store, Pool: pool, Login: login}
|
||||
return &Resolver{
|
||||
Store: store,
|
||||
Pool: pool,
|
||||
Login: login,
|
||||
tokenRefreshedAt: map[string]time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
@@ -72,13 +82,9 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
TriedAccounts: map[string]bool{},
|
||||
resolver: r,
|
||||
}
|
||||
if acc.Token == "" {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
r.Pool.Release(a.AccountID)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
a.DeepSeekToken = acc.Token
|
||||
if err := r.ensureManagedToken(ctx, a); err != nil {
|
||||
r.Pool.Release(a.AccountID)
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
@@ -120,6 +126,7 @@ func (r *Resolver) loginAndPersist(ctx context.Context, a *RequestAuth) error {
|
||||
}
|
||||
a.Account.Token = token
|
||||
a.DeepSeekToken = token
|
||||
r.markTokenRefreshedNow(a.AccountID)
|
||||
return r.Store.UpdateAccountToken(a.AccountID, token)
|
||||
}
|
||||
|
||||
@@ -142,6 +149,7 @@ func (r *Resolver) MarkTokenInvalid(a *RequestAuth) {
|
||||
}
|
||||
a.Account.Token = ""
|
||||
a.DeepSeekToken = ""
|
||||
r.clearTokenRefreshMark(a.AccountID)
|
||||
_ = r.Store.UpdateAccountToken(a.AccountID, "")
|
||||
}
|
||||
|
||||
@@ -162,12 +170,8 @@ func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool {
|
||||
}
|
||||
a.Account = acc
|
||||
a.AccountID = acc.Identifier()
|
||||
if acc.Token == "" {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
a.DeepSeekToken = acc.Token
|
||||
if err := r.ensureManagedToken(ctx, a); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -210,3 +214,57 @@ func callerTokenID(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return "caller:" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func (r *Resolver) ensureManagedToken(ctx context.Context, a *RequestAuth) error {
|
||||
if strings.TrimSpace(a.Account.Token) == "" {
|
||||
return r.loginAndPersist(ctx, a)
|
||||
}
|
||||
if r.shouldForceRefresh(a.AccountID) {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
a.DeepSeekToken = a.Account.Token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Resolver) shouldForceRefresh(accountID string) bool {
|
||||
if r == nil || r.Store == nil {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(accountID) == "" {
|
||||
return false
|
||||
}
|
||||
intervalHours := r.Store.RuntimeTokenRefreshIntervalHours()
|
||||
if intervalHours <= 0 {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
last, ok := r.tokenRefreshedAt[accountID]
|
||||
if !ok || last.IsZero() {
|
||||
r.tokenRefreshedAt[accountID] = now
|
||||
return false
|
||||
}
|
||||
return now.Sub(last) >= time.Duration(intervalHours)*time.Hour
|
||||
}
|
||||
|
||||
func (r *Resolver) markTokenRefreshedNow(accountID string) {
|
||||
if strings.TrimSpace(accountID) == "" {
|
||||
return
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.tokenRefreshedAt[accountID] = time.Now()
|
||||
}
|
||||
|
||||
func (r *Resolver) clearTokenRefreshMark(accountID string) {
|
||||
if strings.TrimSpace(accountID) == "" {
|
||||
return
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
delete(r.tokenRefreshedAt, accountID)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
@@ -58,7 +60,7 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
|
||||
if auth.AccountID != "acc@example.com" {
|
||||
t.Fatalf("unexpected account id: %q", auth.AccountID)
|
||||
}
|
||||
if auth.DeepSeekToken != "account-token" {
|
||||
if auth.DeepSeekToken != "fresh-token" {
|
||||
t.Fatalf("unexpected account token: %q", auth.DeepSeekToken)
|
||||
}
|
||||
if auth.CallerID == "" {
|
||||
@@ -193,3 +195,109 @@ func TestDetermineCallerMissingToken(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineManagedAccountForcesRefreshEverySixHours(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[{"email":"acc@example.com","password":"pwd","token":"seed-token"}]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
if err := store.UpdateAccountToken("acc@example.com", "seed-token"); err != nil {
|
||||
t.Fatalf("update token failed: %v", err)
|
||||
}
|
||||
pool := account.NewPool(store)
|
||||
|
||||
var loginCount int32
|
||||
resolver := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
n := atomic.AddInt32(&loginCount, 1)
|
||||
return "fresh-token-" + string(rune('0'+n)), nil
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
|
||||
a1, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
if a1.DeepSeekToken != "seed-token" {
|
||||
t.Fatalf("expected initial token without forced refresh, got %q", a1.DeepSeekToken)
|
||||
}
|
||||
resolver.Release(a1)
|
||||
if got := atomic.LoadInt32(&loginCount); got != 0 {
|
||||
t.Fatalf("expected no login before refresh interval, got %d", got)
|
||||
}
|
||||
|
||||
resolver.mu.Lock()
|
||||
resolver.tokenRefreshedAt["acc@example.com"] = time.Now().Add(-7 * time.Hour)
|
||||
resolver.mu.Unlock()
|
||||
|
||||
a2, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine after interval failed: %v", err)
|
||||
}
|
||||
defer resolver.Release(a2)
|
||||
if a2.DeepSeekToken != "fresh-token-1" {
|
||||
t.Fatalf("expected refreshed token after interval, got %q", a2.DeepSeekToken)
|
||||
}
|
||||
if got := atomic.LoadInt32(&loginCount); got != 1 {
|
||||
t.Fatalf("expected exactly one forced refresh login, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineManagedAccountUsesUpdatedRefreshInterval(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[{"email":"acc@example.com","password":"pwd","token":"seed-token"}],
|
||||
"runtime":{"token_refresh_interval_hours":6}
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
if err := store.UpdateAccountToken("acc@example.com", "seed-token"); err != nil {
|
||||
t.Fatalf("update token failed: %v", err)
|
||||
}
|
||||
pool := account.NewPool(store)
|
||||
|
||||
var loginCount int32
|
||||
resolver := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
n := atomic.AddInt32(&loginCount, 1)
|
||||
return "fresh-token-" + string(rune('0'+n)), nil
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
|
||||
a1, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
if a1.DeepSeekToken != "seed-token" {
|
||||
t.Fatalf("expected initial token without forced refresh, got %q", a1.DeepSeekToken)
|
||||
}
|
||||
resolver.Release(a1)
|
||||
if got := atomic.LoadInt32(&loginCount); got != 0 {
|
||||
t.Fatalf("expected no login before runtime update, got %d", got)
|
||||
}
|
||||
|
||||
if err := store.Update(func(c *config.Config) error {
|
||||
c.Runtime.TokenRefreshIntervalHours = 1
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("update runtime failed: %v", err)
|
||||
}
|
||||
|
||||
resolver.mu.Lock()
|
||||
resolver.tokenRefreshedAt["acc@example.com"] = time.Now().Add(-2 * time.Hour)
|
||||
resolver.mu.Unlock()
|
||||
|
||||
a2, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine after runtime update failed: %v", err)
|
||||
}
|
||||
defer resolver.Release(a2)
|
||||
if a2.DeepSeekToken != "fresh-token-1" {
|
||||
t.Fatalf("expected refreshed token after runtime update, got %q", a2.DeepSeekToken)
|
||||
}
|
||||
if got := atomic.LoadInt32(&loginCount); got != 1 {
|
||||
t.Fatalf("expected exactly one login after runtime update, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,22 +73,31 @@ func TestGoCompatToolcallFixtures(t *testing.T) {
|
||||
mustLoadJSON(t, fixturePath, &fixture)
|
||||
|
||||
var expected struct {
|
||||
Calls []util.ParsedToolCall `json:"calls"`
|
||||
Calls []util.ParsedToolCall `json:"calls"`
|
||||
SawToolCallSyntax bool `json:"sawToolCallSyntax"`
|
||||
RejectedByPolicy bool `json:"rejectedByPolicy"`
|
||||
RejectedToolNames []string `json:"rejectedToolNames"`
|
||||
}
|
||||
mustLoadJSON(t, expectedPath, &expected)
|
||||
|
||||
var got []util.ParsedToolCall
|
||||
var got util.ToolCallParseResult
|
||||
switch strings.ToLower(strings.TrimSpace(fixture.Mode)) {
|
||||
case "standalone":
|
||||
got = util.ParseStandaloneToolCalls(fixture.Text, fixture.ToolNames)
|
||||
got = util.ParseStandaloneToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||
default:
|
||||
got = util.ParseToolCalls(fixture.Text, fixture.ToolNames)
|
||||
got = util.ParseToolCallsDetailed(fixture.Text, fixture.ToolNames)
|
||||
}
|
||||
if len(got) == 0 && len(expected.Calls) == 0 {
|
||||
continue
|
||||
if got.Calls == nil {
|
||||
got.Calls = []util.ParsedToolCall{}
|
||||
}
|
||||
if !reflect.DeepEqual(got, expected.Calls) {
|
||||
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected.Calls)
|
||||
if got.RejectedToolNames == nil {
|
||||
got.RejectedToolNames = []string{}
|
||||
}
|
||||
if !reflect.DeepEqual(got.Calls, expected.Calls) ||
|
||||
got.SawToolCallSyntax != expected.SawToolCallSyntax ||
|
||||
got.RejectedByPolicy != expected.RejectedByPolicy ||
|
||||
!reflect.DeepEqual(got.RejectedToolNames, expected.RejectedToolNames) {
|
||||
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
)
|
||||
import "strings"
|
||||
|
||||
func (a Account) Identifier() string {
|
||||
if strings.TrimSpace(a.Email) != "" {
|
||||
@@ -13,12 +9,5 @@ func (a Account) Identifier() string {
|
||||
if mobile := NormalizeMobileForStorage(a.Mobile); mobile != "" {
|
||||
return mobile
|
||||
}
|
||||
// Backward compatibility: old configs may contain token-only accounts.
|
||||
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
||||
token := strings.TrimSpace(a.Token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return "token:" + hex.EncodeToString(sum[:8])
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -32,21 +32,19 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 {
|
||||
m["admin"] = c.Admin
|
||||
}
|
||||
if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 {
|
||||
if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 || c.Runtime.TokenRefreshIntervalHours > 0 {
|
||||
m["runtime"] = c.Runtime
|
||||
}
|
||||
if c.Compat.WideInputStrictOutput != nil {
|
||||
m["compat"] = c.Compat
|
||||
}
|
||||
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
|
||||
m["toolcall"] = c.Toolcall
|
||||
}
|
||||
if c.Responses.StoreTTLSeconds > 0 {
|
||||
m["responses"] = c.Responses
|
||||
}
|
||||
if strings.TrimSpace(c.Embeddings.Provider) != "" {
|
||||
m["embeddings"] = c.Embeddings
|
||||
}
|
||||
m["auto_delete"] = c.AutoDelete
|
||||
if c.VercelSyncHash != "" {
|
||||
m["_vercel_sync_hash"] = c.VercelSyncHash
|
||||
}
|
||||
@@ -97,9 +95,7 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "toolcall":
|
||||
if err := json.Unmarshal(v, &c.Toolcall); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
// Legacy field ignored. Toolcall policy is fixed and no longer configurable.
|
||||
case "responses":
|
||||
if err := json.Unmarshal(v, &c.Responses); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
@@ -108,6 +104,10 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "auto_delete":
|
||||
if err := json.Unmarshal(v, &c.AutoDelete); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_hash":
|
||||
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
@@ -138,9 +138,9 @@ func (c Config) Clone() Config {
|
||||
Compat: CompatConfig{
|
||||
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
|
||||
},
|
||||
Toolcall: c.Toolcall,
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
AutoDelete: c.AutoDelete,
|
||||
VercelSyncHash: c.VercelSyncHash,
|
||||
VercelSyncTime: c.VercelSyncTime,
|
||||
AdditionalFields: map[string]any{},
|
||||
|
||||
@@ -9,20 +9,45 @@ type Config struct {
|
||||
Admin AdminConfig `json:"admin,omitempty"`
|
||||
Runtime RuntimeConfig `json:"runtime,omitempty"`
|
||||
Compat CompatConfig `json:"compat,omitempty"`
|
||||
Toolcall ToolcallConfig `json:"toolcall,omitempty"`
|
||||
Responses ResponsesConfig `json:"responses,omitempty"`
|
||||
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
|
||||
AutoDelete AutoDeleteConfig `json:"auto_delete"`
|
||||
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
||||
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
||||
AdditionalFields map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TestStatus string `json:"test_status,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Config) ClearAccountTokens() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
for i := range c.Accounts {
|
||||
c.Accounts[i].Token = ""
|
||||
}
|
||||
}
|
||||
|
||||
// DropInvalidAccounts removes accounts that cannot be addressed by admin APIs
|
||||
// (no email and no normalizable mobile). This prevents legacy token-only
|
||||
// records from becoming orphaned empty entries after token stripping.
|
||||
func (c *Config) DropInvalidAccounts() {
|
||||
if c == nil || len(c.Accounts) == 0 {
|
||||
return
|
||||
}
|
||||
kept := make([]Account, 0, len(c.Accounts))
|
||||
for _, acc := range c.Accounts {
|
||||
if acc.Identifier() == "" {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, acc)
|
||||
}
|
||||
c.Accounts = kept
|
||||
}
|
||||
|
||||
type CompatConfig struct {
|
||||
@@ -36,14 +61,10 @@ type AdminConfig struct {
|
||||
}
|
||||
|
||||
type RuntimeConfig struct {
|
||||
AccountMaxInflight int `json:"account_max_inflight,omitempty"`
|
||||
AccountMaxQueue int `json:"account_max_queue,omitempty"`
|
||||
GlobalMaxInflight int `json:"global_max_inflight,omitempty"`
|
||||
}
|
||||
|
||||
type ToolcallConfig struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"`
|
||||
AccountMaxInflight int `json:"account_max_inflight,omitempty"`
|
||||
AccountMaxQueue int `json:"account_max_queue,omitempty"`
|
||||
GlobalMaxInflight int `json:"global_max_inflight,omitempty"`
|
||||
TokenRefreshIntervalHours int `json:"token_refresh_interval_hours,omitempty"`
|
||||
}
|
||||
|
||||
type ResponsesConfig struct {
|
||||
@@ -53,3 +74,7 @@ type ResponsesConfig struct {
|
||||
type EmbeddingsConfig struct {
|
||||
Provider string `json:"provider,omitempty"`
|
||||
}
|
||||
|
||||
type AutoDeleteConfig struct {
|
||||
Sessions bool `json:"sessions"`
|
||||
}
|
||||
|
||||
@@ -104,6 +104,9 @@ func TestConfigJSONRoundtrip(t *testing.T) {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner",
|
||||
},
|
||||
Runtime: RuntimeConfig{
|
||||
TokenRefreshIntervalHours: 12,
|
||||
},
|
||||
VercelSyncHash: "hash123",
|
||||
VercelSyncTime: 1234567890,
|
||||
AdditionalFields: map[string]any{
|
||||
@@ -130,6 +133,9 @@ func TestConfigJSONRoundtrip(t *testing.T) {
|
||||
if decoded.ClaudeMapping["fast"] != "deepseek-chat" {
|
||||
t.Fatalf("unexpected claude mapping: %#v", decoded.ClaudeMapping)
|
||||
}
|
||||
if decoded.Runtime.TokenRefreshIntervalHours != 12 {
|
||||
t.Fatalf("unexpected runtime refresh interval: %#v", decoded.Runtime.TokenRefreshIntervalHours)
|
||||
}
|
||||
if decoded.VercelSyncHash != "hash123" {
|
||||
t.Fatalf("unexpected vercel sync hash: %q", decoded.VercelSyncHash)
|
||||
}
|
||||
|
||||
@@ -2,25 +2,23 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAccountIdentifierFallsBackToTokenHash(t *testing.T) {
|
||||
func TestAccountIdentifierRequiresEmailOrMobile(t *testing.T) {
|
||||
acc := Account{Token: "example-token-value"}
|
||||
id := acc.Identifier()
|
||||
if !strings.HasPrefix(id, "token:") {
|
||||
t.Fatalf("expected token-prefixed identifier, got %q", id)
|
||||
}
|
||||
if len(id) != len("token:")+16 {
|
||||
t.Fatalf("unexpected identifier length: %d (%q)", len(id), id)
|
||||
if id != "" {
|
||||
t.Fatalf("expected empty identifier when only token is present, got %q", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreFindAccountWithTokenOnlyIdentifier(t *testing.T) {
|
||||
func TestLoadStoreClearsTokensFromConfigInput(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
"accounts":[{"email":"u@example.com","password":"p","token":"token-only-account"}]
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
@@ -28,22 +26,87 @@ func TestStoreFindAccountWithTokenOnlyIdentifier(t *testing.T) {
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||
}
|
||||
id := accounts[0].Identifier()
|
||||
if id == "" {
|
||||
t.Fatalf("expected synthetic identifier for token-only account")
|
||||
}
|
||||
found, ok := store.FindAccount(id)
|
||||
if !ok {
|
||||
t.Fatalf("expected FindAccount to locate token-only account by synthetic id")
|
||||
}
|
||||
if found.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token value: %q", found.Token)
|
||||
if accounts[0].Token != "" {
|
||||
t.Fatalf("expected token to be cleared after loading, got %q", accounts[0].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T) {
|
||||
func TestLoadStoreDropsLegacyTokenOnlyAccounts(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"accounts":[{"token":"old-token"}]
|
||||
"accounts":[
|
||||
{"token":"legacy-token-only"},
|
||||
{"email":"u@example.com","password":"p","token":"runtime-token"}
|
||||
]
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
accounts := store.Accounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected token-only account to be dropped, got %d accounts", len(accounts))
|
||||
}
|
||||
if accounts[0].Identifier() != "u@example.com" {
|
||||
t.Fatalf("unexpected remaining account: %#v", accounts[0])
|
||||
}
|
||||
if accounts[0].Token != "" {
|
||||
t.Fatalf("expected persisted token to be cleared, got %q", accounts[0].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadStorePreservesFileBackedTokensForRuntime(t *testing.T) {
|
||||
tmp, err := os.CreateTemp(t.TempDir(), "config-*.json")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp config: %v", err)
|
||||
}
|
||||
defer tmp.Close()
|
||||
|
||||
if _, err := tmp.WriteString(`{
|
||||
"accounts":[{"email":"u@example.com","password":"p","token":"persisted-token"}]
|
||||
}`); err != nil {
|
||||
t.Fatalf("write temp config: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("DS2API_CONFIG_JSON", "")
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
t.Setenv("DS2API_CONFIG_PATH", tmp.Name())
|
||||
|
||||
store := LoadStore()
|
||||
accounts := store.Accounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||
}
|
||||
if accounts[0].Token != "persisted-token" {
|
||||
t.Fatalf("expected file-backed token preserved for runtime use, got %q", accounts[0].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeTokenRefreshIntervalHoursDefaultsToSix(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"email":"u@example.com","password":"p"}]
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
if got := store.RuntimeTokenRefreshIntervalHours(); got != 6 {
|
||||
t.Fatalf("expected default refresh interval 6, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeTokenRefreshIntervalHoursUsesConfigValue(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"email":"u@example.com","password":"p"}],
|
||||
"runtime":{"token_refresh_interval_hours":9}
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
if got := store.RuntimeTokenRefreshIntervalHours(); got != 9 {
|
||||
t.Fatalf("expected configured refresh interval 9, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdateAccountTokenKeepsIdentifierResolvable(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"accounts":[{"email":"user@example.com","password":"p"}]
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
@@ -52,23 +115,12 @@ func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T)
|
||||
t.Fatalf("expected 1 account, got %d", len(before))
|
||||
}
|
||||
oldID := before[0].Identifier()
|
||||
if oldID == "" {
|
||||
t.Fatal("expected old identifier")
|
||||
}
|
||||
if err := store.UpdateAccountToken(oldID, "new-token"); err != nil {
|
||||
t.Fatalf("update token failed: %v", err)
|
||||
}
|
||||
|
||||
after := store.Accounts()
|
||||
newID := after[0].Identifier()
|
||||
if newID == "" || newID == oldID {
|
||||
t.Fatalf("expected changed identifier, old=%q new=%q", oldID, newID)
|
||||
}
|
||||
if got, ok := store.FindAccount(newID); !ok || got.Token != "new-token" {
|
||||
t.Fatalf("expected find by new identifier")
|
||||
}
|
||||
if got, ok := store.FindAccount(oldID); !ok || got.Token != "new-token" {
|
||||
t.Fatalf("expected find by old identifier alias")
|
||||
t.Fatalf("expected find by stable account identifier")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,3 +173,39 @@ func TestLoadConfigOnVercelWithoutConfigFileFallsBackToMemory(t *testing.T) {
|
||||
t.Fatalf("expected empty bootstrap config, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountTestStatusIsRuntimeOnlyAndNotPersisted(t *testing.T) {
|
||||
tmp, err := os.CreateTemp(t.TempDir(), "config-*.json")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp config: %v", err)
|
||||
}
|
||||
defer tmp.Close()
|
||||
if _, err := tmp.WriteString(`{
|
||||
"accounts":[{"email":"u@example.com","password":"p","test_status":"ok"}]
|
||||
}`); err != nil {
|
||||
t.Fatalf("write temp config: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("DS2API_CONFIG_JSON", "")
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
t.Setenv("DS2API_CONFIG_PATH", tmp.Name())
|
||||
|
||||
store := LoadStore()
|
||||
if got, ok := store.AccountTestStatus("u@example.com"); ok || got != "" {
|
||||
t.Fatalf("expected no runtime status loaded from config, got %q", got)
|
||||
}
|
||||
if err := store.UpdateAccountTestStatus("u@example.com", "ok"); err != nil {
|
||||
t.Fatalf("update test status: %v", err)
|
||||
}
|
||||
if got, ok := store.AccountTestStatus("u@example.com"); !ok || got != "ok" {
|
||||
t.Fatalf("expected runtime status to be available, got %q (ok=%v)", got, ok)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(tmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("read config: %v", err)
|
||||
}
|
||||
if strings.Contains(string(content), "test_status") {
|
||||
t.Fatalf("expected test_status to stay out of persisted config, got: %s", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ type Store struct {
|
||||
fromEnv bool
|
||||
keyMap map[string]struct{} // O(1) API key lookup index
|
||||
accMap map[string]int // O(1) account lookup: identifier -> slice index
|
||||
accTest map[string]string // runtime-only account test status cache
|
||||
}
|
||||
|
||||
func LoadStore() *Store {
|
||||
@@ -39,6 +40,8 @@ func loadConfig() (Config, bool, error) {
|
||||
}
|
||||
if rawCfg != "" {
|
||||
cfg, err := parseConfigString(rawCfg)
|
||||
cfg.ClearAccountTokens()
|
||||
cfg.DropInvalidAccounts()
|
||||
return cfg, true, err
|
||||
}
|
||||
|
||||
@@ -55,6 +58,12 @@ func loadConfig() (Config, bool, error) {
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, false, err
|
||||
}
|
||||
cfg.DropInvalidAccounts()
|
||||
if strings.Contains(string(content), `"test_status"`) && !IsVercel() {
|
||||
if b, err := json.MarshalIndent(cfg, "", " "); err == nil {
|
||||
_ = os.WriteFile(ConfigPath(), b, 0o644)
|
||||
}
|
||||
}
|
||||
if IsVercel() {
|
||||
// Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
|
||||
return cfg, true, nil
|
||||
@@ -105,8 +114,19 @@ func (s *Store) UpdateAccountTestStatus(identifier, status string) error {
|
||||
if !ok {
|
||||
return errors.New("account not found")
|
||||
}
|
||||
s.cfg.Accounts[idx].TestStatus = status
|
||||
return s.saveLocked()
|
||||
s.setAccountTestStatusLocked(s.cfg.Accounts[idx], status, identifier)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) AccountTestStatus(identifier string) (string, bool) {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
if identifier == "" {
|
||||
return "", false
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
status, ok := s.accTest[identifier]
|
||||
return status, ok
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAccountToken(identifier, token string) error {
|
||||
@@ -161,7 +181,9 @@ func (s *Store) Save() error {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
persistCfg := s.cfg.Clone()
|
||||
persistCfg.ClearAccountTokens()
|
||||
b, err := json.MarshalIndent(persistCfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -173,7 +195,9 @@ func (s *Store) saveLocked() error {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
persistCfg := s.cfg.Clone()
|
||||
persistCfg.ClearAccountTokens()
|
||||
b, err := json.MarshalIndent(persistCfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -197,7 +221,9 @@ func (s *Store) SetVercelSync(hash string, ts int64) error {
|
||||
func (s *Store) ExportJSONAndBase64() (string, string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
b, err := json.Marshal(s.cfg)
|
||||
exportCfg := s.cfg.Clone()
|
||||
exportCfg.ClearAccountTokens()
|
||||
b, err := json.Marshal(exportCfg)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
@@ -43,23 +43,11 @@ func (s *Store) CompatWideInputStrictOutput() bool {
|
||||
}
|
||||
|
||||
func (s *Store) ToolcallMode() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
|
||||
if mode == "" {
|
||||
return "feature_match"
|
||||
}
|
||||
return mode
|
||||
return "feature_match"
|
||||
}
|
||||
|
||||
func (s *Store) ToolcallEarlyEmitConfidence() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
|
||||
if level == "" {
|
||||
return "high"
|
||||
}
|
||||
return level
|
||||
return "high"
|
||||
}
|
||||
|
||||
func (s *Store) ResponsesStoreTTLSeconds() int {
|
||||
@@ -165,3 +153,18 @@ func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeTokenRefreshIntervalHours() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.TokenRefreshIntervalHours > 0 {
|
||||
return s.cfg.Runtime.TokenRefreshIntervalHours
|
||||
}
|
||||
return 6
|
||||
}
|
||||
|
||||
func (s *Store) AutoDeleteSessions() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cfg.AutoDelete.Sessions
|
||||
}
|
||||
|
||||
@@ -2,15 +2,20 @@ package config
|
||||
|
||||
// rebuildIndexes must be called with the lock already held (or during init).
|
||||
func (s *Store) rebuildIndexes() {
|
||||
prevStatus := s.accTest
|
||||
s.keyMap = make(map[string]struct{}, len(s.cfg.Keys))
|
||||
for _, k := range s.cfg.Keys {
|
||||
s.keyMap[k] = struct{}{}
|
||||
}
|
||||
s.accMap = make(map[string]int, len(s.cfg.Accounts))
|
||||
s.accTest = make(map[string]string, len(s.cfg.Accounts))
|
||||
for i, acc := range s.cfg.Accounts {
|
||||
id := acc.Identifier()
|
||||
if id != "" {
|
||||
s.accMap[id] = i
|
||||
if status, ok := prevStatus[id]; ok {
|
||||
s.setAccountTestStatusLocked(acc, status, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -29,3 +34,22 @@ func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
func (s *Store) setAccountTestStatusLocked(acc Account, status, hintedIdentifier string) {
|
||||
status = lower(status)
|
||||
if status == "" {
|
||||
return
|
||||
}
|
||||
if id := acc.Identifier(); id != "" {
|
||||
s.accTest[id] = status
|
||||
}
|
||||
if email := acc.Email; email != "" {
|
||||
s.accTest[email] = status
|
||||
}
|
||||
if mobile := CanonicalMobileKey(acc.Mobile); mobile != "" {
|
||||
s.accTest[mobile] = status
|
||||
}
|
||||
if hintedIdentifier = lower(hintedIdentifier); hintedIdentifier != "" {
|
||||
s.accTest[hintedIdentifier] = status
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,8 +62,8 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
sessionID, _ := bizData["id"].(string)
|
||||
@@ -71,10 +71,9 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
|
||||
return sessionID, nil
|
||||
}
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) && !refreshed {
|
||||
if !refreshed && shouldAttemptRefresh(status, code, bizCode, msg, bizMsg) {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
@@ -96,6 +95,7 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
||||
@@ -104,8 +104,8 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
challenge, _ := bizData["challenge"].(map[string]any)
|
||||
@@ -116,15 +116,16 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
||||
}
|
||||
return BuildPowHeader(challenge, answer)
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) {
|
||||
if !refreshed && shouldAttemptRefresh(status, code, bizCode, msg, bizMsg) {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
@@ -143,15 +144,75 @@ func (c *Client) authHeaders(token string) map[string]string {
|
||||
return headers
|
||||
}
|
||||
|
||||
func isTokenInvalid(status int, code int, msg string) bool {
|
||||
msg = strings.ToLower(msg)
|
||||
func isTokenInvalid(status int, code int, bizCode int, msg string, bizMsg string) bool {
|
||||
msg = strings.ToLower(strings.TrimSpace(msg) + " " + strings.TrimSpace(bizMsg))
|
||||
if status == http.StatusUnauthorized || status == http.StatusForbidden {
|
||||
return true
|
||||
}
|
||||
if code == 40001 || code == 40002 || code == 40003 {
|
||||
if code == 40001 || code == 40002 || code == 40003 || bizCode == 40001 || bizCode == 40002 || bizCode == 40003 {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
|
||||
return strings.Contains(msg, "token") ||
|
||||
strings.Contains(msg, "unauthorized") ||
|
||||
strings.Contains(msg, "expired") ||
|
||||
strings.Contains(msg, "not login") ||
|
||||
strings.Contains(msg, "login required") ||
|
||||
strings.Contains(msg, "invalid jwt")
|
||||
}
|
||||
|
||||
func shouldAttemptRefresh(status int, code int, bizCode int, msg string, bizMsg string) bool {
|
||||
if isTokenInvalid(status, code, bizCode, msg, bizMsg) {
|
||||
return true
|
||||
}
|
||||
// Some DeepSeek failures come back as HTTP 200/code=0 but with non-zero biz_code.
|
||||
// Only attempt refresh when these biz failures still look auth-related.
|
||||
return status == http.StatusOK &&
|
||||
code == 0 &&
|
||||
bizCode != 0 &&
|
||||
isAuthIndicativeBizFailure(msg, bizMsg)
|
||||
}
|
||||
|
||||
func isAuthIndicativeBizFailure(msg string, bizMsg string) bool {
|
||||
combined := strings.ToLower(strings.TrimSpace(msg) + " " + strings.TrimSpace(bizMsg))
|
||||
authKeywords := []string{
|
||||
"auth",
|
||||
"authorization",
|
||||
"credential",
|
||||
"expired",
|
||||
"invalid jwt",
|
||||
"jwt",
|
||||
"login",
|
||||
"not login",
|
||||
"session expired",
|
||||
"token",
|
||||
"unauthorized",
|
||||
"登录",
|
||||
"未登录",
|
||||
"认证",
|
||||
"凭证",
|
||||
"会话过期",
|
||||
"令牌",
|
||||
}
|
||||
for _, keyword := range authKeywords {
|
||||
if strings.Contains(combined, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractResponseStatus(resp map[string]any) (code int, bizCode int, msg string, bizMsg string) {
|
||||
code = intFrom(resp["code"])
|
||||
msg, _ = resp["msg"].(string)
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizCode = intFrom(data["biz_code"])
|
||||
bizMsg, _ = data["biz_msg"].(string)
|
||||
if strings.TrimSpace(bizMsg) == "" {
|
||||
if bizData, ok := data["biz_data"].(map[string]any); ok {
|
||||
bizMsg, _ = bizData["msg"].(string)
|
||||
}
|
||||
}
|
||||
return code, bizCode, msg, bizMsg
|
||||
}
|
||||
|
||||
func normalizeMobileForLogin(raw string) (mobile string, areaCode any) {
|
||||
|
||||
27
internal/deepseek/client_auth_refresh_test.go
Normal file
27
internal/deepseek/client_auth_refresh_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package deepseek
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShouldAttemptRefreshOnTokenInvalidSignal(t *testing.T) {
|
||||
if !shouldAttemptRefresh(401, 0, 0, "unauthorized", "") {
|
||||
t.Fatal("expected refresh when response indicates invalid token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldAttemptRefreshOnAuthIndicativeBizCodeFailure(t *testing.T) {
|
||||
if !shouldAttemptRefresh(200, 0, 400123, "", "login expired, token invalid") {
|
||||
t.Fatal("expected refresh on auth-indicative biz_code failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldAttemptRefreshFalseOnNonAuthBizCodeFailure(t *testing.T) {
|
||||
if shouldAttemptRefresh(200, 0, 400123, "", "session create failed: quota reached") {
|
||||
t.Fatal("did not expect refresh on non-auth biz_code failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldAttemptRefreshFalseOnGenericServerError(t *testing.T) {
|
||||
if shouldAttemptRefresh(500, 500, 0, "internal error", "") {
|
||||
t.Fatal("did not expect refresh on generic server error")
|
||||
}
|
||||
}
|
||||
@@ -62,3 +62,40 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
|
||||
}
|
||||
return out, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string) (map[string]any, int, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := doer.Do(req)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[deepseek] fingerprint GET request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if reqErr != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
resp, err = c.fallback.Do(req2)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
payloadBytes, err := readResponseBody(resp)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
out := map[string]any{}
|
||||
if len(payloadBytes) > 0 {
|
||||
if err := json.Unmarshal(payloadBytes, &out); err != nil {
|
||||
config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
|
||||
}
|
||||
}
|
||||
return out, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
256
internal/deepseek/client_session.go
Normal file
256
internal/deepseek/client_session.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
// SessionInfo 会话信息
|
||||
type SessionInfo struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
TitleType string `json:"title_type"`
|
||||
Pinned bool `json:"pinned"`
|
||||
UpdatedAt float64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// SessionStats 会话统计结果
|
||||
type SessionStats struct {
|
||||
AccountID string // 账号标识 (email 或 mobile)
|
||||
FirstPageCount int // 第一页会话数量(当 HasMore 为 true 时,真实总数可能更大)
|
||||
PinnedCount int // 置顶会话数量
|
||||
HasMore bool // 是否还有更多页
|
||||
Success bool // 请求是否成功
|
||||
ErrorMessage string // 错误信息
|
||||
}
|
||||
|
||||
// GetSessionCount 获取单个账号的会话数量
|
||||
func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (*SessionStats, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
|
||||
stats := &SessionStats{
|
||||
AccountID: a.AccountID,
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
|
||||
// 构建请求 URL
|
||||
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[get_session_count] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
chatSessions, _ := bizData["chat_sessions"].([]any)
|
||||
hasMore, _ := bizData["has_more"].(bool)
|
||||
|
||||
stats.FirstPageCount = len(chatSessions)
|
||||
stats.HasMore = hasMore
|
||||
stats.Success = true
|
||||
|
||||
// 统计置顶会话数量
|
||||
for _, session := range chatSessions {
|
||||
if s, ok := session.(map[string]any); ok {
|
||||
if pinned, ok := s["pinned"].(bool); ok && pinned {
|
||||
stats.PinnedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
stats.ErrorMessage = fmt.Sprintf("status=%d, code=%d, msg=%s", status, code, msg)
|
||||
config.Logger.Warn("[get_session_count] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "account", a.AccountID)
|
||||
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
|
||||
stats.Success = false
|
||||
stats.ErrorMessage = "get session count failed after retries"
|
||||
return stats, errors.New(stats.ErrorMessage)
|
||||
}
|
||||
|
||||
// GetSessionCountForToken 直接使用 token 获取会话数量(直通模式)
|
||||
func (c *Client) GetSessionCountForToken(ctx context.Context, token string) (*SessionStats, error) {
|
||||
headers := c.authHeaders(token)
|
||||
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status != http.StatusOK || code != 0 || bizCode != 0 {
|
||||
if strings.TrimSpace(bizMsg) != "" {
|
||||
msg = bizMsg
|
||||
}
|
||||
return nil, fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
}
|
||||
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
chatSessions, _ := bizData["chat_sessions"].([]any)
|
||||
hasMore, _ := bizData["has_more"].(bool)
|
||||
|
||||
stats := &SessionStats{
|
||||
FirstPageCount: len(chatSessions),
|
||||
HasMore: hasMore,
|
||||
Success: true,
|
||||
}
|
||||
|
||||
// 统计置顶会话数量
|
||||
for _, session := range chatSessions {
|
||||
if s, ok := session.(map[string]any); ok {
|
||||
if pinned, ok := s["pinned"].(bool); ok && pinned {
|
||||
stats.PinnedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetSessionCountAll 获取所有账号的会话数量统计
|
||||
func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
|
||||
accounts := c.Store.Accounts()
|
||||
results := make([]*SessionStats, 0, len(accounts))
|
||||
|
||||
for _, acc := range accounts {
|
||||
token := acc.Token
|
||||
accountID := acc.Email
|
||||
if accountID == "" {
|
||||
accountID = acc.Mobile
|
||||
}
|
||||
|
||||
// 如果没有 token,尝试登录获取
|
||||
if token == "" {
|
||||
var err error
|
||||
token, err = c.Login(ctx, acc)
|
||||
if err != nil {
|
||||
results = append(results, &SessionStats{
|
||||
AccountID: accountID,
|
||||
Success: false,
|
||||
ErrorMessage: fmt.Sprintf("login failed: %v", err),
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := c.GetSessionCountForToken(ctx, token)
|
||||
if err != nil {
|
||||
results = append(results, &SessionStats{
|
||||
AccountID: accountID,
|
||||
Success: false,
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
stats.AccountID = accountID
|
||||
results = append(results, stats)
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// FetchSessionPage 获取会话列表(支持分页)
|
||||
func (c *Client) FetchSessionPage(ctx context.Context, a *auth.RequestAuth, cursor string) ([]SessionInfo, bool, error) {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
|
||||
// 构建请求 URL
|
||||
params := url.Values{}
|
||||
params.Set("lte_cursor.pinned", "false")
|
||||
if cursor != "" {
|
||||
params.Set("lte_cursor", cursor)
|
||||
}
|
||||
reqURL := DeepSeekFetchSessionURL + "?" + params.Encode()
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
code := intFrom(resp["code"])
|
||||
if status != http.StatusOK || code != 0 {
|
||||
msg, _ := resp["msg"].(string)
|
||||
return nil, false, fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
}
|
||||
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
chatSessions, _ := bizData["chat_sessions"].([]any)
|
||||
hasMore, _ := bizData["has_more"].(bool)
|
||||
|
||||
sessions := make([]SessionInfo, 0, len(chatSessions))
|
||||
for _, s := range chatSessions {
|
||||
if m, ok := s.(map[string]any); ok {
|
||||
session := SessionInfo{
|
||||
ID: stringFromMap(m, "id"),
|
||||
Title: stringFromMap(m, "title"),
|
||||
TitleType: stringFromMap(m, "title_type"),
|
||||
Pinned: boolFromMap(m, "pinned"),
|
||||
UpdatedAt: floatFromMap(m, "updated_at"),
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
}
|
||||
|
||||
return sessions, hasMore, nil
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func stringFromMap(m map[string]any, key string) string {
|
||||
if v, ok := m[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func boolFromMap(m map[string]any, key string) bool {
|
||||
if v, ok := m[key].(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func floatFromMap(m map[string]any, key string) float64 {
|
||||
if v, ok := m[key].(float64); ok {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
155
internal/deepseek/client_session_delete.go
Normal file
155
internal/deepseek/client_session_delete.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
// DeleteSessionResult 删除会话结果
|
||||
type DeleteSessionResult struct {
|
||||
SessionID string // 会话 ID
|
||||
Success bool // 是否成功
|
||||
ErrorMessage string // 错误信息
|
||||
}
|
||||
|
||||
// DeleteSession 删除单个会话
|
||||
func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, sessionID string, maxAttempts int) (*DeleteSessionResult, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
|
||||
result := &DeleteSessionResult{
|
||||
SessionID: sessionID,
|
||||
}
|
||||
|
||||
if sessionID == "" {
|
||||
result.ErrorMessage = "session_id is required"
|
||||
return result, errors.New(result.ErrorMessage)
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[delete_session] request error", "error", err, "session_id", sessionID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||
result.Success = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result.ErrorMessage = fmt.Sprintf("status=%d, code=%d, msg=%s", status, code, msg)
|
||||
config.Logger.Warn("[delete_session] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "session_id", sessionID)
|
||||
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
|
||||
result.Success = false
|
||||
result.ErrorMessage = "delete session failed after retries"
|
||||
return result, errors.New(result.ErrorMessage)
|
||||
}
|
||||
|
||||
// DeleteSessionForToken 直接使用 token 删除会话(直通模式)
|
||||
func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessionID string) (*DeleteSessionResult, error) {
|
||||
result := &DeleteSessionResult{
|
||||
SessionID: sessionID,
|
||||
}
|
||||
|
||||
if sessionID == "" {
|
||||
result.ErrorMessage = "session_id is required"
|
||||
return result, errors.New(result.ErrorMessage)
|
||||
}
|
||||
|
||||
headers := c.authHeaders(token)
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||
if err != nil {
|
||||
result.ErrorMessage = err.Error()
|
||||
return result, err
|
||||
}
|
||||
|
||||
code := intFrom(resp["code"])
|
||||
if status != http.StatusOK || code != 0 {
|
||||
msg, _ := resp["msg"].(string)
|
||||
result.ErrorMessage = fmt.Sprintf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
return result, errors.New(result.ErrorMessage)
|
||||
}
|
||||
|
||||
result.Success = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteAllSessions 删除所有会话(谨慎使用)
|
||||
func (c *Client) DeleteAllSessions(ctx context.Context, a *auth.RequestAuth) error {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
payload := map[string]any{}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[delete_all_sessions] request error", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
code := intFrom(resp["code"])
|
||||
if status != http.StatusOK || code != 0 {
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[delete_all_sessions] failed", "status", status, "code", code, "msg", msg)
|
||||
return fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAllSessionsForToken 直接使用 token 删除所有会话(直通模式)
|
||||
func (c *Client) DeleteAllSessionsForToken(ctx context.Context, token string) error {
|
||||
headers := c.authHeaders(token)
|
||||
payload := map[string]any{}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[delete_all_sessions_for_token] request error", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
code := intFrom(resp["code"])
|
||||
if status != http.StatusOK || code != 0 {
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[delete_all_sessions_for_token] failed", "status", status, "code", code, "msg", msg)
|
||||
return fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -11,6 +11,9 @@ const (
|
||||
DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create"
|
||||
DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge"
|
||||
DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion"
|
||||
DeepSeekFetchSessionURL = "https://chat.deepseek.com/api/v0/chat_session/fetch_page"
|
||||
DeepSeekDeleteSessionURL = "https://chat.deepseek.com/api/v0/chat_session/delete"
|
||||
DeepSeekDeleteAllSessionsURL = "https://chat.deepseek.com/api/v0/chat_session/delete_all"
|
||||
)
|
||||
|
||||
var defaultBaseHeaders = map[string]string{
|
||||
|
||||
@@ -8,15 +8,15 @@ import (
|
||||
)
|
||||
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
messageObj["reasoning_content"] = finalThinking
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
if len(detected.Calls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected.Calls)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
// Strict mode: only standalone, structured tool-call payloads are treated
|
||||
// as executable tool calls.
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames)
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
if len(detected.Calls) > 0 {
|
||||
exposedOutputText = ""
|
||||
output = append(output, toResponsesFunctionCallItems(detected)...)
|
||||
output = append(output, toResponsesFunctionCallItems(detected.Calls)...)
|
||||
} else {
|
||||
content := make([]any, 0, 2)
|
||||
if finalThinking != "" {
|
||||
|
||||
@@ -71,6 +71,19 @@ func BuildResponsesTextDeltaPayload(responseID, itemID string, outputIndex, cont
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func BuildResponsesTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.done",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"text": text,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning.delta",
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -45,7 +46,7 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) {
|
||||
func TestBuildResponseObjectPromotesMixedProseToolPayloadToFunctionCall(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
@@ -56,8 +57,32 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) {
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one function_call output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected function_call output type, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if !strings.Contains(outputText, "\"tool_calls\"") {
|
||||
t.Fatalf("expected output_text to preserve fenced tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
@@ -69,28 +94,9 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for fenced example")
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected message output type, got %#v", first["type"])
|
||||
}
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestBuildResponseObjectPromotesFencedToolPayloadToFunctionCall(t *testing.T) {
|
||||
TestBuildResponseObjectKeepsFencedToolPayloadAsText(t)
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
||||
|
||||
@@ -8,13 +8,14 @@ const {
|
||||
|
||||
function resolveToolcallPolicy(prepBody, payloadTools) {
|
||||
const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names);
|
||||
const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
|
||||
const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match);
|
||||
const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
|
||||
let toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
|
||||
if (toolNames.length === 0 && Array.isArray(payloadTools) && payloadTools.length > 0) {
|
||||
toolNames = ['__any_tool__'];
|
||||
}
|
||||
return {
|
||||
toolNames,
|
||||
toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled,
|
||||
emitEarlyToolDeltas,
|
||||
toolSieveEnabled: toolNames.length > 0,
|
||||
emitEarlyToolDeltas: true,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -60,6 +61,9 @@ function formatIncrementalToolCallDeltas(deltas, idStore) {
|
||||
if (typeof d.arguments === 'string' && d.arguments !== '') {
|
||||
fn.arguments = d.arguments;
|
||||
}
|
||||
if (Object.keys(fn).length === 0) {
|
||||
continue;
|
||||
}
|
||||
if (Object.keys(fn).length > 0) {
|
||||
item.function = fn;
|
||||
}
|
||||
@@ -73,17 +77,6 @@ function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenName
|
||||
return [];
|
||||
}
|
||||
const seen = seenNames instanceof Map ? seenNames : new Map();
|
||||
const allowed = new Set((allowedNames || []).filter((name) => asString(name) !== ''));
|
||||
if (allowed.size === 0) {
|
||||
for (const d of deltas) {
|
||||
if (d && typeof d === 'object' && asString(d.name)) {
|
||||
const index = Number.isInteger(d.index) ? d.index : 0;
|
||||
seen.set(index, '__blocked__');
|
||||
}
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const out = [];
|
||||
for (const d of deltas) {
|
||||
if (!d || typeof d !== 'object') {
|
||||
@@ -92,16 +85,12 @@ function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenName
|
||||
const index = Number.isInteger(d.index) ? d.index : 0;
|
||||
const name = asString(d.name);
|
||||
if (name) {
|
||||
if (!allowed.has(name)) {
|
||||
seen.set(index, '__blocked__');
|
||||
continue;
|
||||
}
|
||||
seen.set(index, name);
|
||||
out.push(d);
|
||||
continue;
|
||||
}
|
||||
const existing = asString(seen.get(index));
|
||||
if (!existing || existing === '__blocked__') {
|
||||
if (!existing) {
|
||||
continue;
|
||||
}
|
||||
out.push(d);
|
||||
|
||||
@@ -1,33 +1,22 @@
|
||||
'use strict';
|
||||
|
||||
const {
|
||||
extractToolNames,
|
||||
createToolSieveState,
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseStandaloneToolCalls,
|
||||
formatOpenAIStreamToolCalls,
|
||||
} = require('../helpers/stream-tool-sieve');
|
||||
const {
|
||||
BASE_HEADERS,
|
||||
} = require('../shared/deepseek-constants');
|
||||
|
||||
const {
|
||||
writeOpenAIError,
|
||||
} = require('./error_shape');
|
||||
const {
|
||||
parseChunkForContent,
|
||||
isCitation,
|
||||
} = require('./sse_parse');
|
||||
const {
|
||||
buildUsage,
|
||||
} = require('./token_usage');
|
||||
const { BASE_HEADERS } = require('../shared/deepseek-constants');
|
||||
const { writeOpenAIError } = require('./error_shape');
|
||||
const { parseChunkForContent, isCitation } = require('./sse_parse');
|
||||
const { buildUsage } = require('./token_usage');
|
||||
const {
|
||||
resolveToolcallPolicy,
|
||||
formatIncrementalToolCallDeltas,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
} = require('./toolcall_policy');
|
||||
const {
|
||||
createChatCompletionEmitter,
|
||||
} = require('./stream_emitter');
|
||||
const { createChatCompletionEmitter } = require('./stream_emitter');
|
||||
const {
|
||||
asString,
|
||||
isAbortError,
|
||||
@@ -57,6 +46,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
const searchEnabled = toBool(prep.body.search_enabled);
|
||||
const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools);
|
||||
const toolNames = toolPolicy.toolNames;
|
||||
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
|
||||
|
||||
if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
|
||||
writeOpenAIError(res, 500, 'invalid vercel prepare response');
|
||||
@@ -132,6 +122,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
const toolSieveState = createToolSieveState();
|
||||
let toolCallsEmitted = false;
|
||||
const streamToolCallIDs = new Map();
|
||||
const streamToolNames = new Map();
|
||||
const decoder = new TextDecoder();
|
||||
reader = completionRes.body.getReader();
|
||||
let buffered = '';
|
||||
@@ -255,6 +246,18 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
}
|
||||
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
||||
for (const evt of events) {
|
||||
if (evt.type === 'tool_call_deltas') {
|
||||
if (!emitEarlyToolDeltas) {
|
||||
continue;
|
||||
}
|
||||
const filtered = filterIncrementalToolCallDeltasByAllowed(evt.deltas, toolNames, streamToolNames);
|
||||
const formatted = formatIncrementalToolCallDeltas(filtered, streamToolCallIDs);
|
||||
if (formatted.length > 0) {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatted });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (evt.type === 'tool_calls') {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||
|
||||
@@ -140,9 +140,33 @@ function extractJSONObjectFrom(text, start) {
|
||||
return { ok: false, end: 0 };
|
||||
}
|
||||
|
||||
function trimWrappingJSONFence(prefix, suffix) {
|
||||
const rightTrimmedPrefix = (prefix || '').replace(/[ \t\r\n]+$/g, '');
|
||||
const fenceIdx = rightTrimmedPrefix.lastIndexOf('```');
|
||||
if (fenceIdx < 0) return { prefix, suffix };
|
||||
const fenceCount = (rightTrimmedPrefix.slice(0, fenceIdx + 3).match(/```/g) || []).length;
|
||||
if (fenceCount % 2 === 0) {
|
||||
return { prefix, suffix };
|
||||
}
|
||||
const header = rightTrimmedPrefix.slice(fenceIdx + 3).trim().toLowerCase();
|
||||
if (header && header !== 'json') {
|
||||
return { prefix, suffix };
|
||||
}
|
||||
const leftTrimmedSuffix = (suffix || '').replace(/^[ \t\r\n]+/g, '');
|
||||
if (!leftTrimmedSuffix.startsWith('```')) {
|
||||
return { prefix, suffix };
|
||||
}
|
||||
const consumed = (suffix || '').length - leftTrimmedSuffix.length;
|
||||
return {
|
||||
prefix: rightTrimmedPrefix.slice(0, fenceIdx),
|
||||
suffix: (suffix || '').slice(consumed + 3),
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
findObjectFieldValueStart,
|
||||
parseJSONStringLiteral,
|
||||
skipSpaces,
|
||||
extractJSONObjectFrom,
|
||||
trimWrappingJSONFence,
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user