diff --git a/VERSION b/VERSION index bc4abe8..3f5e730 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.3.8 +2.3.8 \ No newline at end of file diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index 1a81660..03ed429 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -97,7 +97,7 @@ func (s *chatStreamRuntime) sendDone() { func (s *chatStreamRuntime) finalize(finishReason string) { finalThinking := s.thinking.String() - finalText := s.text.String() + finalText := sanitizeLeakedToolHistory(s.text.String()) detected := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames) if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted { finishReason = "tool_calls" @@ -141,8 +141,12 @@ func (s *chatStreamRuntime) finalize(finishReason string) { if evt.Content == "" { continue } + cleaned := sanitizeLeakedToolHistory(evt.Content) + if cleaned == "" { + continue + } delta := map[string]any{ - "content": evt.Content, + "content": cleaned, } if !s.firstChunkSent { delta["role"] = "assistant" @@ -246,8 +250,12 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD continue } if evt.Content != "" { + cleaned := sanitizeLeakedToolHistory(evt.Content) + if cleaned == "" { + continue + } contentDelta := map[string]any{ - "content": evt.Content, + "content": cleaned, } if !s.firstChunkSent { contentDelta["role"] = "assistant" diff --git a/internal/adapter/openai/handler_chat.go b/internal/adapter/openai/handler_chat.go index c514e36..27ef187 100644 --- a/internal/adapter/openai/handler_chat.go +++ b/internal/adapter/openai/handler_chat.go @@ -105,7 +105,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re result := sse.CollectStream(resp, thinkingEnabled, true) finalThinking := result.Thinking - finalText := result.Text + finalText := sanitizeLeakedToolHistory(result.Text) respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) writeJSON(w, http.StatusOK, respBody) } diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go index 7f2a340..6ddced4 100644 --- a/internal/adapter/openai/handler_toolcall_format.go +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -53,7 +53,7 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolCh if len(toolSchemas) == 0 { return messages, names } - toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY a JSON code block like this:\n```json\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n```\n\n【EXAMPLE】\nUser: Please check the weather in Beijing and Shanghai, and update my todo list.\nAssistant:\n```json\n{\"tool_calls\": [\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Beijing\"}},\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Shanghai\"}},\n {\"name\": \"update_todo\", \"input\": {\"todos\": [{\"content\": \"Buy milk\"}, {\"content\": \"Write report\"}]}}\n]}\n```\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON code block. The response must start with ```json and end with ```.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block.\n5) Never output [TOOL_CALL_HISTORY] or [TOOL_RESULT_HISTORY] markers in your answer; these markers are system-side context only.\n6) JSON SYNTAX STRICTLY REQUIRED: All property names MUST be enclosed in double quotes (e.g., \"name\", not name).\n7) ARRAY FORMAT: If providing a list of items, you MUST enclose them in square brackets `[]` (e.g., \"todos\": [{\"item\": \"a\"}, {\"item\": \"b\"}]). DO NOT output comma-separated objects without brackets." + toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY a JSON code block like this:\n```json\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n```\n\n【EXAMPLE】\nUser: Please check the weather in Beijing and Shanghai, and update my todo list.\nAssistant:\n```json\n{\"tool_calls\": [\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Beijing\"}},\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Shanghai\"}},\n {\"name\": \"update_todo\", \"input\": {\"todos\": [{\"content\": \"Buy milk\"}, {\"content\": \"Write report\"}]}}\n]}\n```\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON code block. The response must start with ```json and end with ```.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) JSON SYNTAX STRICTLY REQUIRED: All property names MUST be enclosed in double quotes (e.g., \"name\", not name).\n5) ARRAY FORMAT: If providing a list of items, you MUST enclose them in square brackets `[]` (e.g., \"todos\": [{\"item\": \"a\"}, {\"item\": \"b\"}]). DO NOT output comma-separated objects without brackets." if policy.Mode == util.ToolChoiceRequired { toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list." } diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go index c4f4c4a..a831599 100644 --- a/internal/adapter/openai/message_normalize.go +++ b/internal/adapter/openai/message_normalize.go @@ -2,14 +2,13 @@ package openai import ( "encoding/json" - "fmt" "strings" - "ds2api/internal/config" "ds2api/internal/prompt" ) func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any { + _ = traceID out := make([]map[string]any, 0, len(raw)) for _, item := range raw { msg, ok := item.(map[string]any) @@ -20,19 +19,21 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an switch role { case "assistant": content := normalizeOpenAIContentForPrompt(msg["content"]) - toolCalls := formatAssistantToolCallsForPrompt(msg, traceID) - combined := joinNonEmpty(content, toolCalls) - if combined == "" { + if content == "" { continue } out = append(out, map[string]any{ "role": "assistant", - "content": combined, + "content": content, }) case "tool", "function": + content := normalizeOpenAIContentForPrompt(msg["content"]) + if content == "" { + content = "null" + } out = append(out, map[string]any{ "role": "user", - "content": formatToolResultForPrompt(msg), + "content": content, }) case "user", "system", "developer": out = append(out, map[string]any{ @@ -56,95 +57,10 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an return out } -func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) string { - entries := make([]string, 0) - if calls, ok := msg["tool_calls"].([]any); ok { - for i, item := range calls { - call, ok := item.(map[string]any) - if !ok { - continue - } - id := strings.TrimSpace(asString(call["id"])) - if id == "" { - id = fmt.Sprintf("call_%d", i+1) - } - name := strings.TrimSpace(asString(call["name"])) - args := "" - - if fn, ok := call["function"].(map[string]any); ok { - if name == "" { - name = strings.TrimSpace(asString(fn["name"])) - } - args = normalizeOpenAIArgumentsForPrompt(fn["arguments"]) - } - if name == "" { - continue - } - if args == "" { - args = normalizeOpenAIArgumentsForPrompt(call["arguments"]) - } - if args == "" { - args = normalizeOpenAIArgumentsForPrompt(call["input"]) - } - if args == "" { - args = "{}" - } - maybeWarnSuspiciousToolHistory(traceID, id, name, args) - entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: %s\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", id, name, args)) - } - } - - if legacy, ok := msg["function_call"].(map[string]any); ok { - name := strings.TrimSpace(asString(legacy["name"])) - if name == "" { - name = "unknown" - } - args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"]) - if args == "" { - args = "{}" - } - maybeWarnSuspiciousToolHistory(traceID, "call_legacy", name, args) - entries = append(entries, fmt.Sprintf("[TOOL_CALL_HISTORY]\nstatus: already_called\norigin: assistant\nnot_user_input: true\ntool_call_id: call_legacy\nfunction.name: %s\nfunction.arguments: %s\n[/TOOL_CALL_HISTORY]", name, args)) - } - - return strings.Join(entries, "\n\n") -} - -func formatToolResultForPrompt(msg map[string]any) string { - toolCallID := strings.TrimSpace(asString(msg["tool_call_id"])) - if toolCallID == "" { - toolCallID = strings.TrimSpace(asString(msg["id"])) - } - if toolCallID == "" { - toolCallID = "unknown" - } - - name := strings.TrimSpace(asString(msg["name"])) - if name == "" { - name = "unknown" - } - - content := normalizeOpenAIContentForPrompt(msg["content"]) - if content == "" { - content = "null" - } - - return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content) -} - func normalizeOpenAIContentForPrompt(v any) string { return prompt.NormalizeContent(v) } -func normalizeOpenAIArgumentsForPrompt(v any) string { - switch x := v.(type) { - case string: - return normalizeToolArgumentString(x) - default: - return marshalToPromptString(v) - } -} - func normalizeToolArgumentString(raw string) string { trimmed := strings.TrimSpace(raw) if trimmed == "" { @@ -157,14 +73,6 @@ func normalizeToolArgumentString(raw string) string { return trimmed } -func marshalToPromptString(v any) string { - b, err := json.Marshal(v) - if err != nil { - return strings.TrimSpace(fmt.Sprintf("%v", v)) - } - return string(b) -} - func normalizeOpenAIRoleForPrompt(role string) string { role = strings.ToLower(strings.TrimSpace(role)) if role == "developer" { @@ -180,34 +88,6 @@ func asString(v any) string { return "" } -func joinNonEmpty(parts ...string) string { - nonEmpty := make([]string, 0, len(parts)) - for _, p := range parts { - if strings.TrimSpace(p) == "" { - continue - } - nonEmpty = append(nonEmpty, p) - } - return strings.Join(nonEmpty, "\n\n") -} - -func maybeWarnSuspiciousToolHistory(traceID, callID, name, args string) { - if !looksLikeConcatenatedJSON(args) { - return - } - traceID = strings.TrimSpace(traceID) - if traceID == "" { - traceID = "unknown" - } - config.Logger.Warn( - "[openai] suspicious tool call history payload detected", - "trace_id", traceID, - "tool_call_id", strings.TrimSpace(callID), - "name", strings.TrimSpace(name), - "arguments_preview", previewToolArgs(args, 160), - ) -} - func looksLikeConcatenatedJSON(raw string) bool { trimmed := strings.TrimSpace(raw) if trimmed == "" { @@ -224,11 +104,3 @@ func looksLikeConcatenatedJSON(raw string) bool { var second any return dec.Decode(&second) == nil } - -func previewToolArgs(raw string, max int) string { - trimmed := strings.TrimSpace(raw) - if max <= 0 || len(trimmed) <= max { - return trimmed - } - return trimmed[:max] -} diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go index c9c967d..fa17dfe 100644 --- a/internal/adapter/openai/message_normalize_test.go +++ b/internal/adapter/openai/message_normalize_test.go @@ -34,24 +34,20 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes } normalized := normalizeOpenAIMessagesForPrompt(raw, "") - if len(normalized) != 4 { - t.Fatalf("expected 4 normalized messages, got %d", len(normalized)) + if len(normalized) != 3 { + t.Fatalf("expected 3 normalized messages, got %d", len(normalized)) } - assistantContent, _ := normalized[2]["content"].(string) - if !strings.Contains(assistantContent, "[TOOL_CALL_HISTORY]") || - !strings.Contains(assistantContent, "tool_call_id: call_1") || - !strings.Contains(assistantContent, "function.name: get_weather") || - !strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") { - t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent) + toolContent, _ := normalized[2]["content"].(string) + if !strings.Contains(toolContent, `"temp":18`) { + t.Fatalf("tool result should be transparently forwarded, got %q", toolContent) } - toolContent, _ := normalized[3]["content"].(string) - if !strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") || !strings.Contains(toolContent, "name: get_weather") { - t.Fatalf("tool result not serialized correctly: %q", toolContent) + if strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") { + t.Fatalf("tool history marker should not be injected: %q", toolContent) } prompt := util.MessagesPrepare(normalized) - if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") { - t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt) + if strings.Contains(prompt, "[TOOL_CALL_HISTORY]") || strings.Contains(prompt, "[TOOL_RESULT_HISTORY]") { + t.Fatalf("expected no synthetic history markers in prompt: %q", prompt) } } @@ -116,11 +112,38 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"]) } got, _ := normalized[0]["content"].(string) - if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) { + if strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) { t.Fatalf("unexpected normalized function-role content: %q", got) } } +func TestNormalizeOpenAIMessagesForPrompt_EmptyToolContentPreservedAsNull(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "tool", + "tool_call_id": "call_5", + "name": "noop_tool", + "content": "", + }, + map[string]any{ + "role": "assistant", + "content": "done", + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 2 { + t.Fatalf("expected tool completion turn to be preserved, got %#v", normalized) + } + if normalized[0]["role"] != "user" { + t.Fatalf("expected tool role mapped to user, got %#v", normalized[0]["role"]) + } + got, _ := normalized[0]["content"].(string) + if got != "null" { + t.Fatalf("expected empty tool content to be preserved as null placeholder, got %q", got) + } +} + func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSeparated(t *testing.T) { raw := []any{ map[string]any{ @@ -147,24 +170,8 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara } normalized := normalizeOpenAIMessagesForPrompt(raw, "") - if len(normalized) != 1 { - t.Fatalf("expected one normalized assistant message, got %d", len(normalized)) - } - content, _ := normalized[0]["content"].(string) - if strings.Count(content, "[TOOL_CALL_HISTORY]") != 2 { - t.Fatalf("expected two TOOL_CALL_HISTORY blocks, got %q", content) - } - if !strings.Contains(content, "tool_call_id: call_search") || !strings.Contains(content, "function.name: search_web") { - t.Fatalf("missing first tool call block, got %q", content) - } - if !strings.Contains(content, "tool_call_id: call_eval") || !strings.Contains(content, "function.name: eval_javascript") { - t.Fatalf("missing second tool call block, got %q", content) - } - if strings.Contains(content, "search_webeval_javascript") { - t.Fatalf("unexpected merged function name detected: %q", content) - } - if strings.Contains(content, `}{"`) { - t.Fatalf("unexpected concatenated function arguments detected: %q", content) + if len(normalized) != 0 { + t.Fatalf("expected assistant tool_call-only message to be dropped in passthrough mode, got %#v", normalized) } } @@ -185,16 +192,11 @@ func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t * } normalized := normalizeOpenAIMessagesForPrompt(raw, "") - if len(normalized) != 1 { - t.Fatalf("expected one normalized message, got %d", len(normalized)) - } - content, _ := normalized[0]["content"].(string) - if !strings.Contains(content, `function.arguments: {}{"query":"测试工具调用"}`) { - t.Fatalf("expected original concatenated arguments in tool history, got %q", content) + if len(normalized) != 0 { + t.Fatalf("expected no synthetic assistant message for tool_call-only content, got %#v", normalized) } } - func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDropped(t *testing.T) { raw := []any{ map[string]any{ @@ -235,15 +237,8 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLi } normalized := normalizeOpenAIMessagesForPrompt(raw, "") - if len(normalized) != 1 { - t.Fatalf("expected one normalized message, got %d", len(normalized)) - } - content, _ := normalized[0]["content"].(string) - if strings.Contains(content, "<|Assistant|>null") || strings.HasPrefix(strings.TrimSpace(content), "null") { - t.Fatalf("unexpected null literal injected into assistant tool history: %q", content) - } - if !strings.Contains(content, "function.name: send_file_to_user") { - t.Fatalf("expected tool history block preserved, got %q", content) + if len(normalized) != 0 { + t.Fatalf("expected nil-content assistant tool_call-only message to be dropped, got %#v", normalized) } } diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go index 09b3a10..c7d4dc2 100644 --- a/internal/adapter/openai/prompt_build_test.go +++ b/internal/adapter/openai/prompt_build_test.go @@ -44,11 +44,11 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes if len(toolNames) != 1 || toolNames[0] != "get_weather" { t.Fatalf("unexpected tool names: %#v", toolNames) } - if !strings.Contains(finalPrompt, "tool_call_id: call_1") || - !strings.Contains(finalPrompt, "function.name: get_weather") || - !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") || - !strings.Contains(finalPrompt, `"condition":"sunny"`) { - t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt) + if !strings.Contains(finalPrompt, `"condition":"sunny"`) { + t.Fatalf("handler finalPrompt should preserve tool output content: %q", finalPrompt) + } + if strings.Contains(finalPrompt, "[TOOL_CALL_HISTORY]") || strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") { + t.Fatalf("handler finalPrompt should not include synthetic history markers: %q", finalPrompt) } } @@ -77,10 +77,4 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t * if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") { t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt) } - if !strings.Contains(finalPrompt, "[TOOL_RESULT_HISTORY]") { - t.Fatalf("vercel prepare finalPrompt missing history marker instruction: %q", finalPrompt) - } - if !strings.Contains(finalPrompt, "Never output [TOOL_CALL_HISTORY] or [TOOL_RESULT_HISTORY] markers in your answer") { - t.Fatalf("vercel prepare finalPrompt missing marker-output guard instruction: %q", finalPrompt) - } } diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index e4b1de8..b204442 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -113,7 +113,8 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res return } result := sse.CollectStream(resp, thinkingEnabled, true) - textParsed := util.ParseStandaloneToolCallsDetailed(result.Text, toolNames) + sanitizedText := sanitizeLeakedToolHistory(result.Text) + textParsed := util.ParseStandaloneToolCallsDetailed(sanitizedText, toolNames) logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text") callCount := len(textParsed.Calls) @@ -122,7 +123,7 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res return } - responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) + responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, sanitizedText, toolNames) h.getResponseStore().put(owner, responseID, responseObj) writeJSON(w, http.StatusOK, responseObj) } diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go index e8ec6df..c1ca926 100644 --- a/internal/adapter/openai/responses_stream_runtime_core.go +++ b/internal/adapter/openai/responses_stream_runtime_core.go @@ -98,7 +98,7 @@ func newResponsesStreamRuntime( func (s *responsesStreamRuntime) finalize() { finalThinking := s.thinking.String() - finalText := s.text.String() + finalText := sanitizeLeakedToolHistory(s.text.String()) if s.bufferToolContent { s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) @@ -204,12 +204,16 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa continue } - s.text.WriteString(p.Text) - if !s.bufferToolContent { - s.emitTextDelta(p.Text) + cleanedText := sanitizeLeakedToolHistory(p.Text) + if cleanedText == "" { continue } - s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true) + s.text.WriteString(cleanedText) + if !s.bufferToolContent { + s.emitTextDelta(cleanedText) + continue + } + s.processToolStreamEvents(processToolSieveChunk(&s.sieve, cleanedText, s.toolNames), true) } return streamengine.ParsedDecision{ContentSeen: contentSeen} diff --git a/internal/adapter/openai/tool_history_sanitize.go b/internal/adapter/openai/tool_history_sanitize.go new file mode 100644 index 0000000..6a2e80a --- /dev/null +++ b/internal/adapter/openai/tool_history_sanitize.go @@ -0,0 +1,14 @@ +package openai + +import ( + "regexp" +) + +var leakedToolHistoryPattern = regexp.MustCompile(`(?is)\[TOOL_CALL_HISTORY\][\s\S]*?\[/TOOL_CALL_HISTORY\]|\[TOOL_RESULT_HISTORY\][\s\S]*?\[/TOOL_RESULT_HISTORY\]`) + +func sanitizeLeakedToolHistory(text string) string { + if text == "" { + return text + } + return leakedToolHistoryPattern.ReplaceAllString(text, "") +} diff --git a/internal/adapter/openai/tool_history_sanitize_test.go b/internal/adapter/openai/tool_history_sanitize_test.go new file mode 100644 index 0000000..02128c9 --- /dev/null +++ b/internal/adapter/openai/tool_history_sanitize_test.go @@ -0,0 +1,98 @@ +package openai + +import "testing" + +func TestSanitizeLeakedToolHistoryRemovesMarkerBlocks(t *testing.T) { + raw := "前缀\n[TOOL_CALL_HISTORY]\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_CALL_HISTORY]\n后缀" + got := sanitizeLeakedToolHistory(raw) + if got != "前缀\n\n后缀" { + t.Fatalf("unexpected sanitized content: %q", got) + } +} + +func TestSanitizeLeakedToolHistoryPreservesChunkWhitespace(t *testing.T) { + cases := []struct { + name string + raw string + want string + }{ + { + name: "trailing space kept", + raw: "Hello ", + want: "Hello ", + }, + { + name: "leading newline kept", + raw: "\nworld", + want: "\nworld", + }, + { + name: "surrounding whitespace around marker is preserved", + raw: "A \n[TOOL_RESULT_HISTORY]\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]\n B", + want: "A \n\n B", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := sanitizeLeakedToolHistory(tc.raw) + if got != tc.want { + t.Fatalf("unexpected sanitize result, want %q got %q", tc.want, got) + } + }) + } +} + +func TestFlushToolSieveDropsToolHistoryLeak(t *testing.T) { + var state toolStreamSieveState + chunk := "[TOOL_CALL_HISTORY]\nstatus: already_called\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_CALL_HISTORY]" + evts := processToolSieveChunk(&state, chunk, []string{"exec"}) + if len(evts) != 0 { + t.Fatalf("expected no immediate output before history block is complete, got %+v", evts) + } + flushed := flushToolSieve(&state, []string{"exec"}) + if len(flushed) != 0 { + t.Fatalf("expected history block to be swallowed, got %+v", flushed) + } +} + +func TestFlushToolSieveDropsToolResultHistoryLeak(t *testing.T) { + var state toolStreamSieveState + chunk := "[TOOL_RESULT_HISTORY]\nstatus: already_called\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]" + evts := processToolSieveChunk(&state, chunk, []string{"exec"}) + if len(evts) != 0 { + t.Fatalf("expected no immediate output before result history block is complete, got %+v", evts) + } + flushed := flushToolSieve(&state, []string{"exec"}) + if len(flushed) != 0 { + t.Fatalf("expected result history block to be swallowed, got %+v", flushed) + } +} + +func TestProcessToolSieveChunkSplitsResultHistoryBoundary(t *testing.T) { + var state toolStreamSieveState + parts := []string{ + "Hello ", + "[TOOL_RESULT_HISTORY]\nstatus: already_called\n", + "function.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]", + "world", + } + var events []toolStreamEvent + for _, p := range parts { + events = append(events, processToolSieveChunk(&state, p, []string{"exec"})...) + } + events = append(events, flushToolSieve(&state, []string{"exec"})...) + + var text string + for _, evt := range events { + if evt.Content != "" { + text += evt.Content + } + if len(evt.ToolCalls) > 0 { + t.Fatalf("did not expect parsed tool calls from history leak: %+v", evt.ToolCalls) + } + } + if text != "Hello world" { + t.Fatalf("expected clean text output preserving boundary spaces, got %q", text) + } +} diff --git a/internal/adapter/openai/tool_sieve_core.go b/internal/adapter/openai/tool_sieve_core.go index 7618b01..3ee9eda 100644 --- a/internal/adapter/openai/tool_sieve_core.go +++ b/internal/adapter/openai/tool_sieve_core.go @@ -167,7 +167,7 @@ func findToolSegmentStart(s string) int { return -1 } lower := strings.ToLower(s) - keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"} + keywords := []string{"tool_calls", "function.name:", "[tool_call_history]", "[tool_result_history]"} bestKeyIdx := -1 for _, kw := range keywords { idx := strings.Index(lower, kw) @@ -194,9 +194,8 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix return "", nil, "", false } lower := strings.ToLower(captured) - keyIdx := -1 - keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"} + keywords := []string{"tool_calls", "function.name:", "[tool_call_history]", "[tool_result_history]"} for _, kw := range keywords { idx := strings.Index(lower, kw) if idx >= 0 && (keyIdx < 0 || idx < keyIdx) { @@ -209,6 +208,9 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix } start := strings.LastIndex(captured[:keyIdx], "{") if start < 0 { + if blockStart, blockEnd, ok := extractToolHistoryBlock(captured, keyIdx); ok { + return captured[:blockStart], nil, captured[blockEnd:], true + } start = keyIdx } obj, end, ok := extractJSONObjectFrom(captured, start) @@ -233,6 +235,31 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix return prefixPart, parsed.Calls, suffixPart, true } +func extractToolHistoryBlock(captured string, keyIdx int) (start int, end int, ok bool) { + if keyIdx < 0 || keyIdx >= len(captured) { + return 0, 0, false + } + rest := strings.ToLower(captured[keyIdx:]) + switch { + case strings.HasPrefix(rest, "[tool_call_history]"): + closeTag := "[/tool_call_history]" + closeIdx := strings.Index(rest, closeTag) + if closeIdx < 0 { + return 0, 0, false + } + return keyIdx, keyIdx + closeIdx + len(closeTag), true + case strings.HasPrefix(rest, "[tool_result_history]"): + closeTag := "[/tool_result_history]" + closeIdx := strings.Index(rest, closeTag) + if closeIdx < 0 { + return 0, 0, false + } + return keyIdx, keyIdx + closeIdx + len(closeTag), true + default: + return 0, 0, false + } +} + func trimWrappingJSONFence(prefix, suffix string) (string, string) { trimmedPrefix := strings.TrimRight(prefix, " \t\r\n") fenceIdx := strings.LastIndex(trimmedPrefix, "```") diff --git a/internal/admin/deps.go b/internal/admin/deps.go index 997c42b..d95eecf 100644 --- a/internal/admin/deps.go +++ b/internal/admin/deps.go @@ -17,6 +17,7 @@ type ConfigStore interface { FindAccount(identifier string) (config.Account, bool) UpdateAccountToken(identifier, token string) error UpdateAccountTestStatus(identifier, status string) error + AccountTestStatus(identifier string) (string, bool) Update(mutator func(*config.Config) error) error ExportJSONAndBase64() (string, string, error) IsEnvBacked() bool diff --git a/internal/admin/handler_accounts_crud.go b/internal/admin/handler_accounts_crud.go index 6536760..3761a7a 100644 --- a/internal/admin/handler_accounts_crud.go +++ b/internal/admin/handler_accounts_crud.go @@ -54,6 +54,7 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { } items := make([]map[string]any, 0, end-start) for _, acc := range accounts[start:end] { + testStatus, _ := h.Store.AccountTestStatus(acc.Identifier()) token := strings.TrimSpace(acc.Token) preview := "" if token != "" { @@ -70,7 +71,7 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { "has_password": acc.Password != "", "has_token": token != "", "token_preview": preview, - "test_status": acc.TestStatus, + "test_status": testStatus, }) } writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) diff --git a/internal/admin/handler_accounts_testing_test.go b/internal/admin/handler_accounts_testing_test.go index b07afaa..961794e 100644 --- a/internal/admin/handler_accounts_testing_test.go +++ b/internal/admin/handler_accounts_testing_test.go @@ -93,8 +93,9 @@ func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) { if updated.Token != "new-token" { t.Fatalf("expected refreshed token to be persisted, got %q", updated.Token) } - if updated.TestStatus != "ok" { - t.Fatalf("expected test status ok, got %q", updated.TestStatus) + testStatus, ok := store.AccountTestStatus("batch@example.com") + if !ok || testStatus != "ok" { + t.Fatalf("expected runtime test status ok, got %q (ok=%v)", testStatus, ok) } } diff --git a/internal/auth/request.go b/internal/auth/request.go index c0cdd52..ffcd980 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -7,6 +7,8 @@ import ( "errors" "net/http" "strings" + "sync" + "time" "ds2api/internal/account" "ds2api/internal/config" @@ -37,10 +39,20 @@ type Resolver struct { Store *config.Store Pool *account.Pool Login LoginFunc + + mu sync.Mutex + tokenRefreshedAt map[string]time.Time + tokenRefreshInterval time.Duration } func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Resolver { - return &Resolver{Store: store, Pool: pool, Login: login} + return &Resolver{ + Store: store, + Pool: pool, + Login: login, + tokenRefreshedAt: map[string]time.Time{}, + tokenRefreshInterval: 6 * time.Hour, + } } func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { @@ -72,13 +84,9 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { TriedAccounts: map[string]bool{}, resolver: r, } - if acc.Token == "" { - if err := r.loginAndPersist(ctx, a); err != nil { - r.Pool.Release(a.AccountID) - return nil, err - } - } else { - a.DeepSeekToken = acc.Token + if err := r.ensureManagedToken(ctx, a); err != nil { + r.Pool.Release(a.AccountID) + return nil, err } return a, nil } @@ -120,6 +128,7 @@ func (r *Resolver) loginAndPersist(ctx context.Context, a *RequestAuth) error { } a.Account.Token = token a.DeepSeekToken = token + r.markTokenRefreshedNow(a.AccountID) return r.Store.UpdateAccountToken(a.AccountID, token) } @@ -142,6 +151,7 @@ func (r *Resolver) MarkTokenInvalid(a *RequestAuth) { } a.Account.Token = "" a.DeepSeekToken = "" + r.clearTokenRefreshMark(a.AccountID) _ = r.Store.UpdateAccountToken(a.AccountID, "") } @@ -162,12 +172,8 @@ func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool { } a.Account = acc a.AccountID = acc.Identifier() - if acc.Token == "" { - if err := r.loginAndPersist(ctx, a); err != nil { - return false - } - } else { - a.DeepSeekToken = acc.Token + if err := r.ensureManagedToken(ctx, a); err != nil { + return false } return true } @@ -210,3 +216,53 @@ func callerTokenID(token string) string { sum := sha256.Sum256([]byte(token)) return "caller:" + hex.EncodeToString(sum[:8]) } + +func (r *Resolver) ensureManagedToken(ctx context.Context, a *RequestAuth) error { + if strings.TrimSpace(a.Account.Token) == "" { + return r.loginAndPersist(ctx, a) + } + if r.shouldForceRefresh(a.AccountID) { + if err := r.loginAndPersist(ctx, a); err != nil { + return err + } + return nil + } + a.DeepSeekToken = a.Account.Token + return nil +} + +func (r *Resolver) shouldForceRefresh(accountID string) bool { + if strings.TrimSpace(accountID) == "" { + return false + } + if r.tokenRefreshInterval <= 0 { + return false + } + now := time.Now() + r.mu.Lock() + defer r.mu.Unlock() + last, ok := r.tokenRefreshedAt[accountID] + if !ok || last.IsZero() { + r.tokenRefreshedAt[accountID] = now + return false + } + return now.Sub(last) >= r.tokenRefreshInterval +} + +func (r *Resolver) markTokenRefreshedNow(accountID string) { + if strings.TrimSpace(accountID) == "" { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.tokenRefreshedAt[accountID] = time.Now() +} + +func (r *Resolver) clearTokenRefreshMark(accountID string) { + if strings.TrimSpace(accountID) == "" { + return + } + r.mu.Lock() + defer r.mu.Unlock() + delete(r.tokenRefreshedAt, accountID) +} diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index f8cb40f..3e31907 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -3,7 +3,9 @@ package auth import ( "context" "net/http" + "sync/atomic" "testing" + "time" "ds2api/internal/account" "ds2api/internal/config" @@ -193,3 +195,52 @@ func TestDetermineCallerMissingToken(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestDetermineManagedAccountForcesRefreshEverySixHours(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[{"email":"acc@example.com","password":"pwd","token":"seed-token"}] + }`) + store := config.LoadStore() + if err := store.UpdateAccountToken("acc@example.com", "seed-token"); err != nil { + t.Fatalf("update token failed: %v", err) + } + pool := account.NewPool(store) + + var loginCount int32 + resolver := NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + n := atomic.AddInt32(&loginCount, 1) + return "fresh-token-" + string(rune('0'+n)), nil + }) + + req, _ := http.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + req.Header.Set("x-api-key", "managed-key") + + a1, err := resolver.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + if a1.DeepSeekToken != "seed-token" { + t.Fatalf("expected initial token without forced refresh, got %q", a1.DeepSeekToken) + } + resolver.Release(a1) + if got := atomic.LoadInt32(&loginCount); got != 0 { + t.Fatalf("expected no login before refresh interval, got %d", got) + } + + resolver.mu.Lock() + resolver.tokenRefreshedAt["acc@example.com"] = time.Now().Add(-7 * time.Hour) + resolver.mu.Unlock() + + a2, err := resolver.Determine(req) + if err != nil { + t.Fatalf("determine after interval failed: %v", err) + } + defer resolver.Release(a2) + if a2.DeepSeekToken != "fresh-token-1" { + t.Fatalf("expected refreshed token after interval, got %q", a2.DeepSeekToken) + } + if got := atomic.LoadInt32(&loginCount); got != 1 { + t.Fatalf("expected exactly one forced refresh login, got %d", got) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 8c50f8e..7ab4587 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,11 +19,10 @@ type Config struct { } type Account struct { - Email string `json:"email,omitempty"` - Mobile string `json:"mobile,omitempty"` - Password string `json:"password,omitempty"` - Token string `json:"token,omitempty"` - TestStatus string `json:"test_status,omitempty"` + Email string `json:"email,omitempty"` + Mobile string `json:"mobile,omitempty"` + Password string `json:"password,omitempty"` + Token string `json:"token,omitempty"` } func (c *Config) ClearAccountTokens() { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index bd8b714..5429bb8 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "encoding/base64" "os" + "strings" "testing" ) @@ -147,3 +148,39 @@ func TestLoadConfigOnVercelWithoutConfigFileFallsBackToMemory(t *testing.T) { t.Fatalf("expected empty bootstrap config, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts)) } } + +func TestAccountTestStatusIsRuntimeOnlyAndNotPersisted(t *testing.T) { + tmp, err := os.CreateTemp(t.TempDir(), "config-*.json") + if err != nil { + t.Fatalf("create temp config: %v", err) + } + defer tmp.Close() + if _, err := tmp.WriteString(`{ + "accounts":[{"email":"u@example.com","password":"p","test_status":"ok"}] + }`); err != nil { + t.Fatalf("write temp config: %v", err) + } + + t.Setenv("DS2API_CONFIG_JSON", "") + t.Setenv("CONFIG_JSON", "") + t.Setenv("DS2API_CONFIG_PATH", tmp.Name()) + + store := LoadStore() + if got, ok := store.AccountTestStatus("u@example.com"); ok || got != "" { + t.Fatalf("expected no runtime status loaded from config, got %q", got) + } + if err := store.UpdateAccountTestStatus("u@example.com", "ok"); err != nil { + t.Fatalf("update test status: %v", err) + } + if got, ok := store.AccountTestStatus("u@example.com"); !ok || got != "ok" { + t.Fatalf("expected runtime status to be available, got %q (ok=%v)", got, ok) + } + + content, err := os.ReadFile(tmp.Name()) + if err != nil { + t.Fatalf("read config: %v", err) + } + if strings.Contains(string(content), "test_status") { + t.Fatalf("expected test_status to stay out of persisted config, got: %s", content) + } +} diff --git a/internal/config/store.go b/internal/config/store.go index c607f81..d212594 100644 --- a/internal/config/store.go +++ b/internal/config/store.go @@ -17,6 +17,7 @@ type Store struct { fromEnv bool keyMap map[string]struct{} // O(1) API key lookup index accMap map[string]int // O(1) account lookup: identifier -> slice index + accTest map[string]string // runtime-only account test status cache } func LoadStore() *Store { @@ -58,6 +59,11 @@ func loadConfig() (Config, bool, error) { return Config{}, false, err } cfg.DropInvalidAccounts() + if strings.Contains(string(content), `"test_status"`) && !IsVercel() { + if b, err := json.MarshalIndent(cfg, "", " "); err == nil { + _ = os.WriteFile(ConfigPath(), b, 0o644) + } + } if IsVercel() { // Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors. return cfg, true, nil @@ -108,8 +114,19 @@ func (s *Store) UpdateAccountTestStatus(identifier, status string) error { if !ok { return errors.New("account not found") } - s.cfg.Accounts[idx].TestStatus = status - return s.saveLocked() + s.setAccountTestStatusLocked(s.cfg.Accounts[idx], status, identifier) + return nil +} + +func (s *Store) AccountTestStatus(identifier string) (string, bool) { + identifier = strings.TrimSpace(identifier) + if identifier == "" { + return "", false + } + s.mu.RLock() + defer s.mu.RUnlock() + status, ok := s.accTest[identifier] + return status, ok } func (s *Store) UpdateAccountToken(identifier, token string) error { diff --git a/internal/config/store_index.go b/internal/config/store_index.go index 7d0f62a..a0e6638 100644 --- a/internal/config/store_index.go +++ b/internal/config/store_index.go @@ -2,15 +2,20 @@ package config // rebuildIndexes must be called with the lock already held (or during init). func (s *Store) rebuildIndexes() { + prevStatus := s.accTest s.keyMap = make(map[string]struct{}, len(s.cfg.Keys)) for _, k := range s.cfg.Keys { s.keyMap[k] = struct{}{} } s.accMap = make(map[string]int, len(s.cfg.Accounts)) + s.accTest = make(map[string]string, len(s.cfg.Accounts)) for i, acc := range s.cfg.Accounts { id := acc.Identifier() if id != "" { s.accMap[id] = i + if status, ok := prevStatus[id]; ok { + s.setAccountTestStatusLocked(acc, status, "") + } } } } @@ -29,3 +34,22 @@ func (s *Store) findAccountIndexLocked(identifier string) (int, bool) { } return -1, false } + +func (s *Store) setAccountTestStatusLocked(acc Account, status, hintedIdentifier string) { + status = lower(status) + if status == "" { + return + } + if id := acc.Identifier(); id != "" { + s.accTest[id] = status + } + if email := acc.Email; email != "" { + s.accTest[email] = status + } + if mobile := CanonicalMobileKey(acc.Mobile); mobile != "" { + s.accTest[mobile] = status + } + if hintedIdentifier = lower(hintedIdentifier); hintedIdentifier != "" { + s.accTest[hintedIdentifier] = status + } +} diff --git a/internal/js/helpers/stream-tool-sieve/sieve.js b/internal/js/helpers/stream-tool-sieve/sieve.js index 12534f9..6cf8b5c 100644 --- a/internal/js/helpers/stream-tool-sieve/sieve.js +++ b/internal/js/helpers/stream-tool-sieve/sieve.js @@ -1,16 +1,7 @@ 'use strict'; - -const { - resetIncrementalToolState, - noteText, - insideCodeFence, -} = require('./state'); -const { - parseStandaloneToolCallsDetailed, -} = require('./parse'); -const { - extractJSONObjectFrom, -} = require('./jsonscan'); +const { resetIncrementalToolState, noteText, insideCodeFence } = require('./state'); +const { parseStandaloneToolCallsDetailed } = require('./parse'); +const { extractJSONObjectFrom } = require('./jsonscan'); function processToolSieveChunk(state, chunk, toolNames) { if (!state) { @@ -20,8 +11,6 @@ function processToolSieveChunk(state, chunk, toolNames) { state.pending += chunk; } const events = []; - - // eslint-disable-next-line no-constant-condition while (true) { if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) { events.push({ type: 'tool_calls', calls: state.pendingToolCalls }); @@ -60,12 +49,10 @@ function processToolSieveChunk(state, chunk, toolNames) { } continue; } - const pending = state.pending || ''; if (!pending) { break; } - const start = findToolSegmentStart(pending); if (start >= 0) { const prefix = pending.slice(0, start); @@ -79,7 +66,6 @@ function processToolSieveChunk(state, chunk, toolNames) { resetIncrementalToolState(state); continue; } - const [safe, hold] = splitSafeContentForToolDetection(pending); if (!safe) { break; @@ -96,13 +82,11 @@ function flushToolSieve(state, toolNames) { return []; } const events = processToolSieveChunk(state, '', toolNames); - if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) { events.push({ type: 'tool_calls', calls: state.pendingToolCalls }); state.pendingToolRaw = ''; state.pendingToolCalls = []; } - if (state.capturing) { const consumed = consumeToolCapture(state, toolNames); if (consumed.ready) { @@ -125,13 +109,11 @@ function flushToolSieve(state, toolNames) { state.capturing = false; resetIncrementalToolState(state); } - if (state.pending) { noteText(state, state.pending); events.push({ type: 'text', text: state.pending }); state.pending = ''; } - return events; } @@ -147,8 +129,6 @@ function splitSafeContentForToolDetection(s) { if (suspiciousStart > 0) { return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)]; } - // If suspicious content starts at the beginning, keep holding until we can - // either parse a full tool JSON block or reach stream flush. return ['', text]; } @@ -168,13 +148,11 @@ function findToolSegmentStart(s) { return -1; } const lower = s.toLowerCase(); - const keywords = ['tool_calls', 'function.name:', '[tool_call_history]']; + const keywords = ['tool_calls', 'function.name:', '[tool_call_history]', '[tool_result_history]']; let offset = 0; - // eslint-disable-next-line no-constant-condition while (true) { let bestKeyIdx = -1; let matchedKeyword = ''; - for (const kw of keywords) { const idx = lower.indexOf(kw, offset); if (idx >= 0) { @@ -184,11 +162,9 @@ function findToolSegmentStart(s) { } } } - if (bestKeyIdx < 0) { return -1; } - const keyIdx = bestKeyIdx; const start = s.slice(0, keyIdx).lastIndexOf('{'); const candidateStart = start >= 0 ? start : keyIdx; @@ -205,30 +181,36 @@ function consumeToolCapture(state, toolNames) { return { ready: false, prefix: '', calls: [], suffix: '' }; } const lower = captured.toLowerCase(); - let keyIdx = -1; - const keywords = ['tool_calls', 'function.name:', '[tool_call_history]']; + const keywords = ['tool_calls', 'function.name:', '[tool_call_history]', '[tool_result_history]']; for (const kw of keywords) { const idx = lower.indexOf(kw); if (idx >= 0 && (keyIdx < 0 || idx < keyIdx)) { keyIdx = idx; } } - if (keyIdx < 0) { return { ready: false, prefix: '', calls: [], suffix: '' }; } const start = captured.slice(0, keyIdx).lastIndexOf('{'); const actualStart = start >= 0 ? start : keyIdx; - + if (start < 0) { + const history = extractToolHistoryBlock(captured, keyIdx); + if (history.ok) { + return { + ready: true, + prefix: captured.slice(0, history.start), + calls: [], + suffix: captured.slice(history.end), + }; + } + } const obj = extractJSONObjectFrom(captured, actualStart); if (!obj.ok) { return { ready: false, prefix: '', calls: [], suffix: '' }; } - const prefixPart = captured.slice(0, actualStart); const suffixPart = captured.slice(obj.end); - if (insideCodeFence((state.recentTextTail || '') + prefixPart)) { return { ready: true, @@ -237,7 +219,6 @@ function consumeToolCapture(state, toolNames) { suffix: '', }; } - const parsed = parseStandaloneToolCallsDetailed(captured.slice(actualStart, obj.end), toolNames); if (!Array.isArray(parsed.calls) || parsed.calls.length === 0) { if (parsed.sawToolCallSyntax && parsed.rejectedByPolicy) { @@ -255,7 +236,6 @@ function consumeToolCapture(state, toolNames) { suffix: '', }; } - const trimmedFence = trimWrappingJSONFence(prefixPart, suffixPart); return { ready: true, @@ -265,14 +245,34 @@ function consumeToolCapture(state, toolNames) { }; } +function extractToolHistoryBlock(captured, keyIdx) { + if (typeof captured !== 'string' || keyIdx < 0 || keyIdx >= captured.length) { + return { ok: false, start: 0, end: 0 }; + } + const rest = captured.slice(keyIdx).toLowerCase(); + if (rest.startsWith('[tool_call_history]')) { + const closeTag = '[/tool_call_history]'; + const closeIdx = rest.indexOf(closeTag); + if (closeIdx < 0) { + return { ok: false, start: 0, end: 0 }; + } + return { ok: true, start: keyIdx, end: keyIdx + closeIdx + closeTag.length }; + } + if (rest.startsWith('[tool_result_history]')) { + const closeTag = '[/tool_result_history]'; + const closeIdx = rest.indexOf(closeTag); + if (closeIdx < 0) { + return { ok: false, start: 0, end: 0 }; + } + return { ok: true, start: keyIdx, end: keyIdx + closeIdx + closeTag.length }; + } + return { ok: false, start: 0, end: 0 }; +} + function trimWrappingJSONFence(prefix, suffix) { const rightTrimmedPrefix = (prefix || '').replace(/[ \t\r\n]+$/g, ''); const fenceIdx = rightTrimmedPrefix.lastIndexOf('```'); - if (fenceIdx < 0) { - return { prefix, suffix }; - } - // Only strip when this behaves like an opening fence. - // If it's a legitimate closing fence before standalone tool JSON, keep it. + if (fenceIdx < 0) return { prefix, suffix }; const fenceCount = (rightTrimmedPrefix.slice(0, fenceIdx + 3).match(/```/g) || []).length; if (fenceCount % 2 === 0) { return { prefix, suffix }; diff --git a/tests/node/stream-tool-sieve.test.js b/tests/node/stream-tool-sieve.test.js index e352ca7..b24e138 100644 --- a/tests/node/stream-tool-sieve.test.js +++ b/tests/node/stream-tool-sieve.test.js @@ -226,6 +226,56 @@ test('sieve keeps plain text intact in tool mode when no tool call appears', () assert.equal(leakedText, '你好,这是普通文本回复。请继续。'); }); +test('sieve swallows leaked TOOL_CALL_HISTORY marker blocks', () => { + const events = runSieve( + [ + '前置文本。', + '[TOOL_CALL_HISTORY]\nstatus: already_called\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_CALL_HISTORY]', + '后置文本。', + ], + ['exec'], + ); + const leakedText = collectText(events); + const hasToolCall = events.some((evt) => evt.type === 'tool_calls'); + assert.equal(hasToolCall, false); + assert.equal(leakedText.includes('前置文本。'), true); + assert.equal(leakedText.includes('后置文本。'), true); + assert.equal(leakedText.includes('[TOOL_CALL_HISTORY]'), false); +}); + +test('sieve swallows leaked TOOL_RESULT_HISTORY marker blocks', () => { + const events = runSieve( + [ + '前置文本。', + '[TOOL_RESULT_HISTORY]\nstatus: already_called\nfunction.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]', + '后置文本。', + ], + ['exec'], + ); + const leakedText = collectText(events); + const hasToolCall = events.some((evt) => evt.type === 'tool_calls'); + assert.equal(hasToolCall, false); + assert.equal(leakedText.includes('前置文本。'), true); + assert.equal(leakedText.includes('后置文本。'), true); + assert.equal(leakedText.includes('[TOOL_RESULT_HISTORY]'), false); +}); + +test('sieve preserves text spacing when TOOL_RESULT_HISTORY spans chunks', () => { + const events = runSieve( + [ + 'Hello ', + '[TOOL_RESULT_HISTORY]\nstatus: already_called\n', + 'function.name: exec\nfunction.arguments: {}\n[/TOOL_RESULT_HISTORY]', + 'world', + ], + ['exec'], + ); + const leakedText = collectText(events); + const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && evt.calls?.length > 0); + assert.equal(hasToolCall, false); + assert.equal(leakedText, 'Hello world'); +}); + test('sieve intercepts rejected unknown tool payload (no args) without raw leak', () => { const events = runSieve( ['{"tool_calls":[{"name":"not_in_schema"}]}', '后置正文G。'], diff --git a/webui/src/features/settings/BackupSection.jsx b/webui/src/features/settings/BackupSection.jsx index c31a56f..06a1268 100644 --- a/webui/src/features/settings/BackupSection.jsx +++ b/webui/src/features/settings/BackupSection.jsx @@ -6,7 +6,9 @@ export default function BackupSection({ setImportMode, importing, onLoadExportData, + onDownloadExportFile, onImport, + onImportFileChange, importText, setImportText, exportData, @@ -23,6 +25,27 @@ export default function BackupSection({ {t('settings.loadExport')} + +