Files
ds2api/internal/adapter/claude/handler.go

604 lines
16 KiB
Go

package claude
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/deepseek"
"ds2api/internal/sse"
"ds2api/internal/util"
)
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
var writeJSON = util.WriteJSON
type Handler struct {
Store *config.Store
Auth *auth.Resolver
DS *deepseek.Client
}
var (
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
)
func RegisterRoutes(r chi.Router, h *Handler) {
r.Get("/anthropic/v1/models", h.ListModels)
r.Post("/anthropic/v1/messages", h.Messages)
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
}
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
}
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}})
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}})
return
}
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if model == "" || len(messagesRaw) == 0 {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}})
return
}
normalized := normalizeClaudeMessages(messagesRaw)
payload := cloneMap(req)
payload["messages"] = normalized
toolsRequested, _ := req["tools"].([]any)
if len(toolsRequested) > 0 && !hasSystemMessage(normalized) {
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...)
}
dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store)
dsModel, _ := dsPayload["model"].(string)
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
if !ok {
thinkingEnabled = false
searchEnabled = false
}
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}})
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}})
return
}
requestPayload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": finalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}})
return
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
return
}
toolNames := extractClaudeToolNames(toolsRequested)
if util.ToBool(req["stream"]) {
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
return
}
result := sse.CollectStream(resp, thinkingEnabled, true)
fullText := result.Text
fullThinking := result.Thinking
detected := util.ParseToolCalls(fullText, toolNames)
content := make([]map[string]any, 0, 4)
if fullThinking != "" {
content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking})
}
stopReason := "end_turn"
if len(detected) > 0 {
stopReason = "tool_use"
for i, tc := range detected {
content = append(content, map[string]any{
"type": "tool_use",
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
"name": tc.Name,
"input": tc.Input,
})
}
} else {
if fullText == "" {
fullText = "抱歉,没有生成有效的响应内容。"
}
content = append(content, map[string]any{"type": "text", "text": fullText})
}
writeJSON(w, http.StatusOK, map[string]any{
"id": fmt.Sprintf("msg_%d", time.Now().UnixNano()),
"type": "message",
"role": "assistant",
"model": model,
"content": content,
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)),
"output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText),
},
})
}
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()})
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"})
return
}
model, _ := req["model"].(string)
messages, _ := req["messages"].([]any)
if model == "" || len(messages) == 0 {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."})
return
}
inputTokens := 0
if sys, ok := req["system"].(string); ok {
inputTokens += util.EstimateTokens(sys)
}
for _, item := range messages {
msg, ok := item.(map[string]any)
if !ok {
continue
}
inputTokens += 2
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
}
if tools, ok := req["tools"].([]any); ok {
for _, t := range tools {
b, _ := json.Marshal(t)
inputTokens += util.EstimateTokens(string(b))
}
}
if inputTokens < 1 {
inputTokens = 1
}
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
}
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
canFlush := rc.Flush() == nil
if !canFlush {
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
}
send := func(event string, v any) {
b, _ := json.Marshal(v)
_, _ = w.Write([]byte("event: "))
_, _ = w.Write([]byte(event))
_, _ = w.Write([]byte("\n"))
_, _ = w.Write([]byte("data: "))
_, _ = w.Write(b)
_, _ = w.Write([]byte("\n\n"))
if canFlush {
_ = rc.Flush()
}
}
sendError := func(message string) {
msg := strings.TrimSpace(message)
if msg == "" {
msg = "upstream stream error"
}
send("error", map[string]any{
"type": "error",
"error": map[string]any{
"type": "api_error",
"message": msg,
},
})
}
messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages))
send("message_start", map[string]any{
"type": "message_start",
"message": map[string]any{
"id": messageID,
"type": "message",
"role": "assistant",
"model": model,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
},
})
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
bufferToolContent := len(toolNames) > 0
hasContent := false
lastContent := time.Now()
keepaliveCount := 0
thinking := strings.Builder{}
text := strings.Builder{}
nextBlockIndex := 0
thinkingBlockOpen := false
thinkingBlockIndex := -1
textBlockOpen := false
textBlockIndex := -1
ended := false
closeThinkingBlock := func() {
if !thinkingBlockOpen {
return
}
send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": thinkingBlockIndex,
})
thinkingBlockOpen = false
thinkingBlockIndex = -1
}
closeTextBlock := func() {
if !textBlockOpen {
return
}
send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": textBlockIndex,
})
textBlockOpen = false
textBlockIndex = -1
}
finalize := func(stopReason string) {
if ended {
return
}
ended = true
closeThinkingBlock()
closeTextBlock()
finalThinking := thinking.String()
finalText := text.String()
if bufferToolContent {
detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) > 0 {
stopReason = "tool_use"
for i, tc := range detected {
idx := nextBlockIndex + i
send("content_block_start", map[string]any{
"type": "content_block_start",
"index": idx,
"content_block": map[string]any{
"type": "tool_use",
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
"name": tc.Name,
"input": tc.Input,
},
})
send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": idx,
})
}
nextBlockIndex += len(detected)
} else if finalText != "" {
idx := nextBlockIndex
nextBlockIndex++
send("content_block_start", map[string]any{
"type": "content_block_start",
"index": idx,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": idx,
"delta": map[string]any{
"type": "text_delta",
"text": finalText,
},
})
send("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": idx,
})
}
}
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
send("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": map[string]any{
"output_tokens": outputTokens,
},
})
send("message_stop", map[string]any{"type": "message_stop"})
}
pingTicker := time.NewTicker(claudeStreamPingInterval)
defer pingTicker.Stop()
for {
select {
case <-r.Context().Done():
return
case <-pingTicker.C:
if !hasContent {
keepaliveCount++
if keepaliveCount >= claudeStreamMaxKeepaliveCnt {
finalize("end_turn")
return
}
}
if hasContent && time.Since(lastContent) > claudeStreamIdleTimeout {
finalize("end_turn")
return
}
send("ping", map[string]any{"type": "ping"})
case parsed, ok := <-parsedLines:
if !ok {
if err := <-done; err != nil {
sendError(err.Error())
return
}
finalize("end_turn")
return
}
if !parsed.Parsed {
continue
}
if parsed.ErrorMessage != "" {
sendError(parsed.ErrorMessage)
return
}
if parsed.Stop {
finalize("end_turn")
return
}
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) {
continue
}
hasContent = true
lastContent = time.Now()
keepaliveCount = 0
if p.Type == "thinking" {
if !thinkingEnabled {
continue
}
thinking.WriteString(p.Text)
closeTextBlock()
if !thinkingBlockOpen {
thinkingBlockIndex = nextBlockIndex
nextBlockIndex++
send("content_block_start", map[string]any{
"type": "content_block_start",
"index": thinkingBlockIndex,
"content_block": map[string]any{
"type": "thinking",
"thinking": "",
},
})
thinkingBlockOpen = true
}
send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": thinkingBlockIndex,
"delta": map[string]any{
"type": "thinking_delta",
"thinking": p.Text,
},
})
continue
}
text.WriteString(p.Text)
if bufferToolContent {
continue
}
closeThinkingBlock()
if !textBlockOpen {
textBlockIndex = nextBlockIndex
nextBlockIndex++
send("content_block_start", map[string]any{
"type": "content_block_start",
"index": textBlockIndex,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
textBlockOpen = true
}
send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": textBlockIndex,
"delta": map[string]any{
"type": "text_delta",
"text": p.Text,
},
})
}
}
}
}
func normalizeClaudeMessages(messages []any) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
msg, ok := m.(map[string]any)
if !ok {
continue
}
copied := cloneMap(msg)
switch content := msg["content"].(type) {
case []any:
parts := make([]string, 0, len(content))
for _, block := range content {
b, ok := block.(map[string]any)
if !ok {
continue
}
typeStr, _ := b["type"].(string)
if typeStr == "text" {
if t, ok := b["text"].(string); ok {
parts = append(parts, t)
}
}
if typeStr == "tool_result" {
parts = append(parts, fmt.Sprintf("%v", b["content"]))
}
}
copied["content"] = strings.Join(parts, "\n")
}
out = append(out, copied)
}
return out
}
func buildClaudeToolPrompt(tools []any) string {
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
for _, t := range tools {
m, ok := t.(map[string]any)
if !ok {
continue
}
name, _ := m["name"].(string)
desc, _ := m["description"].(string)
schema, _ := json.Marshal(m["input_schema"])
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
}
parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}")
return strings.Join(parts, "\n\n")
}
func hasSystemMessage(messages []any) bool {
for _, m := range messages {
msg, ok := m.(map[string]any)
if ok && msg["role"] == "system" {
return true
}
}
return false
}
func extractClaudeToolNames(tools []any) []string {
out := make([]string, 0, len(tools))
for _, t := range tools {
m, ok := t.(map[string]any)
if !ok {
continue
}
if name, ok := m["name"].(string); ok && name != "" {
out = append(out, name)
}
}
return out
}
func toMessageMaps(v any) []map[string]any {
arr, ok := v.([]any)
if !ok {
return nil
}
out := make([]map[string]any, 0, len(arr))
for _, item := range arr {
if m, ok := item.(map[string]any); ok {
out = append(out, m)
}
}
return out
}
func extractMessageContent(v any) string {
switch x := v.(type) {
case string:
return x
case []any:
parts := make([]string, 0, len(x))
for _, it := range x {
parts = append(parts, fmt.Sprintf("%v", it))
}
return strings.Join(parts, "\n")
default:
return fmt.Sprintf("%v", x)
}
}
func cloneMap(in map[string]any) map[string]any {
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}