feat: enhance tool call parsing robustness, authentication flexibility, and streaming output for tool content

This commit is contained in:
CJACK
2026-02-16 01:24:52 +08:00
parent bd788a12b1
commit 57f2041edb
14 changed files with 1753 additions and 1519 deletions

View File

@@ -211,6 +211,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
created := time.Now().Unix()
firstChunkSent := false
bufferToolContent := len(toolNames) > 0
currentType := "text"
if thinkingEnabled {
currentType = "thinking"
@@ -240,12 +241,34 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) > 0 {
finishReason = "tool_calls"
delta := map[string]any{
"tool_calls": util.FormatOpenAIToolCalls(detected),
}
if !firstChunkSent {
delta["role"] = "assistant"
firstChunkSent = true
}
sendChunk(map[string]any{
"id": completionID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]any{{"delta": map[string]any{"tool_calls": util.FormatOpenAIToolCalls(detected)}, "index": 0}},
"choices": []map[string]any{{"delta": delta, "index": 0}},
})
} else if bufferToolContent && strings.TrimSpace(finalText) != "" {
delta := map[string]any{
"content": finalText,
}
if !firstChunkSent {
delta["role"] = "assistant"
firstChunkSent = true
}
sendChunk(map[string]any{
"id": completionID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]any{{"delta": delta, "index": 0}},
})
}
promptTokens := util.EstimateTokens(finalPrompt)
@@ -325,7 +348,9 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
}
} else {
text.WriteString(p.Text)
delta["content"] = p.Text
if !bufferToolContent {
delta["content"] = p.Text
}
}
if len(delta) > 0 {
newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0})

View File

@@ -17,7 +17,7 @@ func AdminKey() string {
if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" {
return v
}
return "your-admin-secret-key"
return "admin"
}
func jwtSecret() string {

View File

@@ -15,7 +15,7 @@ type ctxKey string
const authCtxKey ctxKey = "auth_context"
var (
ErrUnauthorized = errors.New("unauthorized: missing Bearer token")
ErrUnauthorized = errors.New("unauthorized: missing auth token")
ErrNoAccount = errors.New("no accounts configured or all accounts are busy")
)
@@ -41,11 +41,10 @@ func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Reso
}
func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
authHeader := req.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, "Bearer ") {
callerKey := extractCallerToken(req)
if callerKey == "" {
return nil, ErrUnauthorized
}
callerKey := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
ctx := req.Context()
if !r.Store.HasAPIKey(callerKey) {
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
@@ -148,3 +147,14 @@ func (r *Resolver) Release(a *RequestAuth) {
}
r.Pool.Release(a.AccountID)
}
func extractCallerToken(req *http.Request) string {
authHeader := strings.TrimSpace(req.Header.Get("Authorization"))
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
token := strings.TrimSpace(authHeader[7:])
if token != "" {
return token
}
}
return strings.TrimSpace(req.Header.Get("x-api-key"))
}

View File

