mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 16:35:27 +08:00
@@ -24,7 +24,7 @@ go test ./...
|
||||
```
|
||||
|
||||
```bash
|
||||
node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js
|
||||
node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js
|
||||
```
|
||||
|
||||
### 端到端测试 | End-to-End Tests
|
||||
@@ -39,7 +39,7 @@ node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js
|
||||
- `go test ./... -count=1`(单元测试)
|
||||
- `node --check api/chat-stream.js`(语法检查)
|
||||
- `node --check api/helpers/stream-tool-sieve.js`(语法检查)
|
||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js`(Node 流式拦截单测)
|
||||
- `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`(Node 流式拦截 + compat 单测)
|
||||
- `npm run build --prefix webui`(WebUI 构建检查)
|
||||
|
||||
2. **隔离启动**:复制 `config.json` 到临时目录,启动独立服务进程
|
||||
|
||||
@@ -10,31 +10,14 @@ const {
|
||||
parseToolCalls,
|
||||
formatOpenAIStreamToolCalls,
|
||||
} = require('./helpers/stream-tool-sieve');
|
||||
const {
|
||||
BASE_HEADERS,
|
||||
SKIP_PATTERNS,
|
||||
SKIP_EXACT_PATHS,
|
||||
} = require('./shared/deepseek-constants');
|
||||
|
||||
const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion';
|
||||
|
||||
const BASE_HEADERS = {
|
||||
Host: 'chat.deepseek.com',
|
||||
'User-Agent': 'DeepSeek/1.6.11 Android/35',
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
'x-client-platform': 'android',
|
||||
'x-client-version': '1.6.11',
|
||||
'x-client-locale': 'zh_CN',
|
||||
'accept-charset': 'UTF-8',
|
||||
};
|
||||
|
||||
const SKIP_PATTERNS = [
|
||||
'quasi_status',
|
||||
'elapsed_secs',
|
||||
'token_usage',
|
||||
'pending_fragment',
|
||||
'conversation_mode',
|
||||
'fragments/-1/status',
|
||||
'fragments/-2/status',
|
||||
'fragments/-3/status',
|
||||
];
|
||||
|
||||
module.exports = async function handler(req, res) {
|
||||
setCorsHeaders(res);
|
||||
if (req.method === 'OPTIONS') {
|
||||
@@ -725,7 +708,7 @@ function extractContentRecursive(items, defaultType) {
|
||||
}
|
||||
|
||||
function shouldSkipPath(pathValue) {
|
||||
if (pathValue === 'response/search_status') {
|
||||
if (SKIP_EXACT_PATHS.has(pathValue)) {
|
||||
return true;
|
||||
}
|
||||
for (const p of SKIP_PATTERNS) {
|
||||
@@ -808,7 +791,16 @@ function estimateTokens(text) {
|
||||
if (!t) {
|
||||
return 0;
|
||||
}
|
||||
const n = Math.floor(Array.from(t).length / 4);
|
||||
let asciiChars = 0;
|
||||
let nonASCIIChars = 0;
|
||||
for (const ch of Array.from(t)) {
|
||||
if (ch.charCodeAt(0) < 128) {
|
||||
asciiChars += 1;
|
||||
} else {
|
||||
nonASCIIChars += 1;
|
||||
}
|
||||
}
|
||||
const n = Math.floor(asciiChars / 4) + Math.floor((nonASCIIChars * 10 + 7) / 13);
|
||||
return n < 1 ? 1 : n;
|
||||
}
|
||||
|
||||
@@ -972,4 +964,5 @@ module.exports.__test = {
|
||||
resolveToolcallPolicy,
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
estimateTokens,
|
||||
};
|
||||
|
||||
60
api/compat/js_compat_test.js
Normal file
60
api/compat/js_compat_test.js
Normal file
@@ -0,0 +1,60 @@
|
||||
'use strict';
|
||||
|
||||
const test = require('node:test');
|
||||
const assert = require('node:assert/strict');
|
||||
const fs = require('node:fs');
|
||||
const path = require('node:path');
|
||||
|
||||
const chatStream = require('../chat-stream');
|
||||
const { parseToolCalls } = require('../helpers/stream-tool-sieve');
|
||||
|
||||
const { parseChunkForContent, estimateTokens } = chatStream.__test;
|
||||
|
||||
const compatRoot = path.resolve(__dirname, '../../tests/compat');
|
||||
|
||||
function readJSON(filePath) {
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'));
|
||||
}
|
||||
|
||||
test('js compat: sse fixtures', () => {
|
||||
const fixtureDir = path.join(compatRoot, 'fixtures', 'sse_chunks');
|
||||
const expectedDir = path.join(compatRoot, 'expected');
|
||||
const files = fs.readdirSync(fixtureDir).filter((f) => f.endsWith('.json')).sort();
|
||||
assert.ok(files.length > 0);
|
||||
|
||||
for (const file of files) {
|
||||
const name = file.replace(/\.json$/i, '');
|
||||
const fixture = readJSON(path.join(fixtureDir, file));
|
||||
const expected = readJSON(path.join(expectedDir, `sse_${name}.json`));
|
||||
const got = parseChunkForContent(fixture.chunk, Boolean(fixture.thinking_enabled), fixture.current_type || 'text');
|
||||
assert.deepEqual(got.parts, expected.parts, `${name}: parts mismatch`);
|
||||
assert.equal(got.finished, expected.finished, `${name}: finished mismatch`);
|
||||
assert.equal(got.newType, expected.new_type, `${name}: newType mismatch`);
|
||||
}
|
||||
});
|
||||
|
||||
test('js compat: toolcall fixtures', () => {
|
||||
const fixtureDir = path.join(compatRoot, 'fixtures', 'toolcalls');
|
||||
const expectedDir = path.join(compatRoot, 'expected');
|
||||
const files = fs.readdirSync(fixtureDir).filter((f) => f.endsWith('.json')).sort();
|
||||
assert.ok(files.length > 0);
|
||||
|
||||
for (const file of files) {
|
||||
const name = file.replace(/\.json$/i, '');
|
||||
const fixture = readJSON(path.join(fixtureDir, file));
|
||||
const expected = readJSON(path.join(expectedDir, `toolcalls_${name}.json`));
|
||||
const got = parseToolCalls(fixture.text, fixture.tool_names || []);
|
||||
assert.deepEqual(got, expected.calls, `${name}: calls mismatch`);
|
||||
}
|
||||
});
|
||||
|
||||
test('js compat: token fixtures', () => {
|
||||
const fixture = readJSON(path.join(compatRoot, 'fixtures', 'token_cases.json'));
|
||||
const expected = readJSON(path.join(compatRoot, 'expected', 'token_cases.json'));
|
||||
const expectedByName = new Map(expected.cases.map((c) => [c.name, c.tokens]));
|
||||
for (const c of fixture.cases) {
|
||||
assert.ok(expectedByName.has(c.name), `missing expected case: ${c.name}`);
|
||||
const got = estimateTokens(c.text);
|
||||
assert.equal(got, expectedByName.get(c.name), `${c.name}: tokens mismatch`);
|
||||
}
|
||||
});
|
||||
66
api/shared/deepseek-constants.js
Normal file
66
api/shared/deepseek-constants.js
Normal file
@@ -0,0 +1,66 @@
|
||||
'use strict';
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
const DEFAULT_BASE_HEADERS = Object.freeze({
|
||||
Host: 'chat.deepseek.com',
|
||||
'User-Agent': 'DeepSeek/1.6.11 Android/35',
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
'x-client-platform': 'android',
|
||||
'x-client-version': '1.6.11',
|
||||
'x-client-locale': 'zh_CN',
|
||||
'accept-charset': 'UTF-8',
|
||||
});
|
||||
|
||||
const DEFAULT_SKIP_PATTERNS = Object.freeze([
|
||||
'quasi_status',
|
||||
'elapsed_secs',
|
||||
'token_usage',
|
||||
'pending_fragment',
|
||||
'conversation_mode',
|
||||
'fragments/-1/status',
|
||||
'fragments/-2/status',
|
||||
'fragments/-3/status',
|
||||
]);
|
||||
|
||||
const DEFAULT_SKIP_EXACT_PATHS = Object.freeze([
|
||||
'response/search_status',
|
||||
]);
|
||||
|
||||
function loadSharedConstants() {
|
||||
const sharedPath = path.resolve(__dirname, '../../internal/deepseek/constants_shared.json');
|
||||
try {
|
||||
const raw = fs.readFileSync(sharedPath, 'utf8');
|
||||
const parsed = JSON.parse(raw);
|
||||
const baseHeaders = parsed && typeof parsed.base_headers === 'object' && !Array.isArray(parsed.base_headers)
|
||||
? { ...DEFAULT_BASE_HEADERS, ...parsed.base_headers }
|
||||
: { ...DEFAULT_BASE_HEADERS };
|
||||
const skipPatterns = Array.isArray(parsed && parsed.skip_contains_patterns)
|
||||
? parsed.skip_contains_patterns.filter((v) => typeof v === 'string' && v !== '')
|
||||
: [...DEFAULT_SKIP_PATTERNS];
|
||||
const skipExactPaths = Array.isArray(parsed && parsed.skip_exact_paths)
|
||||
? parsed.skip_exact_paths.filter((v) => typeof v === 'string' && v !== '')
|
||||
: [...DEFAULT_SKIP_EXACT_PATHS];
|
||||
return {
|
||||
baseHeaders,
|
||||
skipPatterns,
|
||||
skipExactPaths,
|
||||
};
|
||||
} catch (_err) {
|
||||
return {
|
||||
baseHeaders: { ...DEFAULT_BASE_HEADERS },
|
||||
skipPatterns: [...DEFAULT_SKIP_PATTERNS],
|
||||
skipExactPaths: [...DEFAULT_SKIP_EXACT_PATHS],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const shared = loadSharedConstants();
|
||||
|
||||
module.exports = {
|
||||
BASE_HEADERS: Object.freeze(shared.baseHeaders),
|
||||
SKIP_PATTERNS: Object.freeze(shared.skipPatterns),
|
||||
SKIP_EXACT_PATHS: new Set(shared.skipExactPaths),
|
||||
};
|
||||
@@ -2,6 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -28,10 +30,21 @@ func main() {
|
||||
Addr: "0.0.0.0:" + port,
|
||||
Handler: app.Router,
|
||||
}
|
||||
localURL := fmt.Sprintf("http://127.0.0.1:%s", port)
|
||||
lanIP := detectLANIPv4()
|
||||
lanURL := ""
|
||||
if lanIP != "" {
|
||||
lanURL = fmt.Sprintf("http://%s:%s", lanIP, port)
|
||||
}
|
||||
|
||||
// Start server in a goroutine so we can listen for shutdown signals.
|
||||
go func() {
|
||||
config.Logger.Info("starting ds2api", "port", port)
|
||||
if lanURL != "" {
|
||||
config.Logger.Info("starting ds2api", "bind", srv.Addr, "port", port, "local_url", localURL, "lan_url", lanURL, "lan_ip", lanIP)
|
||||
} else {
|
||||
config.Logger.Info("starting ds2api", "bind", srv.Addr, "port", port, "local_url", localURL)
|
||||
config.Logger.Warn("lan ip not detected; check active network interfaces")
|
||||
}
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
config.Logger.Error("server stopped unexpectedly", "error", err)
|
||||
os.Exit(1)
|
||||
@@ -54,3 +67,36 @@ func main() {
|
||||
}
|
||||
config.Logger.Info("server gracefully stopped")
|
||||
}
|
||||
|
||||
func detectLANIPv4() string {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
ip = ip.To4()
|
||||
if ip == nil || !ip.IsPrivate() {
|
||||
continue
|
||||
}
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -20,13 +20,18 @@ type Pool struct {
|
||||
maxInflightPerAccount int
|
||||
recommendedConcurrency int
|
||||
maxQueueSize int
|
||||
globalMaxInflight int
|
||||
}
|
||||
|
||||
func NewPool(store *config.Store) *Pool {
|
||||
maxPer := 2
|
||||
if store != nil {
|
||||
maxPer = store.RuntimeAccountMaxInflight()
|
||||
}
|
||||
p := &Pool{
|
||||
store: store,
|
||||
inUse: map[string]int{},
|
||||
maxInflightPerAccount: maxInflightFromEnv(),
|
||||
maxInflightPerAccount: maxPer,
|
||||
}
|
||||
p.Reset()
|
||||
return p
|
||||
@@ -49,8 +54,18 @@ func (p *Pool) Reset() {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
if p.store != nil {
|
||||
p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight()
|
||||
} else {
|
||||
p.maxInflightPerAccount = maxInflightFromEnv()
|
||||
}
|
||||
recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
|
||||
queueLimit := maxQueueFromEnv(recommended)
|
||||
globalLimit := recommended
|
||||
if p.store != nil {
|
||||
queueLimit = p.store.RuntimeAccountMaxQueue(recommended)
|
||||
globalLimit = p.store.RuntimeGlobalMaxInflight(recommended)
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.drainWaitersLocked()
|
||||
@@ -58,10 +73,12 @@ func (p *Pool) Reset() {
|
||||
p.inUse = map[string]int{}
|
||||
p.recommendedConcurrency = recommended
|
||||
p.maxQueueSize = queueLimit
|
||||
p.globalMaxInflight = globalLimit
|
||||
config.Logger.Info(
|
||||
"[init_account_queue] initialized",
|
||||
"total", len(ids),
|
||||
"max_inflight_per_account", p.maxInflightPerAccount,
|
||||
"global_max_inflight", p.globalMaxInflight,
|
||||
"recommended_concurrency", p.recommendedConcurrency,
|
||||
"max_queue_size", p.maxQueueSize,
|
||||
)
|
||||
@@ -109,7 +126,7 @@ func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[strin
|
||||
|
||||
func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
|
||||
if target != "" {
|
||||
if exclude[target] || p.inUse[target] >= p.maxInflightPerAccount {
|
||||
if exclude[target] || !p.canAcquireIDLocked(target) {
|
||||
return config.Account{}, false
|
||||
}
|
||||
acc, ok := p.store.FindAccount(target)
|
||||
@@ -133,7 +150,7 @@ func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Acc
|
||||
func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
|
||||
for i := 0; i < len(p.queue); i++ {
|
||||
id := p.queue[i]
|
||||
if exclude[id] || p.inUse[id] >= p.maxInflightPerAccount {
|
||||
if exclude[id] || !p.canAcquireIDLocked(id) {
|
||||
continue
|
||||
}
|
||||
acc, ok := p.store.FindAccount(id)
|
||||
@@ -205,12 +222,35 @@ func (p *Pool) Status() map[string]any {
|
||||
"available_accounts": available,
|
||||
"in_use_accounts": inUseAccounts,
|
||||
"max_inflight_per_account": p.maxInflightPerAccount,
|
||||
"global_max_inflight": p.globalMaxInflight,
|
||||
"recommended_concurrency": p.recommendedConcurrency,
|
||||
"waiting": len(p.waiters),
|
||||
"max_queue_size": p.maxQueueSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) {
|
||||
if maxInflightPerAccount <= 0 {
|
||||
maxInflightPerAccount = 1
|
||||
}
|
||||
if maxQueueSize < 0 {
|
||||
maxQueueSize = 0
|
||||
}
|
||||
if globalMaxInflight <= 0 {
|
||||
globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts())
|
||||
if globalMaxInflight <= 0 {
|
||||
globalMaxInflight = maxInflightPerAccount
|
||||
}
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.maxInflightPerAccount = maxInflightPerAccount
|
||||
p.maxQueueSize = maxQueueSize
|
||||
p.globalMaxInflight = globalMaxInflight
|
||||
p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount)
|
||||
p.notifyWaiterLocked()
|
||||
}
|
||||
|
||||
func maxInflightFromEnv() int {
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
@@ -300,3 +340,24 @@ func maxQueueFromEnv(defaultSize int) int {
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
func (p *Pool) canAcquireIDLocked(accountID string) bool {
|
||||
if accountID == "" {
|
||||
return false
|
||||
}
|
||||
if p.inUse[accountID] >= p.maxInflightPerAccount {
|
||||
return false
|
||||
}
|
||||
if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Pool) currentInUseLocked() int {
|
||||
total := 0
|
||||
for _, n := range p.inUse {
|
||||
total += n
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
11
internal/adapter/claude/convert.go
Normal file
11
internal/adapter/claude/convert.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/claudeconv"
|
||||
)
|
||||
|
||||
const defaultClaudeModel = "claude-sonnet-4-5"
|
||||
|
||||
func convertClaudeToDeepSeek(claudeReq map[string]any, store ConfigReader) map[string]any {
|
||||
return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, defaultClaudeModel)
|
||||
}
|
||||
29
internal/adapter/claude/deps.go
Normal file
29
internal/adapter/claude/deps.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type AuthResolver interface {
|
||||
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||
Release(a *auth.RequestAuth)
|
||||
}
|
||||
|
||||
type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
}
|
||||
|
||||
type ConfigReader interface {
|
||||
ClaudeMapping() map[string]string
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
33
internal/adapter/claude/deps_injection_test.go
Normal file
33
internal/adapter/claude/deps_injection_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package claude
|
||||
|
||||
import "testing"
|
||||
|
||||
type mockClaudeConfig struct {
|
||||
m map[string]string
|
||||
}
|
||||
|
||||
func (m mockClaudeConfig) ClaudeMapping() map[string]string { return m.m }
|
||||
|
||||
func TestNormalizeClaudeRequestUsesConfigInterfaceMapping(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
out, err := normalizeClaudeRequest(mockClaudeConfig{
|
||||
m: map[string]string{
|
||||
"fast": "deepseek-chat",
|
||||
"slow": "deepseek-reasoner-search",
|
||||
},
|
||||
}, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeClaudeRequest error: %v", err)
|
||||
}
|
||||
if out.Standard.ResolvedModel != "deepseek-reasoner-search" {
|
||||
t.Fatalf("resolved model mismatch: got=%q", out.Standard.ResolvedModel)
|
||||
}
|
||||
if !out.Standard.Thinking || !out.Standard.Search {
|
||||
t.Fatalf("unexpected flags: thinking=%v search=%v", out.Standard.Thinking, out.Standard.Search)
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
claudefmt "ds2api/internal/format/claude"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
@@ -21,9 +23,9 @@ import (
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
DS *deepseek.Client
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -98,7 +100,7 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, stdReq.Thinking, true)
|
||||
respBody := util.BuildClaudeMessageResponse(
|
||||
respBody := claudefmt.BuildMessageResponse(
|
||||
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
stdReq.ResponseModel,
|
||||
norm.NormalizedMessages,
|
||||
@@ -169,279 +171,38 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
send := func(event string, v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = w.Write([]byte("event: "))
|
||||
_, _ = w.Write([]byte(event))
|
||||
_, _ = w.Write([]byte("\n"))
|
||||
_, _ = w.Write([]byte("data: "))
|
||||
_, _ = w.Write(b)
|
||||
_, _ = w.Write([]byte("\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
sendError := func(message string) {
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
msg = "upstream stream error"
|
||||
}
|
||||
send("error", map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "api_error",
|
||||
"message": msg,
|
||||
"code": "internal_error",
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
|
||||
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages))
|
||||
send("message_start", map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
})
|
||||
streamRuntime := newClaudeStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
model,
|
||||
messages,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
)
|
||||
streamRuntime.sendMessageStart()
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
hasContent := false
|
||||
lastContent := time.Now()
|
||||
keepaliveCount := 0
|
||||
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
|
||||
nextBlockIndex := 0
|
||||
thinkingBlockOpen := false
|
||||
thinkingBlockIndex := -1
|
||||
textBlockOpen := false
|
||||
textBlockIndex := -1
|
||||
ended := false
|
||||
|
||||
closeThinkingBlock := func() {
|
||||
if !thinkingBlockOpen {
|
||||
return
|
||||
}
|
||||
send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": thinkingBlockIndex,
|
||||
})
|
||||
thinkingBlockOpen = false
|
||||
thinkingBlockIndex = -1
|
||||
}
|
||||
closeTextBlock := func() {
|
||||
if !textBlockOpen {
|
||||
return
|
||||
}
|
||||
send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": textBlockIndex,
|
||||
})
|
||||
textBlockOpen = false
|
||||
textBlockIndex = -1
|
||||
}
|
||||
|
||||
finalize := func(stopReason string) {
|
||||
if ended {
|
||||
return
|
||||
}
|
||||
ended = true
|
||||
|
||||
closeThinkingBlock()
|
||||
closeTextBlock()
|
||||
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
|
||||
if bufferToolContent {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
idx := nextBlockIndex + i
|
||||
send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
},
|
||||
})
|
||||
send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
nextBlockIndex += len(detected)
|
||||
} else if finalText != "" {
|
||||
idx := nextBlockIndex
|
||||
nextBlockIndex++
|
||||
send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": idx,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": finalText,
|
||||
},
|
||||
})
|
||||
send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
send("message_stop", map[string]any{"type": "message_stop"})
|
||||
}
|
||||
|
||||
pingTicker := time.NewTicker(claudeStreamPingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-pingTicker.C:
|
||||
if !hasContent {
|
||||
keepaliveCount++
|
||||
if keepaliveCount >= claudeStreamMaxKeepaliveCnt {
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
}
|
||||
if hasContent && time.Since(lastContent) > claudeStreamIdleTimeout {
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
send("ping", map[string]any{"type": "ping"})
|
||||
case parsed, ok := <-parsedLines:
|
||||
if !ok {
|
||||
if err := <-done; err != nil {
|
||||
sendError(err.Error())
|
||||
return
|
||||
}
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
sendError(parsed.ErrorMessage)
|
||||
return
|
||||
}
|
||||
if parsed.Stop {
|
||||
finalize("end_turn")
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
|
||||
hasContent = true
|
||||
lastContent = time.Now()
|
||||
keepaliveCount = 0
|
||||
|
||||
if p.Type == "thinking" {
|
||||
if !thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
thinking.WriteString(p.Text)
|
||||
closeTextBlock()
|
||||
if !thinkingBlockOpen {
|
||||
thinkingBlockIndex = nextBlockIndex
|
||||
nextBlockIndex++
|
||||
send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": thinkingBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
},
|
||||
})
|
||||
thinkingBlockOpen = true
|
||||
}
|
||||
send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": thinkingBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "thinking_delta",
|
||||
"thinking": p.Text,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
text.WriteString(p.Text)
|
||||
if bufferToolContent {
|
||||
continue
|
||||
}
|
||||
closeThinkingBlock()
|
||||
if !textBlockOpen {
|
||||
textBlockIndex = nextBlockIndex
|
||||
nextBlockIndex++
|
||||
send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": textBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
textBlockOpen = true
|
||||
}
|
||||
send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": textBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": p.Text,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: claudeStreamPingInterval,
|
||||
IdleTimeout: claudeStreamIdleTimeout,
|
||||
MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendPing()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: streamRuntime.onFinalize,
|
||||
})
|
||||
}
|
||||
|
||||
func writeClaudeError(w http.ResponseWriter, status int, message string) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
@@ -13,7 +14,7 @@ type claudeNormalizedRequest struct {
|
||||
NormalizedMessages []any
|
||||
}
|
||||
|
||||
func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) {
|
||||
func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNormalizedRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
@@ -30,14 +31,14 @@ func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNorm
|
||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...)
|
||||
}
|
||||
|
||||
dsPayload := util.ConvertClaudeToDeepSeek(payload, store)
|
||||
dsPayload := convertClaudeToDeepSeek(payload, store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||
if !ok {
|
||||
thinkingEnabled = false
|
||||
searchEnabled = false
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
|
||||
return claudeNormalizedRequest{
|
||||
|
||||
308
internal/adapter/claude/stream_runtime.go
Normal file
308
internal/adapter/claude/stream_runtime.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type claudeStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
model string
|
||||
toolNames []string
|
||||
messages []any
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
bufferToolContent bool
|
||||
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
|
||||
nextBlockIndex int
|
||||
thinkingBlockOpen bool
|
||||
thinkingBlockIndex int
|
||||
textBlockOpen bool
|
||||
textBlockIndex int
|
||||
ended bool
|
||||
upstreamErr string
|
||||
}
|
||||
|
||||
func newClaudeStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
model string,
|
||||
messages []any,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
) *claudeStreamRuntime {
|
||||
return &claudeStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
model: model,
|
||||
messages: messages,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferToolContent: len(toolNames) > 0,
|
||||
toolNames: toolNames,
|
||||
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
thinkingBlockIndex: -1,
|
||||
textBlockIndex: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) send(event string, v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = s.w.Write([]byte("event: "))
|
||||
_, _ = s.w.Write([]byte(event))
|
||||
_, _ = s.w.Write([]byte("\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendError(message string) {
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
msg = "upstream stream error"
|
||||
}
|
||||
s.send("error", map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "api_error",
|
||||
"message": msg,
|
||||
"code": "internal_error",
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendPing() {
|
||||
s.send("ping", map[string]any{"type": "ping"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendMessageStart() {
|
||||
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
|
||||
s.send("message_start", map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": s.messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": s.model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) closeThinkingBlock() {
|
||||
if !s.thinkingBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.thinkingBlockIndex,
|
||||
})
|
||||
s.thinkingBlockOpen = false
|
||||
s.thinkingBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) closeTextBlock() {
|
||||
if !s.textBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.textBlockIndex,
|
||||
})
|
||||
s.textBlockOpen = false
|
||||
s.textBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
if s.ended {
|
||||
return
|
||||
}
|
||||
s.ended = true
|
||||
|
||||
s.closeThinkingBlock()
|
||||
s.closeTextBlock()
|
||||
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
|
||||
if s.bufferToolContent {
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
idx := s.nextBlockIndex + i
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
},
|
||||
})
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
s.nextBlockIndex += len(detected)
|
||||
} else if finalText != "" {
|
||||
idx := s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": idx,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": finalText,
|
||||
},
|
||||
})
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
s.send("message_stop", map[string]any{"type": "message_stop"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
s.upstreamErr = parsed.ErrorMessage
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||
}
|
||||
if parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
|
||||
if p.Type == "thinking" {
|
||||
if !s.thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.closeTextBlock()
|
||||
if !s.thinkingBlockOpen {
|
||||
s.thinkingBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
},
|
||||
})
|
||||
s.thinkingBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "thinking_delta",
|
||||
"thinking": p.Text,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if s.bufferToolContent {
|
||||
continue
|
||||
}
|
||||
s.closeThinkingBlock()
|
||||
if !s.textBlockOpen {
|
||||
s.textBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.textBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.textBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.textBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": p.Text,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
|
||||
if string(reason) == "upstream_error" {
|
||||
s.sendError(s.upstreamErr)
|
||||
return
|
||||
}
|
||||
if scannerErr != nil {
|
||||
s.sendError(scannerErr.Error())
|
||||
return
|
||||
}
|
||||
s.finalize("end_turn")
|
||||
}
|
||||
237
internal/adapter/openai/chat_stream_runtime.go
Normal file
237
internal/adapter/openai/chat_stream_runtime.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type chatStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
completionID string
|
||||
created int64
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
|
||||
firstChunkSent bool
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
|
||||
toolSieve toolStreamSieveState
|
||||
streamToolCallIDs map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
func newChatStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
completionID string,
|
||||
created int64,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
) *chatStreamRuntime {
|
||||
return &chatStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
completionID: completionID,
|
||||
created: created,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
toolNames: toolNames,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) sendKeepAlive() {
|
||||
if !s.canFlush {
|
||||
return
|
||||
}
|
||||
_, _ = s.w.Write([]byte(": keep-alive\n\n"))
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) sendChunk(v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) sendDone() {
|
||||
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 && !s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": util.FormatOpenAIStreamToolCalls(detected),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
} else if s.bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&s.toolSieve, s.toolNames) {
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
}
|
||||
delta := map[string]any{
|
||||
"content": evt.Content,
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
if len(detected) > 0 || s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)},
|
||||
openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText),
|
||||
))
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")}
|
||||
}
|
||||
if parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested}
|
||||
}
|
||||
|
||||
newChoices := make([]map[string]any, 0, len(parsed.Parts))
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
delta := map[string]any{}
|
||||
if !s.firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
if p.Type == "thinking" {
|
||||
if s.thinkingEnabled {
|
||||
s.thinking.WriteString(p.Text)
|
||||
delta["reasoning_content"] = p.Text
|
||||
}
|
||||
} else {
|
||||
s.text.WriteString(p.Text)
|
||||
if !s.bufferToolContent {
|
||||
delta["content"] = p.Text
|
||||
} else {
|
||||
events := processToolSieveChunk(&s.toolSieve, p.Text, s.toolNames)
|
||||
for _, evt := range events {
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
|
||||
continue
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls),
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
contentDelta := map[string]any{
|
||||
"content": evt.Content,
|
||||
}
|
||||
if !s.firstChunkSent {
|
||||
contentDelta["role"] = "assistant"
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, contentDelta))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(delta) > 0 {
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, delta))
|
||||
}
|
||||
}
|
||||
|
||||
if len(newChoices) > 0 {
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(s.completionID, s.created, s.model, newChoices, nil))
|
||||
}
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
35
internal/adapter/openai/deps.go
Normal file
35
internal/adapter/openai/deps.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type AuthResolver interface {
|
||||
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||
DetermineCaller(req *http.Request) (*auth.RequestAuth, error)
|
||||
Release(a *auth.RequestAuth)
|
||||
}
|
||||
|
||||
type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
}
|
||||
|
||||
type ConfigReader interface {
|
||||
ModelAliases() map[string]string
|
||||
CompatWideInputStrictOutput() bool
|
||||
ToolcallMode() string
|
||||
ToolcallEarlyEmitConfidence() string
|
||||
ResponsesStoreTTLSeconds() int
|
||||
EmbeddingsProvider() string
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
70
internal/adapter/openai/deps_injection_test.go
Normal file
70
internal/adapter/openai/deps_injection_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
type mockOpenAIConfig struct {
|
||||
aliases map[string]string
|
||||
wideInput bool
|
||||
toolMode string
|
||||
earlyEmit string
|
||||
responsesTTL int
|
||||
embedProv string
|
||||
}
|
||||
|
||||
func (m mockOpenAIConfig) ModelAliases() map[string]string { return m.aliases }
|
||||
func (m mockOpenAIConfig) CompatWideInputStrictOutput() bool {
|
||||
return m.wideInput
|
||||
}
|
||||
func (m mockOpenAIConfig) ToolcallMode() string { return m.toolMode }
|
||||
func (m mockOpenAIConfig) ToolcallEarlyEmitConfidence() string { return m.earlyEmit }
|
||||
func (m mockOpenAIConfig) ResponsesStoreTTLSeconds() int { return m.responsesTTL }
|
||||
func (m mockOpenAIConfig) EmbeddingsProvider() string { return m.embedProv }
|
||||
|
||||
func TestNormalizeOpenAIChatRequestWithConfigInterface(t *testing.T) {
|
||||
cfg := mockOpenAIConfig{
|
||||
aliases: map[string]string{
|
||||
"my-model": "deepseek-chat-search",
|
||||
},
|
||||
wideInput: true,
|
||||
}
|
||||
req := map[string]any{
|
||||
"model": "my-model",
|
||||
"messages": []any{map[string]any{"role": "user", "content": "hello"}},
|
||||
}
|
||||
out, err := normalizeOpenAIChatRequest(cfg, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeOpenAIChatRequest error: %v", err)
|
||||
}
|
||||
if out.ResolvedModel != "deepseek-chat-search" {
|
||||
t.Fatalf("resolved model mismatch: got=%q", out.ResolvedModel)
|
||||
}
|
||||
if !out.Search || out.Thinking {
|
||||
t.Fatalf("unexpected model flags: thinking=%v search=%v", out.Thinking, out.Search)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestWideInputPolicyFromInterface(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "deepseek-chat",
|
||||
"input": "hi",
|
||||
}
|
||||
|
||||
_, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||
aliases: map[string]string{},
|
||||
wideInput: false,
|
||||
}, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when wide input is disabled and only input is provided")
|
||||
}
|
||||
|
||||
out, err := normalizeOpenAIResponsesRequest(mockOpenAIConfig{
|
||||
aliases: map[string]string{},
|
||||
wideInput: true,
|
||||
}, req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error when wide input is enabled: %v", err)
|
||||
}
|
||||
if out.Surface != "openai_responses" {
|
||||
t.Fatalf("unexpected surface: %q", out.Surface)
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,9 @@ import (
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
@@ -25,9 +27,9 @@ import (
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
DS *deepseek.Client
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
|
||||
leaseMu sync.Mutex
|
||||
streamLeases map[string]streamLease
|
||||
@@ -136,7 +138,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := result.Text
|
||||
respBody := util.BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
@@ -158,214 +160,49 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
}
|
||||
|
||||
created := time.Now().Unix()
|
||||
firstChunkSent := false
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
var toolSieve toolStreamSieveState
|
||||
toolCallsEmitted := false
|
||||
streamToolCallIDs := map[int]string{}
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
lastContent := time.Now()
|
||||
hasContent := false
|
||||
keepaliveTicker := time.NewTicker(time.Duration(deepseek.KeepAliveTimeout) * time.Second)
|
||||
defer keepaliveTicker.Stop()
|
||||
keepaliveCountWithoutContent := 0
|
||||
|
||||
sendChunk := func(v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = w.Write([]byte("data: "))
|
||||
_, _ = w.Write(b)
|
||||
_, _ = w.Write([]byte("\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
sendDone := func() {
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
streamRuntime := newChatStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
finalPrompt,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
)
|
||||
|
||||
finalize := func(finishReason string) {
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) > 0 && !toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": util.FormatOpenAIStreamToolCalls(detected),
|
||||
}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
sendChunk(util.BuildOpenAIChatStreamChunk(
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
[]map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
} else if bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&toolSieve, toolNames) {
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
}
|
||||
delta := map[string]any{
|
||||
"content": evt.Content,
|
||||
}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
sendChunk(util.BuildOpenAIChatStreamChunk(
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
[]map[string]any{util.BuildOpenAIChatStreamDeltaChoice(0, delta)},
|
||||
nil,
|
||||
))
|
||||
}
|
||||
}
|
||||
if len(detected) > 0 || toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
sendChunk(util.BuildOpenAIChatStreamChunk(
|
||||
completionID,
|
||||
created,
|
||||
model,
|
||||
[]map[string]any{util.BuildOpenAIChatStreamFinishChoice(0, finishReason)},
|
||||
util.BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText),
|
||||
))
|
||||
sendDone()
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-keepaliveTicker.C:
|
||||
if !hasContent {
|
||||
keepaliveCountWithoutContent++
|
||||
if keepaliveCountWithoutContent >= deepseek.MaxKeepaliveCount {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
}
|
||||
if hasContent && time.Since(lastContent) > time.Duration(deepseek.StreamIdleTimeout)*time.Second {
|
||||
finalize("stop")
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendKeepAlive()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: func(reason streamengine.StopReason, _ error) {
|
||||
if string(reason) == "content_filter" {
|
||||
streamRuntime.finalize("content_filter")
|
||||
return
|
||||
}
|
||||
if canFlush {
|
||||
_, _ = w.Write([]byte(": keep-alive\n\n"))
|
||||
_ = rc.Flush()
|
||||
}
|
||||
case parsed, ok := <-parsedLines:
|
||||
if !ok {
|
||||
// Ensure scanner completion is observed only after all queued
|
||||
// SSE lines are drained, avoiding early finalize races.
|
||||
_ = <-done
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||
finalize("content_filter")
|
||||
return
|
||||
}
|
||||
if parsed.Stop {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
newChoices := make([]map[string]any, 0, len(parsed.Parts))
|
||||
for _, p := range parsed.Parts {
|
||||
if searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
hasContent = true
|
||||
lastContent = time.Now()
|
||||
keepaliveCountWithoutContent = 0
|
||||
delta := map[string]any{}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
if p.Type == "thinking" {
|
||||
if thinkingEnabled {
|
||||
thinking.WriteString(p.Text)
|
||||
delta["reasoning_content"] = p.Text
|
||||
}
|
||||
} else {
|
||||
text.WriteString(p.Text)
|
||||
if !bufferToolContent {
|
||||
delta["content"] = p.Text
|
||||
} else {
|
||||
events := processToolSieveChunk(&toolSieve, p.Text, toolNames)
|
||||
if len(events) == 0 {
|
||||
// Keep thinking delta only frame.
|
||||
}
|
||||
for _, evt := range events {
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs),
|
||||
}
|
||||
if !firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta))
|
||||
continue
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls),
|
||||
}
|
||||
if !firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, tcDelta))
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
contentDelta := map[string]any{
|
||||
"content": evt.Content,
|
||||
}
|
||||
if !firstChunkSent {
|
||||
contentDelta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, contentDelta))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(delta) > 0 {
|
||||
newChoices = append(newChoices, util.BuildOpenAIChatStreamDeltaChoice(0, delta))
|
||||
}
|
||||
}
|
||||
if len(newChoices) > 0 {
|
||||
sendChunk(util.BuildOpenAIChatStreamChunk(completionID, created, model, newChoices, nil))
|
||||
}
|
||||
}
|
||||
}
|
||||
streamRuntime.finalize("stop")
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package openai
|
||||
|
||||
import "ds2api/internal/util"
|
||||
import (
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) {
|
||||
messages := normalizeOpenAIMessagesForPrompt(messagesRaw)
|
||||
@@ -8,5 +10,5 @@ func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string)
|
||||
if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 {
|
||||
messages, toolNames = injectToolPrompt(messages, tools)
|
||||
}
|
||||
return util.MessagesPrepare(messages), toolNames
|
||||
return deepseek.MessagesPrepare(messages), toolNames
|
||||
}
|
||||
|
||||
@@ -6,13 +6,16 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/deepseek"
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -108,7 +111,7 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
responseObj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
@@ -127,114 +130,45 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
rc := http.NewResponseController(w)
|
||||
canFlush := rc.Flush() == nil
|
||||
|
||||
sendEvent := func(event string, payload map[string]any) {
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = w.Write([]byte("event: " + event + "\n"))
|
||||
_, _ = w.Write([]byte("data: "))
|
||||
_, _ = w.Write(b)
|
||||
_, _ = w.Write([]byte("\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent("response.created", util.BuildOpenAIResponsesCreatedPayload(responseID, model))
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
var sieve toolStreamSieveState
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
toolCallsEmitted := false
|
||||
streamToolCallIDs := map[int]string{}
|
||||
|
||||
finalize := func() {
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
if bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&sieve, toolNames) {
|
||||
if evt.Content != "" {
|
||||
sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
toolCallsEmitted = true
|
||||
sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
|
||||
}
|
||||
}
|
||||
}
|
||||
obj := util.BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
if toolCallsEmitted {
|
||||
obj["status"] = "completed"
|
||||
}
|
||||
h.getResponseStore().put(owner, responseID, obj)
|
||||
sendEvent("response.completed", util.BuildOpenAIResponsesCompletedPayload(obj))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
if canFlush {
|
||||
_ = rc.Flush()
|
||||
}
|
||||
}
|
||||
streamRuntime := newResponsesStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
responseID,
|
||||
model,
|
||||
finalPrompt,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
toolNames,
|
||||
bufferToolContent,
|
||||
emitEarlyToolDeltas,
|
||||
func(obj map[string]any) {
|
||||
h.getResponseStore().put(owner, responseID, obj)
|
||||
},
|
||||
)
|
||||
streamRuntime.sendCreated()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case parsed, ok := <-parsedLines:
|
||||
if !ok {
|
||||
_ = <-done
|
||||
finalize()
|
||||
return
|
||||
}
|
||||
if !parsed.Parsed {
|
||||
continue
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
finalize()
|
||||
return
|
||||
}
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
if p.Type == "thinking" {
|
||||
if !thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
thinking.WriteString(p.Text)
|
||||
sendEvent("response.reasoning.delta", util.BuildOpenAIResponsesReasoningDeltaPayload(responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
text.WriteString(p.Text)
|
||||
if !bufferToolContent {
|
||||
sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
for _, evt := range processToolSieveChunk(&sieve, p.Text, toolNames) {
|
||||
if evt.Content != "" {
|
||||
sendEvent("response.output_text.delta", util.BuildOpenAIResponsesTextDeltaPayload(responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
toolCallsEmitted = true
|
||||
sendEvent("response.output_tool_call.delta", util.BuildOpenAIResponsesToolCallDeltaPayload(responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs)))
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
toolCallsEmitted = true
|
||||
sendEvent("response.output_tool_call.done", util.BuildOpenAIResponsesToolCallDonePayload(responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
||||
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
||||
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: func(_ streamengine.StopReason, _ error) {
|
||||
streamRuntime.finalize()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func responsesMessagesFromRequest(req map[string]any) []any {
|
||||
|
||||
168
internal/adapter/openai/responses_stream_runtime.go
Normal file
168
internal/adapter/openai/responses_stream_runtime.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
openaifmt "ds2api/internal/format/openai"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type responsesStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
responseID string
|
||||
model string
|
||||
finalPrompt string
|
||||
toolNames []string
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
|
||||
bufferToolContent bool
|
||||
emitEarlyToolDeltas bool
|
||||
toolCallsEmitted bool
|
||||
|
||||
sieve toolStreamSieveState
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
streamToolCallIDs map[int]string
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
|
||||
func newResponsesStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
responseID string,
|
||||
model string,
|
||||
finalPrompt string,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
toolNames []string,
|
||||
bufferToolContent bool,
|
||||
emitEarlyToolDeltas bool,
|
||||
persistResponse func(obj map[string]any),
|
||||
) *responsesStreamRuntime {
|
||||
return &responsesStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
responseID: responseID,
|
||||
model: model,
|
||||
finalPrompt: finalPrompt,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
toolNames: toolNames,
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
persistResponse: persistResponse,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
|
||||
b, _ := json.Marshal(payload)
|
||||
_, _ = s.w.Write([]byte("event: " + event + "\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendCreated() {
|
||||
s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model))
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) sendDone() {
|
||||
_, _ = s.w.Write([]byte("data: [DONE]\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) finalize() {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
if s.bufferToolContent {
|
||||
for _, evt := range flushToolSieve(&s.sieve, s.toolNames) {
|
||||
if evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
|
||||
if s.toolCallsEmitted {
|
||||
obj["status"] = "completed"
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj))
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
if p.Text == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
if p.Type == "thinking" {
|
||||
if !s.thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
|
||||
s.text.WriteString(p.Text)
|
||||
if !s.bufferToolContent {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
|
||||
continue
|
||||
}
|
||||
for _, evt := range processToolSieveChunk(&s.sieve, p.Text, s.toolNames) {
|
||||
if evt.Content != "" {
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !s.emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)))
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, util.FormatOpenAIStreamToolCalls(evt.ToolCalls)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
|
||||
func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
@@ -41,7 +41,7 @@ func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.S
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
|
||||
func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
|
||||
46
internal/admin/deps.go
Normal file
46
internal/admin/deps.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type ConfigStore interface {
|
||||
Snapshot() config.Config
|
||||
Keys() []string
|
||||
Accounts() []config.Account
|
||||
FindAccount(identifier string) (config.Account, bool)
|
||||
UpdateAccountToken(identifier, token string) error
|
||||
Update(mutator func(*config.Config) error) error
|
||||
ExportJSONAndBase64() (string, string, error)
|
||||
IsEnvBacked() bool
|
||||
SetVercelSync(hash string, ts int64) error
|
||||
AdminPasswordHash() string
|
||||
AdminJWTExpireHours() int
|
||||
AdminJWTValidAfterUnix() int64
|
||||
RuntimeAccountMaxInflight() int
|
||||
RuntimeAccountMaxQueue(defaultSize int) int
|
||||
RuntimeGlobalMaxInflight(defaultSize int) int
|
||||
}
|
||||
|
||||
type PoolController interface {
|
||||
Reset()
|
||||
Status() map[string]any
|
||||
ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int)
|
||||
}
|
||||
|
||||
type DeepSeekCaller interface {
|
||||
Login(ctx context.Context, acc config.Account) (string, error)
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
}
|
||||
|
||||
var _ ConfigStore = (*config.Store)(nil)
|
||||
var _ PoolController = (*account.Pool)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
@@ -2,16 +2,12 @@ package admin
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Pool *account.Pool
|
||||
DS *deepseek.Client
|
||||
Store ConfigStore
|
||||
Pool PoolController
|
||||
DS DeepSeekCaller
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
@@ -22,6 +18,11 @@ func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
pr.Get("/vercel/config", h.getVercelConfig)
|
||||
pr.Get("/config", h.getConfig)
|
||||
pr.Post("/config", h.updateConfig)
|
||||
pr.Get("/settings", h.getSettings)
|
||||
pr.Put("/settings", h.updateSettings)
|
||||
pr.Post("/settings/password", h.updateSettingsPassword)
|
||||
pr.Post("/config/import", h.configImport)
|
||||
pr.Get("/config/export", h.configExport)
|
||||
pr.Post("/keys", h.addKey)
|
||||
pr.Delete("/keys/{key}", h.deleteKey)
|
||||
pr.Get("/accounts", h.listAccounts)
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
func (h *Handler) requireAdmin(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := authn.VerifyAdminRequest(r); err != nil {
|
||||
if err := authn.VerifyAdminRequestWithStore(r, h.Store); err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -25,18 +25,18 @@ func (h *Handler) login(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
adminKey, _ := req["admin_key"].(string)
|
||||
expireHours := intFrom(req["expire_hours"])
|
||||
if expireHours <= 0 {
|
||||
expireHours = 24
|
||||
}
|
||||
if adminKey != authn.AdminKey() {
|
||||
if !authn.VerifyAdminCredential(adminKey, h.Store) {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": "Invalid admin key"})
|
||||
return
|
||||
}
|
||||
token, err := authn.CreateJWT(expireHours)
|
||||
token, err := authn.CreateJWTWithStore(expireHours, h.Store)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
if expireHours <= 0 {
|
||||
expireHours = h.Store.AdminJWTExpireHours()
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "token": token, "expires_in": expireHours * 3600})
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (h *Handler) verify(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
token := strings.TrimSpace(header[7:])
|
||||
payload, err := authn.VerifyJWT(token)
|
||||
payload, err := authn.VerifyJWTWithStore(token, h.Store)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusUnauthorized, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -204,38 +203,191 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) {
|
||||
h.configExport(w, nil)
|
||||
}
|
||||
|
||||
func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
jsonStr, b64, err := h.Store.ExportJSONAndBase64()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"json": jsonStr, "base64": b64})
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"config": snap,
|
||||
"json": jsonStr,
|
||||
"base64": b64,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
|
||||
mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode")))
|
||||
if mode == "" {
|
||||
mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode")))
|
||||
}
|
||||
if mode == "" {
|
||||
mode = "merge"
|
||||
}
|
||||
if mode != "merge" && mode != "replace" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"})
|
||||
return
|
||||
}
|
||||
|
||||
payload := req
|
||||
if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 {
|
||||
payload = raw
|
||||
}
|
||||
rawJSON, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"})
|
||||
return
|
||||
}
|
||||
var incoming config.Config
|
||||
if err := json.Unmarshal(rawJSON, &incoming); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
importedKeys, importedAccounts := 0, 0
|
||||
err = h.Store.Update(func(c *config.Config) error {
|
||||
next := c.Clone()
|
||||
if mode == "replace" {
|
||||
next = incoming.Clone()
|
||||
next.VercelSyncHash = c.VercelSyncHash
|
||||
next.VercelSyncTime = c.VercelSyncTime
|
||||
importedKeys = len(next.Keys)
|
||||
importedAccounts = len(next.Accounts)
|
||||
} else {
|
||||
existingKeys := map[string]struct{}{}
|
||||
for _, k := range next.Keys {
|
||||
existingKeys[k] = struct{}{}
|
||||
}
|
||||
for _, k := range incoming.Keys {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existingKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
existingKeys[key] = struct{}{}
|
||||
next.Keys = append(next.Keys, key)
|
||||
importedKeys++
|
||||
}
|
||||
|
||||
existingAccounts := map[string]struct{}{}
|
||||
for _, acc := range next.Accounts {
|
||||
existingAccounts[acc.Identifier()] = struct{}{}
|
||||
}
|
||||
for _, acc := range incoming.Accounts {
|
||||
id := acc.Identifier()
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existingAccounts[id]; ok {
|
||||
continue
|
||||
}
|
||||
existingAccounts[id] = struct{}{}
|
||||
next.Accounts = append(next.Accounts, acc)
|
||||
importedAccounts++
|
||||
}
|
||||
|
||||
if len(incoming.ClaudeMapping) > 0 {
|
||||
if next.ClaudeMapping == nil {
|
||||
next.ClaudeMapping = map[string]string{}
|
||||
}
|
||||
for k, v := range incoming.ClaudeMapping {
|
||||
next.ClaudeMapping[k] = v
|
||||
}
|
||||
}
|
||||
if len(incoming.ClaudeModelMap) > 0 {
|
||||
if next.ClaudeModelMap == nil {
|
||||
next.ClaudeModelMap = map[string]string{}
|
||||
}
|
||||
for k, v := range incoming.ClaudeModelMap {
|
||||
next.ClaudeModelMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if len(incoming.ModelAliases) > 0 {
|
||||
if next.ModelAliases == nil {
|
||||
next.ModelAliases = map[string]string{}
|
||||
}
|
||||
for k, v := range incoming.ModelAliases {
|
||||
next.ModelAliases[k] = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(incoming.Toolcall.Mode) != "" {
|
||||
next.Toolcall.Mode = incoming.Toolcall.Mode
|
||||
}
|
||||
if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" {
|
||||
next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence
|
||||
}
|
||||
if incoming.Responses.StoreTTLSeconds > 0 {
|
||||
next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds
|
||||
}
|
||||
if strings.TrimSpace(incoming.Embeddings.Provider) != "" {
|
||||
next.Embeddings.Provider = incoming.Embeddings.Provider
|
||||
}
|
||||
if strings.TrimSpace(incoming.Admin.PasswordHash) != "" {
|
||||
next.Admin.PasswordHash = incoming.Admin.PasswordHash
|
||||
}
|
||||
if incoming.Admin.JWTExpireHours > 0 {
|
||||
next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours
|
||||
}
|
||||
if incoming.Admin.JWTValidAfterUnix > 0 {
|
||||
next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix
|
||||
}
|
||||
if incoming.Runtime.AccountMaxInflight > 0 {
|
||||
next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight
|
||||
}
|
||||
if incoming.Runtime.AccountMaxQueue > 0 {
|
||||
next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue
|
||||
}
|
||||
if incoming.Runtime.GlobalMaxInflight > 0 {
|
||||
next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
|
||||
normalizeSettingsConfig(&next)
|
||||
if err := validateSettingsConfig(next); err != nil {
|
||||
return newRequestError(err.Error())
|
||||
}
|
||||
|
||||
*c = next
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if detail, ok := requestErrorDetail(err); ok {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"mode": mode,
|
||||
"imported_keys": importedKeys,
|
||||
"imported_accounts": importedAccounts,
|
||||
"message": "config imported",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) computeSyncHash() string {
|
||||
snap := h.Store.Snapshot()
|
||||
syncable := map[string]any{"keys": snap.Keys, "accounts": []map[string]any{}}
|
||||
accounts := make([]map[string]any, 0, len(snap.Accounts))
|
||||
for _, a := range snap.Accounts {
|
||||
m := map[string]any{}
|
||||
if a.Email != "" {
|
||||
m["email"] = a.Email
|
||||
}
|
||||
if a.Mobile != "" {
|
||||
m["mobile"] = a.Mobile
|
||||
}
|
||||
if a.Password != "" {
|
||||
m["password"] = a.Password
|
||||
}
|
||||
accounts = append(accounts, m)
|
||||
}
|
||||
sort.Slice(accounts, func(i, j int) bool {
|
||||
ai := fmt.Sprintf("%v%v", accounts[i]["email"], accounts[i]["mobile"])
|
||||
aj := fmt.Sprintf("%v%v", accounts[j]["email"], accounts[j]["mobile"])
|
||||
return ai < aj
|
||||
})
|
||||
syncable["accounts"] = accounts
|
||||
b, _ := json.Marshal(syncable)
|
||||
snap := h.Store.Snapshot().Clone()
|
||||
snap.VercelSyncHash = ""
|
||||
snap.VercelSyncTime = 0
|
||||
b, _ := json.Marshal(snap)
|
||||
sum := md5.Sum(b)
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
|
||||
321
internal/admin/handler_settings.go
Normal file
321
internal/admin/handler_settings.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
authn "ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
|
||||
snap := h.Store.Snapshot()
|
||||
recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight())
|
||||
needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"admin": map[string]any{
|
||||
"has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "",
|
||||
"jwt_expire_hours": h.Store.AdminJWTExpireHours(),
|
||||
"jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix,
|
||||
"default_password_warning": authn.UsingDefaultAdminKey(h.Store),
|
||||
},
|
||||
"runtime": map[string]any{
|
||||
"account_max_inflight": h.Store.RuntimeAccountMaxInflight(),
|
||||
"account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended),
|
||||
"global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
|
||||
},
|
||||
"toolcall": snap.Toolcall,
|
||||
"responses": snap.Responses,
|
||||
"embeddings": snap.Embeddings,
|
||||
"claude_mapping": settingsClaudeMapping(snap),
|
||||
"model_aliases": snap.ModelAliases,
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"needs_vercel_sync": needsSync,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
|
||||
adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
if runtimeCfg != nil {
|
||||
if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.Store.Update(func(c *config.Config) error {
|
||||
if adminCfg != nil {
|
||||
if adminCfg.JWTExpireHours > 0 {
|
||||
c.Admin.JWTExpireHours = adminCfg.JWTExpireHours
|
||||
}
|
||||
}
|
||||
if runtimeCfg != nil {
|
||||
if runtimeCfg.AccountMaxInflight > 0 {
|
||||
c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight
|
||||
}
|
||||
if runtimeCfg.AccountMaxQueue > 0 {
|
||||
c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue
|
||||
}
|
||||
if runtimeCfg.GlobalMaxInflight > 0 {
|
||||
c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
if toolcallCfg != nil {
|
||||
if strings.TrimSpace(toolcallCfg.Mode) != "" {
|
||||
c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode)
|
||||
}
|
||||
if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" {
|
||||
c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence)
|
||||
}
|
||||
}
|
||||
if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 {
|
||||
c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds
|
||||
}
|
||||
if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
|
||||
c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
|
||||
}
|
||||
if claudeMap != nil {
|
||||
c.ClaudeMapping = claudeMap
|
||||
c.ClaudeModelMap = nil
|
||||
}
|
||||
if aliasMap != nil {
|
||||
c.ModelAliases = aliasMap
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.applyRuntimeSettings()
|
||||
needsSync := config.IsVercel() || h.Store.IsEnvBacked()
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"message": "settings updated and hot reloaded",
|
||||
"env_backed": h.Store.IsEnvBacked(),
|
||||
"needs_vercel_sync": needsSync,
|
||||
"manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。",
|
||||
})
|
||||
}
|
||||
|
||||
func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error {
|
||||
merged := current
|
||||
if incoming != nil {
|
||||
if incoming.AccountMaxInflight > 0 {
|
||||
merged.AccountMaxInflight = incoming.AccountMaxInflight
|
||||
}
|
||||
if incoming.AccountMaxQueue > 0 {
|
||||
merged.AccountMaxQueue = incoming.AccountMaxQueue
|
||||
}
|
||||
if incoming.GlobalMaxInflight > 0 {
|
||||
merged.GlobalMaxInflight = incoming.GlobalMaxInflight
|
||||
}
|
||||
}
|
||||
return validateRuntimeSettings(merged)
|
||||
}
|
||||
|
||||
func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
newPassword := strings.TrimSpace(fieldString(req, "new_password"))
|
||||
if newPassword == "" {
|
||||
newPassword = strings.TrimSpace(fieldString(req, "password"))
|
||||
}
|
||||
if len(newPassword) < 4 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"})
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
hash := authn.HashAdminPassword(newPassword)
|
||||
if err := h.Store.Update(func(c *config.Config) error {
|
||||
c.Admin.PasswordHash = hash
|
||||
c.Admin.JWTValidAfterUnix = now
|
||||
return nil
|
||||
}); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"success": true,
|
||||
"message": "password updated",
|
||||
"force_relogin": true,
|
||||
"jwt_valid_after_unix": now,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) applyRuntimeSettings() {
|
||||
if h == nil || h.Store == nil || h.Pool == nil {
|
||||
return
|
||||
}
|
||||
accountCount := len(h.Store.Accounts())
|
||||
maxPer := h.Store.RuntimeAccountMaxInflight()
|
||||
recommended := defaultRuntimeRecommended(accountCount, maxPer)
|
||||
maxQueue := h.Store.RuntimeAccountMaxQueue(recommended)
|
||||
global := h.Store.RuntimeGlobalMaxInflight(recommended)
|
||||
h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global)
|
||||
}
|
||||
|
||||
func defaultRuntimeRecommended(accountCount, maxPer int) int {
|
||||
if maxPer <= 0 {
|
||||
maxPer = 1
|
||||
}
|
||||
if accountCount <= 0 {
|
||||
return maxPer
|
||||
}
|
||||
return accountCount * maxPer
|
||||
}
|
||||
|
||||
func settingsClaudeMapping(c config.Config) map[string]string {
|
||||
if len(c.ClaudeMapping) > 0 {
|
||||
return c.ClaudeMapping
|
||||
}
|
||||
if len(c.ClaudeModelMap) > 0 {
|
||||
return c.ClaudeModelMap
|
||||
}
|
||||
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
|
||||
}
|
||||
|
||||
func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
|
||||
var (
|
||||
adminCfg *config.AdminConfig
|
||||
runtimeCfg *config.RuntimeConfig
|
||||
toolcallCfg *config.ToolcallConfig
|
||||
respCfg *config.ResponsesConfig
|
||||
embCfg *config.EmbeddingsConfig
|
||||
claudeMap map[string]string
|
||||
aliasMap map[string]string
|
||||
)
|
||||
|
||||
if raw, ok := req["admin"].(map[string]any); ok {
|
||||
cfg := &config.AdminConfig{}
|
||||
if v, exists := raw["jwt_expire_hours"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 720 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
|
||||
}
|
||||
cfg.JWTExpireHours = n
|
||||
}
|
||||
adminCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["runtime"].(map[string]any); ok {
|
||||
cfg := &config.RuntimeConfig{}
|
||||
if v, exists := raw["account_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 256 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
|
||||
}
|
||||
cfg.AccountMaxInflight = n
|
||||
}
|
||||
if v, exists := raw["account_max_queue"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 200000 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
|
||||
}
|
||||
cfg.AccountMaxQueue = n
|
||||
}
|
||||
if v, exists := raw["global_max_inflight"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 1 || n > 200000 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
|
||||
}
|
||||
cfg.GlobalMaxInflight = n
|
||||
}
|
||||
if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
}
|
||||
runtimeCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["toolcall"].(map[string]any); ok {
|
||||
cfg := &config.ToolcallConfig{}
|
||||
if v, exists := raw["mode"]; exists {
|
||||
mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
switch mode {
|
||||
case "feature_match", "off":
|
||||
cfg.Mode = mode
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
|
||||
}
|
||||
}
|
||||
if v, exists := raw["early_emit_confidence"]; exists {
|
||||
level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
|
||||
switch level {
|
||||
case "high", "low", "off":
|
||||
cfg.EarlyEmitConfidence = level
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
||||
}
|
||||
}
|
||||
toolcallCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["responses"].(map[string]any); ok {
|
||||
cfg := &config.ResponsesConfig{}
|
||||
if v, exists := raw["store_ttl_seconds"]; exists {
|
||||
n := intFrom(v)
|
||||
if n < 30 || n > 86400 {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
|
||||
}
|
||||
cfg.StoreTTLSeconds = n
|
||||
}
|
||||
respCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["embeddings"].(map[string]any); ok {
|
||||
cfg := &config.EmbeddingsConfig{}
|
||||
if v, exists := raw["provider"]; exists {
|
||||
p := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if p == "" {
|
||||
return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
|
||||
}
|
||||
cfg.Provider = p
|
||||
}
|
||||
embCfg = cfg
|
||||
}
|
||||
|
||||
if raw, ok := req["claude_mapping"].(map[string]any); ok {
|
||||
claudeMap = map[string]string{}
|
||||
for k, v := range raw {
|
||||
key := strings.TrimSpace(k)
|
||||
val := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
claudeMap[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
if raw, ok := req["model_aliases"].(map[string]any); ok {
|
||||
aliasMap = map[string]string{}
|
||||
for k, v := range raw {
|
||||
key := strings.TrimSpace(k)
|
||||
val := strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
aliasMap[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
|
||||
}
|
||||
267
internal/admin/handler_settings_test.go
Normal file
267
internal/admin/handler_settings_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
authn "ds2api/internal/auth"
|
||||
)
|
||||
|
||||
func TestGetSettingsDefaultPasswordWarning(t *testing.T) {
|
||||
t.Setenv("DS2API_ADMIN_KEY", "")
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/settings", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
h.getSettings(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var body map[string]any
|
||||
_ = json.Unmarshal(rec.Body.Bytes(), &body)
|
||||
admin, _ := body["admin"].(map[string]any)
|
||||
warn, _ := admin["default_password_warning"].(bool)
|
||||
if !warn {
|
||||
t.Fatalf("expected default password warning true, body=%v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsValidation(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
payload := map[string]any{
|
||||
"runtime": map[string]any{
|
||||
"account_max_inflight": 0,
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateSettings(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsValidationWithMergedRuntimeSnapshot(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
"runtime":{
|
||||
"account_max_inflight":8,
|
||||
"global_max_inflight":8
|
||||
}
|
||||
}`)
|
||||
payload := map[string]any{
|
||||
"runtime": map[string]any{
|
||||
"account_max_inflight": 16,
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateSettings(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.global_max_inflight")) {
|
||||
t.Fatalf("expected merged runtime validation detail, got %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsWithoutRuntimeSkipsMergedRuntimeValidation(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
"runtime":{
|
||||
"account_max_inflight":8,
|
||||
"global_max_inflight":4
|
||||
}
|
||||
}`)
|
||||
payload := map[string]any{
|
||||
"responses": map[string]any{
|
||||
"store_ttl_seconds": 600,
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateSettings(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := h.Store.Snapshot().Responses.StoreTTLSeconds; got != 600 {
|
||||
t.Fatalf("store_ttl_seconds=%d want=600", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsHotReloadRuntime(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"email":"a@test.com","token":"t1"},{"email":"b@test.com","token":"t2"}]
|
||||
}`)
|
||||
|
||||
payload := map[string]any{
|
||||
"runtime": map[string]any{
|
||||
"account_max_inflight": 3,
|
||||
"account_max_queue": 20,
|
||||
"global_max_inflight": 5,
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateSettings(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
status := h.Pool.Status()
|
||||
if got := intFrom(status["max_inflight_per_account"]); got != 3 {
|
||||
t.Fatalf("max_inflight_per_account=%d want=3", got)
|
||||
}
|
||||
if got := intFrom(status["max_queue_size"]); got != 20 {
|
||||
t.Fatalf("max_queue_size=%d want=20", got)
|
||||
}
|
||||
if got := intFrom(status["global_max_inflight"]); got != 5 {
|
||||
t.Fatalf("global_max_inflight=%d want=5", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSettingsPasswordInvalidatesOldJWT(t *testing.T) {
|
||||
hash := authn.HashAdminPassword("old-password")
|
||||
h := newAdminTestHandler(t, `{"admin":{"password_hash":"`+hash+`"}}`)
|
||||
|
||||
token, err := authn.CreateJWTWithStore(1, h.Store)
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt failed: %v", err)
|
||||
}
|
||||
if _, err := authn.VerifyJWTWithStore(token, h.Store); err != nil {
|
||||
t.Fatalf("verify before update failed: %v", err)
|
||||
}
|
||||
|
||||
body := map[string]any{"new_password": "new-password"}
|
||||
b, _ := json.Marshal(body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/settings/password", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateSettingsPassword(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
if _, err := authn.VerifyJWTWithStore(token, h.Store); err == nil {
|
||||
t.Fatal("expected old token to be invalid after password update")
|
||||
}
|
||||
if !authn.VerifyAdminCredential("new-password", h.Store) {
|
||||
t.Fatal("expected new password credential to be accepted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportMergeAndReplace(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"email":"a@test.com","password":"p1"}]
|
||||
}`)
|
||||
|
||||
merge := map[string]any{
|
||||
"mode": "merge",
|
||||
"config": map[string]any{
|
||||
"keys": []any{"k1", "k2"},
|
||||
"accounts": []any{
|
||||
map[string]any{"email": "a@test.com", "password": "p1"},
|
||||
map[string]any{"email": "b@test.com", "password": "p2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
mergeBytes, _ := json.Marshal(merge)
|
||||
mergeReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(mergeBytes))
|
||||
mergeRec := httptest.NewRecorder()
|
||||
h.configImport(mergeRec, mergeReq)
|
||||
if mergeRec.Code != http.StatusOK {
|
||||
t.Fatalf("merge status=%d body=%s", mergeRec.Code, mergeRec.Body.String())
|
||||
}
|
||||
if got := len(h.Store.Keys()); got != 2 {
|
||||
t.Fatalf("keys after merge=%d want=2", got)
|
||||
}
|
||||
if got := len(h.Store.Accounts()); got != 2 {
|
||||
t.Fatalf("accounts after merge=%d want=2", got)
|
||||
}
|
||||
|
||||
replace := map[string]any{
|
||||
"mode": "replace",
|
||||
"config": map[string]any{
|
||||
"keys": []any{"k9"},
|
||||
},
|
||||
}
|
||||
replaceBytes, _ := json.Marshal(replace)
|
||||
replaceReq := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=replace", bytes.NewReader(replaceBytes))
|
||||
replaceRec := httptest.NewRecorder()
|
||||
h.configImport(replaceRec, replaceReq)
|
||||
if replaceRec.Code != http.StatusOK {
|
||||
t.Fatalf("replace status=%d body=%s", replaceRec.Code, replaceRec.Body.String())
|
||||
}
|
||||
keys := h.Store.Keys()
|
||||
if len(keys) != 1 || keys[0] != "k9" {
|
||||
t.Fatalf("unexpected keys after replace: %#v", keys)
|
||||
}
|
||||
if got := len(h.Store.Accounts()); got != 0 {
|
||||
t.Fatalf("accounts after replace=%d want=0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportRejectsInvalidRuntimeBounds(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
|
||||
payload := map[string]any{
|
||||
"mode": "replace",
|
||||
"config": map[string]any{
|
||||
"keys": []any{"k2"},
|
||||
"runtime": map[string]any{
|
||||
"account_max_inflight": 300,
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=replace", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.configImport(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.account_max_inflight")) {
|
||||
t.Fatalf("expected runtime bound detail, got %s", rec.Body.String())
|
||||
}
|
||||
keys := h.Store.Keys()
|
||||
if len(keys) != 1 || keys[0] != "k1" {
|
||||
t.Fatalf("store should remain unchanged, keys=%v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportRejectsMergedRuntimeConflict(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
"runtime":{
|
||||
"account_max_inflight":8,
|
||||
"global_max_inflight":8
|
||||
}
|
||||
}`)
|
||||
payload := map[string]any{
|
||||
"mode": "merge",
|
||||
"config": map[string]any{
|
||||
"runtime": map[string]any{
|
||||
"account_max_inflight": 16,
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.configImport(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !bytes.Contains(rec.Body.Bytes(), []byte("runtime.global_max_inflight")) {
|
||||
t.Fatalf("expected merged runtime validation detail, got %s", rec.Body.String())
|
||||
}
|
||||
snap := h.Store.Snapshot()
|
||||
if snap.Runtime.AccountMaxInflight != 8 || snap.Runtime.GlobalMaxInflight != 8 {
|
||||
t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime)
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,8 @@ package admin
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -19,6 +19,62 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
|
||||
return
|
||||
}
|
||||
opts, err := parseVercelSyncOptions(req)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
validated, failed := h.validateAccountsForVercelSync(r.Context(), opts.AutoValidate)
|
||||
_, cfgB64, err := h.Store.ExportJSONAndBase64()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
params := buildVercelParams(opts.TeamID)
|
||||
headers := map[string]string{"Authorization": "Bearer " + opts.VercelToken}
|
||||
|
||||
envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+opts.ProjectID+"/env", params, headers, nil)
|
||||
if err != nil || status != http.StatusOK {
|
||||
writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"})
|
||||
return
|
||||
}
|
||||
envs, _ := envResp["envs"].([]any)
|
||||
status, err = upsertVercelEnv(r.Context(), client, opts.ProjectID, params, headers, envs, "DS2API_CONFIG_JSON", cfgB64)
|
||||
if err != nil || (status != http.StatusOK && status != http.StatusCreated) {
|
||||
writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"})
|
||||
return
|
||||
}
|
||||
savedCreds := h.saveVercelProjectCredentials(r.Context(), client, opts, params, headers, envs)
|
||||
manual, deployURL := triggerVercelDeployment(r.Context(), client, opts.ProjectID, params, headers)
|
||||
_ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix())
|
||||
result := map[string]any{"success": true, "validated_accounts": validated}
|
||||
if manual {
|
||||
result["message"] = "配置已同步到 Vercel,请手动触发重新部署"
|
||||
result["manual_deploy_required"] = true
|
||||
} else {
|
||||
result["message"] = "配置已同步,正在重新部署..."
|
||||
result["deployment_url"] = deployURL
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
result["failed_accounts"] = failed
|
||||
}
|
||||
if len(savedCreds) > 0 {
|
||||
result["saved_credentials"] = savedCreds
|
||||
}
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
type vercelSyncOptions struct {
|
||||
VercelToken string
|
||||
ProjectID string
|
||||
TeamID string
|
||||
AutoValidate bool
|
||||
SaveCreds bool
|
||||
UsePreconfig bool
|
||||
}
|
||||
|
||||
func parseVercelSyncOptions(req map[string]any) (vercelSyncOptions, error) {
|
||||
vercelToken, _ := req["vercel_token"].(string)
|
||||
projectID, _ := req["project_id"].(string)
|
||||
teamID, _ := req["team_id"].(string)
|
||||
@@ -40,108 +96,117 @@ func (h *Handler) syncVercel(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(teamID) == "" {
|
||||
teamID = strings.TrimSpace(os.Getenv("VERCEL_TEAM_ID"))
|
||||
}
|
||||
vercelToken = strings.TrimSpace(vercelToken)
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
teamID = strings.TrimSpace(teamID)
|
||||
if vercelToken == "" || projectID == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 Vercel Token 和 Project ID"})
|
||||
return
|
||||
return vercelSyncOptions{}, fmt.Errorf("需要 Vercel Token 和 Project ID")
|
||||
}
|
||||
return vercelSyncOptions{
|
||||
VercelToken: vercelToken,
|
||||
ProjectID: projectID,
|
||||
TeamID: teamID,
|
||||
AutoValidate: autoValidate,
|
||||
SaveCreds: saveCreds,
|
||||
UsePreconfig: usePreconfig,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildVercelParams(teamID string) url.Values {
|
||||
params := url.Values{}
|
||||
if strings.TrimSpace(teamID) != "" {
|
||||
params.Set("teamId", strings.TrimSpace(teamID))
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func (h *Handler) validateAccountsForVercelSync(ctx context.Context, enabled bool) (int, []string) {
|
||||
if !enabled {
|
||||
return 0, nil
|
||||
}
|
||||
validated, failed := 0, []string{}
|
||||
if autoValidate {
|
||||
for _, acc := range h.Store.Snapshot().Accounts {
|
||||
if strings.TrimSpace(acc.Token) != "" {
|
||||
continue
|
||||
}
|
||||
token, err := h.DS.Login(r.Context(), acc)
|
||||
if err != nil {
|
||||
failed = append(failed, acc.Identifier())
|
||||
} else {
|
||||
validated++
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
for _, acc := range h.Store.Snapshot().Accounts {
|
||||
if strings.TrimSpace(acc.Token) != "" {
|
||||
continue
|
||||
}
|
||||
token, err := h.DS.Login(ctx, acc)
|
||||
if err != nil {
|
||||
failed = append(failed, acc.Identifier())
|
||||
} else {
|
||||
validated++
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
return validated, failed
|
||||
}
|
||||
|
||||
cfgJSON, _, err := h.Store.ExportJSONAndBase64()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
func upsertVercelEnv(ctx context.Context, client *http.Client, projectID string, params url.Values, headers map[string]string, envs []any, key, value string) (int, error) {
|
||||
existingID := findEnvID(envs, key)
|
||||
if existingID != "" {
|
||||
_, status, err := vercelRequest(ctx, client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingID, params, headers, map[string]any{"value": value})
|
||||
return status, err
|
||||
}
|
||||
cfgB64 := base64.StdEncoding.EncodeToString([]byte(cfgJSON))
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
params := url.Values{}
|
||||
if teamID != "" {
|
||||
params.Set("teamId", teamID)
|
||||
_, status, err := vercelRequest(ctx, client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{
|
||||
"key": key,
|
||||
"value": value,
|
||||
"type": "encrypted",
|
||||
"target": []string{"production", "preview"},
|
||||
})
|
||||
return status, err
|
||||
}
|
||||
|
||||
func (h *Handler) saveVercelProjectCredentials(ctx context.Context, client *http.Client, opts vercelSyncOptions, params url.Values, headers map[string]string, envs []any) []string {
|
||||
if !opts.SaveCreds || opts.UsePreconfig {
|
||||
return nil
|
||||
}
|
||||
headers := map[string]string{"Authorization": "Bearer " + vercelToken}
|
||||
envResp, status, err := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID+"/env", params, headers, nil)
|
||||
if err != nil || status != http.StatusOK {
|
||||
writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "获取环境变量失败"})
|
||||
return
|
||||
saved := []string{}
|
||||
creds := [][2]string{{"VERCEL_TOKEN", opts.VercelToken}, {"VERCEL_PROJECT_ID", opts.ProjectID}}
|
||||
if opts.TeamID != "" {
|
||||
creds = append(creds, [2]string{"VERCEL_TEAM_ID", opts.TeamID})
|
||||
}
|
||||
envs, _ := envResp["envs"].([]any)
|
||||
existingEnvID := findEnvID(envs, "DS2API_CONFIG_JSON")
|
||||
if existingEnvID != "" {
|
||||
_, status, err = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+existingEnvID, params, headers, map[string]any{"value": cfgB64})
|
||||
} else {
|
||||
_, status, err = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": "DS2API_CONFIG_JSON", "value": cfgB64, "type": "encrypted", "target": []string{"production", "preview"}})
|
||||
}
|
||||
if err != nil || (status != http.StatusOK && status != http.StatusCreated) {
|
||||
writeJSON(w, statusOr(status, http.StatusInternalServerError), map[string]any{"detail": "更新环境变量失败"})
|
||||
return
|
||||
}
|
||||
savedCreds := []string{}
|
||||
if saveCreds && !usePreconfig {
|
||||
creds := [][2]string{{"VERCEL_TOKEN", vercelToken}, {"VERCEL_PROJECT_ID", projectID}}
|
||||
if teamID != "" {
|
||||
creds = append(creds, [2]string{"VERCEL_TEAM_ID", teamID})
|
||||
}
|
||||
for _, kv := range creds {
|
||||
id := findEnvID(envs, kv[0])
|
||||
if id != "" {
|
||||
_, status, _ = vercelRequest(r.Context(), client, http.MethodPatch, "https://api.vercel.com/v9/projects/"+projectID+"/env/"+id, params, headers, map[string]any{"value": kv[1]})
|
||||
} else {
|
||||
_, status, _ = vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v10/projects/"+projectID+"/env", params, headers, map[string]any{"key": kv[0], "value": kv[1], "type": "encrypted", "target": []string{"production", "preview"}})
|
||||
}
|
||||
if status == http.StatusOK || status == http.StatusCreated {
|
||||
savedCreds = append(savedCreds, kv[0])
|
||||
}
|
||||
for _, kv := range creds {
|
||||
status, _ := upsertVercelEnv(ctx, client, opts.ProjectID, params, headers, envs, kv[0], kv[1])
|
||||
if status == http.StatusOK || status == http.StatusCreated {
|
||||
saved = append(saved, kv[0])
|
||||
}
|
||||
}
|
||||
projectResp, status, _ := vercelRequest(r.Context(), client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil)
|
||||
manual := true
|
||||
deployURL := ""
|
||||
if status == http.StatusOK {
|
||||
if link, ok := projectResp["link"].(map[string]any); ok {
|
||||
if linkType, _ := link["type"].(string); linkType == "github" {
|
||||
repoID := intFrom(link["repoId"])
|
||||
ref, _ := link["productionBranch"].(string)
|
||||
if ref == "" {
|
||||
ref = "main"
|
||||
}
|
||||
depResp, depStatus, _ := vercelRequest(r.Context(), client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{"name": projectID, "project": projectID, "target": "production", "gitSource": map[string]any{"type": "github", "repoId": repoID, "ref": ref}})
|
||||
if depStatus == http.StatusOK || depStatus == http.StatusCreated {
|
||||
deployURL, _ = depResp["url"].(string)
|
||||
manual = false
|
||||
}
|
||||
}
|
||||
}
|
||||
return saved
|
||||
}
|
||||
|
||||
func triggerVercelDeployment(ctx context.Context, client *http.Client, projectID string, params url.Values, headers map[string]string) (bool, string) {
|
||||
projectResp, status, _ := vercelRequest(ctx, client, http.MethodGet, "https://api.vercel.com/v9/projects/"+projectID, params, headers, nil)
|
||||
if status != http.StatusOK {
|
||||
return true, ""
|
||||
}
|
||||
_ = h.Store.SetVercelSync(h.computeSyncHash(), time.Now().Unix())
|
||||
result := map[string]any{"success": true, "validated_accounts": validated}
|
||||
if manual {
|
||||
result["message"] = "配置已同步到 Vercel,请手动触发重新部署"
|
||||
result["manual_deploy_required"] = true
|
||||
} else {
|
||||
result["message"] = "配置已同步,正在重新部署..."
|
||||
result["deployment_url"] = deployURL
|
||||
link, ok := projectResp["link"].(map[string]any)
|
||||
if !ok {
|
||||
return true, ""
|
||||
}
|
||||
if len(failed) > 0 {
|
||||
result["failed_accounts"] = failed
|
||||
linkType, _ := link["type"].(string)
|
||||
if linkType != "github" {
|
||||
return true, ""
|
||||
}
|
||||
if len(savedCreds) > 0 {
|
||||
result["saved_credentials"] = savedCreds
|
||||
repoID := intFrom(link["repoId"])
|
||||
ref, _ := link["productionBranch"].(string)
|
||||
if ref == "" {
|
||||
ref = "main"
|
||||
}
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
depResp, depStatus, _ := vercelRequest(ctx, client, http.MethodPost, "https://api.vercel.com/v13/deployments", params, headers, map[string]any{
|
||||
"name": projectID,
|
||||
"project": projectID,
|
||||
"target": "production",
|
||||
"gitSource": map[string]any{
|
||||
"type": "github",
|
||||
"repoId": repoID,
|
||||
"ref": ref,
|
||||
},
|
||||
})
|
||||
if depStatus != http.StatusOK && depStatus != http.StatusCreated {
|
||||
return true, ""
|
||||
}
|
||||
deployURL, _ := depResp["url"].(string)
|
||||
return false, deployURL
|
||||
}
|
||||
|
||||
func (h *Handler) vercelStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
|
||||
@@ -96,7 +96,7 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool {
|
||||
return acc.Identifier() == id
|
||||
}
|
||||
|
||||
func findAccountByIdentifier(store *config.Store, identifier string) (config.Account, bool) {
|
||||
func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) {
|
||||
id := strings.TrimSpace(identifier)
|
||||
if id == "" {
|
||||
return config.Account{}, false
|
||||
|
||||
23
internal/admin/request_error.go
Normal file
23
internal/admin/request_error.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package admin
|
||||
|
||||
import "errors"
|
||||
|
||||
type requestError struct {
|
||||
detail string
|
||||
}
|
||||
|
||||
func (e *requestError) Error() string {
|
||||
return e.detail
|
||||
}
|
||||
|
||||
func newRequestError(detail string) error {
|
||||
return &requestError{detail: detail}
|
||||
}
|
||||
|
||||
func requestErrorDetail(err error) (string, bool) {
|
||||
var reqErr *requestError
|
||||
if errors.As(err, &reqErr) {
|
||||
return reqErr.detail, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
64
internal/admin/settings_validation.go
Normal file
64
internal/admin/settings_validation.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func normalizeSettingsConfig(c *config.Config) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.Admin.PasswordHash = strings.TrimSpace(c.Admin.PasswordHash)
|
||||
c.Toolcall.Mode = strings.ToLower(strings.TrimSpace(c.Toolcall.Mode))
|
||||
c.Toolcall.EarlyEmitConfidence = strings.ToLower(strings.TrimSpace(c.Toolcall.EarlyEmitConfidence))
|
||||
c.Embeddings.Provider = strings.TrimSpace(c.Embeddings.Provider)
|
||||
}
|
||||
|
||||
func validateSettingsConfig(c config.Config) error {
|
||||
if c.Admin.JWTExpireHours != 0 && (c.Admin.JWTExpireHours < 1 || c.Admin.JWTExpireHours > 720) {
|
||||
return fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
|
||||
}
|
||||
if err := validateRuntimeSettings(c.Runtime); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Responses.StoreTTLSeconds != 0 && (c.Responses.StoreTTLSeconds < 30 || c.Responses.StoreTTLSeconds > 86400) {
|
||||
return fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
|
||||
}
|
||||
if mode := strings.TrimSpace(c.Toolcall.Mode); mode != "" {
|
||||
switch mode {
|
||||
case "feature_match", "off":
|
||||
default:
|
||||
return fmt.Errorf("toolcall.mode must be feature_match or off")
|
||||
}
|
||||
}
|
||||
if level := strings.TrimSpace(c.Toolcall.EarlyEmitConfidence); level != "" {
|
||||
switch level {
|
||||
case "high", "low", "off":
|
||||
default:
|
||||
return fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
|
||||
}
|
||||
}
|
||||
if c.Embeddings.Provider != "" && strings.TrimSpace(c.Embeddings.Provider) == "" {
|
||||
return fmt.Errorf("embeddings.provider cannot be empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateRuntimeSettings(runtime config.RuntimeConfig) error {
|
||||
if runtime.AccountMaxInflight != 0 && (runtime.AccountMaxInflight < 1 || runtime.AccountMaxInflight > 256) {
|
||||
return fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
|
||||
}
|
||||
if runtime.AccountMaxQueue != 0 && (runtime.AccountMaxQueue < 1 || runtime.AccountMaxQueue > 200000) {
|
||||
return fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
|
||||
}
|
||||
if runtime.GlobalMaxInflight != 0 && (runtime.GlobalMaxInflight < 1 || runtime.GlobalMaxInflight > 200000) {
|
||||
return fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
|
||||
}
|
||||
if runtime.AccountMaxInflight > 0 && runtime.GlobalMaxInflight > 0 && runtime.GlobalMaxInflight < runtime.AccountMaxInflight {
|
||||
return fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -3,7 +3,9 @@ package auth
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
@@ -17,7 +19,22 @@ import (
|
||||
|
||||
var warnOnce sync.Once
|
||||
|
||||
type AdminConfigReader interface {
|
||||
AdminPasswordHash() string
|
||||
AdminJWTExpireHours() int
|
||||
AdminJWTValidAfterUnix() int64
|
||||
}
|
||||
|
||||
func AdminKey() string {
|
||||
return effectiveAdminKey(nil)
|
||||
}
|
||||
|
||||
func effectiveAdminKey(store AdminConfigReader) string {
|
||||
if store != nil {
|
||||
if hash := strings.TrimSpace(store.AdminPasswordHash()); hash != "" {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" {
|
||||
return v
|
||||
}
|
||||
@@ -27,14 +44,24 @@ func AdminKey() string {
|
||||
return "admin"
|
||||
}
|
||||
|
||||
func jwtSecret() string {
|
||||
func jwtSecret(store AdminConfigReader) string {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_JWT_SECRET")); v != "" {
|
||||
return v
|
||||
}
|
||||
return AdminKey()
|
||||
if store != nil {
|
||||
if hash := strings.TrimSpace(store.AdminPasswordHash()); hash != "" {
|
||||
return hash
|
||||
}
|
||||
}
|
||||
return effectiveAdminKey(store)
|
||||
}
|
||||
|
||||
func jwtExpireHours() int {
|
||||
func jwtExpireHours(store AdminConfigReader) int {
|
||||
if store != nil {
|
||||
if n := store.AdminJWTExpireHours(); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
@@ -44,27 +71,44 @@ func jwtExpireHours() int {
|
||||
}
|
||||
|
||||
func CreateJWT(expireHours int) (string, error) {
|
||||
return CreateJWTWithStore(expireHours, nil)
|
||||
}
|
||||
|
||||
func CreateJWTWithStore(expireHours int, store AdminConfigReader) (string, error) {
|
||||
if expireHours <= 0 {
|
||||
expireHours = jwtExpireHours()
|
||||
expireHours = jwtExpireHours(store)
|
||||
}
|
||||
issuedAt := time.Now().Unix()
|
||||
// If sessions were invalidated in this same second, move iat forward by
|
||||
// one second so newly minted tokens remain valid with strict cutoff checks.
|
||||
if store != nil {
|
||||
if validAfter := store.AdminJWTValidAfterUnix(); validAfter >= issuedAt {
|
||||
issuedAt = validAfter + 1
|
||||
}
|
||||
}
|
||||
expireAt := time.Unix(issuedAt, 0).Add(time.Duration(expireHours) * time.Hour).Unix()
|
||||
header := map[string]any{"alg": "HS256", "typ": "JWT"}
|
||||
payload := map[string]any{"iat": time.Now().Unix(), "exp": time.Now().Add(time.Duration(expireHours) * time.Hour).Unix(), "role": "admin"}
|
||||
payload := map[string]any{"iat": issuedAt, "exp": expireAt, "role": "admin"}
|
||||
h, _ := json.Marshal(header)
|
||||
p, _ := json.Marshal(payload)
|
||||
headerB64 := rawB64Encode(h)
|
||||
payloadB64 := rawB64Encode(p)
|
||||
msg := headerB64 + "." + payloadB64
|
||||
sig := signHS256(msg)
|
||||
sig := signHS256(msg, store)
|
||||
return msg + "." + rawB64Encode(sig), nil
|
||||
}
|
||||
|
||||
func VerifyJWT(token string) (map[string]any, error) {
|
||||
return VerifyJWTWithStore(token, nil)
|
||||
}
|
||||
|
||||
func VerifyJWTWithStore(token string, store AdminConfigReader) (map[string]any, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, errors.New("invalid token format")
|
||||
}
|
||||
msg := parts[0] + "." + parts[1]
|
||||
expected := signHS256(msg)
|
||||
expected := signHS256(msg, store)
|
||||
actual, err := rawB64Decode(parts[2])
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid signature")
|
||||
@@ -84,10 +128,23 @@ func VerifyJWT(token string) (map[string]any, error) {
|
||||
if int64(exp) < time.Now().Unix() {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
if store != nil {
|
||||
validAfter := store.AdminJWTValidAfterUnix()
|
||||
if validAfter > 0 {
|
||||
iat, _ := payload["iat"].(float64)
|
||||
if int64(iat) <= validAfter {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
}
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func VerifyAdminRequest(r *http.Request) error {
|
||||
return VerifyAdminRequestWithStore(r, nil)
|
||||
}
|
||||
|
||||
func VerifyAdminRequestWithStore(r *http.Request, store AdminConfigReader) error {
|
||||
authHeader := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
return errors.New("authentication required")
|
||||
@@ -96,17 +153,65 @@ func VerifyAdminRequest(r *http.Request) error {
|
||||
if token == "" {
|
||||
return errors.New("authentication required")
|
||||
}
|
||||
if token == AdminKey() {
|
||||
if VerifyAdminCredential(token, store) {
|
||||
return nil
|
||||
}
|
||||
if _, err := VerifyJWT(token); err == nil {
|
||||
if _, err := VerifyJWTWithStore(token, store); err == nil {
|
||||
return nil
|
||||
}
|
||||
return errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
func signHS256(msg string) []byte {
|
||||
h := hmac.New(sha256.New, []byte(jwtSecret()))
|
||||
func VerifyAdminCredential(candidate string, store AdminConfigReader) bool {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate == "" {
|
||||
return false
|
||||
}
|
||||
if store != nil {
|
||||
hash := strings.TrimSpace(store.AdminPasswordHash())
|
||||
if hash != "" {
|
||||
return verifyAdminPasswordHash(candidate, hash)
|
||||
}
|
||||
}
|
||||
key := effectiveAdminKey(store)
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(candidate), []byte(key)) == 1
|
||||
}
|
||||
|
||||
func UsingDefaultAdminKey(store AdminConfigReader) bool {
|
||||
if store != nil && strings.TrimSpace(store.AdminPasswordHash()) != "" {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")) == ""
|
||||
}
|
||||
|
||||
func HashAdminPassword(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return "sha256:" + hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func verifyAdminPasswordHash(candidate, encoded string) bool {
|
||||
encoded = strings.TrimSpace(strings.ToLower(encoded))
|
||||
if encoded == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(encoded, "sha256:") {
|
||||
want := strings.TrimPrefix(encoded, "sha256:")
|
||||
sum := sha256.Sum256([]byte(candidate))
|
||||
got := hex.EncodeToString(sum[:])
|
||||
return subtle.ConstantTimeCompare([]byte(got), []byte(want)) == 1
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(candidate), []byte(encoded)) == 1
|
||||
}
|
||||
|
||||
func signHS256(msg string, store AdminConfigReader) []byte {
|
||||
h := hmac.New(sha256.New, []byte(jwtSecret(store)))
|
||||
_, _ = h.Write([]byte(msg))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package auth
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func TestJWTCreateVerify(t *testing.T) {
|
||||
@@ -27,3 +29,58 @@ func TestVerifyAdminRequest(t *testing.T) {
|
||||
t.Fatalf("expected token accepted: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyJWTWithStoreValidAfter(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"admin":{"password_hash":"`+HashAdminPassword("oldpass")+`"}}`)
|
||||
store := config.LoadStore()
|
||||
token, err := CreateJWTWithStore(1, store)
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt failed: %v", err)
|
||||
}
|
||||
if _, err := VerifyJWTWithStore(token, store); err != nil {
|
||||
t.Fatalf("verify before invalidation failed: %v", err)
|
||||
}
|
||||
if err := store.Update(func(c *config.Config) error {
|
||||
c.Admin.JWTValidAfterUnix = 1<<62 - 1
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("set valid-after failed: %v", err)
|
||||
}
|
||||
if _, err := VerifyJWTWithStore(token, store); err == nil {
|
||||
t.Fatal("expected token invalid after valid-after update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyJWTWithStoreSameSecondInvalidationAndRelogin(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"admin":{"password_hash":"`+HashAdminPassword("oldpass")+`"}}`)
|
||||
store := config.LoadStore()
|
||||
|
||||
oldToken, err := CreateJWTWithStore(1, store)
|
||||
if err != nil {
|
||||
t.Fatalf("create old jwt failed: %v", err)
|
||||
}
|
||||
oldPayload, err := VerifyJWTWithStore(oldToken, store)
|
||||
if err != nil {
|
||||
t.Fatalf("verify old jwt before invalidation failed: %v", err)
|
||||
}
|
||||
oldIAT, _ := oldPayload["iat"].(float64)
|
||||
|
||||
if err := store.Update(func(c *config.Config) error {
|
||||
c.Admin.JWTValidAfterUnix = int64(oldIAT)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("set valid-after failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := VerifyJWTWithStore(oldToken, store); err == nil {
|
||||
t.Fatal("expected old token invalid when iat == valid-after")
|
||||
}
|
||||
|
||||
newToken, err := CreateJWTWithStore(1, store)
|
||||
if err != nil {
|
||||
t.Fatalf("create new jwt failed: %v", err)
|
||||
}
|
||||
if _, err := VerifyJWTWithStore(newToken, store); err != nil {
|
||||
t.Fatalf("expected new token valid after invalidation cutoff: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
48
internal/claudeconv/convert.go
Normal file
48
internal/claudeconv/convert.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package claudeconv
|
||||
|
||||
import "strings"
|
||||
|
||||
type ClaudeMappingProvider interface {
|
||||
ClaudeMapping() map[string]string
|
||||
}
|
||||
|
||||
func ConvertClaudeToDeepSeek(claudeReq map[string]any, mappingProvider ClaudeMappingProvider, defaultClaudeModel string) map[string]any {
|
||||
messages, _ := claudeReq["messages"].([]any)
|
||||
model, _ := claudeReq["model"].(string)
|
||||
if model == "" {
|
||||
model = defaultClaudeModel
|
||||
}
|
||||
|
||||
mapping := map[string]string{}
|
||||
if mappingProvider != nil {
|
||||
mapping = mappingProvider.ClaudeMapping()
|
||||
}
|
||||
dsModel := mapping["fast"]
|
||||
if dsModel == "" {
|
||||
dsModel = "deepseek-chat"
|
||||
}
|
||||
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") {
|
||||
if slow := mapping["slow"]; slow != "" {
|
||||
dsModel = slow
|
||||
}
|
||||
}
|
||||
|
||||
convertedMessages := make([]any, 0, len(messages)+1)
|
||||
if system, ok := claudeReq["system"].(string); ok && system != "" {
|
||||
convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system})
|
||||
}
|
||||
convertedMessages = append(convertedMessages, messages...)
|
||||
|
||||
out := map[string]any{"model": dsModel, "messages": convertedMessages}
|
||||
for _, k := range []string{"temperature", "top_p", "stream"} {
|
||||
if v, ok := claudeReq[k]; ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
if stopSeq, ok := claudeReq["stop_sequences"]; ok {
|
||||
out["stop"] = stopSeq
|
||||
}
|
||||
return out
|
||||
}
|
||||
142
internal/compat/go_compat_test.go
Normal file
142
internal/compat/go_compat_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func TestGoCompatSSEFixtures(t *testing.T) {
|
||||
files, err := filepath.Glob(compatPath("fixtures", "sse_chunks", "*.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("glob fixtures failed: %v", err)
|
||||
}
|
||||
if len(files) == 0 {
|
||||
t.Fatal("no sse fixtures found")
|
||||
}
|
||||
for _, fixturePath := range files {
|
||||
name := trimExt(filepath.Base(fixturePath))
|
||||
expectedPath := compatPath("expected", "sse_"+name+".json")
|
||||
|
||||
var fixture struct {
|
||||
Chunk map[string]any `json:"chunk"`
|
||||
ThinkingEnable bool `json:"thinking_enabled"`
|
||||
CurrentType string `json:"current_type"`
|
||||
}
|
||||
mustLoadJSON(t, fixturePath, &fixture)
|
||||
|
||||
var expected struct {
|
||||
Parts []map[string]any `json:"parts"`
|
||||
Finished bool `json:"finished"`
|
||||
NewType string `json:"new_type"`
|
||||
}
|
||||
mustLoadJSON(t, expectedPath, &expected)
|
||||
|
||||
parts, finished, newType := sse.ParseSSEChunkForContent(fixture.Chunk, fixture.ThinkingEnable, fixture.CurrentType)
|
||||
gotParts := make([]map[string]any, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
gotParts = append(gotParts, map[string]any{
|
||||
"text": p.Text,
|
||||
"type": p.Type,
|
||||
})
|
||||
}
|
||||
if !reflect.DeepEqual(gotParts, expected.Parts) || finished != expected.Finished || newType != expected.NewType {
|
||||
t.Fatalf("fixture %s mismatch:\n got parts=%#v finished=%v newType=%q\nwant parts=%#v finished=%v newType=%q",
|
||||
name, gotParts, finished, newType, expected.Parts, expected.Finished, expected.NewType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoCompatToolcallFixtures(t *testing.T) {
|
||||
files, err := filepath.Glob(compatPath("fixtures", "toolcalls", "*.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("glob toolcall fixtures failed: %v", err)
|
||||
}
|
||||
if len(files) == 0 {
|
||||
t.Fatal("no toolcall fixtures found")
|
||||
}
|
||||
for _, fixturePath := range files {
|
||||
name := trimExt(filepath.Base(fixturePath))
|
||||
expectedPath := compatPath("expected", "toolcalls_"+name+".json")
|
||||
|
||||
var fixture struct {
|
||||
Text string `json:"text"`
|
||||
ToolNames []string `json:"tool_names"`
|
||||
}
|
||||
mustLoadJSON(t, fixturePath, &fixture)
|
||||
|
||||
var expected struct {
|
||||
Calls []util.ParsedToolCall `json:"calls"`
|
||||
}
|
||||
mustLoadJSON(t, expectedPath, &expected)
|
||||
|
||||
got := util.ParseToolCalls(fixture.Text, fixture.ToolNames)
|
||||
if len(got) == 0 && len(expected.Calls) == 0 {
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(got, expected.Calls) {
|
||||
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected.Calls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoCompatTokenFixtures(t *testing.T) {
|
||||
var fixture struct {
|
||||
Cases []struct {
|
||||
Name string `json:"name"`
|
||||
Text string `json:"text"`
|
||||
} `json:"cases"`
|
||||
}
|
||||
mustLoadJSON(t, compatPath("fixtures", "token_cases.json"), &fixture)
|
||||
|
||||
var expected struct {
|
||||
Cases []struct {
|
||||
Name string `json:"name"`
|
||||
Tokens int `json:"tokens"`
|
||||
} `json:"cases"`
|
||||
}
|
||||
mustLoadJSON(t, compatPath("expected", "token_cases.json"), &expected)
|
||||
|
||||
expectByName := map[string]int{}
|
||||
for _, c := range expected.Cases {
|
||||
expectByName[c.Name] = c.Tokens
|
||||
}
|
||||
for _, c := range fixture.Cases {
|
||||
want, ok := expectByName[c.Name]
|
||||
if !ok {
|
||||
t.Fatalf("missing expected token case: %s", c.Name)
|
||||
}
|
||||
got := util.EstimateTokens(c.Text)
|
||||
if got != want {
|
||||
t.Fatalf("token fixture %s mismatch: got=%d want=%d", c.Name, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mustLoadJSON(t *testing.T, path string, out any) {
|
||||
t.Helper()
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read %s failed: %v", path, err)
|
||||
}
|
||||
if err := json.Unmarshal(b, out); err != nil {
|
||||
t.Fatalf("decode %s failed: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func trimExt(name string) string {
|
||||
if len(name) > 5 && name[len(name)-5:] == ".json" {
|
||||
return name[:len(name)-5]
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func compatPath(parts ...string) string {
|
||||
prefix := []string{"..", "..", "tests", "compat"}
|
||||
return filepath.Join(append(prefix, parts...)...)
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
@@ -63,6 +64,8 @@ type Config struct {
|
||||
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
|
||||
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
|
||||
ModelAliases map[string]string `json:"model_aliases,omitempty"`
|
||||
Admin AdminConfig `json:"admin,omitempty"`
|
||||
Runtime RuntimeConfig `json:"runtime,omitempty"`
|
||||
Compat CompatConfig `json:"compat,omitempty"`
|
||||
Toolcall ToolcallConfig `json:"toolcall,omitempty"`
|
||||
Responses ResponsesConfig `json:"responses,omitempty"`
|
||||
@@ -76,6 +79,18 @@ type CompatConfig struct {
|
||||
WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"`
|
||||
}
|
||||
|
||||
type AdminConfig struct {
|
||||
PasswordHash string `json:"password_hash,omitempty"`
|
||||
JWTExpireHours int `json:"jwt_expire_hours,omitempty"`
|
||||
JWTValidAfterUnix int64 `json:"jwt_valid_after_unix,omitempty"`
|
||||
}
|
||||
|
||||
type RuntimeConfig struct {
|
||||
AccountMaxInflight int `json:"account_max_inflight,omitempty"`
|
||||
AccountMaxQueue int `json:"account_max_queue,omitempty"`
|
||||
GlobalMaxInflight int `json:"global_max_inflight,omitempty"`
|
||||
}
|
||||
|
||||
type ToolcallConfig struct {
|
||||
Mode string `json:"mode,omitempty"`
|
||||
EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"`
|
||||
@@ -109,6 +124,12 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
if len(c.ModelAliases) > 0 {
|
||||
m["model_aliases"] = c.ModelAliases
|
||||
}
|
||||
if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 {
|
||||
m["admin"] = c.Admin
|
||||
}
|
||||
if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 {
|
||||
m["runtime"] = c.Runtime
|
||||
}
|
||||
if c.Compat.WideInputStrictOutput != nil {
|
||||
m["compat"] = c.Compat
|
||||
}
|
||||
@@ -158,6 +179,14 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "admin":
|
||||
if err := json.Unmarshal(v, &c.Admin); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "runtime":
|
||||
if err := json.Unmarshal(v, &c.Runtime); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
}
|
||||
case "compat":
|
||||
if err := json.Unmarshal(v, &c.Compat); err != nil {
|
||||
return fmt.Errorf("invalid field %q: %w", k, err)
|
||||
@@ -199,6 +228,8 @@ func (c Config) Clone() Config {
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
|
||||
ModelAliases: cloneStringMap(c.ModelAliases),
|
||||
Admin: c.Admin,
|
||||
Runtime: c.Runtime,
|
||||
Compat: CompatConfig{
|
||||
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
|
||||
},
|
||||
@@ -621,3 +652,92 @@ func (s *Store) EmbeddingsProvider() string {
|
||||
defer s.mu.RUnlock()
|
||||
return strings.TrimSpace(s.cfg.Embeddings.Provider)
|
||||
}
|
||||
|
||||
func (s *Store) AdminPasswordHash() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return strings.TrimSpace(s.cfg.Admin.PasswordHash)
|
||||
}
|
||||
|
||||
func (s *Store) AdminJWTExpireHours() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Admin.JWTExpireHours > 0 {
|
||||
return s.cfg.Admin.JWTExpireHours
|
||||
}
|
||||
if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" {
|
||||
if n, err := strconv.Atoi(raw); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 24
|
||||
}
|
||||
|
||||
func (s *Store) AdminJWTValidAfterUnix() int64 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.cfg.Admin.JWTValidAfterUnix
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeAccountMaxInflight() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.AccountMaxInflight > 0 {
|
||||
return s.cfg.Runtime.AccountMaxInflight
|
||||
}
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.AccountMaxQueue > 0 {
|
||||
return s.cfg.Runtime.AccountMaxQueue
|
||||
}
|
||||
for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n >= 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Runtime.GlobalMaxInflight > 0 {
|
||||
return s.cfg.Runtime.GlobalMaxInflight
|
||||
}
|
||||
for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if defaultSize < 0 {
|
||||
return 0
|
||||
}
|
||||
return defaultSize
|
||||
}
|
||||
|
||||
@@ -10,6 +10,10 @@ type ModelInfo struct {
|
||||
Permission []any `json:"permission,omitempty"`
|
||||
}
|
||||
|
||||
type ModelAliasReader interface {
|
||||
ModelAliases() map[string]string
|
||||
}
|
||||
|
||||
var DeepSeekModels = []ModelInfo{
|
||||
{ID: "deepseek-chat", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}},
|
||||
{ID: "deepseek-reasoner", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}},
|
||||
@@ -104,7 +108,7 @@ func DefaultModelAliases() map[string]string {
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveModel(store *Store, requested string) (string, bool) {
|
||||
func ResolveModel(store ModelAliasReader, requested string) (string, bool) {
|
||||
model := lower(strings.TrimSpace(requested))
|
||||
if model == "" {
|
||||
return "", false
|
||||
@@ -172,7 +176,7 @@ func OpenAIModelsResponse() map[string]any {
|
||||
return map[string]any{"object": "list", "data": DeepSeekModels}
|
||||
}
|
||||
|
||||
func OpenAIModelByID(store *Store, id string) (ModelInfo, bool) {
|
||||
func OpenAIModelByID(store ModelAliasReader, id string) (ModelInfo, bool) {
|
||||
canonical, ok := ResolveModel(store, id)
|
||||
if !ok {
|
||||
return ModelInfo{}, false
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
const (
|
||||
DeepSeekHost = "chat.deepseek.com"
|
||||
DeepSeekLoginURL = "https://chat.deepseek.com/api/v0/users/login"
|
||||
@@ -8,7 +13,7 @@ const (
|
||||
DeepSeekCompletionURL = "https://chat.deepseek.com/api/v0/chat/completion"
|
||||
)
|
||||
|
||||
var BaseHeaders = map[string]string{
|
||||
var defaultBaseHeaders = map[string]string{
|
||||
"Host": "chat.deepseek.com",
|
||||
"User-Agent": "DeepSeek/1.6.11 Android/35",
|
||||
"Accept": "application/json",
|
||||
@@ -19,6 +24,75 @@ var BaseHeaders = map[string]string{
|
||||
"accept-charset": "UTF-8",
|
||||
}
|
||||
|
||||
var defaultSkipContainsPatterns = []string{
|
||||
"quasi_status",
|
||||
"elapsed_secs",
|
||||
"token_usage",
|
||||
"pending_fragment",
|
||||
"conversation_mode",
|
||||
"fragments/-1/status",
|
||||
"fragments/-2/status",
|
||||
"fragments/-3/status",
|
||||
}
|
||||
|
||||
var defaultSkipExactPaths = []string{
|
||||
"response/search_status",
|
||||
}
|
||||
|
||||
var BaseHeaders = cloneStringMap(defaultBaseHeaders)
|
||||
var SkipContainsPatterns = cloneStringSlice(defaultSkipContainsPatterns)
|
||||
var SkipExactPathSet = toStringSet(defaultSkipExactPaths)
|
||||
|
||||
type sharedConstants struct {
|
||||
BaseHeaders map[string]string `json:"base_headers"`
|
||||
SkipContainsPattern []string `json:"skip_contains_patterns"`
|
||||
SkipExactPaths []string `json:"skip_exact_paths"`
|
||||
}
|
||||
|
||||
//go:embed constants_shared.json
|
||||
var sharedConstantsJSON []byte
|
||||
|
||||
func init() {
|
||||
cfg := sharedConstants{}
|
||||
if err := json.Unmarshal(sharedConstantsJSON, &cfg); err != nil {
|
||||
return
|
||||
}
|
||||
if len(cfg.BaseHeaders) > 0 {
|
||||
BaseHeaders = cloneStringMap(cfg.BaseHeaders)
|
||||
}
|
||||
if len(cfg.SkipContainsPattern) > 0 {
|
||||
SkipContainsPatterns = cloneStringSlice(cfg.SkipContainsPattern)
|
||||
}
|
||||
if len(cfg.SkipExactPaths) > 0 {
|
||||
SkipExactPathSet = toStringSet(cfg.SkipExactPaths)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneStringMap(in map[string]string) map[string]string {
|
||||
out := make(map[string]string, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneStringSlice(in []string) []string {
|
||||
out := make([]string, len(in))
|
||||
copy(out, in)
|
||||
return out
|
||||
}
|
||||
|
||||
func toStringSet(in []string) map[string]struct{} {
|
||||
out := make(map[string]struct{}, len(in))
|
||||
for _, v := range in {
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
out[v] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
const (
|
||||
KeepAliveTimeout = 5
|
||||
StreamIdleTimeout = 30
|
||||
|
||||
25
internal/deepseek/constants_shared.json
Normal file
25
internal/deepseek/constants_shared.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"base_headers": {
|
||||
"Host": "chat.deepseek.com",
|
||||
"User-Agent": "DeepSeek/1.6.11 Android/35",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"x-client-platform": "android",
|
||||
"x-client-version": "1.6.11",
|
||||
"x-client-locale": "zh_CN",
|
||||
"accept-charset": "UTF-8"
|
||||
},
|
||||
"skip_contains_patterns": [
|
||||
"quasi_status",
|
||||
"elapsed_secs",
|
||||
"token_usage",
|
||||
"pending_fragment",
|
||||
"conversation_mode",
|
||||
"fragments/-1/status",
|
||||
"fragments/-2/status",
|
||||
"fragments/-3/status"
|
||||
],
|
||||
"skip_exact_paths": [
|
||||
"response/search_status"
|
||||
]
|
||||
}
|
||||
15
internal/deepseek/constants_test.go
Normal file
15
internal/deepseek/constants_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package deepseek
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSharedConstantsLoaded(t *testing.T) {
|
||||
if BaseHeaders["x-client-platform"] != "android" {
|
||||
t.Fatalf("unexpected base header x-client-platform=%q", BaseHeaders["x-client-platform"])
|
||||
}
|
||||
if len(SkipContainsPatterns) == 0 {
|
||||
t.Fatal("expected skip contains patterns to be loaded")
|
||||
}
|
||||
if _, ok := SkipExactPathSet["response/search_status"]; !ok {
|
||||
t.Fatal("expected response/search_status in exact skip path set")
|
||||
}
|
||||
}
|
||||
7
internal/deepseek/prompt.go
Normal file
7
internal/deepseek/prompt.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package deepseek
|
||||
|
||||
import "ds2api/internal/prompt"
|
||||
|
||||
func MessagesPrepare(messages []map[string]any) string {
|
||||
return prompt.MessagesPrepare(messages)
|
||||
}
|
||||
46
internal/format/claude/render.go
Normal file
46
internal/format/claude/render.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func BuildMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
content := make([]map[string]any, 0, 4)
|
||||
if finalThinking != "" {
|
||||
content = append(content, map[string]any{"type": "thinking", "thinking": finalThinking})
|
||||
}
|
||||
stopReason := "end_turn"
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
content = append(content, map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
|
||||
"name": tc.Name,
|
||||
"input": tc.Input,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if finalText == "" {
|
||||
finalText = "抱歉,没有生成有效的响应内容。"
|
||||
}
|
||||
content = append(content, map[string]any{"type": "text", "text": finalText})
|
||||
}
|
||||
return map[string]any{
|
||||
"id": messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content,
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalizedMessages)),
|
||||
"output_tokens": util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText),
|
||||
},
|
||||
}
|
||||
}
|
||||
193
internal/format/openai/render.go
Normal file
193
internal/format/openai/render.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
messageObj["reasoning_content"] = finalThinking
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
|
||||
return map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
exposedOutputText = ""
|
||||
toolCalls := make([]any, 0, len(detected))
|
||||
for _, tc := range detected {
|
||||
toolCalls = append(toolCalls, map[string]any{
|
||||
"type": "tool_call",
|
||||
"name": tc.Name,
|
||||
"arguments": tc.Input,
|
||||
})
|
||||
}
|
||||
output = append(output, map[string]any{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": toolCalls,
|
||||
})
|
||||
} else {
|
||||
content := []any{
|
||||
map[string]any{
|
||||
"type": "output_text",
|
||||
"text": finalText,
|
||||
},
|
||||
}
|
||||
if finalThinking != "" {
|
||||
content = append([]any{map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
}}, content...)
|
||||
}
|
||||
output = append(output, map[string]any{
|
||||
"type": "message",
|
||||
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"id": responseID,
|
||||
"type": "response",
|
||||
"object": "response",
|
||||
"created_at": time.Now().Unix(),
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"output_text": exposedOutputText,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": promptTokens,
|
||||
"output_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": delta,
|
||||
"index": index,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": map[string]any{},
|
||||
"index": index,
|
||||
"finish_reason": finishReason,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
|
||||
out := map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": choices,
|
||||
}
|
||||
if len(usage) > 0 {
|
||||
out["usage"] = usage
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.created",
|
||||
"id": responseID,
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"status": "in_progress",
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
"id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning.delta",
|
||||
"id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
"id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.done",
|
||||
"id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": response,
|
||||
}
|
||||
}
|
||||
84
internal/prompt/messages.go
Normal file
84
internal/prompt/messages.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`)
|
||||
|
||||
func MessagesPrepare(messages []map[string]any) string {
|
||||
type block struct {
|
||||
Role string
|
||||
Text string
|
||||
}
|
||||
processed := make([]block, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
role, _ := m["role"].(string)
|
||||
text := NormalizeContent(m["content"])
|
||||
processed = append(processed, block{Role: role, Text: text})
|
||||
}
|
||||
if len(processed) == 0 {
|
||||
return ""
|
||||
}
|
||||
merged := make([]block, 0, len(processed))
|
||||
for _, msg := range processed {
|
||||
if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role {
|
||||
merged[len(merged)-1].Text += "\n\n" + msg.Text
|
||||
continue
|
||||
}
|
||||
merged = append(merged, msg)
|
||||
}
|
||||
parts := make([]string, 0, len(merged))
|
||||
for i, m := range merged {
|
||||
switch m.Role {
|
||||
case "assistant":
|
||||
parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>")
|
||||
case "user", "system":
|
||||
if i > 0 {
|
||||
parts = append(parts, "<|User|>"+m.Text)
|
||||
} else {
|
||||
parts = append(parts, m.Text)
|
||||
}
|
||||
default:
|
||||
parts = append(parts, m.Text)
|
||||
}
|
||||
}
|
||||
out := strings.Join(parts, "")
|
||||
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
|
||||
}
|
||||
|
||||
func NormalizeContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, item := range x {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr, _ := m["type"].(string)
|
||||
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
|
||||
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
|
||||
if txt, ok := m["text"].(string); ok {
|
||||
parts = append(parts, txt)
|
||||
continue
|
||||
}
|
||||
if txt, ok := m["content"].(string); ok {
|
||||
parts = append(parts, txt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/deepseek"
|
||||
)
|
||||
|
||||
type ContentPart struct {
|
||||
@@ -11,11 +13,6 @@ type ContentPart struct {
|
||||
Type string
|
||||
}
|
||||
|
||||
var skipPatterns = []string{
|
||||
"quasi_status", "elapsed_secs", "token_usage", "pending_fragment", "conversation_mode",
|
||||
"fragments/-1/status", "fragments/-2/status", "fragments/-3/status",
|
||||
}
|
||||
|
||||
func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) {
|
||||
line := strings.TrimSpace(string(raw))
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
@@ -33,10 +30,10 @@ func ParseDeepSeekSSELine(raw []byte) (map[string]any, bool, bool) {
|
||||
}
|
||||
|
||||
func shouldSkipPath(path string) bool {
|
||||
if path == "response/search_status" {
|
||||
if _, ok := deepseek.SkipExactPathSet[path]; ok {
|
||||
return true
|
||||
}
|
||||
for _, p := range skipPatterns {
|
||||
for _, p := range deepseek.SkipContainsPatterns {
|
||||
if strings.Contains(path, p) {
|
||||
return true
|
||||
}
|
||||
@@ -60,126 +57,159 @@ func ParseSSEChunkForContent(chunk map[string]any, thinkingEnabled bool, current
|
||||
}
|
||||
newType := currentFragmentType
|
||||
parts := make([]ContentPart, 0, 8)
|
||||
collectDirectFragments(path, chunk, v, &newType, &parts)
|
||||
updateTypeFromNestedResponse(path, v, &newType)
|
||||
partType := resolvePartType(path, thinkingEnabled, newType)
|
||||
finished := appendChunkValueContent(v, partType, &newType, &parts, path)
|
||||
if finished {
|
||||
return nil, true, newType
|
||||
}
|
||||
return parts, false, newType
|
||||
}
|
||||
|
||||
// Newer DeepSeek responses may emit fragment APPEND directly on
|
||||
// path "response/fragments" instead of wrapping it in path "response".
|
||||
if path == "response/fragments" {
|
||||
if op, _ := chunk["o"].(string); strings.EqualFold(op, "APPEND") {
|
||||
if frags, ok := v.([]any); ok {
|
||||
for _, frag := range frags {
|
||||
fm, ok := frag.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
t, _ := fm["type"].(string)
|
||||
content, _ := fm["content"].(string)
|
||||
t = strings.ToUpper(t)
|
||||
switch t {
|
||||
case "THINK", "THINKING":
|
||||
newType = "thinking"
|
||||
if content != "" {
|
||||
parts = append(parts, ContentPart{Text: content, Type: "thinking"})
|
||||
}
|
||||
case "RESPONSE":
|
||||
newType = "text"
|
||||
if content != "" {
|
||||
parts = append(parts, ContentPart{Text: content, Type: "text"})
|
||||
}
|
||||
default:
|
||||
if content != "" {
|
||||
parts = append(parts, ContentPart{Text: content, Type: "text"})
|
||||
}
|
||||
}
|
||||
}
|
||||
func collectDirectFragments(path string, chunk map[string]any, v any, newType *string, parts *[]ContentPart) {
|
||||
if path != "response/fragments" {
|
||||
return
|
||||
}
|
||||
op, _ := chunk["o"].(string)
|
||||
if !strings.EqualFold(op, "APPEND") {
|
||||
return
|
||||
}
|
||||
frags, ok := v.([]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, frag := range frags {
|
||||
m, ok := frag.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeName, content, fragType := parseFragmentTypeContent(m)
|
||||
if typeName == "" {
|
||||
typeName = fragType
|
||||
}
|
||||
switch typeName {
|
||||
case "THINK", "THINKING":
|
||||
*newType = "thinking"
|
||||
appendContentPart(parts, content, "thinking")
|
||||
case "RESPONSE":
|
||||
*newType = "text"
|
||||
appendContentPart(parts, content, "text")
|
||||
default:
|
||||
appendContentPart(parts, content, "text")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func updateTypeFromNestedResponse(path string, v any, newType *string) {
|
||||
if path != "response" {
|
||||
return
|
||||
}
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, it := range arr {
|
||||
m, ok := it.(map[string]any)
|
||||
if !ok || m["p"] != "fragments" || m["o"] != "APPEND" {
|
||||
continue
|
||||
}
|
||||
frags, ok := m["v"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, frag := range frags {
|
||||
fm, ok := frag.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeName, _, _ := parseFragmentTypeContent(fm)
|
||||
switch typeName {
|
||||
case "THINK", "THINKING":
|
||||
*newType = "thinking"
|
||||
case "RESPONSE":
|
||||
*newType = "text"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if path == "response" {
|
||||
if arr, ok := v.([]any); ok {
|
||||
for _, it := range arr {
|
||||
m, ok := it.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if m["p"] == "fragments" && m["o"] == "APPEND" {
|
||||
if frags, ok := m["v"].([]any); ok {
|
||||
for _, frag := range frags {
|
||||
fm, ok := frag.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
t, _ := fm["type"].(string)
|
||||
t = strings.ToUpper(t)
|
||||
if t == "THINK" || t == "THINKING" {
|
||||
newType = "thinking"
|
||||
} else if t == "RESPONSE" {
|
||||
newType = "text"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
partType := "text"
|
||||
func resolvePartType(path string, thinkingEnabled bool, newType string) string {
|
||||
switch {
|
||||
case path == "response/thinking_content":
|
||||
partType = "thinking"
|
||||
return "thinking"
|
||||
case path == "response/content":
|
||||
partType = "text"
|
||||
return "text"
|
||||
case strings.Contains(path, "response/fragments") && strings.Contains(path, "/content"):
|
||||
partType = newType
|
||||
case path == "":
|
||||
if thinkingEnabled {
|
||||
partType = newType
|
||||
}
|
||||
return newType
|
||||
case path == "" && thinkingEnabled:
|
||||
return newType
|
||||
default:
|
||||
return "text"
|
||||
}
|
||||
}
|
||||
|
||||
func appendChunkValueContent(v any, partType string, newType *string, parts *[]ContentPart, path string) bool {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
if val == "FINISHED" && (path == "" || path == "status") {
|
||||
return nil, true, newType
|
||||
}
|
||||
if val != "" {
|
||||
parts = append(parts, ContentPart{Text: val, Type: partType})
|
||||
return true
|
||||
}
|
||||
appendContentPart(parts, val, partType)
|
||||
case []any:
|
||||
pp, finished := extractContentRecursive(val, partType)
|
||||
if finished {
|
||||
return nil, true, newType
|
||||
return true
|
||||
}
|
||||
parts = append(parts, pp...)
|
||||
*parts = append(*parts, pp...)
|
||||
case map[string]any:
|
||||
resp := val
|
||||
if wrapped, ok := val["response"].(map[string]any); ok {
|
||||
resp = wrapped
|
||||
appendWrappedFragments(val, partType, newType, parts)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func appendWrappedFragments(val map[string]any, partType string, newType *string, parts *[]ContentPart) {
|
||||
resp := val
|
||||
if wrapped, ok := val["response"].(map[string]any); ok {
|
||||
resp = wrapped
|
||||
}
|
||||
frags, ok := resp["fragments"].([]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, item := range frags {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if frags, ok := resp["fragments"].([]any); ok {
|
||||
for _, item := range frags {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
t, _ := m["type"].(string)
|
||||
content, _ := m["content"].(string)
|
||||
t = strings.ToUpper(t)
|
||||
if t == "THINK" || t == "THINKING" {
|
||||
newType = "thinking"
|
||||
if content != "" {
|
||||
parts = append(parts, ContentPart{Text: content, Type: "thinking"})
|
||||
}
|
||||
} else if t == "RESPONSE" {
|
||||
newType = "text"
|
||||
if content != "" {
|
||||
parts = append(parts, ContentPart{Text: content, Type: "text"})
|
||||
}
|
||||
} else if content != "" {
|
||||
parts = append(parts, ContentPart{Text: content, Type: partType})
|
||||
}
|
||||
}
|
||||
typeName, content, fragType := parseFragmentTypeContent(m)
|
||||
if typeName == "" {
|
||||
typeName = fragType
|
||||
}
|
||||
switch typeName {
|
||||
case "THINK", "THINKING":
|
||||
*newType = "thinking"
|
||||
appendContentPart(parts, content, "thinking")
|
||||
case "RESPONSE":
|
||||
*newType = "text"
|
||||
appendContentPart(parts, content, "text")
|
||||
default:
|
||||
appendContentPart(parts, content, partType)
|
||||
}
|
||||
}
|
||||
return parts, false, newType
|
||||
}
|
||||
|
||||
func parseFragmentTypeContent(m map[string]any) (string, string, string) {
|
||||
typeName, _ := m["type"].(string)
|
||||
content, _ := m["content"].(string)
|
||||
return strings.ToUpper(typeName), content, strings.ToUpper(typeName)
|
||||
}
|
||||
|
||||
func appendContentPart(parts *[]ContentPart, content, kind string) {
|
||||
if content == "" {
|
||||
return
|
||||
}
|
||||
*parts = append(*parts, ContentPart{Text: content, Type: kind})
|
||||
}
|
||||
|
||||
func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bool) {
|
||||
|
||||
128
internal/stream/engine.go
Normal file
128
internal/stream/engine.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
type StopReason string
|
||||
|
||||
const (
|
||||
StopReasonNone StopReason = ""
|
||||
StopReasonContextCancelled StopReason = "context_cancelled"
|
||||
StopReasonNoContentTimeout StopReason = "no_content_timeout"
|
||||
StopReasonIdleTimeout StopReason = "idle_timeout"
|
||||
StopReasonUpstreamCompleted StopReason = "upstream_completed"
|
||||
StopReasonHandlerRequested StopReason = "handler_requested"
|
||||
)
|
||||
|
||||
type ConsumeConfig struct {
|
||||
Context context.Context
|
||||
Body io.Reader
|
||||
ThinkingEnabled bool
|
||||
InitialType string
|
||||
KeepAliveInterval time.Duration
|
||||
IdleTimeout time.Duration
|
||||
MaxKeepAliveNoInput int
|
||||
}
|
||||
|
||||
type ParsedDecision struct {
|
||||
Stop bool
|
||||
StopReason StopReason
|
||||
ContentSeen bool
|
||||
}
|
||||
|
||||
type ConsumeHooks struct {
|
||||
OnParsed func(parsed sse.LineResult) ParsedDecision
|
||||
OnKeepAlive func()
|
||||
OnFinalize func(reason StopReason, scannerErr error)
|
||||
OnContextDone func()
|
||||
}
|
||||
|
||||
func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) {
|
||||
if cfg.Context == nil {
|
||||
cfg.Context = context.Background()
|
||||
}
|
||||
initialType := cfg.InitialType
|
||||
if initialType == "" {
|
||||
if cfg.ThinkingEnabled {
|
||||
initialType = "thinking"
|
||||
} else {
|
||||
initialType = "text"
|
||||
}
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(cfg.Context, cfg.Body, cfg.ThinkingEnabled, initialType)
|
||||
|
||||
var ticker *time.Ticker
|
||||
if cfg.KeepAliveInterval > 0 {
|
||||
ticker = time.NewTicker(cfg.KeepAliveInterval)
|
||||
defer ticker.Stop()
|
||||
}
|
||||
|
||||
hasContent := false
|
||||
lastContent := time.Now()
|
||||
keepaliveCount := 0
|
||||
|
||||
finalize := func(reason StopReason, scannerErr error) {
|
||||
if hooks.OnFinalize != nil {
|
||||
hooks.OnFinalize(reason, scannerErr)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cfg.Context.Done():
|
||||
if hooks.OnContextDone != nil {
|
||||
hooks.OnContextDone()
|
||||
}
|
||||
return
|
||||
case <-tickCh(ticker):
|
||||
if !hasContent {
|
||||
keepaliveCount++
|
||||
if cfg.MaxKeepAliveNoInput > 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput {
|
||||
finalize(StopReasonNoContentTimeout, nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
if hasContent && cfg.IdleTimeout > 0 && time.Since(lastContent) > cfg.IdleTimeout {
|
||||
finalize(StopReasonIdleTimeout, nil)
|
||||
return
|
||||
}
|
||||
if hooks.OnKeepAlive != nil {
|
||||
hooks.OnKeepAlive()
|
||||
}
|
||||
case parsed, ok := <-parsedLines:
|
||||
if !ok {
|
||||
finalize(StopReasonUpstreamCompleted, <-done)
|
||||
return
|
||||
}
|
||||
if hooks.OnParsed == nil {
|
||||
continue
|
||||
}
|
||||
decision := hooks.OnParsed(parsed)
|
||||
if decision.ContentSeen {
|
||||
hasContent = true
|
||||
lastContent = time.Now()
|
||||
keepaliveCount = 0
|
||||
}
|
||||
if decision.Stop {
|
||||
reason := decision.StopReason
|
||||
if reason == StopReasonNone {
|
||||
reason = StopReasonHandlerRequested
|
||||
}
|
||||
finalize(reason, nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tickCh(ticker *time.Ticker) <-chan time.Time {
|
||||
if ticker == nil {
|
||||
return nil
|
||||
}
|
||||
return ticker.C
|
||||
}
|
||||
@@ -327,7 +327,7 @@ func (r *Runner) runPreflight(ctx context.Context) error {
|
||||
{"go", "test", "./...", "-count=1"},
|
||||
{"node", "--check", "api/chat-stream.js"},
|
||||
{"node", "--check", "api/helpers/stream-tool-sieve.js"},
|
||||
{"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js"},
|
||||
{"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js", "api/compat/js_compat_test.js"},
|
||||
{"npm", "run", "build", "--prefix", "webui"},
|
||||
}
|
||||
f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/claudeconv"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/prompt"
|
||||
)
|
||||
|
||||
var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`)
|
||||
|
||||
const ClaudeDefaultModel = "claude-sonnet-4-5"
|
||||
|
||||
type Message struct {
|
||||
@@ -19,112 +14,15 @@ type Message struct {
|
||||
}
|
||||
|
||||
func MessagesPrepare(messages []map[string]any) string {
|
||||
type block struct {
|
||||
Role string
|
||||
Text string
|
||||
}
|
||||
processed := make([]block, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
role, _ := m["role"].(string)
|
||||
text := normalizeContent(m["content"])
|
||||
processed = append(processed, block{Role: role, Text: text})
|
||||
}
|
||||
if len(processed) == 0 {
|
||||
return ""
|
||||
}
|
||||
merged := make([]block, 0, len(processed))
|
||||
for _, msg := range processed {
|
||||
if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role {
|
||||
merged[len(merged)-1].Text += "\n\n" + msg.Text
|
||||
continue
|
||||
}
|
||||
merged = append(merged, msg)
|
||||
}
|
||||
parts := make([]string, 0, len(merged))
|
||||
for i, m := range merged {
|
||||
switch m.Role {
|
||||
case "assistant":
|
||||
parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>")
|
||||
case "user", "system":
|
||||
if i > 0 {
|
||||
parts = append(parts, "<|User|>"+m.Text)
|
||||
} else {
|
||||
parts = append(parts, m.Text)
|
||||
}
|
||||
default:
|
||||
parts = append(parts, m.Text)
|
||||
}
|
||||
}
|
||||
out := strings.Join(parts, "")
|
||||
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
|
||||
return prompt.MessagesPrepare(messages)
|
||||
}
|
||||
|
||||
func normalizeContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, item := range x {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr, _ := m["type"].(string)
|
||||
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
|
||||
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
|
||||
if txt, ok := m["text"].(string); ok {
|
||||
parts = append(parts, txt)
|
||||
continue
|
||||
}
|
||||
if txt, ok := m["content"].(string); ok {
|
||||
parts = append(parts, txt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
return prompt.NormalizeContent(v)
|
||||
}
|
||||
|
||||
func ConvertClaudeToDeepSeek(claudeReq map[string]any, store *config.Store) map[string]any {
|
||||
messages, _ := claudeReq["messages"].([]any)
|
||||
model, _ := claudeReq["model"].(string)
|
||||
if model == "" {
|
||||
model = ClaudeDefaultModel
|
||||
}
|
||||
mapping := store.ClaudeMapping()
|
||||
dsModel := mapping["fast"]
|
||||
if dsModel == "" {
|
||||
dsModel = "deepseek-chat"
|
||||
}
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "opus") || strings.Contains(modelLower, "reasoner") || strings.Contains(modelLower, "slow") {
|
||||
if slow := mapping["slow"]; slow != "" {
|
||||
dsModel = slow
|
||||
}
|
||||
}
|
||||
convertedMessages := make([]any, 0, len(messages)+1)
|
||||
if system, ok := claudeReq["system"].(string); ok && system != "" {
|
||||
convertedMessages = append(convertedMessages, map[string]any{"role": "system", "content": system})
|
||||
}
|
||||
convertedMessages = append(convertedMessages, messages...)
|
||||
|
||||
out := map[string]any{"model": dsModel, "messages": convertedMessages}
|
||||
for _, k := range []string{"temperature", "top_p", "stream"} {
|
||||
if v, ok := claudeReq[k]; ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
if stopSeq, ok := claudeReq["stop_sequences"]; ok {
|
||||
out["stop"] = stopSeq
|
||||
}
|
||||
return out
|
||||
return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, ClaudeDefaultModel)
|
||||
}
|
||||
|
||||
// EstimateTokens provides a rough token count approximation.
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// BuildOpenAIChatCompletion is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildChatCompletion for new code.
|
||||
func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := ParseToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
@@ -41,6 +43,8 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponseObject is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponseObject for new code.
|
||||
func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := ParseToolCalls(finalText, toolNames)
|
||||
exposedOutputText := finalText
|
||||
@@ -101,6 +105,8 @@ func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, fi
|
||||
}
|
||||
}
|
||||
|
||||
// BuildClaudeMessageResponse is kept for backward compatibility.
|
||||
// Prefer internal/format/claude.BuildMessageResponse for new code.
|
||||
func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []any, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := ParseToolCalls(finalText, toolNames)
|
||||
content := make([]map[string]any, 0, 4)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package util
|
||||
|
||||
// BuildOpenAIChatStreamDeltaChoice is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildChatStreamDeltaChoice for new code.
|
||||
func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": delta,
|
||||
@@ -7,6 +9,8 @@ func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[strin
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIChatStreamFinishChoice is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildChatStreamFinishChoice for new code.
|
||||
func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": map[string]any{},
|
||||
@@ -15,6 +19,8 @@ func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[strin
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIChatStreamChunk is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildChatStreamChunk for new code.
|
||||
func BuildOpenAIChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
|
||||
out := map[string]any{
|
||||
"id": completionID,
|
||||
@@ -29,6 +35,8 @@ func BuildOpenAIChatStreamChunk(completionID string, created int64, model string
|
||||
return out
|
||||
}
|
||||
|
||||
// BuildOpenAIChatUsage is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildChatUsage for new code.
|
||||
func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := EstimateTokens(finalPrompt)
|
||||
reasoningTokens := EstimateTokens(finalThinking)
|
||||
@@ -43,6 +51,8 @@ func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[stri
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesCreatedPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesCreatedPayload for new code.
|
||||
func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.created",
|
||||
@@ -53,6 +63,8 @@ func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesTextDeltaPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesTextDeltaPayload for new code.
|
||||
func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
@@ -61,6 +73,8 @@ func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]a
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesReasoningDeltaPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesReasoningDeltaPayload for new code.
|
||||
func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning.delta",
|
||||
@@ -69,6 +83,8 @@ func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[str
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesToolCallDeltaPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesToolCallDeltaPayload for new code.
|
||||
func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
@@ -77,6 +93,8 @@ func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesToolCallDonePayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesToolCallDonePayload for new code.
|
||||
func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.done",
|
||||
@@ -85,6 +103,8 @@ func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesCompletedPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesCompletedPayload for new code.
|
||||
func BuildOpenAIResponsesCompletedPayload(response map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.completed",
|
||||
|
||||
8
tests/compat/expected/sse_fragments_append.json
Normal file
8
tests/compat/expected/sse_fragments_append.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"parts": [
|
||||
{"text": "思考中", "type": "thinking"},
|
||||
{"text": "结论", "type": "text"}
|
||||
],
|
||||
"finished": false,
|
||||
"new_type": "text"
|
||||
}
|
||||
5
tests/compat/expected/sse_nested_finished.json
Normal file
5
tests/compat/expected/sse_nested_finished.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"parts": [],
|
||||
"finished": true,
|
||||
"new_type": "text"
|
||||
}
|
||||
8
tests/compat/expected/sse_split_tool_json.json
Normal file
8
tests/compat/expected/sse_split_tool_json.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"parts": [
|
||||
{"text": "{\"", "type": "text"},
|
||||
{"text": "tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}", "type": "text"}
|
||||
],
|
||||
"finished": false,
|
||||
"new_type": "text"
|
||||
}
|
||||
7
tests/compat/expected/token_cases.json
Normal file
7
tests/compat/expected/token_cases.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"cases": [
|
||||
{"name": "ascii_short", "tokens": 1},
|
||||
{"name": "cjk", "tokens": 3},
|
||||
{"name": "mixed", "tokens": 4}
|
||||
]
|
||||
}
|
||||
3
tests/compat/expected/toolcalls_fenced_json.json
Normal file
3
tests/compat/expected/toolcalls_fenced_json.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"calls": []
|
||||
}
|
||||
5
tests/compat/expected/toolcalls_unknown_name.json
Normal file
5
tests/compat/expected/toolcalls_unknown_name.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"calls": [
|
||||
{"name": "unknown_tool", "input": {"x": 1}}
|
||||
]
|
||||
}
|
||||
12
tests/compat/fixtures/sse_chunks/fragments_append.json
Normal file
12
tests/compat/fixtures/sse_chunks/fragments_append.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"chunk": {
|
||||
"p": "response/fragments",
|
||||
"o": "APPEND",
|
||||
"v": [
|
||||
{"type": "THINK", "content": "思考中"},
|
||||
{"type": "RESPONSE", "content": "结论"}
|
||||
]
|
||||
},
|
||||
"thinking_enabled": true,
|
||||
"current_type": "thinking"
|
||||
}
|
||||
10
tests/compat/fixtures/sse_chunks/nested_finished.json
Normal file
10
tests/compat/fixtures/sse_chunks/nested_finished.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"chunk": {
|
||||
"p": "response",
|
||||
"v": [
|
||||
{"p": "status", "v": "FINISHED"}
|
||||
]
|
||||
},
|
||||
"thinking_enabled": false,
|
||||
"current_type": "text"
|
||||
}
|
||||
11
tests/compat/fixtures/sse_chunks/split_tool_json.json
Normal file
11
tests/compat/fixtures/sse_chunks/split_tool_json.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"chunk": {
|
||||
"p": "response",
|
||||
"v": [
|
||||
{"p": "response/content", "v": "{\""},
|
||||
{"p": "response/content", "v": "tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}
|
||||
]
|
||||
},
|
||||
"thinking_enabled": false,
|
||||
"current_type": "text"
|
||||
}
|
||||
7
tests/compat/fixtures/token_cases.json
Normal file
7
tests/compat/fixtures/token_cases.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"cases": [
|
||||
{"name": "ascii_short", "text": "abcd"},
|
||||
{"name": "cjk", "text": "你好世界"},
|
||||
{"name": "mixed", "text": "Hello 你好世界"}
|
||||
]
|
||||
}
|
||||
4
tests/compat/fixtures/toolcalls/fenced_json.json
Normal file
4
tests/compat/fixtures/toolcalls/fenced_json.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"text": "```json\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}\n```",
|
||||
"tool_names": ["read_file"]
|
||||
}
|
||||
4
tests/compat/fixtures/toolcalls/unknown_name.json
Normal file
4
tests/compat/fixtures/toolcalls/unknown_name.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"text": "{\"tool_calls\":[{\"name\":\"unknown_tool\",\"input\":{\"x\":1}}]}",
|
||||
"tool_names": ["read_file"]
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
Key,
|
||||
Upload,
|
||||
Cloud,
|
||||
Settings as SettingsIcon,
|
||||
LogOut,
|
||||
Menu,
|
||||
X,
|
||||
@@ -23,12 +24,13 @@ import AccountManager from './components/AccountManager'
|
||||
import ApiTester from './components/ApiTester'
|
||||
import BatchImport from './components/BatchImport'
|
||||
import VercelSync from './components/VercelSync'
|
||||
import Settings from './components/Settings'
|
||||
import Login from './components/Login'
|
||||
import LandingPage from './components/LandingPage'
|
||||
import LanguageToggle from './components/LanguageToggle'
|
||||
import { useI18n } from './i18n'
|
||||
|
||||
function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message }) {
|
||||
function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout }) {
|
||||
const { t } = useI18n()
|
||||
const [activeTab, setActiveTab] = useState('accounts')
|
||||
const [sidebarOpen, setSidebarOpen] = useState(false)
|
||||
@@ -39,6 +41,7 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message
|
||||
{ id: 'test', label: t('nav.test.label'), icon: Server, description: t('nav.test.desc') },
|
||||
{ id: 'import', label: t('nav.import.label'), icon: Upload, description: t('nav.import.desc') },
|
||||
{ id: 'vercel', label: t('nav.vercel.label'), icon: Cloud, description: t('nav.vercel.desc') },
|
||||
{ id: 'settings', label: t('nav.settings.label'), icon: SettingsIcon, description: t('nav.settings.desc') },
|
||||
]
|
||||
|
||||
const authFetch = async (url, options = {}) => {
|
||||
@@ -65,6 +68,8 @@ function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message
|
||||
return <BatchImport onRefresh={fetchConfig} onMessage={showMessage} authFetch={authFetch} />
|
||||
case 'vercel':
|
||||
return <VercelSync onMessage={showMessage} authFetch={authFetch} />
|
||||
case 'settings':
|
||||
return <Settings onRefresh={fetchConfig} onMessage={showMessage} authFetch={authFetch} onForceLogout={onForceLogout} />
|
||||
default:
|
||||
return null
|
||||
}
|
||||
@@ -314,6 +319,7 @@ export default function App() {
|
||||
fetchConfig={fetchConfig}
|
||||
showMessage={showMessage}
|
||||
message={message}
|
||||
onForceLogout={handleLogout}
|
||||
/>
|
||||
) : (
|
||||
<div className="min-h-screen flex flex-col bg-background relative overflow-hidden">
|
||||
|
||||
376
webui/src/components/Settings.jsx
Normal file
376
webui/src/components/Settings.jsx
Normal file
@@ -0,0 +1,376 @@
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { AlertTriangle, Download, Lock, Save, Upload } from 'lucide-react'
|
||||
import { useI18n } from '../i18n'
|
||||
|
||||
export default function Settings({ onRefresh, onMessage, authFetch, onForceLogout }) {
|
||||
const { t } = useI18n()
|
||||
const apiFetch = authFetch || fetch
|
||||
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [saving, setSaving] = useState(false)
|
||||
const [changingPassword, setChangingPassword] = useState(false)
|
||||
const [importing, setImporting] = useState(false)
|
||||
const [exportData, setExportData] = useState(null)
|
||||
const [importMode, setImportMode] = useState('merge')
|
||||
const [importText, setImportText] = useState('')
|
||||
const [newPassword, setNewPassword] = useState('')
|
||||
const [settingsMeta, setSettingsMeta] = useState({ default_password_warning: false, env_backed: false, needs_vercel_sync: false })
|
||||
|
||||
const [form, setForm] = useState({
|
||||
admin: { jwt_expire_hours: 24 },
|
||||
runtime: { account_max_inflight: 2, account_max_queue: 10, global_max_inflight: 10 },
|
||||
toolcall: { mode: 'feature_match', early_emit_confidence: 'high' },
|
||||
responses: { store_ttl_seconds: 900 },
|
||||
embeddings: { provider: '' },
|
||||
claude_mapping_text: '{\n "fast": "deepseek-chat",\n "slow": "deepseek-reasoner"\n}',
|
||||
model_aliases_text: '{}',
|
||||
})
|
||||
|
||||
const parseJSONMap = (raw, fieldName) => {
|
||||
const text = String(raw || '').trim()
|
||||
if (!text) {
|
||||
return {}
|
||||
}
|
||||
let parsed
|
||||
try {
|
||||
parsed = JSON.parse(text)
|
||||
} catch (_e) {
|
||||
throw new Error(t('settings.invalidJsonField', { field: fieldName }))
|
||||
}
|
||||
if (!parsed || typeof parsed !== 'object' || Array.isArray(parsed)) {
|
||||
throw new Error(t('settings.invalidJsonField', { field: fieldName }))
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
const loadSettings = useCallback(async () => {
|
||||
setLoading(true)
|
||||
try {
|
||||
const res = await apiFetch('/admin/settings')
|
||||
const data = await res.json()
|
||||
if (!res.ok) {
|
||||
onMessage('error', data.detail || t('settings.loadFailed'))
|
||||
return
|
||||
}
|
||||
setSettingsMeta({
|
||||
default_password_warning: Boolean(data.admin?.default_password_warning),
|
||||
env_backed: Boolean(data.env_backed),
|
||||
needs_vercel_sync: Boolean(data.needs_vercel_sync),
|
||||
})
|
||||
setForm({
|
||||
admin: { jwt_expire_hours: Number(data.admin?.jwt_expire_hours || 24) },
|
||||
runtime: {
|
||||
account_max_inflight: Number(data.runtime?.account_max_inflight || 2),
|
||||
account_max_queue: Number(data.runtime?.account_max_queue || 10),
|
||||
global_max_inflight: Number(data.runtime?.global_max_inflight || 10),
|
||||
},
|
||||
toolcall: {
|
||||
mode: data.toolcall?.mode || 'feature_match',
|
||||
early_emit_confidence: data.toolcall?.early_emit_confidence || 'high',
|
||||
},
|
||||
responses: {
|
||||
store_ttl_seconds: Number(data.responses?.store_ttl_seconds || 900),
|
||||
},
|
||||
embeddings: {
|
||||
provider: data.embeddings?.provider || '',
|
||||
},
|
||||
claude_mapping_text: JSON.stringify(data.claude_mapping || {}, null, 2),
|
||||
model_aliases_text: JSON.stringify(data.model_aliases || {}, null, 2),
|
||||
})
|
||||
} catch (e) {
|
||||
onMessage('error', t('settings.loadFailed'))
|
||||
// eslint-disable-next-line no-console
|
||||
console.error(e)
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [apiFetch, onMessage, t])
|
||||
|
||||
useEffect(() => {
|
||||
loadSettings()
|
||||
}, [loadSettings])
|
||||
|
||||
const saveSettings = async () => {
|
||||
let claudeMapping = {}
|
||||
let modelAliases = {}
|
||||
try {
|
||||
claudeMapping = parseJSONMap(form.claude_mapping_text, 'claude_mapping')
|
||||
modelAliases = parseJSONMap(form.model_aliases_text, 'model_aliases')
|
||||
} catch (e) {
|
||||
onMessage('error', e.message)
|
||||
return
|
||||
}
|
||||
|
||||
const payload = {
|
||||
admin: { jwt_expire_hours: Number(form.admin.jwt_expire_hours) },
|
||||
runtime: {
|
||||
account_max_inflight: Number(form.runtime.account_max_inflight),
|
||||
account_max_queue: Number(form.runtime.account_max_queue),
|
||||
global_max_inflight: Number(form.runtime.global_max_inflight),
|
||||
},
|
||||
toolcall: {
|
||||
mode: String(form.toolcall.mode || '').trim(),
|
||||
early_emit_confidence: String(form.toolcall.early_emit_confidence || '').trim(),
|
||||
},
|
||||
responses: { store_ttl_seconds: Number(form.responses.store_ttl_seconds) },
|
||||
embeddings: { provider: String(form.embeddings.provider || '').trim() },
|
||||
claude_mapping: claudeMapping,
|
||||
model_aliases: modelAliases,
|
||||
}
|
||||
|
||||
setSaving(true)
|
||||
try {
|
||||
const res = await apiFetch('/admin/settings', {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(payload),
|
||||
})
|
||||
const data = await res.json()
|
||||
if (!res.ok) {
|
||||
onMessage('error', data.detail || t('settings.saveFailed'))
|
||||
return
|
||||
}
|
||||
onMessage('success', t('settings.saveSuccess'))
|
||||
if (typeof onRefresh === 'function') {
|
||||
onRefresh()
|
||||
}
|
||||
await loadSettings()
|
||||
} catch (e) {
|
||||
onMessage('error', t('settings.saveFailed'))
|
||||
// eslint-disable-next-line no-console
|
||||
console.error(e)
|
||||
} finally {
|
||||
setSaving(false)
|
||||
}
|
||||
}
|
||||
|
||||
const updatePassword = async () => {
|
||||
if (String(newPassword || '').trim().length < 4) {
|
||||
onMessage('error', t('settings.passwordTooShort'))
|
||||
return
|
||||
}
|
||||
setChangingPassword(true)
|
||||
try {
|
||||
const res = await apiFetch('/admin/settings/password', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ new_password: newPassword.trim() }),
|
||||
})
|
||||
const data = await res.json()
|
||||
if (!res.ok) {
|
||||
onMessage('error', data.detail || t('settings.passwordUpdateFailed'))
|
||||
return
|
||||
}
|
||||
onMessage('success', t('settings.passwordUpdated'))
|
||||
setNewPassword('')
|
||||
if (typeof onForceLogout === 'function') {
|
||||
onForceLogout()
|
||||
}
|
||||
} catch (e) {
|
||||
onMessage('error', t('settings.passwordUpdateFailed'))
|
||||
} finally {
|
||||
setChangingPassword(false)
|
||||
}
|
||||
}
|
||||
|
||||
const loadExportData = async () => {
|
||||
try {
|
||||
const res = await apiFetch('/admin/config/export')
|
||||
const data = await res.json()
|
||||
if (!res.ok) {
|
||||
onMessage('error', data.detail || t('settings.exportFailed'))
|
||||
return
|
||||
}
|
||||
setExportData(data)
|
||||
onMessage('success', t('settings.exportLoaded'))
|
||||
} catch (e) {
|
||||
onMessage('error', t('settings.exportFailed'))
|
||||
}
|
||||
}
|
||||
|
||||
const doImport = async () => {
|
||||
if (!String(importText || '').trim()) {
|
||||
onMessage('error', t('settings.importEmpty'))
|
||||
return
|
||||
}
|
||||
let parsed
|
||||
try {
|
||||
parsed = JSON.parse(importText)
|
||||
} catch (_e) {
|
||||
onMessage('error', t('settings.importInvalidJson'))
|
||||
return
|
||||
}
|
||||
setImporting(true)
|
||||
try {
|
||||
const res = await apiFetch(`/admin/config/import?mode=${encodeURIComponent(importMode)}`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ config: parsed, mode: importMode }),
|
||||
})
|
||||
const data = await res.json()
|
||||
if (!res.ok) {
|
||||
onMessage('error', data.detail || t('settings.importFailed'))
|
||||
return
|
||||
}
|
||||
onMessage('success', t('settings.importSuccess', { mode: importMode }))
|
||||
if (typeof onRefresh === 'function') {
|
||||
onRefresh()
|
||||
}
|
||||
await loadSettings()
|
||||
} catch (e) {
|
||||
onMessage('error', t('settings.importFailed'))
|
||||
} finally {
|
||||
setImporting(false)
|
||||
}
|
||||
}
|
||||
|
||||
const syncHintVisible = useMemo(() => settingsMeta.env_backed || settingsMeta.needs_vercel_sync, [settingsMeta.env_backed, settingsMeta.needs_vercel_sync])
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{settingsMeta.default_password_warning && (
|
||||
<div className="p-4 rounded-lg border border-amber-300/30 bg-amber-500/10 text-amber-700 flex items-center gap-2">
|
||||
<AlertTriangle className="w-4 h-4" />
|
||||
<span className="text-sm">{t('settings.defaultPasswordWarning')}</span>
|
||||
</div>
|
||||
)}
|
||||
{syncHintVisible && (
|
||||
<div className="p-4 rounded-lg border border-amber-300/30 bg-amber-500/10 text-amber-700 flex items-center gap-2">
|
||||
<AlertTriangle className="w-4 h-4" />
|
||||
<span className="text-sm">{t('settings.vercelSyncHint')}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="bg-card border border-border rounded-xl p-5 space-y-4">
|
||||
<h3 className="font-semibold">{t('settings.securityTitle')}</h3>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.jwtExpireHours')}</span>
|
||||
<input
|
||||
type="number"
|
||||
min={1}
|
||||
max={720}
|
||||
value={form.admin.jwt_expire_hours}
|
||||
onChange={(e) => setForm((prev) => ({ ...prev, admin: { ...prev.admin, jwt_expire_hours: Number(e.target.value || 1) } }))}
|
||||
className="w-full bg-background border border-border rounded-lg px-3 py-2"
|
||||
/>
|
||||
</label>
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.newPassword')}</span>
|
||||
<div className="flex gap-2">
|
||||
<input
|
||||
type="password"
|
||||
value={newPassword}
|
||||
onChange={(e) => setNewPassword(e.target.value)}
|
||||
placeholder={t('settings.newPasswordPlaceholder')}
|
||||
className="w-full bg-background border border-border rounded-lg px-3 py-2"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={updatePassword}
|
||||
disabled={changingPassword}
|
||||
className="px-3 py-2 rounded-lg bg-secondary border border-border hover:bg-secondary/80 text-sm flex items-center gap-1"
|
||||
>
|
||||
<Lock className="w-4 h-4" />
|
||||
{changingPassword ? t('settings.updating') : t('settings.updatePassword')}
|
||||
</button>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-card border border-border rounded-xl p-5 space-y-4">
|
||||
<h3 className="font-semibold">{t('settings.runtimeTitle')}</h3>
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.accountMaxInflight')}</span>
|
||||
<input type="number" min={1} value={form.runtime.account_max_inflight} onChange={(e) => setForm((prev) => ({ ...prev, runtime: { ...prev.runtime, account_max_inflight: Number(e.target.value || 1) } }))} className="w-full bg-background border border-border rounded-lg px-3 py-2" />
|
||||
</label>
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.accountMaxQueue')}</span>
|
||||
<input type="number" min={1} value={form.runtime.account_max_queue} onChange={(e) => setForm((prev) => ({ ...prev, runtime: { ...prev.runtime, account_max_queue: Number(e.target.value || 1) } }))} className="w-full bg-background border border-border rounded-lg px-3 py-2" />
|
||||
</label>
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.globalMaxInflight')}</span>
|
||||
<input type="number" min={1} value={form.runtime.global_max_inflight} onChange={(e) => setForm((prev) => ({ ...prev, runtime: { ...prev.runtime, global_max_inflight: Number(e.target.value || 1) } }))} className="w-full bg-background border border-border rounded-lg px-3 py-2" />
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-card border border-border rounded-xl p-5 space-y-4">
|
||||
<h3 className="font-semibold">{t('settings.behaviorTitle')}</h3>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.toolcallMode')}</span>
|
||||
<select value={form.toolcall.mode} onChange={(e) => setForm((prev) => ({ ...prev, toolcall: { ...prev.toolcall, mode: e.target.value } }))} className="w-full bg-background border border-border rounded-lg px-3 py-2">
|
||||
<option value="feature_match">feature_match</option>
|
||||
<option value="off">off</option>
|
||||
</select>
|
||||
</label>
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.earlyEmitConfidence')}</span>
|
||||
<select value={form.toolcall.early_emit_confidence} onChange={(e) => setForm((prev) => ({ ...prev, toolcall: { ...prev.toolcall, early_emit_confidence: e.target.value } }))} className="w-full bg-background border border-border rounded-lg px-3 py-2">
|
||||
<option value="high">high</option>
|
||||
<option value="low">low</option>
|
||||
<option value="off">off</option>
|
||||
</select>
|
||||
</label>
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.responsesTTL')}</span>
|
||||
<input type="number" min={30} value={form.responses.store_ttl_seconds} onChange={(e) => setForm((prev) => ({ ...prev, responses: { ...prev.responses, store_ttl_seconds: Number(e.target.value || 30) } }))} className="w-full bg-background border border-border rounded-lg px-3 py-2" />
|
||||
</label>
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.embeddingsProvider')}</span>
|
||||
<input type="text" value={form.embeddings.provider} onChange={(e) => setForm((prev) => ({ ...prev, embeddings: { ...prev.embeddings, provider: e.target.value } }))} className="w-full bg-background border border-border rounded-lg px-3 py-2" />
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-card border border-border rounded-xl p-5 space-y-4">
|
||||
<h3 className="font-semibold">{t('settings.modelTitle')}</h3>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.claudeMapping')}</span>
|
||||
<textarea value={form.claude_mapping_text} onChange={(e) => setForm((prev) => ({ ...prev, claude_mapping_text: e.target.value }))} rows={8} className="w-full bg-background border border-border rounded-lg px-3 py-2 font-mono text-xs" />
|
||||
</label>
|
||||
<label className="text-sm space-y-2">
|
||||
<span className="text-muted-foreground">{t('settings.modelAliases')}</span>
|
||||
<textarea value={form.model_aliases_text} onChange={(e) => setForm((prev) => ({ ...prev, model_aliases_text: e.target.value }))} rows={8} className="w-full bg-background border border-border rounded-lg px-3 py-2 font-mono text-xs" />
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-card border border-border rounded-xl p-5 space-y-4">
|
||||
<h3 className="font-semibold">{t('settings.backupTitle')}</h3>
|
||||
<div className="flex flex-wrap items-center gap-3">
|
||||
<button type="button" onClick={loadExportData} className="px-3 py-2 rounded-lg bg-secondary border border-border hover:bg-secondary/80 text-sm flex items-center gap-2">
|
||||
<Download className="w-4 h-4" />
|
||||
{t('settings.loadExport')}
|
||||
</button>
|
||||
<select value={importMode} onChange={(e) => setImportMode(e.target.value)} className="bg-background border border-border rounded-lg px-3 py-2 text-sm">
|
||||
<option value="merge">{t('settings.importModeMerge')}</option>
|
||||
<option value="replace">{t('settings.importModeReplace')}</option>
|
||||
</select>
|
||||
<button type="button" onClick={doImport} disabled={importing} className="px-3 py-2 rounded-lg bg-secondary border border-border hover:bg-secondary/80 text-sm flex items-center gap-2">
|
||||
<Upload className="w-4 h-4" />
|
||||
{importing ? t('settings.importing') : t('settings.importNow')}
|
||||
</button>
|
||||
</div>
|
||||
<textarea value={importText} onChange={(e) => setImportText(e.target.value)} rows={8} className="w-full bg-background border border-border rounded-lg px-3 py-2 font-mono text-xs" placeholder={t('settings.importPlaceholder')} />
|
||||
{exportData && (
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm text-muted-foreground">{t('settings.exportJson')}</label>
|
||||
<textarea value={exportData.json || ''} readOnly rows={6} className="w-full bg-background border border-border rounded-lg px-3 py-2 font-mono text-xs" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end">
|
||||
<button type="button" onClick={saveSettings} disabled={loading || saving} className="px-4 py-2 rounded-lg bg-primary text-primary-foreground hover:bg-primary/90 flex items-center gap-2">
|
||||
<Save className="w-4 h-4" />
|
||||
{saving ? t('settings.saving') : t('settings.save')}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -20,6 +20,10 @@
|
||||
"vercel": {
|
||||
"label": "Vercel Sync",
|
||||
"desc": "Sync configuration to Vercel"
|
||||
},
|
||||
"settings": {
|
||||
"label": "Settings",
|
||||
"desc": "Edit runtime and security settings online"
|
||||
}
|
||||
},
|
||||
"sidebar": {
|
||||
@@ -192,6 +196,51 @@
|
||||
"importComplete": "Import complete",
|
||||
"importSummary": "Imported {keys} API keys and updated {accounts} accounts."
|
||||
},
|
||||
"settings": {
|
||||
"loadFailed": "Failed to load settings.",
|
||||
"save": "Save settings",
|
||||
"saving": "Saving...",
|
||||
"saveSuccess": "Settings saved and hot reloaded.",
|
||||
"saveFailed": "Failed to save settings.",
|
||||
"securityTitle": "Security",
|
||||
"jwtExpireHours": "JWT expiry (hours)",
|
||||
"newPassword": "New admin password",
|
||||
"newPasswordPlaceholder": "Enter new password (min 4 chars)",
|
||||
"updatePassword": "Update password",
|
||||
"updating": "Updating...",
|
||||
"passwordTooShort": "Password must be at least 4 characters.",
|
||||
"passwordUpdated": "Password updated. Please sign in again.",
|
||||
"passwordUpdateFailed": "Failed to update password.",
|
||||
"runtimeTitle": "Concurrency & Queue",
|
||||
"accountMaxInflight": "Per-account max inflight",
|
||||
"accountMaxQueue": "Account max queue size",
|
||||
"globalMaxInflight": "Global max inflight",
|
||||
"behaviorTitle": "Behavior",
|
||||
"toolcallMode": "Toolcall mode",
|
||||
"earlyEmitConfidence": "Early emit confidence",
|
||||
"responsesTTL": "Responses store TTL (seconds)",
|
||||
"embeddingsProvider": "Embeddings provider",
|
||||
"modelTitle": "Model mapping",
|
||||
"claudeMapping": "Claude mapping (JSON)",
|
||||
"modelAliases": "Model aliases (JSON)",
|
||||
"backupTitle": "Backup & Restore",
|
||||
"loadExport": "Load current export",
|
||||
"importModeMerge": "Merge import (default)",
|
||||
"importModeReplace": "Replace all import",
|
||||
"importNow": "Import now",
|
||||
"importing": "Importing...",
|
||||
"importPlaceholder": "Paste config JSON to import",
|
||||
"importEmpty": "Please input import JSON.",
|
||||
"importInvalidJson": "Import JSON is invalid.",
|
||||
"importFailed": "Import failed.",
|
||||
"importSuccess": "Config imported (mode: {mode}).",
|
||||
"exportFailed": "Export failed.",
|
||||
"exportLoaded": "Current export loaded.",
|
||||
"exportJson": "Export JSON",
|
||||
"invalidJsonField": "{field} is not a valid JSON object.",
|
||||
"defaultPasswordWarning": "You are using the default admin password \"admin\". Please change it.",
|
||||
"vercelSyncHint": "Configuration changed. For Vercel deployments, sync manually in Vercel Sync and redeploy."
|
||||
},
|
||||
"login": {
|
||||
"welcome": "Welcome back",
|
||||
"subtitle": "Enter your admin key to continue",
|
||||
|
||||
@@ -20,6 +20,10 @@
|
||||
"vercel": {
|
||||
"label": "Vercel 同步",
|
||||
"desc": "同步配置到 Vercel"
|
||||
},
|
||||
"settings": {
|
||||
"label": "设置中心",
|
||||
"desc": "在线修改系统设置与配置"
|
||||
}
|
||||
},
|
||||
"sidebar": {
|
||||
@@ -192,6 +196,51 @@
|
||||
"importComplete": "导入操作已完成",
|
||||
"importSummary": "成功导入了 {keys} 个 API 密钥,并更新了 {accounts} 个账号。"
|
||||
},
|
||||
"settings": {
|
||||
"loadFailed": "加载设置失败",
|
||||
"save": "保存设置",
|
||||
"saving": "保存中...",
|
||||
"saveSuccess": "设置已保存并热更新生效",
|
||||
"saveFailed": "保存设置失败",
|
||||
"securityTitle": "安全设置",
|
||||
"jwtExpireHours": "JWT 有效期(小时)",
|
||||
"newPassword": "面板新密码",
|
||||
"newPasswordPlaceholder": "输入新密码(至少 4 位)",
|
||||
"updatePassword": "修改密码",
|
||||
"updating": "更新中...",
|
||||
"passwordTooShort": "新密码至少 4 位",
|
||||
"passwordUpdated": "密码已更新,需重新登录",
|
||||
"passwordUpdateFailed": "密码更新失败",
|
||||
"runtimeTitle": "并发与队列",
|
||||
"accountMaxInflight": "每账号并发上限",
|
||||
"accountMaxQueue": "账号等待队列上限",
|
||||
"globalMaxInflight": "全局并发上限",
|
||||
"behaviorTitle": "行为设置",
|
||||
"toolcallMode": "Toolcall 模式",
|
||||
"earlyEmitConfidence": "早发置信度",
|
||||
"responsesTTL": "Responses 缓存 TTL(秒)",
|
||||
"embeddingsProvider": "Embeddings Provider",
|
||||
"modelTitle": "模型映射",
|
||||
"claudeMapping": "Claude 映射(JSON)",
|
||||
"modelAliases": "模型别名(JSON)",
|
||||
"backupTitle": "备份与恢复",
|
||||
"loadExport": "加载当前导出",
|
||||
"importModeMerge": "合并导入(默认)",
|
||||
"importModeReplace": "全量覆盖导入",
|
||||
"importNow": "立即导入",
|
||||
"importing": "导入中...",
|
||||
"importPlaceholder": "粘贴要导入的 JSON 配置",
|
||||
"importEmpty": "请先输入导入 JSON",
|
||||
"importInvalidJson": "导入 JSON 格式无效",
|
||||
"importFailed": "导入失败",
|
||||
"importSuccess": "配置导入成功(模式:{mode})",
|
||||
"exportFailed": "导出失败",
|
||||
"exportLoaded": "已加载当前配置导出",
|
||||
"exportJson": "导出 JSON",
|
||||
"invalidJsonField": "{field} 不是有效 JSON 对象",
|
||||
"defaultPasswordWarning": "当前使用默认密码 admin,请尽快在此修改。",
|
||||
"vercelSyncHint": "当前配置已更新。Vercel 部署请到 Vercel 同步页面手动同步并重部署。"
|
||||
},
|
||||
"login": {
|
||||
"welcome": "欢迎回来",
|
||||
"subtitle": "请输入管理员密钥以继续",
|
||||
|
||||
Reference in New Issue
Block a user