mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
fix: classify empty upstream and tighten xml tool-name parsing
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
@@ -107,8 +106,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := sanitizeLeakedOutput(result.Text)
|
||||
if strings.TrimSpace(finalThinking) == "" && strings.TrimSpace(finalText) == "" {
|
||||
writeOpenAIError(w, http.StatusTooManyRequests, "Upstream model returned empty output; please retry.")
|
||||
if writeUpstreamEmptyOutputError(w, result) {
|
||||
return
|
||||
}
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
|
||||
@@ -275,7 +275,7 @@ func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
|
||||
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
|
||||
}
|
||||
|
||||
func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
func TestHandleNonStreamReturns502WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":""}`,
|
||||
@@ -284,14 +284,32 @@ func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty", "deepseek-chat", "prompt", false, nil)
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
if rec.Code != http.StatusBadGateway {
|
||||
t.Fatalf("expected status 502 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
msg, _ := errObj["message"].(string)
|
||||
if !strings.Contains(strings.ToLower(msg), "empty") {
|
||||
t.Fatalf("expected empty-output hint in error message, got %#v", out)
|
||||
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||
t.Fatalf("expected code=upstream_empty_output, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutput(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"code":"content_filter"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, nil)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "content_filter" {
|
||||
t.Fatalf("expected code=content_filter, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -114,8 +114,7 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
sanitizedText := sanitizeLeakedOutput(result.Text)
|
||||
if strings.TrimSpace(result.Thinking) == "" && strings.TrimSpace(sanitizedText) == "" {
|
||||
writeOpenAIError(w, http.StatusTooManyRequests, "Upstream model returned empty output; please retry.")
|
||||
if writeUpstreamEmptyOutputError(w, result) {
|
||||
return
|
||||
}
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(sanitizedText, toolNames)
|
||||
|
||||
@@ -627,7 +627,7 @@ func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
func TestHandleResponsesNonStreamReturns502WhenUpstreamOutputEmpty(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
@@ -639,14 +639,35 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T)
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
if rec.Code != http.StatusBadGateway {
|
||||
t.Fatalf("expected 502 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
msg, _ := errObj["message"].(string)
|
||||
if !strings.Contains(strings.ToLower(msg), "empty") {
|
||||
t.Fatalf("expected empty-output message, got %#v", out)
|
||||
if asString(errObj["code"]) != "upstream_empty_output" {
|
||||
t.Fatalf("expected code=upstream_empty_output, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutput(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"code":"content_filter"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "content_filter" {
|
||||
t.Fatalf("expected code=content_filter, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
20
internal/adapter/openai/upstream_empty.go
Normal file
20
internal/adapter/openai/upstream_empty.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
func writeUpstreamEmptyOutputError(w http.ResponseWriter, result sse.CollectResult) bool {
|
||||
if strings.TrimSpace(result.Thinking) != "" || strings.TrimSpace(sanitizeLeakedOutput(result.Text)) != "" {
|
||||
return false
|
||||
}
|
||||
if result.ContentFilter {
|
||||
writeOpenAIErrorWithCode(w, http.StatusBadRequest, "Upstream content filtered the response and returned no output.", "content_filter")
|
||||
return true
|
||||
}
|
||||
writeOpenAIErrorWithCode(w, http.StatusBadGateway, "Upstream model returned empty output.", "upstream_empty_output")
|
||||
return true
|
||||
}
|
||||
@@ -10,9 +10,10 @@ import (
|
||||
// CollectResult holds the aggregated text and thinking content from a
|
||||
// DeepSeek SSE stream, consumed to completion (non-streaming use case).
|
||||
type CollectResult struct {
|
||||
Text string
|
||||
Thinking string
|
||||
OutputTokens int
|
||||
Text string
|
||||
Thinking string
|
||||
OutputTokens int
|
||||
ContentFilter bool
|
||||
}
|
||||
|
||||
// CollectStream fully consumes a DeepSeek SSE response and separates
|
||||
@@ -28,6 +29,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
text := strings.Builder{}
|
||||
thinking := strings.Builder{}
|
||||
outputTokens := 0
|
||||
contentFilter := false
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
@@ -39,6 +41,9 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
return true
|
||||
}
|
||||
if result.Stop {
|
||||
if result.ContentFilter {
|
||||
contentFilter = true
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
@@ -56,5 +61,10 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
}
|
||||
return true
|
||||
})
|
||||
return CollectResult{Text: text.String(), Thinking: thinking.String(), OutputTokens: outputTokens}
|
||||
return CollectResult{
|
||||
Text: text.String(),
|
||||
Thinking: thinking.String(),
|
||||
OutputTokens: outputTokens,
|
||||
ContentFilter: contentFilter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(extractXMLToolNameByRegex(inner))
|
||||
name := ""
|
||||
params := extractXMLToolParamsByRegex(inner)
|
||||
dec := xml.NewDecoder(strings.NewReader(block))
|
||||
inParams := false
|
||||
@@ -136,9 +136,13 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
}
|
||||
inParams = false
|
||||
case "tool_name", "name":
|
||||
case "tool_name", "function_name", "name":
|
||||
var v string
|
||||
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
|
||||
if inParams {
|
||||
params[t.Name.Local] = strings.TrimSpace(v)
|
||||
break
|
||||
}
|
||||
name = strings.TrimSpace(v)
|
||||
}
|
||||
case "input", "arguments", "argument", "args", "params":
|
||||
@@ -168,12 +172,37 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name = strings.TrimSpace(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner)))
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
|
||||
}
|
||||
|
||||
func stripTopLevelXMLParameters(inner string) string {
|
||||
out := strings.TrimSpace(inner)
|
||||
for {
|
||||
idx := strings.Index(strings.ToLower(out), "<parameters")
|
||||
if idx < 0 {
|
||||
return out
|
||||
}
|
||||
segment := out[idx:]
|
||||
segmentLower := strings.ToLower(segment)
|
||||
openEnd := strings.Index(segmentLower, ">")
|
||||
if openEnd < 0 {
|
||||
return out
|
||||
}
|
||||
closeIdx := strings.Index(segmentLower, "</parameters>")
|
||||
if closeIdx < 0 {
|
||||
return out[:idx]
|
||||
}
|
||||
end := idx + closeIdx + len("</parameters>")
|
||||
out = out[:idx] + out[end:]
|
||||
}
|
||||
}
|
||||
|
||||
func extractXMLToolNameByRegex(inner string) string {
|
||||
for _, pattern := range xmlToolNamePatterns {
|
||||
if m := pattern.FindStringSubmatch(inner); len(m) >= 2 {
|
||||
|
||||
@@ -431,6 +431,14 @@ func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsDoesNotTreatParametersFunctionNameAsToolName(t *testing.T) {
|
||||
text := `<tool_call><parameters><function_name>data_only</function_name><path>README.md</path></parameters></tool_call>`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected no tool call when function_name appears only under parameters, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairInvalidJSONBackslashes(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
|
||||
Reference in New Issue
Block a user