package claude import ( "encoding/json" "fmt" "io" "net/http" "strings" "time" "github.com/go-chi/chi/v5" "ds2api/internal/auth" "ds2api/internal/config" "ds2api/internal/deepseek" claudefmt "ds2api/internal/format/claude" "ds2api/internal/sse" streamengine "ds2api/internal/stream" "ds2api/internal/util" ) // writeJSON is a package-internal alias to avoid mass-renaming all call-sites. var writeJSON = util.WriteJSON type Handler struct { Store ConfigReader Auth AuthResolver DS DeepSeekCaller } var ( claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount ) func RegisterRoutes(r chi.Router, h *Handler) { r.Get("/anthropic/v1/models", h.ListModels) r.Post("/anthropic/v1/messages", h.Messages) r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens) } func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { writeJSON(w, http.StatusOK, config.ClaudeModelsResponse()) } func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" { r.Header.Set("anthropic-version", "2023-06-01") } a, err := h.Auth.Determine(r) if err != nil { status := http.StatusUnauthorized detail := err.Error() if err == auth.ErrNoAccount { status = http.StatusTooManyRequests } writeClaudeError(w, status, detail) return } defer h.Auth.Release(a) var req map[string]any if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeClaudeError(w, http.StatusBadRequest, "invalid json") return } norm, err := normalizeClaudeRequest(h.Store, req) if err != nil { writeClaudeError(w, http.StatusBadRequest, err.Error()) return } stdReq := norm.Standard sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { writeClaudeError(w, http.StatusUnauthorized, "invalid token.") return } pow, err := h.DS.GetPow(r.Context(), a, 3) if err != nil { writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW") return } requestPayload := stdReq.CompletionPayload(sessionID) resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) if err != nil { writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.") return } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) writeClaudeError(w, http.StatusInternalServerError, string(body)) return } if stdReq.Stream { h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } result := sse.CollectStream(resp, stdReq.Thinking, true) respBody := claudefmt.BuildMessageResponse( fmt.Sprintf("msg_%d", time.Now().UnixNano()), stdReq.ResponseModel, norm.NormalizedMessages, result.Thinking, result.Text, stdReq.ToolNames, ) writeJSON(w, http.StatusOK, respBody) } func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { a, err := h.Auth.Determine(r) if err != nil { writeClaudeError(w, http.StatusUnauthorized, err.Error()) return } defer h.Auth.Release(a) var req map[string]any if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeClaudeError(w, http.StatusBadRequest, "invalid json") return } model, _ := req["model"].(string) messages, _ := req["messages"].([]any) if model == "" || len(messages) == 0 { writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") return } inputTokens := 0 if sys, ok := req["system"].(string); ok { inputTokens += util.EstimateTokens(sys) } for _, item := range messages { msg, ok := item.(map[string]any) if !ok { continue } inputTokens += 2 inputTokens += util.EstimateTokens(extractMessageContent(msg["content"])) } if tools, ok := req["tools"].([]any); ok { for _, t := range tools { b, _ := json.Marshal(t) inputTokens += util.EstimateTokens(string(b)) } } if inputTokens < 1 { inputTokens = 1 } writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens}) } func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) writeClaudeError(w, http.StatusInternalServerError, string(body)) return } w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") rc := http.NewResponseController(w) canFlush := rc.Flush() == nil if !canFlush { config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") } streamRuntime := newClaudeStreamRuntime( w, rc, canFlush, model, messages, thinkingEnabled, searchEnabled, toolNames, ) streamRuntime.sendMessageStart() initialType := "text" if thinkingEnabled { initialType = "thinking" } streamengine.ConsumeSSE(streamengine.ConsumeConfig{ Context: r.Context(), Body: resp.Body, ThinkingEnabled: thinkingEnabled, InitialType: initialType, KeepAliveInterval: claudeStreamPingInterval, IdleTimeout: claudeStreamIdleTimeout, MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt, }, streamengine.ConsumeHooks{ OnKeepAlive: func() { streamRuntime.sendPing() }, OnParsed: streamRuntime.onParsed, OnFinalize: streamRuntime.onFinalize, }) } func writeClaudeError(w http.ResponseWriter, status int, message string) { code := "invalid_request" switch status { case http.StatusUnauthorized: code = "authentication_failed" case http.StatusTooManyRequests: code = "rate_limit_exceeded" case http.StatusNotFound: code = "not_found" case http.StatusInternalServerError: code = "internal_error" } writeJSON(w, status, map[string]any{ "error": map[string]any{ "type": "invalid_request_error", "message": message, "code": code, "param": nil, }, }) } func normalizeClaudeMessages(messages []any) []any { out := make([]any, 0, len(messages)) for _, m := range messages { msg, ok := m.(map[string]any) if !ok { continue } copied := cloneMap(msg) switch content := msg["content"].(type) { case []any: parts := make([]string, 0, len(content)) for _, block := range content { b, ok := block.(map[string]any) if !ok { continue } typeStr, _ := b["type"].(string) if typeStr == "text" { if t, ok := b["text"].(string); ok { parts = append(parts, t) } } if typeStr == "tool_result" { parts = append(parts, formatClaudeToolResultForPrompt(b)) } } copied["content"] = strings.Join(parts, "\n") } out = append(out, copied) } return out } func buildClaudeToolPrompt(tools []any) string { parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"} for _, t := range tools { m, ok := t.(map[string]any) if !ok { continue } name, _ := m["name"].(string) desc, _ := m["description"].(string) schema, _ := json.Marshal(m["input_schema"]) parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) } parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}", "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.", "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.", ) return strings.Join(parts, "\n\n") } func formatClaudeToolResultForPrompt(block map[string]any) string { if block == nil { return "" } toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])) if toolCallID == "" { toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])) } if toolCallID == "" { toolCallID = "unknown" } name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])) if name == "" { name = "unknown" } content := strings.TrimSpace(fmt.Sprintf("%v", block["content"])) if content == "" { content = "null" } return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content) } func hasSystemMessage(messages []any) bool { for _, m := range messages { msg, ok := m.(map[string]any) if ok && msg["role"] == "system" { return true } } return false } func extractClaudeToolNames(tools []any) []string { out := make([]string, 0, len(tools)) for _, t := range tools { m, ok := t.(map[string]any) if !ok { continue } if name, ok := m["name"].(string); ok && name != "" { out = append(out, name) } } return out } func toMessageMaps(v any) []map[string]any { arr, ok := v.([]any) if !ok { return nil } out := make([]map[string]any, 0, len(arr)) for _, item := range arr { if m, ok := item.(map[string]any); ok { out = append(out, m) } } return out } func extractMessageContent(v any) string { switch x := v.(type) { case string: return x case []any: parts := make([]string, 0, len(x)) for _, it := range x { parts = append(parts, fmt.Sprintf("%v", it)) } return strings.Join(parts, "\n") default: return fmt.Sprintf("%v", x) } } func cloneMap(in map[string]any) map[string]any { out := make(map[string]any, len(in)) for k, v := range in { out[k] = v } return out }