mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-06 09:25:27 +08:00
365 lines
9.8 KiB
Go
365 lines
9.8 KiB
Go
package claude
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
|
|
"ds2api/internal/auth"
|
|
"ds2api/internal/config"
|
|
"ds2api/internal/deepseek"
|
|
claudefmt "ds2api/internal/format/claude"
|
|
"ds2api/internal/sse"
|
|
streamengine "ds2api/internal/stream"
|
|
"ds2api/internal/util"
|
|
)
|
|
|
|
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
|
var writeJSON = util.WriteJSON
|
|
|
|
type Handler struct {
|
|
Store ConfigReader
|
|
Auth AuthResolver
|
|
DS DeepSeekCaller
|
|
}
|
|
|
|
var (
|
|
claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
|
|
claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
|
|
claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
|
|
)
|
|
|
|
func RegisterRoutes(r chi.Router, h *Handler) {
|
|
r.Get("/anthropic/v1/models", h.ListModels)
|
|
r.Post("/anthropic/v1/messages", h.Messages)
|
|
r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
|
|
}
|
|
|
|
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
|
writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
|
|
}
|
|
|
|
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
|
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
|
r.Header.Set("anthropic-version", "2023-06-01")
|
|
}
|
|
a, err := h.Auth.Determine(r)
|
|
if err != nil {
|
|
status := http.StatusUnauthorized
|
|
detail := err.Error()
|
|
if err == auth.ErrNoAccount {
|
|
status = http.StatusTooManyRequests
|
|
}
|
|
writeClaudeError(w, status, detail)
|
|
return
|
|
}
|
|
defer h.Auth.Release(a)
|
|
|
|
var req map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
|
return
|
|
}
|
|
norm, err := normalizeClaudeRequest(h.Store, req)
|
|
if err != nil {
|
|
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
stdReq := norm.Standard
|
|
|
|
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
|
if err != nil {
|
|
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
|
|
return
|
|
}
|
|
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
|
if err != nil {
|
|
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
|
|
return
|
|
}
|
|
requestPayload := stdReq.CompletionPayload(sessionID)
|
|
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
|
if err != nil {
|
|
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
|
|
return
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
|
return
|
|
}
|
|
|
|
if stdReq.Stream {
|
|
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
|
return
|
|
}
|
|
result := sse.CollectStream(resp, stdReq.Thinking, true)
|
|
respBody := claudefmt.BuildMessageResponse(
|
|
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
|
stdReq.ResponseModel,
|
|
norm.NormalizedMessages,
|
|
result.Thinking,
|
|
result.Text,
|
|
stdReq.ToolNames,
|
|
)
|
|
writeJSON(w, http.StatusOK, respBody)
|
|
}
|
|
|
|
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
|
a, err := h.Auth.Determine(r)
|
|
if err != nil {
|
|
writeClaudeError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
defer h.Auth.Release(a)
|
|
|
|
var req map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
|
return
|
|
}
|
|
model, _ := req["model"].(string)
|
|
messages, _ := req["messages"].([]any)
|
|
if model == "" || len(messages) == 0 {
|
|
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
|
return
|
|
}
|
|
inputTokens := 0
|
|
if sys, ok := req["system"].(string); ok {
|
|
inputTokens += util.EstimateTokens(sys)
|
|
}
|
|
for _, item := range messages {
|
|
msg, ok := item.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
inputTokens += 2
|
|
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
|
|
}
|
|
if tools, ok := req["tools"].([]any); ok {
|
|
for _, t := range tools {
|
|
b, _ := json.Marshal(t)
|
|
inputTokens += util.EstimateTokens(string(b))
|
|
}
|
|
}
|
|
if inputTokens < 1 {
|
|
inputTokens = 1
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
|
}
|
|
|
|
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
rc := http.NewResponseController(w)
|
|
canFlush := rc.Flush() == nil
|
|
if !canFlush {
|
|
config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
|
|
}
|
|
|
|
streamRuntime := newClaudeStreamRuntime(
|
|
w,
|
|
rc,
|
|
canFlush,
|
|
model,
|
|
messages,
|
|
thinkingEnabled,
|
|
searchEnabled,
|
|
toolNames,
|
|
)
|
|
streamRuntime.sendMessageStart()
|
|
|
|
initialType := "text"
|
|
if thinkingEnabled {
|
|
initialType = "thinking"
|
|
}
|
|
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
|
Context: r.Context(),
|
|
Body: resp.Body,
|
|
ThinkingEnabled: thinkingEnabled,
|
|
InitialType: initialType,
|
|
KeepAliveInterval: claudeStreamPingInterval,
|
|
IdleTimeout: claudeStreamIdleTimeout,
|
|
MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
|
|
}, streamengine.ConsumeHooks{
|
|
OnKeepAlive: func() {
|
|
streamRuntime.sendPing()
|
|
},
|
|
OnParsed: streamRuntime.onParsed,
|
|
OnFinalize: streamRuntime.onFinalize,
|
|
})
|
|
}
|
|
|
|
func writeClaudeError(w http.ResponseWriter, status int, message string) {
|
|
code := "invalid_request"
|
|
switch status {
|
|
case http.StatusUnauthorized:
|
|
code = "authentication_failed"
|
|
case http.StatusTooManyRequests:
|
|
code = "rate_limit_exceeded"
|
|
case http.StatusNotFound:
|
|
code = "not_found"
|
|
case http.StatusInternalServerError:
|
|
code = "internal_error"
|
|
}
|
|
writeJSON(w, status, map[string]any{
|
|
"error": map[string]any{
|
|
"type": "invalid_request_error",
|
|
"message": message,
|
|
"code": code,
|
|
"param": nil,
|
|
},
|
|
})
|
|
}
|
|
|
|
func normalizeClaudeMessages(messages []any) []any {
|
|
out := make([]any, 0, len(messages))
|
|
for _, m := range messages {
|
|
msg, ok := m.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
copied := cloneMap(msg)
|
|
switch content := msg["content"].(type) {
|
|
case []any:
|
|
parts := make([]string, 0, len(content))
|
|
for _, block := range content {
|
|
b, ok := block.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
typeStr, _ := b["type"].(string)
|
|
if typeStr == "text" {
|
|
if t, ok := b["text"].(string); ok {
|
|
parts = append(parts, t)
|
|
}
|
|
}
|
|
if typeStr == "tool_result" {
|
|
parts = append(parts, formatClaudeToolResultForPrompt(b))
|
|
}
|
|
}
|
|
copied["content"] = strings.Join(parts, "\n")
|
|
}
|
|
out = append(out, copied)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func buildClaudeToolPrompt(tools []any) string {
|
|
parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
|
|
for _, t := range tools {
|
|
m, ok := t.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
name, _ := m["name"].(string)
|
|
desc, _ := m["description"].(string)
|
|
schema, _ := json.Marshal(m["input_schema"])
|
|
parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
|
|
}
|
|
parts = append(parts,
|
|
"When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
|
|
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
|
|
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
|
|
)
|
|
return strings.Join(parts, "\n\n")
|
|
}
|
|
|
|
func formatClaudeToolResultForPrompt(block map[string]any) string {
|
|
if block == nil {
|
|
return ""
|
|
}
|
|
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
|
|
if toolCallID == "" {
|
|
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
|
|
}
|
|
if toolCallID == "" {
|
|
toolCallID = "unknown"
|
|
}
|
|
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
|
|
if name == "" {
|
|
name = "unknown"
|
|
}
|
|
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
|
|
if content == "" {
|
|
content = "null"
|
|
}
|
|
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
|
|
}
|
|
|
|
func hasSystemMessage(messages []any) bool {
|
|
for _, m := range messages {
|
|
msg, ok := m.(map[string]any)
|
|
if ok && msg["role"] == "system" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func extractClaudeToolNames(tools []any) []string {
|
|
out := make([]string, 0, len(tools))
|
|
for _, t := range tools {
|
|
m, ok := t.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if name, ok := m["name"].(string); ok && name != "" {
|
|
out = append(out, name)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func toMessageMaps(v any) []map[string]any {
|
|
arr, ok := v.([]any)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
out := make([]map[string]any, 0, len(arr))
|
|
for _, item := range arr {
|
|
if m, ok := item.(map[string]any); ok {
|
|
out = append(out, m)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func extractMessageContent(v any) string {
|
|
switch x := v.(type) {
|
|
case string:
|
|
return x
|
|
case []any:
|
|
parts := make([]string, 0, len(x))
|
|
for _, it := range x {
|
|
parts = append(parts, fmt.Sprintf("%v", it))
|
|
}
|
|
return strings.Join(parts, "\n")
|
|
default:
|
|
return fmt.Sprintf("%v", x)
|
|
}
|
|
}
|
|
|
|
func cloneMap(in map[string]any) map[string]any {
|
|
out := make(map[string]any, len(in))
|
|
for k, v := range in {
|
|
out[k] = v
|
|
}
|
|
return out
|
|
}
|