mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-02 07:25:26 +08:00
Merge pull request #193 from CJackHwang/codex/analyze-context-absence-issue
Proxy Claude/Gemini via OpenAI with translation bridge and streaming translator; update deps
This commit is contained in:
16
go.mod
16
go.mod
@@ -1,17 +1,25 @@
|
||||
module ds2api
|
||||
|
||||
go 1.24
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
github.com/go-chi/chi/v5 v5.2.3
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/refraction-networking/utls v1.8.1
|
||||
github.com/refraction-networking/utls v1.8.2
|
||||
github.com/tetratelabs/wazero v1.9.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/klauspost/compress v1.17.4 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
github.com/router-for-me/CLIProxyAPI/v6 v6.9.8 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
31
go.sum
31
go.sum
@@ -1,16 +1,47 @@
|
||||
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
||||
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
|
||||
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/router-for-me/CLIProxyAPI/v6 v6.9.8 h1:O65R38THenp8E1IK0paQlOfop3Y6UYlfqSdLlepidSY=
|
||||
github.com/router-for-me/CLIProxyAPI/v6 v6.9.8/go.mod h1:P1jsIPFXorYGuS2N/3BlZYkpRKi/z7+oR3+1tdG0u4k=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I=
|
||||
github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -24,6 +24,10 @@ type ConfigReader interface {
|
||||
ClaudeMapping() map[string]string
|
||||
}
|
||||
|
||||
type OpenAIChatRunner interface {
|
||||
ChatCompletions(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -13,12 +15,21 @@ import (
|
||||
claudefmt "ds2api/internal/format/claude"
|
||||
"ds2api/internal/sse"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
|
||||
r.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
if h.OpenAI != nil {
|
||||
if h.proxyViaOpenAI(w, r) {
|
||||
return
|
||||
}
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
@@ -79,9 +90,99 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
result.Text,
|
||||
stdReq.ToolNames,
|
||||
)
|
||||
if result.OutputTokens > 0 {
|
||||
if usage, ok := respBody["usage"].(map[string]any); ok {
|
||||
usage["output_tokens"] = result.OutputTokens
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid body")
|
||||
return true
|
||||
}
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(raw, &req); err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return true
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
stream := util.ToBool(req["stream"])
|
||||
translatedReq := translatorcliproxy.ToOpenAI(sdktranslator.FormatClaude, model, raw, stream)
|
||||
|
||||
isVercelPrepare := strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1"
|
||||
isVercelRelease := strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
|
||||
if isVercelRelease {
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(raw))
|
||||
proxyReq.ContentLength = int64(len(raw))
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(translatedReq))
|
||||
proxyReq.ContentLength = int64(len(translatedReq))
|
||||
|
||||
if stream && !isVercelPrepare {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
streamWriter := translatorcliproxy.NewOpenAIStreamTranslatorWriter(w, sdktranslator.FormatClaude, model, raw, translatedReq)
|
||||
h.OpenAI.ChatCompletions(streamWriter, proxyReq)
|
||||
return true
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
if isVercelPrepare {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
converted := translatorcliproxy.FromOpenAINonStream(sdktranslator.FormatClaude, model, raw, translatedReq, body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(converted)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
|
||||
@@ -15,9 +15,10 @@ import (
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatRunner
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
42
internal/adapter/claude/proxy_vercel_test.go
Normal file
42
internal/adapter/claude/proxy_vercel_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type openAIProxyStub struct {
|
||||
status int
|
||||
body string
|
||||
}
|
||||
|
||||
func (s openAIProxyStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
if s.status == 0 {
|
||||
s.status = http.StatusOK
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(s.status)
|
||||
_, _ = w.Write([]byte(s.body))
|
||||
}
|
||||
|
||||
func TestClaudeProxyViaOpenAIVercelPreparePassthrough(t *testing.T) {
|
||||
h := &Handler{OpenAI: openAIProxyStub{status: 200, body: `{"lease_id":"lease_123","payload":{"a":1}}`}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages?__stream_prepare=1", strings.NewReader(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.Messages(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("expected json response, got err=%v body=%s", err, rec.Body.String())
|
||||
}
|
||||
if _, ok := out["lease_id"]; !ok {
|
||||
t.Fatalf("expected lease_id in prepare passthrough, got=%v", out)
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ type claudeStreamRuntime struct {
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
|
||||
nextBlockIndex int
|
||||
thinkingBlockOpen bool
|
||||
@@ -66,6 +67,9 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
s.upstreamErr = parsed.ErrorMessage
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||
|
||||
@@ -108,6 +108,9 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
if s.outputTokens > 0 {
|
||||
outputTokens = s.outputTokens
|
||||
}
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
|
||||
@@ -24,6 +24,10 @@ type ConfigReader interface {
|
||||
ModelAliases() map[string]string
|
||||
}
|
||||
|
||||
type OpenAIChatRunner interface {
|
||||
ChatCompletions(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
var _ AuthResolver = (*auth.Resolver)(nil)
|
||||
var _ DeepSeekCaller = (*deepseek.Client)(nil)
|
||||
var _ ConfigReader = (*config.Store)(nil)
|
||||
|
||||
@@ -1,19 +1,29 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
|
||||
if h.OpenAI != nil {
|
||||
if h.proxyViaOpenAI(w, r, stream) {
|
||||
return
|
||||
}
|
||||
}
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
@@ -67,6 +77,94 @@ func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request,
|
||||
h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid body")
|
||||
return true
|
||||
}
|
||||
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||
translatedReq := translatorcliproxy.ToOpenAI(sdktranslator.FormatGemini, routeModel, raw, stream)
|
||||
if !strings.Contains(string(translatedReq), `"stream"`) {
|
||||
var reqMap map[string]any
|
||||
if json.Unmarshal(translatedReq, &reqMap) == nil {
|
||||
reqMap["stream"] = stream
|
||||
if b, e := json.Marshal(reqMap); e == nil {
|
||||
translatedReq = b
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isVercelPrepare := strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1"
|
||||
isVercelRelease := strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
|
||||
if isVercelRelease {
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(raw))
|
||||
proxyReq.ContentLength = int64(len(raw))
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
|
||||
proxyReq := r.Clone(r.Context())
|
||||
proxyReq.URL.Path = "/v1/chat/completions"
|
||||
proxyReq.Body = io.NopCloser(bytes.NewReader(translatedReq))
|
||||
proxyReq.ContentLength = int64(len(translatedReq))
|
||||
|
||||
if stream && !isVercelPrepare {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
streamWriter := translatorcliproxy.NewOpenAIStreamTranslatorWriter(w, sdktranslator.FormatGemini, routeModel, raw, translatedReq)
|
||||
h.OpenAI.ChatCompletions(streamWriter, proxyReq)
|
||||
return true
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
h.OpenAI.ChatCompletions(rec, proxyReq)
|
||||
res := rec.Result()
|
||||
defer res.Body.Close()
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
if isVercelPrepare {
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(res.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
return true
|
||||
}
|
||||
converted := translatorcliproxy.FromOpenAINonStream(sdktranslator.FormatGemini, routeModel, raw, translatedReq, body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(converted)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
@@ -76,12 +174,12 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht
|
||||
}
|
||||
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
|
||||
writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames, result.OutputTokens))
|
||||
}
|
||||
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, outputTokens int) map[string]any {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, outputTokens)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
@@ -98,10 +196,14 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string, outputTokens int) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
if outputTokens > 0 {
|
||||
completionTokens = outputTokens
|
||||
reasoningTokens = 0
|
||||
}
|
||||
return map[string]any{
|
||||
"promptTokenCount": promptTokens,
|
||||
"candidatesTokenCount": reasoningTokens + completionTokens,
|
||||
|
||||
@@ -11,9 +11,10 @@ import (
|
||||
var writeJSON = util.WriteJSON
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
Store ConfigReader
|
||||
Auth AuthResolver
|
||||
DS DeepSeekCaller
|
||||
OpenAI OpenAIChatRunner
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
|
||||
@@ -64,6 +64,7 @@ type geminiStreamRuntime struct {
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
}
|
||||
|
||||
func newGeminiStreamRuntime(
|
||||
@@ -103,6 +104,9 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
@@ -176,6 +180,6 @@ func (s *geminiStreamRuntime) finalize() {
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.outputTokens),
|
||||
})
|
||||
}
|
||||
|
||||
42
internal/adapter/gemini/proxy_vercel_test.go
Normal file
42
internal/adapter/gemini/proxy_vercel_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type openAIProxyStub struct {
|
||||
status int
|
||||
body string
|
||||
}
|
||||
|
||||
func (s openAIProxyStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) {
|
||||
if s.status == 0 {
|
||||
s.status = http.StatusOK
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(s.status)
|
||||
_, _ = w.Write([]byte(s.body))
|
||||
}
|
||||
|
||||
func TestGeminiProxyViaOpenAIVercelReleasePassthrough(t *testing.T) {
|
||||
h := &Handler{OpenAI: openAIProxyStub{status: 200, body: `{"success":true}`}}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:streamGenerateContent?__stream_release=1", strings.NewReader(`{"lease_id":"lease_123"}`))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.StreamGenerateContent(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("expected json response, got err=%v body=%s", err, rec.Body.String())
|
||||
}
|
||||
if v, ok := out["success"].(bool); !ok || !v {
|
||||
t.Fatalf("expected success=true passthrough, got=%v", out)
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,7 @@ type chatStreamRuntime struct {
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
}
|
||||
|
||||
func newChatStreamRuntime(
|
||||
@@ -165,12 +166,19 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
if len(detected.Calls) > 0 || s.toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
|
||||
if s.outputTokens > 0 {
|
||||
usage["completion_tokens"] = s.outputTokens
|
||||
if prompt, ok := usage["prompt_tokens"].(int); ok {
|
||||
usage["total_tokens"] = prompt + s.outputTokens
|
||||
}
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
s.model,
|
||||
[]map[string]any{openaifmt.BuildChatStreamFinishChoice(0, finishReason)},
|
||||
openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText),
|
||||
usage,
|
||||
))
|
||||
s.sendDone()
|
||||
}
|
||||
@@ -179,7 +187,13 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" {
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested}
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("content_filter")}
|
||||
}
|
||||
if parsed.Stop {
|
||||
|
||||
@@ -107,6 +107,14 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
finalThinking := result.Thinking
|
||||
finalText := sanitizeLeakedOutput(result.Text)
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
if result.OutputTokens > 0 {
|
||||
if usage, ok := respBody["usage"].(map[string]any); ok {
|
||||
usage["completion_tokens"] = result.OutputTokens
|
||||
if prompt, ok := usage["prompt_tokens"].(int); ok {
|
||||
usage["total_tokens"] = prompt + result.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
|
||||
@@ -124,6 +124,14 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, result.Thinking, sanitizedText, toolNames)
|
||||
if result.OutputTokens > 0 {
|
||||
if usage, ok := responseObj["usage"].(map[string]any); ok {
|
||||
usage["output_tokens"] = result.OutputTokens
|
||||
if input, ok := usage["input_tokens"].(int); ok {
|
||||
usage["total_tokens"] = input + result.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ type responsesStreamRuntime struct {
|
||||
messagePartAdded bool
|
||||
sequence int
|
||||
failed bool
|
||||
outputTokens int
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
@@ -144,6 +145,14 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.closeIncompleteFunctionItems()
|
||||
|
||||
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
|
||||
if s.outputTokens > 0 {
|
||||
if usage, ok := obj["usage"].(map[string]any); ok {
|
||||
usage["output_tokens"] = s.outputTokens
|
||||
if input, ok := usage["input_tokens"].(int); ok {
|
||||
usage["total_tokens"] = input + s.outputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
@@ -172,6 +181,9 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
@@ -183,3 +183,53 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsStreamContentFilterStopsNormallyWithoutLeak(t *testing.T) {
|
||||
statuses := make([]int, 0, 1)
|
||||
h := &Handler{
|
||||
Store: mockOpenAIConfig{wideInput: true},
|
||||
Auth: streamStatusAuthStub{},
|
||||
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"合法前缀"}`,
|
||||
`data: {"p":"response/status","v":"CONTENT_FILTER","accumulated_token_usage":77}`,
|
||||
`data: {"p":"response/content","v":"CONTENT_FILTER你好,这个问题我暂时无法回答,让我们换个话题再聊聊吧。"}`,
|
||||
)},
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(captureStatusMiddleware(&statuses))
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
reqBody := `{"model":"deepseek-chat","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if len(statuses) != 1 || statuses[0] != http.StatusOK {
|
||||
t.Fatalf("expected captured status 200, got %#v", statuses)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "这个问题我暂时无法回答") {
|
||||
t.Fatalf("expected leaked content-filter suffix to be hidden, body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if len(frames) == 0 {
|
||||
t.Fatalf("expected at least one json frame, body=%s", rec.Body.String())
|
||||
}
|
||||
last := frames[len(frames)-1]
|
||||
choices, _ := last["choices"].([]any)
|
||||
if len(choices) != 1 {
|
||||
t.Fatalf("expected one choice in final frame, got %#v", last)
|
||||
}
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop for content-filter upstream stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,8 +44,8 @@ func NewApp() *App {
|
||||
}
|
||||
|
||||
openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient}
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: openaiHandler}
|
||||
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: openaiHandler}
|
||||
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient}
|
||||
webuiHandler := webui.NewHandler()
|
||||
|
||||
|
||||
@@ -10,8 +10,9 @@ import (
|
||||
// CollectResult holds the aggregated text and thinking content from a
|
||||
// DeepSeek SSE stream, consumed to completion (non-streaming use case).
|
||||
type CollectResult struct {
|
||||
Text string
|
||||
Thinking string
|
||||
Text string
|
||||
Thinking string
|
||||
OutputTokens int
|
||||
}
|
||||
|
||||
// CollectStream fully consumes a DeepSeek SSE response and separates
|
||||
@@ -26,6 +27,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
}
|
||||
text := strings.Builder{}
|
||||
thinking := strings.Builder{}
|
||||
outputTokens := 0
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
currentType = "thinking"
|
||||
@@ -37,8 +39,14 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
return true
|
||||
}
|
||||
if result.Stop {
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
return false
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
for _, p := range result.Parts {
|
||||
if p.Type == "thinking" {
|
||||
thinking.WriteString(p.Text)
|
||||
@@ -48,5 +56,5 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
}
|
||||
return true
|
||||
})
|
||||
return CollectResult{Text: text.String(), Thinking: thinking.String()}
|
||||
return CollectResult{Text: text.String(), Thinking: thinking.String(), OutputTokens: outputTokens}
|
||||
}
|
||||
|
||||
@@ -138,3 +138,15 @@ func TestCollectStreamStatusFinished(t *testing.T) {
|
||||
t.Fatalf("expected 'Hello', got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectStreamStopsOnContentFilterStatus(t *testing.T) {
|
||||
resp := makeHTTPResponse(
|
||||
"data: {\"p\":\"response/content\",\"v\":\"safe\"}\n" +
|
||||
"data: {\"p\":\"response/status\",\"v\":\"CONTENT_FILTER\"}\n" +
|
||||
"data: {\"p\":\"response/content\",\"v\":\"blocked\"}\n",
|
||||
)
|
||||
result := CollectStream(resp, false, false)
|
||||
if result.Text != "safe" {
|
||||
t.Fatalf("expected stream to stop before blocked tail, got %q", result.Text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ type LineResult struct {
|
||||
ErrorMessage string
|
||||
Parts []ContentPart
|
||||
NextType string
|
||||
OutputTokens int
|
||||
}
|
||||
|
||||
// ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both
|
||||
@@ -35,8 +36,17 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Parsed: true,
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
ErrorMessage: "content filtered by upstream",
|
||||
NextType: currentType,
|
||||
OutputTokens: extractAccumulatedTokenUsage(chunk),
|
||||
}
|
||||
}
|
||||
if hasContentFilterStatus(chunk) {
|
||||
return LineResult{
|
||||
Parsed: true,
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
NextType: currentType,
|
||||
OutputTokens: extractAccumulatedTokenUsage(chunk),
|
||||
}
|
||||
}
|
||||
parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
@@ -46,5 +56,6 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Stop: finished,
|
||||
Parts: parts,
|
||||
NextType: nextType,
|
||||
OutputTokens: extractAccumulatedTokenUsage(chunk),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,8 +40,8 @@ func TestParseDeepSeekContentLineContentFilterMessage(t *testing.T) {
|
||||
if !res.ContentFilter {
|
||||
t.Fatal("expected content filter flag")
|
||||
}
|
||||
if res.ErrorMessage == "" {
|
||||
t.Fatal("expected error message on content filter")
|
||||
if res.ErrorMessage != "" {
|
||||
t.Fatalf("expected empty error message on content filter, got %q", res.ErrorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,33 @@ func TestParseDeepSeekContentLineContentFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentFilterCodeIncludesOutputTokens(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine(
|
||||
[]byte(`data: {"code":"content_filter","accumulated_token_usage":99}`),
|
||||
false, "text",
|
||||
)
|
||||
if !res.Parsed || !res.Stop || !res.ContentFilter {
|
||||
t.Fatalf("expected content-filter stop result: %#v", res)
|
||||
}
|
||||
if res.OutputTokens != 99 {
|
||||
t.Fatalf("expected output token usage 99, got %d", res.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentFilterStatus(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/status","v":"CONTENT_FILTER"}`), false, "text")
|
||||
if !res.Parsed || !res.Stop || !res.ContentFilter {
|
||||
t.Fatalf("expected status-based content-filter stop result: %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineCapturesAccumulatedTokenUsage(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response","o":"BATCH","v":[{"p":"accumulated_token_usage","v":1383},{"p":"quasi_status","v":"FINISHED"}]}`), false, "text")
|
||||
if res.OutputTokens != 1383 {
|
||||
t.Fatalf("expected output token usage 1383, got %d", res.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContent(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"hi"}`), false, "text")
|
||||
if !res.Parsed || res.Stop {
|
||||
@@ -65,3 +92,13 @@ func TestParseDeepSeekContentLineTrimsFromContentFilterKeyword(t *testing.T) {
|
||||
t.Fatalf("unexpected parts after filter: %#v", res.Parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentTextEqualContentFilterDoesNotStop(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response/content","v":"content_filter"}`), false, "text")
|
||||
if !res.Parsed {
|
||||
t.Fatalf("expected parsed result: %#v", res)
|
||||
}
|
||||
if res.Stop || res.ContentFilter {
|
||||
t.Fatalf("did not expect content-filter stop for content text: %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package sse
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/deepseek"
|
||||
@@ -287,3 +288,90 @@ func extractContentRecursive(items []any, defaultType string) ([]ContentPart, bo
|
||||
func IsCitation(text string) bool {
|
||||
return bytes.HasPrefix([]byte(strings.TrimSpace(text)), []byte("[citation:"))
|
||||
}
|
||||
|
||||
func hasContentFilterStatus(chunk map[string]any) bool {
|
||||
if code, _ := chunk["code"].(string); strings.EqualFold(strings.TrimSpace(code), "content_filter") {
|
||||
return true
|
||||
}
|
||||
return hasContentFilterStatusValue(chunk)
|
||||
}
|
||||
|
||||
func hasContentFilterStatusValue(v any) bool {
|
||||
switch x := v.(type) {
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
if hasContentFilterStatusValue(item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "status") {
|
||||
if s, _ := x["v"].(string); strings.EqualFold(strings.TrimSpace(s), "content_filter") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if code, _ := x["code"].(string); strings.EqualFold(strings.TrimSpace(code), "content_filter") {
|
||||
return true
|
||||
}
|
||||
for _, vv := range x {
|
||||
if hasContentFilterStatusValue(vv) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAccumulatedTokenUsage(chunk map[string]any) int {
|
||||
return findAccumulatedTokenUsage(chunk)
|
||||
}
|
||||
|
||||
func findAccumulatedTokenUsage(v any) int {
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "accumulated_token_usage") {
|
||||
if n, ok := toInt(x["v"]); ok && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
if n, ok := toInt(x["accumulated_token_usage"]); ok && n > 0 {
|
||||
return n
|
||||
}
|
||||
for _, vv := range x {
|
||||
if n := findAccumulatedTokenUsage(vv); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
if n := findAccumulatedTokenUsage(item); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func toInt(v any) (int, bool) {
|
||||
switch x := v.(type) {
|
||||
case int:
|
||||
return x, true
|
||||
case int32:
|
||||
return int(x), true
|
||||
case int64:
|
||||
return int(x), true
|
||||
case float64:
|
||||
if math.IsNaN(x) || math.IsInf(x, 0) {
|
||||
return 0, false
|
||||
}
|
||||
return int(x), true
|
||||
case json.Number:
|
||||
i, err := x.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int(i), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
67
internal/translatorcliproxy/bridge.go
Normal file
67
internal/translatorcliproxy/bridge.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package translatorcliproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator/builtin"
|
||||
)
|
||||
|
||||
func ToOpenAI(from sdktranslator.Format, model string, raw []byte, stream bool) []byte {
|
||||
return sdktranslator.TranslateRequest(from, sdktranslator.FormatOpenAI, model, raw, stream)
|
||||
}
|
||||
|
||||
func FromOpenAINonStream(to sdktranslator.Format, model string, originalReq, translatedReq, raw []byte) []byte {
|
||||
var param any
|
||||
return sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, ¶m)
|
||||
}
|
||||
|
||||
func FromOpenAIStream(to sdktranslator.Format, model string, originalReq, translatedReq, streamBody []byte) []byte {
|
||||
var out bytes.Buffer
|
||||
var param any
|
||||
for _, line := range bytes.Split(streamBody, []byte("\n")) {
|
||||
trimmed := strings.TrimSpace(string(line))
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
payload := append([]byte(nil), line...)
|
||||
if !bytes.HasPrefix(payload, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, payload, ¶m)
|
||||
for i := range chunks {
|
||||
out.Write(chunks[i])
|
||||
if !bytes.HasSuffix(chunks[i], []byte("\n")) {
|
||||
out.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
return out.Bytes()
|
||||
}
|
||||
|
||||
func ParseFormat(name string) sdktranslator.Format {
|
||||
switch strings.ToLower(strings.TrimSpace(name)) {
|
||||
case "openai", "openai-chat", "chat", "chat-completions":
|
||||
return sdktranslator.FormatOpenAI
|
||||
case "openai-response", "responses", "openai-responses":
|
||||
return sdktranslator.FormatOpenAIResponse
|
||||
case "claude", "anthropic":
|
||||
return sdktranslator.FormatClaude
|
||||
case "gemini", "google":
|
||||
return sdktranslator.FormatGemini
|
||||
case "gemini-cli", "geminicli":
|
||||
return sdktranslator.FormatGeminiCLI
|
||||
case "codex", "openai-codex":
|
||||
return sdktranslator.FormatCodex
|
||||
case "antigravity":
|
||||
return sdktranslator.FormatAntigravity
|
||||
default:
|
||||
return sdktranslator.FromString(name)
|
||||
}
|
||||
}
|
||||
|
||||
func ToOpenAIByName(formatName, model string, raw []byte, stream bool) []byte {
|
||||
return ToOpenAI(ParseFormat(formatName), model, raw, stream)
|
||||
}
|
||||
72
internal/translatorcliproxy/bridge_test.go
Normal file
72
internal/translatorcliproxy/bridge_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package translatorcliproxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func TestToOpenAIClaude(t *testing.T) {
|
||||
raw := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`)
|
||||
got := ToOpenAI(sdktranslator.FormatClaude, "claude-sonnet-4-5", raw, false)
|
||||
s := string(got)
|
||||
if !strings.Contains(s, `"messages"`) || !strings.Contains(s, `"model"`) {
|
||||
t.Fatalf("unexpected translated request: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromOpenAINonStreamClaude(t *testing.T) {
|
||||
original := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`)
|
||||
translatedReq := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`)
|
||||
openaibody := []byte(`{"id":"chatcmpl_1","object":"chat.completion","created":1,"model":"claude-sonnet-4-5","choices":[{"index":0,"message":{"role":"assistant","content":"hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`)
|
||||
got := FromOpenAINonStream(sdktranslator.FormatClaude, "claude-sonnet-4-5", original, translatedReq, openaibody)
|
||||
if !strings.Contains(string(got), `"type":"message"`) {
|
||||
t.Fatalf("expected claude response format, got: %s", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFormatAliases(t *testing.T) {
|
||||
cases := map[string]sdktranslator.Format{
|
||||
"responses": sdktranslator.FormatOpenAIResponse,
|
||||
"anthropic": sdktranslator.FormatClaude,
|
||||
"geminicli": sdktranslator.FormatGeminiCLI,
|
||||
"openai-codex": sdktranslator.FormatCodex,
|
||||
"antigravity": sdktranslator.FormatAntigravity,
|
||||
"chat-completions": sdktranslator.FormatOpenAI,
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := ParseFormat(in); got != want {
|
||||
t.Fatalf("ParseFormat(%q)=%q want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToOpenAIByNameAllSupportedFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
format string
|
||||
model string
|
||||
body string
|
||||
}{
|
||||
{name: "openai", format: "openai", model: "gpt-4.1", body: `{"model":"gpt-4.1","messages":[{"role":"user","content":"hi"}],"stream":false}`},
|
||||
{name: "responses", format: "responses", model: "gpt-4.1", body: `{"model":"gpt-4.1","input":"hello","stream":false}`},
|
||||
{name: "claude", format: "claude", model: "claude-sonnet-4-5", body: `{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hello"}],"stream":false}`},
|
||||
{name: "gemini", format: "gemini", model: "gemini-2.5-pro", body: `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}`},
|
||||
{name: "gemini-cli", format: "gemini-cli", model: "gemini-2.5-pro", body: `{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hello"}],"stream":false}`},
|
||||
{name: "codex", format: "codex", model: "gpt-5-codex", body: `{"model":"gpt-5-codex","messages":[{"role":"user","content":"hello"}],"stream":false}`},
|
||||
{name: "antigravity", format: "antigravity", model: "gpt-4.1", body: `{"model":"gpt-4.1","messages":[{"role":"user","content":"hello"}],"stream":false}`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ToOpenAIByName(tc.format, tc.model, []byte(tc.body), false)
|
||||
if len(got) == 0 {
|
||||
t.Fatalf("expected non-empty conversion result for format=%s", tc.format)
|
||||
}
|
||||
if !strings.Contains(string(got), `"model"`) {
|
||||
t.Fatalf("expected model field in converted payload, got=%s", string(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
120
internal/translatorcliproxy/stream_writer.go
Normal file
120
internal/translatorcliproxy/stream_writer.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package translatorcliproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
// OpenAIStreamTranslatorWriter translates OpenAI SSE output to another client format in real-time.
|
||||
type OpenAIStreamTranslatorWriter struct {
|
||||
dst http.ResponseWriter
|
||||
target sdktranslator.Format
|
||||
model string
|
||||
originalReq []byte
|
||||
translatedReq []byte
|
||||
param any
|
||||
statusCode int
|
||||
headersSent bool
|
||||
lineBuf bytes.Buffer
|
||||
}
|
||||
|
||||
func NewOpenAIStreamTranslatorWriter(dst http.ResponseWriter, target sdktranslator.Format, model string, originalReq, translatedReq []byte) *OpenAIStreamTranslatorWriter {
|
||||
return &OpenAIStreamTranslatorWriter{
|
||||
dst: dst,
|
||||
target: target,
|
||||
model: model,
|
||||
originalReq: originalReq,
|
||||
translatedReq: translatedReq,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) Header() http.Header {
|
||||
return w.dst.Header()
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) WriteHeader(statusCode int) {
|
||||
if w.headersSent {
|
||||
return
|
||||
}
|
||||
w.statusCode = statusCode
|
||||
w.headersSent = true
|
||||
w.dst.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) Write(p []byte) (int, error) {
|
||||
if !w.headersSent {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if w.statusCode < 200 || w.statusCode >= 300 {
|
||||
return w.dst.Write(p)
|
||||
}
|
||||
w.lineBuf.Write(p)
|
||||
for {
|
||||
line, ok := w.readOneLine()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(trimmed, []byte(":")) {
|
||||
if _, err := w.dst.Write(trimmed); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if _, err := w.dst.Write([]byte("\n\n")); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if f, ok := w.dst.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(context.Background(), sdktranslator.FormatOpenAI, w.target, w.model, w.originalReq, w.translatedReq, trimmed, &w.param)
|
||||
for i := range chunks {
|
||||
if len(chunks[i]) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := w.dst.Write(chunks[i]); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if !bytes.HasSuffix(chunks[i], []byte("\n")) {
|
||||
if _, err := w.dst.Write([]byte("\n")); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
}
|
||||
}
|
||||
if f, ok := w.dst.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) Flush() {
|
||||
if f, ok := w.dst.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) Unwrap() http.ResponseWriter {
|
||||
return w.dst
|
||||
}
|
||||
|
||||
func (w *OpenAIStreamTranslatorWriter) readOneLine() ([]byte, bool) {
|
||||
b := w.lineBuf.Bytes()
|
||||
idx := bytes.IndexByte(b, '\n')
|
||||
if idx < 0 {
|
||||
return nil, false
|
||||
}
|
||||
line := append([]byte(nil), b[:idx]...)
|
||||
w.lineBuf.Next(idx + 1)
|
||||
return line, true
|
||||
}
|
||||
57
internal/translatorcliproxy/stream_writer_test.go
Normal file
57
internal/translatorcliproxy/stream_writer_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package translatorcliproxy
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
func TestOpenAIStreamTranslatorWriterClaude(t *testing.T) {
|
||||
original := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`)
|
||||
translated := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":true}`)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := NewOpenAIStreamTranslatorWriter(rec, sdktranslator.FormatClaude, "claude-sonnet-4-5", original, translated)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: message_start") {
|
||||
t.Fatalf("expected claude message_start event, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamTranslatorWriterGemini(t *testing.T) {
|
||||
original := []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`)
|
||||
translated := []byte(`{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hi"}],"stream":true}`)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := NewOpenAIStreamTranslatorWriter(rec, sdktranslator.FormatGemini, "gemini-2.5-pro", original, translated)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gemini-2.5-pro\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "candidates") {
|
||||
t.Fatalf("expected gemini stream payload, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamTranslatorWriterPreservesKeepAliveComment(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
w := NewOpenAIStreamTranslatorWriter(rec, sdktranslator.FormatGemini, "gemini-2.5-pro", []byte(`{}`), []byte(`{}`))
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte(": keep-alive\n\n"))
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, ": keep-alive\n\n") {
|
||||
t.Fatalf("expected keep-alive comment passthrough, got %q", body)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user