mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-17 14:45:11 +08:00
feat: add unified response history session management across Claude, Gemini, and OpenAI API backends
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -5,15 +5,25 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
dsclient "ds2api/internal/deepseek/client"
|
||||
)
|
||||
|
||||
type claudeCurrentInputAuth struct{}
|
||||
|
||||
type claudeHistoryConfig struct {
|
||||
aliases map[string]string
|
||||
}
|
||||
|
||||
func (m claudeHistoryConfig) ModelAliases() map[string]string { return m.aliases }
|
||||
func (claudeHistoryConfig) CurrentInputFileEnabled() bool { return false }
|
||||
func (claudeHistoryConfig) CurrentInputFileMinChars() int { return 0 }
|
||||
|
||||
func (claudeCurrentInputAuth) Determine(*http.Request) (*auth.RequestAuth, error) {
|
||||
return &auth.RequestAuth{
|
||||
DeepSeekToken: "direct-token",
|
||||
@@ -22,6 +32,50 @@ func (claudeCurrentInputAuth) Determine(*http.Request) (*auth.RequestAuth, error
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestClaudeDirectRecordsResponseHistory(t *testing.T) {
|
||||
ds := &claudeCurrentInputDS{}
|
||||
historyStore := chathistory.New(filepath.Join(t.TempDir(), "history.json"))
|
||||
h := &Handler{
|
||||
Store: claudeHistoryConfig{aliases: map[string]string{"claude-sonnet-4-6": "deepseek-v4-flash"}},
|
||||
Auth: claudeCurrentInputAuth{},
|
||||
DS: ds,
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
reqBody := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hello from claude"}],"max_tokens":1024}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot history: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one history item, got %d", len(snapshot.Items))
|
||||
}
|
||||
item, err := historyStore.Get(snapshot.Items[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get history item: %v", err)
|
||||
}
|
||||
if item.Surface != "claude.messages" {
|
||||
t.Fatalf("unexpected surface: %q", item.Surface)
|
||||
}
|
||||
if item.Model != "claude-sonnet-4-6" {
|
||||
t.Fatalf("unexpected model: %q", item.Model)
|
||||
}
|
||||
if item.UserInput != "hello from claude" {
|
||||
t.Fatalf("unexpected user input: %q", item.UserInput)
|
||||
}
|
||||
if item.Content != "ok" {
|
||||
t.Fatalf("expected raw upstream content, got %q", item.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func (claudeCurrentInputAuth) Release(*auth.RequestAuth) {}
|
||||
|
||||
type claudeCurrentInputDS struct {
|
||||
@@ -53,10 +107,12 @@ func (d *claudeCurrentInputDS) CallCompletion(_ context.Context, _ *auth.Request
|
||||
|
||||
func TestClaudeDirectAppliesCurrentInputFile(t *testing.T) {
|
||||
ds := &claudeCurrentInputDS{}
|
||||
historyStore := chathistory.New(filepath.Join(t.TempDir(), "history.json"))
|
||||
h := &Handler{
|
||||
Store: mockClaudeConfig{aliases: map[string]string{"claude-sonnet-4-6": "deepseek-v4-flash"}},
|
||||
Auth: claudeCurrentInputAuth{},
|
||||
DS: ds,
|
||||
Store: mockClaudeConfig{aliases: map[string]string{"claude-sonnet-4-6": "deepseek-v4-flash"}},
|
||||
Auth: claudeCurrentInputAuth{},
|
||||
DS: ds,
|
||||
ChatHistory: historyStore,
|
||||
}
|
||||
reqBody := `{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"hello from claude"}],"max_tokens":1024}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(reqBody))
|
||||
@@ -82,4 +138,21 @@ func TestClaudeDirectAppliesCurrentInputFile(t *testing.T) {
|
||||
if !strings.Contains(prompt, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
|
||||
t.Fatalf("expected continuation prompt, got %q", prompt)
|
||||
}
|
||||
snapshot, err := historyStore.Snapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("snapshot history: %v", err)
|
||||
}
|
||||
if len(snapshot.Items) != 1 {
|
||||
t.Fatalf("expected one history item, got %d", len(snapshot.Items))
|
||||
}
|
||||
full, err := historyStore.Get(snapshot.Items[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get history item: %v", err)
|
||||
}
|
||||
if full.HistoryText != string(ds.uploads[0].Data) {
|
||||
t.Fatalf("expected uploaded current input file to be persisted in history text")
|
||||
}
|
||||
if len(full.Messages) != 1 || !strings.Contains(full.Messages[0].Content, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
|
||||
t.Fatalf("expected persisted message to match upstream continuation prompt, got %#v", full.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,8 +16,10 @@ import (
|
||||
"ds2api/internal/completionruntime"
|
||||
"ds2api/internal/config"
|
||||
claudefmt "ds2api/internal/format/claude"
|
||||
"ds2api/internal/httpapi/openai/history"
|
||||
"ds2api/internal/httpapi/requestbody"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/responsehistory"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
@@ -79,38 +82,71 @@ func (h *Handler) handleClaudeDirect(w http.ResponseWriter, r *http.Request) boo
|
||||
return true
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
if norm.Standard.Stream {
|
||||
h.handleClaudeDirectStream(w, r, a, norm.Standard)
|
||||
stdReq, err := h.applyCurrentInputFile(r.Context(), a, norm.Standard)
|
||||
if err != nil {
|
||||
status, message := mapCurrentInputFileError(err)
|
||||
writeClaudeError(w, status, message)
|
||||
return true
|
||||
}
|
||||
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, norm.Standard, completionruntime.Options{
|
||||
historySession := responsehistory.Start(responsehistory.StartParams{
|
||||
Store: h.ChatHistory,
|
||||
Request: r,
|
||||
Auth: a,
|
||||
Surface: "claude.messages",
|
||||
Standard: stdReq,
|
||||
})
|
||||
if stdReq.Stream {
|
||||
h.handleClaudeDirectStream(w, r, a, stdReq, historySession)
|
||||
return true
|
||||
}
|
||||
result, outErr := completionruntime.ExecuteNonStreamWithRetry(r.Context(), h.DS, a, stdReq, completionruntime.Options{
|
||||
StripReferenceMarkers: stripReferenceMarkersEnabled(),
|
||||
RetryEnabled: true,
|
||||
CurrentInputFile: h.Store,
|
||||
})
|
||||
if outErr != nil {
|
||||
if historySession != nil {
|
||||
historySession.ErrorTurn(outErr.Status, outErr.Message, outErr.Code, result.Turn)
|
||||
}
|
||||
writeClaudeError(w, outErr.Status, outErr.Message)
|
||||
return true
|
||||
}
|
||||
if historySession != nil {
|
||||
historySession.SuccessTurn(http.StatusOK, result.Turn, responsehistory.GenericUsage(result.Turn))
|
||||
}
|
||||
writeJSON(w, http.StatusOK, claudefmt.BuildMessageResponseFromTurn(
|
||||
fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
norm.Standard.ResponseModel,
|
||||
stdReq.ResponseModel,
|
||||
result.Turn,
|
||||
exposeThinking,
|
||||
))
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) {
|
||||
func (h *Handler) applyCurrentInputFile(ctx context.Context, a *auth.RequestAuth, stdReq promptcompat.StandardRequest) (promptcompat.StandardRequest, error) {
|
||||
if h == nil {
|
||||
return stdReq, nil
|
||||
}
|
||||
return (history.Service{Store: h.Store, DS: h.DS}).ApplyCurrentInputFile(ctx, a, stdReq)
|
||||
}
|
||||
|
||||
func mapCurrentInputFileError(err error) (int, string) {
|
||||
return history.MapError(err)
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeDirectStream(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, stdReq promptcompat.StandardRequest, historySession *responsehistory.Session) {
|
||||
start, outErr := completionruntime.StartCompletion(r.Context(), h.DS, a, stdReq, completionruntime.Options{
|
||||
CurrentInputFile: h.Store,
|
||||
})
|
||||
if outErr != nil {
|
||||
if historySession != nil {
|
||||
historySession.Error(outErr.Status, outErr.Message, outErr.Code, "", "")
|
||||
}
|
||||
writeClaudeError(w, outErr.Status, outErr.Message)
|
||||
return
|
||||
}
|
||||
streamReq := start.Request
|
||||
h.handleClaudeStreamRealtime(w, r, start.Response, streamReq.ResponseModel, streamReq.Messages, streamReq.Thinking, streamReq.Search, streamReq.ToolNames, streamReq.ToolsRaw)
|
||||
h.handleClaudeStreamRealtime(w, r, start.Response, streamReq.ResponseModel, streamReq.Messages, streamReq.Thinking, streamReq.Search, streamReq.ToolNames, streamReq.ToolsRaw, historySession)
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool {
|
||||
@@ -264,10 +300,17 @@ func stripClaudeThinkingBlocks(raw []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) {
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySessions ...*responsehistory.Session) {
|
||||
var historySession *responsehistory.Session
|
||||
if len(historySessions) > 0 {
|
||||
historySession = historySessions[0]
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if historySession != nil {
|
||||
historySession.Error(resp.StatusCode, strings.TrimSpace(string(body)), "error", "", "")
|
||||
}
|
||||
writeClaudeError(w, http.StatusInternalServerError, string(body))
|
||||
return
|
||||
}
|
||||
@@ -294,6 +337,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
|
||||
toolNames,
|
||||
toolsRaw,
|
||||
buildClaudePromptTokenText(messages, thinkingEnabled),
|
||||
historySession,
|
||||
)
|
||||
streamRuntime.sendMessageStart()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
dsprotocol "ds2api/internal/deepseek/protocol"
|
||||
"ds2api/internal/textclean"
|
||||
@@ -16,10 +17,11 @@ import (
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatRunner
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatRunner
|
||||
ChatHistory *chathistory.Store
|
||||
}
|
||||
|
||||
func stripReferenceMarkersEnabled() bool {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/responsehistory"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/toolcall"
|
||||
@@ -46,6 +47,7 @@ type claudeStreamRuntime struct {
|
||||
textEmitted bool
|
||||
ended bool
|
||||
upstreamErr string
|
||||
history *responsehistory.Session
|
||||
}
|
||||
|
||||
func newClaudeStreamRuntime(
|
||||
@@ -60,6 +62,7 @@ func newClaudeStreamRuntime(
|
||||
toolNames []string,
|
||||
toolsRaw any,
|
||||
promptTokenText string,
|
||||
history *responsehistory.Session,
|
||||
) *claudeStreamRuntime {
|
||||
return &claudeStreamRuntime{
|
||||
w: w,
|
||||
@@ -74,6 +77,7 @@ func newClaudeStreamRuntime(
|
||||
toolNames: toolNames,
|
||||
toolsRaw: toolsRaw,
|
||||
promptTokenText: promptTokenText,
|
||||
history: history,
|
||||
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
thinkingBlockIndex: -1,
|
||||
textBlockIndex: -1,
|
||||
@@ -232,5 +236,11 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
}
|
||||
}
|
||||
|
||||
if s.history != nil {
|
||||
s.history.Progress(
|
||||
responsehistory.ThinkingForArchive(s.rawThinking.String(), s.toolDetectionThinking.String(), s.thinking.String()),
|
||||
responsehistory.TextForArchive(s.rawText.String(), s.text.String()),
|
||||
)
|
||||
}
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package claude
|
||||
|
||||
import (
|
||||
"ds2api/internal/assistantturn"
|
||||
"ds2api/internal/responsehistory"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/toolcall"
|
||||
"ds2api/internal/toolstream"
|
||||
@@ -175,6 +176,15 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
if outcome.HasToolCalls {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
if s.history != nil {
|
||||
s.history.Success(
|
||||
200,
|
||||
responsehistory.ThinkingForArchive(turn.RawThinking, turn.DetectionThinking, turn.Thinking),
|
||||
responsehistory.TextForArchive(turn.RawText, turn.Text),
|
||||
stopReason,
|
||||
responsehistory.GenericUsage(turn),
|
||||
)
|
||||
}
|
||||
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
@@ -191,10 +201,16 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
|
||||
func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
|
||||
if string(reason) == "upstream_error" {
|
||||
if s.history != nil {
|
||||
s.history.Error(500, s.upstreamErr, "upstream_error", responsehistory.ThinkingForArchive(s.rawThinking.String(), s.toolDetectionThinking.String(), s.thinking.String()), responsehistory.TextForArchive(s.rawText.String(), s.text.String()))
|
||||
}
|
||||
s.sendError(s.upstreamErr)
|
||||
return
|
||||
}
|
||||
if scannerErr != nil {
|
||||
if s.history != nil {
|
||||
s.history.Error(500, scannerErr.Error(), "error", responsehistory.ThinkingForArchive(s.rawThinking.String(), s.toolDetectionThinking.String(), s.thinking.String()), responsehistory.TextForArchive(s.rawText.String(), s.text.String()))
|
||||
}
|
||||
s.sendError(scannerErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user