mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-03 16:05:26 +08:00
Compare commits
119 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
586d31e556 | ||
|
|
b0a09dfab0 | ||
|
|
58f753d0c0 | ||
|
|
2e0586d060 | ||
|
|
1676c8e4f2 | ||
|
|
add13366d2 | ||
|
|
d5a23191f2 | ||
|
|
d2d4e39983 | ||
|
|
6e0dca3b30 | ||
|
|
b108a7915a | ||
|
|
2caabd8ce6 | ||
|
|
c4a73e871a | ||
|
|
6802a3d53e | ||
|
|
e828006cb0 | ||
|
|
a6499cbece | ||
|
|
a504905626 | ||
|
|
59bf78d2c4 | ||
|
|
25b3292497 | ||
|
|
6cf4f0528c | ||
|
|
d8f8dcb704 | ||
|
|
455489ffeb | ||
|
|
5031ae0e6f | ||
|
|
3fccec0e22 | ||
|
|
00d38f1187 | ||
|
|
fe0f3d2c17 | ||
|
|
f67cbfad35 | ||
|
|
11f66db87d | ||
|
|
9afc533153 | ||
|
|
6a39543288 | ||
|
|
7131b06e26 | ||
|
|
8fa1f998aa | ||
|
|
f8936887d0 | ||
|
|
db89744055 | ||
|
|
65312fc573 | ||
|
|
661d753fd3 | ||
|
|
7ca3f141c6 | ||
|
|
d530d25793 | ||
|
|
990cdcf02d | ||
|
|
648bb74587 | ||
|
|
9e5baed061 | ||
|
|
4884773639 | ||
|
|
6758514c61 | ||
|
|
01f33c409f | ||
|
|
55f11e655a | ||
|
|
2275e931f9 | ||
|
|
40594a44db | ||
|
|
67787d9c99 | ||
|
|
7061094964 | ||
|
|
492c603300 | ||
|
|
7e473dffc9 | ||
|
|
43a6e6712f | ||
|
|
ce1b76c90f | ||
|
|
1e7e0b2ae3 | ||
|
|
fd158e5ae2 | ||
|
|
95c96f7744 | ||
|
|
e7f59fac80 | ||
|
|
1bf059396f | ||
|
|
696b403173 | ||
|
|
f4db2732b0 | ||
|
|
ee88a74dcf | ||
|
|
ca08bb66b9 | ||
|
|
708fcb5beb | ||
|
|
7a65d1eaa2 | ||
|
|
6de2457743 | ||
|
|
ce44e260bf | ||
|
|
09f6537ffc | ||
|
|
ab8f494fdb | ||
|
|
b56a211da9 | ||
|
|
fcce5308cb | ||
|
|
d27b19cc53 | ||
|
|
b8ff678f24 | ||
|
|
b24ef1282d | ||
|
|
65e0de3c82 | ||
|
|
0c2743a48c | ||
|
|
dc73e8a6da | ||
|
|
b8495eeeb3 | ||
|
|
b3eae22cef | ||
|
|
7af0098d1b | ||
|
|
17405be300 | ||
|
|
5bc03e5de6 | ||
|
|
5a5f93148d | ||
|
|
32dc5b6099 | ||
|
|
7936d4675f | ||
|
|
808eafa7c6 | ||
|
|
bcb8ed6df2 | ||
|
|
8ec5dcc0cc | ||
|
|
88a79f212d | ||
|
|
b1f8d6192f | ||
|
|
acfb3b225d | ||
|
|
99a6164000 | ||
|
|
e49d9d33e2 | ||
|
|
184a3d1e4e | ||
|
|
c4ec14f49a | ||
|
|
fb5fc0e885 | ||
|
|
20b603666d | ||
|
|
4d549b7102 | ||
|
|
33b0d1d144 | ||
|
|
41c0f7ce28 | ||
|
|
efb484ba4f | ||
|
|
145501d4a5 | ||
|
|
2d5103997b | ||
|
|
52e7e7aae8 | ||
|
|
5b5a4000d7 | ||
|
|
2bbf603148 | ||
|
|
d14b8a0664 | ||
|
|
f16e0b579e | ||
|
|
43cbc4aac0 | ||
|
|
cf569f4749 | ||
|
|
c9c59f2490 | ||
|
|
16216cc2ca | ||
|
|
de50fd3954 | ||
|
|
7648d5f192 | ||
|
|
d35e5eab25 | ||
|
|
90610a52ce | ||
|
|
f6296d506f | ||
|
|
dfea092583 | ||
|
|
af7dc134bb | ||
|
|
2657d37f76 | ||
|
|
7318d1f4a8 |
98
.env.example
98
.env.example
@@ -1,93 +1,15 @@
|
||||
# DS2API environment template (Go runtime)
|
||||
# Copy this file to .env and adjust values.
|
||||
# Updated: 2026-02
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Runtime
|
||||
# ---------------------------------------------------------------
|
||||
# HTTP listen port (default: 5001)
|
||||
# DS2API runtime
|
||||
PORT=5001
|
||||
|
||||
# Log level: DEBUG | INFO | WARN | ERROR
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Max concurrent inflight requests per account in managed-key mode.
|
||||
# Default: 2
|
||||
# Recommended client concurrency is calculated dynamically as:
|
||||
# account_count * DS2API_ACCOUNT_MAX_INFLIGHT
|
||||
# So by default it is account_count * 2.
|
||||
# Requests beyond inflight slots enter a waiting queue first.
|
||||
# Default queue size equals recommended concurrency, so 429 starts after:
|
||||
# account_count * DS2API_ACCOUNT_MAX_INFLIGHT * 2
|
||||
# Alias: DS2API_ACCOUNT_CONCURRENCY
|
||||
# DS2API_ACCOUNT_MAX_INFLIGHT=2
|
||||
# Admin authentication
|
||||
DS2API_ADMIN_KEY=change-me
|
||||
|
||||
# Optional waiting queue size override for managed-key mode.
|
||||
# Default: recommended_concurrency (same as account_count * inflight_limit)
|
||||
# Alias: DS2API_ACCOUNT_QUEUE_SIZE
|
||||
# DS2API_ACCOUNT_MAX_QUEUE=10
|
||||
# Config loading (choose one)
|
||||
# 1) file-based config
|
||||
DS2API_CONFIG_PATH=/app/config.json
|
||||
# 2) inline JSON or Base64 JSON
|
||||
# DS2API_CONFIG_JSON=
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Admin auth
|
||||
# ---------------------------------------------------------------
|
||||
# Admin key for /admin login and protected admin APIs.
|
||||
# Default is "admin" when unset, but setting it explicitly is recommended.
|
||||
DS2API_ADMIN_KEY=admin
|
||||
|
||||
# Optional JWT signing secret for admin token.
|
||||
# Defaults to DS2API_ADMIN_KEY when unset.
|
||||
# DS2API_JWT_SECRET=change-me
|
||||
|
||||
# Optional admin JWT validity in hours (default: 24)
|
||||
# DS2API_JWT_EXPIRE_HOURS=24
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Config source (choose one)
|
||||
# ---------------------------------------------------------------
|
||||
# Option A: config file path (local/dev recommended)
|
||||
# DS2API_CONFIG_PATH=config.json
|
||||
|
||||
# Option B: JSON string
|
||||
# DS2API_CONFIG_JSON={"keys":["your-api-key"],"accounts":[{"email":"user@example.com","password":"xxx","token":""}]}
|
||||
|
||||
# Option C: Base64 encoded JSON (recommended for Vercel env var)
|
||||
# DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0=
|
||||
#
|
||||
# Generate from local config.json:
|
||||
# DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Paths (optional)
|
||||
# ---------------------------------------------------------------
|
||||
# WASM file used for PoW solving
|
||||
# DS2API_WASM_PATH=sha3_wasm_bg.7b9ca65ddd.wasm
|
||||
|
||||
# Built admin static assets directory
|
||||
# DS2API_STATIC_ADMIN_DIR=static/admin
|
||||
|
||||
# Auto-build WebUI on startup when static/admin is missing.
|
||||
# Default: enabled on local/Docker, disabled on Vercel.
|
||||
# DS2API_AUTO_BUILD_WEBUI=true
|
||||
|
||||
# Internal auth secret used by the Vercel hybrid streaming path
|
||||
# (Go prepare endpoint <-> Node stream function).
|
||||
# Optional: falls back to DS2API_ADMIN_KEY when unset.
|
||||
# DS2API_VERCEL_INTERNAL_SECRET=change-me
|
||||
|
||||
# Stream lease TTL seconds for Vercel hybrid streaming.
|
||||
# During this window, the managed account stays occupied until Node calls release.
|
||||
# Default: 900 (15 minutes)
|
||||
# DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS=900
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Vercel sync integration (optional)
|
||||
# ---------------------------------------------------------------
|
||||
# VERCEL_TOKEN=your-vercel-token
|
||||
# VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx
|
||||
# VERCEL_TEAM_ID=team_xxxxxxxxxxxx
|
||||
|
||||
# Optional: Vercel deployment protection bypass secret.
|
||||
# If deployment protection is enabled, DS2API will use this value as
|
||||
# x-vercel-protection-bypass for internal Node->Go calls on Vercel.
|
||||
# You can also use VERCEL_AUTOMATION_BYPASS_SECRET directly.
|
||||
# DS2API_VERCEL_PROTECTION_BYPASS=your-bypass-secret
|
||||
# Optional: static admin assets path
|
||||
# DS2API_STATIC_ADMIN_DIR=/app/static/admin
|
||||
|
||||
2
.github/workflows/quality-gates.yml
vendored
2
.github/workflows/quality-gates.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
node-version: "24"
|
||||
cache: "npm"
|
||||
cache-dependency-path: webui/package-lock.json
|
||||
|
||||
|
||||
8
.github/workflows/release-artifacts.yml
vendored
8
.github/workflows/release-artifacts.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
node-version: "24"
|
||||
cache: "npm"
|
||||
cache-dependency-path: webui/package-lock.json
|
||||
|
||||
@@ -51,6 +51,10 @@ jobs:
|
||||
run: |
|
||||
set -euo pipefail
|
||||
TAG="${RELEASE_TAG}"
|
||||
BUILD_VERSION="${TAG}"
|
||||
if [ -z "${BUILD_VERSION}" ] && [ -f VERSION ]; then
|
||||
BUILD_VERSION="$(cat VERSION | tr -d '[:space:]')"
|
||||
fi
|
||||
mkdir -p dist
|
||||
|
||||
targets=(
|
||||
@@ -73,7 +77,7 @@ jobs:
|
||||
|
||||
mkdir -p "${STAGE}/static"
|
||||
CGO_ENABLED=0 GOOS="${GOOS}" GOARCH="${GOARCH}" \
|
||||
go build -trimpath -ldflags="-s -w" -o "${STAGE}/${BIN}" ./cmd/ds2api
|
||||
go build -trimpath -ldflags="-s -w -X ds2api/internal/version.BuildVersion=${BUILD_VERSION}" -o "${STAGE}/${BIN}" ./cmd/ds2api
|
||||
|
||||
cp config.example.json .env.example sha3_wasm_bg.7b9ca65ddd.wasm LICENSE README.MD README.en.md "${STAGE}/"
|
||||
cp -R static/admin "${STAGE}/static/admin"
|
||||
|
||||
2
.github/workflows/release-dockerhub.yml
vendored
2
.github/workflows/release-dockerhub.yml
vendored
@@ -123,5 +123,7 @@ jobs:
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
build-args: |
|
||||
BUILD_VERSION=${{ steps.next_version.outputs.new_tag }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
258
.github/workflows/release.yml
vendored
258
.github/workflows/release.yml
vendored
@@ -1,128 +1,130 @@
|
||||
name: Release to Aliyun CR
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version_type:
|
||||
description: '版本类型'
|
||||
required: true
|
||||
default: 'patch'
|
||||
type: choice
|
||||
options:
|
||||
- patch
|
||||
- minor
|
||||
- major
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Get current version
|
||||
id: get_version
|
||||
run: |
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
TAG_VERSION=${LATEST_TAG#v}
|
||||
|
||||
if [ -f VERSION ]; then
|
||||
FILE_VERSION=$(cat VERSION | tr -d '[:space:]')
|
||||
else
|
||||
FILE_VERSION="0.0.0"
|
||||
fi
|
||||
|
||||
function version_gt() { test "$(printf '%s\n' "$@" | sort -V | head -n 1)" != "$1"; }
|
||||
|
||||
if version_gt "$FILE_VERSION" "$TAG_VERSION"; then
|
||||
VERSION="$FILE_VERSION"
|
||||
else
|
||||
VERSION="$TAG_VERSION"
|
||||
fi
|
||||
|
||||
echo "Current version: $VERSION"
|
||||
echo "current_version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Calculate next version
|
||||
id: next_version
|
||||
env:
|
||||
VERSION_TYPE: ${{ github.event.inputs.version_type }}
|
||||
run: |
|
||||
VERSION="${{ steps.get_version.outputs.current_version }}"
|
||||
BASE_VERSION=$(echo "$VERSION" | sed 's/-.*$//')
|
||||
|
||||
IFS='.' read -r -a version_parts <<< "$BASE_VERSION"
|
||||
MAJOR="${version_parts[0]:-0}"
|
||||
MINOR="${version_parts[1]:-0}"
|
||||
PATCH="${version_parts[2]:-0}"
|
||||
|
||||
case "$VERSION_TYPE" in
|
||||
major)
|
||||
NEW_VERSION="$((MAJOR + 1)).0.0"
|
||||
;;
|
||||
minor)
|
||||
NEW_VERSION="${MAJOR}.$((MINOR + 1)).0"
|
||||
;;
|
||||
*)
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.$((PATCH + 1))"
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "New version: $NEW_VERSION"
|
||||
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
echo "new_tag=v$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
echo "${{ steps.next_version.outputs.new_version }}" > VERSION
|
||||
|
||||
- name: Commit VERSION and create tag
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git add VERSION
|
||||
if ! git diff --cached --quiet; then
|
||||
git commit -m "chore: bump version to ${{ steps.next_version.outputs.new_tag }} [skip ci]"
|
||||
fi
|
||||
|
||||
NEW_TAG="${{ steps.next_version.outputs.new_tag }}"
|
||||
git tag -a "$NEW_TAG" -m "Release $NEW_TAG"
|
||||
git push origin HEAD:main "$NEW_TAG"
|
||||
|
||||
# Docker 构建并推送到阿里云
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Aliyun Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ secrets.ALIYUN_REGISTRY }}
|
||||
username: ${{ secrets.ALIYUN_REGISTRY_USER }}
|
||||
password: ${{ secrets.ALIYUN_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_tag }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_version }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:latest
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
name: Release to Aliyun CR
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version_type:
|
||||
description: '版本类型'
|
||||
required: true
|
||||
default: 'patch'
|
||||
type: choice
|
||||
options:
|
||||
- patch
|
||||
- minor
|
||||
- major
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Get current version
|
||||
id: get_version
|
||||
run: |
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
TAG_VERSION=${LATEST_TAG#v}
|
||||
|
||||
if [ -f VERSION ]; then
|
||||
FILE_VERSION=$(cat VERSION | tr -d '[:space:]')
|
||||
else
|
||||
FILE_VERSION="0.0.0"
|
||||
fi
|
||||
|
||||
function version_gt() { test "$(printf '%s\n' "$@" | sort -V | head -n 1)" != "$1"; }
|
||||
|
||||
if version_gt "$FILE_VERSION" "$TAG_VERSION"; then
|
||||
VERSION="$FILE_VERSION"
|
||||
else
|
||||
VERSION="$TAG_VERSION"
|
||||
fi
|
||||
|
||||
echo "Current version: $VERSION"
|
||||
echo "current_version=$VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Calculate next version
|
||||
id: next_version
|
||||
env:
|
||||
VERSION_TYPE: ${{ github.event.inputs.version_type }}
|
||||
run: |
|
||||
VERSION="${{ steps.get_version.outputs.current_version }}"
|
||||
BASE_VERSION=$(echo "$VERSION" | sed 's/-.*$//')
|
||||
|
||||
IFS='.' read -r -a version_parts <<< "$BASE_VERSION"
|
||||
MAJOR="${version_parts[0]:-0}"
|
||||
MINOR="${version_parts[1]:-0}"
|
||||
PATCH="${version_parts[2]:-0}"
|
||||
|
||||
case "$VERSION_TYPE" in
|
||||
major)
|
||||
NEW_VERSION="$((MAJOR + 1)).0.0"
|
||||
;;
|
||||
minor)
|
||||
NEW_VERSION="${MAJOR}.$((MINOR + 1)).0"
|
||||
;;
|
||||
*)
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.$((PATCH + 1))"
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "New version: $NEW_VERSION"
|
||||
echo "new_version=$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
echo "new_tag=v$NEW_VERSION" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
echo "${{ steps.next_version.outputs.new_version }}" > VERSION
|
||||
|
||||
- name: Commit VERSION and create tag
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git add VERSION
|
||||
if ! git diff --cached --quiet; then
|
||||
git commit -m "chore: bump version to ${{ steps.next_version.outputs.new_tag }} [skip ci]"
|
||||
fi
|
||||
|
||||
NEW_TAG="${{ steps.next_version.outputs.new_tag }}"
|
||||
git tag -a "$NEW_TAG" -m "Release $NEW_TAG"
|
||||
git push origin HEAD:main "$NEW_TAG"
|
||||
|
||||
# Docker 构建并推送到阿里云
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Aliyun Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ secrets.ALIYUN_REGISTRY }}
|
||||
username: ${{ secrets.ALIYUN_REGISTRY_USER }}
|
||||
password: ${{ secrets.ALIYUN_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_tag }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:${{ steps.next_version.outputs.new_version }}
|
||||
${{ secrets.ALIYUN_REGISTRY }}/${{ secrets.ALIYUN_REGISTRY_NAMESPACE }}/ds2api:latest
|
||||
labels: |
|
||||
org.opencontainers.image.version=${{ steps.next_version.outputs.new_version }}
|
||||
org.opencontainers.image.revision=${{ github.sha }}
|
||||
build-args: |
|
||||
BUILD_VERSION=${{ steps.next_version.outputs.new_tag }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
@@ -623,6 +623,7 @@ Reads runtime settings and status, including:
|
||||
- `admin` (JWT expiry, default-password warning, etc.)
|
||||
- `runtime` (`account_max_inflight`, `account_max_queue`, `global_max_inflight`)
|
||||
- `toolcall` / `responses` / `embeddings`
|
||||
- `auto_delete` (`sessions`)
|
||||
- `claude_mapping` / `model_aliases`
|
||||
- `env_backed`, `needs_vercel_sync`
|
||||
|
||||
@@ -635,6 +636,7 @@ Hot-updates runtime settings. Supported fields:
|
||||
- `toolcall.mode` / `toolcall.early_emit_confidence`
|
||||
- `responses.store_ttl_seconds`
|
||||
- `embeddings.provider`
|
||||
- `auto_delete.sessions`
|
||||
- `claude_mapping`
|
||||
- `model_aliases`
|
||||
|
||||
|
||||
2
API.md
2
API.md
@@ -628,6 +628,7 @@ data: {"type":"message_stop"}
|
||||
- `admin`(JWT 过期、默认密码告警等)
|
||||
- `runtime`(`account_max_inflight`、`account_max_queue`、`global_max_inflight`)
|
||||
- `toolcall` / `responses` / `embeddings`
|
||||
- `auto_delete`(`sessions`)
|
||||
- `claude_mapping` / `model_aliases`
|
||||
- `env_backed`、`needs_vercel_sync`
|
||||
|
||||
@@ -640,6 +641,7 @@ data: {"type":"message_stop"}
|
||||
- `toolcall.mode` / `toolcall.early_emit_confidence`
|
||||
- `responses.store_ttl_seconds`
|
||||
- `embeddings.provider`
|
||||
- `auto_delete.sessions`
|
||||
- `claude_mapping`
|
||||
- `model_aliases`
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go entry
|
||||
│ ├── chat-stream.js # Vercel Node.js stream relay
|
||||
│ └── helpers/ # Node.js helper modules
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # Account pool and concurrency queue
|
||||
│ ├── adapter/
|
||||
@@ -112,6 +112,7 @@ ds2api/
|
||||
│ ├── compat/ # Compatibility helpers
|
||||
│ ├── config/ # Config loading and hot-reload
|
||||
│ ├── deepseek/ # DeepSeek client, PoW WASM
|
||||
│ ├── js/ # Node runtime stream/compat logic
|
||||
│ ├── devcapture/ # Dev packet capture
|
||||
│ ├── format/ # Output formatting
|
||||
│ ├── prompt/ # Prompt building
|
||||
@@ -123,7 +124,9 @@ ds2api/
|
||||
│ └── webui/ # WebUI static hosting
|
||||
├── webui/ # React WebUI source
|
||||
│ └── src/
|
||||
│ ├── components/ # Components
|
||||
│ ├── app/ # Routing, auth, config state
|
||||
│ ├── features/ # Feature modules
|
||||
│ ├── components/ # Shared components
|
||||
│ └── locales/ # Language packs
|
||||
├── scripts/ # Build and test scripts
|
||||
├── static/admin/ # WebUI build output (not committed)
|
||||
|
||||
@@ -99,7 +99,7 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go 入口
|
||||
│ ├── chat-stream.js # Vercel Node.js 流式转发
|
||||
│ └── helpers/ # Node.js 辅助模块
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # 账号池与并发队列
|
||||
│ ├── adapter/
|
||||
@@ -112,6 +112,7 @@ ds2api/
|
||||
│ ├── compat/ # 兼容性辅助
|
||||
│ ├── config/ # 配置加载与热更新
|
||||
│ ├── deepseek/ # DeepSeek 客户端、PoW WASM
|
||||
│ ├── js/ # Node 运行时流式/兼容逻辑
|
||||
│ ├── devcapture/ # 开发抓包
|
||||
│ ├── format/ # 输出格式化
|
||||
│ ├── prompt/ # Prompt 构建
|
||||
@@ -123,7 +124,9 @@ ds2api/
|
||||
│ └── webui/ # WebUI 静态托管
|
||||
├── webui/ # React WebUI 源码
|
||||
│ └── src/
|
||||
│ ├── components/ # 组件
|
||||
│ ├── app/ # 路由、鉴权、配置状态
|
||||
│ ├── features/ # 业务功能模块
|
||||
│ ├── components/ # 通用组件
|
||||
│ └── locales/ # 语言包
|
||||
├── scripts/ # 构建与测试脚本
|
||||
├── static/admin/ # WebUI 构建产物(不提交)
|
||||
|
||||
13
DEPLOY.en.md
13
DEPLOY.en.md
@@ -113,12 +113,8 @@ go build -o ds2api ./cmd/ds2api
|
||||
# Copy env template
|
||||
cp .env.example .env
|
||||
|
||||
# Generate single-line Base64 from config.json
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# Edit .env and set:
|
||||
# Edit .env and set at least:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# Start
|
||||
docker-compose up -d
|
||||
@@ -185,6 +181,7 @@ Notes:
|
||||
|
||||
- **Port**: DS2API listens on `5001` by default; the template sets `PORT=5001`.
|
||||
- **Persistent config**: the template mounts `/data` and sets `DS2API_CONFIG_PATH=/data/config.json`. After importing config in Admin UI, it will be written and persisted to this path.
|
||||
- **Build version**: Zeabur / regular `docker build` does not require `BUILD_VERSION` by default. The image prefers that build arg when provided, and automatically falls back to the repo-root `VERSION` file when it is absent.
|
||||
- **First login**: after deployment, open `/admin` and login with `DS2API_ADMIN_KEY` shown in Zeabur env/template instructions (recommended: rotate to a strong secret after first login).
|
||||
|
||||
---
|
||||
@@ -366,7 +363,7 @@ Each archive includes:
|
||||
|
||||
- `ds2api` executable (`ds2api.exe` on Windows)
|
||||
- `static/admin/` (built WebUI assets)
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm`
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm` (optional; binary has embedded fallback)
|
||||
- `config.example.json`, `.env.example`
|
||||
- `README.MD`, `README.en.md`, `LICENSE`
|
||||
|
||||
@@ -455,7 +452,9 @@ server {
|
||||
```bash
|
||||
# Copy compiled binary and related files to target directory
|
||||
sudo mkdir -p /opt/ds2api
|
||||
sudo cp ds2api config.json sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp ds2api config.json /opt/ds2api/
|
||||
# Optional: if you want to use an external WASM file (override embedded one)
|
||||
# sudo cp sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp -r static/admin /opt/ds2api/static/admin
|
||||
```
|
||||
|
||||
|
||||
13
DEPLOY.md
13
DEPLOY.md
@@ -113,12 +113,8 @@ go build -o ds2api ./cmd/ds2api
|
||||
# 复制环境变量模板
|
||||
cp .env.example .env
|
||||
|
||||
# 从 config.json 生成单行 Base64
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# 编辑 .env(请改成你的强密码),设置:
|
||||
# 编辑 .env(请改成你的强密码),至少设置:
|
||||
# DS2API_ADMIN_KEY=your-admin-key
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# 启动
|
||||
docker-compose up -d
|
||||
@@ -185,6 +181,7 @@ healthcheck:
|
||||
|
||||
- **端口**:服务默认监听 `5001`,模板会固定设置 `PORT=5001`。
|
||||
- **配置持久化**:模板挂载卷 `/data`,并设置 `DS2API_CONFIG_PATH=/data/config.json`;在管理台导入配置后,会写入并持久化到该路径。
|
||||
- **构建版本号**:Zeabur / 普通 `docker build` 默认不需要传 `BUILD_VERSION`;镜像会优先使用该构建参数,未提供时自动回退到仓库根目录的 `VERSION` 文件。
|
||||
- **首次登录**:部署完成后访问 `/admin`,使用 Zeabur 环境变量/模板指引中的 `DS2API_ADMIN_KEY` 登录(建议首次登录后自行更换为强密码)。
|
||||
|
||||
---
|
||||
@@ -366,7 +363,7 @@ No Output Directory named "public" found after the Build completed.
|
||||
|
||||
- `ds2api` 可执行文件(Windows 为 `ds2api.exe`)
|
||||
- `static/admin/`(WebUI 构建产物)
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm`
|
||||
- `sha3_wasm_bg.7b9ca65ddd.wasm`(可选;程序内置 embed fallback)
|
||||
- `config.example.json`、`.env.example`
|
||||
- `README.MD`、`README.en.md`、`LICENSE`
|
||||
|
||||
@@ -455,7 +452,9 @@ server {
|
||||
```bash
|
||||
# 将编译好的二进制文件和相关文件复制到目标目录
|
||||
sudo mkdir -p /opt/ds2api
|
||||
sudo cp ds2api config.json sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp ds2api config.json /opt/ds2api/
|
||||
# 可选:若你希望使用外置 WASM 文件(覆盖内置版本)
|
||||
# sudo cp sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/
|
||||
sudo cp -r static/admin /opt/ds2api/static/admin
|
||||
```
|
||||
|
||||
|
||||
@@ -10,19 +10,24 @@ FROM golang:1.24 AS go-builder
|
||||
WORKDIR /app
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
ARG BUILD_VERSION
|
||||
COPY go.mod go.sum* ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN set -eux; \
|
||||
GOOS="${TARGETOS:-$(go env GOOS)}"; \
|
||||
GOARCH="${TARGETARCH:-$(go env GOARCH)}"; \
|
||||
CGO_ENABLED=0 GOOS="${GOOS}" GOARCH="${GOARCH}" go build -o /out/ds2api ./cmd/ds2api
|
||||
BUILD_VERSION_RESOLVED="${BUILD_VERSION:-}"; \
|
||||
if [ -z "${BUILD_VERSION_RESOLVED}" ] && [ -f VERSION ]; then BUILD_VERSION_RESOLVED="$(cat VERSION | tr -d "[:space:]")"; fi; \
|
||||
CGO_ENABLED=0 GOOS="${GOOS}" GOARCH="${GOARCH}" go build -ldflags="-s -w -X ds2api/internal/version.BuildVersion=${BUILD_VERSION_RESOLVED}" -o /out/ds2api ./cmd/ds2api
|
||||
|
||||
FROM busybox:1.36.1-musl AS busybox-tools
|
||||
|
||||
FROM debian:bookworm-slim AS runtime-base
|
||||
WORKDIR /app
|
||||
COPY --from=go-builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=busybox-tools /bin/busybox /usr/local/bin/busybox
|
||||
EXPOSE 5001
|
||||
CMD ["/usr/local/bin/ds2api"]
|
||||
|
||||
56
README.MD
56
README.MD
@@ -160,17 +160,13 @@ go run ./cmd/ds2api
|
||||
# 1. 准备环境变量文件
|
||||
cp .env.example .env
|
||||
|
||||
# 2. 从 config.json 生成 DS2API_CONFIG_JSON(单行 Base64)
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# 3. 编辑 .env,设置:
|
||||
# 2. 编辑 .env(至少设置 DS2API_ADMIN_KEY)
|
||||
# DS2API_ADMIN_KEY=请替换为强密码
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# 4. 启动
|
||||
# 3. 启动
|
||||
docker-compose up -d
|
||||
|
||||
# 5. 查看日志
|
||||
# 4. 查看日志
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
@@ -182,6 +178,8 @@ docker-compose logs -f
|
||||
2. 部署完成后访问 `/admin`,使用 Zeabur 环境变量/模板指引中的 `DS2API_ADMIN_KEY` 登录。
|
||||
3. 在管理台导入/编辑配置(会写入并持久化到 `/data/config.json`)。
|
||||
|
||||
说明:Zeabur 使用仓库内 `Dockerfile` 直接构建时,不需要额外传入 `BUILD_VERSION`;镜像会优先读取该构建参数,未提供时自动回退到仓库根目录的 `VERSION` 文件。
|
||||
|
||||
### 方式三:Vercel 部署
|
||||
|
||||
1. Fork 仓库到自己的 GitHub
|
||||
@@ -246,13 +244,11 @@ cp opencode.json.example opencode.json
|
||||
"accounts": [
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
},
|
||||
{
|
||||
"mobile": "12345678901",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
@@ -273,7 +269,7 @@ cp opencode.json.example opencode.json
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_model_mapping": {
|
||||
"claude_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
@@ -284,21 +280,25 @@ cp opencode.json.example opencode.json
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
},
|
||||
"auto_delete": {
|
||||
"sessions": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- `keys`:API 访问密钥列表,客户端通过 `Authorization: Bearer <key>` 鉴权
|
||||
- `accounts`:DeepSeek 账号列表,支持 `email` 或 `mobile` 登录
|
||||
- `token`:留空则首次请求时自动登录获取;也可预填已有 token
|
||||
- `token`:配置文件中即使填写也会在加载时被清空(不会从 `config.json` 读取 token);实际 token 仅在运行时内存中维护并自动刷新
|
||||
- `model_aliases`:常见模型名(如 GPT/Codex/Claude)到 DeepSeek 模型的映射
|
||||
- `compat.wide_input_strict_output`:建议保持 `true`(当前实现默认宽进严出)
|
||||
- `toolcall`:固定采用特征匹配 + 高置信早发策略
|
||||
- `responses.store_ttl_seconds`:`/v1/responses/{id}` 的内存缓存 TTL
|
||||
- `embeddings.provider`:embedding 提供方(当前内置 `deterministic/mock/builtin`)
|
||||
- `claude_model_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型
|
||||
- `claude_mapping`:字典中 `fast`/`slow` 后缀映射到对应 DeepSeek 模型(兼容读取 `claude_model_mapping`)
|
||||
- `admin`:管理后台设置(JWT 过期时间、密码哈希等),可通过 Admin Settings API 热更新
|
||||
- `runtime`:运行时参数(并发限制、队列大小),可通过 Admin Settings API 热更新
|
||||
- `runtime`:运行时参数(并发限制、队列大小),可通过 Admin Settings API 热更新;`account_max_queue=0`/`global_max_inflight=0` 表示按推荐值自动计算
|
||||
- `auto_delete.sessions`:是否在请求结束后自动清理 DeepSeek 会话(默认 `false`,可在 Settings 热更新)
|
||||
|
||||
### 环境变量
|
||||
|
||||
@@ -397,7 +397,7 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go 入口
|
||||
│ ├── chat-stream.js # Vercel Node.js 流式转发
|
||||
│ └── helpers/ # Node.js 辅助模块
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # 账号池与并发队列
|
||||
│ ├── adapter/
|
||||
@@ -410,6 +410,7 @@ ds2api/
|
||||
│ ├── compat/ # 兼容性辅助
|
||||
│ ├── config/ # 配置加载与热更新
|
||||
│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM
|
||||
│ ├── js/ # Node 运行时流式处理与兼容逻辑
|
||||
│ ├── devcapture/ # 开发抓包模块
|
||||
│ ├── format/ # 输出格式化
|
||||
│ ├── prompt/ # Prompt 构建
|
||||
@@ -420,7 +421,9 @@ ds2api/
|
||||
│ └── webui/ # WebUI 静态文件托管与自动构建
|
||||
├── webui/ # React WebUI 源码(Vite + Tailwind)
|
||||
│ └── src/
|
||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||
│ ├── app/ # 路由、鉴权、配置状态管理
|
||||
│ ├── features/ # 业务功能模块(account/settings/vercel/apiTester)
|
||||
│ ├── components/ # 登录/落地页等通用组件
|
||||
│ └── locales/ # 中英文语言包(zh.json / en.json)
|
||||
├── scripts/
|
||||
│ └── build-webui.sh # WebUI 手动构建脚本
|
||||
@@ -476,6 +479,23 @@ go run ./cmd/ds2api-tests \
|
||||
npm ci --prefix webui && npm run build --prefix webui
|
||||
```
|
||||
|
||||
## 测试
|
||||
|
||||
详细测试指南请参阅 [TESTING.md](TESTING.md)。
|
||||
|
||||
### 快速测试命令
|
||||
|
||||
```bash
|
||||
# 运行所有单元测试
|
||||
go test ./...
|
||||
|
||||
# 运行 tool calls 相关测试(调试工具调用问题)
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||
|
||||
# 运行端到端测试
|
||||
./tests/scripts/run-live.sh
|
||||
```
|
||||
|
||||
## Release 自动构建(GitHub Actions)
|
||||
|
||||
工作流文件:`.github/workflows/release-artifacts.yml`
|
||||
@@ -483,7 +503,7 @@ npm ci --prefix webui && npm run build --prefix webui
|
||||
- **触发条件**:仅在 GitHub Release `published` 时触发(普通 push 不会触发)
|
||||
- **构建产物**:多平台二进制包(`linux/amd64`、`linux/arm64`、`darwin/amd64`、`darwin/arm64`、`windows/amd64`)+ `sha256sums.txt`
|
||||
- **容器镜像发布**:仅推送到 GHCR(`ghcr.io/cjackhwang/ds2api`)
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件、配置示例、README、LICENSE
|
||||
- **每个压缩包包含**:`ds2api` 可执行文件、`static/admin`、WASM 文件(同时支持内置 fallback)、配置示例、README、LICENSE
|
||||
|
||||
## 免责声明
|
||||
|
||||
|
||||
39
README.en.md
39
README.en.md
@@ -160,17 +160,13 @@ Default URL: `http://localhost:5001`
|
||||
# 1. Prepare env file
|
||||
cp .env.example .env
|
||||
|
||||
# 2. Generate DS2API_CONFIG_JSON from config.json (single-line Base64)
|
||||
DS2API_CONFIG_JSON="$(base64 < config.json | tr -d '\n')"
|
||||
|
||||
# 3. Edit .env and set:
|
||||
# 2. Edit .env (at least set DS2API_ADMIN_KEY)
|
||||
# DS2API_ADMIN_KEY=replace-with-a-strong-secret
|
||||
# DS2API_CONFIG_JSON=${DS2API_CONFIG_JSON}
|
||||
|
||||
# 4. Start
|
||||
# 3. Start
|
||||
docker-compose up -d
|
||||
|
||||
# 5. View logs
|
||||
# 4. View logs
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
@@ -182,6 +178,8 @@ Rebuild after updates: `docker-compose up -d --build`
|
||||
2. After deployment, open `/admin` and login with `DS2API_ADMIN_KEY` shown in Zeabur env/template instructions.
|
||||
3. Import / edit config in Admin UI (it will be written and persisted to `/data/config.json`).
|
||||
|
||||
Note: when Zeabur builds directly from the repo `Dockerfile`, you do not need to pass `BUILD_VERSION`. The image prefers that build arg when provided, and automatically falls back to the repo-root `VERSION` file when it is absent.
|
||||
|
||||
### Option 3: Vercel
|
||||
|
||||
1. Fork this repo to your GitHub account
|
||||
@@ -246,13 +244,11 @@ cp opencode.json.example opencode.json
|
||||
"accounts": [
|
||||
{
|
||||
"email": "user@example.com",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
},
|
||||
{
|
||||
"mobile": "12345678901",
|
||||
"password": "your-password",
|
||||
"token": ""
|
||||
"password": "your-password"
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
@@ -273,7 +269,7 @@ cp opencode.json.example opencode.json
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_model_mapping": {
|
||||
"claude_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
@@ -284,21 +280,25 @@ cp opencode.json.example opencode.json
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
},
|
||||
"auto_delete": {
|
||||
"sessions": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- `keys`: API access keys; clients authenticate via `Authorization: Bearer <key>`
|
||||
- `accounts`: DeepSeek account list, supports `email` or `mobile` login
|
||||
- `token`: Leave empty for auto-login on first request; or pre-fill an existing token
|
||||
- `token`: Even if set in `config.json`, it is cleared during load (DS2API does not read persisted tokens from config); runtime tokens are maintained/refreshed in memory only
|
||||
- `model_aliases`: Map common model names (GPT/Codex/Claude) to DeepSeek models
|
||||
- `compat.wide_input_strict_output`: Keep `true` (current default policy)
|
||||
- `toolcall`: Fixed to feature matching + high-confidence early emit
|
||||
- `responses.store_ttl_seconds`: In-memory TTL for `/v1/responses/{id}`
|
||||
- `embeddings.provider`: Embeddings provider (`deterministic/mock/builtin` built-in)
|
||||
- `claude_model_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models
|
||||
- `claude_mapping`: Maps `fast`/`slow` suffixes to corresponding DeepSeek models (still compatible with `claude_model_mapping`)
|
||||
- `admin`: Admin panel settings (JWT expiry, password hash, etc.), hot-reloadable via Admin Settings API
|
||||
- `runtime`: Runtime parameters (concurrency limits, queue sizes), hot-reloadable via Admin Settings API
|
||||
- `runtime`: Runtime parameters (concurrency limits, queue sizes), hot-reloadable via Admin Settings API; `account_max_queue=0`/`global_max_inflight=0` means auto-calculate from recommended values
|
||||
- `auto_delete.sessions`: Whether to auto-delete DeepSeek sessions after request completion (default `false`, hot-reloadable via Settings)
|
||||
|
||||
### Environment Variables
|
||||
|
||||
@@ -398,7 +398,7 @@ ds2api/
|
||||
├── api/
|
||||
│ ├── index.go # Vercel Serverless Go entry
|
||||
│ ├── chat-stream.js # Vercel Node.js stream relay
|
||||
│ └── helpers/ # Node.js helper modules
|
||||
│ └── (rewrite targets in vercel.json)
|
||||
├── internal/
|
||||
│ ├── account/ # Account pool and concurrency queue
|
||||
│ ├── adapter/
|
||||
@@ -411,6 +411,7 @@ ds2api/
|
||||
│ ├── compat/ # Compatibility helpers
|
||||
│ ├── config/ # Config loading and hot-reload
|
||||
│ ├── deepseek/ # DeepSeek API client, PoW WASM
|
||||
│ ├── js/ # Node runtime stream/compat logic
|
||||
│ ├── devcapture/ # Dev packet capture module
|
||||
│ ├── format/ # Output formatting
|
||||
│ ├── prompt/ # Prompt construction
|
||||
@@ -421,7 +422,9 @@ ds2api/
|
||||
│ └── webui/ # WebUI static file serving and auto-build
|
||||
├── webui/ # React WebUI source (Vite + Tailwind)
|
||||
│ └── src/
|
||||
│ ├── components/ # AccountManager / ApiTester / BatchImport / VercelSync / Login / LandingPage
|
||||
│ ├── app/ # Routing, auth, config state
|
||||
│ ├── features/ # Feature modules (account/settings/vercel/apiTester)
|
||||
│ ├── components/ # Shared UI pieces (login/landing, etc.)
|
||||
│ └── locales/ # Language packs (zh.json / en.json)
|
||||
├── scripts/
|
||||
│ └── build-webui.sh # Manual WebUI build script
|
||||
@@ -484,7 +487,7 @@ Workflow: `.github/workflows/release-artifacts.yml`
|
||||
- **Trigger**: only on GitHub Release `published` (normal pushes do not trigger builds)
|
||||
- **Outputs**: multi-platform archives (`linux/amd64`, `linux/arm64`, `darwin/amd64`, `darwin/arm64`, `windows/amd64`) + `sha256sums.txt`
|
||||
- **Container publishing**: GHCR only (`ghcr.io/cjackhwang/ds2api`)
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file, config template, README, LICENSE
|
||||
- **Each archive includes**: `ds2api` executable, `static/admin`, WASM file (with embedded fallback support), config template, README, LICENSE
|
||||
|
||||
## Disclaimer
|
||||
|
||||
|
||||
46
TESTING.md
46
TESTING.md
@@ -51,7 +51,7 @@ DS2API 提供两个层级的测试:
|
||||
1. **Preflight 检查**:
|
||||
- `go test ./... -count=1`(单元测试)
|
||||
- `./tests/scripts/check-node-split-syntax.sh`(Node 拆分模块语法门禁)
|
||||
- `node --test`(如仓库存在 Node 单测文件时执行;当前默认以 Go 测试 + Node 语法门禁为主)
|
||||
- `node --test tests/node/stream-tool-sieve.test.js tests/node/chat-stream.test.js tests/node/js_compat_test.js`
|
||||
- `npm run build --prefix webui`(WebUI 构建检查)
|
||||
|
||||
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
||||
@@ -173,6 +173,50 @@ rg "<trace_id>" artifacts/testsuite/<run_id>/server.log
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### 运行特定模块的单元测试
|
||||
|
||||
```bash
|
||||
# 运行 tool calls 相关测试(推荐用于调试 tool call 解析问题)
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||
|
||||
# 运行单个测试用例
|
||||
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/
|
||||
|
||||
# 运行 format 相关测试
|
||||
go test -v ./internal/format/...
|
||||
|
||||
# 运行 adapter 相关测试
|
||||
go test -v ./internal/adapter/openai/...
|
||||
```
|
||||
|
||||
### 调试 Tool Call 问题 | Debugging Tool Call Issues
|
||||
|
||||
当遇到 DeepSeek 工具调用解析问题时,可以使用以下方法:
|
||||
|
||||
```bash
|
||||
# 1. 运行 tool calls 相关的所有测试
|
||||
go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/
|
||||
|
||||
# 2. 查看测试输出中的详细调试信息
|
||||
go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/ 2>&1
|
||||
|
||||
# 3. 检查具体测试用例的修复效果
|
||||
# 测试用例位于 internal/util/toolcalls_test.go,包含:
|
||||
# - TestParseToolCallsWithDeepSeekHallucination: DeepSeek 典型幻觉输出
|
||||
# - TestRepairLooseJSONWithNestedObjects: 嵌套对象的方括号修复
|
||||
# - TestParseToolCallsWithMixedWindowsPaths: Windows 路径处理
|
||||
```
|
||||
|
||||
### 运行 Node.js 测试
|
||||
|
||||
```bash
|
||||
# 运行 Node 测试
|
||||
node --test tests/node/stream-tool-sieve.test.js
|
||||
|
||||
# 或使用脚本
|
||||
./tests/scripts/run-unit-node.sh
|
||||
```
|
||||
|
||||
### 跑端到端测试(跳过 preflight)
|
||||
|
||||
```bash
|
||||
|
||||
@@ -9,20 +9,17 @@
|
||||
{
|
||||
"_comment": "邮箱登录方式",
|
||||
"email": "example1@example.com",
|
||||
"password": "your-password-1",
|
||||
"token": ""
|
||||
"password": "your-password-1"
|
||||
},
|
||||
{
|
||||
"_comment": "邮箱登录方式 - 账号2",
|
||||
"email": "example2@example.com",
|
||||
"password": "your-password-2",
|
||||
"token": ""
|
||||
"password": "your-password-2"
|
||||
},
|
||||
{
|
||||
"_comment": "手机号登录方式(中国大陆)",
|
||||
"mobile": "12345678901",
|
||||
"password": "your-password-3",
|
||||
"token": ""
|
||||
"password": "your-password-3"
|
||||
}
|
||||
],
|
||||
"model_aliases": {
|
||||
@@ -43,8 +40,19 @@
|
||||
"embeddings": {
|
||||
"provider": "deterministic"
|
||||
},
|
||||
"claude_model_mapping": {
|
||||
"claude_mapping": {
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner"
|
||||
},
|
||||
"admin": {
|
||||
"jwt_expire_hours": 24
|
||||
},
|
||||
"runtime": {
|
||||
"account_max_inflight": 2,
|
||||
"account_max_queue": 0,
|
||||
"global_max_inflight": 0
|
||||
},
|
||||
"auto_delete": {
|
||||
"sessions": false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ func TestPoolAccountConcurrencyAliasEnv(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
|
||||
func TestPoolDropsLegacyTokenOnlyAccountOnLoad(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
@@ -203,19 +203,15 @@ func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
|
||||
|
||||
pool := NewPool(config.LoadStore())
|
||||
status := pool.Status()
|
||||
if got, ok := status["total"].(int); !ok || got != 1 {
|
||||
if got, ok := status["total"].(int); !ok || got != 0 {
|
||||
t.Fatalf("unexpected total in pool status: %#v", status["total"])
|
||||
}
|
||||
if got, ok := status["available"].(int); !ok || got != 1 {
|
||||
if got, ok := status["available"].(int); !ok || got != 0 {
|
||||
t.Fatalf("unexpected available in pool status: %#v", status["available"])
|
||||
}
|
||||
|
||||
acc, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatalf("expected acquire success for token-only account")
|
||||
}
|
||||
if acc.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token on acquired account: %q", acc.Token)
|
||||
if _, ok := pool.Acquire("", nil); ok {
|
||||
t.Fatalf("expected acquire to fail for token-only account")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -358,7 +358,7 @@ func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeDoesNotStopOnUnclosedFencedToolExample(t *testing.T) {
|
||||
func TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Here is an example:\\n```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"Bash\\\",\\\"input\\\":{\\\"command\\\":\\\"pwd\\\"}}]}\"}",
|
||||
@@ -371,22 +371,32 @@ func TestHandleClaudeStreamRealtimeDoesNotStopOnUnclosedFencedToolExample(t *tes
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
t.Fatalf("unexpected tool_use for fenced example, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
foundEndTurn := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "end_turn" {
|
||||
foundEndTurn = true
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundEndTurn {
|
||||
t.Fatalf("expected stop_reason=end_turn, body=%s", rec.Body.String())
|
||||
if foundToolUse {
|
||||
t.Fatalf("expected no tool_use for fenced example, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundToolStop := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "tool_use" {
|
||||
foundToolStop = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundToolStop {
|
||||
t.Fatalf("expected stop_reason to remain content-only, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestHandleClaudeStreamRealtimePromotesUnclosedFencedToolExample(t *testing.T) {
|
||||
TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t)
|
||||
}
|
||||
|
||||
@@ -48,10 +48,85 @@ func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role preserved, got %#v", m["role"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") {
|
||||
t.Fatalf("expected serialized tool result marker, got %q", content)
|
||||
if content != "tool output" {
|
||||
t.Fatalf("expected raw tool output content preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_1",
|
||||
"name": "search_web",
|
||||
"input": map[string]any{"query": "latest"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized tool-call message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant role, got %#v", m["role"])
|
||||
}
|
||||
tc, _ := m["tool_calls"].([]any)
|
||||
if len(tc) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", m["tool_calls"])
|
||||
}
|
||||
call, _ := tc[0].(map[string]any)
|
||||
if call["id"] != "call_1" {
|
||||
t.Fatalf("expected call id preserved, got %#v", call)
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, "search_web") || !containsStr(content, `"arguments":"{\"query\":\"latest\"}"`) {
|
||||
t.Fatalf("expected assistant content to include serialized tool call for prompt roundtrip, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesDoesNotPromoteUserToolUse(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_unsafe",
|
||||
"name": "dangerous_tool",
|
||||
"input": map[string]any{"value": "x"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "user" {
|
||||
t.Fatalf("expected user role preserved, got %#v", m["role"])
|
||||
}
|
||||
if _, ok := m["tool_calls"]; ok {
|
||||
t.Fatalf("expected no tool_calls promotion for user message, got %#v", m["tool_calls"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, `"type":"tool_use"`) || !containsStr(content, "dangerous_tool") {
|
||||
t.Fatalf("expected raw tool_use block preserved in user content, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,15 +162,63 @@ func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) {
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "Hello"},
|
||||
map[string]any{"type": "image", "source": "data:..."},
|
||||
map[string]any{"type": "image", "source": map[string]any{"type": "base64", "data": strings.Repeat("A", 2048)}},
|
||||
map[string]any{"type": "text", "text": "World"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "Hello\nWorld" {
|
||||
t.Fatalf("expected only text parts joined, got %q", m["content"])
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, "Hello") || !containsStr(content, "World") || !containsStr(content, `"type":"image"`) {
|
||||
t.Fatalf("expected text plus non-text block marker preserved, got %q", content)
|
||||
}
|
||||
if !containsStr(content, omittedBinaryMarker) {
|
||||
t.Fatalf("expected binary payload omitted marker, got %q", content)
|
||||
}
|
||||
if containsStr(content, strings.Repeat("A", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolResultNonTextPayloadStringified(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_image_1",
|
||||
"name": "vision_tool",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "image analysis"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{"type": "base64", "media_type": "image/png", "data": strings.Repeat("B", 2048)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", m["role"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, `"type":"tool_result"`) || !containsStr(content, `"type":"image"`) {
|
||||
t.Fatalf("expected non-text tool_result payload to be JSON stringified, got %q", content)
|
||||
}
|
||||
if !containsStr(content, omittedBinaryMarker) {
|
||||
t.Fatalf("expected binary data to be sanitized with omitted marker, got %q", content)
|
||||
}
|
||||
if containsStr(content, strings.Repeat("B", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,8 +251,11 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
||||
if !containsStr(prompt, "tool_use") {
|
||||
t.Fatalf("expected tool_use instruction in prompt")
|
||||
}
|
||||
if containsStr(prompt, "tool_calls") {
|
||||
t.Fatalf("expected prompt to avoid tool_calls JSON instruction")
|
||||
if containsStr(prompt, "TOOL_CALL_HISTORY") || containsStr(prompt, "TOOL_RESULT_HISTORY") {
|
||||
t.Fatalf("expected legacy tool history markers removed from prompt")
|
||||
}
|
||||
if !containsStr(prompt, "Do not print tool-call JSON in text") {
|
||||
t.Fatalf("expected prompt to keep no tool-call-json instruction")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,28 +13,58 @@ func normalizeClaudeMessages(messages []any) []any {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
role := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", msg["role"])))
|
||||
switch content := msg["content"].(type) {
|
||||
case []any:
|
||||
parts := make([]string, 0, len(content))
|
||||
textParts := make([]string, 0, len(content))
|
||||
flushText := func() {
|
||||
if len(textParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": strings.Join(textParts, "\n"),
|
||||
})
|
||||
textParts = textParts[:0]
|
||||
}
|
||||
for _, block := range content {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr, _ := b["type"].(string)
|
||||
if typeStr == "text" {
|
||||
typeStr := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", b["type"])))
|
||||
switch typeStr {
|
||||
case "text":
|
||||
if t, ok := b["text"].(string); ok {
|
||||
parts = append(parts, t)
|
||||
textParts = append(textParts, t)
|
||||
}
|
||||
case "tool_use":
|
||||
if role == "assistant" {
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolUseToAssistant(b); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
case "tool_result":
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolResultToToolMessage(b); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
default:
|
||||
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
}
|
||||
if typeStr == "tool_result" {
|
||||
parts = append(parts, formatClaudeToolResultForPrompt(b))
|
||||
}
|
||||
}
|
||||
copied["content"] = strings.Join(parts, "\n")
|
||||
flushText()
|
||||
default:
|
||||
copied := cloneMap(msg)
|
||||
out = append(out, copied)
|
||||
}
|
||||
out = append(out, copied)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -52,8 +82,8 @@ func buildClaudeToolPrompt(tools []any) string {
|
||||
}
|
||||
parts = append(parts,
|
||||
"When you need a tool, respond with Claude-native tool use (tool_use) using the provided tool schema. Do not print tool-call JSON in text.",
|
||||
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
||||
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||
"Tool roundtrip context is included directly in the conversation messages (assistant tool_use/tool_calls and tool results).",
|
||||
"After receiving a valid tool result, continue with final answer instead of repeating the same call unless required fields are still missing.",
|
||||
)
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
@@ -62,22 +92,111 @@ func formatClaudeToolResultForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "tool_result",
|
||||
"content": block["content"],
|
||||
}
|
||||
if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])); toolCallID != "" {
|
||||
payload["tool_call_id"] = toolCallID
|
||||
} else if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])); toolCallID != "" {
|
||||
payload["tool_call_id"] = toolCallID
|
||||
}
|
||||
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
|
||||
payload["name"] = name
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", payload))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
callID := strings.TrimSpace(fmt.Sprintf("%v", block["id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_claude"
|
||||
}
|
||||
arguments := block["input"]
|
||||
if arguments == nil {
|
||||
arguments = map[string]any{}
|
||||
}
|
||||
argsJSON, err := json.Marshal(arguments)
|
||||
if err != nil || len(argsJSON) == 0 {
|
||||
argsJSON = []byte("{}")
|
||||
}
|
||||
toolCalls := []any{
|
||||
map[string]any{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": string(argsJSON),
|
||||
},
|
||||
},
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"content": marshalCompactJSON(toolCalls),
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "unknown"
|
||||
toolCallID = "call_claude"
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": toolCallID,
|
||||
"content": normalizeClaudeToolResultContent(block["content"]),
|
||||
}
|
||||
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
|
||||
if content == "" {
|
||||
content = "null"
|
||||
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeClaudeToolResultContent(content any) any {
|
||||
if text, ok := content.(string); ok {
|
||||
return text
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "tool_result",
|
||||
"content": content,
|
||||
}
|
||||
b, err := json.Marshal(sanitizeClaudeBlockForPrompt(payload))
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", content))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func formatClaudeBlockRaw(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
b, err := json.Marshal(block)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", block))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
|
||||
105
internal/adapter/claude/handler_utils_sanitize.go
Normal file
105
internal/adapter/claude/handler_utils_sanitize.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
maxClaudeRawPromptChars = 1024
|
||||
omittedBinaryMarker = "[omitted_binary_payload]"
|
||||
)
|
||||
|
||||
func formatClaudeUnknownBlockForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
safe := sanitizeClaudeBlockForPrompt(block)
|
||||
raw := strings.TrimSpace(formatClaudeBlockRaw(safe))
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if len(raw) > maxClaudeRawPromptChars {
|
||||
return raw[:maxClaudeRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func sanitizeClaudeBlockForPrompt(block map[string]any) map[string]any {
|
||||
out := cloneMap(block)
|
||||
for k, v := range out {
|
||||
if looksLikeBinaryFieldName(k) {
|
||||
out[k] = omittedBinaryMarker
|
||||
continue
|
||||
}
|
||||
switch inner := v.(type) {
|
||||
case map[string]any:
|
||||
out[k] = sanitizeClaudeBlockForPrompt(inner)
|
||||
case []any:
|
||||
out[k] = sanitizeClaudeArrayForPrompt(inner)
|
||||
case string:
|
||||
out[k] = sanitizeClaudeStringForPrompt(k, inner)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeClaudeArrayForPrompt(items []any) []any {
|
||||
out := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
switch v := item.(type) {
|
||||
case map[string]any:
|
||||
out = append(out, sanitizeClaudeBlockForPrompt(v))
|
||||
case []any:
|
||||
out = append(out, sanitizeClaudeArrayForPrompt(v))
|
||||
default:
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeClaudeStringForPrompt(key, value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if looksLikeBinaryFieldName(key) || looksLikeBase64Payload(trimmed) {
|
||||
return omittedBinaryMarker
|
||||
}
|
||||
if len(trimmed) > maxClaudeRawPromptChars {
|
||||
return trimmed[:maxClaudeRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func looksLikeBinaryFieldName(name string) bool {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
return n == "data" || n == "bytes" || n == "base64" || n == "inline_data" || n == "inlinedata"
|
||||
}
|
||||
|
||||
func looksLikeBase64Payload(v string) bool {
|
||||
if len(v) < 512 {
|
||||
return false
|
||||
}
|
||||
compact := strings.TrimRight(v, "=")
|
||||
if compact == "" {
|
||||
return false
|
||||
}
|
||||
for _, ch := range compact {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func marshalCompactJSON(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package gemini
|
||||
|
||||
import "strings"
|
||||
|
||||
const maxGeminiRawPromptChars = 1024
|
||||
|
||||
func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
out := make([]any, 0, 8)
|
||||
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
|
||||
@@ -107,6 +109,11 @@ func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
msg["name"] = name
|
||||
}
|
||||
out = append(out, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(formatGeminiUnknownPartForPrompt(part)); raw != "" && raw != "null" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
}
|
||||
flushText()
|
||||
@@ -151,3 +158,87 @@ func mapGeminiRole(v any) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func formatGeminiUnknownPartForPrompt(part map[string]any) string {
|
||||
safe := sanitizeGeminiPartForPrompt(part)
|
||||
raw := strings.TrimSpace(stringifyJSON(safe))
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if len(raw) > maxGeminiRawPromptChars {
|
||||
return raw[:maxGeminiRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func sanitizeGeminiPartForPrompt(part map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(part))
|
||||
for k, v := range part {
|
||||
if looksLikeGeminiBinaryField(k) {
|
||||
out[k] = "[omitted_binary_payload]"
|
||||
continue
|
||||
}
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
out[k] = sanitizeGeminiPartForPrompt(x)
|
||||
case []any:
|
||||
out[k] = sanitizeGeminiArrayForPrompt(x)
|
||||
case string:
|
||||
out[k] = sanitizeGeminiStringForPrompt(k, x)
|
||||
default:
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeGeminiArrayForPrompt(items []any) []any {
|
||||
out := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
out = append(out, sanitizeGeminiPartForPrompt(x))
|
||||
case []any:
|
||||
out = append(out, sanitizeGeminiArrayForPrompt(x))
|
||||
default:
|
||||
out = append(out, x)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeGeminiStringForPrompt(key, value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if looksLikeGeminiBinaryField(key) || looksLikeGeminiBase64(trimmed) {
|
||||
return "[omitted_binary_payload]"
|
||||
}
|
||||
if len(trimmed) > maxGeminiRawPromptChars {
|
||||
return trimmed[:maxGeminiRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func looksLikeGeminiBinaryField(name string) bool {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
return n == "data" || n == "bytes" || n == "inlinedata" || n == "inline_data" || n == "base64"
|
||||
}
|
||||
|
||||
func looksLikeGeminiBase64(v string) bool {
|
||||
if len(v) < 512 {
|
||||
return false
|
||||
}
|
||||
compact := strings.TrimRight(v, "=")
|
||||
if compact == "" {
|
||||
return false
|
||||
}
|
||||
for _, ch := range compact {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
84
internal/adapter/gemini/convert_messages_test.go
Normal file
84
internal/adapter/gemini/convert_messages_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGeminiMessagesFromRequestPreservesFunctionRoundtrip(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "model",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"id": "call_g1",
|
||||
"name": "search_web",
|
||||
"args": map[string]any{"query": "ai"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionResponse": map[string]any{
|
||||
"id": "call_g1",
|
||||
"name": "search_web",
|
||||
"response": "ok",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := geminiMessagesFromRequest(req)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected two normalized messages, got %#v", got)
|
||||
}
|
||||
assistant, _ := got[0].(map[string]any)
|
||||
if assistant["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant first, got %#v", assistant)
|
||||
}
|
||||
tc, _ := assistant["tool_calls"].([]any)
|
||||
if len(tc) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", assistant["tool_calls"])
|
||||
}
|
||||
toolMsg, _ := got[1].(map[string]any)
|
||||
if toolMsg["role"] != "tool" || toolMsg["tool_call_id"] != "call_g1" {
|
||||
t.Fatalf("expected tool message with call id, got %#v", toolMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{"text": "hello"},
|
||||
map[string]any{"inlineData": map[string]any{"mimeType": "image/png", "data": strings.Repeat("A", 2048)}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := geminiMessagesFromRequest(req)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %#v", got)
|
||||
}
|
||||
msg, _ := got[0].(map[string]any)
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "hello") || !strings.Contains(content, "inlineData") {
|
||||
t.Fatalf("expected unknown part preserved as raw json text, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "[omitted_binary_payload]") {
|
||||
t.Fatalf("expected inlineData payload to be redacted, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, strings.Repeat("A", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be embedded, got %q", content)
|
||||
}
|
||||
}
|
||||
@@ -97,12 +97,12 @@ func (s *chatStreamRuntime) sendDone() {
|
||||
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
detected := util.ParseStandaloneToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
||||
finalText := sanitizeLeakedToolHistory(s.text.String())
|
||||
detected := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs),
|
||||
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
@@ -141,8 +141,12 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
}
|
||||
cleaned := sanitizeLeakedToolHistory(evt.Content)
|
||||
if cleaned == "" {
|
||||
continue
|
||||
}
|
||||
delta := map[string]any{
|
||||
"content": evt.Content,
|
||||
"content": cleaned,
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
@@ -158,7 +162,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(detected) > 0 || s.toolCallsEmitted {
|
||||
if len(detected.Calls) > 0 || s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
@@ -246,8 +250,12 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
cleaned := sanitizeLeakedToolHistory(evt.Content)
|
||||
if cleaned == "" {
|
||||
continue
|
||||
}
|
||||
contentDelta := map[string]any{
|
||||
"content": evt.Content,
|
||||
"content": cleaned,
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
contentDelta["role"] = "assistant"
|
||||
|
||||
@@ -19,6 +19,7 @@ type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
DeleteAllSessionsForToken(ctx context.Context, token string) error
|
||||
}
|
||||
|
||||
type ConfigReader interface {
|
||||
@@ -28,6 +29,7 @@ type ConfigReader interface {
|
||||
ToolcallEarlyEmitConfidence() string
|
||||
ResponsesStoreTTLSeconds() int
|
||||
EmbeddingsProvider() string
|
||||
AutoDeleteSessions() bool
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
|
||||
@@ -19,6 +19,7 @@ func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMo
|
||||
func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit }
|
||||
func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL }
|
||||
func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv }
|
||||
func (m mockOpenAIConfig) AutoDeleteSessions() bool { return false }
|
||||
|
||||
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||
cfg := mockOpenAIConfig{
|
||||
|
||||
@@ -35,7 +35,25 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
defer func() {
|
||||
// 自动删除会话(同步)
|
||||
// 必须在 Release 之前同步删除,否则:
|
||||
// 1. 异步删除时账号已被 Release
|
||||
// 2. 新请求可能获取到同一账号并开始使用
|
||||
// 3. 异步删除仍在进行,会截断新请求正在使用的会话
|
||||
if h.Store.AutoDeleteSessions() && a.DeepSeekToken != "" {
|
||||
deleteCtx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
err := h.DS.DeleteAllSessionsForToken(deleteCtx, a.DeepSeekToken)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[auto_delete_sessions] failed", "account", a.AccountID, "error", err)
|
||||
} else {
|
||||
config.Logger.Debug("[auto_delete_sessions] success", "account", a.AccountID)
|
||||
}
|
||||
}
|
||||
h.Auth.Release(a)
|
||||
}()
|
||||
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
var req map[string]any
|
||||
@@ -87,7 +105,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := result.Text
|
||||
finalText := sanitizeLeakedToolHistory(result.Text)
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
@@ -110,8 +128,8 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
}
|
||||
|
||||
created := time.Now().Unix()
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
emitEarlyToolDeltas := h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence()
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
|
||||
@@ -53,13 +53,13 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolCh
|
||||
if len(toolSchemas) == 0 {
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON object format:\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\n【EXAMPLE】\nUser: Please check the weather in Beijing and Shanghai, and update my todo list.\nAssistant:\n{\"tool_calls\": [\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Beijing\"}},\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Shanghai\"}},\n {\"name\": \"update_todo\", \"input\": {\"todos\": [{\"content\": \"Buy milk\"}, {\"content\": \"Write report\"}]}}\n]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON object above. Do NOT include any extra text.\n2) Do NOT wrap tool-call JSON in markdown/code fences (for example, do not use triple backticks).\n3) After receiving a tool result, you MUST use it to produce the final answer.\n4) Only call another tool when the previous result is missing required data or returned an error.\n5) JSON SYNTAX STRICTLY REQUIRED: All property names MUST be enclosed in double quotes (e.g., \"name\", not name).\n6) ARRAY FORMAT: If providing a list of items, you MUST enclose them in square brackets `[]` (e.g., \"todos\": [{\"item\": \"a\"}, {\"item\": \"b\"}]). DO NOT output comma-separated objects without brackets."
|
||||
if policy.Mode == util.ToolChoiceRequired {
|
||||
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
|
||||
toolPrompt += "\n7) For this response, you MUST call at least one tool from the allowed list."
|
||||
}
|
||||
if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
|
||||
toolPrompt += "\n5) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
||||
toolPrompt += "\n6) Do not call any other tool."
|
||||
toolPrompt += "\n7) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
||||
toolPrompt += "\n8) Do not call any other tool."
|
||||
}
|
||||
|
||||
for i := range messages {
|
||||
|
||||
@@ -2,12 +2,6 @@ package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
|
||||
for k, v := range collectOpenAIChatPassThrough(req) {
|
||||
payload[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallFeatureMatchEnabled() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
|
||||
@@ -211,7 +211,7 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamEmbeddedToolCallExampleRemainsText(t *testing.T) {
|
||||
func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
@@ -229,20 +229,21 @@ func TestHandleNonStreamEmbeddedToolCallExampleRemainsText(t *testing.T) {
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool_call field for embedded example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "下面是示例:") || !strings.Contains(content, "请勿执行。") || !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded example to remain plain text, got %#v", content)
|
||||
if strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected raw tool_calls json stripped from content, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
|
||||
func TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
||||
@@ -258,19 +259,25 @@ func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
if choice["finish_reason"] == "tool_calls" {
|
||||
t.Fatalf("expected fenced example to remain content-only, got finish_reason=%#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"])
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) != 0 {
|
||||
t.Fatalf("expected no tool_call field for fenced example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected fenced tool example to pass through as text, got %q", content)
|
||||
if !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected fenced example content preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -615,7 +622,7 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamFencedToolCallSnippetRemainsText(t *testing.T) {
|
||||
func TestHandleStreamFencedToolCallSnippetPromotesToolCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
|
||||
@@ -631,8 +638,8 @@ func TestHandleStreamFencedToolCallSnippetRemainsText(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -646,11 +653,53 @@ func TestHandleStreamFencedToolCallSnippetRemainsText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "```json") || !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected fenced tool snippet in content, got=%q", got)
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected raw fenced tool_calls snippet stripped from content, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
if strings.Contains(strings.ToLower(got), "```json") || strings.Contains(got, "\n```\n") {
|
||||
t.Fatalf("expected consumed fenced tool payload to not leave empty code fence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamStandaloneToolCallAfterClosedFenceKeepsFence(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "先给一个代码示例:\n```text\nhello\n```\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7g", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta for standalone payload, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "```") {
|
||||
t.Fatalf("expected closed fence before standalone tool json to be preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/prompt"
|
||||
)
|
||||
|
||||
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
|
||||
_ = traceID
|
||||
out := make([]map[string]any, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
msg, ok := item.(map[string]any)
|
||||
@@ -19,20 +19,19 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
|
||||
switch role {
|
||||
case "assistant":
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
toolCalls := formatAssistantToolCallsForPrompt(msg, traceID)
|
||||
combined := joinNonEmpty(content, toolCalls)
|
||||
if combined == "" {
|
||||
content := buildAssistantContentForPrompt(msg)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"content": combined,
|
||||
"content": content,
|
||||
})
|
||||
case "tool", "function":
|
||||
content := buildToolContentForPrompt(msg)
|
||||
out = append(out, map[string]any{
|
||||
"role": "user",
|
||||
"content": formatToolResultForPrompt(msg),
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
})
|
||||
case "user", "system", "developer":
|
||||
out = append(out, map[string]any{
|
||||
@@ -56,95 +55,54 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
return out
|
||||
}
|
||||
|
||||
func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string {
|
||||
entries := make([]string, 0)
|
||||
if calls, ok := msg["tool_calls"].([]any); ok {
|
||||
for i, item := range calls {
|
||||
call, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id := strings.TrimSpace(asString(call["id"]))
|
||||
if id == "" {
|
||||
id = fmt.Sprintf("call_%d", i+1)
|
||||
}
|
||||
name := strings.TrimSpace(asString(call["name"]))
|
||||
args := ""
|
||||
|
||||
if fn, ok := call["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asString(fn["name"]))
|
||||
}
|
||||
args = normalizeOpenAIArgumentsForPrompt(fn["arguments"])
|
||||
}
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if args == "" {
|
||||
args = normalizeOpenAIArgumentsForPrompt(call["arguments"])
|
||||
}
|
||||
if args == "" {
|
||||
args = normalizeOpenAIArgumentsForPrompt(call["input"])
|
||||
}
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
maybeWarnSuspiciousToolHistory(traceID, id, name, args)
|
||||
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args))
|
||||
}
|
||||
func buildAssistantContentForPrompt(msg map[string]any) string {
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
toolCalls := normalizeAssistantToolCallsForPrompt(msg["tool_calls"])
|
||||
if toolCalls == "" {
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
if legacy, ok := msg["function_call"].(map[string]any); ok {
|
||||
name := strings.TrimSpace(asString(legacy["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
}
|
||||
args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"])
|
||||
if args == "" {
|
||||
args = "{}"
|
||||
}
|
||||
maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args)
|
||||
entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args))
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
return strings.Join(entries, "\n\n")
|
||||
return strings.TrimSpace(content + "\n" + toolCalls)
|
||||
}
|
||||
|
||||
func formatToolResultForPrompt(msg map[string]any) string {
|
||||
toolCallID := strings.TrimSpace(asString(msg["tool_call_id"]))
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(asString(msg["id"]))
|
||||
func normalizeAssistantToolCallsForPrompt(v any) string {
|
||||
calls, ok := v.([]any)
|
||||
if !ok || len(calls) == 0 {
|
||||
return ""
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = "unknown"
|
||||
b, err := json.Marshal(calls)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", calls))
|
||||
}
|
||||
return strings.TrimSpace(string(b))
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(asString(msg["name"]))
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
func buildToolContentForPrompt(msg map[string]any) string {
|
||||
payload := map[string]any{
|
||||
"content": msg["content"],
|
||||
}
|
||||
|
||||
content := normalizeOpenAIContentForPrompt(msg["content"])
|
||||
if content == "" {
|
||||
content = "null"
|
||||
if id := strings.TrimSpace(asString(msg["tool_call_id"])); id != "" {
|
||||
payload["tool_call_id"] = id
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
||||
if id := strings.TrimSpace(asString(msg["id"])); id != "" {
|
||||
payload["id"] = id
|
||||
}
|
||||
if name := strings.TrimSpace(asString(msg["name"])); name != "" {
|
||||
payload["name"] = name
|
||||
}
|
||||
content := normalizeOpenAIContentForPrompt(payload)
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return `{"content":"null"}`
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func normalizeOpenAIContentForPrompt(v any) string {
|
||||
return prompt.NormalizeContent(v)
|
||||
}
|
||||
|
||||
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return normalizeToolArgumentString(x)
|
||||
default:
|
||||
return marshalToPromptString(v)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeToolArgumentString(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
@@ -157,14 +115,6 @@ func normalizeToolArgumentString(raw string) string {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func marshalToPromptString(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func normalizeOpenAIRoleForPrompt(role string) string {
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if role == "developer" {
|
||||
@@ -180,34 +130,6 @@ func asString(v any) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func joinNonEmpty(parts ...string) string {
|
||||
nonEmpty := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if strings.TrimSpace(p) == "" {
|
||||
continue
|
||||
}
|
||||
nonEmpty = append(nonEmpty, p)
|
||||
}
|
||||
return strings.Join(nonEmpty, "\n\n")
|
||||
}
|
||||
|
||||
func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) {
|
||||
if !looksLikeConcatenatedJSON(args) {
|
||||
return
|
||||
}
|
||||
traceID = strings.TrimSpace(traceID)
|
||||
if traceID == "" {
|
||||
traceID = "unknown"
|
||||
}
|
||||
config.Logger.Warn(
|
||||
"[openai] suspicious tool call history payload detected",
|
||||
"trace_id", traceID,
|
||||
"tool_call_id", strings.TrimSpace(callID),
|
||||
"name", strings.TrimSpace(name),
|
||||
"arguments_preview", previewToolArgs(args, 160),
|
||||
)
|
||||
}
|
||||
|
||||
func looksLikeConcatenatedJSON(raw string) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
@@ -224,11 +146,3 @@ func looksLikeConcatenatedJSON(raw string) bool {
|
||||
var second any
|
||||
return dec.Decode(&second) == nil
|
||||
}
|
||||
|
||||
func previewToolArgs(raw string, max int) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if max <= 0 || len(trimmed) <= max {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:max]
|
||||
}
|
||||
|
||||
@@ -35,23 +35,19 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 4 {
|
||||
t.Fatalf("expected 4 normalized messages, got %d", len(normalized))
|
||||
}
|
||||
assistantContent, _ := normalized[2]["content"].(string)
|
||||
if !strings.Contains(assistantContent, "[TOOL_CALL_HISTORY]") ||
|
||||
!strings.Contains(assistantContent, "tool_call_id: call_1") ||
|
||||
!strings.Contains(assistantContent, "function.name: get_weather") ||
|
||||
!strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") {
|
||||
t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent)
|
||||
t.Fatalf("expected 4 normalized messages with assistant tool_call history preserved, got %d", len(normalized))
|
||||
}
|
||||
toolContent, _ := normalized[3]["content"].(string)
|
||||
if !strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") || !strings.Contains(toolContent, "name: get_weather") {
|
||||
t.Fatalf("tool result not serialized correctly: %q", toolContent)
|
||||
if !strings.Contains(toolContent, `\"temp\":18`) {
|
||||
t.Fatalf("tool result should be transparently forwarded, got %q", toolContent)
|
||||
}
|
||||
if strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("tool history marker should not be injected: %q", toolContent)
|
||||
}
|
||||
|
||||
prompt := util.MessagesPrepare(normalized)
|
||||
if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt)
|
||||
if strings.Contains(prompt, "[TOOL_CALL_HISTORY]") || strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("expected no synthetic history markers in prompt: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,8 +87,8 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) {
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "line-1\nline-2") {
|
||||
t.Fatalf("expected joined text blocks, got %q", got)
|
||||
if !strings.Contains(got, `"line-1"`) || !strings.Contains(got, `"line-2"`) || !strings.Contains(got, `"name":"read_file"`) {
|
||||
t.Fatalf("expected tool envelope to preserve content blocks and metadata, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,15 +108,42 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) {
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
if normalized[0]["role"] != "user" {
|
||||
t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"])
|
||||
if normalized[0]["role"] != "tool" {
|
||||
t.Fatalf("expected function role normalized as tool, got %#v", normalized[0]["role"])
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) {
|
||||
if !strings.Contains(got, `"name":"legacy_tool"`) || !strings.Contains(got, `"ok":true`) {
|
||||
t.Fatalf("unexpected normalized function-role content: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_EmptyToolContentPreservedAsNull(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_5",
|
||||
"name": "noop_tool",
|
||||
"content": "",
|
||||
},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "done",
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 2 {
|
||||
t.Fatalf("expected tool completion turn to be preserved, got %#v", normalized)
|
||||
}
|
||||
if normalized[0]["role"] != "tool" {
|
||||
t.Fatalf("expected tool role preserved, got %#v", normalized[0]["role"])
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"content":""`) || !strings.Contains(got, `"name":"noop_tool"`) || !strings.Contains(got, `"tool_call_id":"call_5"`) {
|
||||
t.Fatalf("expected tool metadata preserved in content envelope, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
@@ -148,23 +171,11 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized assistant message, got %d", len(normalized))
|
||||
t.Fatalf("expected assistant tool_call-only message to be preserved, got %#v", normalized)
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 {
|
||||
t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") {
|
||||
t.Fatalf("missing first tool call block, got %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") {
|
||||
t.Fatalf("missing second tool call block, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, "search_webeval_javascript") {
|
||||
t.Fatalf("unexpected merged function name detected: %q", content)
|
||||
}
|
||||
if strings.Contains(content, `}{"`) {
|
||||
t.Fatalf("unexpected concatenated function arguments detected: %q", content)
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `"name":"search_web"`) || !strings.Contains(got, `"name":"eval_javascript"`) {
|
||||
t.Fatalf("expected tool_calls payload preserved in assistant content, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,15 +197,14 @@ func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
t.Fatalf("expected assistant tool_call-only content to be preserved, got %#v", normalized)
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(content, `function.arguments: {}{"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, `{}{\"query\":\"测试工具调用\"}`) {
|
||||
t.Fatalf("expected concatenated arguments preserved verbatim, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDropped(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
@@ -212,8 +222,12 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDroppe
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 0 {
|
||||
t.Fatalf("expected nameless assistant tool_calls to be dropped, got %#v", normalized)
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected assistant tool_calls history to be preserved even when name missing, got %#v", normalized)
|
||||
}
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "call_missing_name") {
|
||||
t.Fatalf("expected raw tool_call payload preserved, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,14 +250,11 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLi
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
t.Fatalf("expected nil-content assistant tool_call-only message to be preserved, got %#v", normalized)
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if strings.Contains(content, "<|Assistant|>null") || strings.HasPrefix(strings.TrimSpace(content), "null") {
|
||||
t.Fatalf("unexpected null literal injected into assistant tool history: %q", content)
|
||||
}
|
||||
if !strings.Contains(content, "function.name: send_file_to_user") {
|
||||
t.Fatalf("expected tool history block preserved, got %q", content)
|
||||
got, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(got, "send_file_to_user") {
|
||||
t.Fatalf("expected tool call payload preserved, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,11 +44,11 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes
|
||||
if len(toolNames) != 1 || toolNames[0] != "get_weather" {
|
||||
t.Fatalf("unexpected tool names: %#v", toolNames)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "tool_call_id: call_1") ||
|
||||
!strings.Contains(finalPrompt, "function.name: get_weather") ||
|
||||
!strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") ||
|
||||
!strings.Contains(finalPrompt, `"condition":"sunny"`) {
|
||||
t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt)
|
||||
if !strings.Contains(finalPrompt, `"condition":"sunny"`) {
|
||||
t.Fatalf("handler finalPrompt should preserve tool output content: %q", finalPrompt)
|
||||
}
|
||||
if strings.Contains(finalPrompt, "[TOOL_CALL_HISTORY]") || strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("handler finalPrompt should not include synthetic history markers: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,10 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt)
|
||||
}
|
||||
if !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing history marker instruction: %q", finalPrompt)
|
||||
if !strings.Contains(finalPrompt, "Do NOT wrap tool-call JSON in markdown/code fences") {
|
||||
t.Fatalf("vercel prepare finalPrompt missing no-fence instruction: %q", finalPrompt)
|
||||
}
|
||||
if strings.Contains(finalPrompt, "```json") {
|
||||
t.Fatalf("vercel prepare finalPrompt should not require fenced json tool calls: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +113,8 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(result.Text, toolNames)
|
||||
sanitizedText := sanitizeLeakedToolHistory(result.Text)
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(sanitizedText, toolNames)
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||
|
||||
callCount := len(textParsed.Calls)
|
||||
@@ -122,7 +123,7 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, sanitizedText, toolNames)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
@@ -145,8 +146,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
emitEarlyToolDeltas := h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence()
|
||||
|
||||
streamRuntime := newResponsesStreamRuntime(
|
||||
w,
|
||||
|
||||
@@ -19,6 +19,27 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role != "" {
|
||||
if role == "assistant" {
|
||||
out := map[string]any{
|
||||
"role": "assistant",
|
||||
}
|
||||
if toolCalls, ok := m["tool_calls"].([]any); ok && len(toolCalls) > 0 {
|
||||
out["tool_calls"] = toolCalls
|
||||
}
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content != nil {
|
||||
out["content"] = content
|
||||
}
|
||||
if _, hasToolCalls := out["tool_calls"]; hasToolCalls || out["content"] != nil {
|
||||
return out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
@@ -28,10 +49,22 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
out := map[string]any{
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
}
|
||||
if role == "tool" || role == "function" {
|
||||
if callID := strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
|
||||
@@ -32,7 +32,6 @@ type responsesStreamRuntime struct {
|
||||
toolCallsDoneEmitted bool
|
||||
|
||||
sieve toolStreamSieveState
|
||||
thinkingSieve toolStreamSieveState
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
visibleText strings.Builder
|
||||
@@ -98,7 +97,7 @@ func newResponsesStreamRuntime(
|
||||
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
finalText := sanitizeLeakedToolHistory(s.text.String())
|
||||
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
@@ -169,15 +168,6 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed util.ToolCal
|
||||
logRejected(textParsed, "text")
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
||||
for _, done := range s.functionDone {
|
||||
if done {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
@@ -204,12 +194,16 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
continue
|
||||
}
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if !s.bufferToolContent {
|
||||
s.emitTextDelta(p.Text)
|
||||
cleanedText := sanitizeLeakedToolHistory(p.Text)
|
||||
if cleanedText == "" {
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
|
||||
s.text.WriteString(cleanedText)
|
||||
if !s.bufferToolContent {
|
||||
s.emitTextDelta(cleanedText)
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, cleanedText, s.toolNames), true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
|
||||
@@ -297,7 +297,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *testing.T) {
|
||||
func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -333,6 +333,7 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
|
||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasMessage := false
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
@@ -342,12 +343,15 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
|
||||
hasMessage = true
|
||||
}
|
||||
if asString(m["type"]) == "function_call" {
|
||||
t.Fatalf("did not expect function_call output for mixed prose tool example, output=%#v", output)
|
||||
hasFunctionCall = true
|
||||
}
|
||||
}
|
||||
if !hasMessage {
|
||||
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
@@ -671,18 +675,3 @@ func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asFloat(v any) float64 {
|
||||
switch x := v.(type) {
|
||||
case float64:
|
||||
return x
|
||||
case float32:
|
||||
return float64(x)
|
||||
case int:
|
||||
return float64(x)
|
||||
case int64:
|
||||
return float64(x)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,10 @@ func (m streamStatusDSStub) CallCompletion(_ context.Context, _ *auth.RequestAut
|
||||
return m.resp, nil
|
||||
}
|
||||
|
||||
func (m streamStatusDSStub) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeOpenAISSEHTTPResponse(lines ...string) *http.Response {
|
||||
body := strings.Join(lines, "\n")
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
@@ -167,15 +171,15 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
outputText, _ := out["output_text"].(string)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one output item, got %#v", output)
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected message output item, got %#v", output)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
17
internal/adapter/openai/tool_history_sanitize.go
Normal file
17
internal/adapter/openai/tool_history_sanitize.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var leakedToolHistoryPattern = regexp.MustCompile(`(?is)\[TOOL_CALL_HISTORY\][\s\S]*?\[/TOOL_CALL_HISTORY\]|\[TOOL_RESULT_HISTORY\][\s\S]*?\[/TOOL_RESULT_HISTORY\]`)
|
||||
var emptyJSONFencePattern = regexp.MustCompile("(?is)```json\\s*```")
|
||||
|
||||
func sanitizeLeakedToolHistory(text string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
out := leakedToolHistoryPattern.ReplaceAllString(text, "")
|
||||
out = emptyJSONFencePattern.ReplaceAllString(out, "")
|
||||
return out
|
||||
}
|
||||
106
internal/adapter/openai/tool_history_sanitize_test.go
Normal file
106
internal/adapter/openai/tool_history_sanitize_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitizeLeakedToolHistoryRemovesMarkerBlocks(t *testing.T) {
|
||||
raw := "前缀\n[TOOL_CALL_HISTORY]\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_CALL_HISTORY]\n后缀"
|
||||
got := sanitizeLeakedToolHistory(raw)
|
||||
if got != "前缀\n\n后缀" {
|
||||
t.Fatalf("unexpected sanitized content: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedToolHistoryPreservesChunkWhitespace(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "trailing space kept",
|
||||
raw: "Hello ",
|
||||
want: "Hello ",
|
||||
},
|
||||
{
|
||||
name: "leading newline kept",
|
||||
raw: "\nworld",
|
||||
want: "\nworld",
|
||||
},
|
||||
{
|
||||
name: "surrounding whitespace around marker is preserved",
|
||||
raw: "A \n[TOOL_RESULT_HISTORY]\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]\n B",
|
||||
want: "A \n\n B",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := sanitizeLeakedToolHistory(tc.raw)
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected sanitize result, want %q got %q", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeLeakedToolHistoryRemovesEmptyJSONFence(t *testing.T) {
|
||||
raw := "before\n```json\n```\nafter"
|
||||
got := sanitizeLeakedToolHistory(raw)
|
||||
if got != "before\n\nafter" {
|
||||
t.Fatalf("unexpected sanitized empty json fence: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushToolSieveDropsToolHistoryLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := "[TOOL_CALL_HISTORY]\nstatus: already_called\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_CALL_HISTORY]"
|
||||
evts := processToolSieveChunk(&state, chunk, []string{"exec"})
|
||||
if len(evts) != 0 {
|
||||
t.Fatalf("expected no immediate output before history block is complete, got %+v", evts)
|
||||
}
|
||||
flushed := flushToolSieve(&state, []string{"exec"})
|
||||
if len(flushed) != 0 {
|
||||
t.Fatalf("expected history block to be swallowed, got %+v", flushed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushToolSieveDropsToolResultHistoryLeak(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
chunk := "[TOOL_RESULT_HISTORY]\nstatus: already_called\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]"
|
||||
evts := processToolSieveChunk(&state, chunk, []string{"exec"})
|
||||
if len(evts) != 0 {
|
||||
t.Fatalf("expected no immediate output before result history block is complete, got %+v", evts)
|
||||
}
|
||||
flushed := flushToolSieve(&state, []string{"exec"})
|
||||
if len(flushed) != 0 {
|
||||
t.Fatalf("expected result history block to be swallowed, got %+v", flushed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveChunkSplitsResultHistoryBoundary(t *testing.T) {
|
||||
var state toolStreamSieveState
|
||||
parts := []string{
|
||||
"Hello ",
|
||||
"[TOOL_RESULT_HISTORY]\nstatus: already_called\n",
|
||||
"function.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]",
|
||||
"world",
|
||||
}
|
||||
var events []toolStreamEvent
|
||||
for _, p := range parts {
|
||||
events = append(events, processToolSieveChunk(&state, p, []string{"exec"})...)
|
||||
}
|
||||
events = append(events, flushToolSieve(&state, []string{"exec"})...)
|
||||
|
||||
var text string
|
||||
for _, evt := range events {
|
||||
if evt.Content != "" {
|
||||
text += evt.Content
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
t.Fatalf("did not expect parsed tool calls from history leak: %+v", evt.ToolCalls)
|
||||
}
|
||||
}
|
||||
if text != "Hello world" {
|
||||
t.Fatalf("expected clean text output preserving boundary spaces, got %q", text)
|
||||
}
|
||||
}
|
||||
@@ -167,22 +167,25 @@ func findToolSegmentStart(s string) int {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
offset := 0
|
||||
for {
|
||||
keyRel := strings.Index(lower[offset:], "tool_calls")
|
||||
if keyRel < 0 {
|
||||
return -1
|
||||
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]", "[tool_result_history]"}
|
||||
bestKeyIdx := -1
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
|
||||
bestKeyIdx = idx
|
||||
}
|
||||
keyIdx := offset + keyRel
|
||||
start := strings.LastIndex(s[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
start = keyIdx
|
||||
}
|
||||
if !insideCodeFence(s[:start]) {
|
||||
return start
|
||||
}
|
||||
offset = keyIdx + len("tool_calls")
|
||||
}
|
||||
if bestKeyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
start := strings.LastIndex(s[:bestKeyIdx], "{")
|
||||
if start < 0 {
|
||||
start = bestKeyIdx
|
||||
}
|
||||
if fenceStart, ok := openFenceStartBefore(s, start); ok {
|
||||
return fenceStart
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
@@ -191,13 +194,24 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
return "", nil, "", false
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
keyIdx := -1
|
||||
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]", "[tool_result_history]"}
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower, kw)
|
||||
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {
|
||||
keyIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if keyIdx < 0 {
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return "", nil, "", false
|
||||
if blockStart, blockEnd, ok := extractToolHistoryBlock(captured, keyIdx); ok {
|
||||
return captured[:blockStart], nil, captured[blockEnd:], true
|
||||
}
|
||||
start = keyIdx
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
@@ -205,9 +219,6 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
}
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||
if len(parsed.Calls) == 0 {
|
||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||
@@ -215,7 +226,75 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
// consume it to avoid leaking raw tool_calls JSON to user content.
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
// If it has obvious keywords but failed to parse even after loose repair,
|
||||
// we still might want to intercept it if it looks like an attempt at tool call.
|
||||
// For now, keep the original logic but rely on loose JSON repair.
|
||||
return captured, nil, "", true
|
||||
}
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed.Calls, suffixPart, true
|
||||
}
|
||||
|
||||
func extractToolHistoryBlock(captured string, keyIdx int) (start int, end int, ok bool) {
|
||||
if keyIdx < 0 || keyIdx >= len(captured) {
|
||||
return 0, 0, false
|
||||
}
|
||||
rest := strings.ToLower(captured[keyIdx:])
|
||||
switch {
|
||||
case strings.HasPrefix(rest, "[tool_call_history]"):
|
||||
closeTag := "[/tool_call_history]"
|
||||
closeIdx := strings.Index(rest, closeTag)
|
||||
if closeIdx < 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
return keyIdx, keyIdx + closeIdx + len(closeTag), true
|
||||
case strings.HasPrefix(rest, "[tool_result_history]"):
|
||||
closeTag := "[/tool_result_history]"
|
||||
closeIdx := strings.Index(rest, closeTag)
|
||||
if closeIdx < 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
return keyIdx, keyIdx + closeIdx + len(closeTag), true
|
||||
default:
|
||||
return 0, 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func trimWrappingJSONFence(prefix, suffix string) (string, string) {
|
||||
trimmedPrefix := strings.TrimRight(prefix, " \t\r\n")
|
||||
fenceIdx := strings.LastIndex(trimmedPrefix, "```")
|
||||
if fenceIdx < 0 {
|
||||
return prefix, suffix
|
||||
}
|
||||
// Only strip when the trailing fence in prefix behaves like an opening fence.
|
||||
// A legitimate closing fence before a standalone tool JSON must be preserved.
|
||||
if strings.Count(trimmedPrefix[:fenceIdx+3], "```")%2 == 0 {
|
||||
return prefix, suffix
|
||||
}
|
||||
fenceHeader := strings.TrimSpace(trimmedPrefix[fenceIdx+3:])
|
||||
if fenceHeader != "" && !strings.EqualFold(fenceHeader, "json") {
|
||||
return prefix, suffix
|
||||
}
|
||||
|
||||
trimmedSuffix := strings.TrimLeft(suffix, " \t\r\n")
|
||||
if !strings.HasPrefix(trimmedSuffix, "```") {
|
||||
return prefix, suffix
|
||||
}
|
||||
consumedLeading := len(suffix) - len(trimmedSuffix)
|
||||
return trimmedPrefix[:fenceIdx], suffix[consumedLeading+3:]
|
||||
}
|
||||
|
||||
func openFenceStartBefore(s string, pos int) (int, bool) {
|
||||
if pos <= 0 || pos > len(s) {
|
||||
return -1, false
|
||||
}
|
||||
segment := s[:pos]
|
||||
lastFence := strings.LastIndex(segment, "```")
|
||||
if lastFence < 0 {
|
||||
return -1, false
|
||||
}
|
||||
if strings.Count(segment, "```")%2 == 1 {
|
||||
return lastFence, true
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
if state.disableDeltas {
|
||||
return nil
|
||||
}
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return nil
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
return nil
|
||||
}
|
||||
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
||||
return nil
|
||||
}
|
||||
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
||||
if hasMultiple {
|
||||
state.disableDeltas = true
|
||||
return nil
|
||||
}
|
||||
if !certainSingle {
|
||||
// In uncertain phases (e.g. first call arrived but array not closed yet),
|
||||
// avoid speculative deltas and wait for final parsed tool_calls payload.
|
||||
return nil
|
||||
}
|
||||
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
deltas := make([]toolCallDelta, 0, 2)
|
||||
if state.toolName == "" {
|
||||
name, ok := extractToolCallName(captured, callStart)
|
||||
if !ok || name == "" {
|
||||
return nil
|
||||
}
|
||||
state.toolName = name
|
||||
}
|
||||
if state.toolArgsStart < 0 {
|
||||
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
|
||||
if ok {
|
||||
state.toolArgsString = stringMode
|
||||
if stringMode {
|
||||
state.toolArgsStart = argsStart + 1
|
||||
} else {
|
||||
state.toolArgsStart = argsStart
|
||||
}
|
||||
state.toolArgsSent = state.toolArgsStart
|
||||
}
|
||||
}
|
||||
if !state.toolNameSent {
|
||||
if state.toolArgsStart < 0 {
|
||||
return nil
|
||||
}
|
||||
state.toolNameSent = true
|
||||
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
|
||||
}
|
||||
if state.toolArgsStart < 0 || state.toolArgsDone {
|
||||
return deltas
|
||||
}
|
||||
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
|
||||
if !ok {
|
||||
return deltas
|
||||
}
|
||||
if end > state.toolArgsSent {
|
||||
deltas = append(deltas, toolCallDelta{
|
||||
Index: 0,
|
||||
Arguments: captured[state.toolArgsSent:end],
|
||||
})
|
||||
state.toolArgsSent = end
|
||||
}
|
||||
if complete {
|
||||
state.toolArgsDone = true
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return false, false
|
||||
}
|
||||
count := 0
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for ; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
if depth == 0 {
|
||||
count++
|
||||
if count > 1 {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
if depth > 0 {
|
||||
depth--
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == ',' && depth == 0 {
|
||||
// top-level separator means at least one more tool call exists
|
||||
// (or is expected). Treat as multi-call and stop incremental deltas.
|
||||
return false, true
|
||||
}
|
||||
if ch == ']' && depth == 0 {
|
||||
return count == 1, false
|
||||
}
|
||||
}
|
||||
// array not closed yet: still uncertain whether more calls will appear
|
||||
return false, false
|
||||
}
|
||||
|
||||
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
|
||||
i := keyIdx + len("tool_calls")
|
||||
for i < len(text) && text[i] != ':' {
|
||||
i++
|
||||
}
|
||||
if i >= len(text) {
|
||||
return -1, false
|
||||
}
|
||||
i = skipSpaces(text, i+1)
|
||||
if i >= len(text) || text[i] != '[' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func extractToolCallName(text string, callStart int) (string, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return "", false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
name, _, ok := parseJSONStringLiteral(text, valueStart)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return name, true
|
||||
}
|
||||
|
||||
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
|
||||
keys := []string{"input", "arguments", "args", "parameters", "params"}
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
|
||||
if !ok {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return -1, false, false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
|
||||
if !ok {
|
||||
return -1, false, false
|
||||
}
|
||||
}
|
||||
if valueStart >= len(text) {
|
||||
return -1, false, false
|
||||
}
|
||||
ch := text[valueStart]
|
||||
if ch == '{' || ch == '[' {
|
||||
return valueStart, false, true
|
||||
}
|
||||
if ch == '"' {
|
||||
return valueStart, true, true
|
||||
}
|
||||
return -1, false, false
|
||||
}
|
||||
|
||||
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
|
||||
if start < 0 || start > len(text) {
|
||||
return 0, false, false
|
||||
}
|
||||
if stringMode {
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return i, true, true
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
if start >= len(text) {
|
||||
return start, false, false
|
||||
}
|
||||
if text[start] != '{' && text[start] != '[' {
|
||||
return 0, false, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' || ch == '[' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' || ch == ']' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i + 1, true, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
|
||||
func findFunctionObjectStart(text string, callStart int) (int, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return valueStart, true
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
@@ -43,110 +41,3 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
|
||||
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
|
||||
return 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := objStart; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
if depth == 1 {
|
||||
key, end, ok := parseJSONStringLiteral(text, i)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
j := skipSpaces(text, end)
|
||||
if j >= len(text) || text[j] != ':' {
|
||||
i = end - 1
|
||||
continue
|
||||
}
|
||||
j = skipSpaces(text, j+1)
|
||||
if j >= len(text) {
|
||||
return 0, false
|
||||
}
|
||||
if containsKey(keys, key) {
|
||||
return j, true
|
||||
}
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '"' {
|
||||
return "", 0, false
|
||||
}
|
||||
var b strings.Builder
|
||||
escaped := false
|
||||
for i := start + 1; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
b.WriteByte(ch)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return b.String(), i + 1, true
|
||||
}
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func containsKey(keys []string, value string) bool {
|
||||
for _, k := range keys {
|
||||
if k == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func skipSpaces(text string, i int) int {
|
||||
for i < len(text) {
|
||||
switch text[i] {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
i++
|
||||
default:
|
||||
return i
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
@@ -63,14 +63,3 @@ func appendTail(prev, next string, max int) string {
|
||||
}
|
||||
return combined[len(combined)-max:]
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
return insideCodeFence(text)
|
||||
}
|
||||
|
||||
func insideCodeFence(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Count(text, "```")%2 == 1
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ type ConfigStore interface {
|
||||
FindAccount(identifier string) (config.Account, bool)
|
||||
UpdateAccountToken(identifier, token string) error
|
||||
UpdateAccountTestStatus(identifier, status string) error
|
||||
AccountTestStatus(identifier string) (string, bool)
|
||||
Update(mutator func(*config.Config) error) error
|
||||
ExportJSONAndBase64() (string, string, error)
|
||||
IsEnvBacked() bool
|
||||
@@ -27,6 +28,7 @@ type ConfigStore interface {
|
||||
RuntimeAccountMaxInflight() int
|
||||
RuntimeAccountMaxQueue(defaultSize int) int
|
||||
RuntimeGlobalMaxInflight(defaultSize int) int
|
||||
AutoDeleteSessions() bool
|
||||
}
|
||||
|
||||
type PoolController interface {
|
||||
@@ -40,6 +42,8 @@ type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
GetSessionCountForToken(ctx context.Context, token string) (*deepseek.SessionStats, error)
|
||||
DeleteAllSessionsForToken(ctx context.Context, token string) error
|
||||
}
|
||||
|
||||
var _ ConfigStore = (*config.Store)(nil)
|
||||
|
||||
@@ -31,12 +31,15 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Get("/queue/status", h.queueStatus)
|
||||
pr.Post("/accounts/test", h.testSingleAccount)
|
||||
pr.Post("/accounts/test-all", h.testAllAccounts)
|
||||
pr.Post("/accounts/sessions/delete-all", h.deleteAllSessions)
|
||||
pr.Post("/import", h.batchImport)
|
||||
pr.Post("/test", h.testAPI)
|
||||
pr.Post("/vercel/sync", h.syncVercel)
|
||||
pr.Get("/vercel/status", h.vercelStatus)
|
||||
pr.Post("/vercel/status", h.vercelStatus)
|
||||
pr.Get("/export", h.exportConfig)
|
||||
pr.Get("/dev/captures", h.getDevCaptures)
|
||||
pr.Delete("/dev/captures", h.clearDevCaptures)
|
||||
pr.Get("/version", h.getVersion)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
items := make([]map[string]any, 0, end-start)
|
||||
for _, acc := range accounts[start:end] {
|
||||
testStatus, _ := h.Store.AccountTestStatus(acc.Identifier())
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
@@ -70,7 +71,7 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
"has_password": acc.Password != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
"test_status": acc.TestStatus,
|
||||
"test_status": testStatus,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -26,9 +25,9 @@ func newAdminTestHandler(t *testing.T, raw string) *Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) {
|
||||
func TestListAccountsUsesEmailIdentifier(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
"accounts":[{"email":"u@example.com","password":"pwd"}]
|
||||
}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/accounts?page=1&page_size=10", nil)
|
||||
@@ -49,38 +48,8 @@ func TestListAccountsIncludesTokenOnlyIdentifier(t *testing.T) {
|
||||
}
|
||||
first, _ := items[0].(map[string]any)
|
||||
identifier, _ := first["identifier"].(string)
|
||||
if identifier == "" {
|
||||
t.Fatalf("expected non-empty identifier: %#v", first)
|
||||
}
|
||||
if !strings.HasPrefix(identifier, "token:") {
|
||||
t.Fatalf("expected token synthetic identifier, got %q", identifier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAccountSupportsTokenOnlyIdentifier(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
}`)
|
||||
accounts := h.Store.Accounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||
}
|
||||
id := accounts[0].Identifier()
|
||||
if id == "" {
|
||||
t.Fatal("expected token-only synthetic identifier")
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape(id), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := len(h.Store.Accounts()); got != 0 {
|
||||
t.Fatalf("expected account removed, remaining=%d", got)
|
||||
if identifier != "u@example.com" {
|
||||
t.Fatalf("expected email identifier, got %q", identifier)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,11 +111,10 @@ func TestAddAccountRejectsCanonicalMobileDuplicate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||
func TestFindAccountByIdentifierSupportsMobile(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[
|
||||
{"email":"u@example.com","mobile":"13800138000","password":"pwd"},
|
||||
{"token":"token-only-account"}
|
||||
{"email":"u@example.com","mobile":"13800138000","password":"pwd"}
|
||||
]
|
||||
}`)
|
||||
|
||||
@@ -165,21 +133,4 @@ func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||
t.Fatalf("unexpected account by +86 mobile: %#v", accByMobileWithCountryCode)
|
||||
}
|
||||
|
||||
tokenOnlyID := ""
|
||||
for _, acc := range h.Store.Accounts() {
|
||||
if strings.TrimSpace(acc.Email) == "" && strings.TrimSpace(acc.Mobile) == "" {
|
||||
tokenOnlyID = acc.Identifier()
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenOnlyID == "" {
|
||||
t.Fatal("expected token-only account identifier")
|
||||
}
|
||||
accByTokenOnly, ok := findAccountByIdentifier(h.Store, tokenOnlyID)
|
||||
if !ok {
|
||||
t.Fatalf("expected find by token-only id=%q", tokenOnlyID)
|
||||
}
|
||||
if accByTokenOnly.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token-only account: %#v", accByTokenOnly)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +89,15 @@ func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int,
|
||||
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
||||
start := time.Now()
|
||||
identifier := acc.Identifier()
|
||||
result := map[string]any{"account": identifier, "success": false, "response_time": 0, "message": "", "model": model}
|
||||
result := map[string]any{
|
||||
"account": identifier,
|
||||
"success": false,
|
||||
"response_time": 0,
|
||||
"message": "",
|
||||
"model": model,
|
||||
"session_count": 0,
|
||||
"config_writable": !h.Store.IsEnvBacked(),
|
||||
}
|
||||
defer func() {
|
||||
status := "failed"
|
||||
if ok, _ := result["success"].(bool); ok {
|
||||
@@ -97,15 +105,14 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
|
||||
}
|
||||
_ = h.Store.UpdateAccountTestStatus(identifier, status)
|
||||
}()
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
if token == "" {
|
||||
newToken, err := h.DS.Login(ctx, acc)
|
||||
if err != nil {
|
||||
result["message"] = "登录失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
token = newToken
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
token, err := h.DS.Login(ctx, acc)
|
||||
if err != nil {
|
||||
result["message"] = "登录失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
if err := h.Store.UpdateAccountToken(acc.Identifier(), token); err != nil {
|
||||
result["message"] = "登录成功但写入运行时 token 失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
|
||||
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
|
||||
@@ -117,16 +124,26 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
|
||||
}
|
||||
token = newToken
|
||||
authCtx.DeepSeekToken = token
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
if err := h.Store.UpdateAccountToken(acc.Identifier(), token); err != nil {
|
||||
result["message"] = "刷新 token 成功但写入运行时 token 失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
result["message"] = "创建会话失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// 获取会话数量
|
||||
sessionStats, sessionErr := h.DS.GetSessionCountForToken(ctx, token)
|
||||
if sessionErr == nil && sessionStats != nil {
|
||||
result["session_count"] = sessionStats.FirstPageCount
|
||||
}
|
||||
|
||||
if strings.TrimSpace(message) == "" {
|
||||
result["success"] = true
|
||||
result["message"] = "API 测试成功(仅会话创建)"
|
||||
result["message"] = "Token 刷新成功(登录与会话创建成功)"
|
||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||
return result
|
||||
}
|
||||
@@ -210,3 +227,45 @@ func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAllSessions(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
identifier, _ := req["identifier"].(string)
|
||||
if strings.TrimSpace(identifier) == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
||||
return
|
||||
}
|
||||
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 每次先登录刷新一次 token,避免使用过期 token。
|
||||
token, err := h.DS.Login(r.Context(), acc)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "登录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
|
||||
// 删除所有会话
|
||||
err = h.DS.DeleteAllSessionsForToken(r.Context(), token)
|
||||
if err != nil {
|
||||
// token 可能过期,尝试重新登录并重试一次
|
||||
newToken, loginErr := h.DS.Login(r.Context(), acc)
|
||||
if loginErr != nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
token = newToken
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
if retryErr := h.DS.DeleteAllSessionsForToken(r.Context(), token); retryErr != nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + retryErr.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "删除成功"})
|
||||
}
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type testingDSMock struct {
|
||||
loginCalls int
|
||||
createSessionCalls int
|
||||
getPowCalls int
|
||||
callCompletionCalls int
|
||||
loginCalls int
|
||||
createSessionCalls int
|
||||
getPowCalls int
|
||||
callCompletionCalls int
|
||||
deleteAllSessionsCalls int
|
||||
deleteAllSessionsError error
|
||||
deleteAllSessionsErrorOnce bool
|
||||
}
|
||||
|
||||
func (m *testingDSMock) Login(_ context.Context, _ config.Account) (string, error) {
|
||||
@@ -38,6 +45,22 @@ func (m *testingDSMock) CallCompletion(_ context.Context, _ *auth.RequestAuth, _
|
||||
return nil, errors.New("should not call CallCompletion in this test")
|
||||
}
|
||||
|
||||
func (m *testingDSMock) DeleteAllSessionsForToken(_ context.Context, _ string) error {
|
||||
m.deleteAllSessionsCalls++
|
||||
if m.deleteAllSessionsError != nil {
|
||||
err := m.deleteAllSessionsError
|
||||
if m.deleteAllSessionsErrorOnce {
|
||||
m.deleteAllSessionsError = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testingDSMock) GetSessionCountForToken(_ context.Context, _ string) (*deepseek.SessionStats, error) {
|
||||
return &deepseek.SessionStats{Success: true}, nil
|
||||
}
|
||||
|
||||
func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":""}]}`)
|
||||
store := config.LoadStore()
|
||||
@@ -54,7 +77,7 @@ func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
|
||||
t.Fatalf("expected success=true, got %#v", result)
|
||||
}
|
||||
msg, _ := result["message"].(string)
|
||||
if !strings.Contains(msg, "仅会话创建") {
|
||||
if !strings.Contains(msg, "Token 刷新成功") {
|
||||
t.Fatalf("expected session-only success message, got %q", msg)
|
||||
}
|
||||
if ds.loginCalls != 1 || ds.createSessionCalls != 1 {
|
||||
@@ -70,7 +93,43 @@ func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
|
||||
if updated.Token != "new-token" {
|
||||
t.Fatalf("expected refreshed token to be persisted, got %q", updated.Token)
|
||||
}
|
||||
if updated.TestStatus != "ok" {
|
||||
t.Fatalf("expected test status ok, got %q", updated.TestStatus)
|
||||
testStatus, ok := store.AccountTestStatus("batch@example.com")
|
||||
if !ok || testStatus != "ok" {
|
||||
t.Fatalf("expected runtime test status ok, got %q (ok=%v)", testStatus, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAllSessions_RetryWithReloginOnDeleteFailure(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":"expired-token"}]}`)
|
||||
store := config.LoadStore()
|
||||
ds := &testingDSMock{deleteAllSessionsError: errors.New("token expired"), deleteAllSessionsErrorOnce: true}
|
||||
h := &Handler{Store: store, DS: ds}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/delete-all", bytes.NewBufferString(`{"identifier":"batch@example.com"}`))
|
||||
rec := httptest.NewRecorder()
|
||||
h.deleteAllSessions(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", rec.Code)
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if ok, _ := resp["success"].(bool); !ok {
|
||||
t.Fatalf("expected success response, got %#v", resp)
|
||||
}
|
||||
if ds.loginCalls != 2 {
|
||||
t.Fatalf("expected initial login plus relogin, got %d", ds.loginCalls)
|
||||
}
|
||||
if ds.deleteAllSessionsCalls != 2 {
|
||||
t.Fatalf("expected delete called twice, got %d", ds.deleteAllSessionsCalls)
|
||||
}
|
||||
updated, ok := store.FindAccount("batch@example.com")
|
||||
if !ok {
|
||||
t.Fatal("expected account")
|
||||
}
|
||||
if updated.Token != "new-token" {
|
||||
t.Fatalf("expected refreshed token persisted, got %q", updated.Token)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
incoming.ClearAccountTokens()
|
||||
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err = h.Store.Update(func(c *config.Config) error {
|
||||
@@ -180,6 +181,7 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (h *Handler) computeSyncHash() string {
|
||||
snap := h.Store.Snapshot().Clone()
|
||||
snap.ClearAccountTokens()
|
||||
snap.VercelSyncHash = ""
|
||||
snap.VercelSyncTime = 0
|
||||
b, _ := json.Marshal(snap)
|
||||
|
||||
@@ -8,8 +8,9 @@ import (
|
||||
func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
safe := map[string]any{
|
||||
"keys": snap.Keys,
|
||||
"accounts": []map[string]any{},
|
||||
"keys": snap.Keys,
|
||||
"accounts": []map[string]any{},
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"claude_mapping": func() map[string]string {
|
||||
if len(snap.ClaudeMapping) > 0 {
|
||||
return snap.ClaudeMapping
|
||||
|
||||
@@ -50,9 +50,6 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(acc.Password) == "" {
|
||||
acc.Password = prev.Password
|
||||
}
|
||||
if strings.TrimSpace(acc.Token) == "" {
|
||||
acc.Token = prev.Token
|
||||
}
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
accounts = append(accounts, acc)
|
||||
|
||||
@@ -7,15 +7,30 @@ import (
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
|
||||
func boolFrom(v any) bool {
|
||||
if v == nil {
|
||||
return false
|
||||
}
|
||||
switch x := v.(type) {
|
||||
case bool:
|
||||
return x
|
||||
case string:
|
||||
return strings.ToLower(strings.TrimSpace(x)) == "true"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, *config.AutoDeleteConfig, map[string]string, map[string]string, error) {
|
||||
var (
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
toolcallCfg *config.ToolcallConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
toolcallCfg *config.ToolcallConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
autoDeleteCfg *config.AutoDeleteConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
)
|
||||
|
||||
if raw, ok := req["admin"].(map[string]any); ok {
|
||||
@@ -23,7 +38,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["jwt_expire_hours"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 720 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
|
||||
}
|
||||
cfg.JWTExpireHours = n
|
||||
}
|
||||
@@ -35,26 +50,26 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["account_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 256 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
|
||||
}
|
||||
cfg.AccountMaxInflight = n
|
||||
}
|
||||
if v, exists := raw["account_max_queue"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 200000 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
|
||||
}
|
||||
cfg.AccountMaxQueue = n
|
||||
}
|
||||
if v, exists := raw["global_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 200000 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
|
||||
}
|
||||
cfg.GlobalMaxInflight = n
|
||||
}
|
||||
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
}
|
||||
runtimeCfg = cfg
|
||||
}
|
||||
@@ -67,7 +82,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
case "feature_match", "off":
|
||||
cfg.Mode = mode
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
|
||||
}
|
||||
}
|
||||
if v, exists := raw["early_emit_confidence"]; exists {
|
||||
@@ -76,7 +91,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
case "high", "low", "off":
|
||||
cfg.EarlyEmitConfidence = level
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
||||
}
|
||||
}
|
||||
toolcallCfg = cfg
|
||||
@@ -87,7 +102,7 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
if v, exists := raw["store_ttl_seconds"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 30 || n > 86400 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
|
||||
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
|
||||
}
|
||||
cfg.StoreTTLSeconds = n
|
||||
}
|
||||
@@ -98,9 +113,6 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
cfg := &config.EmbeddingsConfig{}
|
||||
if v, exists := raw["provider"]; exists {
|
||||
p := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if p == "" {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
|
||||
}
|
||||
cfg.Provider = p
|
||||
}
|
||||
embCfg = cfg
|
||||
@@ -130,5 +142,13 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
|
||||
}
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
|
||||
if raw, ok := req["auto_delete"].(map[string]any); ok {
|
||||
cfg := &config.AutoDeleteConfig{}
|
||||
if v, exists := raw["sessions"]; exists {
|
||||
cfg.Sessions = boolFrom(v)
|
||||
}
|
||||
autoDeleteCfg = cfg
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, autoDeleteCfg, claudeMap, aliasMap, nil
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
|
||||
"toolcall": snap.Toolcall,
|
||||
"responses": snap.Responses,
|
||||
"embeddings": snap.Embeddings,
|
||||
"auto_delete": snap.AutoDelete,
|
||||
"claude_mapping": settingsClaudeMapping(snap),
|
||||
"model_aliases": snap.ModelAliases,
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
|
||||
@@ -17,7 +17,7 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, autoDeleteCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
@@ -60,6 +60,9 @@ func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
|
||||
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
|
||||
}
|
||||
if autoDeleteCfg != nil {
|
||||
c.AutoDelete.Sessions = autoDeleteCfg.Sessions
|
||||
}
|
||||
if claudeMap != nil {
|
||||
c.ClaudeMapping = claudeMap
|
||||
c.ClaudeModelMap = nil
|
||||
|
||||
@@ -3,6 +3,8 @@ package admin
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -11,6 +13,8 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -25,7 +29,7 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
validated, failed := h.validateAccountsForVercelSync(r.Context(), opts.AutoValidate)
|
||||
_, cfgB64, err := h.Store.ExportJSONAndBase64()
|
||||
cfgJSON, cfgB64, err := h.exportSyncConfig(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
@@ -47,7 +51,7 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
savedCreds := h.saveVercelProjectCredentials(r.Context(), client, opts, params, headers, envs)
|
||||
manual, deployURL := triggerVercelDeployment(r.Context(), client, opts.ProjectID, params, headers)
|
||||
_ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix())
|
||||
_ = h.Store.SetVercelSync(syncHashForJSON(cfgJSON), time.Now().Unix())
|
||||
result := map[string]any{"success": true, "validated_accounts": validated}
|
||||
if manual {
|
||||
result["message"] = "配置已同步到 Vercel,请手动触发重新部署"
|
||||
@@ -209,11 +213,71 @@ func triggerVercelDeployment(ctx context.Context, client *http.Client, projectID
|
||||
return false, deployURL
|
||||
}
|
||||
|
||||
func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
func (h *Handler) vercelStatus(w http.ResponseWriter, r *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
current := h.computeSyncHash()
|
||||
synced := snap.VercelSyncHash != "" && snap.VercelSyncHash == current
|
||||
writeJSON(w, http.StatusOK, map[string]any{"synced": synced, "last_sync_time": nilIfZero(snap.VercelSyncTime), "has_synced_before": snap.VercelSyncHash != ""})
|
||||
draftHash := ""
|
||||
draftDiffers := false
|
||||
if r != nil && r.Method == http.MethodPost && r.Body != nil {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
|
||||
if cfgJSON, _, err := h.exportSyncConfig(req); err == nil {
|
||||
draftHash = syncHashForJSON(cfgJSON)
|
||||
draftDiffers = draftHash != "" && draftHash != current
|
||||
}
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"synced": synced,
|
||||
"last_sync_time": nilIfZero(snap.VercelSyncTime),
|
||||
"has_synced_before": snap.VercelSyncHash != "",
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"config_hash": current,
|
||||
"last_synced_hash": snap.VercelSyncHash,
|
||||
"draft_hash": draftHash,
|
||||
"draft_differs": draftDiffers,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) exportSyncConfig(req map[string]any) (string, string, error) {
|
||||
override, ok := req["config_override"]
|
||||
if !ok || override == nil {
|
||||
return h.Store.ExportJSONAndBase64()
|
||||
}
|
||||
raw, err := json.Marshal(override)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
var cfg config.Config
|
||||
if err := json.Unmarshal(raw, &cfg); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
cfg.DropInvalidAccounts()
|
||||
cfg.ClearAccountTokens()
|
||||
cfg.VercelSyncHash = ""
|
||||
cfg.VercelSyncTime = 0
|
||||
b, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return string(b), base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func syncHashForJSON(s string) string {
|
||||
var cfg config.Config
|
||||
if err := json.Unmarshal([]byte(s), &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
cfg.VercelSyncHash = ""
|
||||
cfg.VercelSyncTime = 0
|
||||
cfg.ClearAccountTokens()
|
||||
b, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
sum := md5.Sum(b)
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
|
||||
func vercelRequest(ctx context.Context, client *http.Client, method, endpoint string, params url.Values, headers map[string]string, body any) (map[string]any, int, error) {
|
||||
|
||||
75
internal/admin/handler_version.go
Normal file
75
internal/admin/handler_version.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/version"
|
||||
)
|
||||
|
||||
const latestReleaseAPI = "https://api.github.com/repos/CJackHwang/ds2api/releases/latest"
|
||||
|
||||
type latestReleasePayload struct {
|
||||
TagName string `json:"tag_name"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
}
|
||||
|
||||
func (h *Handler) getVersion(w http.ResponseWriter, _ *http.Request) {
|
||||
current, source := version.Current()
|
||||
resp := map[string]any{
|
||||
"success": true,
|
||||
"current_version": current,
|
||||
"current_tag": version.Tag(current),
|
||||
"source": source,
|
||||
"checked_at": time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, latestReleaseAPI, nil)
|
||||
if err != nil {
|
||||
resp["check_error"] = err.Error()
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("User-Agent", "ds2api-version-check")
|
||||
|
||||
client := &http.Client{Timeout: 4 * time.Second}
|
||||
r, err := client.Do(req)
|
||||
if err != nil {
|
||||
resp["check_error"] = err.Error()
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
if r.StatusCode < 200 || r.StatusCode >= 300 {
|
||||
resp["check_error"] = "github api status: " + r.Status
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
var data latestReleasePayload
|
||||
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
|
||||
resp["check_error"] = err.Error()
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
latest := strings.TrimSpace(data.TagName)
|
||||
if latest == "" {
|
||||
resp["check_error"] = "missing latest tag"
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
latestVersion := strings.TrimPrefix(latest, "v")
|
||||
|
||||
resp["latest_tag"] = latest
|
||||
resp["latest_version"] = latestVersion
|
||||
resp["release_url"] = data.HTMLURL
|
||||
resp["published_at"] = data.PublishedAt
|
||||
resp["has_update"] = version.Compare(current, latestVersion) < 0
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
@@ -65,7 +65,6 @@ func toAccount(m map[string]any) config.Account {
|
||||
Email: email,
|
||||
Mobile: mobile,
|
||||
Password: fieldString(m, "password"),
|
||||
Token: fieldString(m, "token"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -188,8 +188,8 @@ func TestToAccountAllFields(t *testing.T) {
|
||||
if acc.Password != "secret" {
|
||||
t.Fatalf("unexpected password: %q", acc.Password)
|
||||
}
|
||||
if acc.Token != "tok123" {
|
||||
t.Fatalf("unexpected token: %q", acc.Token)
|
||||
if acc.Token != "" {
|
||||
t.Fatalf("expected token to be ignored, got %q", acc.Token)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
109
internal/admin/token_runtime_http_test.go
Normal file
109
internal/admin/token_runtime_http_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newHTTPAdminHarness(t *testing.T, rawConfig string, ds DeepSeekCaller) http.Handler {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", rawConfig)
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
store := config.LoadStore()
|
||||
h := &Handler{
|
||||
Store: store,
|
||||
Pool: account.NewPool(store),
|
||||
DS: ds,
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
return r
|
||||
}
|
||||
|
||||
func adminReq(method, path string, body []byte) *http.Request {
|
||||
req := httptest.NewRequest(method, path, bytes.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer admin")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req
|
||||
}
|
||||
|
||||
func TestConfigImportIgnoresTokenFieldInPayload(t *testing.T) {
|
||||
ds := &testingDSMock{}
|
||||
router := newHTTPAdminHarness(t, `{"accounts":[]}`, ds)
|
||||
|
||||
payload := []byte(`{
|
||||
"mode":"replace",
|
||||
"config":{
|
||||
"accounts":[{"email":"u@example.com","password":"pwd","token":"expired-token"}]
|
||||
}
|
||||
}`)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, adminReq(http.MethodPost, "/config/import", payload))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("import status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
readRec := httptest.NewRecorder()
|
||||
router.ServeHTTP(readRec, adminReq(http.MethodGet, "/config", nil))
|
||||
if readRec.Code != http.StatusOK {
|
||||
t.Fatalf("get config status=%d body=%s", readRec.Code, readRec.Body.String())
|
||||
}
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(readRec.Body.Bytes(), &data); err != nil {
|
||||
t.Fatalf("decode config response: %v", err)
|
||||
}
|
||||
accounts, _ := data["accounts"].([]any)
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected one account, got %d", len(accounts))
|
||||
}
|
||||
accountMap, _ := accounts[0].(map[string]any)
|
||||
if hasToken, _ := accountMap["has_token"].(bool); hasToken {
|
||||
t.Fatalf("expected imported token to be ignored, account=%#v", accountMap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountTestRefreshesRuntimeTokenButExportOmitsToken(t *testing.T) {
|
||||
ds := &testingDSMock{}
|
||||
router := newHTTPAdminHarness(t, `{
|
||||
"accounts":[{"email":"batch@example.com","password":"pwd","token":"stale-token"}]
|
||||
}`, ds)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, adminReq(http.MethodPost, "/accounts/test", []byte(`{"identifier":"batch@example.com"}`)))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("test account status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var testResp map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &testResp); err != nil {
|
||||
t.Fatalf("decode test response: %v", err)
|
||||
}
|
||||
if ok, _ := testResp["success"].(bool); !ok {
|
||||
t.Fatalf("expected test success, got %#v", testResp)
|
||||
}
|
||||
if ds.loginCalls < 1 {
|
||||
t.Fatalf("expected login to be called at least once, got %d", ds.loginCalls)
|
||||
}
|
||||
|
||||
exportRec := httptest.NewRecorder()
|
||||
router.ServeHTTP(exportRec, adminReq(http.MethodGet, "/config/export", nil))
|
||||
if exportRec.Code != http.StatusOK {
|
||||
t.Fatalf("export status=%d body=%s", exportRec.Code, exportRec.Body.String())
|
||||
}
|
||||
var exportResp map[string]any
|
||||
if err := json.Unmarshal(exportRec.Body.Bytes(), &exportResp); err != nil {
|
||||
t.Fatalf("decode export response: %v", err)
|
||||
}
|
||||
exportJSON, _ := exportResp["json"].(string)
|
||||
if strings.Contains(exportJSON, `"token"`) {
|
||||
t.Fatalf("expected export json to omit tokens, got %s", exportJSON)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
@@ -37,10 +39,20 @@ type Resolver struct {
|
||||
Store *config.Store
|
||||
Pool *account.Pool
|
||||
Login LoginFunc
|
||||
|
||||
mu sync.Mutex
|
||||
tokenRefreshedAt map[string]time.Time
|
||||
tokenRefreshInterval time.Duration
|
||||
}
|
||||
|
||||
func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Resolver {
|
||||
return &Resolver{Store: store, Pool: pool, Login: login}
|
||||
return &Resolver{
|
||||
Store: store,
|
||||
Pool: pool,
|
||||
Login: login,
|
||||
tokenRefreshedAt: map[string]time.Time{},
|
||||
tokenRefreshInterval: 6 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
@@ -72,13 +84,9 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
TriedAccounts: map[string]bool{},
|
||||
resolver: r,
|
||||
}
|
||||
if acc.Token == "" {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
r.Pool.Release(a.AccountID)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
a.DeepSeekToken = acc.Token
|
||||
if err := r.ensureManagedToken(ctx, a); err != nil {
|
||||
r.Pool.Release(a.AccountID)
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
@@ -120,6 +128,7 @@ func (r *Resolver) loginAndPersist(ctx context.Context, a *RequestAuth) error {
|
||||
}
|
||||
a.Account.Token = token
|
||||
a.DeepSeekToken = token
|
||||
r.markTokenRefreshedNow(a.AccountID)
|
||||
return r.Store.UpdateAccountToken(a.AccountID, token)
|
||||
}
|
||||
|
||||
@@ -142,6 +151,7 @@ func (r *Resolver) MarkTokenInvalid(a *RequestAuth) {
|
||||
}
|
||||
a.Account.Token = ""
|
||||
a.DeepSeekToken = ""
|
||||
r.clearTokenRefreshMark(a.AccountID)
|
||||
_ = r.Store.UpdateAccountToken(a.AccountID, "")
|
||||
}
|
||||
|
||||
@@ -162,12 +172,8 @@ func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool {
|
||||
}
|
||||
a.Account = acc
|
||||
a.AccountID = acc.Identifier()
|
||||
if acc.Token == "" {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
a.DeepSeekToken = acc.Token
|
||||
if err := r.ensureManagedToken(ctx, a); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -210,3 +216,53 @@ func callerTokenID(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return "caller:" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func (r *Resolver) ensureManagedToken(ctx context.Context, a *RequestAuth) error {
|
||||
if strings.TrimSpace(a.Account.Token) == "" {
|
||||
return r.loginAndPersist(ctx, a)
|
||||
}
|
||||
if r.shouldForceRefresh(a.AccountID) {
|
||||
if err := r.loginAndPersist(ctx, a); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
a.DeepSeekToken = a.Account.Token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Resolver) shouldForceRefresh(accountID string) bool {
|
||||
if strings.TrimSpace(accountID) == "" {
|
||||
return false
|
||||
}
|
||||
if r.tokenRefreshInterval <= 0 {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
last, ok := r.tokenRefreshedAt[accountID]
|
||||
if !ok || last.IsZero() {
|
||||
r.tokenRefreshedAt[accountID] = now
|
||||
return false
|
||||
}
|
||||
return now.Sub(last) >= r.tokenRefreshInterval
|
||||
}
|
||||
|
||||
func (r *Resolver) markTokenRefreshedNow(accountID string) {
|
||||
if strings.TrimSpace(accountID) == "" {
|
||||
return
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.tokenRefreshedAt[accountID] = time.Now()
|
||||
}
|
||||
|
||||
func (r *Resolver) clearTokenRefreshMark(accountID string) {
|
||||
if strings.TrimSpace(accountID) == "" {
|
||||
return
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
delete(r.tokenRefreshedAt, accountID)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
@@ -58,7 +60,7 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
|
||||
if auth.AccountID != "acc@example.com" {
|
||||
t.Fatalf("unexpected account id: %q", auth.AccountID)
|
||||
}
|
||||
if auth.DeepSeekToken != "account-token" {
|
||||
if auth.DeepSeekToken != "fresh-token" {
|
||||
t.Fatalf("unexpected account token: %q", auth.DeepSeekToken)
|
||||
}
|
||||
if auth.CallerID == "" {
|
||||
@@ -193,3 +195,52 @@ func TestDetermineCallerMissingToken(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineManagedAccountForcesRefreshEverySixHours(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[{"email":"acc@example.com","password":"pwd","token":"seed-token"}]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
if err := store.UpdateAccountToken("acc@example.com", "seed-token"); err != nil {
|
||||
t.Fatalf("update token failed: %v", err)
|
||||
}
|
||||
pool := account.NewPool(store)
|
||||
|
||||
var loginCount int32
|
||||
resolver := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
n := atomic.AddInt32(&loginCount, 1)
|
||||
return "fresh-token-" + string(rune('0'+n)), nil
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
|
||||
a1, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
if a1.DeepSeekToken != "seed-token" {
|
||||
t.Fatalf("expected initial token without forced refresh, got %q", a1.DeepSeekToken)
|
||||
}
|
||||
resolver.Release(a1)
|
||||
if got := atomic.LoadInt32(&loginCount); got != 0 {
|
||||
t.Fatalf("expected no login before refresh interval, got %d", got)
|
||||
}
|
||||
|
||||
resolver.mu.Lock()
|
||||
resolver.tokenRefreshedAt["acc@example.com"] = time.Now().Add(-7 * time.Hour)
|
||||
resolver.mu.Unlock()
|
||||
|
||||
a2, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine after interval failed: %v", err)
|
||||
}
|
||||
defer resolver.Release(a2)
|
||||
if a2.DeepSeekToken != "fresh-token-1" {
|
||||
t.Fatalf("expected refreshed token after interval, got %q", a2.DeepSeekToken)
|
||||
}
|
||||
if got := atomic.LoadInt32(&loginCount); got != 1 {
|
||||
t.Fatalf("expected exactly one forced refresh login, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
)
|
||||
import "strings"
|
||||
|
||||
func (a Account) Identifier() string {
|
||||
if strings.TrimSpace(a.Email) != "" {
|
||||
@@ -13,12 +9,5 @@ func (a Account) Identifier() string {
|
||||
if mobile := NormalizeMobileForStorage(a.Mobile); mobile != "" {
|
||||
return mobile
|
||||
}
|
||||
// Backward compatibility: old configs may contain token-only accounts.
|
||||
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
||||
token := strings.TrimSpace(a.Token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return "token:" + hex.EncodeToString(sum[:8])
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
if strings.TrimSpace(c.Embeddings.Provider) != "" {
|
||||
m["embeddings"] = c.Embeddings
|
||||
}
|
||||
m["auto_delete"] = c.AutoDelete
|
||||
if c.VercelSyncHash != "" {
|
||||
m["_vercel_sync_hash"] = c.VercelSyncHash
|
||||
}
|
||||
@@ -108,6 +109,10 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "auto_delete":
|
||||
if err := json.Unmarshal(v, &c.AutoDelete); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "_vercel_sync_hash":
|
||||
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
@@ -141,6 +146,7 @@ func (c Config) Clone() Config {
|
||||
Toolcall: c.Toolcall,
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
AutoDelete: c.AutoDelete,
|
||||
VercelSyncHash: c.VercelSyncHash,
|
||||
VercelSyncTime: c.VercelSyncTime,
|
||||
AdditionalFields: map[string]any{},
|
||||
|
||||
@@ -12,17 +12,43 @@ type Config struct {
|
||||
Toolcall ToolcallConfig `json:"toolcall,omitempty"`
|
||||
Responses ResponsesConfig `json:"responses,omitempty"`
|
||||
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
|
||||
AutoDelete AutoDeleteConfig `json:"auto_delete"`
|
||||
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
|
||||
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
|
||||
AdditionalFields map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TestStatus string `json:"test_status,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Mobile string `json:"mobile,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Config) ClearAccountTokens() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
for i := range c.Accounts {
|
||||
c.Accounts[i].Token = ""
|
||||
}
|
||||
}
|
||||
|
||||
// DropInvalidAccounts removes accounts that cannot be addressed by admin APIs
|
||||
// (no email and no normalizable mobile). This prevents legacy token-only
|
||||
// records from becoming orphaned empty entries after token stripping.
|
||||
func (c *Config) DropInvalidAccounts() {
|
||||
if c == nil || len(c.Accounts) == 0 {
|
||||
return
|
||||
}
|
||||
kept := make([]Account, 0, len(c.Accounts))
|
||||
for _, acc := range c.Accounts {
|
||||
if acc.Identifier() == "" {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, acc)
|
||||
}
|
||||
c.Accounts = kept
|
||||
}
|
||||
|
||||
type CompatConfig struct {
|
||||
@@ -53,3 +79,7 @@ type ResponsesConfig struct {
|
||||
type EmbeddingsConfig struct {
|
||||
Provider string `json:"provider,omitempty"`
|
||||
}
|
||||
|
||||
type AutoDeleteConfig struct {
|
||||
Sessions bool `json:"sessions"`
|
||||
}
|
||||
|
||||
@@ -2,25 +2,23 @@ package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAccountIdentifierFallsBackToTokenHash(t *testing.T) {
|
||||
func TestAccountIdentifierRequiresEmailOrMobile(t *testing.T) {
|
||||
acc := Account{Token: "example-token-value"}
|
||||
id := acc.Identifier()
|
||||
if !strings.HasPrefix(id, "token:") {
|
||||
t.Fatalf("expected token-prefixed identifier, got %q", id)
|
||||
}
|
||||
if len(id) != len("token:")+16 {
|
||||
t.Fatalf("unexpected identifier length: %d (%q)", len(id), id)
|
||||
if id != "" {
|
||||
t.Fatalf("expected empty identifier when only token is present, got %q", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreFindAccountWithTokenOnlyIdentifier(t *testing.T) {
|
||||
func TestLoadStoreClearsTokensFromConfigInput(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
"accounts":[{"email":"u@example.com","password":"p","token":"token-only-account"}]
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
@@ -28,22 +26,62 @@ func TestStoreFindAccountWithTokenOnlyIdentifier(t *testing.T) {
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||
}
|
||||
id := accounts[0].Identifier()
|
||||
if id == "" {
|
||||
t.Fatalf("expected synthetic identifier for token-only account")
|
||||
}
|
||||
found, ok := store.FindAccount(id)
|
||||
if !ok {
|
||||
t.Fatalf("expected FindAccount to locate token-only account by synthetic id")
|
||||
}
|
||||
if found.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token value: %q", found.Token)
|
||||
if accounts[0].Token != "" {
|
||||
t.Fatalf("expected token to be cleared after loading, got %q", accounts[0].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T) {
|
||||
func TestLoadStoreDropsLegacyTokenOnlyAccounts(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"accounts":[{"token":"old-token"}]
|
||||
"accounts":[
|
||||
{"token":"legacy-token-only"},
|
||||
{"email":"u@example.com","password":"p","token":"runtime-token"}
|
||||
]
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
accounts := store.Accounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected token-only account to be dropped, got %d accounts", len(accounts))
|
||||
}
|
||||
if accounts[0].Identifier() != "u@example.com" {
|
||||
t.Fatalf("unexpected remaining account: %#v", accounts[0])
|
||||
}
|
||||
if accounts[0].Token != "" {
|
||||
t.Fatalf("expected persisted token to be cleared, got %q", accounts[0].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadStorePreservesFileBackedTokensForRuntime(t *testing.T) {
|
||||
tmp, err := os.CreateTemp(t.TempDir(), "config-*.json")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp config: %v", err)
|
||||
}
|
||||
defer tmp.Close()
|
||||
|
||||
if _, err := tmp.WriteString(`{
|
||||
"accounts":[{"email":"u@example.com","password":"p","token":"persisted-token"}]
|
||||
}`); err != nil {
|
||||
t.Fatalf("write temp config: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("DS2API_CONFIG_JSON", "")
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
t.Setenv("DS2API_CONFIG_PATH", tmp.Name())
|
||||
|
||||
store := LoadStore()
|
||||
accounts := store.Accounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||
}
|
||||
if accounts[0].Token != "persisted-token" {
|
||||
t.Fatalf("expected file-backed token preserved for runtime use, got %q", accounts[0].Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreUpdateAccountTokenKeepsIdentifierResolvable(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"accounts":[{"email":"user@example.com","password":"p"}]
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
@@ -52,23 +90,12 @@ func TestStoreUpdateAccountTokenKeepsOldAndNewIdentifierResolvable(t *testing.T)
|
||||
t.Fatalf("expected 1 account, got %d", len(before))
|
||||
}
|
||||
oldID := before[0].Identifier()
|
||||
if oldID == "" {
|
||||
t.Fatal("expected old identifier")
|
||||
}
|
||||
if err := store.UpdateAccountToken(oldID, "new-token"); err != nil {
|
||||
t.Fatalf("update token failed: %v", err)
|
||||
}
|
||||
|
||||
after := store.Accounts()
|
||||
newID := after[0].Identifier()
|
||||
if newID == "" || newID == oldID {
|
||||
t.Fatalf("expected changed identifier, old=%q new=%q", oldID, newID)
|
||||
}
|
||||
if got, ok := store.FindAccount(newID); !ok || got.Token != "new-token" {
|
||||
t.Fatalf("expected find by new identifier")
|
||||
}
|
||||
if got, ok := store.FindAccount(oldID); !ok || got.Token != "new-token" {
|
||||
t.Fatalf("expected find by old identifier alias")
|
||||
t.Fatalf("expected find by stable account identifier")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,3 +148,39 @@ func TestLoadConfigOnVercelWithoutConfigFileFallsBackToMemory(t *testing.T) {
|
||||
t.Fatalf("expected empty bootstrap config, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountTestStatusIsRuntimeOnlyAndNotPersisted(t *testing.T) {
|
||||
tmp, err := os.CreateTemp(t.TempDir(), "config-*.json")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp config: %v", err)
|
||||
}
|
||||
defer tmp.Close()
|
||||
if _, err := tmp.WriteString(`{
|
||||
"accounts":[{"email":"u@example.com","password":"p","test_status":"ok"}]
|
||||
}`); err != nil {
|
||||
t.Fatalf("write temp config: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("DS2API_CONFIG_JSON", "")
|
||||
t.Setenv("CONFIG_JSON", "")
|
||||
t.Setenv("DS2API_CONFIG_PATH", tmp.Name())
|
||||
|
||||
store := LoadStore()
|
||||
if got, ok := store.AccountTestStatus("u@example.com"); ok || got != "" {
|
||||
t.Fatalf("expected no runtime status loaded from config, got %q", got)
|
||||
}
|
||||
if err := store.UpdateAccountTestStatus("u@example.com", "ok"); err != nil {
|
||||
t.Fatalf("update test status: %v", err)
|
||||
}
|
||||
if got, ok := store.AccountTestStatus("u@example.com"); !ok || got != "ok" {
|
||||
t.Fatalf("expected runtime status to be available, got %q (ok=%v)", got, ok)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(tmp.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("read config: %v", err)
|
||||
}
|
||||
if strings.Contains(string(content), "test_status") {
|
||||
t.Fatalf("expected test_status to stay out of persisted config, got: %s", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ type Store struct {
|
||||
fromEnv bool
|
||||
keyMap map[string]struct{} // O(1) API key lookup index
|
||||
accMap map[string]int // O(1) account lookup: identifier -> slice index
|
||||
accTest map[string]string // runtime-only account test status cache
|
||||
}
|
||||
|
||||
func LoadStore() *Store {
|
||||
@@ -39,6 +40,8 @@ func loadConfig() (Config, bool, error) {
|
||||
}
|
||||
if rawCfg != "" {
|
||||
cfg, err := parseConfigString(rawCfg)
|
||||
cfg.ClearAccountTokens()
|
||||
cfg.DropInvalidAccounts()
|
||||
return cfg, true, err
|
||||
}
|
||||
|
||||
@@ -55,6 +58,12 @@ func loadConfig() (Config, bool, error) {
|
||||
if err := json.Unmarshal(content, &cfg); err != nil {
|
||||
return Config{}, false, err
|
||||
}
|
||||
cfg.DropInvalidAccounts()
|
||||
if strings.Contains(string(content), `"test_status"`) && !IsVercel() {
|
||||
if b, err := json.MarshalIndent(cfg, "", " "); err == nil {
|
||||
_ = os.WriteFile(ConfigPath(), b, 0o644)
|
||||
}
|
||||
}
|
||||
if IsVercel() {
|
||||
// Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
|
||||
return cfg, true, nil
|
||||
@@ -105,8 +114,19 @@ func (s *Store) UpdateAccountTestStatus(identifier, status string) error {
|
||||
if !ok {
|
||||
return errors.New("account not found")
|
||||
}
|
||||
s.cfg.Accounts[idx].TestStatus = status
|
||||
return s.saveLocked()
|
||||
s.setAccountTestStatusLocked(s.cfg.Accounts[idx], status, identifier)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) AccountTestStatus(identifier string) (string, bool) {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
if identifier == "" {
|
||||
return "", false
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
status, ok := s.accTest[identifier]
|
||||
return status, ok
|
||||
}
|
||||
|
||||
func (s *Store) UpdateAccountToken(identifier, token string) error {
|
||||
@@ -161,7 +181,9 @@ func (s *Store) Save() error {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
persistCfg := s.cfg.Clone()
|
||||
persistCfg.ClearAccountTokens()
|
||||
b, err := json.MarshalIndent(persistCfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -173,7 +195,9 @@ func (s *Store) saveLocked() error {
|
||||
Logger.Info("[save_config] source from env, skip write")
|
||||
return nil
|
||||
}
|
||||
b, err := json.MarshalIndent(s.cfg, "", " ")
|
||||
persistCfg := s.cfg.Clone()
|
||||
persistCfg.ClearAccountTokens()
|
||||
b, err := json.MarshalIndent(persistCfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -197,7 +221,9 @@ func (s *Store) SetVercelSync(hash string, ts int64) error {
|
||||
func (s *Store) ExportJSONAndBase64() (string, string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
b, err := json.Marshal(s.cfg)
|
||||
exportCfg := s.cfg.Clone()
|
||||
exportCfg.ClearAccountTokens()
|
||||
b, err := json.Marshal(exportCfg)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
@@ -165,3 +165,9 @@ func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
func (s *Store) AutoDeleteSessions() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cfg.AutoDelete.Sessions
|
||||
}
|
||||
|
||||
@@ -2,15 +2,20 @@ package config
|
||||
|
||||
// rebuildIndexes must be called with the lock already held (or during init).
|
||||
func (s *Store) rebuildIndexes() {
|
||||
prevStatus := s.accTest
|
||||
s.keyMap = make(map[string]struct{}, len(s.cfg.Keys))
|
||||
for _, k := range s.cfg.Keys {
|
||||
s.keyMap[k] = struct{}{}
|
||||
}
|
||||
s.accMap = make(map[string]int, len(s.cfg.Accounts))
|
||||
s.accTest = make(map[string]string, len(s.cfg.Accounts))
|
||||
for i, acc := range s.cfg.Accounts {
|
||||
id := acc.Identifier()
|
||||
if id != "" {
|
||||
s.accMap[id] = i
|
||||
if status, ok := prevStatus[id]; ok {
|
||||
s.setAccountTestStatusLocked(acc, status, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -29,3 +34,22 @@ func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
|
||||
}
|
||||
return -1, false
|
||||
}
|
||||
|
||||
func (s *Store) setAccountTestStatusLocked(acc Account, status, hintedIdentifier string) {
|
||||
status = lower(status)
|
||||
if status == "" {
|
||||
return
|
||||
}
|
||||
if id := acc.Identifier(); id != "" {
|
||||
s.accTest[id] = status
|
||||
}
|
||||
if email := acc.Email; email != "" {
|
||||
s.accTest[email] = status
|
||||
}
|
||||
if mobile := CanonicalMobileKey(acc.Mobile); mobile != "" {
|
||||
s.accTest[mobile] = status
|
||||
}
|
||||
if hintedIdentifier = lower(hintedIdentifier); hintedIdentifier != "" {
|
||||
s.accTest[hintedIdentifier] = status
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,8 +62,8 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
sessionID, _ := bizData["id"].(string)
|
||||
@@ -71,10 +71,9 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
|
||||
return sessionID, nil
|
||||
}
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
config.Logger.Warn("[create_session] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) && !refreshed {
|
||||
if !refreshed && shouldAttemptRefresh(status, code, bizCode, msg, bizMsg) {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
@@ -96,6 +95,7 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
||||
@@ -104,8 +104,8 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
code := intFrom(resp["code"])
|
||||
if status == http.StatusOK && code == 0 {
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
challenge, _ := bizData["challenge"].(map[string]any)
|
||||
@@ -116,15 +116,16 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
||||
}
|
||||
return BuildPowHeader(challenge, answer)
|
||||
}
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, msg) {
|
||||
if !refreshed && shouldAttemptRefresh(status, code, bizCode, msg, bizMsg) {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
@@ -143,15 +144,75 @@ func (c *Client) authHeaders(token string) map[string]string {
|
||||
return headers
|
||||
}
|
||||
|
||||
func isTokenInvalid(status int, code int, msg string) bool {
|
||||
msg = strings.ToLower(msg)
|
||||
func isTokenInvalid(status int, code int, bizCode int, msg string, bizMsg string) bool {
|
||||
msg = strings.ToLower(strings.TrimSpace(msg) + " " + strings.TrimSpace(bizMsg))
|
||||
if status == http.StatusUnauthorized || status == http.StatusForbidden {
|
||||
return true
|
||||
}
|
||||
if code == 40001 || code == 40002 || code == 40003 {
|
||||
if code == 40001 || code == 40002 || code == 40003 || bizCode == 40001 || bizCode == 40002 || bizCode == 40003 {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
|
||||
return strings.Contains(msg, "token") ||
|
||||
strings.Contains(msg, "unauthorized") ||
|
||||
strings.Contains(msg, "expired") ||
|
||||
strings.Contains(msg, "not login") ||
|
||||
strings.Contains(msg, "login required") ||
|
||||
strings.Contains(msg, "invalid jwt")
|
||||
}
|
||||
|
||||
func shouldAttemptRefresh(status int, code int, bizCode int, msg string, bizMsg string) bool {
|
||||
if isTokenInvalid(status, code, bizCode, msg, bizMsg) {
|
||||
return true
|
||||
}
|
||||
// Some DeepSeek failures come back as HTTP 200/code=0 but with non-zero biz_code.
|
||||
// Only attempt refresh when these biz failures still look auth-related.
|
||||
return status == http.StatusOK &&
|
||||
code == 0 &&
|
||||
bizCode != 0 &&
|
||||
isAuthIndicativeBizFailure(msg, bizMsg)
|
||||
}
|
||||
|
||||
func isAuthIndicativeBizFailure(msg string, bizMsg string) bool {
|
||||
combined := strings.ToLower(strings.TrimSpace(msg) + " " + strings.TrimSpace(bizMsg))
|
||||
authKeywords := []string{
|
||||
"auth",
|
||||
"authorization",
|
||||
"credential",
|
||||
"expired",
|
||||
"invalid jwt",
|
||||
"jwt",
|
||||
"login",
|
||||
"not login",
|
||||
"session expired",
|
||||
"token",
|
||||
"unauthorized",
|
||||
"登录",
|
||||
"未登录",
|
||||
"认证",
|
||||
"凭证",
|
||||
"会话过期",
|
||||
"令牌",
|
||||
}
|
||||
for _, keyword := range authKeywords {
|
||||
if strings.Contains(combined, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractResponseStatus(resp map[string]any) (code int, bizCode int, msg string, bizMsg string) {
|
||||
code = intFrom(resp["code"])
|
||||
msg, _ = resp["msg"].(string)
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizCode = intFrom(data["biz_code"])
|
||||
bizMsg, _ = data["biz_msg"].(string)
|
||||
if strings.TrimSpace(bizMsg) == "" {
|
||||
if bizData, ok := data["biz_data"].(map[string]any); ok {
|
||||
bizMsg, _ = bizData["msg"].(string)
|
||||
}
|
||||
}
|
||||
return code, bizCode, msg, bizMsg
|
||||
}
|
||||
|
||||
func normalizeMobileForLogin(raw string) (mobile string, areaCode any) {
|
||||
|
||||
27
internal/deepseek/client_auth_refresh_test.go
Normal file
27
internal/deepseek/client_auth_refresh_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package deepseek
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShouldAttemptRefreshOnTokenInvalidSignal(t *testing.T) {
|
||||
if !shouldAttemptRefresh(401, 0, 0, "unauthorized", "") {
|
||||
t.Fatal("expected refresh when response indicates invalid token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldAttemptRefreshOnAuthIndicativeBizCodeFailure(t *testing.T) {
|
||||
if !shouldAttemptRefresh(200, 0, 400123, "", "login expired, token invalid") {
|
||||
t.Fatal("expected refresh on auth-indicative biz_code failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldAttemptRefreshFalseOnNonAuthBizCodeFailure(t *testing.T) {
|
||||
if shouldAttemptRefresh(200, 0, 400123, "", "session create failed: quota reached") {
|
||||
t.Fatal("did not expect refresh on non-auth biz_code failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldAttemptRefreshFalseOnGenericServerError(t *testing.T) {
|
||||
if shouldAttemptRefresh(500, 500, 0, "internal error", "") {
|
||||
t.Fatal("did not expect refresh on generic server error")
|
||||
}
|
||||
}
|
||||
@@ -62,3 +62,40 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
|
||||
}
|
||||
return out, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string) (map[string]any, int, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := doer.Do(req)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[deepseek] fingerprint GET request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if reqErr != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
resp, err = c.fallback.Do(req2)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
payloadBytes, err := readResponseBody(resp)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
out := map[string]any{}
|
||||
if len(payloadBytes) > 0 {
|
||||
if err := json.Unmarshal(payloadBytes, &out); err != nil {
|
||||
config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
|
||||
}
|
||||
}
|
||||
return out, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
256
internal/deepseek/client_session.go
Normal file
256
internal/deepseek/client_session.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
// SessionInfo 会话信息
|
||||
type SessionInfo struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
TitleType string `json:"title_type"`
|
||||
Pinned bool `json:"pinned"`
|
||||
UpdatedAt float64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// SessionStats 会话统计结果
|
||||
type SessionStats struct {
|
||||
AccountID string // 账号标识 (email 或 mobile)
|
||||
FirstPageCount int // 第一页会话数量(当 HasMore 为 true 时,真实总数可能更大)
|
||||
PinnedCount int // 置顶会话数量
|
||||
HasMore bool // 是否还有更多页
|
||||
Success bool // 请求是否成功
|
||||
ErrorMessage string // 错误信息
|
||||
}
|
||||
|
||||
// GetSessionCount 获取单个账号的会话数量
|
||||
func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (*SessionStats, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
|
||||
stats := &SessionStats{
|
||||
AccountID: a.AccountID,
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
|
||||
// 构建请求 URL
|
||||
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[get_session_count] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
chatSessions, _ := bizData["chat_sessions"].([]any)
|
||||
hasMore, _ := bizData["has_more"].(bool)
|
||||
|
||||
stats.FirstPageCount = len(chatSessions)
|
||||
stats.HasMore = hasMore
|
||||
stats.Success = true
|
||||
|
||||
// 统计置顶会话数量
|
||||
for _, session := range chatSessions {
|
||||
if s, ok := session.(map[string]any); ok {
|
||||
if pinned, ok := s["pinned"].(bool); ok && pinned {
|
||||
stats.PinnedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
stats.ErrorMessage = fmt.Sprintf("status=%d, code=%d, msg=%s", status, code, msg)
|
||||
config.Logger.Warn("[get_session_count] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "account", a.AccountID)
|
||||
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
|
||||
stats.Success = false
|
||||
stats.ErrorMessage = "get session count failed after retries"
|
||||
return stats, errors.New(stats.ErrorMessage)
|
||||
}
|
||||
|
||||
// GetSessionCountForToken 直接使用 token 获取会话数量(直通模式)
|
||||
func (c *Client) GetSessionCountForToken(ctx context.Context, token string) (*SessionStats, error) {
|
||||
headers := c.authHeaders(token)
|
||||
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status != http.StatusOK || code != 0 || bizCode != 0 {
|
||||
if strings.TrimSpace(bizMsg) != "" {
|
||||
msg = bizMsg
|
||||
}
|
||||
return nil, fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
}
|
||||
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
chatSessions, _ := bizData["chat_sessions"].([]any)
|
||||
hasMore, _ := bizData["has_more"].(bool)
|
||||
|
||||
stats := &SessionStats{
|
||||
FirstPageCount: len(chatSessions),
|
||||
HasMore: hasMore,
|
||||
Success: true,
|
||||
}
|
||||
|
||||
// 统计置顶会话数量
|
||||
for _, session := range chatSessions {
|
||||
if s, ok := session.(map[string]any); ok {
|
||||
if pinned, ok := s["pinned"].(bool); ok && pinned {
|
||||
stats.PinnedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetSessionCountAll 获取所有账号的会话数量统计
|
||||
func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
|
||||
accounts := c.Store.Accounts()
|
||||
results := make([]*SessionStats, 0, len(accounts))
|
||||
|
||||
for _, acc := range accounts {
|
||||
token := acc.Token
|
||||
accountID := acc.Email
|
||||
if accountID == "" {
|
||||
accountID = acc.Mobile
|
||||
}
|
||||
|
||||
// 如果没有 token,尝试登录获取
|
||||
if token == "" {
|
||||
var err error
|
||||
token, err = c.Login(ctx, acc)
|
||||
if err != nil {
|
||||
results = append(results, &SessionStats{
|
||||
AccountID: accountID,
|
||||
Success: false,
|
||||
ErrorMessage: fmt.Sprintf("login failed: %v", err),
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := c.GetSessionCountForToken(ctx, token)
|
||||
if err != nil {
|
||||
results = append(results, &SessionStats{
|
||||
AccountID: accountID,
|
||||
Success: false,
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
stats.AccountID = accountID
|
||||
results = append(results, stats)
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// FetchSessionPage 获取会话列表(支持分页)
|
||||
func (c *Client) FetchSessionPage(ctx context.Context, a *auth.RequestAuth, cursor string) ([]SessionInfo, bool, error) {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
|
||||
// 构建请求 URL
|
||||
params := url.Values{}
|
||||
params.Set("lte_cursor.pinned", "false")
|
||||
if cursor != "" {
|
||||
params.Set("lte_cursor", cursor)
|
||||
}
|
||||
reqURL := DeepSeekFetchSessionURL + "?" + params.Encode()
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
code := intFrom(resp["code"])
|
||||
if status != http.StatusOK || code != 0 {
|
||||
msg, _ := resp["msg"].(string)
|
||||
return nil, false, fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
}
|
||||
|
||||
data, _ := resp["data"].(map[string]any)
|
||||
bizData, _ := data["biz_data"].(map[string]any)
|
||||
chatSessions, _ := bizData["chat_sessions"].([]any)
|
||||
hasMore, _ := bizData["has_more"].(bool)
|
||||
|
||||
sessions := make([]SessionInfo, 0, len(chatSessions))
|
||||
for _, s := range chatSessions {
|
||||
if m, ok := s.(map[string]any); ok {
|
||||
session := SessionInfo{
|
||||
ID: stringFromMap(m, "id"),
|
||||
Title: stringFromMap(m, "title"),
|
||||
TitleType: stringFromMap(m, "title_type"),
|
||||
Pinned: boolFromMap(m, "pinned"),
|
||||
UpdatedAt: floatFromMap(m, "updated_at"),
|
||||
}
|
||||
sessions = append(sessions, session)
|
||||
}
|
||||
}
|
||||
|
||||
return sessions, hasMore, nil
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func stringFromMap(m map[string]any, key string) string {
|
||||
if v, ok := m[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func boolFromMap(m map[string]any, key string) bool {
|
||||
if v, ok := m[key].(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func floatFromMap(m map[string]any, key string) float64 {
|
||||
if v, ok := m[key].(float64); ok {
|
||||
return v
|
||||
}
|
||||
return 0
|
||||
}
|
||||
155
internal/deepseek/client_session_delete.go
Normal file
155
internal/deepseek/client_session_delete.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
// DeleteSessionResult 删除会话结果
|
||||
type DeleteSessionResult struct {
|
||||
SessionID string // 会话 ID
|
||||
Success bool // 是否成功
|
||||
ErrorMessage string // 错误信息
|
||||
}
|
||||
|
||||
// DeleteSession 删除单个会话
|
||||
func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, sessionID string, maxAttempts int) (*DeleteSessionResult, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
|
||||
result := &DeleteSessionResult{
|
||||
SessionID: sessionID,
|
||||
}
|
||||
|
||||
if sessionID == "" {
|
||||
result.ErrorMessage = "session_id is required"
|
||||
return result, errors.New(result.ErrorMessage)
|
||||
}
|
||||
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[delete_session] request error", "error", err, "session_id", sessionID)
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
|
||||
code, bizCode, msg, bizMsg := extractResponseStatus(resp)
|
||||
if status == http.StatusOK && code == 0 && bizCode == 0 {
|
||||
result.Success = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result.ErrorMessage = fmt.Sprintf("status=%d, code=%d, msg=%s", status, code, msg)
|
||||
config.Logger.Warn("[delete_session] failed", "status", status, "code", code, "biz_code", bizCode, "msg", msg, "biz_msg", bizMsg, "session_id", sessionID)
|
||||
|
||||
if a.UseConfigToken {
|
||||
if isTokenInvalid(status, code, bizCode, msg, bizMsg) && !refreshed {
|
||||
if c.Auth.RefreshToken(ctx, a) {
|
||||
refreshed = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c.Auth.SwitchAccount(ctx, a) {
|
||||
refreshed = false
|
||||
attempts++
|
||||
continue
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
}
|
||||
|
||||
result.Success = false
|
||||
result.ErrorMessage = "delete session failed after retries"
|
||||
return result, errors.New(result.ErrorMessage)
|
||||
}
|
||||
|
||||
// DeleteSessionForToken 直接使用 token 删除会话(直通模式)
|
||||
func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessionID string) (*DeleteSessionResult, error) {
|
||||
result := &DeleteSessionResult{
|
||||
SessionID: sessionID,
|
||||
}
|
||||
|
||||
if sessionID == "" {
|
||||
result.ErrorMessage = "session_id is required"
|
||||
return result, errors.New(result.ErrorMessage)
|
||||
}
|
||||
|
||||
headers := c.authHeaders(token)
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||
if err != nil {
|
||||
result.ErrorMessage = err.Error()
|
||||
return result, err
|
||||
}
|
||||
|
||||
code := intFrom(resp["code"])
|
||||
if status != http.StatusOK || code != 0 {
|
||||
msg, _ := resp["msg"].(string)
|
||||
result.ErrorMessage = fmt.Sprintf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
return result, errors.New(result.ErrorMessage)
|
||||
}
|
||||
|
||||
result.Success = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteAllSessions 删除所有会话(谨慎使用)
|
||||
func (c *Client) DeleteAllSessions(ctx context.Context, a *auth.RequestAuth) error {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
payload := map[string]any{}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[delete_all_sessions] request error", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
code := intFrom(resp["code"])
|
||||
if status != http.StatusOK || code != 0 {
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[delete_all_sessions] failed", "status", status, "code", code, "msg", msg)
|
||||
return fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAllSessionsForToken 直接使用 token 删除所有会话(直通模式)
|
||||
func (c *Client) DeleteAllSessionsForToken(ctx context.Context, token string) error {
|
||||
headers := c.authHeaders(token)
|
||||
payload := map[string]any{}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[delete_all_sessions_for_token] request error", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
code := intFrom(resp["code"])
|
||||
if status != http.StatusOK || code != 0 {
|
||||
msg, _ := resp["msg"].(string)
|
||||
config.Logger.Warn("[delete_all_sessions_for_token] failed", "status", status, "code", code, "msg", msg)
|
||||
return fmt.Errorf("request failed: status=%d, code=%d, msg=%s", status, code, msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -11,6 +11,9 @@ const (
|
||||
DeepSeekCreateSessionURL = "https://chat.deepseek.com/api/v0/chat_session/create"
|
||||
DeepSeekCreatePowURL = "https://chat.deepseek.com/api/v0/chat/create_pow_challenge"
|
||||
DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion"
|
||||
DeepSeekFetchSessionURL = "https://chat.deepseek.com/api/v0/chat_session/fetch_page"
|
||||
DeepSeekDeleteSessionURL = "https://chat.deepseek.com/api/v0/chat_session/delete"
|
||||
DeepSeekDeleteAllSessionsURL = "https://chat.deepseek.com/api/v0/chat_session/delete_all"
|
||||
)
|
||||
|
||||
var defaultBaseHeaders = map[string]string{
|
||||
|
||||
@@ -8,15 +8,15 @@ import (
|
||||
)
|
||||
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
messageObj["reasoning_content"] = finalThinking
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
if len(detected.Calls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected.Calls)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
// Strict mode: only standalone, structured tool-call payloads are treated
|
||||
// as executable tool calls.
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames)
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
if len(detected.Calls) > 0 {
|
||||
exposedOutputText = ""
|
||||
output = append(output, toResponsesFunctionCallItems(detected)...)
|
||||
output = append(output, toResponsesFunctionCallItems(detected.Calls)...)
|
||||
} else {
|
||||
content := make([]any, 0, 2)
|
||||
if finalThinking != "" {
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -45,7 +46,7 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) {
|
||||
func TestBuildResponseObjectPromotesMixedProseToolPayloadToFunctionCall(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
@@ -56,8 +57,32 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) {
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one function_call output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected function_call output type, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if !strings.Contains(outputText, "\"tool_calls\"") {
|
||||
t.Fatalf("expected output_text to preserve fenced tool payload, got %q", outputText)
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
@@ -69,28 +94,9 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
"```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```",
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for fenced example")
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected message output type, got %#v", first["type"])
|
||||
}
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestBuildResponseObjectPromotesFencedToolPayloadToFunctionCall(t *testing.T) {
|
||||
TestBuildResponseObjectKeepsFencedToolPayloadAsText(t)
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
||||
|
||||
@@ -10,10 +10,10 @@ function resolveToolcallPolicy(prepBody, payloadTools) {
|
||||
const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names);
|
||||
const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
|
||||
const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match);
|
||||
const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
|
||||
const emitEarlyToolDeltas = featureMatchEnabled && boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
|
||||
return {
|
||||
toolNames,
|
||||
toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled,
|
||||
toolSieveEnabled: toolNames.length > 0,
|
||||
emitEarlyToolDeltas,
|
||||
};
|
||||
}
|
||||
@@ -60,6 +60,9 @@ function formatIncrementalToolCallDeltas(deltas, idStore) {
|
||||
if (typeof d.arguments === 'string' && d.arguments !== '') {
|
||||
fn.arguments = d.arguments;
|
||||
}
|
||||
if (Object.keys(fn).length === 0) {
|
||||
continue;
|
||||
}
|
||||
if (Object.keys(fn).length > 0) {
|
||||
item.function = fn;
|
||||
}
|
||||
|
||||
@@ -1,33 +1,22 @@
|
||||
'use strict';
|
||||
|
||||
const {
|
||||
extractToolNames,
|
||||
createToolSieveState,
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseStandaloneToolCalls,
|
||||
formatOpenAIStreamToolCalls,
|
||||
} = require('../helpers/stream-tool-sieve');
|
||||
const {
|
||||
BASE_HEADERS,
|
||||
} = require('../shared/deepseek-constants');
|
||||
|
||||
const {
|
||||
writeOpenAIError,
|
||||
} = require('./error_shape');
|
||||
const {
|
||||
parseChunkForContent,
|
||||
isCitation,
|
||||
} = require('./sse_parse');
|
||||
const {
|
||||
buildUsage,
|
||||
} = require('./token_usage');
|
||||
const { BASE_HEADERS } = require('../shared/deepseek-constants');
|
||||
const { writeOpenAIError } = require('./error_shape');
|
||||
const { parseChunkForContent, isCitation } = require('./sse_parse');
|
||||
const { buildUsage } = require('./token_usage');
|
||||
const {
|
||||
resolveToolcallPolicy,
|
||||
formatIncrementalToolCallDeltas,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
} = require('./toolcall_policy');
|
||||
const {
|
||||
createChatCompletionEmitter,
|
||||
} = require('./stream_emitter');
|
||||
const { createChatCompletionEmitter } = require('./stream_emitter');
|
||||
const {
|
||||
asString,
|
||||
isAbortError,
|
||||
@@ -57,6 +46,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
const searchEnabled = toBool(prep.body.search_enabled);
|
||||
const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools);
|
||||
const toolNames = toolPolicy.toolNames;
|
||||
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
|
||||
|
||||
if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
|
||||
writeOpenAIError(res, 500, 'invalid vercel prepare response');
|
||||
@@ -132,6 +122,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
const toolSieveState = createToolSieveState();
|
||||
let toolCallsEmitted = false;
|
||||
const streamToolCallIDs = new Map();
|
||||
const streamToolNames = new Map();
|
||||
const decoder = new TextDecoder();
|
||||
reader = completionRes.body.getReader();
|
||||
let buffered = '';
|
||||
@@ -255,6 +246,18 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
}
|
||||
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
||||
for (const evt of events) {
|
||||
if (evt.type === 'tool_call_deltas') {
|
||||
if (!emitEarlyToolDeltas) {
|
||||
continue;
|
||||
}
|
||||
const filtered = filterIncrementalToolCallDeltasByAllowed(evt.deltas, toolNames, streamToolNames);
|
||||
const formatted = formatIncrementalToolCallDeltas(filtered, streamToolCallIDs);
|
||||
if (formatted.length > 0) {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatted });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (evt.type === 'tool_calls') {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||
|
||||
@@ -2,32 +2,36 @@
|
||||
|
||||
const {
|
||||
toStringSafe,
|
||||
looksLikeToolExampleContext,
|
||||
} = require('./state');
|
||||
const {
|
||||
stripFencedCodeBlocks,
|
||||
buildToolCallCandidates,
|
||||
parseToolCallsPayload,
|
||||
parseMarkupToolCalls,
|
||||
parseTextKVToolCalls,
|
||||
stripFencedCodeBlocks,
|
||||
} = require('./parse_payload');
|
||||
const { TOOL_SEGMENT_KEYWORDS } = require('./tool-keywords');
|
||||
|
||||
const TOOL_NAME_LOOSE_PATTERN = /[^a-z0-9]+/g;
|
||||
const TOOL_MARKUP_PREFIXES = ['<tool_call', '<function_call', '<invoke'];
|
||||
|
||||
function extractToolNames(tools) {
|
||||
if (!Array.isArray(tools) || tools.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
const seen = new Set();
|
||||
for (const t of tools) {
|
||||
if (!t || typeof t !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const fn = t.function && typeof t.function === 'object' ? t.function : t;
|
||||
const name = toStringSafe(fn.name);
|
||||
// Keep parity with Go injectToolPrompt: object tools without name still
|
||||
// enter tool mode via fallback name "unknown".
|
||||
out.push(name || 'unknown');
|
||||
if (!name || seen.has(name)) {
|
||||
continue;
|
||||
}
|
||||
seen.add(name);
|
||||
out.push(name);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@@ -38,16 +42,16 @@ function parseToolCalls(text, toolNames) {
|
||||
|
||||
function parseToolCallsDetailed(text, toolNames) {
|
||||
const result = emptyParseResult();
|
||||
if (!toStringSafe(text)) {
|
||||
const normalized = toStringSafe(text);
|
||||
if (!normalized) {
|
||||
return result;
|
||||
}
|
||||
const sanitized = stripFencedCodeBlocks(text);
|
||||
if (!toStringSafe(sanitized)) {
|
||||
result.sawToolCallSyntax = looksLikeToolCallSyntax(normalized);
|
||||
if (shouldSkipToolCallParsingForCodeFenceExample(normalized)) {
|
||||
return result;
|
||||
}
|
||||
result.sawToolCallSyntax = looksLikeToolCallSyntax(sanitized);
|
||||
|
||||
const candidates = buildToolCallCandidates(sanitized);
|
||||
const candidates = buildToolCallCandidates(normalized);
|
||||
let parsed = [];
|
||||
for (const c of candidates) {
|
||||
parsed = parseToolCallsPayload(c);
|
||||
@@ -63,9 +67,9 @@ function parseToolCallsDetailed(text, toolNames) {
|
||||
}
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseMarkupToolCalls(sanitized);
|
||||
parsed = parseMarkupToolCalls(normalized);
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseTextKVToolCalls(sanitized);
|
||||
parsed = parseTextKVToolCalls(normalized);
|
||||
if (parsed.length === 0) {
|
||||
return result;
|
||||
}
|
||||
@@ -90,22 +94,32 @@ function parseStandaloneToolCallsDetailed(text, toolNames) {
|
||||
if (!trimmed) {
|
||||
return result;
|
||||
}
|
||||
if (trimmed.includes('```')) {
|
||||
return result;
|
||||
}
|
||||
if (looksLikeToolExampleContext(trimmed)) {
|
||||
return result;
|
||||
}
|
||||
result.sawToolCallSyntax = looksLikeToolCallSyntax(trimmed);
|
||||
let parsed = parseToolCallsPayload(trimmed);
|
||||
if (shouldSkipToolCallParsingForCodeFenceExample(trimmed)) {
|
||||
return result;
|
||||
}
|
||||
const candidates = buildToolCallCandidates(trimmed);
|
||||
let parsed = [];
|
||||
for (const c of candidates) {
|
||||
parsed = parseToolCallsPayload(c);
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseMarkupToolCalls(c);
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseTextKVToolCalls(c);
|
||||
}
|
||||
if (parsed.length > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseMarkupToolCalls(trimmed);
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseTextKVToolCalls(trimmed);
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
return result;
|
||||
if (parsed.length === 0) {
|
||||
parsed = parseTextKVToolCalls(trimmed);
|
||||
if (parsed.length === 0) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.sawToolCallSyntax = true;
|
||||
@@ -218,11 +232,16 @@ function resolveAllowedToolName(name, allowed, allowedCanonical) {
|
||||
|
||||
function looksLikeToolCallSyntax(text) {
|
||||
const lower = toStringSafe(text).toLowerCase();
|
||||
return lower.includes('tool_calls')
|
||||
|| lower.includes('<tool_call')
|
||||
|| lower.includes('<function_call')
|
||||
|| lower.includes('<invoke')
|
||||
|| lower.includes('function.name:');
|
||||
return TOOL_SEGMENT_KEYWORDS.some((kw) => lower.includes(kw))
|
||||
|| TOOL_MARKUP_PREFIXES.some((prefix) => lower.includes(prefix));
|
||||
}
|
||||
|
||||
function shouldSkipToolCallParsingForCodeFenceExample(text) {
|
||||
if (!looksLikeToolCallSyntax(text)) {
|
||||
return false;
|
||||
}
|
||||
const stripped = stripFencedCodeBlocks(text);
|
||||
return !looksLikeToolCallSyntax(stripped);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
|
||||
@@ -114,6 +114,9 @@ function parseToolCallsPayload(payload) {
|
||||
return [];
|
||||
}
|
||||
if (decoded.tool_calls) {
|
||||
if (isLikelyChatMessageEnvelope(decoded)) {
|
||||
return [];
|
||||
}
|
||||
return parseToolCallList(decoded.tool_calls);
|
||||
}
|
||||
|
||||
@@ -121,6 +124,21 @@ function parseToolCallsPayload(payload) {
|
||||
return one ? [one] : [];
|
||||
}
|
||||
|
||||
function isLikelyChatMessageEnvelope(value) {
|
||||
if (!value || typeof value !== 'object' || Array.isArray(value)) {
|
||||
return false;
|
||||
}
|
||||
if (!Object.prototype.hasOwnProperty.call(value, 'tool_calls')) {
|
||||
return false;
|
||||
}
|
||||
const role = toStringSafe(value.role).trim().toLowerCase();
|
||||
if (role === 'assistant' || role === 'tool' || role === 'user' || role === 'system') {
|
||||
return true;
|
||||
}
|
||||
return Object.prototype.hasOwnProperty.call(value, 'tool_call_id')
|
||||
|| Object.prototype.hasOwnProperty.call(value, 'content');
|
||||
}
|
||||
|
||||
function parseMarkupToolCalls(text) {
|
||||
const raw = toStringSafe(text).trim();
|
||||
if (!raw) {
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
'use strict';
|
||||
|
||||
const {
|
||||
resetIncrementalToolState,
|
||||
noteText,
|
||||
insideCodeFence,
|
||||
insideCodeFenceWithState,
|
||||
} = require('./state');
|
||||
const {
|
||||
parseStandaloneToolCallsDetailed,
|
||||
} = require('./parse');
|
||||
const {
|
||||
extractJSONObjectFrom,
|
||||
} = require('./jsonscan');
|
||||
|
||||
const { parseStandaloneToolCallsDetailed } = require('./parse');
|
||||
const { extractJSONObjectFrom } = require('./jsonscan');
|
||||
const { TOOL_SEGMENT_KEYWORDS, earliestKeywordIndex } = require('./tool-keywords');
|
||||
function processToolSieveChunk(state, chunk, toolNames) {
|
||||
if (!state) {
|
||||
return [];
|
||||
@@ -20,8 +15,6 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
state.pending += chunk;
|
||||
}
|
||||
const events = [];
|
||||
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
|
||||
@@ -46,6 +39,9 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
||||
state.pendingToolRaw = captured;
|
||||
state.pendingToolCalls = consumed.calls;
|
||||
if (consumed.suffix) {
|
||||
state.pending = consumed.suffix + state.pending;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (consumed.prefix) {
|
||||
@@ -57,13 +53,11 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const pending = state.pending || '';
|
||||
if (!pending) {
|
||||
break;
|
||||
}
|
||||
|
||||
const start = findToolSegmentStart(pending);
|
||||
const start = findToolSegmentStart(state, pending);
|
||||
if (start >= 0) {
|
||||
const prefix = pending.slice(0, start);
|
||||
if (prefix) {
|
||||
@@ -76,7 +70,6 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
resetIncrementalToolState(state);
|
||||
continue;
|
||||
}
|
||||
|
||||
const [safe, hold] = splitSafeContentForToolDetection(pending);
|
||||
if (!safe) {
|
||||
break;
|
||||
@@ -93,13 +86,11 @@ function flushToolSieve(state, toolNames) {
|
||||
return [];
|
||||
}
|
||||
const events = processToolSieveChunk(state, '', toolNames);
|
||||
|
||||
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
|
||||
state.pendingToolRaw = '';
|
||||
state.pendingToolCalls = [];
|
||||
}
|
||||
|
||||
if (state.capturing) {
|
||||
const consumed = consumeToolCapture(state, toolNames);
|
||||
if (consumed.ready) {
|
||||
@@ -122,13 +113,11 @@ function flushToolSieve(state, toolNames) {
|
||||
state.capturing = false;
|
||||
resetIncrementalToolState(state);
|
||||
}
|
||||
|
||||
if (state.pending) {
|
||||
noteText(state, state.pending);
|
||||
events.push({ type: 'text', text: state.pending });
|
||||
state.pending = '';
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
@@ -144,8 +133,6 @@ function splitSafeContentForToolDetection(s) {
|
||||
if (suspiciousStart > 0) {
|
||||
return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)];
|
||||
}
|
||||
// If suspicious content starts at the beginning, keep holding until we can
|
||||
// either parse a full tool JSON block or reach stream flush.
|
||||
return ['', text];
|
||||
}
|
||||
|
||||
@@ -160,24 +147,24 @@ function findSuspiciousPrefixStart(s) {
|
||||
return start;
|
||||
}
|
||||
|
||||
function findToolSegmentStart(s) {
|
||||
function findToolSegmentStart(state, s) {
|
||||
if (!s) {
|
||||
return -1;
|
||||
}
|
||||
const lower = s.toLowerCase();
|
||||
let offset = 0;
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
const keyIdx = lower.indexOf('tool_calls', offset);
|
||||
if (keyIdx < 0) {
|
||||
const { index: bestKeyIdx, keyword: matchedKeyword } = earliestKeywordIndex(lower, TOOL_SEGMENT_KEYWORDS, offset);
|
||||
if (bestKeyIdx < 0) {
|
||||
return -1;
|
||||
}
|
||||
const keyIdx = bestKeyIdx;
|
||||
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
||||
const candidateStart = start >= 0 ? start : keyIdx;
|
||||
if (!insideCodeFence(s.slice(0, candidateStart))) {
|
||||
if (!insideCodeFenceWithState(state, s.slice(0, candidateStart))) {
|
||||
return candidateStart;
|
||||
}
|
||||
offset = keyIdx + 'tool_calls'.length;
|
||||
offset = keyIdx + matchedKeyword.length;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,23 +174,30 @@ function consumeToolCapture(state, toolNames) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const lower = captured.toLowerCase();
|
||||
const keyIdx = lower.indexOf('tool_calls');
|
||||
const { index: keyIdx } = earliestKeywordIndex(lower, TOOL_SEGMENT_KEYWORDS);
|
||||
if (keyIdx < 0) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
||||
const actualStart = start >= 0 ? start : keyIdx;
|
||||
if (start < 0) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
const history = extractToolHistoryBlock(captured, keyIdx);
|
||||
if (history.ok) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured.slice(0, history.start),
|
||||
calls: [],
|
||||
suffix: captured.slice(history.end),
|
||||
};
|
||||
}
|
||||
}
|
||||
const obj = extractJSONObjectFrom(captured, start);
|
||||
const obj = extractJSONObjectFrom(captured, actualStart);
|
||||
if (!obj.ok) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
|
||||
const prefixPart = captured.slice(0, start);
|
||||
const prefixPart = captured.slice(0, actualStart);
|
||||
const suffixPart = captured.slice(obj.end);
|
||||
|
||||
if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
|
||||
if (insideCodeFenceWithState(state, prefixPart)) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured,
|
||||
@@ -211,17 +205,7 @@ function consumeToolCapture(state, toolNames) {
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
|
||||
if ((state.recentTextTail || '').trim() !== '' || prefixPart.trim() !== '' || suffixPart.trim() !== '') {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured,
|
||||
calls: [],
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
|
||||
const parsed = parseStandaloneToolCallsDetailed(captured.slice(start, obj.end), toolNames);
|
||||
const parsed = parseStandaloneToolCallsDetailed(captured.slice(actualStart, obj.end), toolNames);
|
||||
if (!Array.isArray(parsed.calls) || parsed.calls.length === 0) {
|
||||
if (parsed.sawToolCallSyntax && parsed.rejectedByPolicy) {
|
||||
return {
|
||||
@@ -238,15 +222,61 @@ function consumeToolCapture(state, toolNames) {
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
|
||||
const trimmedFence = trimWrappingJSONFence(prefixPart, suffixPart);
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
prefix: trimmedFence.prefix,
|
||||
calls: parsed.calls,
|
||||
suffix: suffixPart,
|
||||
suffix: trimmedFence.suffix,
|
||||
};
|
||||
}
|
||||
|
||||
function extractToolHistoryBlock(captured, keyIdx) {
|
||||
if (typeof captured !== 'string' || keyIdx < 0 || keyIdx >= captured.length) {
|
||||
return { ok: false, start: 0, end: 0 };
|
||||
}
|
||||
const rest = captured.slice(keyIdx).toLowerCase();
|
||||
if (rest.startsWith('[tool_call_history]')) {
|
||||
const closeTag = '[/tool_call_history]';
|
||||
const closeIdx = rest.indexOf(closeTag);
|
||||
if (closeIdx < 0) {
|
||||
return { ok: false, start: 0, end: 0 };
|
||||
}
|
||||
return { ok: true, start: keyIdx, end: keyIdx + closeIdx + closeTag.length };
|
||||
}
|
||||
if (rest.startsWith('[tool_result_history]')) {
|
||||
const closeTag = '[/tool_result_history]';
|
||||
const closeIdx = rest.indexOf(closeTag);
|
||||
if (closeIdx < 0) {
|
||||
return { ok: false, start: 0, end: 0 };
|
||||
}
|
||||
return { ok: true, start: keyIdx, end: keyIdx + closeIdx + closeTag.length };
|
||||
}
|
||||
return { ok: false, start: 0, end: 0 };
|
||||
}
|
||||
|
||||
function trimWrappingJSONFence(prefix, suffix) {
|
||||
const rightTrimmedPrefix = (prefix || '').replace(/[ \t\r\n]+$/g, '');
|
||||
const fenceIdx = rightTrimmedPrefix.lastIndexOf('```');
|
||||
if (fenceIdx < 0) return { prefix, suffix };
|
||||
const fenceCount = (rightTrimmedPrefix.slice(0, fenceIdx + 3).match(/```/g) || []).length;
|
||||
if (fenceCount % 2 === 0) {
|
||||
return { prefix, suffix };
|
||||
}
|
||||
const header = rightTrimmedPrefix.slice(fenceIdx + 3).trim().toLowerCase();
|
||||
if (header && header !== 'json') {
|
||||
return { prefix, suffix };
|
||||
}
|
||||
const leftTrimmedSuffix = (suffix || '').replace(/^[ \t\r\n]+/g, '');
|
||||
if (!leftTrimmedSuffix.startsWith('```')) {
|
||||
return { prefix, suffix };
|
||||
}
|
||||
const consumed = (suffix || '').length - leftTrimmedSuffix.length;
|
||||
return {
|
||||
prefix: rightTrimmedPrefix.slice(0, fenceIdx),
|
||||
suffix: (suffix || '').slice(consumed + 3),
|
||||
};
|
||||
}
|
||||
module.exports = {
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use strict';
|
||||
|
||||
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
|
||||
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 4096;
|
||||
|
||||
function createToolSieveState() {
|
||||
return {
|
||||
@@ -8,6 +8,9 @@ function createToolSieveState() {
|
||||
capture: '',
|
||||
capturing: false,
|
||||
recentTextTail: '',
|
||||
codeFenceStack: [],
|
||||
codeFencePendingTicks: 0,
|
||||
codeFenceLineStart: true,
|
||||
pendingToolRaw: '',
|
||||
pendingToolCalls: [],
|
||||
disableDeltas: false,
|
||||
@@ -34,6 +37,7 @@ function noteText(state, text) {
|
||||
if (!state || !hasMeaningfulText(text)) {
|
||||
return;
|
||||
}
|
||||
updateCodeFenceState(state, text);
|
||||
state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT);
|
||||
}
|
||||
|
||||
@@ -63,6 +67,91 @@ function insideCodeFence(text) {
|
||||
return ticks % 2 === 1;
|
||||
}
|
||||
|
||||
function insideCodeFenceWithState(state, text) {
|
||||
if (!state) {
|
||||
return insideCodeFence(text);
|
||||
}
|
||||
const simulated = simulateCodeFenceState(
|
||||
Array.isArray(state.codeFenceStack) ? state.codeFenceStack : [],
|
||||
Number.isInteger(state.codeFencePendingTicks) ? state.codeFencePendingTicks : 0,
|
||||
state.codeFenceLineStart !== false,
|
||||
text,
|
||||
);
|
||||
return simulated.stack.length > 0;
|
||||
}
|
||||
|
||||
function updateCodeFenceState(state, text) {
|
||||
if (!state) {
|
||||
return;
|
||||
}
|
||||
const next = simulateCodeFenceState(
|
||||
Array.isArray(state.codeFenceStack) ? state.codeFenceStack : [],
|
||||
Number.isInteger(state.codeFencePendingTicks) ? state.codeFencePendingTicks : 0,
|
||||
state.codeFenceLineStart !== false,
|
||||
text,
|
||||
);
|
||||
state.codeFenceStack = next.stack;
|
||||
state.codeFencePendingTicks = next.pendingTicks;
|
||||
state.codeFenceLineStart = next.lineStart;
|
||||
}
|
||||
|
||||
function simulateCodeFenceState(stack, pendingTicks, lineStart, text) {
|
||||
const chunk = typeof text === 'string' ? text : '';
|
||||
const nextStack = Array.isArray(stack) ? [...stack] : [];
|
||||
let ticks = Number.isInteger(pendingTicks) ? pendingTicks : 0;
|
||||
let atLineStart = lineStart !== false;
|
||||
|
||||
const flushTicks = () => {
|
||||
if (ticks > 0) {
|
||||
if (atLineStart && ticks >= 3) {
|
||||
applyFenceMarker(nextStack, ticks);
|
||||
}
|
||||
atLineStart = false;
|
||||
ticks = 0;
|
||||
}
|
||||
};
|
||||
|
||||
for (let i = 0; i < chunk.length; i += 1) {
|
||||
const ch = chunk[i];
|
||||
if (ch === '`') {
|
||||
ticks += 1;
|
||||
continue;
|
||||
}
|
||||
flushTicks();
|
||||
if (ch === '\n' || ch === '\r') {
|
||||
atLineStart = true;
|
||||
continue;
|
||||
}
|
||||
if ((ch === ' ' || ch === '\t') && atLineStart) {
|
||||
continue;
|
||||
}
|
||||
atLineStart = false;
|
||||
}
|
||||
// keep ticks for cross-chunk continuation.
|
||||
return {
|
||||
stack: nextStack,
|
||||
pendingTicks: ticks,
|
||||
lineStart: atLineStart,
|
||||
};
|
||||
}
|
||||
|
||||
function applyFenceMarker(stack, ticks) {
|
||||
if (!Array.isArray(stack)) {
|
||||
return;
|
||||
}
|
||||
if (stack.length === 0) {
|
||||
stack.push(ticks);
|
||||
return;
|
||||
}
|
||||
const top = stack[stack.length - 1];
|
||||
if (ticks >= top) {
|
||||
stack.pop();
|
||||
return;
|
||||
}
|
||||
// nested/open inner fence using longer marker for robustness.
|
||||
stack.push(ticks);
|
||||
}
|
||||
|
||||
function hasMeaningfulText(text) {
|
||||
return toStringSafe(text) !== '';
|
||||
}
|
||||
@@ -88,6 +177,8 @@ module.exports = {
|
||||
appendTail,
|
||||
looksLikeToolExampleContext,
|
||||
insideCodeFence,
|
||||
insideCodeFenceWithState,
|
||||
updateCodeFenceState,
|
||||
hasMeaningfulText,
|
||||
toStringSafe,
|
||||
};
|
||||
|
||||
29
internal/js/helpers/stream-tool-sieve/tool-keywords.js
Normal file
29
internal/js/helpers/stream-tool-sieve/tool-keywords.js
Normal file
@@ -0,0 +1,29 @@
|
||||
'use strict';
|
||||
|
||||
const TOOL_SEGMENT_KEYWORDS = [
|
||||
'tool_calls',
|
||||
'function.name:',
|
||||
'[tool_call_history]',
|
||||
'[tool_result_history]',
|
||||
];
|
||||
|
||||
function earliestKeywordIndex(text, keywords = TOOL_SEGMENT_KEYWORDS, offset = 0) {
|
||||
if (!text) {
|
||||
return { index: -1, keyword: '' };
|
||||
}
|
||||
let index = -1;
|
||||
let keyword = '';
|
||||
for (const kw of keywords) {
|
||||
const candidate = text.indexOf(kw, offset);
|
||||
if (candidate >= 0 && (index < 0 || candidate < index)) {
|
||||
index = candidate;
|
||||
keyword = kw;
|
||||
}
|
||||
}
|
||||
return { index, keyword };
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
TOOL_SEGMENT_KEYWORDS,
|
||||
earliestKeywordIndex,
|
||||
};
|
||||
@@ -36,6 +36,12 @@ func MessagesPrepare(messages []map[string]any) string {
|
||||
switch m.Role {
|
||||
case "assistant":
|
||||
parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>")
|
||||
case "tool":
|
||||
if i > 0 {
|
||||
parts = append(parts, "<|Tool|>"+m.Text)
|
||||
} else {
|
||||
parts = append(parts, m.Text)
|
||||
}
|
||||
case "user", "system":
|
||||
if i > 0 {
|
||||
parts = append(parts, "<|User|>"+m.Text)
|
||||
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
|
||||
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
|
||||
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
|
||||
var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
|
||||
var fencedCodeBlockPattern = regexp.MustCompile("(?s)```[\\s\\S]*?```")
|
||||
var markupToolSyntaxPattern = regexp.MustCompile(`(?i)<(?:(?:[a-z0-9_:-]+:)?(?:tool_call|function_call|invoke)\b|(?:[a-z0-9_:-]+:)?function_calls\b|(?:[a-z0-9_:-]+:)?tool_use\b)`)
|
||||
|
||||
func buildToolCallCandidates(text string) []string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
@@ -20,7 +21,7 @@ func buildToolCallCandidates(text string) []string {
|
||||
}
|
||||
}
|
||||
|
||||
// best-effort extraction around "tool_calls" key in mixed text payloads.
|
||||
// best-effort extraction around tool call keywords in mixed text payloads.
|
||||
candidates = append(candidates, extractToolCallObjects(trimmed)...)
|
||||
|
||||
// best-effort object slice: from first '{' to last '}'
|
||||
@@ -57,25 +58,65 @@ func extractToolCallObjects(text string) []string {
|
||||
lower := strings.ToLower(text)
|
||||
out := []string{}
|
||||
offset := 0
|
||||
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"}
|
||||
for {
|
||||
idx := strings.Index(lower[offset:], "tool_calls")
|
||||
if idx < 0 {
|
||||
bestIdx := -1
|
||||
matchedKeyword := ""
|
||||
for _, kw := range keywords {
|
||||
idx := strings.Index(lower[offset:], kw)
|
||||
if idx >= 0 {
|
||||
absIdx := offset + idx
|
||||
if bestIdx < 0 || absIdx < bestIdx {
|
||||
bestIdx = absIdx
|
||||
matchedKeyword = kw
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bestIdx < 0 {
|
||||
break
|
||||
}
|
||||
idx += offset
|
||||
start := strings.LastIndex(text[:idx], "{")
|
||||
for start >= 0 {
|
||||
|
||||
idx := bestIdx
|
||||
// Avoid backtracking too far to prevent OOM on malicious or very long strings
|
||||
searchLimit := idx - 2000
|
||||
if searchLimit < offset {
|
||||
searchLimit = offset
|
||||
}
|
||||
|
||||
start := strings.LastIndex(text[searchLimit:idx], "{")
|
||||
if start >= 0 {
|
||||
start += searchLimit
|
||||
}
|
||||
|
||||
if start < 0 {
|
||||
offset = idx + len(matchedKeyword)
|
||||
continue
|
||||
}
|
||||
|
||||
foundObj := false
|
||||
for start >= searchLimit {
|
||||
candidate, end, ok := extractJSONObject(text, start)
|
||||
if ok {
|
||||
// Move forward to avoid repeatedly matching the same object.
|
||||
offset = end
|
||||
out = append(out, strings.TrimSpace(candidate))
|
||||
foundObj = true
|
||||
break
|
||||
}
|
||||
start = strings.LastIndex(text[:start], "{")
|
||||
// Try previous '{'
|
||||
if start > searchLimit {
|
||||
prevStart := strings.LastIndex(text[searchLimit:start], "{")
|
||||
if prevStart >= 0 {
|
||||
start = searchLimit + prevStart
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if start < 0 {
|
||||
offset = idx + len("tool_calls")
|
||||
|
||||
if !foundObj {
|
||||
offset = idx + len(matchedKeyword)
|
||||
}
|
||||
}
|
||||
return out
|
||||
@@ -88,7 +129,12 @@ func extractJSONObject(text string, start int) (string, int, bool) {
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
// Limit scan length to avoid OOM on unclosed objects
|
||||
maxLen := start + 50000
|
||||
if maxLen > len(text) {
|
||||
maxLen = len(text)
|
||||
}
|
||||
for i := start; i < maxLen; i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
@@ -130,9 +176,21 @@ func looksLikeToolExampleContext(text string) bool {
|
||||
return strings.Contains(t, "```")
|
||||
}
|
||||
|
||||
func shouldSkipToolCallParsingForCodeFenceExample(text string) bool {
|
||||
if !looksLikeToolCallSyntax(text) {
|
||||
return false
|
||||
}
|
||||
stripped := strings.TrimSpace(stripFencedCodeBlocks(text))
|
||||
return !looksLikeToolCallSyntax(stripped)
|
||||
}
|
||||
|
||||
func looksLikeMarkupToolSyntax(text string) bool {
|
||||
return markupToolSyntaxPattern.MatchString(text)
|
||||
}
|
||||
|
||||
func stripFencedCodeBlocks(text string) string {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
return fencedBlockPattern.ReplaceAllString(text, " ")
|
||||
return fencedCodeBlockPattern.ReplaceAllString(text, " ")
|
||||
}
|
||||
|
||||
108
internal/util/toolcalls_input_parse.go
Normal file
108
internal/util/toolcalls_input_parse.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func parseToolCallInput(v any) map[string]any {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return map[string]any{}
|
||||
case map[string]any:
|
||||
return x
|
||||
case string:
|
||||
raw := strings.TrimSpace(x)
|
||||
if raw == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil {
|
||||
repairPathLikeControlChars(parsed)
|
||||
return parsed
|
||||
}
|
||||
// Try to repair invalid backslashes (common in Windows paths output by models)
|
||||
repaired := repairInvalidJSONBackslashes(raw)
|
||||
if repaired != raw {
|
||||
if err := json.Unmarshal([]byte(repaired), &parsed); err == nil && parsed != nil {
|
||||
repairPathLikeControlChars(parsed)
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
// Try to repair loose JSON in string argument as well
|
||||
repairedLoose := RepairLooseJSON(raw)
|
||||
if repairedLoose != raw {
|
||||
if err := json.Unmarshal([]byte(repairedLoose), &parsed); err == nil && parsed != nil {
|
||||
repairPathLikeControlChars(parsed)
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return map[string]any{"_raw": raw}
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil {
|
||||
return parsed
|
||||
}
|
||||
return map[string]any{}
|
||||
}
|
||||
}
|
||||
|
||||
func repairPathLikeControlChars(m map[string]any) {
|
||||
for k, v := range m {
|
||||
switch vv := v.(type) {
|
||||
case map[string]any:
|
||||
repairPathLikeControlChars(vv)
|
||||
case []any:
|
||||
for _, item := range vv {
|
||||
if child, ok := item.(map[string]any); ok {
|
||||
repairPathLikeControlChars(child)
|
||||
}
|
||||
}
|
||||
case string:
|
||||
if isPathLikeKey(k) && containsControlRune(vv) {
|
||||
m[k] = escapeControlRunes(vv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isPathLikeKey(key string) bool {
|
||||
k := strings.ToLower(strings.TrimSpace(key))
|
||||
return strings.Contains(k, "path") || strings.Contains(k, "file")
|
||||
}
|
||||
|
||||
func containsControlRune(s string) bool {
|
||||
for _, r := range s {
|
||||
if unicode.IsControl(r) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func escapeControlRunes(s string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s) + 8)
|
||||
for _, r := range s {
|
||||
switch r {
|
||||
case '\b':
|
||||
b.WriteString(`\b`)
|
||||
case '\f':
|
||||
b.WriteString(`\f`)
|
||||
case '\n':
|
||||
b.WriteString(`\n`)
|
||||
case '\r':
|
||||
b.WriteString(`\r`)
|
||||
case '\t':
|
||||
b.WriteString(`\t`)
|
||||
default:
|
||||
b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
79
internal/util/toolcalls_json_repair.go
Normal file
79
internal/util/toolcalls_json_repair.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func repairInvalidJSONBackslashes(s string) string {
|
||||
if !strings.Contains(s, "\\") {
|
||||
return s
|
||||
}
|
||||
var out strings.Builder
|
||||
out.Grow(len(s) + 10)
|
||||
runes := []rune(s)
|
||||
for i := 0; i < len(runes); i++ {
|
||||
if runes[i] == '\\' {
|
||||
if i+1 < len(runes) {
|
||||
next := runes[i+1]
|
||||
switch next {
|
||||
case '"', '\\', '/', 'b', 'f', 'n', 'r', 't':
|
||||
out.WriteRune('\\')
|
||||
out.WriteRune(next)
|
||||
i++
|
||||
continue
|
||||
case 'u':
|
||||
if i+5 < len(runes) {
|
||||
isHex := true
|
||||
for j := 1; j <= 4; j++ {
|
||||
r := runes[i+1+j]
|
||||
if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) {
|
||||
isHex = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isHex {
|
||||
out.WriteRune('\\')
|
||||
out.WriteRune('u')
|
||||
for j := 1; j <= 4; j++ {
|
||||
out.WriteRune(runes[i+1+j])
|
||||
}
|
||||
i += 5
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Not a valid escape sequence, double it
|
||||
out.WriteString("\\\\")
|
||||
} else {
|
||||
out.WriteRune(runes[i])
|
||||
}
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
var unquotedKeyPattern = regexp.MustCompile(`([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:`)
|
||||
|
||||
// missingArrayBracketsPattern identifies a sequence of two or more JSON objects separated by commas
|
||||
// that immediately follow a colon, which indicates a missing array bracket `[` `]`.
|
||||
// E.g., "key": {"a": 1}, {"b": 2} -> "key": [{"a": 1}, {"b": 2}]
|
||||
// NOTE: The pattern uses (?:[^{}]|\{[^{}]*\})* to support single-level nested {} objects,
|
||||
// which handles cases like {"content": "x", "input": {"q": "y"}}
|
||||
var missingArrayBracketsPattern = regexp.MustCompile(`(:\s*)(\{(?:[^{}]|\{[^{}]*\})*\}(?:\s*,\s*\{(?:[^{}]|\{[^{}]*\})*\})+)`)
|
||||
|
||||
func RepairLooseJSON(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
// 1. Replace unquoted keys: {key: -> {"key":
|
||||
s = unquotedKeyPattern.ReplaceAllString(s, `$1"$2":`)
|
||||
|
||||
// 2. Heuristic: Fix missing array brackets for list of objects
|
||||
// e.g., : {obj1}, {obj2} -> : [{obj1}, {obj2}]
|
||||
// This specifically addresses DeepSeek's "list hallucination"
|
||||
s = missingArrayBracketsPattern.ReplaceAllString(s, `$1[$2]`)
|
||||
|
||||
return s
|
||||
}
|
||||
@@ -16,7 +16,6 @@ type ToolCallParseResult struct {
|
||||
RejectedByPolicy bool
|
||||
RejectedToolNames []string
|
||||
}
|
||||
|
||||
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return ParseToolCallsDetailed(text, availableToolNames).Calls
|
||||
}
|
||||
@@ -26,11 +25,10 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return result
|
||||
}
|
||||
text = stripFencedCodeBlocks(text)
|
||||
if strings.TrimSpace(text) == "" {
|
||||
result.SawToolCallSyntax = looksLikeToolCallSyntax(text)
|
||||
if shouldSkipToolCallParsingForCodeFenceExample(text) {
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = looksLikeToolCallSyntax(text)
|
||||
|
||||
candidates := buildToolCallCandidates(text)
|
||||
var parsed []ParsedToolCall
|
||||
@@ -68,7 +66,6 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa
|
||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||
return result
|
||||
}
|
||||
|
||||
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
|
||||
}
|
||||
@@ -79,17 +76,18 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string)
|
||||
if trimmed == "" {
|
||||
return result
|
||||
}
|
||||
if looksLikeToolExampleContext(trimmed) {
|
||||
result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed)
|
||||
if shouldSkipToolCallParsingForCodeFenceExample(trimmed) {
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed)
|
||||
candidates := []string{trimmed}
|
||||
candidates := buildToolCallCandidates(trimmed)
|
||||
var parsed []ParsedToolCall
|
||||
for _, candidate := range candidates {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
parsed := parseToolCallsPayload(candidate)
|
||||
parsed = parseToolCallsPayload(candidate)
|
||||
if len(parsed) == 0 {
|
||||
parsed = parseXMLToolCalls(candidate)
|
||||
}
|
||||
@@ -100,14 +98,23 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string)
|
||||
parsed = parseTextKVToolCalls(candidate)
|
||||
}
|
||||
if len(parsed) > 0 {
|
||||
result.SawToolCallSyntax = true
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
result.Calls = calls
|
||||
result.RejectedToolNames = rejectedNames
|
||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||
return result
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parsed) == 0 {
|
||||
parsed = parseXMLToolCalls(trimmed)
|
||||
if len(parsed) == 0 {
|
||||
parsed = parseTextKVToolCalls(trimmed)
|
||||
if len(parsed) == 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
result.SawToolCallSyntax = true
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
result.Calls = calls
|
||||
result.RejectedToolNames = rejectedNames
|
||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -171,11 +178,20 @@ func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCan
|
||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
var decoded any
|
||||
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
|
||||
return nil
|
||||
// Try to repair backslashes first! Because LLMs often mix these two problems.
|
||||
repaired := repairInvalidJSONBackslashes(payload)
|
||||
// Try loose repair on top of that
|
||||
repaired = RepairLooseJSON(repaired)
|
||||
if err := json.Unmarshal([]byte(repaired), &decoded); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
switch v := decoded.(type) {
|
||||
case map[string]any:
|
||||
if tc, ok := v["tool_calls"]; ok {
|
||||
if isLikelyChatMessageEnvelope(v) {
|
||||
return nil
|
||||
}
|
||||
return parseToolCallList(tc)
|
||||
}
|
||||
if parsed, ok := parseToolCallItem(v); ok {
|
||||
@@ -187,6 +203,28 @@ func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isLikelyChatMessageEnvelope(v map[string]any) bool {
|
||||
if v == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := v["tool_calls"]; !ok {
|
||||
return false
|
||||
}
|
||||
if role, ok := v["role"].(string); ok {
|
||||
switch strings.ToLower(strings.TrimSpace(role)) {
|
||||
case "assistant", "tool", "user", "system":
|
||||
return true
|
||||
}
|
||||
}
|
||||
if _, ok := v["tool_call_id"]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := v["content"]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func looksLikeToolCallSyntax(text string) bool {
|
||||
lower := strings.ToLower(text)
|
||||
return strings.Contains(lower, "tool_calls") ||
|
||||
@@ -248,32 +286,3 @@ func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) {
|
||||
Input: parseToolCallInput(inputRaw),
|
||||
}, true
|
||||
}
|
||||
|
||||
func parseToolCallInput(v any) map[string]any {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return map[string]any{}
|
||||
case map[string]any:
|
||||
return x
|
||||
case string:
|
||||
raw := strings.TrimSpace(x)
|
||||
if raw == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil {
|
||||
return parsed
|
||||
}
|
||||
return map[string]any{"_raw": raw}
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil {
|
||||
return parsed
|
||||
}
|
||||
return map[string]any{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+
|
||||
var antmlParametersPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?parameters\s*>\s*(\{.*?\})\s*</(?:[a-z0-9_]+:)?parameters>`)
|
||||
var invokeCallPattern = regexp.MustCompile(`(?is)<invoke\s+name="([^"]+)"\s*>(.*?)</invoke>`)
|
||||
var invokeParamPattern = regexp.MustCompile(`(?is)<parameter\s+name="([^"]+)"\s*>\s*(.*?)\s*</parameter>`)
|
||||
var toolUseFunctionPattern = regexp.MustCompile(`(?is)<tool_use>\s*<function\s+name="([^"]+)"\s*>(.*?)</function>\s*</tool_use>`)
|
||||
|
||||
func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
matches := xmlToolCallPattern.FindAllString(text, -1)
|
||||
@@ -38,6 +39,9 @@ func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
if call, ok := parseInvokeFunctionCallStyle(text); ok {
|
||||
return []ParsedToolCall{call}
|
||||
}
|
||||
if call, ok := parseToolUseFunctionStyle(text); ok {
|
||||
return []ParsedToolCall{call}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -229,6 +233,30 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
|
||||
func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
|
||||
m := toolUseFunctionPattern.FindStringSubmatch(text)
|
||||
if len(m) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
name := strings.TrimSpace(m[1])
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
body := m[2]
|
||||
input := map[string]any{}
|
||||
for _, pm := range invokeParamPattern.FindAllStringSubmatch(body, -1) {
|
||||
if len(pm) < 3 {
|
||||
continue
|
||||
}
|
||||
k := strings.TrimSpace(pm[1])
|
||||
v := strings.TrimSpace(pm[2])
|
||||
if k != "" {
|
||||
input[k] = v
|
||||
}
|
||||
}
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseToolCalls(t *testing.T) {
|
||||
text := `prefix {"tool_calls":[{"name":"search","input":{"q":"golang"}}]} suffix`
|
||||
@@ -16,11 +19,11 @@ func TestParseToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsFromFencedJSON(t *testing.T) {
|
||||
func TestParseToolCallsIgnoresFencedJSON(t *testing.T) {
|
||||
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
||||
t.Fatalf("expected fenced tool_call payload to be ignored, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,10 +99,10 @@ func TestFormatOpenAIToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) {
|
||||
func TestParseStandaloneToolCallsSupportsMixedProsePayload(t *testing.T) {
|
||||
mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 {
|
||||
t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls)
|
||||
if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 1 {
|
||||
t.Fatalf("expected standalone parser to parse mixed prose payload, got %#v", calls)
|
||||
}
|
||||
|
||||
standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
@@ -112,7 +115,14 @@ func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) {
|
||||
func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
|
||||
fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```"
|
||||
if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 {
|
||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
||||
t.Fatalf("expected fenced tool_call payload to be ignored, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStandaloneToolCallsIgnoresChatTranscriptEnvelope(t *testing.T) {
|
||||
transcript := `[{"role":"user","content":"请展示完整会话"},{"role":"assistant","content":null,"tool_calls":[{"function":{"name":"search","arguments":"{\"q\":\"go\"}"}}]}]`
|
||||
if calls := ParseStandaloneToolCalls(transcript, []string{"search"}); len(calls) != 0 {
|
||||
t.Fatalf("expected transcript envelope not to trigger tool call parse, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,6 +243,20 @@ func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsToolUseFunctionParameterStyle(t *testing.T) {
|
||||
text := `<tool_use><function name="search_web"><parameter name="query">test</parameter></function></tool_use>`
|
||||
calls := ParseToolCalls(text, []string{"search_web"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "search_web" {
|
||||
t.Fatalf("expected canonical tool name search_web, got %q", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["query"] != "test" {
|
||||
t.Fatalf("expected query argument, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsNestedToolTagStyle(t *testing.T) {
|
||||
text := `<tool_call><tool name="Bash"><command>pwd</command><description>show cwd</description></tool></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"bash"})
|
||||
@@ -279,3 +303,238 @@ func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
|
||||
t.Fatalf("expected mismatched tags to be rejected, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairInvalidJSONBackslashes(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{`{"path": "C:\Users\name"}`, `{"path": "C:\\Users\name"}`},
|
||||
{`{"cmd": "cd D:\git_codes"}`, `{"cmd": "cd D:\\git_codes"}`},
|
||||
{`{"text": "line1\nline2"}`, `{"text": "line1\nline2"}`},
|
||||
{`{"path": "D:\\back\\slash"}`, `{"path": "D:\\back\\slash"}`},
|
||||
{`{"unicode": "\u2705"}`, `{"unicode": "\u2705"}`},
|
||||
{`{"invalid_u": "\u123"}`, `{"invalid_u": "\\u123"}`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := repairInvalidJSONBackslashes(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("repairInvalidJSONBackslashes(%s) = %s; want %s", tt.input, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairLooseJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{`{tool_calls: [{"name": "search", "input": {"q": "go"}}]}`, `{"tool_calls": [{"name": "search", "input": {"q": "go"}}]}`},
|
||||
{`{name: "search", input: {q: "go"}}`, `{"name": "search", "input": {"q": "go"}}`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := RepairLooseJSON(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("RepairLooseJSON(%s) = %s; want %s", tt.input, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsWithUnquotedKeys(t *testing.T) {
|
||||
text := `这里是列表:{tool_calls: [{"name": "todowrite", "input": {"todos": "test"}}]}`
|
||||
availableTools := []string{"todowrite"}
|
||||
|
||||
parsed := ParseToolCalls(text, availableTools)
|
||||
if len(parsed) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(parsed))
|
||||
}
|
||||
if parsed[0].Name != "todowrite" {
|
||||
t.Errorf("expected tool todowrite, got %s", parsed[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsWithInvalidBackslashes(t *testing.T) {
|
||||
// DeepSeek sometimes outputs Windows paths with single backslashes in JSON strings
|
||||
// Note: using raw string to simulate what AI actually sends in the stream
|
||||
text := `好的,执行以下命令:{"name": "execute_command", "input": "{\"command\": \"cd D:\git_codes && dir\"}"}`
|
||||
availableTools := []string{"execute_command"}
|
||||
|
||||
parsed := ParseToolCalls(text, availableTools)
|
||||
// If standard JSON fails, buildToolCallCandidates should still extract the object,
|
||||
// and parseToolCallsPayload should repair it.
|
||||
if len(parsed) != 1 {
|
||||
// If it still fails, let's see why
|
||||
candidates := buildToolCallCandidates(text)
|
||||
t.Logf("Candidates: %v", candidates)
|
||||
t.Fatalf("expected 1 tool call, got %d", len(parsed))
|
||||
}
|
||||
|
||||
cmd, ok := parsed[0].Input["command"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected command string in input, got %v", parsed[0].Input)
|
||||
}
|
||||
|
||||
expected := "cd D:\\git_codes && dir"
|
||||
if cmd != expected {
|
||||
t.Errorf("expected command %q, got %q", expected, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsWithDeepSeekHallucination(t *testing.T) {
|
||||
// 模拟 DeepSeek 典型的幻觉输出:未加引号的键名 + 包含 Windows 路径的嵌套 JSON 字符串 + 漏掉列表的方括号
|
||||
text := `检测到实施意图——实现经典算法。需在misc/目录创建Python文件。
|
||||
关键约束:
|
||||
1. Windows UTF-8编码处理
|
||||
2. 必须用绝对路径导入
|
||||
3. 禁止write覆盖已有文件(misc/目录允许创建新文件)
|
||||
将任务分解并委托:
|
||||
- 研究8皇后算法模式(并行探索)
|
||||
- 实现带可视化输出的解决方案(unspecified-high)
|
||||
先创建todo列表追踪步骤。
|
||||
{tool_calls: [{"name": "todowrite", "input": {"todos": {"content": "研究8皇后问题算法模式(回溯法)和输出格式", "status": "pending", "priority": "high"}, {"content": "在misc/目录创建8皇后Python脚本,包含完整解决方案和可视化输出", "status": "pending", "priority": "high"}, {"content": "验证脚本正确性(运行测试)", "status": "pending", "priority": "medium"}}}]}`
|
||||
|
||||
availableTools := []string{"todowrite"}
|
||||
parsed := ParseToolCalls(text, availableTools)
|
||||
|
||||
if len(parsed) != 1 {
|
||||
cands := buildToolCallCandidates(text)
|
||||
for i, c := range cands {
|
||||
t.Logf("CAND %d: %s", i, c)
|
||||
repaired := RepairLooseJSON(c)
|
||||
t.Logf(" REPAIRED: %s", repaired)
|
||||
}
|
||||
t.Fatalf("expected 1 tool call, got %d. Candidates: %v", len(parsed), buildToolCallCandidates(text))
|
||||
}
|
||||
|
||||
if parsed[0].Name != "todowrite" {
|
||||
t.Errorf("expected tool name 'todowrite', got %q", parsed[0].Name)
|
||||
}
|
||||
|
||||
todos, ok := parsed[0].Input["todos"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected 'todos' to be parsed as a list, got %T: %#v", parsed[0].Input["todos"], parsed[0].Input["todos"])
|
||||
}
|
||||
if len(todos) != 3 {
|
||||
t.Errorf("expected 3 todo items, got %d", len(todos))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsWithMixedWindowsPaths(t *testing.T) {
|
||||
// 更复杂的案例:嵌套 JSON 字符串中的反斜杠未转义
|
||||
text := `关键约束: 1. Windows UTF-8编码处理 2. 必须用绝对路径导入 D:\git_codes\ds2api\misc
|
||||
{tool_calls: [{"name": "write_file", "input": "{\"path\": \"D:\\git_codes\\ds2api\\misc\\queens.py\", \"content\": \"print('hello')\"}"}]}`
|
||||
|
||||
availableTools := []string{"write_file"}
|
||||
parsed := ParseToolCalls(text, availableTools)
|
||||
|
||||
if len(parsed) != 1 {
|
||||
t.Fatalf("expected 1 tool call from mixed text with paths, got %d", len(parsed))
|
||||
}
|
||||
|
||||
path, _ := parsed[0].Input["path"].(string)
|
||||
// 在解析后的 Go map 中,反斜杠应该被还原
|
||||
if !strings.Contains(path, "D:\\git_codes") && !strings.Contains(path, "D:/git_codes") {
|
||||
t.Errorf("expected path to contain Windows style separators, got %q", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallInputRepairsControlCharsInPath(t *testing.T) {
|
||||
in := `{"path":"D:\tmp\new\readme.txt","content":"line1\nline2"}`
|
||||
parsed := parseToolCallInput(in)
|
||||
|
||||
path, ok := parsed["path"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected path string in parsed input, got %#v", parsed["path"])
|
||||
}
|
||||
if path != `D:\tmp\new\readme.txt` {
|
||||
t.Fatalf("expected repaired windows path, got %q", path)
|
||||
}
|
||||
|
||||
content, ok := parsed["content"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected content string in parsed input, got %#v", parsed["content"])
|
||||
}
|
||||
if content != "line1\nline2" {
|
||||
t.Fatalf("expected non-path field to keep decoded escapes, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairLooseJSONWithNestedObjects(t *testing.T) {
|
||||
// 测试嵌套对象的修复:DeepSeek 幻觉输出,每个元素内部包含嵌套 {}
|
||||
// 注意:正则只支持单层嵌套,不支持更深层次的嵌套
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
// 1. 单层嵌套对象(核心修复目标)
|
||||
{
|
||||
name: "单层嵌套 - 2个元素",
|
||||
input: `"todos": {"content": "研究算法", "input": {"q": "8 queens"}}, {"content": "实现", "input": {"path": "queens.py"}}`,
|
||||
expected: `"todos": [{"content": "研究算法", "input": {"q": "8 queens"}}, {"content": "实现", "input": {"path": "queens.py"}}]`,
|
||||
},
|
||||
// 2. 3个单层嵌套对象
|
||||
{
|
||||
name: "3个单层嵌套对象",
|
||||
input: `"items": {"a": {"x":1}}, {"b": {"y":2}}, {"c": {"z":3}}`,
|
||||
expected: `"items": [{"a": {"x":1}}, {"b": {"y":2}}, {"c": {"z":3}}]`,
|
||||
},
|
||||
// 3. 混合嵌套:有些字段是对象,有些是原始值
|
||||
{
|
||||
name: "混合嵌套 - 对象和原始值混合",
|
||||
input: `"items": {"name": "test", "config": {"timeout": 30}}, {"name": "test2", "config": {"timeout": 60}}`,
|
||||
expected: `"items": [{"name": "test", "config": {"timeout": 30}}, {"name": "test2", "config": {"timeout": 60}}]`,
|
||||
},
|
||||
// 4. 4个嵌套对象(边界测试)
|
||||
{
|
||||
name: "4个嵌套对象",
|
||||
input: `"todos": {"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}`,
|
||||
expected: `"todos": [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}]`,
|
||||
},
|
||||
// 5. DeepSeek 典型幻觉:无空格逗号分隔
|
||||
{
|
||||
name: "无空格逗号分隔",
|
||||
input: `"results": {"name": "a"}, {"name": "b"}, {"name": "c"}`,
|
||||
expected: `"results": [{"name": "a"}, {"name": "b"}, {"name": "c"}]`,
|
||||
},
|
||||
// 6. 嵌套数组(数组在对象内,不是深层嵌套)
|
||||
{
|
||||
name: "对象内包含数组",
|
||||
input: `"data": {"items": [1,2,3]}, {"items": [4,5,6]}`,
|
||||
expected: `"data": [{"items": [1,2,3]}, {"items": [4,5,6]}]`,
|
||||
},
|
||||
// 7. 真实的 DeepSeek 8皇后问题输出
|
||||
{
|
||||
name: "DeepSeek 8皇后真实输出",
|
||||
input: `"todos": {"content": "研究8皇后算法", "status": "pending"}, {"content": "实现Python脚本", "status": "pending"}, {"content": "验证结果", "status": "pending"}`,
|
||||
expected: `"todos": [{"content": "研究8皇后算法", "status": "pending"}, {"content": "实现Python脚本", "status": "pending"}, {"content": "验证结果", "status": "pending"}]`,
|
||||
},
|
||||
// 8. 简单无嵌套对象(回归测试)
|
||||
{
|
||||
name: "简单无嵌套对象",
|
||||
input: `"items": {"a": 1}, {"b": 2}`,
|
||||
expected: `"items": [{"a": 1}, {"b": 2}]`,
|
||||
},
|
||||
// 9. 更复杂的单层嵌套
|
||||
{
|
||||
name: "复杂单层嵌套",
|
||||
input: `"functions": {"name": "execute", "input": {"command": "ls"}}, {"name": "read", "input": {"file": "a.txt"}}`,
|
||||
expected: `"functions": [{"name": "execute", "input": {"command": "ls"}}, {"name": "read", "input": {"file": "a.txt"}}]`,
|
||||
},
|
||||
// 10. 5个嵌套对象
|
||||
{
|
||||
name: "5个嵌套对象",
|
||||
input: `"tasks": {"id":1}, {"id":2}, {"id":3}, {"id":4}, {"id":5}`,
|
||||
expected: `"tasks": [{"id":1}, {"id":2}, {"id":3}, {"id":4}, {"id":5}]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := RepairLooseJSON(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("[%s] RepairLooseJSON with nested objects:\n input: %s\n got: %s\n expected: %s", tt.name, tt.input, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -410,7 +410,7 @@ func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) {
|
||||
fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this."
|
||||
calls := ParseStandaloneToolCalls(fenced, []string{"search"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected fenced code block ignored, got %d calls", len(calls))
|
||||
t.Fatalf("expected fenced code block to be ignored, got %d calls", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
185
internal/version/version.go
Normal file
185
internal/version/version.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// BuildVersion can be injected at build time via -ldflags.
|
||||
// In release builds it should come from Git tag (e.g. v2.3.5).
|
||||
var BuildVersion = ""
|
||||
|
||||
var (
|
||||
currentOnce sync.Once
|
||||
currentVal string
|
||||
sourceVal string
|
||||
)
|
||||
|
||||
func Current() (value string, source string) {
|
||||
currentOnce.Do(func() {
|
||||
if build := strings.TrimSpace(BuildVersion); build != "" {
|
||||
currentVal = normalize(build)
|
||||
sourceVal = "build-ldflags"
|
||||
return
|
||||
}
|
||||
if fv := readVersionFile(); fv != "" {
|
||||
currentVal = normalize(fv)
|
||||
sourceVal = "file:VERSION"
|
||||
return
|
||||
}
|
||||
|
||||
if vv := versionFromVercelEnv(); vv != "" {
|
||||
currentVal = vv
|
||||
sourceVal = "env:vercel"
|
||||
return
|
||||
}
|
||||
currentVal = "dev"
|
||||
sourceVal = "default"
|
||||
})
|
||||
return currentVal, sourceVal
|
||||
}
|
||||
|
||||
func readVersionFile() string {
|
||||
candidates := []string{"VERSION"}
|
||||
if wd, err := os.Getwd(); err == nil {
|
||||
candidates = append(candidates, filepath.Join(wd, "VERSION"))
|
||||
}
|
||||
if _, file, _, ok := runtime.Caller(0); ok {
|
||||
repoRoot := filepath.Clean(filepath.Join(filepath.Dir(file), "../.."))
|
||||
candidates = append(candidates, filepath.Join(repoRoot, "VERSION"))
|
||||
}
|
||||
seen := map[string]struct{}{}
|
||||
for _, c := range candidates {
|
||||
c = filepath.Clean(strings.TrimSpace(c))
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
continue
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
b, err := os.ReadFile(c)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if v := strings.TrimSpace(string(b)); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalize(v string) string {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimPrefix(v, "v")
|
||||
}
|
||||
|
||||
func Tag(v string) string {
|
||||
v = normalize(v)
|
||||
if v == "" || v == "dev" {
|
||||
return v
|
||||
}
|
||||
if v[0] < '0' || v[0] > '9' {
|
||||
return v
|
||||
}
|
||||
return "v" + v
|
||||
}
|
||||
|
||||
func versionFromVercelEnv() string {
|
||||
if tag := normalize(strings.TrimSpace(os.Getenv("VERCEL_GIT_COMMIT_TAG"))); tag != "" {
|
||||
return tag
|
||||
}
|
||||
ref := strings.TrimSpace(os.Getenv("VERCEL_GIT_COMMIT_REF"))
|
||||
sha := strings.TrimSpace(os.Getenv("VERCEL_GIT_COMMIT_SHA"))
|
||||
if len(sha) > 7 {
|
||||
sha = sha[:7]
|
||||
}
|
||||
ref = sanitizeVersionLabel(ref)
|
||||
sha = sanitizeVersionLabel(sha)
|
||||
if ref == "" && sha == "" {
|
||||
return ""
|
||||
}
|
||||
if ref != "" && sha != "" {
|
||||
return "preview-" + ref + "." + sha
|
||||
}
|
||||
if ref != "" {
|
||||
return "preview-" + ref
|
||||
}
|
||||
return "preview-" + sha
|
||||
}
|
||||
|
||||
func sanitizeVersionLabel(in string) string {
|
||||
in = strings.TrimSpace(strings.ToLower(in))
|
||||
if in == "" {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(in))
|
||||
prevDash := false
|
||||
for i := 0; i < len(in); i++ {
|
||||
c := in[i]
|
||||
if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') {
|
||||
b.WriteByte(c)
|
||||
prevDash = false
|
||||
continue
|
||||
}
|
||||
if !prevDash {
|
||||
b.WriteByte('-')
|
||||
prevDash = true
|
||||
}
|
||||
}
|
||||
out := strings.Trim(b.String(), "-")
|
||||
return out
|
||||
}
|
||||
|
||||
func Compare(a, b string) int {
|
||||
pa := parse(normalize(a))
|
||||
pb := parse(normalize(b))
|
||||
for i := 0; i < 3; i++ {
|
||||
if pa[i] < pb[i] {
|
||||
return -1
|
||||
}
|
||||
if pa[i] > pb[i] {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parse(v string) [3]int {
|
||||
var out [3]int
|
||||
parts := strings.SplitN(v, ".", 4)
|
||||
for i := 0; i < 3 && i < len(parts); i++ {
|
||||
n := readLeadingInt(parts[i])
|
||||
out[i] = n
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readLeadingInt(s string) int {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
if s[i] < '0' || s[i] > '9' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if i == 0 {
|
||||
return 0
|
||||
}
|
||||
n, err := strconv.Atoi(s[:i])
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
39
internal/version/version_test.go
Normal file
39
internal/version/version_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package version
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeAndTag(t *testing.T) {
|
||||
if got := normalize("v2.3.5"); got != "2.3.5" {
|
||||
t.Fatalf("normalize failed: %q", got)
|
||||
}
|
||||
if got := Tag("2.3.5"); got != "v2.3.5" {
|
||||
t.Fatalf("tag failed: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompare(t *testing.T) {
|
||||
if Compare("2.3.5", "2.3.5") != 0 {
|
||||
t.Fatal("expected equal")
|
||||
}
|
||||
if Compare("2.3.5", "2.3.6") >= 0 {
|
||||
t.Fatal("expected less")
|
||||
}
|
||||
if Compare("v2.10.0", "2.3.9") <= 0 {
|
||||
t.Fatal("expected greater")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagKeepsPreviewStyle(t *testing.T) {
|
||||
if got := Tag("preview-dev.abcd123"); got != "preview-dev.abcd123" {
|
||||
t.Fatalf("expected preview tag unchanged, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionFromVercelEnv(t *testing.T) {
|
||||
t.Setenv("VERCEL_GIT_COMMIT_TAG", "")
|
||||
t.Setenv("VERCEL_GIT_COMMIT_REF", "dev")
|
||||
t.Setenv("VERCEL_GIT_COMMIT_SHA", "abcdef123456")
|
||||
if got := versionFromVercelEnv(); got != "preview-dev.abcdef1" {
|
||||
t.Fatalf("unexpected vercel preview version: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
# DeepSeek Function Calling 缺陷分析与 ds2api 的增强修复策略
|
||||
|
||||
> **相关 PR**: #74 (代码核心实现) 与 #75 (Merge to dev)
|
||||
> **问题背景**: 解决因包括 DeepSeek 在内的部分模型在函数调用(Function Calling/Tool Call)表现不够“规范”,从而导致工具调用失败的问题。
|
||||
|
||||
## 一、底层架构对比:为什么会产生 Function Calling 缺陷?
|
||||
|
||||
在探讨缺陷前,我们需要理解两种 Function Calling 的底层结构差异:
|
||||
|
||||
### 1. OpenAI 的原生结构化返回 (API 级分离)
|
||||
在 OpenAI 的规范中,**聊天文字与工具调用是在底层的 JSON 结构中被硬性拆分的**:
|
||||
* 聊天废话存放在 `response.choices[0].message.content` 里。
|
||||
* 工具请求存放在单独的数组 `response.choices[0].message.tool_calls` 里。
|
||||
|
||||
**优势:** 这种设计对客户端极其友好。客户端只需判断 `tool_calls` 是否为空,就能决定是执行代码还是渲染文字。它支持同时并发多个工具请求,且底层的生成殷勤被严格训练和约束,极少抛出语法错误的 JSON。
|
||||
|
||||
### 2. DeepSeek 等模型的“单文本流”机制
|
||||
相比之下,部分未经深度专门微调的模型(或者在特定的通信适配层中),它们依然倾向于把一切内容打包成一个纯文本流吐出。这就是为什么它们的输出往往不仅包含了本该属于 `tool_calls` 结构里的 JSON,还会像个“老实人”一样夹杂了属于 `content` 里的散文。
|
||||
|
||||
---
|
||||
|
||||
## 二、DeepSeek 在 Function Calling 上的特定缺陷表现
|
||||
|
||||
相比于 OpenAI 严格遵循 API 约定的原生结构,DeepSeek 等开源/国产推理模型在工具调用时,经常会暴露出以下三种典型的“不守规矩”的输出行为:
|
||||
|
||||
### 1. 混合输出:散文文本与工具 JSON 混杂 (Mixed Prose Streams)
|
||||
当应用要求模型直接返回工具请求时,DeepSeek 有时候会**“忍不住想和用户搭话”**。
|
||||
它常常前置一段解释性废话,中间插入工具调用的 JSON 参数,并在末尾再补上一句总结:
|
||||
```text
|
||||
好的,我这就帮你读取 README.md 的内容:
|
||||
{"tool_calls":[{"name":"read_file","input":{"path":"README.md"}}]}
|
||||
请稍等片刻,我马上把它读出来。
|
||||
```
|
||||
**旧版系统痛点:**
|
||||
原有的代码存在**严格模式(Strict Mode)**校验:
|
||||
```go
|
||||
// 如果解析到的 JSON 块前后存在任何非空字符串,就放弃当作工具调用!
|
||||
if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" ... {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
```
|
||||
这直接导致上述结构被网关认定是一段“普通聊天”,直接原封不动地返回给用户,这直接干挂了后续的工具自动执行流程。
|
||||
|
||||
### 2. 工具名格式幻觉:擅自修改或前缀化工具名称
|
||||
由于 DeepSeek 的预训练数据中有大量的代码和不同的平台结构,它在回复工具名称时,常常无法忠实于 System Prompt 中提供的纯命名(也就是 `name: "read_file"`),而是加上前缀或者拼写变形,例如:
|
||||
* `{"name": "mcp.search_web"}` (自带命名空间)
|
||||
* `{"name": "tools.read_file"}`
|
||||
* `{"name": "search-web"}` (下划线变成了中划线)
|
||||
|
||||
**旧版系统痛点:**
|
||||
旧版系统对于工具名的匹配几乎只有“绝对相等”的字典级比对,只要差了一个字符或加了前缀,就会由于找不到合法工具而直接失败。
|
||||
|
||||
### 3. Role 角色的非标准返回
|
||||
在部分工具通信流的响应中,返回的内容其所属的 `role` 没有被标准化处理,可能携带意料之外的属性,或是与下游严格比对出现冲突。
|
||||
|
||||
---
|
||||
|
||||
## 二、PR #74 的代码增强修复方案
|
||||
|
||||
为了解决大模型这种自身的不规范行为,PR #74 在系统的中间层网关联入了一个**极其包容的容错引擎**。它并不强制要求模型“改过自新”,而是主动做了以下三块增强:
|
||||
|
||||
### 1. 从流中分离混合内容(废除 Strict Mode)
|
||||
修改了 `internal/adapter/openai/tool_sieve_core.go`。
|
||||
取消了前后包裹文本的拦截逻辑。当系统扫描到流式结构中有完整的 `{"tool_calls":...}` 时,它会将废话和 JSON 分发到不同的事件流中:
|
||||
```go
|
||||
if prefix != "" {
|
||||
// 将前面的“好的,帮你读文件”剥离出来作为常规文本输出
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
// 捕获并拦截中间的工具请求,进行背后执行
|
||||
state.pendingToolCalls = calls
|
||||
```
|
||||
**效果:** 用户的屏幕上只能看到正常的文字交流,而后端的工具也会立刻挂载。
|
||||
|
||||
### 2. 多级宽容匹配引擎 (Resolve Allowed Tool Name)
|
||||
在 `internal/util/toolcalls_parse.go` 中,新增了一个由严到松降级匹配的强大漏斗策略函数 `resolveAllowedToolName`:
|
||||
|
||||
1. **绝对匹配**:和以前一样,`read_file` == `read_file`。
|
||||
2. **忽略大小写**:`Read_File` 算作合法。
|
||||
3. **命名空间抹除**:通过寻找最后一个 `.` 来剥离前缀,强制将 `mcp.search_web` 还原出真实的 `search_web`。
|
||||
4. **终极正则清洗**:
|
||||
引入 `var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)`。
|
||||
这个正则剥离了字符串里所有的符号、空格、格式符。
|
||||
将传入的 `read-file` 洗除符号成为 `readfile`,并去和系统中所有合法工具同样清洗后的版本进行比较。只要核心字母一致,即算作匹配成功。
|
||||
|
||||
### 3. Role 归一化 (Normalize OpenAIRoleForPrompt)
|
||||
在 `internal/adapter/openai/responses_input_items.go` 等处,引入了特定的 `normalizeOpenAIRoleForPrompt(role)` 清洗,保证输入和传递给上游的 Role 枚举始终受控,消除了因为意外的身份字段传参崩溃。
|
||||
|
||||
---
|
||||
|
||||
## 报告总结与 tool_sieve 的本质作用
|
||||
|
||||
PR #74 / #75 并没有从模型本身开刀,而是基于**网关应足够健壮**的设计哲学。
|
||||
|
||||
**其实整个增强实现,本质上实现了一个名为 `tool_sieve` (工具筛子) 的中间层网关。**
|
||||
面对 DeepSeek 这种吐出一团混合了聊天文字与 JSON 面团的“不标准”数据流,`tool_sieve` 就像一个勤劳的高精度筛子,不仅人工揉开了面团:
|
||||
1. 它把散文分拣出来,塞回标准结构的 `content` 字段去展示;
|
||||
2. 剥离并清洗出有瑕疵的 JSON 块,按照 OpenAI 的标准格式小心翼翼地放进 `tool_calls` 结构里去等待执行。
|
||||
|
||||
这意味着,即便 AI 被配置了奇怪的回复设定、加粗了强调语言,甚至是犯了标点符号拼写小失误,**只要它输出了可以拼凑成工具指令的 JSON 核心单元,整个中继层就能将其挽救,并把正确的工具结果呈现给模型和用户**。 这不仅修复了缺陷,更极大地增强了工具网关的通用性和鲁棒性。
|
||||
@@ -1,32 +0,0 @@
|
||||
# DS2API Refactor Baseline (Historical Snapshot)
|
||||
|
||||
- Snapshot time: `2026-02-22T08:53:54Z`
|
||||
- Snapshot branch: `dev`
|
||||
- Snapshot HEAD: `5d3989a`
|
||||
- Scope: backend + node api + webui large-file decoupling (no behavior change)
|
||||
|
||||
## Gate Commands
|
||||
|
||||
1. `./tests/scripts/run-unit-all.sh`
|
||||
- Result: PASS
|
||||
- Includes:
|
||||
- `go test ./...`
|
||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`
|
||||
2. `npm --prefix webui run build`
|
||||
- Result: PASS
|
||||
3. `./tests/scripts/check-refactor-line-gate.sh`
|
||||
- Result: PASS (`checked=131 missing=0 over_limit=0`)
|
||||
4. Stage gates (1-5) replay:
|
||||
- `go test ./internal/config ./internal/admin ./internal/account ./internal/deepseek ./internal/format/openai` -> PASS
|
||||
- `go test ./internal/adapter/openai ./internal/util ./internal/sse ./internal/compat` -> PASS
|
||||
- `go test ./internal/adapter/claude ./internal/adapter/gemini ./internal/config` -> PASS
|
||||
- `go test ./internal/testsuite ./cmd/ds2api-tests` -> PASS
|
||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js` -> PASS
|
||||
5. Final full regression:
|
||||
- `go test ./... -count=1` -> PASS
|
||||
|
||||
## Notes
|
||||
|
||||
- This file records a historical baseline for refactor process tracking.
|
||||
- It is not intended to represent the current repository HEAD.
|
||||
- Frontend manual smoke for phase 6 still requires human execution and sign-off.
|
||||
@@ -1,6 +1,8 @@
|
||||
# Line gate targets for large-file decoupling refactor.
|
||||
# Default limit: 300 lines
|
||||
# Backend default limit: 300 lines
|
||||
# Frontend (webui/) default limit: 500 lines
|
||||
# Entry/facade limit: 120 lines (enforced in script)
|
||||
# Test files are ignored by the gate script.
|
||||
|
||||
internal/config/config.go
|
||||
internal/config/logger.go
|
||||
@@ -51,7 +53,6 @@ internal/adapter/openai/responses_stream_runtime_events.go
|
||||
internal/adapter/openai/responses_stream_runtime_toolcalls.go
|
||||
internal/adapter/openai/tool_sieve_state.go
|
||||
internal/adapter/openai/tool_sieve_core.go
|
||||
internal/adapter/openai/tool_sieve_incremental.go
|
||||
internal/adapter/openai/tool_sieve_jsonscan.go
|
||||
|
||||
internal/util/toolcalls_parse.go
|
||||
@@ -115,7 +116,6 @@ webui/src/app/useAdminAuth.js
|
||||
webui/src/app/useAdminConfig.js
|
||||
webui/src/layout/DashboardShell.jsx
|
||||
|
||||
webui/src/components/AccountManager.jsx
|
||||
webui/src/features/account/AccountManagerContainer.jsx
|
||||
webui/src/features/account/useAccountsData.js
|
||||
webui/src/features/account/useAccountActions.js
|
||||
@@ -125,14 +125,12 @@ webui/src/features/account/AccountsTable.jsx
|
||||
webui/src/features/account/AddKeyModal.jsx
|
||||
webui/src/features/account/AddAccountModal.jsx
|
||||
|
||||
webui/src/components/ApiTester.jsx
|
||||
webui/src/features/apiTester/ApiTesterContainer.jsx
|
||||
webui/src/features/apiTester/useApiTesterState.js
|
||||
webui/src/features/apiTester/useChatStreamClient.js
|
||||
webui/src/features/apiTester/ConfigPanel.jsx
|
||||
webui/src/features/apiTester/ChatPanel.jsx
|
||||
|
||||
webui/src/components/Settings.jsx
|
||||
webui/src/features/settings/SettingsContainer.jsx
|
||||
webui/src/features/settings/useSettingsForm.js
|
||||
webui/src/features/settings/settingsApi.js
|
||||
@@ -142,7 +140,6 @@ webui/src/features/settings/BehaviorSection.jsx
|
||||
webui/src/features/settings/ModelSection.jsx
|
||||
webui/src/features/settings/BackupSection.jsx
|
||||
|
||||
webui/src/components/VercelSync.jsx
|
||||
webui/src/features/vercel/VercelSyncContainer.jsx
|
||||
webui/src/features/vercel/useVercelSyncState.js
|
||||
webui/src/features/vercel/VercelSyncForm.jsx
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user