mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-07 09:55:29 +08:00
Merge pull request #352 from shern-point/fix/tool-string-schema-protection
Fix/tool type schema protection
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/toolcall"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
@@ -31,30 +32,9 @@ func extractClaudeToolNames(tools []any) []string {
|
||||
}
|
||||
|
||||
func extractClaudeToolMeta(m map[string]any) (string, string, any) {
|
||||
name, _ := m["name"].(string)
|
||||
desc, _ := m["description"].(string)
|
||||
schemaObj := m["input_schema"]
|
||||
if schemaObj == nil {
|
||||
schemaObj = m["parameters"]
|
||||
}
|
||||
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name, _ = fn["name"].(string)
|
||||
}
|
||||
if strings.TrimSpace(desc) == "" {
|
||||
desc, _ = fn["description"].(string)
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["input_schema"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["parameters"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
name, desc, schemaObj := toolcall.ExtractToolMeta(m)
|
||||
if strings.TrimSpace(desc) == "" {
|
||||
desc = "No description available"
|
||||
}
|
||||
return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj
|
||||
}
|
||||
|
||||
@@ -177,7 +177,7 @@ func stripClaudeThinkingBlocks(raw []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -205,6 +205,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
searchEnabled,
|
||||
h.compatStripReferenceMarkers(),
|
||||
toolNames,
|
||||
toolsRaw,
|
||||
)
|
||||
streamRuntime.sendMessageStart()
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ func TestHandleClaudeStreamRealtimeTextIncrementsWithEventHeaders(t *testing.T)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil)
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil, nil)
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: message_start") {
|
||||
@@ -122,7 +122,7 @@ func TestHandleClaudeStreamRealtimeThinkingDelta(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, true, false, nil)
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, true, false, nil, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundThinkingDelta := false
|
||||
@@ -149,7 +149,7 @@ func TestHandleClaudeStreamRealtimeSkipsThinkingFallbackWhenFinalTextExists(t *t
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"})
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"}, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
@@ -180,7 +180,7 @@ func TestHandleClaudeStreamRealtimeUpstreamErrorEvent(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil)
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
errFrames := findClaudeFrames(frames, "error")
|
||||
@@ -217,7 +217,7 @@ func TestHandleClaudeStreamRealtimePingEvent(t *testing.T) {
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil)
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
if len(findClaudeFrames(frames, "ping")) == 0 {
|
||||
@@ -271,7 +271,7 @@ func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"})
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"}, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
@@ -299,7 +299,7 @@ func TestHandleClaudeStreamRealtimeDetectsToolUseWithLeadingProse(t *testing.T)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"write_file"})
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"write_file"}, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
@@ -333,7 +333,7 @@ func TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t *testing.T
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"})
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"}, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
@@ -365,3 +365,48 @@ func TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t *testing.T
|
||||
func TestHandleClaudeStreamRealtimePromotesUnclosedFencedToolExample(t *testing.T) {
|
||||
TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t)
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeNormalizesToolInputBySchema(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"<tool_calls><invoke name=\"Write\">{\"input\":{\"content\":{\"message\":\"hi\"},\"taskId\":1}}</invoke></tool_calls>"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
toolsRaw := []any{
|
||||
map[string]any{
|
||||
"name": "Write",
|
||||
"inputSchema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
"taskId": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "write"}}, false, false, []string{"Write"}, toolsRaw)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
for _, f := range findClaudeFrames(frames, "content_block_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["type"] != "input_json_delta" {
|
||||
continue
|
||||
}
|
||||
partial := asString(delta["partial_json"])
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(partial), &args); err != nil {
|
||||
t.Fatalf("decode partial_json failed: %v payload=%s", err, partial)
|
||||
}
|
||||
if args["content"] != `{"message":"hi"}` {
|
||||
t.Fatalf("expected content normalized to string, got %#v", args["content"])
|
||||
}
|
||||
if args["taskId"] != "1" {
|
||||
t.Fatalf("expected taskId normalized to string, got %#v", args["taskId"])
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Fatalf("expected input_json_delta frame, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
|
||||
ResolvedModel: dsModel,
|
||||
ResponseModel: strings.TrimSpace(model),
|
||||
Messages: payload["messages"].([]any),
|
||||
ToolsRaw: toolsRequested,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
|
||||
@@ -32,11 +32,39 @@ func TestNormalizeClaudeRequest(t *testing.T) {
|
||||
if len(norm.Standard.ToolNames) == 0 {
|
||||
t.Fatalf("expected tool names")
|
||||
}
|
||||
if norm.Standard.ToolsRaw == nil {
|
||||
t.Fatalf("expected ToolsRaw preserved for downstream normalization")
|
||||
}
|
||||
if norm.Standard.FinalPrompt == "" {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestSupportsCamelCaseInputSchemaPromptInjection(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"name": "todowrite",
|
||||
"description": "Write todos",
|
||||
"inputSchema": map[string]any{"type": "object", "properties": map[string]any{"todos": map[string]any{"type": "array"}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if !containsStr(norm.Standard.FinalPrompt, `"type":"array"`) {
|
||||
t.Fatalf("expected inputSchema to be injected into prompt, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
|
||||
@@ -18,6 +18,7 @@ type claudeStreamRuntime struct {
|
||||
model string
|
||||
toolNames []string
|
||||
messages []any
|
||||
toolsRaw any
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
@@ -47,6 +48,7 @@ func newClaudeStreamRuntime(
|
||||
searchEnabled bool,
|
||||
stripReferenceMarkers bool,
|
||||
toolNames []string,
|
||||
toolsRaw any,
|
||||
) *claudeStreamRuntime {
|
||||
return &claudeStreamRuntime{
|
||||
w: w,
|
||||
@@ -59,6 +61,7 @@ func newClaudeStreamRuntime(
|
||||
bufferToolContent: len(toolNames) > 0,
|
||||
stripReferenceMarkers: stripReferenceMarkers,
|
||||
toolNames: toolNames,
|
||||
toolsRaw: toolsRaw,
|
||||
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
thinkingBlockIndex: -1,
|
||||
textBlockIndex: -1,
|
||||
|
||||
@@ -52,6 +52,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
detected = toolcall.ParseStandaloneToolCalls(finalThinking, s.toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
detected = toolcall.NormalizeParsedToolCallsForSchemas(detected, s.toolsRaw)
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
idx := s.nextBlockIndex + i
|
||||
|
||||
Reference in New Issue
Block a user