mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-13 12:47:41 +08:00
feat: enhance tool call parsing robustness, authentication flexibility, and streaming output for tool content
This commit is contained in:
@@ -211,6 +211,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
|
||||
created := time.Now().Unix()
|
||||
firstChunkSent := false
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
@@ -240,12 +241,34 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
"tool_calls": util.FormatOpenAIToolCalls(detected),
|
||||
}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": map[string]any{"tool_calls": util.FormatOpenAIToolCalls(detected)}, "index": 0}},
|
||||
"choices": []map[string]any{{"delta": delta, "index": 0}},
|
||||
})
|
||||
} else if bufferToolContent && strings.TrimSpace(finalText) != "" {
|
||||
delta := map[string]any{
|
||||
"content": finalText,
|
||||
}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
sendChunk(map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []map[string]any{{"delta": delta, "index": 0}},
|
||||
})
|
||||
}
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
@@ -325,7 +348,9 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
}
|
||||
} else {
|
||||
text.WriteString(p.Text)
|
||||
delta["content"] = p.Text
|
||||
if !bufferToolContent {
|
||||
delta["content"] = p.Text
|
||||
}
|
||||
}
|
||||
if len(delta) > 0 {
|
||||
newChoices = append(newChoices, map[string]any{"delta": delta, "index": 0})
|
||||
|
||||
@@ -17,7 +17,7 @@ func AdminKey() string {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" {
|
||||
return v
|
||||
}
|
||||
return "your-admin-secret-key"
|
||||
return "admin"
|
||||
}
|
||||
|
||||
func jwtSecret() string {
|
||||
|
||||
@@ -15,7 +15,7 @@ type ctxKey string
|
||||
const authCtxKey ctxKey = "auth_context"
|
||||
|
||||
var (
|
||||
ErrUnauthorized = errors.New("unauthorized: missing Bearer token")
|
||||
ErrUnauthorized = errors.New("unauthorized: missing auth token")
|
||||
ErrNoAccount = errors.New("no accounts configured or all accounts are busy")
|
||||
)
|
||||
|
||||
@@ -41,11 +41,10 @@ func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Reso
|
||||
}
|
||||
|
||||
func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
authHeader := req.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
callerKey := extractCallerToken(req)
|
||||
if callerKey == "" {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
callerKey := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
|
||||
ctx := req.Context()
|
||||
if !r.Store.HasAPIKey(callerKey) {
|
||||
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
|
||||
@@ -148,3 +147,14 @@ func (r *Resolver) Release(a *RequestAuth) {
|
||||
}
|
||||
r.Pool.Release(a.AccountID)
|
||||
}
|
||||
|
||||
func extractCallerToken(req *http.Request) string {
|
||||
authHeader := strings.TrimSpace(req.Header.Get("Authorization"))
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
token := strings.TrimSpace(authHeader[7:])
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(req.Header.Get("x-api-key"))
|
||||
}
|
||||
|
||||
74
internal/auth/request_test.go
Normal file
74
internal/auth/request_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newTestResolver(t *testing.T) *Resolver {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}]
|
||||
}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
return NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "fresh-token", nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
req.Header.Set("x-api-key", "direct-token")
|
||||
|
||||
auth, err := r.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
if auth.UseConfigToken {
|
||||
t.Fatalf("expected direct token mode")
|
||||
}
|
||||
if auth.DeepSeekToken != "direct-token" {
|
||||
t.Fatalf("unexpected token: %q", auth.DeepSeekToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
|
||||
auth, err := r.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine failed: %v", err)
|
||||
}
|
||||
defer r.Release(auth)
|
||||
if !auth.UseConfigToken {
|
||||
t.Fatalf("expected managed key mode")
|
||||
}
|
||||
if auth.AccountID != "acc@example.com" {
|
||||
t.Fatalf("unexpected account id: %q", auth.AccountID)
|
||||
}
|
||||
if auth.DeepSeekToken != "account-token" {
|
||||
t.Fatalf("unexpected account token: %q", auth.DeepSeekToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineMissingToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
_, err := r.Determine(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected unauthorized error")
|
||||
}
|
||||
if err != ErrUnauthorized {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -356,5 +356,5 @@ func (s *Store) ClaudeMapping() map[string]string {
|
||||
if len(s.cfg.ClaudeMapping) > 0 {
|
||||
return cloneStringMap(s.cfg.ClaudeMapping)
|
||||
}
|
||||
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-chat"}
|
||||
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
)
|
||||
|
||||
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
|
||||
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
|
||||
|
||||
type ParsedToolCall struct {
|
||||
Name string `json:"name"`
|
||||
@@ -19,23 +20,25 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
}
|
||||
m := toolCallPattern.FindStringSubmatch(text)
|
||||
if len(m) < 2 {
|
||||
return nil
|
||||
}
|
||||
payload := "{" + `"tool_calls":[` + m[1] + "]}"
|
||||
var obj struct {
|
||||
ToolCalls []ParsedToolCall `json:"tool_calls"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(payload), &obj); err != nil {
|
||||
|
||||
candidates := buildToolCallCandidates(text)
|
||||
var parsed []ParsedToolCall
|
||||
for _, candidate := range candidates {
|
||||
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
|
||||
parsed = tc
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parsed) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
allowed := map[string]struct{}{}
|
||||
for _, name := range availableToolNames {
|
||||
allowed[name] = struct{}{}
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(obj.ToolCalls))
|
||||
for _, tc := range obj.ToolCalls {
|
||||
out := make([]ParsedToolCall, 0, len(parsed))
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
@@ -52,6 +55,220 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return out
|
||||
}
|
||||
|
||||
func buildToolCallCandidates(text string) []string {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
candidates := []string{trimmed}
|
||||
|
||||
// fenced code block candidates: ```json ... ```
|
||||
for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
|
||||
if len(match) >= 2 {
|
||||
candidates = append(candidates, strings.TrimSpace(match[1]))
|
||||
}
|
||||
}
|
||||
|
||||
// best-effort extraction around "tool_calls" key in mixed text payloads.
|
||||
candidates = append(candidates, extractToolCallObjects(trimmed)...)
|
||||
|
||||
// best-effort object slice: from first '{' to last '}'
|
||||
first := strings.Index(trimmed, "{")
|
||||
last := strings.LastIndex(trimmed, "}")
|
||||
if first >= 0 && last > first {
|
||||
candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
|
||||
}
|
||||
|
||||
// legacy regex extraction fallback
|
||||
if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
|
||||
candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
|
||||
}
|
||||
|
||||
uniq := make([]string, 0, len(candidates))
|
||||
seen := map[string]struct{}{}
|
||||
for _, c := range candidates {
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
continue
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
uniq = append(uniq, c)
|
||||
}
|
||||
return uniq
|
||||
}
|
||||
|
||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
var decoded any
|
||||
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
|
||||
return nil
|
||||
}
|
||||
switch v := decoded.(type) {
|
||||
case map[string]any:
|
||||
if tc, ok := v["tool_calls"]; ok {
|
||||
return parseToolCallList(tc)
|
||||
}
|
||||
if parsed, ok := parseToolCallItem(v); ok {
|
||||
return []ParsedToolCall{parsed}
|
||||
}
|
||||
case []any:
|
||||
return parseToolCallList(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseToolCallList(v any) []ParsedToolCall {
|
||||
items, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(items))
|
||||
for _, item := range items {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if tc, ok := parseToolCallItem(m); ok {
|
||||
out = append(out, tc)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) {
|
||||
name, _ := m["name"].(string)
|
||||
inputRaw, hasInput := m["input"]
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name, _ = fn["name"].(string)
|
||||
}
|
||||
if !hasInput {
|
||||
if v, ok := fn["arguments"]; ok {
|
||||
inputRaw = v
|
||||
hasInput = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasInput {
|
||||
for _, key := range []string{"arguments", "args", "parameters", "params"} {
|
||||
if v, ok := m[key]; ok {
|
||||
inputRaw = v
|
||||
hasInput = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
return ParsedToolCall{
|
||||
Name: strings.TrimSpace(name),
|
||||
Input: parseToolCallInput(inputRaw),
|
||||
}, true
|
||||
}
|
||||
|
||||
func parseToolCallInput(v any) map[string]any {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return map[string]any{}
|
||||
case map[string]any:
|
||||
return x
|
||||
case string:
|
||||
raw := strings.TrimSpace(x)
|
||||
if raw == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &parsed); err == nil && parsed != nil {
|
||||
return parsed
|
||||
}
|
||||
return map[string]any{"_raw": raw}
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(b, &parsed); err == nil && parsed != nil {
|
||||
return parsed
|
||||
}
|
||||
return map[string]any{}
|
||||
}
|
||||
}
|
||||
|
||||
func extractToolCallObjects(text string) []string {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(text)
|
||||
out := []string{}
|
||||
offset := 0
|
||||
for {
|
||||
idx := strings.Index(lower[offset:], "tool_calls")
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
idx += offset
|
||||
start := strings.LastIndex(text[:idx], "{")
|
||||
for start >= 0 {
|
||||
candidate, end, ok := extractJSONObject(text, start)
|
||||
if ok {
|
||||
// Move forward to avoid repeatedly matching the same object.
|
||||
offset = end
|
||||
out = append(out, strings.TrimSpace(candidate))
|
||||
break
|
||||
}
|
||||
start = strings.LastIndex(text[:start], "{")
|
||||
}
|
||||
if start < 0 {
|
||||
offset = idx + len("tool_calls")
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractJSONObject(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return text[start : i+1], i + 1, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for _, c := range calls {
|
||||
|
||||
@@ -11,6 +11,34 @@ func TestParseToolCalls(t *testing.T) {
|
||||
if calls[0].Name != "search" {
|
||||
t.Fatalf("unexpected tool name: %s", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["q"] != "golang" {
|
||||
t.Fatalf("unexpected args: %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsFromFencedJSON(t *testing.T) {
|
||||
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Input["q"] != "news" {
|
||||
t.Fatalf("unexpected args: %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) {
|
||||
text := `{"tool_calls":[{"function":{"name":"get_weather","arguments":"{\"city\":\"beijing\"}"}}]}`
|
||||
calls := ParseToolCalls(text, []string{"get_weather"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Name != "get_weather" {
|
||||
t.Fatalf("unexpected tool name: %s", calls[0].Name)
|
||||
}
|
||||
if calls[0].Input["city"] != "beijing" {
|
||||
t.Fatalf("unexpected args: %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsRejectUnknown(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user