mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-08 10:25:28 +08:00
Compare commits
10 Commits
v3.0.0
...
v3.1.0_bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
068f4b0df6 | ||
|
|
5a51045ba4 | ||
|
|
3497d5d019 | ||
|
|
95a9d16843 | ||
|
|
0847091864 | ||
|
|
c6340354ec | ||
|
|
6bf08e00cd | ||
|
|
35221002d5 | ||
|
|
4b1f1ea550 | ||
|
|
0258f83d10 |
@@ -106,6 +106,9 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
|||||||
|
|
||||||
finalThinking := result.Thinking
|
finalThinking := result.Thinking
|
||||||
finalText := sanitizeLeakedOutput(result.Text)
|
finalText := sanitizeLeakedOutput(result.Text)
|
||||||
|
if writeUpstreamEmptyOutputError(w, result) {
|
||||||
|
return
|
||||||
|
}
|
||||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||||
if result.OutputTokens > 0 {
|
if result.OutputTokens > 0 {
|
||||||
if usage, ok := respBody["usage"].(map[string]any); ok {
|
if usage, ok := respBody["usage"].(map[string]any); ok {
|
||||||
|
|||||||
@@ -275,6 +275,44 @@ func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
|
|||||||
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
|
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleNonStreamReturns502WhenUpstreamOutputEmpty(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
`data: {"p":"response/content","v":""}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.handleNonStream(rec, context.Background(), resp, "cid-empty", "deepseek-chat", "prompt", false, nil)
|
||||||
|
if rec.Code != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected status 502 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
errObj, _ := out["error"].(map[string]any)
|
||||||
|
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||||
|
t.Fatalf("expected code=upstream_empty_output, got %#v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutput(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
resp := makeSSEHTTPResponse(
|
||||||
|
`data: {"code":"content_filter"}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.handleNonStream(rec, context.Background(), resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, nil)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
errObj, _ := out["error"].(map[string]any)
|
||||||
|
if asString(errObj["code"]) != "content_filter" {
|
||||||
|
t.Fatalf("expected code=content_filter, got %#v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
resp := makeSSEHTTPResponse(
|
resp := makeSSEHTTPResponse(
|
||||||
|
|||||||
@@ -114,6 +114,9 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
|||||||
}
|
}
|
||||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||||
sanitizedText := sanitizeLeakedOutput(result.Text)
|
sanitizedText := sanitizeLeakedOutput(result.Text)
|
||||||
|
if writeUpstreamEmptyOutputError(w, result) {
|
||||||
|
return
|
||||||
|
}
|
||||||
textParsed := util.ParseStandaloneToolCallsDetailed(sanitizedText, toolNames)
|
textParsed := util.ParseStandaloneToolCallsDetailed(sanitizedText, toolNames)
|
||||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||||
|
|
||||||
|
|||||||
@@ -627,6 +627,50 @@ func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesNonStreamReturns502WhenUpstreamOutputEmpty(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
`data: {"p":"response/content","v":""}` + "\n" +
|
||||||
|
`data: [DONE]` + "\n",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||||
|
if rec.Code != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected 502 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
errObj, _ := out["error"].(map[string]any)
|
||||||
|
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||||
|
t.Fatalf("expected code=upstream_empty_output, got %#v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutput(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader(
|
||||||
|
`data: {"code":"content_filter"}` + "\n" +
|
||||||
|
`data: [DONE]` + "\n",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
out := decodeJSONBody(t, rec.Body.String())
|
||||||
|
errObj, _ := out["error"].(map[string]any)
|
||||||
|
if asString(errObj["code"]) != "content_filter" {
|
||||||
|
t.Fatalf("expected code=content_filter, got %#v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||||
matched := false
|
matched := false
|
||||||
|
|||||||
@@ -71,12 +71,31 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
|||||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||||
return prefixPart, parsed, suffixPart, true
|
return prefixPart, parsed, suffixPart, true
|
||||||
}
|
}
|
||||||
|
// If this block does not look like an executable tool-call payload,
|
||||||
|
// pass it through as normal content (e.g. user-requested XML snippets).
|
||||||
|
if !looksLikeExecutableXMLToolCallBlock(xmlBlock, pair.open) {
|
||||||
|
return prefixPart + xmlBlock, nil, suffixPart, true
|
||||||
|
}
|
||||||
// Looks like XML tool syntax but failed to parse — consume it to avoid leak.
|
// Looks like XML tool syntax but failed to parse — consume it to avoid leak.
|
||||||
return prefixPart, nil, suffixPart, true
|
return prefixPart, nil, suffixPart, true
|
||||||
}
|
}
|
||||||
return "", nil, "", false
|
return "", nil, "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func looksLikeExecutableXMLToolCallBlock(xmlBlock, openTag string) bool {
|
||||||
|
lower := strings.ToLower(xmlBlock)
|
||||||
|
// Agent wrapper tags are always treated as internal tool-call wrappers.
|
||||||
|
switch openTag {
|
||||||
|
case "<attempt_completion", "<ask_followup_question", "<new_task":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return strings.Contains(lower, "<tool_name") ||
|
||||||
|
strings.Contains(lower, "<parameters") ||
|
||||||
|
strings.Contains(lower, `"tool"`) ||
|
||||||
|
strings.Contains(lower, `"tool_name"`) ||
|
||||||
|
strings.Contains(lower, `"name"`)
|
||||||
|
}
|
||||||
|
|
||||||
// hasOpenXMLToolTag returns true if captured text contains an XML tool opening tag
|
// hasOpenXMLToolTag returns true if captured text contains an XML tool opening tag
|
||||||
// whose SPECIFIC closing tag has not appeared yet.
|
// whose SPECIFIC closing tag has not appeared yet.
|
||||||
func hasOpenXMLToolTag(captured string) bool {
|
func hasOpenXMLToolTag(captured string) bool {
|
||||||
|
|||||||
@@ -78,6 +78,49 @@ func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcessToolSievePassesThroughNonToolXMLBlock(t *testing.T) {
|
||||||
|
var state toolStreamSieveState
|
||||||
|
chunk := `<tool_call><title>示例 XML</title><body>plain text xml payload</body></tool_call>`
|
||||||
|
events := processToolSieveChunk(&state, chunk, []string{"read_file"})
|
||||||
|
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||||
|
|
||||||
|
var textContent strings.Builder
|
||||||
|
toolCalls := 0
|
||||||
|
for _, evt := range events {
|
||||||
|
textContent.WriteString(evt.Content)
|
||||||
|
toolCalls += len(evt.ToolCalls)
|
||||||
|
}
|
||||||
|
if toolCalls != 0 {
|
||||||
|
t.Fatalf("expected no tool calls for plain XML payload, got %d events=%#v", toolCalls, events)
|
||||||
|
}
|
||||||
|
if textContent.String() != chunk {
|
||||||
|
t.Fatalf("expected XML payload to pass through unchanged, got %q", textContent.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessToolSieveNonToolXMLKeepsSuffixForToolParsing(t *testing.T) {
|
||||||
|
var state toolStreamSieveState
|
||||||
|
chunk := `<tool_call><title>plain xml</title></tool_call><invoke name="read_file"><parameters>{"path":"README.MD"}</parameters></invoke>`
|
||||||
|
events := processToolSieveChunk(&state, chunk, []string{"read_file"})
|
||||||
|
events = append(events, flushToolSieve(&state, []string{"read_file"})...)
|
||||||
|
|
||||||
|
var textContent strings.Builder
|
||||||
|
toolCalls := 0
|
||||||
|
for _, evt := range events {
|
||||||
|
textContent.WriteString(evt.Content)
|
||||||
|
toolCalls += len(evt.ToolCalls)
|
||||||
|
}
|
||||||
|
if !strings.Contains(textContent.String(), `<tool_call><title>plain xml</title></tool_call>`) {
|
||||||
|
t.Fatalf("expected leading non-tool XML to be preserved, got %q", textContent.String())
|
||||||
|
}
|
||||||
|
if strings.Contains(textContent.String(), `<invoke name="read_file">`) {
|
||||||
|
t.Fatalf("expected invoke tool XML to be intercepted, got %q", textContent.String())
|
||||||
|
}
|
||||||
|
if toolCalls != 1 {
|
||||||
|
t.Fatalf("expected exactly one parsed tool call from suffix, got %d events=%#v", toolCalls, events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProcessToolSievePartialXMLTagHeldBack(t *testing.T) {
|
func TestProcessToolSievePartialXMLTagHeldBack(t *testing.T) {
|
||||||
var state toolStreamSieveState
|
var state toolStreamSieveState
|
||||||
// Chunk ends with a partial XML tool tag.
|
// Chunk ends with a partial XML tool tag.
|
||||||
|
|||||||
20
internal/adapter/openai/upstream_empty.go
Normal file
20
internal/adapter/openai/upstream_empty.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"ds2api/internal/sse"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeUpstreamEmptyOutputError(w http.ResponseWriter, result sse.CollectResult) bool {
|
||||||
|
if strings.TrimSpace(result.Thinking) != "" || strings.TrimSpace(sanitizeLeakedOutput(result.Text)) != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if result.ContentFilter {
|
||||||
|
writeOpenAIErrorWithCode(w, http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
writeOpenAIErrorWithCode(w, http.StatusBadGateway, "Upstream model returned empty output.", "upstream_empty_output")
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -10,9 +10,10 @@ import (
|
|||||||
// CollectResult holds the aggregated text and thinking content from a
|
// CollectResult holds the aggregated text and thinking content from a
|
||||||
// DeepSeek SSE stream, consumed to completion (non-streaming use case).
|
// DeepSeek SSE stream, consumed to completion (non-streaming use case).
|
||||||
type CollectResult struct {
|
type CollectResult struct {
|
||||||
Text string
|
Text string
|
||||||
Thinking string
|
Thinking string
|
||||||
OutputTokens int
|
OutputTokens int
|
||||||
|
ContentFilter bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// CollectStream fully consumes a DeepSeek SSE response and separates
|
// CollectStream fully consumes a DeepSeek SSE response and separates
|
||||||
@@ -28,6 +29,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
|||||||
text := strings.Builder{}
|
text := strings.Builder{}
|
||||||
thinking := strings.Builder{}
|
thinking := strings.Builder{}
|
||||||
outputTokens := 0
|
outputTokens := 0
|
||||||
|
contentFilter := false
|
||||||
currentType := "text"
|
currentType := "text"
|
||||||
if thinkingEnabled {
|
if thinkingEnabled {
|
||||||
currentType = "thinking"
|
currentType = "thinking"
|
||||||
@@ -39,6 +41,9 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if result.Stop {
|
if result.Stop {
|
||||||
|
if result.ContentFilter {
|
||||||
|
contentFilter = true
|
||||||
|
}
|
||||||
if result.OutputTokens > 0 {
|
if result.OutputTokens > 0 {
|
||||||
outputTokens = result.OutputTokens
|
outputTokens = result.OutputTokens
|
||||||
}
|
}
|
||||||
@@ -56,5 +61,10 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return CollectResult{Text: text.String(), Thinking: thinking.String(), OutputTokens: outputTokens}
|
return CollectResult{
|
||||||
|
Text: text.String(),
|
||||||
|
Thinking: thinking.String(),
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
ContentFilter: contentFilter,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ func filterLeakedContentFilterParts(parts []ContentPart) []ContentPart {
|
|||||||
out := make([]ContentPart, 0, len(parts))
|
out := make([]ContentPart, 0, len(parts))
|
||||||
for _, p := range parts {
|
for _, p := range parts {
|
||||||
cleaned := stripLeakedContentFilterSuffix(p.Text)
|
cleaned := stripLeakedContentFilterSuffix(p.Text)
|
||||||
if strings.TrimSpace(cleaned) == "" {
|
if shouldDropCleanedLeakedChunk(cleaned) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
p.Text = cleaned
|
p.Text = cleaned
|
||||||
@@ -27,5 +27,19 @@ func stripLeakedContentFilterSuffix(text string) string {
|
|||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
return text
|
return text
|
||||||
}
|
}
|
||||||
return strings.TrimRight(text[:idx], " \t\r\n")
|
// Keep "\n" so we don't collapse line structure when the upstream model
|
||||||
|
// appends leaked CONTENT_FILTER markers after a line break.
|
||||||
|
return strings.TrimRight(text[:idx], " \t\r")
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldDropCleanedLeakedChunk(cleaned string) bool {
|
||||||
|
if cleaned == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Preserve newline-only chunks to avoid dropping legitimate line breaks
|
||||||
|
// before a leaked CONTENT_FILTER suffix.
|
||||||
|
if strings.Contains(cleaned, "\n") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(cleaned) == ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,3 +102,23 @@ func TestParseDeepSeekContentLineContentTextEqualContentFilterDoesNotStop(t *tes
|
|||||||
t.Fatalf("did not expect content-filter stop for content text: %#v", res)
|
t.Fatalf("did not expect content-filter stop for content text: %#v", res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseDeepSeekContentLinePreservesTrailingNewlineBeforeLeakedContentFilter(t *testing.T) {
|
||||||
|
res := ParseDeepSeekContentLine([]byte("data: {\"p\":\"response/content\",\"v\":\"line1\\nCONTENT_FILTERblocked\"}"), false, "text")
|
||||||
|
if !res.Parsed || res.Stop {
|
||||||
|
t.Fatalf("expected parsed non-stop result: %#v", res)
|
||||||
|
}
|
||||||
|
if len(res.Parts) != 1 || res.Parts[0].Text != "line1\n" {
|
||||||
|
t.Fatalf("expected trailing newline preserved, got %#v", res.Parts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDeepSeekContentLineKeepsNewlineOnlyChunkBeforeLeakedContentFilter(t *testing.T) {
|
||||||
|
res := ParseDeepSeekContentLine([]byte("data: {\"p\":\"response/content\",\"v\":\"\\nCONTENT_FILTERblocked\"}"), false, "text")
|
||||||
|
if !res.Parsed || res.Stop {
|
||||||
|
t.Fatalf("expected parsed non-stop result: %#v", res)
|
||||||
|
}
|
||||||
|
if len(res.Parts) != 1 || res.Parts[0].Text != "\n" {
|
||||||
|
t.Fatalf("expected newline-only chunk preserved, got %#v", res.Parts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ var toolUseFunctionPattern = regexp.MustCompile(`(?is)<tool_use>\s*<function\s+n
|
|||||||
var toolUseNameParametersPattern = regexp.MustCompile(`(?is)<tool_use>\s*<tool_name>\s*([^<]+?)\s*</tool_name>\s*<parameters>\s*(.*?)\s*</parameters>\s*</tool_use>`)
|
var toolUseNameParametersPattern = regexp.MustCompile(`(?is)<tool_use>\s*<tool_name>\s*([^<]+?)\s*</tool_name>\s*<parameters>\s*(.*?)\s*</parameters>\s*</tool_use>`)
|
||||||
var toolUseFunctionNameParametersPattern = regexp.MustCompile(`(?is)<tool_use>\s*<function_name>\s*([^<]+?)\s*</function_name>\s*<parameters>\s*(.*?)\s*</parameters>\s*</tool_use>`)
|
var toolUseFunctionNameParametersPattern = regexp.MustCompile(`(?is)<tool_use>\s*<function_name>\s*([^<]+?)\s*</function_name>\s*<parameters>\s*(.*?)\s*</parameters>\s*</tool_use>`)
|
||||||
var toolUseToolNameBodyPattern = regexp.MustCompile(`(?is)<tool_use>\s*<tool_name>\s*([^<]+?)\s*</tool_name>\s*(.*?)\s*</tool_use>`)
|
var toolUseToolNameBodyPattern = regexp.MustCompile(`(?is)<tool_use>\s*<tool_name>\s*([^<]+?)\s*</tool_name>\s*(.*?)\s*</tool_use>`)
|
||||||
|
var xmlToolNamePatterns = []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?tool_name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?tool_name>`),
|
||||||
|
regexp.MustCompile(`(?is)<(?:[a-z0-9_:-]+:)?function_name\b[^>]*>(.*?)</(?:[a-z0-9_:-]+:)?function_name>`),
|
||||||
|
}
|
||||||
|
|
||||||
func parseXMLToolCalls(text string) []ParsedToolCall {
|
func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||||
matches := xmlToolCallPattern.FindAllString(text, -1)
|
matches := xmlToolCallPattern.FindAllString(text, -1)
|
||||||
@@ -81,9 +85,9 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dec := xml.NewDecoder(strings.NewReader(block))
|
|
||||||
name := ""
|
name := ""
|
||||||
params := map[string]any{}
|
params := extractXMLToolParamsByRegex(inner)
|
||||||
|
dec := xml.NewDecoder(strings.NewReader(block))
|
||||||
inParams := false
|
inParams := false
|
||||||
inTool := false
|
inTool := false
|
||||||
for {
|
for {
|
||||||
@@ -132,9 +136,13 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
inParams = false
|
inParams = false
|
||||||
case "tool_name", "name":
|
case "tool_name", "function_name", "name":
|
||||||
var v string
|
var v string
|
||||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||||
|
if inParams {
|
||||||
|
params[t.Name.Local] = strings.TrimSpace(v)
|
||||||
|
break
|
||||||
|
}
|
||||||
name = strings.TrimSpace(v)
|
name = strings.TrimSpace(v)
|
||||||
}
|
}
|
||||||
case "input", "arguments", "argument", "args", "params":
|
case "input", "arguments", "argument", "args", "params":
|
||||||
@@ -164,12 +172,60 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(name) == "" {
|
||||||
|
name = strings.TrimSpace(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner)))
|
||||||
|
}
|
||||||
if strings.TrimSpace(name) == "" {
|
if strings.TrimSpace(name) == "" {
|
||||||
return ParsedToolCall{}, false
|
return ParsedToolCall{}, false
|
||||||
}
|
}
|
||||||
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
|
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func stripTopLevelXMLParameters(inner string) string {
|
||||||
|
out := strings.TrimSpace(inner)
|
||||||
|
for {
|
||||||
|
idx := strings.Index(strings.ToLower(out), "<parameters")
|
||||||
|
if idx < 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
segment := out[idx:]
|
||||||
|
segmentLower := strings.ToLower(segment)
|
||||||
|
openEnd := strings.Index(segmentLower, ">")
|
||||||
|
if openEnd < 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
closeIdx := strings.Index(segmentLower, "</parameters>")
|
||||||
|
if closeIdx < 0 {
|
||||||
|
return out[:idx]
|
||||||
|
}
|
||||||
|
end := idx + closeIdx + len("</parameters>")
|
||||||
|
out = out[:idx] + out[end:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractXMLToolNameByRegex(inner string) string {
|
||||||
|
for _, pattern := range xmlToolNamePatterns {
|
||||||
|
if m := pattern.FindStringSubmatch(inner); len(m) >= 2 {
|
||||||
|
if v := strings.TrimSpace(stripTagText(m[1])); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractXMLToolParamsByRegex(inner string) map[string]any {
|
||||||
|
raw := findMarkupTagValue(inner, toolCallMarkupArgsTagNames, toolCallMarkupArgsPatternByTag)
|
||||||
|
if raw == "" {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
parsed := parseMarkupInput(raw)
|
||||||
|
if parsed == nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
|
||||||
m := functionCallPattern.FindStringSubmatch(text)
|
m := functionCallPattern.FindStringSubmatch(text)
|
||||||
if len(m) < 2 {
|
if len(m) < 2 {
|
||||||
|
|||||||
@@ -176,6 +176,35 @@ func TestParseToolCallsSupportsCanonicalXMLParametersJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsSupportsXMLParametersJSONWithAmpersandCommand(t *testing.T) {
|
||||||
|
text := `<tool_calls><tool_call><tool_name>execute_command</tool_name><parameters>{"command":"sshpass -p 'xxx' ssh -o StrictHostKeyChecking=no -p 1111 root@111.111.111.111 'cd /root && git clone https://github.com/ericc-ch/copilot-api.git'","cwd":null,"timeout":null}</parameters></tool_call></tool_calls>`
|
||||||
|
calls := ParseToolCalls(text, []string{"execute_command"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "execute_command" {
|
||||||
|
t.Fatalf("expected tool name execute_command, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
cmd, _ := calls[0].Input["command"].(string)
|
||||||
|
if !strings.Contains(cmd, "&& git clone") {
|
||||||
|
t.Fatalf("expected command to keep && segment, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsDoesNotTreatParameterNameTagAsToolName(t *testing.T) {
|
||||||
|
text := `<tool_call><tool name="execute_command"><parameters><name>file.txt</name><command>pwd</command></parameters></tool></tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"execute_command"})
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 call, got %#v", calls)
|
||||||
|
}
|
||||||
|
if calls[0].Name != "execute_command" {
|
||||||
|
t.Fatalf("expected tool name execute_command, got %q", calls[0].Name)
|
||||||
|
}
|
||||||
|
if calls[0].Input["name"] != "file.txt" {
|
||||||
|
t.Fatalf("expected parameter name preserved, got %#v", calls[0].Input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseToolCallsPrefersJSONPayloadOverIncidentalXMLInString(t *testing.T) {
|
func TestParseToolCallsPrefersJSONPayloadOverIncidentalXMLInString(t *testing.T) {
|
||||||
text := `{"tool_calls":[{"name":"search","input":{"q":"latest <tool_call><tool_name>wrong</tool_name><parameters>{\"x\":1}</parameters></tool_call>"}}]}`
|
text := `{"tool_calls":[{"name":"search","input":{"q":"latest <tool_call><tool_name>wrong</tool_name><parameters>{\"x\":1}</parameters></tool_call>"}}]}`
|
||||||
calls := ParseToolCallsDetailed(text, []string{"search"}).Calls
|
calls := ParseToolCallsDetailed(text, []string{"search"}).Calls
|
||||||
@@ -402,6 +431,14 @@ func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseToolCallsDoesNotTreatParametersFunctionNameAsToolName(t *testing.T) {
|
||||||
|
text := `<tool_call><parameters><function_name>data_only</function_name><path>README.md</path></parameters></tool_call>`
|
||||||
|
calls := ParseToolCalls(text, []string{"read_file"})
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool call when function_name appears only under parameters, got %#v", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRepairInvalidJSONBackslashes(t *testing.T) {
|
func TestRepairInvalidJSONBackslashes(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
Reference in New Issue
Block a user