mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-02 07:25:26 +08:00
工具调用优化
This commit is contained in:
@@ -98,7 +98,7 @@ func (s *chatStreamRuntime) sendDone() {
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
detected := util.ParseStandaloneToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
|
||||
@@ -211,7 +211,7 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
||||
func TestHandleNonStreamEmbeddedToolCallExampleRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
@@ -229,16 +229,16 @@ func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) == 0 {
|
||||
t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||
}
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"])
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "下面是示例:") || !strings.Contains(content, "请勿执行。") || !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded example to remain plain text, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -513,8 +513,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -531,15 +531,15 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
||||
t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got)
|
||||
if !strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded tool json to remain text in strict mode, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
||||
func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||
@@ -555,8 +555,8 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -573,15 +573,15 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) {
|
||||
func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
||||
@@ -596,8 +596,8 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testin
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -614,15 +614,15 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testin
|
||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
spaces := strings.Repeat(" ", 200)
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -640,11 +640,8 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -658,14 +655,14 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if strings.Contains(got, "{") {
|
||||
t.Fatalf("unexpected suspicious prefix leak in content: %q", got)
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") || !strings.Contains(got, "{") {
|
||||
t.Fatalf("expected embedded tool json to remain in text, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(got, "后置正文C。") {
|
||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -743,7 +740,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
||||
func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||
@@ -766,8 +763,8 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
argChunks := streamToolCallArgumentChunks(frames)
|
||||
if len(argChunks) < 2 {
|
||||
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String())
|
||||
if len(argChunks) == 0 {
|
||||
t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String())
|
||||
}
|
||||
joined := strings.Join(argChunks, "")
|
||||
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||
|
||||
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
@@ -175,30 +174,11 @@ func normalizeToolArgumentString(raw string) string {
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if !looksLikeConcatenatedJSON(trimmed) {
|
||||
return trimmed
|
||||
if looksLikeConcatenatedJSON(trimmed) {
|
||||
// Keep original payload to avoid silent argument rewrites.
|
||||
return raw
|
||||
}
|
||||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||||
values := make([]any, 0, 2)
|
||||
for {
|
||||
var v any
|
||||
if err := dec.Decode(&v); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
if len(values) < 2 {
|
||||
return trimmed
|
||||
}
|
||||
last := values[len(values)-1]
|
||||
b, err := json.Marshal(last)
|
||||
if err != nil || len(b) == 0 {
|
||||
return trimmed
|
||||
}
|
||||
return string(b)
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func marshalToPromptString(v any) string {
|
||||
|
||||
@@ -168,7 +168,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) {
|
||||
func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
@@ -189,10 +189,7 @@ func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *te
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(content, `function.arguments: {"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected repaired arguments in tool history, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, `{}{"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected concatenated JSON to be repaired, got %q", content)
|
||||
if !strings.Contains(content, `function.arguments: {}{"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) {
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItemPreservesConcatenatedArguments(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
@@ -151,8 +151,8 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArg
|
||||
toolCalls, _ := m["tool_calls"].([]any)
|
||||
call, _ := toolCalls[0].(map[string]any)
|
||||
fn, _ := call["function"].(map[string]any)
|
||||
if fn["arguments"] != `{"q":"golang"}` {
|
||||
t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"])
|
||||
if fn["arguments"] != `{}{"q":"golang"}` {
|
||||
t.Fatalf("expected original concatenated call arguments preserved, got %#v", fn["arguments"])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -99,9 +99,6 @@ func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||
t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||
}
|
||||
@@ -360,7 +357,7 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) {
|
||||
func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -373,7 +370,7 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
// invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it
|
||||
// invalid JSON (NaN) should remain plain text in strict mode.
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
@@ -382,14 +379,11 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||
t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body)
|
||||
if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected runtime to close function output item, body=%s", body)
|
||||
if !strings.Contains(body, "event: response.output_text.delta") {
|
||||
t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
|
||||
@@ -167,19 +167,15 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
outputText, _ := out["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
break
|
||||
}
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one output item, got %#v", output)
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected message output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,21 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.pending.WriteString(chunk)
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
if len(state.pendingToolCalls) > 0 {
|
||||
pending := state.pending.String()
|
||||
if strings.TrimSpace(pending) != "" {
|
||||
content := state.pendingToolRaw + pending
|
||||
state.pending.Reset()
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
} else {
|
||||
// Wait for either more non-whitespace content (demote to plain text)
|
||||
// or stream flush (promote to executable tool calls).
|
||||
return events
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
if state.capturing {
|
||||
@@ -21,23 +36,23 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||
if !ready {
|
||||
break
|
||||
}
|
||||
captured := state.capture.String()
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if len(calls) > 0 {
|
||||
state.pendingToolRaw = captured
|
||||
state.pendingToolCalls = calls
|
||||
continue
|
||||
}
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: calls})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
@@ -80,6 +95,11 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
return nil
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if len(state.pendingToolCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
}
|
||||
if state.capturing {
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||
if ready {
|
||||
@@ -191,6 +211,11 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
// Strict mode: only standalone tool payloads are executable. If the
|
||||
// payload is wrapped by non-whitespace prose, keep it as plain text.
|
||||
if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||
if len(parsed.Calls) == 0 {
|
||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||
|
||||
@@ -7,17 +7,19 @@ import (
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
pendingToolRaw string
|
||||
pendingToolCalls []util.ParsedToolCall
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
|
||||
@@ -11,12 +11,9 @@ import (
|
||||
)
|
||||
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
// Align responses tool-call semantics with chat/completions:
|
||||
// mixed prose + tool_call payloads should still be interpreted as tool calls.
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||
}
|
||||
// Strict mode: only standalone, structured tool-call payloads are treated
|
||||
// as executable tool calls.
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) {
|
||||
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
@@ -56,17 +56,16 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T)
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected first output type function_call, got %#v", first["type"])
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected message output type, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,7 +126,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
func TestBuildResponseObjectIgnoresToolCallFromThinkingChannel(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
@@ -139,10 +138,10 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected output function_call, got %#v", first["type"])
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected output message, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,16 +57,20 @@ func NewApp() *App {
|
||||
r.Use(cors)
|
||||
r.Use(timeout(0))
|
||||
|
||||
r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
healthzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||
})
|
||||
r.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
readyzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"status":"ready"}`))
|
||||
})
|
||||
}
|
||||
r.Get("/healthz", healthzHandler)
|
||||
r.Head("/healthz", healthzHandler)
|
||||
r.Get("/readyz", readyzHandler)
|
||||
r.Head("/readyz", readyzHandler)
|
||||
openai.RegisterRoutes(r, openaiHandler)
|
||||
claude.RegisterRoutes(r, claudeHandler)
|
||||
gemini.RegisterRoutes(r, geminiHandler)
|
||||
|
||||
20
internal/server/router_health_test.go
Normal file
20
internal/server/router_health_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHealthEndpointsSupportHEAD(t *testing.T) {
|
||||
app := NewApp()
|
||||
|
||||
for _, path := range []string{"/healthz", "/readyz"} {
|
||||
req := httptest.NewRequest(http.MethodHead, path, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.Router.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected %s HEAD status 200, got %d", path, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,12 @@ func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error {
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
|
||||
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/healthz", Retryable: true})
|
||||
if headErr != nil {
|
||||
return headErr
|
||||
}
|
||||
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -29,6 +35,12 @@ func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error {
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
|
||||
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/readyz", Retryable: true})
|
||||
if headErr != nil {
|
||||
return headErr
|
||||
}
|
||||
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user