@@ -0,0 +1,74 @@
package auth
import (
"context"
"net/http"
"testing"
"ds2api/internal/account"
"ds2api/internal/config"
)
func newTestResolver(t *testing.T) *Resolver {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}]
}`)
store := config.LoadStore()
pool := account.NewPool(store)
return NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "fresh-token", nil
})
}
func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
req.Header.Set("x-api-key", "direct-token")
auth, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
if auth.UseConfigToken {
t.Fatalf("expected direct token mode")
}
if auth.DeepSeekToken != "direct-token" {
t.Fatalf("unexpected token: %q", auth.DeepSeekToken)
}
}
func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
req.Header.Set("x-api-key", "managed-key")
auth, err := r.Determine(req)
if err != nil {
t.Fatalf("determine failed: %v", err)
}
defer r.Release(auth)
if !auth.UseConfigToken {
t.Fatalf("expected managed key mode")
}
if auth.AccountID != "acc@example.com" {
t.Fatalf("unexpected account id: %q", auth.AccountID)
}
if auth.DeepSeekToken != "account-token" {
t.Fatalf("unexpected account token: %q", auth.DeepSeekToken)
}
}
func TestDetermineMissingToken(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
_, err := r.Determine(req)
if err == nil {
t.Fatal("expected unauthorized error")
}
if err != ErrUnauthorized {
t.Fatalf("unexpected error: %v", err)
}
}

View File

@@ -356,5 +356,5 @@ func (s *Store) ClaudeMapping() map[string]string {
if len(s.cfg.ClaudeMapping) > 0 {
return cloneStringMap(s.cfg.ClaudeMapping)
}
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-chat"}
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
}

View File

@@ -9,6 +9,7 @@ import (
)
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
type ParsedToolCall struct {
Name string `json:"name"`
@@ -19,23 +20,25 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
if strings.TrimSpace(text) == "" {
return nil
}
m := toolCallPattern.FindStringSubmatch(text)
if len(m) < 2 {
return nil
}
payload := "{" + `"tool_calls":[` + m[1] + "]}"
var obj struct {
ToolCalls []ParsedToolCall `json:"tool_calls"`
}
if err := json.Unmarshal([]byte(payload), &obj); err != nil {
candidates := buildToolCallCandidates(text)
var parsed []ParsedToolCall
for _, candidate := range candidates {
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
parsed = tc
break
}
}
if len(parsed) == 0 {
return nil
}
allowed := map[string]struct{}{}
for _, name := range availableToolNames {
allowed[name] = struct{}{}
}
out := make([]ParsedToolCall, 0, len(obj.ToolCalls))
for _, tc := range obj.ToolCalls {
out := make([]ParsedToolCall, 0, len(parsed))
for _, tc := range parsed {
if tc.Name == "" {
continue
}
@@ -52,6 +55,220 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return out
}
func buildToolCallCandidates(text string) []string {
trimmed := strings.TrimSpace(text)
candidates := []string{trimmed}
// fenced code block candidates: ```json ... ```
for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
if len(match) >= 2 {
candidates = append(candidates, strings.TrimSpace(match[1]))
}
}
// best-effort extraction around "tool_calls" key in mixed text payloads.
candidates = append(candidates, extractToolCallObjects(trimmed)...)
// best-effort object slice: from first '{' to last '}'
first := strings.Index(trimmed, "{")
last := strings.LastIndex(trimmed, "}")
if first >= 0 && last > first {
candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
}
// legacy regex extraction fallback
if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
}
uniq := make([]string, 0, len(candidates))
seen := map[string]struct{}{}
for _, c := range candidates {
if c == "" {
continue
}
if _, ok := seen[c]; ok {
continue
}
seen[c] = struct{}{}
uniq = append(uniq, c)
}
return uniq
}
func parseToolCallsPayload(payload string) []ParsedToolCall {
var decoded any
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
return nil
}
switch v := decoded.(type) {
case map[string]any:
if tc, ok := v["tool_calls"]; ok {
return parseToolCallList(tc)
}
if parsed, ok := parseToolCallItem(v); ok {
return []ParsedToolCall{parsed}
}
case []any:
return parseToolCallList(v)
}
return nil
}
func parseToolCallList(v any) []ParsedToolCall {
items, ok := v.([]any)
if !ok {
return nil
}
out := make([]ParsedToolCall, 0, len(items))
for _, item := range items {
m, ok := item.(map[string]any)
if !ok {
continue
}
if tc, ok := parseToolCallItem(m); ok {
out = append(out, tc)
}
}
if len(out) == 0 {
return nil
}
return out
}
func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) {
name, _ := m["name"].(string)
inputRaw, hasInput := m["input"]
if fn, ok := m["function"].(map[string]any); ok {
if name == "" {
name, _ = fn["name"].(string)
}
if !hasInput {
if v, ok := fn["arguments"]; ok {
inputRaw = v
hasInput = true
}
}
}
if !hasInput {
for _, key := range []string{"arguments", "args", "parameters", "params"} {
if v, ok := m[key]; ok {
inputRaw = v
hasInput = true
break
}
}
}
if strings.TrimSpace(name) == "" {
return ParsedToolCall{}, false
}
return ParsedToolCall{
Name: strings.TrimSpace(name),
Input: parseToolCallInput(inputRaw),
}, true
}
func parseToolCallInput(v any) map[string]any {
switch x := v.(type) {
case nil:
return map[string]any{}
case map[string]any:
return x
case string:
raw := strings.TrimSpace(x)
if raw == "" {
return map[string]any{}
}
var parsed map[string]any
if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil {
return parsed
}
return map[string]any{"_raw": raw}
default:
b, err := json.Marshal(x)
if err != nil {
return map[string]any{}
}
var parsed map[string]any
if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil {
return parsed
}
return map[string]any{}
}
}
func extractToolCallObjects(text string) []string {
if text == "" {
return nil
}
lower := strings.ToLower(text)
out := []string{}
offset := 0
for {
idx := strings.Index(lower[offset:], "tool_calls")
if idx < 0 {
break
}
idx += offset
start := strings.LastIndex(text[:idx], "{")
for start >= 0 {
candidate, end, ok := extractJSONObject(text, start)
if ok {
// Move forward to avoid repeatedly matching the same object.
offset = end
out = append(out, strings.TrimSpace(candidate))
break
}
start = strings.LastIndex(text[:start], "{")
}
if start < 0 {
offset = idx + len("tool_calls")
}
}
return out
}
func extractJSONObject(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '{' {
return "", 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
return text[start : i+1], i + 1, true
}
}
}
return "", 0, false
}
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
out := make([]map[string]any, 0, len(calls))
for _, c := range calls {

View File

@@ -11,6 +11,34 @@ func TestParseToolCalls(t *testing.T) {
if calls[0].Name != "search" {
t.Fatalf("unexpected tool name: %s", calls[0].Name)
}
if calls[0].Input["q"] != "golang" {
t.Fatalf("unexpected args: %#v", calls[0].Input)
}
}
func TestParseToolCallsFromFencedJSON(t *testing.T) {
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Input["q"] != "news" {
t.Fatalf("unexpected args: %#v", calls[0].Input)
}
}
func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) {
text := `{"tool_calls":[{"function":{"name":"get_weather","arguments":"{\"city\":\"beijing\"}"}}]}`
calls := ParseToolCalls(text, []string{"get_weather"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Name != "get_weather" {
t.Fatalf("unexpected tool name: %s", calls[0].Name)
}
if calls[0].Input["city"] != "beijing" {
t.Fatalf("unexpected args: %#v", calls[0].Input)
}
}
func TestParseToolCallsRejectUnknown(t *testing.T) {