diff --git a/README.MD b/README.MD index 9b77d09..0ce3ec7 100644 --- a/README.MD +++ b/README.MD @@ -1,5 +1,5 @@

- DS2API icon + DS2API icon

# DS2API @@ -10,6 +10,7 @@ [![Release](https://img.shields.io/github/v/release/CJackHwang/ds2api?display_name=tag)](https://github.com/CJackHwang/ds2api/releases) [![Docker](https://img.shields.io/badge/docker-ready-blue.svg)](DEPLOY.md) [![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/L4CFHP) +[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api) 语言 / Language: [中文](README.MD) | [English](README.en.md) diff --git a/README.en.md b/README.en.md index add018b..fe07911 100644 --- a/README.en.md +++ b/README.en.md @@ -1,5 +1,5 @@

- DS2API icon + DS2API icon

# DS2API @@ -10,6 +10,7 @@ [![Release](https://img.shields.io/github/v/release/CJackHwang/ds2api?display_name=tag)](https://github.com/CJackHwang/ds2api/releases) [![Docker](https://img.shields.io/badge/docker-ready-blue.svg)](DEPLOY.en.md) [![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/L4CFHP) +[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api) Language: [中文](README.MD) | [English](README.en.md) diff --git a/assets/ds2api-icon.svg b/assets/ds2api-icon.svg deleted file mode 100644 index faf8eb3..0000000 --- a/assets/ds2api-icon.svg +++ /dev/null @@ -1,63 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docker-compose.yml b/docker-compose.yml index 3e6b605..e5e2ff1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,6 @@ -services: - ds2api: - image: crpi-cnazxqmg4avmg4fq.cn-beijing.personal.cr.aliyuncs.com/ronghuaxueleng/ds2api:latest +services: + ds2api: + image: ghcr.io/cjackhwang/ds2api:latest container_name: ds2api restart: always ports: diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index a5ecbd6..5cd16da 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -98,7 +98,7 @@ func (s *chatStreamRuntime) sendDone() { func (s *chatStreamRuntime) finalize(finishReason string) { finalThinking := s.thinking.String() finalText := s.text.String() - detected := util.ParseToolCalls(finalText, s.toolNames) + detected := util.ParseStandaloneToolCalls(finalText, s.toolNames) if len(detected) > 0 && !s.toolCallsDoneEmitted { finishReason = "tool_calls" delta := map[string]any{ diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index 895605f..5e78f0b 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -210,7 +211,7 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) { } } -func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) { +func TestHandleNonStreamEmbeddedToolCallExampleRemainsText(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"下面是示例:"}`, @@ -228,16 +229,16 @@ func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) { out := decodeJSONBody(t, rec.Body.String()) choices, _ := out["choices"].([]any) choice, _ := choices[0].(map[string]any) - if choice["finish_reason"] != "tool_calls" { - t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"]) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) } msg, _ := choice["message"].(map[string]any) - toolCalls, _ := msg["tool_calls"].([]any) - if len(toolCalls) == 0 { - t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"]) + if _, ok := msg["tool_calls"]; ok { + t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"]) } - if msg["content"] != nil { - t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"]) + content, _ := msg["content"].(string) + if !strings.Contains(content, "下面是示例:") || !strings.Contains(content, "请勿执行。") || !strings.Contains(content, `"tool_calls"`) { + t.Fatalf("expected embedded example to remain plain text, got %#v", content) } } @@ -315,6 +316,36 @@ func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) { } } +func TestHandleStreamToolCallLargeArgumentsStillIntercepted(t *testing.T) { + h := &Handler{} + large := strings.Repeat("a", 9000) + payload := fmt.Sprintf(`{"tool_calls":[{"name":"search","input":{"q":"%s"}}]}`, large) + splitAt := len(payload) / 2 + resp := makeSSEHTTPResponse( + fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[:splitAt]), + fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[splitAt:]), + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid3-large", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + if streamHasRawToolJSONContent(frames) { + t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + } +} + func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( @@ -482,8 +513,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String()) + if streamHasToolCallsDelta(frames) { + t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String()) } content := strings.Builder{} for _, frame := range frames { @@ -500,15 +531,15 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") { t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got) } - if strings.Contains(strings.ToLower(got), `"tool_calls"`) { - t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got) + if !strings.Contains(strings.ToLower(got), `"tool_calls"`) { + t.Fatalf("expected embedded tool json to remain text in strict mode, got=%q", got) } - if streamFinishReason(frames) != "tool_calls" { - t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String()) + if streamFinishReason(frames) != "stop" { + t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String()) } } -func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) { +func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"我将调用工具。"}`, @@ -524,8 +555,8 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + if streamHasToolCallsDelta(frames) { + t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String()) } content := strings.Builder{} for _, frame := range frames { @@ -542,15 +573,15 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) { if !strings.Contains(got, "我将调用工具。") { t.Fatalf("expected leading text to keep streaming, got=%q", got) } - if strings.Contains(strings.ToLower(got), "tool_calls") { - t.Fatalf("unexpected raw tool json leak, got=%q", got) + if !strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("expected tool_calls example text preserved, got=%q", got) } - if streamFinishReason(frames) != "tool_calls" { - t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + if streamFinishReason(frames) != "stop" { + t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String()) } } -func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) { +func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`, @@ -565,8 +596,8 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testin if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + if streamHasToolCallsDelta(frames) { + t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String()) } content := strings.Builder{} for _, frame := range frames { @@ -583,15 +614,15 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testin if !strings.Contains(got, "接下来我会继续说明。") { t.Fatalf("expected trailing plain text to be preserved, got=%q", got) } - if strings.Contains(strings.ToLower(got), "tool_calls") { - t.Fatalf("unexpected raw tool json leak, got=%q", got) + if !strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("expected tool_calls example text preserved, got=%q", got) } - if streamFinishReason(frames) != "tool_calls" { - t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + if streamFinishReason(frames) != "stop" { + t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String()) } } -func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) { +func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) { h := &Handler{} spaces := strings.Repeat(" ", 200) resp := makeSSEHTTPResponse( @@ -609,11 +640,8 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) - } - if streamHasRawToolJSONContent(frames) { - t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + if streamHasToolCallsDelta(frames) { + t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String()) } content := strings.Builder{} for _, frame := range frames { @@ -627,14 +655,14 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) { } } got := content.String() - if strings.Contains(got, "{") { - t.Fatalf("unexpected suspicious prefix leak in content: %q", got) + if !strings.Contains(strings.ToLower(got), "tool_calls") || !strings.Contains(got, "{") { + t.Fatalf("expected embedded tool json to remain in text, got=%q", got) } if !strings.Contains(got, "后置正文C。") { t.Fatalf("expected stream to continue after tool json convergence, got=%q", got) } - if streamFinishReason(frames) != "tool_calls" { - t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + if streamFinishReason(frames) != "stop" { + t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String()) } } @@ -712,7 +740,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin } } -func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) { +func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`, @@ -735,8 +763,8 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) { t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) } argChunks := streamToolCallArgumentChunks(frames) - if len(argChunks) < 2 { - t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String()) + if len(argChunks) == 0 { + t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String()) } joined := strings.Join(argChunks, "") if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) { diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go index 94b2339..8c6bb8f 100644 --- a/internal/adapter/openai/message_normalize.go +++ b/internal/adapter/openai/message_normalize.go @@ -3,7 +3,6 @@ package openai import ( "encoding/json" "fmt" - "io" "strings" "ds2api/internal/config" @@ -175,30 +174,11 @@ func normalizeToolArgumentString(raw string) string { if trimmed == "" { return "" } - if !looksLikeConcatenatedJSON(trimmed) { - return trimmed + if looksLikeConcatenatedJSON(trimmed) { + // Keep original payload to avoid silent argument rewrites. + return raw } - dec := json.NewDecoder(strings.NewReader(trimmed)) - values := make([]any, 0, 2) - for { - var v any - if err := dec.Decode(&v); err != nil { - if err == io.EOF { - break - } - return trimmed - } - values = append(values, v) - } - if len(values) < 2 { - return trimmed - } - last := values[len(values)-1] - b, err := json.Marshal(last) - if err != nil || len(b) == 0 { - return trimmed - } - return string(b) + return trimmed } func marshalToPromptString(v any) string { diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go index ff36bd9..1abe426 100644 --- a/internal/adapter/openai/message_normalize_test.go +++ b/internal/adapter/openai/message_normalize_test.go @@ -168,7 +168,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara } } -func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) { +func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *testing.T) { raw := []any{ map[string]any{ "role": "assistant", @@ -189,10 +189,7 @@ func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *te t.Fatalf("expected one normalized message, got %d", len(normalized)) } content, _ := normalized[0]["content"].(string) - if !strings.Contains(content, `function.arguments: {"query":"测试工具调用"}`) { - t.Fatalf("expected repaired arguments in tool history, got %q", content) - } - if strings.Contains(content, `{}{"query":"测试工具调用"}`) { - t.Fatalf("expected concatenated JSON to be repaired, got %q", content) + if !strings.Contains(content, `function.arguments: {}{"query":"测试工具调用"}`) { + t.Fatalf("expected original concatenated arguments in tool history, got %q", content) } } diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go index a586682..2907bd6 100644 --- a/internal/adapter/openai/responses_embeddings_test.go +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -135,7 +135,7 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) { } } -func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) { +func TestNormalizeResponsesInputAsMessagesFunctionCallItemPreservesConcatenatedArguments(t *testing.T) { msgs := normalizeResponsesInputAsMessages([]any{ map[string]any{ "type": "function_call", @@ -151,8 +151,8 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArg toolCalls, _ := m["tool_calls"].([]any) call, _ := toolCalls[0].(map[string]any) fn, _ := call["function"].(map[string]any) - if fn["arguments"] != `{"q":"golang"}` { - t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"]) + if fn["arguments"] != `{}{"q":"golang"}` { + t.Fatalf("expected original concatenated call arguments preserved, got %#v", fn["arguments"]) } } diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 81da92d..e4b1de8 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -113,15 +113,10 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res return } result := sse.CollectStream(resp, thinkingEnabled, true) - textParsed := util.ParseToolCallsDetailed(result.Text, toolNames) - thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames) + textParsed := util.ParseStandaloneToolCallsDetailed(result.Text, toolNames) logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text") - logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking") callCount := len(textParsed.Calls) - if callCount == 0 { - callCount = len(thinkingParsed.Calls) - } if toolChoice.IsRequired() && callCount == 0 { writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation") return diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go index 02303d0..e8ec6df 100644 --- a/internal/adapter/openai/responses_stream_runtime_core.go +++ b/internal/adapter/openai/responses_stream_runtime_core.go @@ -102,16 +102,11 @@ func (s *responsesStreamRuntime) finalize() { if s.bufferToolContent { s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) - s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false) } - textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames) - thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames) + textParsed := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames) detected := textParsed.Calls - if len(detected) == 0 { - detected = thinkingParsed.Calls - } - s.logToolPolicyRejections(textParsed, thinkingParsed) + s.logToolPolicyRejections(textParsed) if len(detected) > 0 { s.toolCallsEmitted = true @@ -157,7 +152,7 @@ func (s *responsesStreamRuntime) finalize() { s.sendDone() } -func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) { +func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed util.ToolCallParseResult) { logRejected := func(parsed util.ToolCallParseResult, channel string) { rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames) if !parsed.RejectedByPolicy || len(rejected) == 0 { @@ -172,7 +167,6 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingPar ) } logRejected(textParsed, "text") - logRejected(thinkingParsed, "thinking") } func (s *responsesStreamRuntime) hasFunctionCallDone() bool { @@ -207,9 +201,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa } s.thinking.WriteString(p.Text) s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) - if s.bufferToolContent { - s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false) - } continue } diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index ca3c4a3..90ade96 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -99,9 +99,6 @@ func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) { if !strings.Contains(body, "event: response.output_item.done") { t.Fatalf("expected response.output_item.done event, body=%s", body) } - if !strings.Contains(body, "event: response.function_call_arguments.delta") { - t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body) - } if !strings.Contains(body, "event: response.function_call_arguments.done") { t.Fatalf("expected response.function_call_arguments.done event, body=%s", body) } @@ -266,7 +263,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) { } } -func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) { +func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) rec := httptest.NewRecorder() @@ -291,23 +288,12 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added") - if len(addedPayloads) < 2 { - t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String()) + if len(addedPayloads) != 1 { + t.Fatalf("expected only one message output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String()) } - - indexes := map[int]struct{}{} - typeByIndex := map[int]string{} - addedIDs := map[string]string{} - for _, payload := range addedPayloads { - item, _ := payload["item"].(map[string]any) - itemType := strings.TrimSpace(asString(item["type"])) - outputIndex := int(asFloat(payload["output_index"])) - if _, exists := indexes[outputIndex]; exists { - t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload) - } - indexes[outputIndex] = struct{}{} - typeByIndex[outputIndex] = itemType - addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"])) + item, _ := addedPayloads[0]["item"].(map[string]any) + if asString(item["type"]) != "message" { + t.Fatalf("expected only message output item in strict mode, got %#v", item) } completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") @@ -316,21 +302,15 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes } responseObj, _ := completedPayload["response"].(map[string]any) output, _ := responseObj["output"].([]any) - found := map[string]bool{} for _, item := range output { m, _ := item.(map[string]any) - itemType := strings.TrimSpace(asString(m["type"])) - itemID := strings.TrimSpace(asString(m["id"])) - if itemType == "" || itemID == "" { + if m == nil { continue } - if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID { - found[itemType] = true + if asString(m["type"]) == "function_call" { + t.Fatalf("did not expect function_call output for mixed prose tool example, output=%#v", output) } } - if !found["message"] || !found["function_call"] { - t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output) - } } func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { @@ -360,7 +340,7 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { } } -func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) { +func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) rec := httptest.NewRecorder() @@ -373,7 +353,7 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t return "data: " + string(b) + "\n" } - // invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it + // invalid JSON (NaN) should remain plain text in strict mode. streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n" resp := &http.Response{ StatusCode: http.StatusOK, @@ -382,14 +362,11 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "") body := rec.Body.String() - if !strings.Contains(body, "event: response.function_call_arguments.delta") { - t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body) + if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") { + t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body) } - if !strings.Contains(body, "event: response.function_call_arguments.done") { - t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body) - } - if !strings.Contains(body, "event: response.output_item.done") { - t.Fatalf("expected runtime to close function output item, body=%s", body) + if !strings.Contains(body, "event: response.output_text.delta") { + t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body) } if !strings.Contains(body, "event: response.completed") { t.Fatalf("expected response.completed event, body=%s", body) @@ -430,6 +407,42 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) { } } +func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(path, value string) string { + b, _ := json.Marshal(map[string]any{ + "p": path, + "v": value, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + + sseLine("response/content", "plain text only") + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + policy := util.ToolChoicePolicy{ + Mode: util.ToolChoiceRequired, + Allowed: map[string]struct{}{"read_file": {}}, + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "") + body := rec.Body.String() + if !strings.Contains(body, "event: response.failed") { + t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body) + } + if strings.Contains(body, "event: response.completed") { + t.Fatalf("did not expect response.completed after failure, body=%s", body) + } +} + func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) @@ -516,6 +529,33 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) { } } +func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) { + h := &Handler{} + rec := httptest.NewRecorder() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" + + `data: {"p":"response/content","v":"plain text only"}` + "\n" + + `data: [DONE]` + "\n", + )), + } + policy := util.ToolChoicePolicy{ + Mode: util.ToolChoiceRequired, + Allowed: map[string]struct{}{"read_file": {}}, + } + + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "") + if rec.Code != http.StatusUnprocessableEntity { + t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String()) + } + out := decodeJSONBody(t, rec.Body.String()) + errObj, _ := out["error"].(map[string]any) + if asString(errObj["code"]) != "tool_choice_violation" { + t.Fatalf("expected code=tool_choice_violation, got %#v", out) + } +} + func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) { h := &Handler{} rec := httptest.NewRecorder() diff --git a/internal/adapter/openai/stream_status_test.go b/internal/adapter/openai/stream_status_test.go index 4f8305a..4d66b46 100644 --- a/internal/adapter/openai/stream_status_test.go +++ b/internal/adapter/openai/stream_status_test.go @@ -167,19 +167,15 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) { t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String()) } outputText, _ := out["output_text"].(string) - if outputText != "" { - t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText) + if outputText == "" { + t.Fatalf("expected output_text preserved for mixed prose payload") } output, _ := out["output"].([]any) - hasFunctionCall := false - for _, item := range output { - m, _ := item.(map[string]any) - if m != nil && m["type"] == "function_call" { - hasFunctionCall = true - break - } + if len(output) != 1 { + t.Fatalf("expected one output item, got %#v", output) } - if !hasFunctionCall { - t.Fatalf("expected function_call output item, got %#v", output) + first, _ := output[0].(map[string]any) + if first["type"] != "message" { + t.Fatalf("expected message output item, got %#v", output) } } diff --git a/internal/adapter/openai/tool_sieve_core.go b/internal/adapter/openai/tool_sieve_core.go index 5ed9b90..fd0261d 100644 --- a/internal/adapter/openai/tool_sieve_core.go +++ b/internal/adapter/openai/tool_sieve_core.go @@ -14,6 +14,21 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames state.pending.WriteString(chunk) } events := make([]toolStreamEvent, 0, 2) + if len(state.pendingToolCalls) > 0 { + pending := state.pending.String() + if strings.TrimSpace(pending) != "" { + content := state.pendingToolRaw + pending + state.pending.Reset() + state.pendingToolRaw = "" + state.pendingToolCalls = nil + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) + } else { + // Wait for either more non-whitespace content (demote to plain text) + // or stream flush (promote to executable tool calls). + return events + } + } for { if state.capturing { @@ -21,32 +36,23 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames state.capture.WriteString(state.pending.String()) state.pending.Reset() } - if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 { - events = append(events, toolStreamEvent{ToolCallDeltas: deltas}) - } prefix, calls, suffix, ready := consumeToolCapture(state, toolNames) if !ready { - if state.capture.Len() > toolSieveCaptureLimit { - content := state.capture.String() - state.capture.Reset() - state.capturing = false - state.resetIncrementalToolState() - state.noteText(content) - events = append(events, toolStreamEvent{Content: content}) - continue - } break } + captured := state.capture.String() state.capture.Reset() state.capturing = false state.resetIncrementalToolState() + if len(calls) > 0 { + state.pendingToolRaw = captured + state.pendingToolCalls = calls + continue + } if prefix != "" { state.noteText(prefix) events = append(events, toolStreamEvent{Content: prefix}) } - if len(calls) > 0 { - events = append(events, toolStreamEvent{ToolCalls: calls}) - } if suffix != "" { state.pending.WriteString(suffix) } @@ -89,6 +95,11 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea return nil } events := processToolSieveChunk(state, "", toolNames) + if len(state.pendingToolCalls) > 0 { + events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls}) + state.pendingToolRaw = "" + state.pendingToolCalls = nil + } if state.capturing { consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames) if ready { @@ -200,6 +211,11 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix if insideCodeFence(state.recentTextTail + prefixPart) { return captured, nil, "", true } + // Strict mode: only standalone tool payloads are executable. If the + // payload is wrapped by non-whitespace prose, keep it as plain text. + if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" { + return captured, nil, "", true + } parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames) if len(parsed.Calls) == 0 { if parsed.SawToolCallSyntax && parsed.RejectedByPolicy { diff --git a/internal/adapter/openai/tool_sieve_state.go b/internal/adapter/openai/tool_sieve_state.go index 04699e6..1db9413 100644 --- a/internal/adapter/openai/tool_sieve_state.go +++ b/internal/adapter/openai/tool_sieve_state.go @@ -7,17 +7,19 @@ import ( ) type toolStreamSieveState struct { - pending strings.Builder - capture strings.Builder - capturing bool - recentTextTail string - disableDeltas bool - toolNameSent bool - toolName string - toolArgsStart int - toolArgsSent int - toolArgsString bool - toolArgsDone bool + pending strings.Builder + capture strings.Builder + capturing bool + recentTextTail string + pendingToolRaw string + pendingToolCalls []util.ParsedToolCall + disableDeltas bool + toolNameSent bool + toolName string + toolArgsStart int + toolArgsSent int + toolArgsString bool + toolArgsDone bool } type toolStreamEvent struct { @@ -32,7 +34,6 @@ type toolCallDelta struct { Arguments string } -const toolSieveCaptureLimit = 8 * 1024 const toolSieveContextTailLimit = 256 func (s *toolStreamSieveState) resetIncrementalToolState() { diff --git a/internal/admin/handler_accounts_crud.go b/internal/admin/handler_accounts_crud.go index 768e59e..6536760 100644 --- a/internal/admin/handler_accounts_crud.go +++ b/internal/admin/handler_accounts_crud.go @@ -1,128 +1,133 @@ -package admin - -import ( - "encoding/json" - "fmt" - "net/http" - "strings" - - "github.com/go-chi/chi/v5" - - "ds2api/internal/config" -) - -func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { - page := intFromQuery(r, "page", 1) - pageSize := intFromQuery(r, "page_size", 10) - if page < 1 { - page = 1 - } - if pageSize < 1 { - pageSize = 1 - } - if pageSize > 100 { - pageSize = 100 - } - accounts := h.Store.Snapshot().Accounts - reverseAccounts(accounts) - q := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("q"))) - if q != "" { - filtered := make([]config.Account, 0, len(accounts)) - for _, acc := range accounts { - id := strings.ToLower(acc.Identifier()) - if strings.Contains(id, q) || - strings.Contains(strings.ToLower(acc.Email), q) || - strings.Contains(strings.ToLower(acc.Mobile), q) { - filtered = append(filtered, acc) - } - } - accounts = filtered - } - total := len(accounts) - totalPages := 1 - if total > 0 { - totalPages = (total + pageSize - 1) / pageSize - } - start := (page - 1) * pageSize - if start > total { - start = total - } - end := start + pageSize - if end > total { - end = total - } - items := make([]map[string]any, 0, end-start) - for _, acc := range accounts[start:end] { - token := strings.TrimSpace(acc.Token) - preview := "" - if token != "" { - if len(token) > 20 { - preview = token[:20] + "..." - } else { - preview = token - } - } - items = append(items, map[string]any{ - "identifier": acc.Identifier(), - "email": acc.Email, - "mobile": acc.Mobile, - "has_password": acc.Password != "", - "has_token": token != "", - "token_preview": preview, - "test_status": acc.TestStatus, - }) - } - writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) -} - -func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - acc := toAccount(req) - if acc.Identifier() == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"}) - return - } - err := h.Store.Update(func(c *config.Config) error { - for _, a := range c.Accounts { - if acc.Email != "" && a.Email == acc.Email { - return fmt.Errorf("邮箱已存在") - } - if acc.Mobile != "" && a.Mobile == acc.Mobile { - return fmt.Errorf("手机号已存在") - } - } - c.Accounts = append(c.Accounts, acc) - return nil - }) - if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) -} - -func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { - identifier := chi.URLParam(r, "identifier") - err := h.Store.Update(func(c *config.Config) error { - idx := -1 - for i, a := range c.Accounts { - if accountMatchesIdentifier(a, identifier) { - idx = i - break - } - } - if idx < 0 { - return fmt.Errorf("账号不存在") - } - c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...) - return nil - }) - if err != nil { - writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) -} +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/config" +) + +func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { + page := intFromQuery(r, "page", 1) + pageSize := intFromQuery(r, "page_size", 10) + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 1 + } + if pageSize > 100 { + pageSize = 100 + } + accounts := h.Store.Snapshot().Accounts + reverseAccounts(accounts) + q := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("q"))) + if q != "" { + filtered := make([]config.Account, 0, len(accounts)) + for _, acc := range accounts { + id := strings.ToLower(acc.Identifier()) + if strings.Contains(id, q) || + strings.Contains(strings.ToLower(acc.Email), q) || + strings.Contains(strings.ToLower(acc.Mobile), q) { + filtered = append(filtered, acc) + } + } + accounts = filtered + } + total := len(accounts) + totalPages := 1 + if total > 0 { + totalPages = (total + pageSize - 1) / pageSize + } + start := (page - 1) * pageSize + if start > total { + start = total + } + end := start + pageSize + if end > total { + end = total + } + items := make([]map[string]any, 0, end-start) + for _, acc := range accounts[start:end] { + token := strings.TrimSpace(acc.Token) + preview := "" + if token != "" { + if len(token) > 20 { + preview = token[:20] + "..." + } else { + preview = token + } + } + items = append(items, map[string]any{ + "identifier": acc.Identifier(), + "email": acc.Email, + "mobile": acc.Mobile, + "has_password": acc.Password != "", + "has_token": token != "", + "token_preview": preview, + "test_status": acc.TestStatus, + }) + } + writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) +} + +func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + acc := toAccount(req) + if acc.Identifier() == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"}) + return + } + err := h.Store.Update(func(c *config.Config) error { + mobileKey := config.CanonicalMobileKey(acc.Mobile) + for _, a := range c.Accounts { + if acc.Email != "" && a.Email == acc.Email { + return fmt.Errorf("邮箱已存在") + } + if mobileKey != "" && config.CanonicalMobileKey(a.Mobile) == mobileKey { + return fmt.Errorf("手机号已存在") + } + } + c.Accounts = append(c.Accounts, acc) + return nil + }) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} + +func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { + identifier := chi.URLParam(r, "identifier") + if decoded, err := url.PathUnescape(identifier); err == nil { + identifier = decoded + } + err := h.Store.Update(func(c *config.Config) error { + idx := -1 + for i, a := range c.Accounts { + if accountMatchesIdentifier(a, identifier) { + idx = i + break + } + } + if idx < 0 { + return fmt.Errorf("账号不存在") + } + c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...) + return nil + }) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} diff --git a/internal/admin/handler_accounts_identifier_test.go b/internal/admin/handler_accounts_identifier_test.go index 591d43a..b6f63ca 100644 --- a/internal/admin/handler_accounts_identifier_test.go +++ b/internal/admin/handler_accounts_identifier_test.go @@ -1,6 +1,7 @@ package admin import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -102,6 +103,45 @@ func TestDeleteAccountSupportsMobileAlias(t *testing.T) { } } +func TestDeleteAccountSupportsEncodedPlusMobile(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"mobile":"+8613800138000","password":"pwd"}] + }`) + + r := chi.NewRouter() + r.Delete("/admin/accounts/{identifier}", h.deleteAccount) + req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape("+8613800138000"), nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + if got := len(h.Store.Accounts()); got != 0 { + t.Fatalf("expected account removed, remaining=%d", got) + } +} + +func TestAddAccountRejectsCanonicalMobileDuplicate(t *testing.T) { + h := newAdminTestHandler(t, `{ + "accounts":[{"mobile":"+8613800138000","password":"pwd"}] + }`) + + r := chi.NewRouter() + r.Post("/admin/accounts", h.addAccount) + body := []byte(`{"mobile":"13800138000","password":"pwd2"}`) + req := httptest.NewRequest(http.MethodPost, "/admin/accounts", bytes.NewReader(body)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String()) + } + if got := len(h.Store.Accounts()); got != 1 { + t.Fatalf("expected no duplicate insert, got=%d", got) + } +} + func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) { h := newAdminTestHandler(t, `{ "accounts":[ @@ -117,6 +157,13 @@ func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) { if accByMobile.Email != "u@example.com" { t.Fatalf("unexpected account by mobile: %#v", accByMobile) } + accByMobileWithCountryCode, ok := findAccountByIdentifier(h.Store, "+8613800138000") + if !ok { + t.Fatal("expected find by +86 mobile") + } + if accByMobileWithCountryCode.Email != "u@example.com" { + t.Fatalf("unexpected account by +86 mobile: %#v", accByMobileWithCountryCode) + } tokenOnlyID := "" for _, acc := range h.Store.Accounts() { diff --git a/internal/admin/handler_config_import.go b/internal/admin/handler_config_import.go index 674d8b2..2b88d45 100644 --- a/internal/admin/handler_config_import.go +++ b/internal/admin/handler_config_import.go @@ -49,6 +49,7 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) { next := c.Clone() if mode == "replace" { next = incoming.Clone() + next.Accounts = normalizeAndDedupeAccounts(next.Accounts) next.VercelSyncHash = c.VercelSyncHash next.VercelSyncTime = c.VercelSyncTime importedKeys = len(next.Keys) @@ -73,17 +74,22 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) { existingAccounts := map[string]struct{}{} for _, acc := range next.Accounts { - existingAccounts[acc.Identifier()] = struct{}{} + acc = normalizeAccountForStorage(acc) + key := accountDedupeKey(acc) + if key != "" { + existingAccounts[key] = struct{}{} + } } for _, acc := range incoming.Accounts { - id := acc.Identifier() - if id == "" { + acc = normalizeAccountForStorage(acc) + key := accountDedupeKey(acc) + if key == "" { continue } - if _, ok := existingAccounts[id]; ok { + if _, ok := existingAccounts[key]; ok { continue } - existingAccounts[id] = struct{}{} + existingAccounts[key] = struct{}{} next.Accounts = append(next.Accounts, acc) importedAccounts++ } diff --git a/internal/admin/handler_config_write.go b/internal/admin/handler_config_write.go index 792e696..e09edfe 100644 --- a/internal/admin/handler_config_write.go +++ b/internal/admin/handler_config_write.go @@ -25,17 +25,28 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) { if accountsRaw, ok := req["accounts"].([]any); ok { existing := map[string]config.Account{} for _, a := range old.Accounts { - existing[a.Identifier()] = a + a = normalizeAccountForStorage(a) + key := accountDedupeKey(a) + if key != "" { + existing[key] = a + } } + seen := map[string]struct{}{} accounts := make([]config.Account, 0, len(accountsRaw)) for _, item := range accountsRaw { m, ok := item.(map[string]any) if !ok { continue } - acc := toAccount(m) - id := acc.Identifier() - if prev, ok := existing[id]; ok { + acc := normalizeAccountForStorage(toAccount(m)) + key := accountDedupeKey(acc) + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + if prev, ok := existing[key]; ok { if strings.TrimSpace(acc.Password) == "" { acc.Password = prev.Password } @@ -43,6 +54,7 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) { acc.Token = prev.Token } } + seen[key] = struct{}{} accounts = append(accounts, acc) } c.Accounts = accounts @@ -138,20 +150,24 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) { if accounts, ok := req["accounts"].([]any); ok { existing := map[string]bool{} for _, a := range c.Accounts { - existing[a.Identifier()] = true + a = normalizeAccountForStorage(a) + key := accountDedupeKey(a) + if key != "" { + existing[key] = true + } } for _, item := range accounts { m, ok := item.(map[string]any) if !ok { continue } - acc := toAccount(m) - id := acc.Identifier() - if id == "" || existing[id] { + acc := normalizeAccountForStorage(toAccount(m)) + key := accountDedupeKey(acc) + if key == "" || existing[key] { continue } c.Accounts = append(c.Accounts, acc) - existing[id] = true + existing[key] = true importedAccounts++ } } diff --git a/internal/admin/handler_settings_test.go b/internal/admin/handler_settings_test.go index 3eb5114..2a606fb 100644 --- a/internal/admin/handler_settings_test.go +++ b/internal/admin/handler_settings_test.go @@ -265,3 +265,57 @@ func TestConfigImportRejectsMergedRuntimeConflict(t *testing.T) { t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime) } } + +func TestConfigImportMergeDedupesMobileAliases(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "accounts":[{"mobile":"+8613800138000","password":"p1"}] + }`) + + merge := map[string]any{ + "mode": "merge", + "config": map[string]any{ + "accounts": []any{ + map[string]any{"mobile": "13800138000", "password": "p2"}, + }, + }, + } + b, _ := json.Marshal(merge) + req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.configImport(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + if got := len(h.Store.Accounts()); got != 1 { + t.Fatalf("expected merge dedupe by canonical mobile, got=%d", got) + } +} + +func TestUpdateConfigDedupesMobileAliases(t *testing.T) { + h := newAdminTestHandler(t, `{ + "keys":["k1"], + "accounts":[{"mobile":"+8613800138000","password":"old"}] + }`) + + reqBody := map[string]any{ + "accounts": []any{ + map[string]any{"mobile": "+8613800138000"}, + map[string]any{"mobile": "13800138000"}, + }, + } + b, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/admin/config", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateConfig(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String()) + } + accounts := h.Store.Accounts() + if len(accounts) != 1 { + t.Fatalf("expected update dedupe by canonical mobile, got=%d", len(accounts)) + } + if accounts[0].Identifier() != "+8613800138000" { + t.Fatalf("unexpected identifier: %q", accounts[0].Identifier()) + } +} diff --git a/internal/admin/helpers.go b/internal/admin/helpers.go index 2e00323..af27676 100644 --- a/internal/admin/helpers.go +++ b/internal/admin/helpers.go @@ -59,9 +59,11 @@ func toStringSlice(v any) ([]string, bool) { } func toAccount(m map[string]any) config.Account { + email := fieldString(m, "email") + mobile := config.NormalizeMobileForStorage(fieldString(m, "mobile")) return config.Account{ - Email: fieldString(m, "email"), - Mobile: fieldString(m, "mobile"), + Email: email, + Mobile: mobile, Password: fieldString(m, "password"), Token: fieldString(m, "token"), } @@ -90,12 +92,52 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool { if strings.TrimSpace(acc.Email) == id { return true } - if strings.TrimSpace(acc.Mobile) == id { + if mobileKey := config.CanonicalMobileKey(id); mobileKey != "" && mobileKey == config.CanonicalMobileKey(acc.Mobile) { return true } return acc.Identifier() == id } +func normalizeAccountForStorage(acc config.Account) config.Account { + acc.Email = strings.TrimSpace(acc.Email) + acc.Mobile = config.NormalizeMobileForStorage(acc.Mobile) + return acc +} + +func accountDedupeKey(acc config.Account) string { + if email := strings.TrimSpace(acc.Email); email != "" { + return "email:" + email + } + if mobile := config.CanonicalMobileKey(acc.Mobile); mobile != "" { + return "mobile:" + mobile + } + if id := strings.TrimSpace(acc.Identifier()); id != "" { + return "id:" + id + } + return "" +} + +func normalizeAndDedupeAccounts(accounts []config.Account) []config.Account { + if len(accounts) == 0 { + return nil + } + out := make([]config.Account, 0, len(accounts)) + seen := make(map[string]struct{}, len(accounts)) + for _, acc := range accounts { + acc = normalizeAccountForStorage(acc) + key := accountDedupeKey(acc) + if key == "" { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, acc) + } + return out +} + func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) { id := strings.TrimSpace(identifier) if id == "" { diff --git a/internal/admin/helpers_edge_test.go b/internal/admin/helpers_edge_test.go index 2a0bf20..0b2a0ab 100644 --- a/internal/admin/helpers_edge_test.go +++ b/internal/admin/helpers_edge_test.go @@ -182,7 +182,7 @@ func TestToAccountAllFields(t *testing.T) { if acc.Email != "user@test.com" { t.Fatalf("unexpected email: %q", acc.Email) } - if acc.Mobile != "13800138000" { + if acc.Mobile != "+8613800138000" { t.Fatalf("unexpected mobile: %q", acc.Mobile) } if acc.Password != "secret" { diff --git a/internal/config/account.go b/internal/config/account.go index 29a4947..3d6fa7d 100644 --- a/internal/config/account.go +++ b/internal/config/account.go @@ -10,8 +10,8 @@ func (a Account) Identifier() string { if strings.TrimSpace(a.Email) != "" { return strings.TrimSpace(a.Email) } - if strings.TrimSpace(a.Mobile) != "" { - return strings.TrimSpace(a.Mobile) + if mobile := NormalizeMobileForStorage(a.Mobile); mobile != "" { + return mobile } // Backward compatibility: old configs may contain token-only accounts. // Use a stable non-sensitive synthetic id so they can still join the pool. diff --git a/internal/config/config_edge_test.go b/internal/config/config_edge_test.go index 1138867..8a969df 100644 --- a/internal/config/config_edge_test.go +++ b/internal/config/config_edge_test.go @@ -202,7 +202,7 @@ func TestConfigCloneNilMaps(t *testing.T) { func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) { acc := Account{Mobile: "13800138000", Token: "tok"} - if acc.Identifier() != "13800138000" { + if acc.Identifier() != "+8613800138000" { t.Fatalf("expected mobile identifier, got %q", acc.Identifier()) } } diff --git a/internal/config/mobile.go b/internal/config/mobile.go new file mode 100644 index 0000000..7e2158b --- /dev/null +++ b/internal/config/mobile.go @@ -0,0 +1,82 @@ +package config + +import "strings" + +// NormalizeMobileForStorage normalizes user input to a stable storage format. +// It keeps existing country codes and auto-prefixes mainland China numbers with +86. +func NormalizeMobileForStorage(raw string) string { + digits, hasPlus := extractMobileDigits(raw) + if digits == "" { + return "" + } + if hasPlus { + return "+" + digits + } + if isChinaMobileWithCountryCode(digits) { + return "+86" + digits[2:] + } + if isChinaMainlandMobileDigits(digits) { + return "+86" + digits + } + // For non-China numbers without a leading +, preserve semantics by adding it. + return "+" + digits +} + +// CanonicalMobileKey returns the comparison key used by dedupe/matching logic. +func CanonicalMobileKey(raw string) string { + return NormalizeMobileForStorage(raw) +} + +func extractMobileDigits(raw string) (digits string, hasPlus bool) { + s := strings.TrimSpace(raw) + if s == "" { + return "", false + } + + for _, r := range s { + switch { + case r >= '0' && r <= '9': + goto collect + case isMobileSeparator(r): + continue + case r == '+': + hasPlus = true + goto collect + default: + goto collect + } + } + +collect: + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if r >= '0' && r <= '9' { + b.WriteRune(r) + } + } + return b.String(), hasPlus +} + +func isChinaMainlandMobileDigits(digits string) bool { + if len(digits) != 11 || digits[0] != '1' { + return false + } + return digits[1] >= '3' && digits[1] <= '9' +} + +func isChinaMobileWithCountryCode(digits string) bool { + if len(digits) != 13 || !strings.HasPrefix(digits, "86") { + return false + } + return isChinaMainlandMobileDigits(digits[2:]) +} + +func isMobileSeparator(r rune) bool { + switch r { + case ' ', '\t', '\n', '\r', '-', '(', ')', '.', '/': + return true + default: + return false + } +} diff --git a/internal/config/mobile_test.go b/internal/config/mobile_test.go new file mode 100644 index 0000000..96a98b6 --- /dev/null +++ b/internal/config/mobile_test.go @@ -0,0 +1,36 @@ +package config + +import "testing" + +func TestNormalizeMobileForStorageChinaMainlandAddsPlus86(t *testing.T) { + if got := NormalizeMobileForStorage("13800138000"); got != "+8613800138000" { + t.Fatalf("got %q", got) + } +} + +func TestNormalizeMobileForStorageChinaWithCountryCode(t *testing.T) { + if got := NormalizeMobileForStorage("8613800138000"); got != "+8613800138000" { + t.Fatalf("got %q", got) + } +} + +func TestNormalizeMobileForStorageKeepsExistingCountryCode(t *testing.T) { + if got := NormalizeMobileForStorage(" +1 (415) 555-2671 "); got != "+14155552671" { + t.Fatalf("got %q", got) + } +} + +func TestCanonicalMobileKeyMatchesChinaAliases(t *testing.T) { + a := CanonicalMobileKey("+8613800138000") + b := CanonicalMobileKey("13800138000") + c := CanonicalMobileKey("86 13800138000") + if a == "" || a != b || b != c { + t.Fatalf("alias mismatch: a=%q b=%q c=%q", a, b, c) + } +} + +func TestCanonicalMobileKeyEmptyForInvalidInput(t *testing.T) { + if got := CanonicalMobileKey("() --"); got != "" { + t.Fatalf("got %q", got) + } +} diff --git a/internal/format/openai/render_chat.go b/internal/format/openai/render_chat.go index 1e58fbd..181e8b9 100644 --- a/internal/format/openai/render_chat.go +++ b/internal/format/openai/render_chat.go @@ -8,7 +8,7 @@ import ( ) func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { - detected := util.ParseToolCalls(finalText, toolNames) + detected := util.ParseStandaloneToolCalls(finalText, toolNames) finishReason := "stop" messageObj := map[string]any{"role": "assistant", "content": finalText} if strings.TrimSpace(finalThinking) != "" { diff --git a/internal/format/openai/render_responses.go b/internal/format/openai/render_responses.go index f55ee9f..21df584 100644 --- a/internal/format/openai/render_responses.go +++ b/internal/format/openai/render_responses.go @@ -11,12 +11,9 @@ import ( ) func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { - // Align responses tool-call semantics with chat/completions: - // mixed prose + tool_call payloads should still be interpreted as tool calls. - detected := util.ParseToolCalls(finalText, toolNames) - if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { - detected = util.ParseToolCalls(finalThinking, toolNames) - } + // Strict mode: only standalone, structured tool-call payloads are treated + // as executable tool calls. + detected := util.ParseStandaloneToolCalls(finalText, toolNames) exposedOutputText := finalText output := make([]any, 0, 2) if len(detected) > 0 { diff --git a/internal/format/openai/render_test.go b/internal/format/openai/render_test.go index df792ed..7a9d897 100644 --- a/internal/format/openai/render_test.go +++ b/internal/format/openai/render_test.go @@ -45,7 +45,7 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) { } } -func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) { +func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) { obj := BuildResponseObject( "resp_test", "gpt-4o", @@ -56,17 +56,16 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) ) outputText, _ := obj["output_text"].(string) - if outputText != "" { - t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText) + if outputText == "" { + t.Fatalf("expected output_text preserved for mixed prose payload") } - output, _ := obj["output"].([]any) if len(output) != 1 { - t.Fatalf("expected function_call output only, got %#v", obj["output"]) + t.Fatalf("expected one message output item, got %#v", obj["output"]) } first, _ := output[0].(map[string]any) - if first["type"] != "function_call" { - t.Fatalf("expected first output type function_call, got %#v", first["type"]) + if first["type"] != "message" { + t.Fatalf("expected message output type, got %#v", first["type"]) } } @@ -127,7 +126,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) { } } -func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) { +func TestBuildResponseObjectIgnoresToolCallFromThinkingChannel(t *testing.T) { obj := BuildResponseObject( "resp_test", "gpt-4o", @@ -139,10 +138,10 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) { output, _ := obj["output"].([]any) if len(output) != 1 { - t.Fatalf("expected function_call output only, got %#v", obj["output"]) + t.Fatalf("expected one message output item, got %#v", obj["output"]) } first, _ := output[0].(map[string]any) - if first["type"] != "function_call" { - t.Fatalf("expected output function_call, got %#v", first["type"]) + if first["type"] != "message" { + t.Fatalf("expected output message, got %#v", first["type"]) } } diff --git a/internal/js/helpers/stream-tool-sieve/sieve.js b/internal/js/helpers/stream-tool-sieve/sieve.js index 699c3a8..0abe507 100644 --- a/internal/js/helpers/stream-tool-sieve/sieve.js +++ b/internal/js/helpers/stream-tool-sieve/sieve.js @@ -1,7 +1,6 @@ 'use strict'; const { - TOOL_SIEVE_CAPTURE_LIMIT, resetIncrementalToolState, noteText, insideCodeFence, @@ -37,14 +36,6 @@ function processToolSieveChunk(state, chunk, toolNames) { } const consumed = consumeToolCapture(state, toolNames); if (!consumed.ready) { - if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) { - noteText(state, state.capture); - events.push({ type: 'text', text: state.capture }); - state.capture = ''; - state.capturing = false; - resetIncrementalToolState(state); - continue; - } break; } state.capture = ''; diff --git a/internal/js/helpers/stream-tool-sieve/state.js b/internal/js/helpers/stream-tool-sieve/state.js index a2d2b5c..ff588e2 100644 --- a/internal/js/helpers/stream-tool-sieve/state.js +++ b/internal/js/helpers/stream-tool-sieve/state.js @@ -1,6 +1,5 @@ 'use strict'; -const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024; const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256; function createToolSieveState() { @@ -78,7 +77,6 @@ function toStringSafe(v) { } module.exports = { - TOOL_SIEVE_CAPTURE_LIMIT, TOOL_SIEVE_CONTEXT_TAIL_LIMIT, createToolSieveState, resetIncrementalToolState, diff --git a/internal/server/router.go b/internal/server/router.go index ae3108e..6672ad6 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -57,16 +57,20 @@ func NewApp() *App { r.Use(cors) r.Use(timeout(0)) - r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) { + healthzHandler := func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"status":"ok"}`)) - }) - r.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) { + } + readyzHandler := func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"status":"ready"}`)) - }) + } + r.Get("/healthz", healthzHandler) + r.Head("/healthz", healthzHandler) + r.Get("/readyz", readyzHandler) + r.Head("/readyz", readyzHandler) openai.RegisterRoutes(r, openaiHandler) claude.RegisterRoutes(r, claudeHandler) gemini.RegisterRoutes(r, geminiHandler) diff --git a/internal/server/router_health_test.go b/internal/server/router_health_test.go new file mode 100644 index 0000000..0f744dd --- /dev/null +++ b/internal/server/router_health_test.go @@ -0,0 +1,20 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestHealthEndpointsSupportHEAD(t *testing.T) { + app := NewApp() + + for _, path := range []string{"/healthz", "/readyz"} { + req := httptest.NewRequest(http.MethodHead, path, nil) + rec := httptest.NewRecorder() + app.Router.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected %s HEAD status 200, got %d", path, rec.Code) + } + } +} diff --git a/internal/testsuite/runner_cases_openai.go b/internal/testsuite/runner_cases_openai.go index 4ca2e40..6de3fd9 100644 --- a/internal/testsuite/runner_cases_openai.go +++ b/internal/testsuite/runner_cases_openai.go @@ -17,6 +17,12 @@ func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error { var m map[string]any _ = json.Unmarshal(resp.Body, &m) cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body))) + + headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/healthz", Retryable: true}) + if headErr != nil { + return headErr + } + cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode)) return nil } @@ -29,6 +35,12 @@ func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error { var m map[string]any _ = json.Unmarshal(resp.Body, &m) cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body))) + + headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/readyz", Retryable: true}) + if headErr != nil { + return headErr + } + cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode)) return nil } diff --git a/tests/node/stream-tool-sieve.test.js b/tests/node/stream-tool-sieve.test.js index f20cb11..ccbd160 100644 --- a/tests/node/stream-tool-sieve.test.js +++ b/tests/node/stream-tool-sieve.test.js @@ -141,6 +141,20 @@ test('sieve flushes incomplete captured tool json as text on stream finalize', ( assert.equal(leakedText.includes('{'), true); }); +test('sieve still intercepts large tool json payloads over previous capture limit', () => { + const large = 'a'.repeat(9000); + const payload = `{"tool_calls":[{"name":"read_file","input":{"path":"${large}"}}]}`; + const events = runSieve( + [payload.slice(0, 3000), payload.slice(3000, 7000), payload.slice(7000)], + ['read_file'], + ); + const leakedText = collectText(events); + const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && evt.calls?.length > 0); + const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && evt.deltas?.length > 0); + assert.equal(hasToolCall || hasToolDelta, true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); +}); + test('sieve keeps plain text intact in tool mode when no tool call appears', () => { const events = runSieve( ['你好,', '这是普通文本回复。', '请继续。'], diff --git a/webui/index.html b/webui/index.html index 556ed1e..370d1f1 100644 --- a/webui/index.html +++ b/webui/index.html @@ -24,9 +24,8 @@ - - + + diff --git a/webui/public/ds2api-favicon.svg b/webui/public/ds2api-favicon.svg new file mode 100644 index 0000000..feb9dcd --- /dev/null +++ b/webui/public/ds2api-favicon.svg @@ -0,0 +1,20 @@ + + + + + + + + + + DS + +