mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-21 16:37:47 +08:00
feat: centralize utility functions, abstract SSE stream collection, and add concurrency to admin account testing.
This commit is contained in:
@@ -18,6 +18,9 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
@@ -113,11 +116,13 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
if toBool(req["stream"]) {
|
||||
if util.ToBool(req["stream"]) {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
|
||||
return
|
||||
}
|
||||
fullText, fullThinking := collectDeepSeek(resp, thinkingEnabled)
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
fullText := result.Text
|
||||
fullThinking := result.Thinking
|
||||
detected := util.ParseToolCalls(fullText, toolNames)
|
||||
content := make([]map[string]any, 0, 4)
|
||||
if fullThinking != "" {
|
||||
@@ -198,41 +203,6 @@ func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
|
||||
}
|
||||
|
||||
func collectDeepSeek(resp *http.Response, thinkingEnabled bool) (string, string) {
|
||||
defer resp.Body.Close()
|
||||
text := strings.Builder{}
|
||||
thinking := strings.Builder{}
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
chunk, done, ok := sse.ParseDeepSeekSSELine(scanner.Bytes())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
currentType = newType
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
for _, p := range parts {
|
||||
if p.Type == "thinking" {
|
||||
thinking.WriteString(p.Text)
|
||||
} else {
|
||||
text.WriteString(p.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return text.String(), thinking.String()
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
@@ -657,14 +627,3 @@ func cloneMap(in map[string]any) map[string]any {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toBool(v any) bool {
|
||||
b, _ := v.(bool)
|
||||
return b
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/sse"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -241,12 +242,12 @@ func TestCollectDeepSeekRegression(t *testing.T) {
|
||||
`data: {"p":"response/content","v":"答"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
text, thinking := collectDeepSeek(resp, true)
|
||||
if thinking != "想" {
|
||||
t.Fatalf("unexpected thinking: %q", thinking)
|
||||
result := sse.CollectStream(resp, true, true)
|
||||
if result.Thinking != "想" {
|
||||
t.Fatalf("unexpected thinking: %q", result.Thinking)
|
||||
}
|
||||
if text != "答" {
|
||||
t.Fatalf("unexpected text: %q", text)
|
||||
if result.Text != "答" {
|
||||
t.Fatalf("unexpected text: %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
// writeJSON is a package-internal alias kept to avoid mass-renaming across
|
||||
// every call-site in this file. It delegates to the shared util version.
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store *config.Store
|
||||
Auth *auth.Resolver
|
||||
@@ -117,7 +121,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if toBool(req["stream"]) {
|
||||
if util.ToBool(req["stream"]) {
|
||||
h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
return
|
||||
}
|
||||
@@ -125,50 +129,17 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
writeOpenAIError(w, resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
}
|
||||
_ = ctx
|
||||
_ = deepseek.ScanSSELines(resp, func(line []byte) bool {
|
||||
chunk, done, ok := sse.ParseDeepSeekSSELine(line)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if done {
|
||||
return false
|
||||
}
|
||||
if _, hasErr := chunk["error"]; hasErr {
|
||||
return false
|
||||
}
|
||||
parts, finished, newType := sse.ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
currentType = newType
|
||||
if finished {
|
||||
return false
|
||||
}
|
||||
for _, p := range parts {
|
||||
if searchEnabled && sse.IsCitation(p.Text) {
|
||||
continue
|
||||
}
|
||||
if p.Type == "thinking" {
|
||||
thinking.WriteString(p.Text)
|
||||
} else {
|
||||
text.WriteString(p.Text)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
finalThinking := result.Thinking
|
||||
finalText := result.Text
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
@@ -507,19 +478,6 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
return messages, names
|
||||
}
|
||||
|
||||
func toBool(v any) bool {
|
||||
if b, ok := v.(bool); ok {
|
||||
return b
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if !toBool(req["stream"]) {
|
||||
if !util.ToBool(req["stream"]) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user