mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
604 lines
16 KiB
Go
604 lines
16 KiB
Go
package claude
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
|
|
"ds2api/internal/auth"
|
|
"ds2api/internal/config"
|
|
"ds2api/internal/deepseek"
|
|
"ds2api/internal/sse"
|
|
"ds2api/internal/util"
|
|
)
|
|
|
|
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
|
var writeJSON = util.WriteJSON
|
|
|
|
type Handler struct {
|
|
Store *config.Store
|
|
Auth *auth.Resolver
|
|
DS *deepseek.Client
|
|
}
|
|
|
|
var (
|
|
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
|
|
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
|
|
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
|
|
)
|
|
|
|
func RegisterRoutes(r chi.Router, h *Handler) {
|
|
r.Get("/anthropic/v1/models", h.ListModels)
|
|
r.Post("/anthropic/v1/messages", h.Messages)
|
|
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
|
}
|
|
|
|
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
|
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
|
}
|
|
|
|
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
|
a, err := h.Auth.Determine(r)
|
|
if err != nil {
|
|
status := http.StatusUnauthorized
|
|
detail := err.Error()
|
|
if err == auth.ErrNoAccount {
|
|
status = http.StatusTooManyRequests
|
|
}
|
|
writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}})
|
|
return
|
|
}
|
|
defer h.Auth.Release(a)
|
|
|
|
var req map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}})
|
|
return
|
|
}
|
|
model, _ := req["model"].(string)
|
|
messagesRaw, _ := req["messages"].([]any)
|
|
if model == "" || len(messagesRaw) == 0 {
|
|
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}})
|
|
return
|
|
}
|
|
|
|
normalized := normalizeClaudeMessages(messagesRaw)
|
|
payload := cloneMap(req)
|
|
payload["messages"] = normalized
|
|
toolsRequested, _ := req["tools"].([]any)
|
|
if len(toolsRequested) > 0 && !hasSystemMessage(normalized) {
|
|
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...)
|
|
}
|
|
|
|
dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store)
|
|
dsModel, _ := dsPayload["model"].(string)
|
|
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
|
if !ok {
|
|
thinkingEnabled = false
|
|
searchEnabled = false
|
|
}
|
|
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
|
|
|
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
|
if err != nil {
|
|
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}})
|
|
return
|
|
}
|
|
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
|
if err != nil {
|
|
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}})
|
|
return
|
|
}
|
|
requestPayload := map[string]any{
|
|
"chat_session_id": sessionID,
|
|
"parent_message_id": nil,
|
|
"prompt": finalPrompt,
|
|
"ref_file_ids": []any{},
|
|
"thinking_enabled": thinkingEnabled,
|
|
"search_enabled": searchEnabled,
|
|
}
|
|
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
|
if err != nil {
|
|
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}})
|
|
return
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
|
|
return
|
|
}
|
|
|
|
toolNames := extractClaudeToolNames(toolsRequested)
|
|
if util.ToBool(req["stream"]) {
|
|
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
|
|
return
|
|
}
|
|
result := sse.CollectStream(resp, thinkingEnabled, true)
|
|
fullText := result.Text
|
|
fullThinking := result.Thinking
|
|
detected := util.ParseToolCalls(fullText, toolNames)
|
|
content := make([]map[string]any, 0, 4)
|
|
if fullThinking != "" {
|
|
content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking})
|
|
}
|
|
stopReason := "end_turn"
|
|
if len(detected) > 0 {
|
|
stopReason = "tool_use"
|
|
for i, tc := range detected {
|
|
content = append(content, map[string]any{
|
|
"type": "tool_use",
|
|
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), i),
|
|
"name": tc.Name,
|
|
"input": tc.Input,
|
|
})
|
|
}
|
|
} else {
|
|
if fullText == "" {
|
|
fullText = "抱歉,没有生成有效的响应内容。"
|
|
}
|
|
content = append(content, map[string]any{"type": "text", "text": fullText})
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{
|
|
"id": fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": model,
|
|
"content": content,
|
|
"stop_reason": stopReason,
|
|
"stop_sequence": nil,
|
|
"usage": map[string]any{
|
|
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)),
|
|
"output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText),
|
|
},
|
|
})
|
|
}
|
|
|
|
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
|
a, err := h.Auth.Determine(r)
|
|
if err != nil {
|
|
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()})
|
|
return
|
|
}
|
|
defer h.Auth.Release(a)
|
|
|
|
var req map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"})
|
|
return
|
|
}
|
|
model, _ := req["model"].(string)
|
|
messages, _ := req["messages"].([]any)
|
|
if model == "" || len(messages) == 0 {
|
|
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."})
|
|
return
|
|
}
|
|
inputTokens := 0
|
|
if sys, ok := req["system"].(string); ok {
|
|
inputTokens += util.EstimateTokens(sys)
|
|
}
|
|
for _, item := range messages {
|
|
msg, ok := item.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
inputTokens += 2
|
|
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
|
}
|
|
if tools, ok := req["tools"].([]any); ok {
|
|
for _, t := range tools {
|
|
b, _ := json.Marshal(t)
|
|
inputTokens += util.EstimateTokens(string(b))
|
|
}
|
|
}
|
|
if inputTokens < 1 {
|
|
inputTokens = 1
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
|
}
|
|
|
|
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
rc := http.NewResponseController(w)
|
|
canFlush := rc.Flush() == nil
|
|
if !canFlush {
|
|
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
|
}
|
|
send := func(event string, v any) {
|
|
b, _ := json.Marshal(v)
|
|
_, _ = w.Write([]byte("event: "))
|
|
_, _ = w.Write([]byte(event))
|
|
_, _ = w.Write([]byte("\n"))
|
|
_, _ = w.Write([]byte("data: "))
|
|
_, _ = w.Write(b)
|
|
_, _ = w.Write([]byte("\n\n"))
|
|
if canFlush {
|
|
_ = rc.Flush()
|
|
}
|
|
}
|
|
sendError := func(message string) {
|
|
msg := strings.TrimSpace(message)
|
|
if msg == "" {
|
|
msg = "upstream stream error"
|
|
}
|
|
send("error", map[string]any{
|
|
"type": "error",
|
|
"error": map[string]any{
|
|
"type": "api_error",
|
|
"message": msg,
|
|
},
|
|
})
|
|
}
|
|
|
|
messageID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
|
|
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", messages))
|
|
send("message_start", map[string]any{
|
|
"type": "message_start",
|
|
"message": map[string]any{
|
|
"id": messageID,
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": model,
|
|
"content": []any{},
|
|
"stop_reason": nil,
|
|
"stop_sequence": nil,
|
|
"usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
|
|
},
|
|
})
|
|
|
|
initialType := "text"
|
|
if thinkingEnabled {
|
|
initialType = "thinking"
|
|
}
|
|
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
|
bufferToolContent := len(toolNames) > 0
|
|
hasContent := false
|
|
lastContent := time.Now()
|
|
keepaliveCount := 0
|
|
|
|
thinking := strings.Builder{}
|
|
text := strings.Builder{}
|
|
|
|
nextBlockIndex := 0
|
|
thinkingBlockOpen := false
|
|
thinkingBlockIndex := -1
|
|
textBlockOpen := false
|
|
textBlockIndex := -1
|
|
ended := false
|
|
|
|
closeThinkingBlock := func() {
|
|
if !thinkingBlockOpen {
|
|
return
|
|
}
|
|
send("content_block_stop", map[string]any{
|
|
"type": "content_block_stop",
|
|
"index": thinkingBlockIndex,
|
|
})
|
|
thinkingBlockOpen = false
|
|
thinkingBlockIndex = -1
|
|
}
|
|
closeTextBlock := func() {
|
|
if !textBlockOpen {
|
|
return
|
|
}
|
|
send("content_block_stop", map[string]any{
|
|
"type": "content_block_stop",
|
|
"index": textBlockIndex,
|
|
})
|
|
textBlockOpen = false
|
|
textBlockIndex = -1
|
|
}
|
|
|
|
finalize := func(stopReason string) {
|
|
if ended {
|
|
return
|
|
}
|
|
ended = true
|
|
|
|
closeThinkingBlock()
|
|
closeTextBlock()
|
|
|
|
finalThinking := thinking.String()
|
|
finalText := text.String()
|
|
|
|
if bufferToolContent {
|
|
detected := util.ParseToolCalls(finalText, toolNames)
|
|
if len(detected) > 0 {
|
|
stopReason = "tool_use"
|
|
for i, tc := range detected {
|
|
idx := nextBlockIndex + i
|
|
send("content_block_start", map[string]any{
|
|
"type": "content_block_start",
|
|
"index": idx,
|
|
"content_block": map[string]any{
|
|
"type": "tool_use",
|
|
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
|
|
"name": tc.Name,
|
|
"input": tc.Input,
|
|
},
|
|
})
|
|
send("content_block_stop", map[string]any{
|
|
"type": "content_block_stop",
|
|
"index": idx,
|
|
})
|
|
}
|
|
nextBlockIndex += len(detected)
|
|
} else if finalText != "" {
|
|
idx := nextBlockIndex
|
|
nextBlockIndex++
|
|
send("content_block_start", map[string]any{
|
|
"type": "content_block_start",
|
|
"index": idx,
|
|
"content_block": map[string]any{
|
|
"type": "text",
|
|
"text": "",
|
|
},
|
|
})
|
|
send("content_block_delta", map[string]any{
|
|
"type": "content_block_delta",
|
|
"index": idx,
|
|
"delta": map[string]any{
|
|
"type": "text_delta",
|
|
"text": finalText,
|
|
},
|
|
})
|
|
send("content_block_stop", map[string]any{
|
|
"type": "content_block_stop",
|
|
"index": idx,
|
|
})
|
|
}
|
|
}
|
|
|
|
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
|
send("message_delta", map[string]any{
|
|
"type": "message_delta",
|
|
"delta": map[string]any{
|
|
"stop_reason": stopReason,
|
|
"stop_sequence": nil,
|
|
},
|
|
"usage": map[string]any{
|
|
"output_tokens": outputTokens,
|
|
},
|
|
})
|
|
send("message_stop", map[string]any{"type": "message_stop"})
|
|
}
|
|
|
|
pingTicker := time.NewTicker(claudeStreamPingInterval)
|
|
defer pingTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-r.Context().Done():
|
|
return
|
|
case <-pingTicker.C:
|
|
if !hasContent {
|
|
keepaliveCount++
|
|
if keepaliveCount >= claudeStreamMaxKeepaliveCnt {
|
|
finalize("end_turn")
|
|
return
|
|
}
|
|
}
|
|
if hasContent && time.Since(lastContent) > claudeStreamIdleTimeout {
|
|
finalize("end_turn")
|
|
return
|
|
}
|
|
send("ping", map[string]any{"type": "ping"})
|
|
case parsed, ok := <-parsedLines:
|
|
if !ok {
|
|
if err := <-done; err != nil {
|
|
sendError(err.Error())
|
|
return
|
|
}
|
|
finalize("end_turn")
|
|
return
|
|
}
|
|
if !parsed.Parsed {
|
|
continue
|
|
}
|
|
if parsed.ErrorMessage != "" {
|
|
sendError(parsed.ErrorMessage)
|
|
return
|
|
}
|
|
if parsed.Stop {
|
|
finalize("end_turn")
|
|
return
|
|
}
|
|
|
|
for _, p := range parsed.Parts {
|
|
if p.Text == "" {
|
|
continue
|
|
}
|
|
if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) {
|
|
continue
|
|
}
|
|
|
|
hasContent = true
|
|
lastContent = time.Now()
|
|
keepaliveCount = 0
|
|
|
|
if p.Type == "thinking" {
|
|
if !thinkingEnabled {
|
|
continue
|
|
}
|
|
thinking.WriteString(p.Text)
|
|
closeTextBlock()
|
|
if !thinkingBlockOpen {
|
|
thinkingBlockIndex = nextBlockIndex
|
|
nextBlockIndex++
|
|
send("content_block_start", map[string]any{
|
|
"type": "content_block_start",
|
|
"index": thinkingBlockIndex,
|
|
"content_block": map[string]any{
|
|
"type": "thinking",
|
|
"thinking": "",
|
|
},
|
|
})
|
|
thinkingBlockOpen = true
|
|
}
|
|
send("content_block_delta", map[string]any{
|
|
"type": "content_block_delta",
|
|
"index": thinkingBlockIndex,
|
|
"delta": map[string]any{
|
|
"type": "thinking_delta",
|
|
"thinking": p.Text,
|
|
},
|
|
})
|
|
continue
|
|
}
|
|
|
|
text.WriteString(p.Text)
|
|
if bufferToolContent {
|
|
continue
|
|
}
|
|
closeThinkingBlock()
|
|
if !textBlockOpen {
|
|
textBlockIndex = nextBlockIndex
|
|
nextBlockIndex++
|
|
send("content_block_start", map[string]any{
|
|
"type": "content_block_start",
|
|
"index": textBlockIndex,
|
|
"content_block": map[string]any{
|
|
"type": "text",
|
|
"text": "",
|
|
},
|
|
})
|
|
textBlockOpen = true
|
|
}
|
|
send("content_block_delta", map[string]any{
|
|
"type": "content_block_delta",
|
|
"index": textBlockIndex,
|
|
"delta": map[string]any{
|
|
"type": "text_delta",
|
|
"text": p.Text,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func normalizeClaudeMessages(messages []any) []any {
|
|
out := make([]any, 0, len(messages))
|
|
for _, m := range messages {
|
|
msg, ok := m.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
copied := cloneMap(msg)
|
|
switch content := msg["content"].(type) {
|
|
case []any:
|
|
parts := make([]string, 0, len(content))
|
|
for _, block := range content {
|
|
b, ok := block.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
typeStr, _ := b["type"].(string)
|
|
if typeStr == "text" {
|
|
if t, ok := b["text"].(string); ok {
|
|
parts = append(parts, t)
|
|
}
|
|
}
|
|
if typeStr == "tool_result" {
|
|
parts = append(parts, fmt.Sprintf("%v", b["content"]))
|
|
}
|
|
}
|
|
copied["content"] = strings.Join(parts, "\n")
|
|
}
|
|
out = append(out, copied)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func buildClaudeToolPrompt(tools []any) string {
|
|
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
|
for _, t := range tools {
|
|
m, ok := t.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
name, _ := m["name"].(string)
|
|
desc, _ := m["description"].(string)
|
|
schema, _ := json.Marshal(m["input_schema"])
|
|
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
|
}
|
|
parts = append(parts, "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}")
|
|
return strings.Join(parts, "\n\n")
|
|
}
|
|
|
|
func hasSystemMessage(messages []any) bool {
|
|
for _, m := range messages {
|
|
msg, ok := m.(map[string]any)
|
|
if ok && msg["role"] == "system" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func extractClaudeToolNames(tools []any) []string {
|
|
out := make([]string, 0, len(tools))
|
|
for _, t := range tools {
|
|
m, ok := t.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if name, ok := m["name"].(string); ok && name != "" {
|
|
out = append(out, name)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func toMessageMaps(v any) []map[string]any {
|
|
arr, ok := v.([]any)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
out := make([]map[string]any, 0, len(arr))
|
|
for _, item := range arr {
|
|
if m, ok := item.(map[string]any); ok {
|
|
out = append(out, m)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func extractMessageContent(v any) string {
|
|
switch x := v.(type) {
|
|
case string:
|
|
return x
|
|
case []any:
|
|
parts := make([]string, 0, len(x))
|
|
for _, it := range x {
|
|
parts = append(parts, fmt.Sprintf("%v", it))
|
|
}
|
|
return strings.Join(parts, "\n")
|
|
default:
|
|
return fmt.Sprintf("%v", x)
|
|
}
|
|
}
|
|
|
|
func cloneMap(in map[string]any) map[string]any {
|
|
out := make(map[string]any, len(in))
|
|
for k, v := range in {
|
|
out[k] = v
|
|
}
|
|
return out
|
|
}
|