mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-23 01:17:44 +08:00
refactor backend API structure
This commit is contained in:
11
internal/httpapi/claude/convert.go
Normal file
11
internal/httpapi/claude/convert.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/claudeconv"
|
||||
)
|
||||
|
||||
const defaultClaudeModel = "claude-sonnet-4-6"
|
||||
|
||||
func convertClaudeToDeepSeek(claudeReq map[string]any, store ConfigReader) map[string]any {
|
||||
return claudeconv.ConvertClaudeToDeepSeek(claudeReq, store, defaultClaudeModel)
|
||||
}
|
||||
34
internal/httpapi/claude/deps.go
Normal file
34
internal/httpapi/claude/deps.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
dsclient "ds2api/internal/deepseek/client"
|
||||
)
|
||||
|
||||
type AuthResolver interface {
|
||||
Determine(req *http.Request) (*auth.RequestAuth, error)
|
||||
Release(a *auth.RequestAuth)
|
||||
}
|
||||
|
||||
type DeepSeekCaller interface {
|
||||
CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error)
|
||||
CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error)
|
||||
}
|
||||
|
||||
type ConfigReader interface {
|
||||
ModelAliases() map[string]string
|
||||
CompatStripReferenceMarkers() bool
|
||||
}
|
||||
|
||||
type OpenAIChatRunner interface {
|
||||
ChatCompletions(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*dsclient.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
74
internal/httpapi/claude/deps_injection_test.go
Normal file
74
internal/httpapi/claude/deps_injection_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package claude
|
||||
|
||||
import "testing"
|
||||
|
||||
type mockClaudeConfig struct {
|
||||
aliases map[string]string
|
||||
}
|
||||
|
||||
func (m mockClaudeConfig) ModelAliases() map[string]string { return m.aliases }
|
||||
func (mockClaudeConfig) CompatStripReferenceMarkers() bool { return true }
|
||||
|
||||
func TestNormalizeClaudeRequestUsesGlobalAliasMapping(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
out, err := normalizeClaudeRequest(mockClaudeConfig{
|
||||
aliases: map[string]string{
|
||||
"claude-opus-4-6": "deepseek-v4-pro-search",
|
||||
},
|
||||
}, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeClaudeRequest error: %v", err)
|
||||
}
|
||||
if out.Standard.ResolvedModel != "deepseek-v4-pro-search" {
|
||||
t.Fatalf("resolved model mismatch: got=%q", out.Standard.ResolvedModel)
|
||||
}
|
||||
if out.Standard.Thinking || !out.Standard.Search {
|
||||
t.Fatalf("unexpected flags: thinking=%v search=%v", out.Standard.Thinking, out.Standard.Search)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestEnablesThinkingWhenRequested(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"thinking": map[string]any{"type": "enabled", "budget_tokens": 1024},
|
||||
}
|
||||
out, err := normalizeClaudeRequest(mockClaudeConfig{
|
||||
aliases: map[string]string{
|
||||
"claude-opus-4-6": "deepseek-v4-pro",
|
||||
},
|
||||
}, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeClaudeRequest error: %v", err)
|
||||
}
|
||||
if !out.Standard.Thinking {
|
||||
t.Fatalf("expected explicit Claude thinking request to enable downstream thinking")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestPrefersGlobalAliasMapping(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
out, err := normalizeClaudeRequest(mockClaudeConfig{
|
||||
aliases: map[string]string{
|
||||
"claude-sonnet-4-6": "deepseek-v4-flash",
|
||||
},
|
||||
}, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeClaudeRequest error: %v", err)
|
||||
}
|
||||
if out.Standard.ResolvedModel != "deepseek-v4-flash" {
|
||||
t.Fatalf("expected global alias to win for explicit model, got=%q", out.Standard.ResolvedModel)
|
||||
}
|
||||
}
|
||||
34
internal/httpapi/claude/error_shape_test.go
Normal file
34
internal/httpapi/claude/error_shape_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
writeClaudeError(rec, http.StatusUnauthorized, "bad token")
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", rec.Code)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode body: %v", err)
|
||||
}
|
||||
errObj, _ := body["error"].(map[string]any)
|
||||
if errObj["message"] != "bad token" {
|
||||
t.Fatalf("unexpected message: %v", errObj["message"])
|
||||
}
|
||||
if errObj["type"] != "invalid_request_error" {
|
||||
t.Fatalf("unexpected type: %v", errObj["type"])
|
||||
}
|
||||
if errObj["code"] != "authentication_failed" {
|
||||
t.Fatalf("unexpected code: %v", errObj["code"])
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatal("expected param field")
|
||||
}
|
||||
}
|
||||
25
internal/httpapi/claude/handler_errors.go
Normal file
25
internal/httpapi/claude/handler_errors.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package claude
|
||||
|
||||
import "net/http"
|
||||
|
||||
func writeClaudeError(w http.ResponseWriter, status int, message string) {
|
||||
code := "invalid_request"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
code = "authentication_failed"
|
||||
case http.StatusTooManyRequests:
|
||||
code = "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
code = "not_found"
|
||||
case http.StatusInternalServerError:
|
||||
code = "internal_error"
|
||||
}
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
"type": "invalid_request_error",
|
||||
"message": message,
|
||||
"code": code,
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
97
internal/httpapi/claude/handler_helpers_misc.go
Normal file
97
internal/httpapi/claude/handler_helpers_misc.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func hasSystemMessage(messages []any) bool {
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if ok && msg["role"] == "system" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractClaudeToolNames(tools []any) []string {
|
||||
out := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _, _ := extractClaudeToolMeta(m)
|
||||
if name != "" {
|
||||
out = append(out, name)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractClaudeToolMeta(m map[string]any) (string, string, any) {
|
||||
name, _ := m["name"].(string)
|
||||
desc, _ := m["description"].(string)
|
||||
schemaObj := m["input_schema"]
|
||||
if schemaObj == nil {
|
||||
schemaObj = m["parameters"]
|
||||
}
|
||||
|
||||
if fn, ok := m["function"].(map[string]any); ok {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
name, _ = fn["name"].(string)
|
||||
}
|
||||
if strings.TrimSpace(desc) == "" {
|
||||
desc, _ = fn["description"].(string)
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["input_schema"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
if schemaObj == nil {
|
||||
if v, ok := fn["parameters"]; ok {
|
||||
schemaObj = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(name), strings.TrimSpace(desc), schemaObj
|
||||
}
|
||||
|
||||
func toMessageMaps(v any) []map[string]any {
|
||||
arr, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
out = append(out, m)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractMessageContent(v any) string {
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
return x
|
||||
case []any:
|
||||
parts := make([]string, 0, len(x))
|
||||
for _, it := range x {
|
||||
parts = append(parts, fmt.Sprintf("%v", it))
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return fmt.Sprintf("%v", x)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
201
internal/httpapi/claude/handler_messages.go
Normal file
201
internal/httpapi/claude/handler_messages.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
if h.OpenAI == nil {
|
||||
writeClaudeError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
|
||||
return
|
||||
}
|
||||
if h.proxyViaOpenAI(w, r, h.Store) {
|
||||
return
|
||||
}
|
||||
writeClaudeError(w, http.StatusBadGateway, "Failed to proxy Claude request.")
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid body")
|
||||
return true
|
||||
}
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(raw, &req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return true
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
stream := util.ToBool(req["stream"])
|
||||
|
||||
// Use the shared global model resolver so Claude/OpenAI/Gemini stay consistent.
|
||||
translateModel := model
|
||||
if store != nil {
|
||||
if norm, normErr := normalizeClaudeRequest(store, cloneMap(req)); normErr == nil && strings.TrimSpace(norm.Standard.ResolvedModel) != "" {
|
||||
translateModel = strings.TrimSpace(norm.Standard.ResolvedModel)
|
||||
}
|
||||
}
|
||||
translatedReq := translatorcliproxy.ToOpenAI(sdktranslator.FormatClaude, translateModel, raw, stream)
|
||||
translatedReq = applyClaudeThinkingPolicyToOpenAIRequest(translatedReq, req)
|
||||
|
||||
isVercelPrepare := strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1"
|
||||
isVercelRelease := strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
|
||||
if isVercelRelease {
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(raw))
|
||||
proxyReq.ContentLength = int64(len(raw))
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer func() { _ = res.Body.Close() }()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(translatedReq))
|
||||
proxyReq.ContentLength = int64(len(translatedReq))
|
||||
|
||||
if stream && !isVercelPrepare {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
streamWriter := translatorcliproxy.NewOpenAIStreamTranslatorWriter(w, sdktranslator.FormatClaude, model, raw, translatedReq)
|
||||
h.OpenAI.ChatCompletions(streamWriter, proxyReq)
|
||||
return true
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer func() { _ = res.Body.Close() }()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
if isVercelPrepare {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
converted := translatorcliproxy.FromOpenAINonStream(sdktranslator.FormatClaude, model, raw, translatedReq, body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(converted)
|
||||
return true
|
||||
}
|
||||
|
||||
func applyClaudeThinkingPolicyToOpenAIRequest(translated []byte, original map[string]any) []byte {
|
||||
req := map[string]any{}
|
||||
if err := json.Unmarshal(translated, &req); err != nil {
|
||||
return translated
|
||||
}
|
||||
enabled, ok := util.ResolveThinkingOverride(original)
|
||||
if !ok {
|
||||
if _, translatedHasOverride := util.ResolveThinkingOverride(req); translatedHasOverride {
|
||||
return translated
|
||||
}
|
||||
enabled = false
|
||||
}
|
||||
typ := "disabled"
|
||||
if enabled {
|
||||
typ = "enabled"
|
||||
}
|
||||
req["thinking"] = map[string]any{"type": typ}
|
||||
out, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return translated
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
rc := http.NewResponseController(w)
|
||||
_, canFlush := w.(http.Flusher)
|
||||
if !canFlush {
|
||||
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
||||
}
|
||||
|
||||
streamRuntime := newClaudeStreamRuntime(
|
||||
w,
|
||||
rc,
|
||||
canFlush,
|
||||
model,
|
||||
messages,
|
||||
thinkingEnabled,
|
||||
searchEnabled,
|
||||
h.compatStripReferenceMarkers(),
|
||||
toolNames,
|
||||
)
|
||||
streamRuntime.sendMessageStart()
|
||||
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
}
|
||||
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
||||
Context: r.Context(),
|
||||
Body: resp.Body,
|
||||
ThinkingEnabled: thinkingEnabled,
|
||||
InitialType: initialType,
|
||||
KeepAliveInterval: claudeStreamPingInterval,
|
||||
IdleTimeout: claudeStreamIdleTimeout,
|
||||
MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
|
||||
}, streamengine.ConsumeHooks{
|
||||
OnKeepAlive: func() {
|
||||
streamRuntime.sendPing()
|
||||
},
|
||||
OnParsed: streamRuntime.onParsed,
|
||||
OnFinalize: streamRuntime.onFinalize,
|
||||
})
|
||||
}
|
||||
49
internal/httpapi/claude/handler_routes.go
Normal file
49
internal/httpapi/claude/handler_routes.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/config"
|
||||
dsprotocol "ds2api/internal/deepseek/protocol"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatRunner
|
||||
}
|
||||
|
||||
func (h *Handler) compatStripReferenceMarkers() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
return h.Store.CompatStripReferenceMarkers()
|
||||
}
|
||||
|
||||
var (
|
||||
claudeStreamPingInterval = time.Duration(dsprotocol.KeepAliveTimeout) * time.Second
|
||||
claudeStreamIdleTimeout = time.Duration(dsprotocol.StreamIdleTimeout) * time.Second
|
||||
claudeStreamMaxKeepaliveCnt = dsprotocol.MaxKeepaliveCount
|
||||
)
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/anthropic/v1/models", h.ListModels)
|
||||
r.Post("/anthropic/v1/messages", h.Messages)
|
||||
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/v1/messages", h.Messages)
|
||||
r.Post("/messages", h.Messages)
|
||||
r.Post("/v1/messages/count_tokens", h.CountTokens)
|
||||
r.Post("/messages/count_tokens", h.CountTokens)
|
||||
}
|
||||
|
||||
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
||||
}
|
||||
367
internal/httpapi/claude/handler_stream_test.go
Normal file
367
internal/httpapi/claude/handler_stream_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/sse"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type claudeFrame struct {
|
||||
Event string
|
||||
Payload map[string]any
|
||||
}
|
||||
|
||||
func makeClaudeSSEHTTPResponse(lines ...string) *http.Response {
|
||||
body := strings.Join(lines, "\n")
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
body += "\n"
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func parseClaudeFrames(t *testing.T, body string) []claudeFrame {
|
||||
t.Helper()
|
||||
chunks := strings.Split(body, "\n\n")
|
||||
frames := make([]claudeFrame, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
chunk = strings.TrimSpace(chunk)
|
||||
if chunk == "" {
|
||||
continue
|
||||
}
|
||||
lines := strings.Split(chunk, "\n")
|
||||
eventName := ""
|
||||
dataPayload := ""
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
switch {
|
||||
case strings.HasPrefix(line, "event:"):
|
||||
eventName = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
case strings.HasPrefix(line, "data:"):
|
||||
dataPayload = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
}
|
||||
}
|
||||
if eventName == "" || dataPayload == "" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(dataPayload), &payload); err != nil {
|
||||
t.Fatalf("decode frame failed: %v, payload=%s", err, dataPayload)
|
||||
}
|
||||
frames = append(frames, claudeFrame{Event: eventName, Payload: payload})
|
||||
}
|
||||
return frames
|
||||
}
|
||||
|
||||
func findClaudeFrames(frames []claudeFrame, event string) []claudeFrame {
|
||||
out := make([]claudeFrame, 0)
|
||||
for _, f := range frames {
|
||||
if f.Event == event {
|
||||
out = append(out, f)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeTextIncrementsWithEventHeaders(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"Hel"}`,
|
||||
`data: {"p":"response/content","v":"lo"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil)
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: message_start") {
|
||||
t.Fatalf("missing event header: message_start, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: content_block_delta") {
|
||||
t.Fatalf("missing event header: content_block_delta, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: message_stop") {
|
||||
t.Fatalf("missing event header: message_stop, body=%s", body)
|
||||
}
|
||||
|
||||
frames := parseClaudeFrames(t, body)
|
||||
deltas := findClaudeFrames(frames, "content_block_delta")
|
||||
if len(deltas) < 2 {
|
||||
t.Fatalf("expected at least 2 text deltas, got=%d body=%s", len(deltas), body)
|
||||
}
|
||||
combined := strings.Builder{}
|
||||
for _, f := range deltas {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["type"] == "text_delta" {
|
||||
combined.WriteString(asString(delta["text"]))
|
||||
}
|
||||
}
|
||||
if combined.String() != "Hello" {
|
||||
t.Fatalf("unexpected combined text: %q body=%s", combined.String(), body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeThinkingDelta(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"思"}`,
|
||||
`data: {"p":"response/thinking_content","v":"考"}`,
|
||||
`data: {"p":"response/content","v":"ok"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, true, false, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundThinkingDelta := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["type"] == "thinking_delta" {
|
||||
foundThinkingDelta = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundThinkingDelta {
|
||||
t.Fatalf("expected thinking_delta event, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeSkipsThinkingFallbackWhenFinalTextExists(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
`data: {"p":"response/thinking_content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"normal answer"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
t.Fatalf("unexpected tool_use block when final text exists, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
foundEndTurn := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "end_turn" {
|
||||
foundEndTurn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundEndTurn {
|
||||
t.Fatalf("expected stop_reason=end_turn, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeUpstreamErrorEvent(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"error":{"message":"boom"}}`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
errFrames := findClaudeFrames(frames, "error")
|
||||
if len(errFrames) == 0 {
|
||||
t.Fatalf("expected error event frame, body=%s", rec.Body.String())
|
||||
}
|
||||
if errFrames[0].Payload["type"] != "error" {
|
||||
t.Fatalf("expected error payload type, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimePingEvent(t *testing.T) {
|
||||
h := &Handler{}
|
||||
oldPing := claudeStreamPingInterval
|
||||
oldIdle := claudeStreamIdleTimeout
|
||||
oldKeepalive := claudeStreamMaxKeepaliveCnt
|
||||
claudeStreamPingInterval = 10 * time.Millisecond
|
||||
claudeStreamIdleTimeout = 300 * time.Millisecond
|
||||
claudeStreamMaxKeepaliveCnt = 50
|
||||
defer func() {
|
||||
claudeStreamPingInterval = oldPing
|
||||
claudeStreamIdleTimeout = oldIdle
|
||||
claudeStreamMaxKeepaliveCnt = oldKeepalive
|
||||
}()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Header: make(http.Header), Body: pr}
|
||||
go func() {
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
_, _ = io.WriteString(pw, "data: {\"p\":\"response/content\",\"v\":\"hi\"}\n")
|
||||
_, _ = io.WriteString(pw, "data: [DONE]\n")
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "hi"}}, false, false, nil)
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
if len(findClaudeFrames(frames, "ping")) == 0 {
|
||||
t.Fatalf("expected ping event in stream, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectDeepSeekRegression(t *testing.T) {
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/thinking_content","v":"想"}`,
|
||||
`data: {"p":"response/content","v":"答"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
result := sse.CollectStream(resp, true, true)
|
||||
if result.Thinking != "想" {
|
||||
t.Fatalf("unexpected thinking: %q", result.Thinking)
|
||||
}
|
||||
if result.Text != "答" {
|
||||
t.Fatalf("unexpected text: %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload string
|
||||
wantToolUse bool
|
||||
}{
|
||||
{name: "invoke_parameter_wrapper", payload: `<tool_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></tool_calls>`, wantToolUse: true},
|
||||
{name: "legacy_single_tool_root", payload: `<tool><tool_name>Bash</tool_name><param><command>pwd</command></param></tool>`, wantToolUse: false},
|
||||
{name: "legacy_tool_call_json", payload: `<tool>{"tool":"Bash","params":{"command":"pwd"}}</tool>`, wantToolUse: false},
|
||||
{name: "legacy_nested_tool_tag_style", payload: `<tool><tool name="Bash"><command>pwd</command></tool_call></tool>`, wantToolUse: false},
|
||||
{name: "legacy_function_tag_style", payload: `<function_call>Bash</function_call><function parameter name="command">pwd</function parameter>`, wantToolUse: false},
|
||||
{name: "legacy_antml_argument_style", payload: `<antml:function_calls><antml:function_call id="1" name="Bash"><antml:argument name="command">pwd</antml:argument></antml:function_call></antml:function_calls>`, wantToolUse: false},
|
||||
{name: "legacy_antml_function_attr_parameters", payload: `<antml:function_calls><antml:function_call id="1" function="Bash"><antml:parameters>{"command":"pwd"}</antml:parameters></antml:function_call></antml:function_calls>`, wantToolUse: false},
|
||||
{name: "legacy_function_calls_wrapper", payload: `<function_calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></function_calls>`, wantToolUse: false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"`+strings.ReplaceAll(tc.payload, `"`, `\"`)+`"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundToolUse != tc.wantToolUse {
|
||||
t.Fatalf("unexpected tool_use=%v for format %s, body=%s", foundToolUse, tc.name, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeDetectsToolUseWithLeadingProse(t *testing.T) {
|
||||
h := &Handler{}
|
||||
payload := "I'll call a tool now.\\n<tool_calls><invoke name=\\\"write_file\\\"><parameter name=\\\"path\\\">/tmp/a.txt</parameter><parameter name=\\\"content\\\">abc</parameter></invoke></tool_calls>"
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"`+payload+`"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"write_file"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" && contentBlock["name"] == "write_file" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Fatalf("expected tool_use block with leading prose payload, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "tool_use" {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected stop_reason=tool_use, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeClaudeSSEHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"Here is an example:\\n```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"Bash\\\",\\\"input\\\":{\\\"command\\\":\\\"pwd\\\"}}]}\"}",
|
||||
"data: {\"p\":\"response/content\",\"v\":\"\\n```\\nDo not execute it.\"}",
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
|
||||
|
||||
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"})
|
||||
|
||||
frames := parseClaudeFrames(t, rec.Body.String())
|
||||
foundToolUse := false
|
||||
for _, f := range findClaudeFrames(frames, "content_block_start") {
|
||||
contentBlock, _ := f.Payload["content_block"].(map[string]any)
|
||||
if contentBlock["type"] == "tool_use" {
|
||||
foundToolUse = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundToolUse {
|
||||
t.Fatalf("expected no tool_use for fenced example, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
foundToolStop := false
|
||||
for _, f := range findClaudeFrames(frames, "message_delta") {
|
||||
delta, _ := f.Payload["delta"].(map[string]any)
|
||||
if delta["stop_reason"] == "tool_use" {
|
||||
foundToolStop = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundToolStop {
|
||||
t.Fatalf("expected stop_reason to remain content-only, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Backward-compatible alias for historical test name used in CI logs.
|
||||
func TestHandleClaudeStreamRealtimePromotesUnclosedFencedToolExample(t *testing.T) {
|
||||
TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t)
|
||||
}
|
||||
51
internal/httpapi/claude/handler_tokens.go
Normal file
51
internal/httpapi/claude/handler_tokens.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messages, _ := req["messages"].([]any)
|
||||
if model == "" || len(messages) == 0 {
|
||||
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
return
|
||||
}
|
||||
inputTokens := 0
|
||||
if sys, ok := req["system"].(string); ok {
|
||||
inputTokens += util.EstimateTokens(sys)
|
||||
}
|
||||
for _, item := range messages {
|
||||
msg, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inputTokens += 2
|
||||
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
||||
}
|
||||
if tools, ok := req["tools"].([]any); ok {
|
||||
for _, t := range tools {
|
||||
b, _ := json.Marshal(t)
|
||||
inputTokens += util.EstimateTokens(string(b))
|
||||
}
|
||||
}
|
||||
if inputTokens < 1 {
|
||||
inputTokens = 1
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
||||
}
|
||||
560
internal/httpapi/claude/handler_util_test.go
Normal file
560
internal/httpapi/claude/handler_util_test.go
Normal file
@@ -0,0 +1,560 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ─── normalizeClaudeMessages ─────────────────────────────────────────
|
||||
|
||||
func TestNormalizeClaudeMessagesSimpleString(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "user", "content": "Hello"},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %v", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesArrayContent(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "line1"},
|
||||
map[string]any{"type": "text", "text": "line2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["content"] != "line1\nline2" {
|
||||
t.Fatalf("expected joined text, got %q", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "tool_result", "content": "tool output"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role preserved, got %#v", m["role"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if content != "tool output" {
|
||||
t.Fatalf("expected raw tool output content preserved, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_1",
|
||||
"name": "search_web",
|
||||
"input": map[string]any{"query": "latest"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized tool-call message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant role, got %#v", m["role"])
|
||||
}
|
||||
tc, _ := m["tool_calls"].([]any)
|
||||
if len(tc) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", m["tool_calls"])
|
||||
}
|
||||
call, _ := tc[0].(map[string]any)
|
||||
if call["id"] != "call_1" {
|
||||
t.Fatalf("expected call id preserved, got %#v", call)
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, "<tool_calls>") || !containsStr(content, `<invoke name="search_web">`) {
|
||||
t.Fatalf("expected assistant content to include XML tool call history, got %q", content)
|
||||
}
|
||||
if !containsStr(content, `<parameter name="query"><![CDATA[latest]]></parameter>`) {
|
||||
t.Fatalf("expected assistant content to include serialized parameters, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesDoesNotPromoteUserToolUse(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "call_unsafe",
|
||||
"name": "dangerous_tool",
|
||||
"input": map[string]any{"value": "x"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "user" {
|
||||
t.Fatalf("expected user role preserved, got %#v", m["role"])
|
||||
}
|
||||
if _, ok := m["tool_calls"]; ok {
|
||||
t.Fatalf("expected no tool_calls promotion for user message, got %#v", m["tool_calls"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, `"type":"tool_use"`) || !containsStr(content, "dangerous_tool") {
|
||||
t.Fatalf("expected raw tool_use block preserved in user content, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) {
|
||||
msgs := []any{"not a map", 42}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected 0 messages for non-map items, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesEmpty(t *testing.T) {
|
||||
got := normalizeClaudeMessages(nil)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected 0, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesPreservesRole(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "assistant", "content": "response"},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected 'assistant', got %q", m["role"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "Hello"},
|
||||
map[string]any{"type": "image", "source": map[string]any{"type": "base64", "data": strings.Repeat("A", 2048)}},
|
||||
map[string]any{"type": "text", "text": "World"},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
m := got[0].(map[string]any)
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, "Hello") || !containsStr(content, "World") || !containsStr(content, `"type":"image"`) {
|
||||
t.Fatalf("expected text plus non-text block marker preserved, got %q", content)
|
||||
}
|
||||
if !containsStr(content, omittedBinaryMarker) {
|
||||
t.Fatalf("expected binary payload omitted marker, got %q", content)
|
||||
}
|
||||
if containsStr(content, strings.Repeat("A", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesToolResultNonTextPayloadStringified(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_image_1",
|
||||
"name": "vision_tool",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "image analysis"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{"type": "base64", "media_type": "image/png", "data": strings.Repeat("B", 2048)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected one normalized message, got %d", len(got))
|
||||
}
|
||||
m := got[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", m["role"])
|
||||
}
|
||||
content, _ := m["content"].(string)
|
||||
if !containsStr(content, `"type":"tool_result"`) || !containsStr(content, `"type":"image"`) {
|
||||
t.Fatalf("expected non-text tool_result payload to be JSON stringified, got %q", content)
|
||||
}
|
||||
if !containsStr(content, omittedBinaryMarker) {
|
||||
t.Fatalf("expected binary data to be sanitized with omitted marker, got %q", content)
|
||||
}
|
||||
if containsStr(content, strings.Repeat("B", 100)) {
|
||||
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeMessagesBackfillsToolResultCallIDByName(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"name": "search_web",
|
||||
"input": map[string]any{"query": "latest"},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"name": "search_web",
|
||||
"content": "ok",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := normalizeClaudeMessages(msgs)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %#v", got)
|
||||
}
|
||||
assistant, _ := got[0].(map[string]any)
|
||||
tc, _ := assistant["tool_calls"].([]any)
|
||||
call, _ := tc[0].(map[string]any)
|
||||
callID, _ := call["id"].(string)
|
||||
if !strings.HasPrefix(callID, "call_claude_") {
|
||||
t.Fatalf("expected generated call id, got %#v", call)
|
||||
}
|
||||
toolMsg, _ := got[1].(map[string]any)
|
||||
if toolMsg["tool_call_id"] != callID {
|
||||
t.Fatalf("expected tool_result to reuse generated id, got %#v", toolMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── buildClaudeToolPrompt ───────────────────────────────────────────
|
||||
|
||||
func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"input_schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if prompt == "" {
|
||||
t.Fatal("expected non-empty prompt")
|
||||
}
|
||||
// Should contain tool name and description
|
||||
if !containsStr(prompt, "search") {
|
||||
t.Fatalf("expected 'search' in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "Search the web") {
|
||||
t.Fatalf("expected description in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "<tool_calls>") {
|
||||
t.Fatalf("expected XML tool_calls format in prompt")
|
||||
}
|
||||
if !containsStr(prompt, "TOOL CALL FORMAT") {
|
||||
t.Fatalf("expected tool call format header in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeToolPromptMultipleTools(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": "tool1", "description": "desc1"},
|
||||
map[string]any{"name": "tool2", "description": "desc2"},
|
||||
}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if !containsStr(prompt, "tool1") || !containsStr(prompt, "tool2") {
|
||||
t.Fatalf("expected both tools in prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeToolPromptSupportsOpenAIStyleFunctionTool(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
"description": "Search via function tool",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"q": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
if !containsStr(prompt, "Tool: search") {
|
||||
t.Fatalf("expected OpenAI-style function tool name in prompt, got: %q", prompt)
|
||||
}
|
||||
if !containsStr(prompt, "Search via function tool") {
|
||||
t.Fatalf("expected OpenAI-style function tool description in prompt, got: %q", prompt)
|
||||
}
|
||||
if !containsStr(prompt, "\"q\"") {
|
||||
t.Fatalf("expected parameters schema serialized in prompt, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeToolPromptSkipsNonMap(t *testing.T) {
|
||||
tools := []any{"not a map"}
|
||||
prompt := buildClaudeToolPrompt(tools)
|
||||
// No valid tools → empty prompt
|
||||
if prompt != "" {
|
||||
t.Fatalf("expected empty prompt for non-map tools, got: %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── hasSystemMessage ────────────────────────────────────────────────
|
||||
|
||||
func TestHasSystemMessageTrue(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "system", "content": "You are a helper"},
|
||||
map[string]any{"role": "user", "content": "Hi"},
|
||||
}
|
||||
if !hasSystemMessage(msgs) {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSystemMessageFalse(t *testing.T) {
|
||||
msgs := []any{
|
||||
map[string]any{"role": "user", "content": "Hi"},
|
||||
map[string]any{"role": "assistant", "content": "Hello"},
|
||||
}
|
||||
if hasSystemMessage(msgs) {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSystemMessageEmpty(t *testing.T) {
|
||||
if hasSystemMessage(nil) {
|
||||
t.Fatal("expected false for nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSystemMessageNonMap(t *testing.T) {
|
||||
msgs := []any{"not a map"}
|
||||
if hasSystemMessage(msgs) {
|
||||
t.Fatal("expected false for non-map")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── extractClaudeToolNames ──────────────────────────────────────────
|
||||
|
||||
func TestExtractClaudeToolNamesSingle(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": "search"},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 1 || names[0] != "search" {
|
||||
t.Fatalf("expected [search], got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesMultiple(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": "search"},
|
||||
map[string]any{"name": "calculate"},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 2 {
|
||||
t.Fatalf("expected 2 names, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesSkipsEmptyName(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{"name": ""},
|
||||
map[string]any{"name": "valid"},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 1 || names[0] != "valid" {
|
||||
t.Fatalf("expected [valid], got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesSkipsNonMap(t *testing.T) {
|
||||
tools := []any{"not a map", 42}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 0 {
|
||||
t.Fatalf("expected 0, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesNil(t *testing.T) {
|
||||
names := extractClaudeToolNames(nil)
|
||||
if len(names) != 0 {
|
||||
t.Fatalf("expected 0, got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaudeToolNamesSupportsOpenAIStyleFunctionTool(t *testing.T) {
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
},
|
||||
},
|
||||
}
|
||||
names := extractClaudeToolNames(tools)
|
||||
if len(names) != 1 || names[0] != "search" {
|
||||
t.Fatalf("expected [search], got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── toMessageMaps ───────────────────────────────────────────────────
|
||||
|
||||
func TestToMessageMapsNormal(t *testing.T) {
|
||||
input := []any{
|
||||
map[string]any{"role": "user", "content": "Hello"},
|
||||
}
|
||||
got := toMessageMaps(input)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessageMapsNonSlice(t *testing.T) {
|
||||
got := toMessageMaps("not a slice")
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessageMapsSkipsNonMap(t *testing.T) {
|
||||
input := []any{"string", map[string]any{"role": "user"}, 42}
|
||||
got := toMessageMaps(input)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 map, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestToMessageMapsNil(t *testing.T) {
|
||||
got := toMessageMaps(nil)
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── extractMessageContent ──────────────────────────────────────────
|
||||
|
||||
func TestExtractMessageContentString(t *testing.T) {
|
||||
if got := extractMessageContent("hello"); got != "hello" {
|
||||
t.Fatalf("expected 'hello', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMessageContentArray(t *testing.T) {
|
||||
input := []any{"part1", "part2"}
|
||||
got := extractMessageContent(input)
|
||||
if got != "part1\npart2" {
|
||||
t.Fatalf("expected joined, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMessageContentOther(t *testing.T) {
|
||||
got := extractMessageContent(42)
|
||||
if got != "42" {
|
||||
t.Fatalf("expected '42', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMessageContentNil(t *testing.T) {
|
||||
got := extractMessageContent(nil)
|
||||
if got != "<nil>" {
|
||||
t.Fatalf("expected '<nil>', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── cloneMap ────────────────────────────────────────────────────────
|
||||
|
||||
func TestCloneMapBasic(t *testing.T) {
|
||||
original := map[string]any{"a": 1, "b": "hello"}
|
||||
clone := cloneMap(original)
|
||||
original["a"] = 999
|
||||
if clone["a"] != 1 {
|
||||
t.Fatalf("expected 1, got %v", clone["a"])
|
||||
}
|
||||
if clone["b"] != "hello" {
|
||||
t.Fatalf("expected 'hello', got %v", clone["b"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneMapEmpty(t *testing.T) {
|
||||
clone := cloneMap(map[string]any{})
|
||||
if len(clone) != 0 {
|
||||
t.Fatalf("expected empty, got %v", clone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneMapNested(t *testing.T) {
|
||||
// cloneMap is shallow, so nested maps share references
|
||||
inner := map[string]any{"key": "value"}
|
||||
original := map[string]any{"nested": inner}
|
||||
clone := cloneMap(original)
|
||||
// Shallow clone means inner is shared
|
||||
inner["key"] = "modified"
|
||||
cloneNested := clone["nested"].(map[string]any)
|
||||
if cloneNested["key"] != "modified" {
|
||||
t.Fatal("expected shallow clone to share nested references")
|
||||
}
|
||||
}
|
||||
|
||||
// helper
|
||||
func containsStr(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && findSubstring(s, sub))
|
||||
}
|
||||
|
||||
func findSubstring(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
226
internal/httpapi/claude/handler_utils.go
Normal file
226
internal/httpapi/claude/handler_utils.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/toolcall"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/prompt"
|
||||
)
|
||||
|
||||
func normalizeClaudeMessages(messages []any) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
state := &claudeToolCallState{
|
||||
nameByID: map[string]string{},
|
||||
lastIDByName: map[string]string{},
|
||||
callIDSequence: 0,
|
||||
}
|
||||
for _, m := range messages {
|
||||
msg, ok := m.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", msg["role"])))
|
||||
switch content := msg["content"].(type) {
|
||||
case []any:
|
||||
textParts := make([]string, 0, len(content))
|
||||
flushText := func() {
|
||||
if len(textParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"content": strings.Join(textParts, "\n"),
|
||||
})
|
||||
textParts = textParts[:0]
|
||||
}
|
||||
for _, block := range content {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeStr := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", b["type"])))
|
||||
switch typeStr {
|
||||
case "text":
|
||||
if t, ok := b["text"].(string); ok {
|
||||
textParts = append(textParts, t)
|
||||
}
|
||||
case "tool_use":
|
||||
if role == "assistant" {
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolUseToAssistant(b, state); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
case "tool_result":
|
||||
flushText()
|
||||
if toolMsg := normalizeClaudeToolResultToToolMessage(b, state); toolMsg != nil {
|
||||
out = append(out, toolMsg)
|
||||
}
|
||||
default:
|
||||
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
|
||||
textParts = append(textParts, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
flushText()
|
||||
default:
|
||||
copied := cloneMap(msg)
|
||||
out = append(out, copied)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeToolPrompt(tools []any) string {
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
m, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, desc, schemaObj := extractClaudeToolMeta(m)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
schema, _ := json.Marshal(schemaObj)
|
||||
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
||||
}
|
||||
if len(toolSchemas) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "You have access to these tools:\n\n" +
|
||||
strings.Join(toolSchemas, "\n\n") + "\n\n" +
|
||||
toolcall.BuildToolCallInstructions(names)
|
||||
}
|
||||
|
||||
//nolint:unused // retained for compatibility with pending Claude tool-result prompt flow.
|
||||
func formatClaudeToolResultForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "tool_result",
|
||||
"content": block["content"],
|
||||
}
|
||||
if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])); toolCallID != "" {
|
||||
payload["tool_call_id"] = toolCallID
|
||||
} else if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])); toolCallID != "" {
|
||||
payload["tool_call_id"] = toolCallID
|
||||
}
|
||||
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
|
||||
payload["name"] = name
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", payload))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func normalizeClaudeToolUseToAssistant(block map[string]any, state *claudeToolCallState) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
callID := safeStringValue(block["id"])
|
||||
if callID == "" {
|
||||
callID = safeStringValue(block["tool_use_id"])
|
||||
}
|
||||
if callID == "" {
|
||||
callID = state.nextID()
|
||||
}
|
||||
state.nameByID[callID] = name
|
||||
state.lastIDByName[strings.ToLower(name)] = callID
|
||||
arguments := block["input"]
|
||||
if arguments == nil {
|
||||
arguments = map[string]any{}
|
||||
}
|
||||
argsJSON, err := json.Marshal(arguments)
|
||||
if err != nil || len(argsJSON) == 0 {
|
||||
argsJSON = []byte("{}")
|
||||
}
|
||||
toolCalls := []any{
|
||||
map[string]any{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": string(argsJSON),
|
||||
},
|
||||
},
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"content": prompt.FormatToolCallsForPrompt(toolCalls),
|
||||
"tool_calls": toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeClaudeToolResultToToolMessage(block map[string]any, state *claudeToolCallState) map[string]any {
|
||||
if block == nil {
|
||||
return nil
|
||||
}
|
||||
name := safeStringValue(block["name"])
|
||||
toolCallID := safeStringValue(block["tool_use_id"])
|
||||
if toolCallID == "" {
|
||||
toolCallID = safeStringValue(block["tool_call_id"])
|
||||
}
|
||||
if toolCallID == "" {
|
||||
if name != "" {
|
||||
toolCallID = strings.TrimSpace(state.lastIDByName[strings.ToLower(name)])
|
||||
}
|
||||
}
|
||||
if toolCallID == "" {
|
||||
toolCallID = state.nextID()
|
||||
}
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": toolCallID,
|
||||
"content": normalizeClaudeToolResultContent(block["content"]),
|
||||
}
|
||||
if name != "" {
|
||||
out["name"] = name
|
||||
state.nameByID[toolCallID] = name
|
||||
state.lastIDByName[strings.ToLower(name)] = toolCallID
|
||||
} else if inferred := strings.TrimSpace(state.nameByID[toolCallID]); inferred != "" {
|
||||
out["name"] = inferred
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeClaudeToolResultContent(content any) any {
|
||||
if text, ok := content.(string); ok {
|
||||
return text
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "tool_result",
|
||||
"content": content,
|
||||
}
|
||||
b, err := json.Marshal(sanitizeClaudeBlockForPrompt(payload))
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", content))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func formatClaudeBlockRaw(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
b, err := json.Marshal(block)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", block))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
106
internal/httpapi/claude/handler_utils_sanitize.go
Normal file
106
internal/httpapi/claude/handler_utils_sanitize.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
maxClaudeRawPromptChars = 1024
|
||||
omittedBinaryMarker = "[omitted_binary_payload]"
|
||||
)
|
||||
|
||||
func formatClaudeUnknownBlockForPrompt(block map[string]any) string {
|
||||
if block == nil {
|
||||
return ""
|
||||
}
|
||||
safe := sanitizeClaudeBlockForPrompt(block)
|
||||
raw := strings.TrimSpace(formatClaudeBlockRaw(safe))
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if len(raw) > maxClaudeRawPromptChars {
|
||||
return raw[:maxClaudeRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
func sanitizeClaudeBlockForPrompt(block map[string]any) map[string]any {
|
||||
out := cloneMap(block)
|
||||
for k, v := range out {
|
||||
if looksLikeBinaryFieldName(k) {
|
||||
out[k] = omittedBinaryMarker
|
||||
continue
|
||||
}
|
||||
switch inner := v.(type) {
|
||||
case map[string]any:
|
||||
out[k] = sanitizeClaudeBlockForPrompt(inner)
|
||||
case []any:
|
||||
out[k] = sanitizeClaudeArrayForPrompt(inner)
|
||||
case string:
|
||||
out[k] = sanitizeClaudeStringForPrompt(k, inner)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeClaudeArrayForPrompt(items []any) []any {
|
||||
out := make([]any, 0, len(items))
|
||||
for _, item := range items {
|
||||
switch v := item.(type) {
|
||||
case map[string]any:
|
||||
out = append(out, sanitizeClaudeBlockForPrompt(v))
|
||||
case []any:
|
||||
out = append(out, sanitizeClaudeArrayForPrompt(v))
|
||||
default:
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeClaudeStringForPrompt(key, value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if looksLikeBinaryFieldName(key) || looksLikeBase64Payload(trimmed) {
|
||||
return omittedBinaryMarker
|
||||
}
|
||||
if len(trimmed) > maxClaudeRawPromptChars {
|
||||
return trimmed[:maxClaudeRawPromptChars] + "...(truncated)"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func looksLikeBinaryFieldName(name string) bool {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
return n == "data" || n == "bytes" || n == "base64" || n == "inline_data" || n == "inlinedata"
|
||||
}
|
||||
|
||||
func looksLikeBase64Payload(v string) bool {
|
||||
if len(v) < 512 {
|
||||
return false
|
||||
}
|
||||
compact := strings.TrimRight(v, "=")
|
||||
if compact == "" {
|
||||
return false
|
||||
}
|
||||
for _, ch := range compact {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
//nolint:unused // helper kept for compatibility with upcoming sanitize pipeline.
|
||||
func marshalCompactJSON(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
13
internal/httpapi/claude/output_clean.go
Normal file
13
internal/httpapi/claude/output_clean.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package claude
|
||||
|
||||
import textclean "ds2api/internal/textclean"
|
||||
|
||||
func cleanVisibleOutput(text string, stripReferenceMarkers bool) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
if stripReferenceMarkers {
|
||||
text = textclean.StripReferenceMarkers(text)
|
||||
}
|
||||
return text
|
||||
}
|
||||
197
internal/httpapi/claude/proxy_vercel_test.go
Normal file
197
internal/httpapi/claude/proxy_vercel_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type claudeProxyStoreStub struct {
|
||||
aliases map[string]string
|
||||
}
|
||||
|
||||
func (s claudeProxyStoreStub) ModelAliases() map[string]string { return s.aliases }
|
||||
|
||||
func (claudeProxyStoreStub) CompatStripReferenceMarkers() bool { return true }
|
||||
|
||||
type openAIProxyStub struct {
|
||||
status int
|
||||
body string
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIPrefersGlobalAliasMapping(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{
|
||||
Store: claudeProxyStoreStub{
|
||||
aliases: map[string]string{"claude-sonnet-4-6": "deepseek-v4-flash"},
|
||||
},
|
||||
OpenAI: openAI,
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"stream":false}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := strings.TrimSpace(openAI.seenModel); got != "deepseek-v4-flash" {
|
||||
t.Fatalf("expected global alias mapped proxy model deepseek-v4-flash, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func (s openAIProxyStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
if s.status == 0 {
|
||||
s.status = http.StatusOK
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(s.status)
|
||||
_, _ = w.Write([]byte(s.body))
|
||||
}
|
||||
|
||||
type openAIProxyCaptureStub struct {
|
||||
seenModel string
|
||||
seenReq map[string]any
|
||||
}
|
||||
|
||||
func (s *openAIProxyCaptureStub) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
s.seenReq = req
|
||||
if m, ok := req["model"].(string); ok {
|
||||
s.seenModel = m
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"id":"ok","choices":[{"message":{"role":"assistant","content":"ok"}}]}`))
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIVercelPreparePassthrough(t *testing.T) {
|
||||
h := &Handler{OpenAI: openAIProxyStub{status: 200, body: `{"lease_id":"lease_123","payload":{"a":1}}`}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages?__stream_prepare=1", strings.NewReader(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("expected json response, got err=%v body=%s", err, rec.Body.String())
|
||||
}
|
||||
if _, ok := out["lease_id"]; !ok {
|
||||
t.Fatalf("expected lease_id in prepare passthrough, got=%v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIUsesGlobalAliasMapping(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{
|
||||
Store: claudeProxyStoreStub{aliases: map[string]string{"claude-3-opus": "deepseek-v4-pro"}},
|
||||
OpenAI: openAI,
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model":"claude-3-opus","messages":[{"role":"user","content":"hi"}],"stream":false}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := strings.TrimSpace(openAI.seenModel); got != "deepseek-v4-pro" {
|
||||
t.Fatalf("expected mapped proxy model deepseek-v4-pro, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIPreservesThinkingOverride(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{
|
||||
Store: claudeProxyStoreStub{aliases: map[string]string{"claude-sonnet-4-6": "deepseek-v4-flash"}},
|
||||
OpenAI: openAI,
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"disabled"},"stream":false}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
thinking, _ := openAI.seenReq["thinking"].(map[string]any)
|
||||
if thinking["type"] != "disabled" {
|
||||
t.Fatalf("expected translated OpenAI request to preserve disabled thinking, got %#v", openAI.seenReq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIDisablesThinkingByDefault(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{
|
||||
Store: claudeProxyStoreStub{aliases: map[string]string{"claude-sonnet-4-6": "deepseek-v4-flash"}},
|
||||
OpenAI: openAI,
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"stream":false}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
thinking, _ := openAI.seenReq["thinking"].(map[string]any)
|
||||
if thinking["type"] != "disabled" {
|
||||
t.Fatalf("expected Claude default to disable downstream thinking, got %#v", openAI.seenReq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIEnablesThinkingWhenRequested(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{
|
||||
Store: claudeProxyStoreStub{aliases: map[string]string{"claude-sonnet-4-6": "deepseek-v4-flash"}},
|
||||
OpenAI: openAI,
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":1024},"stream":false}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
thinking, _ := openAI.seenReq["thinking"].(map[string]any)
|
||||
if thinking["type"] != "enabled" {
|
||||
t.Fatalf("expected Claude explicit thinking to enable downstream thinking, got %#v", openAI.seenReq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProxyTranslatesInlineImageToOpenAIDataURL(t *testing.T) {
|
||||
openAI := &openAIProxyCaptureStub{}
|
||||
h := &Handler{OpenAI: openAI}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"image","source":{"type":"base64","media_type":"image/png","data":"QUJDRA=="}}]}],"stream":false}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
messages, _ := openAI.seenReq["messages"].([]any)
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected one translated message, got %#v", openAI.seenReq)
|
||||
}
|
||||
msg, _ := messages[0].(map[string]any)
|
||||
content, _ := msg["content"].([]any)
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected translated content blocks, got %#v", msg)
|
||||
}
|
||||
imageBlock, _ := content[1].(map[string]any)
|
||||
if strings.TrimSpace(asString(imageBlock["type"])) != "image_url" {
|
||||
t.Fatalf("expected image_url block, got %#v", imageBlock)
|
||||
}
|
||||
imageURL, _ := imageBlock["image_url"].(map[string]any)
|
||||
if !strings.HasPrefix(strings.TrimSpace(asString(imageURL["url"])), "data:image/png;base64,") {
|
||||
t.Fatalf("expected translated data url, got %#v", imageBlock)
|
||||
}
|
||||
}
|
||||
44
internal/httpapi/claude/route_alias_test.go
Normal file
44
internal/httpapi/claude/route_alias_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type routeAliasAuthStub struct{}
|
||||
|
||||
func (routeAliasAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) {
|
||||
return nil, auth.ErrUnauthorized
|
||||
}
|
||||
|
||||
func (routeAliasAuthStub) Release(_ *auth.RequestAuth) {}
|
||||
|
||||
func TestClaudeRouteAliasesDoNot404(t *testing.T) {
|
||||
h := &Handler{
|
||||
Auth: routeAliasAuthStub{},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
paths := []string{
|
||||
"/anthropic/v1/messages",
|
||||
"/v1/messages",
|
||||
"/messages",
|
||||
"/anthropic/v1/messages/count_tokens",
|
||||
"/v1/messages/count_tokens",
|
||||
"/messages/count_tokens",
|
||||
}
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest(http.MethodPost, path, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusNotFound {
|
||||
t.Fatalf("expected route %s to be registered, got 404", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
117
internal/httpapi/claude/standard_request.go
Normal file
117
internal/httpapi/claude/standard_request.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/prompt"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type claudeNormalizedRequest struct {
|
||||
Standard promptcompat.StandardRequest
|
||||
NormalizedMessages []any
|
||||
}
|
||||
|
||||
func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNormalizedRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
return claudeNormalizedRequest{}, fmt.Errorf("request must include 'model' and 'messages'")
|
||||
}
|
||||
if _, ok := req["max_tokens"]; !ok {
|
||||
req["max_tokens"] = 8192
|
||||
}
|
||||
normalizedMessages := normalizeClaudeMessages(messagesRaw)
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalizedMessages
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested)
|
||||
|
||||
dsPayload := convertClaudeToDeepSeek(payload, store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
_, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||
if !ok {
|
||||
searchEnabled = false
|
||||
}
|
||||
thinkingEnabled := util.ResolveThinkingEnabled(req, false)
|
||||
finalPrompt := prompt.MessagesPrepareWithThinking(toMessageMaps(dsPayload["messages"]), thinkingEnabled)
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
if len(toolNames) == 0 && len(toolsRequested) > 0 {
|
||||
toolNames = []string{"__any_tool__"}
|
||||
}
|
||||
|
||||
return claudeNormalizedRequest{
|
||||
Standard: promptcompat.StandardRequest{
|
||||
Surface: "anthropic_messages",
|
||||
RequestedModel: strings.TrimSpace(model),
|
||||
ResolvedModel: dsModel,
|
||||
ResponseModel: strings.TrimSpace(model),
|
||||
Messages: payload["messages"].([]any),
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
},
|
||||
NormalizedMessages: normalizedMessages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any {
|
||||
if len(tools) == 0 {
|
||||
return normalizedMessages
|
||||
}
|
||||
toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools))
|
||||
if toolPrompt == "" {
|
||||
return normalizedMessages
|
||||
}
|
||||
|
||||
// Prefer top-level Anthropic-style system prompt when available.
|
||||
if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" {
|
||||
payload["system"] = mergeSystemPrompt(systemText, toolPrompt)
|
||||
return normalizedMessages
|
||||
}
|
||||
|
||||
messages := cloneAnySlice(normalizedMessages)
|
||||
for i := range messages {
|
||||
msg, ok := messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msg["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "system") {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt)
|
||||
messages[i] = copied
|
||||
return messages
|
||||
}
|
||||
|
||||
return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...)
|
||||
}
|
||||
|
||||
func mergeSystemPrompt(base, extra string) string {
|
||||
base = strings.TrimSpace(base)
|
||||
extra = strings.TrimSpace(extra)
|
||||
switch {
|
||||
case base == "":
|
||||
return extra
|
||||
case extra == "":
|
||||
return base
|
||||
default:
|
||||
return base + "\n\n" + extra
|
||||
}
|
||||
}
|
||||
|
||||
func cloneAnySlice(in []any) []any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, len(in))
|
||||
copy(out, in)
|
||||
return out
|
||||
}
|
||||
92
internal/httpapi/claude/standard_request_test.go
Normal file
92
internal/httpapi/claude/standard_request_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func TestNormalizeClaudeRequest(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"stream": true,
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if norm.Standard.ResolvedModel == "" {
|
||||
t.Fatalf("expected resolved model")
|
||||
}
|
||||
if !norm.Standard.Stream {
|
||||
t.Fatalf("expected stream=true")
|
||||
}
|
||||
if len(norm.Standard.ToolNames) == 0 {
|
||||
t.Fatalf("expected tool names")
|
||||
}
|
||||
if norm.Standard.FinalPrompt == "" {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "system", "content": "baseline rule"},
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||
t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
if !containsStr(norm.Standard.FinalPrompt, "baseline rule") {
|
||||
t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"system": "top-level system",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
if !containsStr(norm.Standard.FinalPrompt, "top-level system") {
|
||||
t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||
t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
}
|
||||
165
internal/httpapi/claude/stream_runtime_core.go
Normal file
165
internal/httpapi/claude/stream_runtime_core.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
)
|
||||
|
||||
type claudeStreamRuntime struct {
|
||||
w http.ResponseWriter
|
||||
rc *http.ResponseController
|
||||
canFlush bool
|
||||
|
||||
model string
|
||||
toolNames []string
|
||||
messages []any
|
||||
|
||||
thinkingEnabled bool
|
||||
searchEnabled bool
|
||||
bufferToolContent bool
|
||||
stripReferenceMarkers bool
|
||||
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
|
||||
nextBlockIndex int
|
||||
thinkingBlockOpen bool
|
||||
thinkingBlockIndex int
|
||||
textBlockOpen bool
|
||||
textBlockIndex int
|
||||
ended bool
|
||||
upstreamErr string
|
||||
}
|
||||
|
||||
func newClaudeStreamRuntime(
|
||||
w http.ResponseWriter,
|
||||
rc *http.ResponseController,
|
||||
canFlush bool,
|
||||
model string,
|
||||
messages []any,
|
||||
thinkingEnabled bool,
|
||||
searchEnabled bool,
|
||||
stripReferenceMarkers bool,
|
||||
toolNames []string,
|
||||
) *claudeStreamRuntime {
|
||||
return &claudeStreamRuntime{
|
||||
w: w,
|
||||
rc: rc,
|
||||
canFlush: canFlush,
|
||||
model: model,
|
||||
messages: messages,
|
||||
thinkingEnabled: thinkingEnabled,
|
||||
searchEnabled: searchEnabled,
|
||||
bufferToolContent: len(toolNames) > 0,
|
||||
stripReferenceMarkers: stripReferenceMarkers,
|
||||
toolNames: toolNames,
|
||||
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
thinkingBlockIndex: -1,
|
||||
textBlockIndex: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
s.upstreamErr = parsed.ErrorMessage
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||
}
|
||||
if parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
contentSeen := false
|
||||
for _, p := range parsed.Parts {
|
||||
cleanedText := cleanVisibleOutput(p.Text, s.stripReferenceMarkers)
|
||||
if cleanedText == "" {
|
||||
continue
|
||||
}
|
||||
if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(cleanedText) {
|
||||
continue
|
||||
}
|
||||
contentSeen = true
|
||||
|
||||
if p.Type == "thinking" {
|
||||
if !s.thinkingEnabled {
|
||||
continue
|
||||
}
|
||||
trimmed := sse.TrimContinuationOverlap(s.thinking.String(), cleanedText)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
s.thinking.WriteString(trimmed)
|
||||
s.closeTextBlock()
|
||||
if !s.thinkingBlockOpen {
|
||||
s.thinkingBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
},
|
||||
})
|
||||
s.thinkingBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.thinkingBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "thinking_delta",
|
||||
"thinking": trimmed,
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
trimmed := sse.TrimContinuationOverlap(s.text.String(), cleanedText)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
s.text.WriteString(trimmed)
|
||||
if s.bufferToolContent {
|
||||
if hasUnclosedCodeFence(s.text.String()) {
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
s.closeThinkingBlock()
|
||||
if !s.textBlockOpen {
|
||||
s.textBlockIndex = s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": s.textBlockIndex,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.textBlockOpen = true
|
||||
}
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": s.textBlockIndex,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": trimmed,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
func hasUnclosedCodeFence(text string) bool {
|
||||
return strings.Count(text, "```")%2 == 1
|
||||
}
|
||||
59
internal/httpapi/claude/stream_runtime_emit.go
Normal file
59
internal/httpapi/claude/stream_runtime_emit.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (s *claudeStreamRuntime) send(event string, v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
_, _ = s.w.Write([]byte("event: "))
|
||||
_, _ = s.w.Write([]byte(event))
|
||||
_, _ = s.w.Write([]byte("\n"))
|
||||
_, _ = s.w.Write([]byte("data: "))
|
||||
_, _ = s.w.Write(b)
|
||||
_, _ = s.w.Write([]byte("\n\n"))
|
||||
if s.canFlush {
|
||||
_ = s.rc.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendError(message string) {
|
||||
msg := strings.TrimSpace(message)
|
||||
if msg == "" {
|
||||
msg = "upstream stream error"
|
||||
}
|
||||
s.send("error", map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "api_error",
|
||||
"message": msg,
|
||||
"code": "internal_error",
|
||||
"param": nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendPing() {
|
||||
s.send("ping", map[string]any{"type": "ping"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) sendMessageStart() {
|
||||
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
|
||||
s.send("message_start", map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": s.messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": s.model,
|
||||
"content": []any{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
||||
},
|
||||
})
|
||||
}
|
||||
135
internal/httpapi/claude/stream_runtime_finalize.go
Normal file
135
internal/httpapi/claude/stream_runtime_finalize.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/toolcall"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (s *claudeStreamRuntime) closeThinkingBlock() {
|
||||
if !s.thinkingBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.thinkingBlockIndex,
|
||||
})
|
||||
s.thinkingBlockOpen = false
|
||||
s.thinkingBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) closeTextBlock() {
|
||||
if !s.textBlockOpen {
|
||||
return
|
||||
}
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": s.textBlockIndex,
|
||||
})
|
||||
s.textBlockOpen = false
|
||||
s.textBlockIndex = -1
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
if s.ended {
|
||||
return
|
||||
}
|
||||
s.ended = true
|
||||
|
||||
s.closeThinkingBlock()
|
||||
s.closeTextBlock()
|
||||
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
|
||||
if s.bufferToolContent {
|
||||
detected := toolcall.ParseStandaloneToolCalls(finalText, s.toolNames)
|
||||
if len(detected) == 0 && finalText == "" && finalThinking != "" {
|
||||
detected = toolcall.ParseStandaloneToolCalls(finalThinking, s.toolNames)
|
||||
}
|
||||
if len(detected) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for i, tc := range detected {
|
||||
idx := s.nextBlockIndex + i
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
||||
"name": tc.Name,
|
||||
"input": map[string]any{},
|
||||
},
|
||||
})
|
||||
|
||||
inputBytes, _ := json.Marshal(tc.Input)
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": idx,
|
||||
"delta": map[string]any{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputBytes),
|
||||
},
|
||||
})
|
||||
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
s.nextBlockIndex += len(detected)
|
||||
} else if finalText != "" {
|
||||
idx := s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.send("content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": idx,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
})
|
||||
s.send("content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": idx,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": finalText,
|
||||
},
|
||||
})
|
||||
s.send("content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": idx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
s.send("message_stop", map[string]any{"type": "message_stop"})
|
||||
}
|
||||
|
||||
func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
|
||||
if string(reason) == "upstream_error" {
|
||||
s.sendError(s.upstreamErr)
|
||||
return
|
||||
}
|
||||
if scannerErr != nil {
|
||||
s.sendError(scannerErr.Error())
|
||||
return
|
||||
}
|
||||
s.finalize("end_turn")
|
||||
}
|
||||
63
internal/httpapi/claude/stream_status_test.go
Normal file
63
internal/httpapi/claude/stream_status_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimw "github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
type streamStatusClaudeOpenAIStub struct{}
|
||||
|
||||
func (streamStatusClaudeOpenAIStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hello\"},\"finish_reason\":null}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
}
|
||||
|
||||
type streamStatusClaudeStoreStub struct{}
|
||||
|
||||
func (streamStatusClaudeStoreStub) ModelAliases() map[string]string { return nil }
|
||||
|
||||
func (streamStatusClaudeStoreStub) CompatStripReferenceMarkers() bool { return true }
|
||||
|
||||
func captureClaudeStatusMiddleware(statuses *[]int) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := chimw.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||
next.ServeHTTP(ww, r)
|
||||
*statuses = append(*statuses, ww.Status())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeMessagesStreamStatusCapturedAs200(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: streamStatusClaudeStoreStub{},
|
||||
OpenAI: streamStatusClaudeOpenAIStub{},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureClaudeStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected one captured status, got %d", len(statuses))
|
||||
}
|
||||
if statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200 (not 000), got %d", statuses[0])
|
||||
}
|
||||
}
|
||||
25
internal/httpapi/claude/tool_call_state.go
Normal file
25
internal/httpapi/claude/tool_call_state.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type claudeToolCallState struct {
|
||||
nameByID map[string]string
|
||||
lastIDByName map[string]string
|
||||
callIDSequence int
|
||||
}
|
||||
|
||||
func (s *claudeToolCallState) nextID() string {
|
||||
s.callIDSequence++
|
||||
return fmt.Sprintf("call_claude_%d", s.callIDSequence)
|
||||
}
|
||||
|
||||
func safeStringValue(v any) string {
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
Reference in New Issue
Block a user