mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-11 11:47:43 +08:00
test: Introduce comprehensive edge case tests across multiple modules and refine tool call and OpenAI handler logic.
This commit is contained in:
@@ -241,6 +241,35 @@ func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamFencedToolCallExampleNotIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls field for fenced example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "```json") || !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected fenced tool example to pass through as text, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -428,9 +457,9 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
|
||||
func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"前置正文A。"}`,
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"后置正文B。"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -457,7 +486,7 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") {
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(got, `"tool_calls"`) {
|
||||
@@ -468,6 +497,48 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
spaces := strings.Repeat(" ", 200)
|
||||
|
||||
@@ -11,6 +11,7 @@ type toolStreamSieveState struct {
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
hasMeaningfulText bool
|
||||
recentTextTail string
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
@@ -32,6 +33,7 @@ type toolCallDelta struct {
|
||||
}
|
||||
|
||||
const toolSieveCaptureLimit = 8 * 1024
|
||||
const toolSieveContextTailLimit = 256
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.toolNameSent = false
|
||||
@@ -67,9 +69,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if strings.TrimSpace(content) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
continue
|
||||
}
|
||||
@@ -79,9 +79,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if prefix != "" {
|
||||
if strings.TrimSpace(prefix) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
@@ -101,9 +99,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
if strings.TrimSpace(prefix) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
state.pending.Reset()
|
||||
@@ -119,9 +115,7 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.pending.WriteString(hold)
|
||||
if strings.TrimSpace(safe) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
state.noteText(safe)
|
||||
events = append(events, toolStreamEvent{Content: safe})
|
||||
}
|
||||
|
||||
@@ -137,26 +131,20 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||
if ready {
|
||||
if consumedPrefix != "" {
|
||||
if strings.TrimSpace(consumedPrefix) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
state.noteText(consumedPrefix)
|
||||
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
||||
}
|
||||
if len(consumedCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
||||
}
|
||||
if consumedSuffix != "" {
|
||||
if strings.TrimSpace(consumedSuffix) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
state.noteText(consumedSuffix)
|
||||
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
||||
}
|
||||
} else {
|
||||
content := state.capture.String()
|
||||
if content != "" {
|
||||
if strings.TrimSpace(content) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
}
|
||||
@@ -166,9 +154,7 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
content := state.pending.String()
|
||||
if strings.TrimSpace(content) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
state.pending.Reset()
|
||||
}
|
||||
@@ -241,7 +227,7 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
}
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
if !state.toolNameSent && (state.hasMeaningfulText || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "") {
|
||||
if !state.toolNameSent && (strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" || looksLikeToolExampleContext(state.recentTextTail)) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
|
||||
@@ -304,7 +290,10 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
|
||||
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
captured := state.capture.String()
|
||||
if captured == "" || state.hasMeaningfulText {
|
||||
if captured == "" {
|
||||
return nil
|
||||
}
|
||||
if looksLikeToolExampleContext(state.recentTextTail) {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
@@ -618,3 +607,46 @@ func skipSpaces(text string, i int) int {
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func (s *toolStreamSieveState) noteText(content string) {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
s.hasMeaningfulText = true
|
||||
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
|
||||
}
|
||||
|
||||
func appendTail(prev, next string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
combined := prev + next
|
||||
if len(combined) <= max {
|
||||
return combined
|
||||
}
|
||||
return combined[len(combined)-max:]
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
t := strings.ToLower(strings.TrimSpace(text))
|
||||
if t == "" {
|
||||
return false
|
||||
}
|
||||
cues := []string{
|
||||
"示例",
|
||||
"例子",
|
||||
"for example",
|
||||
"example",
|
||||
"demo",
|
||||
"请勿执行",
|
||||
"不要执行",
|
||||
"do not execute",
|
||||
"```",
|
||||
}
|
||||
for _, cue := range cues {
|
||||
if strings.Contains(t, cue) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -41,6 +41,9 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
if looksLikeToolExampleContext(trimmed) {
|
||||
return nil
|
||||
}
|
||||
candidates := []string{trimmed}
|
||||
if strings.HasPrefix(trimmed, "```") && strings.HasSuffix(trimmed, "```") {
|
||||
if m := fencedJSONPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
|
||||
@@ -313,6 +316,30 @@ func extractJSONObject(text string, start int) (string, int, bool) {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func looksLikeToolExampleContext(text string) bool {
|
||||
t := strings.ToLower(strings.TrimSpace(text))
|
||||
if t == "" {
|
||||
return false
|
||||
}
|
||||
cues := []string{
|
||||
"```",
|
||||
"示例",
|
||||
"例子",
|
||||
"for example",
|
||||
"example",
|
||||
"demo",
|
||||
"请勿执行",
|
||||
"不要执行",
|
||||
"do not execute",
|
||||
}
|
||||
for _, cue := range cues {
|
||||
if strings.Contains(t, cue) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
|
||||
out := make([]map[string]any, 0, len(calls))
|
||||
for _, c := range calls {
|
||||
|
||||
@@ -75,3 +75,10 @@ func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) {
|
||||
t.Fatalf("expected standalone parser to match, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
|
||||
fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```"
|
||||
if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 {
|
||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user