mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-17 14:45:11 +08:00
refactor: centralize assistant turn semantics and stream accumulation into new assistantturn and completionruntime packages
This commit is contained in:
@@ -4,13 +4,19 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/completionruntime"
|
||||
"ds2api/internal/config"
|
||||
claudefmt "ds2api/internal/format/claude"
|
||||
"ds2api/internal/httpapi/requestbody"
|
||||
"ds2api/internal/promptcompat"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
@@ -22,14 +28,90 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
if h.OpenAI == nil {
|
||||
writeClaudeError(w, http.StatusInternalServerError, "OpenAI proxy backend unavailable.")
|
||||
if isClaudeVercelProxyRequest(r) && h.proxyViaOpenAI(w, r, h.Store) {
|
||||
return
|
||||
}
|
||||
if h.proxyViaOpenAI(w, r, h.Store) {
|
||||
if h.Auth == nil || h.DS == nil {
|
||||
if h.OpenAI != nil && h.proxyViaOpenAI(w, r, h.Store) {
|
||||
return
|
||||
}
|
||||
writeClaudeError(w, http.StatusInternalServerError, "Claude runtime backend unavailable.")
|
||||
return
|
||||
}
|
||||
writeClaudeError(w, http.StatusBadGateway, "Failed to proxy Claude request.")
|
||||
if h.handleClaudeDirect(w, r) {
|
||||
return
|
||||
}
|
||||
writeClaudeError(w, http.StatusBadGateway, "Failed to handle Claude request.")
|
||||
}
|
||||
|
||||
func isClaudeVercelProxyRequest(r *http.Request) bool {
|
||||
if r == nil || r.URL == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1" ||
|
||||
strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeDirect(w http.ResponseWriter, r *http.Request) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
} else {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid body")
|
||||
}
|
||||
return true
|
||||
}
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(raw, &req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return true
|
||||
}
|
||||
exposeThinking := false
|
||||
if enabled, ok := util.ResolveThinkingOverride(req); ok && enabled {
|
||||
exposeThinking = true
|
||||
} else if _, ok := util.ResolveThinkingOverride(req); !ok && !util.ToBool(req["stream"]) {
|
||||
req["thinking"] = map[string]any{"type": "enabled"}
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
||||
return true
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusUnauthorized, err.Error())
|
||||
return true
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
if norm.Standard.Stream {
|
||||
h.handleClaudeDirectStream(w, r, a, norm.Standard)
|
||||
return true
|
||||
}
|
||||
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, norm.Standard, completionruntime.Options{
|
||||
StripReferenceMarkers: h.compatStripReferenceMarkers(),
|
||||
RetryEnabled: true,
|
||||
})
|
||||
if outErr != nil {
|
||||
writeClaudeError(w, outErr.Status, outErr.Message)
|
||||
return true
|
||||
}
|
||||
writeJSON(w, http.StatusOK, claudefmt.BuildMessageResponseFromTurn(
|
||||
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
norm.Standard.ResponseModel,
|
||||
result.Turn,
|
||||
exposeThinking,
|
||||
))
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) {
|
||||
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{})
|
||||
if outErr != nil {
|
||||
writeClaudeError(w, outErr.Status, outErr.Message)
|
||||
return
|
||||
}
|
||||
h.handleClaudeStreamRealtime(w, r, start.Response, stdReq.ResponseModel, stdReq.Messages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw)
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/toolcall"
|
||||
"ds2api/internal/toolstream"
|
||||
@@ -9,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (s *claudeStreamRuntime) closeThinkingBlock() {
|
||||
@@ -115,18 +115,28 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
|
||||
s.closeTextBlock()
|
||||
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
turn := assistantturn.BuildTurnFromStreamSnapshot(assistantturn.StreamSnapshot{
|
||||
RawText: s.rawText.String(),
|
||||
VisibleText: s.text.String(),
|
||||
RawThinking: s.rawThinking.String(),
|
||||
VisibleThinking: s.thinking.String(),
|
||||
DetectionThinking: s.toolDetectionThinking.String(),
|
||||
AlreadyEmittedCalls: s.toolCallsDetected,
|
||||
AlreadyEmittedToolRaw: s.toolCallsDetected,
|
||||
}, assistantturn.BuildOptions{
|
||||
Model: s.model,
|
||||
Prompt: s.promptTokenText,
|
||||
SearchEnabled: s.searchEnabled,
|
||||
StripReferenceMarkers: s.stripReferenceMarkers,
|
||||
ToolNames: s.toolNames,
|
||||
ToolsRaw: s.toolsRaw,
|
||||
})
|
||||
finalText := turn.Text
|
||||
|
||||
if s.bufferToolContent && !s.toolCallsDetected {
|
||||
detected := toolcall.ParseStandaloneToolCallsDetailed(s.rawText.String(), s.toolNames)
|
||||
if len(detected.Calls) == 0 {
|
||||
detected = toolcall.ParseStandaloneToolCallsDetailed(s.rawThinking.String(), s.toolNames)
|
||||
}
|
||||
if len(detected.Calls) > 0 {
|
||||
normalized := toolcall.NormalizeParsedToolCallsForSchemas(detected.Calls, s.toolsRaw)
|
||||
if len(turn.ToolCalls) > 0 {
|
||||
stopReason = "tool_use"
|
||||
for _, tc := range normalized {
|
||||
for _, tc := range turn.ToolCalls {
|
||||
idx := s.nextBlockIndex
|
||||
s.nextBlockIndex++
|
||||
s.sendToolUseBlock(idx, tc)
|
||||
@@ -161,7 +171,6 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
|
||||
outputTokens := util.CountOutputTokens(finalThinking, s.model) + util.CountOutputTokens(finalText, s.model)
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
@@ -169,7 +178,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": outputTokens,
|
||||
"output_tokens": turn.Usage.OutputTokens,
|
||||
},
|
||||
})
|
||||
s.send("message_stop", map[string]any{"type": "message_stop"})
|
||||
|
||||
Reference in New Issue
Block a user