mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-10 19:27:41 +08:00
feat: add Gemini API compatibility, refactor stream rendering, and enhance tool call handling and configuration options
This commit is contained in:
@@ -112,14 +112,20 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNam
|
||||
return nil
|
||||
}
|
||||
allowed := namesToSet(allowedNames)
|
||||
if len(allowed) == 0 {
|
||||
for _, d := range deltas {
|
||||
if d.Name != "" {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
out := make([]toolCallDelta, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name != "" {
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[d.Name]; !ok {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
continue
|
||||
}
|
||||
if _, ok := allowed[d.Name]; !ok {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
continue
|
||||
}
|
||||
seenNames[d.Index] = d.Name
|
||||
out = append(out, d)
|
||||
|
||||
@@ -3,6 +3,7 @@ package openai
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
@@ -163,12 +164,43 @@ func normalizeOpenAIContentForPrompt(v any) string {
|
||||
func normalizeOpenAIArgumentsForPrompt(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(x)
|
||||
return normalizeToolArgumentString(x)
|
||||
default:
|
||||
return marshalToPromptString(v)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeToolArgumentString(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if !looksLikeConcatenatedJSON(trimmed) {
|
||||
return trimmed
|
||||
}
|
||||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||||
values := make([]any, 0, 2)
|
||||
for {
|
||||
var v any
|
||||
if err := dec.Decode(&v); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
if len(values) < 2 {
|
||||
return trimmed
|
||||
}
|
||||
last := values[len(values)-1]
|
||||
b, err := json.Marshal(last)
|
||||
if err != nil || len(b) == 0 {
|
||||
return trimmed
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func marshalToPromptString(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
|
||||
@@ -167,3 +167,32 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara
|
||||
t.Fatalf("unexpected concatenated function arguments detected: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_1",
|
||||
"function": map[string]any{
|
||||
"name": "search_web",
|
||||
"arguments": `{}{"query":"测试工具调用"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(content, `function.arguments: {"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected repaired arguments in tool history, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, `{}{"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected concatenated JSON to be repaired, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,6 +135,27 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": "call_456",
|
||||
"name": "search",
|
||||
"arguments": `{}{"q":"golang"}`,
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
toolCalls, _ := m["tool_calls"].([]any)
|
||||
call, _ := toolCalls[0].(map[string]any)
|
||||
fn, _ := call["function"].(map[string]any)
|
||||
if fn["arguments"] != `{"q":"golang"}` {
|
||||
t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmbeddingInputs(t *testing.T) {
|
||||
got := extractEmbeddingInputs([]any{"a", "b"})
|
||||
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
|
||||
|
||||
@@ -190,7 +190,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
|
||||
func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolicy, parsed util.ToolCallParseResult, channel string) {
|
||||
if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 {
|
||||
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
||||
return
|
||||
}
|
||||
config.Logger.Warn(
|
||||
@@ -198,6 +199,23 @@ func logResponsesToolPolicyRejection(traceID string, policy util.ToolChoicePolic
|
||||
"trace_id", strings.TrimSpace(traceID),
|
||||
"channel", channel,
|
||||
"tool_choice_mode", policy.Mode,
|
||||
"rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","),
|
||||
"rejected_tool_names", strings.Join(rejected, ","),
|
||||
)
|
||||
}
|
||||
|
||||
func filteredRejectedToolNamesForLog(names []string) []string {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(names))
|
||||
for _, name := range names {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
switch strings.ToLower(trimmed) {
|
||||
case "", "tool_name":
|
||||
continue
|
||||
default:
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -188,6 +188,10 @@ func stringifyToolCallArguments(v any) string {
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
s = normalizeToolArgumentString(s)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
|
||||
@@ -37,13 +37,14 @@ type responsesStreamRuntime struct {
|
||||
text strings.Builder
|
||||
visibleText strings.Builder
|
||||
streamToolCallIDs map[int]string
|
||||
streamFunctionIDs map[int]string
|
||||
functionItemIDs map[int]string
|
||||
functionOutputIDs map[int]int
|
||||
functionDone map[int]bool
|
||||
functionAdded map[int]bool
|
||||
functionNames map[int]string
|
||||
toolCallsDoneSigs map[string]bool
|
||||
reasoningItemID string
|
||||
messageItemID string
|
||||
messageOutputID int
|
||||
nextOutputID int
|
||||
messageAdded bool
|
||||
messagePartAdded bool
|
||||
sequence int
|
||||
@@ -81,11 +82,12 @@ func newResponsesStreamRuntime(
|
||||
bufferToolContent: bufferToolContent,
|
||||
emitEarlyToolDeltas: emitEarlyToolDeltas,
|
||||
streamToolCallIDs: map[int]string{},
|
||||
streamFunctionIDs: map[int]string{},
|
||||
functionItemIDs: map[int]string{},
|
||||
functionOutputIDs: map[int]int{},
|
||||
functionDone: map[int]bool{},
|
||||
functionAdded: map[int]bool{},
|
||||
functionNames: map[int]string{},
|
||||
toolCallsDoneSigs: map[string]bool{},
|
||||
messageOutputID: -1,
|
||||
toolChoice: toolChoice,
|
||||
traceID: traceID,
|
||||
persistResponse: persistResponse,
|
||||
@@ -144,10 +146,7 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
return
|
||||
}
|
||||
|
||||
obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
|
||||
if s.toolCallsEmitted {
|
||||
s.alignCompletedOutputCallIDs(obj)
|
||||
}
|
||||
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
@@ -157,7 +156,8 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
|
||||
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) {
|
||||
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
||||
if !parsed.RejectedByPolicy || len(parsed.RejectedToolNames) == 0 {
|
||||
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
||||
return
|
||||
}
|
||||
config.Logger.Warn(
|
||||
@@ -165,7 +165,7 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingPar
|
||||
"trace_id", strings.TrimSpace(s.traceID),
|
||||
"channel", channel,
|
||||
"tool_choice_mode", s.toolChoice.Mode,
|
||||
"rejected_tool_names", strings.Join(parsed.RejectedToolNames, ","),
|
||||
"rejected_tool_names", strings.Join(rejected, ","),
|
||||
)
|
||||
}
|
||||
logRejected(textParsed, "text")
|
||||
|
||||
@@ -11,6 +11,12 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (s *responsesStreamRuntime) allocateOutputIndex() int {
|
||||
idx := s.nextOutputID
|
||||
s.nextOutputID++
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageItemID() string {
|
||||
if strings.TrimSpace(s.messageItemID) != "" {
|
||||
return s.messageItemID
|
||||
@@ -19,11 +25,12 @@ func (s *responsesStreamRuntime) ensureMessageItemID() string {
|
||||
return s.messageItemID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) messageOutputIndex() int {
|
||||
if strings.TrimSpace(s.thinking.String()) != "" {
|
||||
return 1
|
||||
func (s *responsesStreamRuntime) ensureMessageOutputIndex() int {
|
||||
if s.messageOutputID >= 0 {
|
||||
return s.messageOutputID
|
||||
}
|
||||
return 0
|
||||
s.messageOutputID = s.allocateOutputIndex()
|
||||
return s.messageOutputID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureMessageItemAdded() {
|
||||
@@ -39,7 +46,7 @@ func (s *responsesStreamRuntime) ensureMessageItemAdded() {
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.added",
|
||||
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.messageOutputIndex(), item),
|
||||
openaifmt.BuildResponsesOutputItemAddedPayload(s.responseID, itemID, s.ensureMessageOutputIndex(), item),
|
||||
)
|
||||
s.messageAdded = true
|
||||
}
|
||||
@@ -54,7 +61,7 @@ func (s *responsesStreamRuntime) ensureMessageContentPartAdded() {
|
||||
openaifmt.BuildResponsesContentPartAddedPayload(
|
||||
s.responseID,
|
||||
s.ensureMessageItemID(),
|
||||
s.messageOutputIndex(),
|
||||
s.ensureMessageOutputIndex(),
|
||||
0,
|
||||
map[string]any{"type": "output_text", "text": ""},
|
||||
),
|
||||
@@ -68,7 +75,16 @@ func (s *responsesStreamRuntime) emitTextDelta(content string) {
|
||||
}
|
||||
s.ensureMessageContentPartAdded()
|
||||
s.visibleText.WriteString(content)
|
||||
s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, content))
|
||||
s.sendEvent(
|
||||
"response.output_text.delta",
|
||||
openaifmt.BuildResponsesTextDeltaPayload(
|
||||
s.responseID,
|
||||
s.ensureMessageItemID(),
|
||||
s.ensureMessageOutputIndex(),
|
||||
0,
|
||||
content,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) closeMessageItem() {
|
||||
@@ -76,6 +92,7 @@ func (s *responsesStreamRuntime) closeMessageItem() {
|
||||
return
|
||||
}
|
||||
itemID := s.ensureMessageItemID()
|
||||
outputIndex := s.ensureMessageOutputIndex()
|
||||
text := s.visibleText.String()
|
||||
if s.messagePartAdded {
|
||||
s.sendEvent(
|
||||
@@ -83,7 +100,7 @@ func (s *responsesStreamRuntime) closeMessageItem() {
|
||||
openaifmt.BuildResponsesContentPartDonePayload(
|
||||
s.responseID,
|
||||
itemID,
|
||||
s.messageOutputIndex(),
|
||||
outputIndex,
|
||||
0,
|
||||
map[string]any{"type": "output_text", "text": text},
|
||||
),
|
||||
@@ -104,45 +121,35 @@ func (s *responsesStreamRuntime) closeMessageItem() {
|
||||
}
|
||||
s.sendEvent(
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, s.messageOutputIndex(), item),
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureReasoningItemID() string {
|
||||
if strings.TrimSpace(s.reasoningItemID) != "" {
|
||||
return s.reasoningItemID
|
||||
}
|
||||
s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
return s.reasoningItemID
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
|
||||
if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemID(callIndex int) string {
|
||||
if id, ok := s.functionItemIDs[callIndex]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.streamFunctionIDs[index] = id
|
||||
s.functionItemIDs[callIndex] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
|
||||
if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
|
||||
func (s *responsesStreamRuntime) ensureToolCallID(callIndex int) string {
|
||||
if id, ok := s.streamToolCallIDs[callIndex]; ok && strings.TrimSpace(id) != "" {
|
||||
return id
|
||||
}
|
||||
id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
s.streamToolCallIDs[index] = id
|
||||
s.streamToolCallIDs[callIndex] = id
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
|
||||
if strings.TrimSpace(s.thinking.String()) != "" {
|
||||
return 1
|
||||
func (s *responsesStreamRuntime) ensureFunctionOutputIndex(callIndex int) int {
|
||||
if idx, ok := s.functionOutputIDs[callIndex]; ok {
|
||||
return idx
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) functionOutputIndex(callIndex int) int {
|
||||
return s.functionOutputBaseIndex() + callIndex
|
||||
idx := s.allocateOutputIndex()
|
||||
s.functionOutputIDs[callIndex] = idx
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name string) {
|
||||
@@ -156,15 +163,15 @@ func (s *responsesStreamRuntime) ensureFunctionItemAdded(callIndex int, name str
|
||||
if fnName == "" {
|
||||
return
|
||||
}
|
||||
outputIndex := s.functionOutputIndex(callIndex)
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
outputIndex := s.ensureFunctionOutputIndex(callIndex)
|
||||
itemID := s.ensureFunctionItemID(callIndex)
|
||||
callID := s.ensureToolCallID(callIndex)
|
||||
item := map[string]any{
|
||||
"id": itemID,
|
||||
"type": "function_call",
|
||||
"call_id": callID,
|
||||
"name": fnName,
|
||||
"arguments": "{}",
|
||||
"arguments": "",
|
||||
"status": "in_progress",
|
||||
}
|
||||
s.sendEvent(
|
||||
@@ -181,8 +188,8 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDe
|
||||
if strings.TrimSpace(d.Arguments) == "" {
|
||||
continue
|
||||
}
|
||||
outputIndex := s.functionOutputIndex(d.Index)
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
outputIndex := s.ensureFunctionOutputIndex(d.Index)
|
||||
itemID := s.ensureFunctionItemID(d.Index)
|
||||
callID := s.ensureToolCallID(d.Index)
|
||||
s.sendEvent(
|
||||
"response.function_call_arguments.delta",
|
||||
@@ -192,18 +199,16 @@ func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDe
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
|
||||
base := s.functionOutputBaseIndex()
|
||||
for idx, tc := range calls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
s.ensureFunctionItemAdded(idx, tc.Name)
|
||||
|
||||
outputIndex := base + idx
|
||||
if s.functionDone[outputIndex] {
|
||||
if s.functionDone[idx] {
|
||||
continue
|
||||
}
|
||||
itemID := s.ensureFunctionItemID(outputIndex)
|
||||
outputIndex := s.ensureFunctionOutputIndex(idx)
|
||||
itemID := s.ensureFunctionItemID(idx)
|
||||
callID := s.ensureToolCallID(idx)
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
args := string(argsBytes)
|
||||
@@ -223,48 +228,105 @@ func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedT
|
||||
"response.output_item.done",
|
||||
openaifmt.BuildResponsesOutputItemDonePayload(s.responseID, itemID, outputIndex, item),
|
||||
)
|
||||
s.functionDone[outputIndex] = true
|
||||
s.functionDone[idx] = true
|
||||
s.toolCallsDoneEmitted = true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
|
||||
if obj == nil || len(s.streamToolCallIDs) == 0 {
|
||||
return
|
||||
func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, finalText string, calls []util.ParsedToolCall) map[string]any {
|
||||
type indexedItem struct {
|
||||
index int
|
||||
item map[string]any
|
||||
}
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) == 0 {
|
||||
return
|
||||
}
|
||||
indices := make([]int, 0, len(s.streamToolCallIDs))
|
||||
for idx := range s.streamToolCallIDs {
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
sort.Ints(indices)
|
||||
ordered := make([]string, 0, len(indices))
|
||||
for _, idx := range indices {
|
||||
id := strings.TrimSpace(s.streamToolCallIDs[idx])
|
||||
if id == "" {
|
||||
continue
|
||||
indexed := make([]indexedItem, 0, len(calls)+1)
|
||||
|
||||
if s.messageAdded {
|
||||
text := s.visibleText.String()
|
||||
indexed = append(indexed, indexedItem{
|
||||
index: s.ensureMessageOutputIndex(),
|
||||
item: map[string]any{
|
||||
"id": s.ensureMessageItemID(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
} else if len(calls) == 0 {
|
||||
content := make([]map[string]any, 0, 2)
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
content = append(content, map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
})
|
||||
}
|
||||
if strings.TrimSpace(finalText) != "" {
|
||||
content = append(content, map[string]any{
|
||||
"type": "output_text",
|
||||
"text": finalText,
|
||||
})
|
||||
}
|
||||
if len(content) > 0 {
|
||||
indexed = append(indexed, indexedItem{
|
||||
index: s.ensureMessageOutputIndex(),
|
||||
item: map[string]any{
|
||||
"id": s.ensureMessageItemID(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": content,
|
||||
},
|
||||
})
|
||||
}
|
||||
ordered = append(ordered, id)
|
||||
}
|
||||
if len(ordered) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
functionIdx := 0
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
for idx, tc := range calls {
|
||||
if strings.TrimSpace(tc.Name) == "" {
|
||||
continue
|
||||
}
|
||||
if m["type"] != "function_call" {
|
||||
continue
|
||||
}
|
||||
if functionIdx < len(ordered) {
|
||||
m["call_id"] = ordered[functionIdx]
|
||||
functionIdx++
|
||||
argsBytes, _ := json.Marshal(tc.Input)
|
||||
indexed = append(indexed, indexedItem{
|
||||
index: s.ensureFunctionOutputIndex(idx),
|
||||
item: map[string]any{
|
||||
"id": s.ensureFunctionItemID(idx),
|
||||
"type": "function_call",
|
||||
"call_id": s.ensureToolCallID(idx),
|
||||
"name": tc.Name,
|
||||
"arguments": string(argsBytes),
|
||||
"status": "completed",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
sort.SliceStable(indexed, func(i, j int) bool {
|
||||
return indexed[i].index < indexed[j].index
|
||||
})
|
||||
output := make([]any, 0, len(indexed))
|
||||
for _, it := range indexed {
|
||||
output = append(output, it.item)
|
||||
}
|
||||
|
||||
outputText := s.visibleText.String()
|
||||
if strings.TrimSpace(outputText) == "" && len(calls) == 0 {
|
||||
if strings.TrimSpace(finalText) != "" {
|
||||
outputText = finalText
|
||||
} else if strings.TrimSpace(finalThinking) != "" {
|
||||
outputText = finalThinking
|
||||
}
|
||||
}
|
||||
|
||||
return openaifmt.BuildResponseObjectFromItems(
|
||||
s.responseID,
|
||||
s.model,
|
||||
s.finalPrompt,
|
||||
finalThinking,
|
||||
finalText,
|
||||
output,
|
||||
outputText,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -109,6 +109,22 @@ func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
||||
t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body)
|
||||
}
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(body, "response.output_item.added")
|
||||
hasFunctionCallAdded := false
|
||||
for _, payload := range addedPayloads {
|
||||
item, _ := payload["item"].(map[string]any)
|
||||
if item == nil || asString(item["type"]) != "function_call" {
|
||||
continue
|
||||
}
|
||||
hasFunctionCallAdded = true
|
||||
if asString(item["arguments"]) != "" {
|
||||
t.Fatalf("expected in-progress function_call.arguments to start empty string, got %#v", item["arguments"])
|
||||
}
|
||||
}
|
||||
if !hasFunctionCallAdded {
|
||||
t.Fatalf("expected function_call output_item.added payload, body=%s", body)
|
||||
}
|
||||
|
||||
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
|
||||
if !ok {
|
||||
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
|
||||
@@ -213,6 +229,137 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("hello") + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
|
||||
deltaPayload, ok := extractSSEEventPayload(body, "response.output_text.delta")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.output_text.delta payload, body=%s", body)
|
||||
}
|
||||
if strings.TrimSpace(asString(deltaPayload["item_id"])) == "" {
|
||||
t.Fatalf("expected non-empty item_id in output_text.delta, payload=%#v", deltaPayload)
|
||||
}
|
||||
if _, ok := deltaPayload["output_index"]; !ok {
|
||||
t.Fatalf("expected output_index in output_text.delta, payload=%#v", deltaPayload)
|
||||
}
|
||||
if _, ok := deltaPayload["content_index"]; !ok {
|
||||
t.Fatalf("expected content_index in output_text.delta, payload=%#v", deltaPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, value string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": value,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", "thinking...") +
|
||||
sseLine("response/content", "先读取文件。") +
|
||||
sseLine("response/content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||
if len(addedPayloads) < 2 {
|
||||
t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
}
|
||||
|
||||
indexes := map[int]struct{}{}
|
||||
typeByIndex := map[int]string{}
|
||||
addedIDs := map[string]string{}
|
||||
for _, payload := range addedPayloads {
|
||||
item, _ := payload["item"].(map[string]any)
|
||||
itemType := strings.TrimSpace(asString(item["type"]))
|
||||
outputIndex := int(asFloat(payload["output_index"]))
|
||||
if _, exists := indexes[outputIndex]; exists {
|
||||
t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload)
|
||||
}
|
||||
indexes[outputIndex] = struct{}{}
|
||||
typeByIndex[outputIndex] = itemType
|
||||
addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"]))
|
||||
}
|
||||
|
||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
if !ok {
|
||||
t.Fatalf("expected response.completed payload, body=%s", rec.Body.String())
|
||||
}
|
||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
found := map[string]bool{}
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
itemType := strings.TrimSpace(asString(m["type"]))
|
||||
itemID := strings.TrimSpace(asString(m["id"]))
|
||||
if itemType == "" || itemID == "" {
|
||||
continue
|
||||
}
|
||||
if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID {
|
||||
found[itemType] = true
|
||||
}
|
||||
}
|
||||
if !found["message"] || !found["function_call"] {
|
||||
t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "")
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for tool_choice=none, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -299,6 +446,32 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for tool_choice=none passthrough text, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
output, _ := out["output"].([]any)
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
t.Fatalf("did not expect function_call output item for tool_choice=none, got %#v", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
@@ -351,3 +524,18 @@ func extractAllSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asFloat(v any) float64 {
|
||||
switch x := v.(type) {
|
||||
case float64:
|
||||
return x
|
||||
case float32:
|
||||
return float64(x)
|
||||
case int:
|
||||
return float64(x)
|
||||
case int64:
|
||||
return float64(x)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,3 +151,30 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testi
|
||||
t.Fatalf("expected forced undeclared tool to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
},
|
||||
"tool_choice": "none",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ToolChoice.Mode != util.ToolChoiceNone {
|
||||
t.Fatalf("expected tool choice mode none, got %q", n.ToolChoice.Mode)
|
||||
}
|
||||
if len(n.ToolNames) != 0 {
|
||||
t.Fatalf("expected no tool names when tool_choice=none, got %#v", n.ToolNames)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,12 +21,6 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
exposedOutputText = ""
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
output = append(output, map[string]any{
|
||||
"type": "reasoning",
|
||||
"text": finalThinking,
|
||||
})
|
||||
}
|
||||
output = append(output, toResponsesFunctionCallItems(detected)...)
|
||||
} else {
|
||||
content := make([]any, 0, 2)
|
||||
@@ -52,6 +46,21 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
return BuildResponseObjectFromItems(
|
||||
responseID,
|
||||
model,
|
||||
finalPrompt,
|
||||
finalThinking,
|
||||
finalText,
|
||||
output,
|
||||
exposedOutputText,
|
||||
)
|
||||
}
|
||||
|
||||
func BuildResponseObjectFromItems(responseID, model, finalPrompt, finalThinking, finalText string, output []any, outputText string) map[string]any {
|
||||
if output == nil {
|
||||
output = []any{}
|
||||
}
|
||||
return map[string]any{
|
||||
"id": responseID,
|
||||
"type": "response",
|
||||
@@ -60,7 +69,7 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": output,
|
||||
"output_text": exposedOutputText,
|
||||
"output_text": outputText,
|
||||
"usage": BuildResponsesUsage(finalPrompt, finalThinking, finalText),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,12 +59,15 @@ func BuildResponsesContentPartDonePayload(responseID, itemID string, outputIndex
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
func BuildResponsesTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"delta": delta,
|
||||
"type": "response.output_text.delta",
|
||||
"id": responseID,
|
||||
"response_id": responseID,
|
||||
"item_id": itemID,
|
||||
"output_index": outputIndex,
|
||||
"content_index": contentIndex,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -138,15 +138,11 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
)
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 2 {
|
||||
t.Fatalf("expected reasoning + function_call outputs, got %#v", obj["output"])
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "reasoning" {
|
||||
t.Fatalf("expected first output reasoning, got %#v", first["type"])
|
||||
}
|
||||
second, _ := output[1].(map[string]any)
|
||||
if second["type"] != "function_call" {
|
||||
t.Fatalf("expected second output function_call, got %#v", second["type"])
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected output function_call, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
package util
|
||||
|
||||
// BuildOpenAIChatStreamDeltaChoice is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildChatStreamDeltaChoice for new code.
|
||||
func BuildOpenAIChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": delta,
|
||||
"index": index,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIChatStreamFinishChoice is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildChatStreamFinishChoice for new code.
|
||||
func BuildOpenAIChatStreamFinishChoice(index int, finishReason string) map[string]any {
|
||||
return map[string]any{
|
||||
"delta": map[string]any{},
|
||||
"index": index,
|
||||
"finish_reason": finishReason,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIChatStreamChunk is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildChatStreamChunk for new code.
|
||||
func BuildOpenAIChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
|
||||
out := map[string]any{
|
||||
"id": completionID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": choices,
|
||||
}
|
||||
if len(usage) > 0 {
|
||||
out["usage"] = usage
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// BuildOpenAIChatUsage is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildChatUsage for new code.
|
||||
func BuildOpenAIChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := EstimateTokens(finalPrompt)
|
||||
reasoningTokens := EstimateTokens(finalThinking)
|
||||
completionTokens := EstimateTokens(finalText)
|
||||
return map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": reasoningTokens + completionTokens,
|
||||
"total_tokens": promptTokens + reasoningTokens + completionTokens,
|
||||
"completion_tokens_details": map[string]any{
|
||||
"reasoning_tokens": reasoningTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesCreatedPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesCreatedPayload for new code.
|
||||
func BuildOpenAIResponsesCreatedPayload(responseID, model string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.created",
|
||||
"id": responseID,
|
||||
"object": "response",
|
||||
"model": model,
|
||||
"status": "in_progress",
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesTextDeltaPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesTextDeltaPayload for new code.
|
||||
func BuildOpenAIResponsesTextDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_text.delta",
|
||||
"id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesReasoningDeltaPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesReasoningDeltaPayload for new code.
|
||||
func BuildOpenAIResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.reasoning.delta",
|
||||
"id": responseID,
|
||||
"delta": delta,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesToolCallDeltaPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesToolCallDeltaPayload for new code.
|
||||
func BuildOpenAIResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
"id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesToolCallDonePayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesToolCallDonePayload for new code.
|
||||
func BuildOpenAIResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.output_tool_call.done",
|
||||
"id": responseID,
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesCompletedPayload is kept for backward compatibility.
|
||||
// Prefer internal/format/openai.BuildResponsesCompletedPayload for new code.
|
||||
func BuildOpenAIResponsesCompletedPayload(response map[string]any) map[string]any {
|
||||
return map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": response,
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildOpenAIChatStreamChunk(t *testing.T) {
|
||||
chunk := BuildOpenAIChatStreamChunk(
|
||||
"cid",
|
||||
123,
|
||||
"deepseek-chat",
|
||||
[]map[string]any{BuildOpenAIChatStreamDeltaChoice(0, map[string]any{"role": "assistant"})},
|
||||
nil,
|
||||
)
|
||||
if chunk["object"] != "chat.completion.chunk" {
|
||||
t.Fatalf("unexpected object: %#v", chunk["object"])
|
||||
}
|
||||
choices, _ := chunk["choices"].([]map[string]any)
|
||||
if len(choices) == 0 {
|
||||
rawChoices, _ := chunk["choices"].([]any)
|
||||
if len(rawChoices) == 0 {
|
||||
t.Fatalf("expected choices")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIChatUsage(t *testing.T) {
|
||||
usage := BuildOpenAIChatUsage("prompt", "think", "answer")
|
||||
if _, ok := usage["prompt_tokens"]; !ok {
|
||||
t.Fatalf("expected prompt_tokens")
|
||||
}
|
||||
if _, ok := usage["completion_tokens_details"]; !ok {
|
||||
t.Fatalf("expected completion_tokens_details")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIResponsesEventPayloads(t *testing.T) {
|
||||
created := BuildOpenAIResponsesCreatedPayload("resp_1", "gpt-4o")
|
||||
if created["type"] != "response.created" {
|
||||
t.Fatalf("unexpected type: %#v", created["type"])
|
||||
}
|
||||
done := BuildOpenAIResponsesToolCallDonePayload("resp_1", []map[string]any{{"index": 0}})
|
||||
if done["type"] != "response.output_tool_call.done" {
|
||||
t.Fatalf("unexpected type: %#v", done["type"])
|
||||
}
|
||||
completed := BuildOpenAIResponsesCompletedPayload(map[string]any{"id": "resp_1"})
|
||||
if completed["type"] != "response.completed" {
|
||||
t.Fatalf("unexpected type: %#v", completed["type"])
|
||||
}
|
||||
}
|
||||
@@ -92,17 +92,29 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
|
||||
for _, name := range availableToolNames {
|
||||
allowed[name] = struct{}{}
|
||||
}
|
||||
if len(allowed) == 0 {
|
||||
rejectedSet := map[string]struct{}{}
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
}
|
||||
rejected := make([]string, 0, len(rejectedSet))
|
||||
for name := range rejectedSet {
|
||||
rejected = append(rejected, name)
|
||||
}
|
||||
return nil, rejected
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(parsed))
|
||||
rejectedSet := map[string]struct{}{}
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[tc.Name]; !ok {
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
continue
|
||||
}
|
||||
if _, ok := allowed[tc.Name]; !ok {
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
continue
|
||||
}
|
||||
if tc.Input == nil {
|
||||
tc.Input = map[string]any{}
|
||||
|
||||
@@ -60,6 +60,20 @@ func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsDetailedRejectsWhenAllowListEmpty(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
res := ParseToolCallsDetailed(text, nil)
|
||||
if !res.SawToolCallSyntax {
|
||||
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
|
||||
}
|
||||
if !res.RejectedByPolicy {
|
||||
t.Fatalf("expected RejectedByPolicy=true, got %#v", res)
|
||||
}
|
||||
if len(res.Calls) != 0 {
|
||||
t.Fatalf("expected no calls when allow-list is empty, got %#v", res.Calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatOpenAIToolCalls(t *testing.T) {
|
||||
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}})
|
||||
if len(formatted) != 1 {
|
||||
|
||||
@@ -364,8 +364,8 @@ func TestFormatOpenAIStreamToolCalls(t *testing.T) {
|
||||
func TestParseToolCallsNoToolNames(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
calls := ParseToolCalls(text, nil)
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call with nil tool names, got %d", len(calls))
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected 0 call with nil tool names, got %d", len(calls))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user