mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 00:15:28 +08:00
241 lines
5.4 KiB
Go
241 lines
5.4 KiB
Go
package translatorcliproxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
|
)
|
|
|
|
// OpenAIStreamTranslatorWriter translates OpenAI SSE output to another client format in real-time.
|
|
type OpenAIStreamTranslatorWriter struct {
|
|
dst http.ResponseWriter
|
|
target sdktranslator.Format
|
|
model string
|
|
originalReq []byte
|
|
translatedReq []byte
|
|
param any
|
|
statusCode int
|
|
headersSent bool
|
|
lineBuf bytes.Buffer
|
|
}
|
|
|
|
func NewOpenAIStreamTranslatorWriter(dst http.ResponseWriter, target sdktranslator.Format, model string, originalReq, translatedReq []byte) *OpenAIStreamTranslatorWriter {
|
|
return &OpenAIStreamTranslatorWriter{
|
|
dst: dst,
|
|
target: target,
|
|
model: model,
|
|
originalReq: originalReq,
|
|
translatedReq: translatedReq,
|
|
statusCode: http.StatusOK,
|
|
}
|
|
}
|
|
|
|
func (w *OpenAIStreamTranslatorWriter) Header() http.Header {
|
|
return w.dst.Header()
|
|
}
|
|
|
|
func (w *OpenAIStreamTranslatorWriter) WriteHeader(statusCode int) {
|
|
if w.headersSent {
|
|
return
|
|
}
|
|
w.statusCode = statusCode
|
|
w.headersSent = true
|
|
w.dst.WriteHeader(statusCode)
|
|
}
|
|
|
|
func (w *OpenAIStreamTranslatorWriter) Write(p []byte) (int, error) {
|
|
if !w.headersSent {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
if w.statusCode < 200 || w.statusCode >= 300 {
|
|
return w.dst.Write(p)
|
|
}
|
|
w.lineBuf.Write(p)
|
|
for {
|
|
line, ok := w.readOneLine()
|
|
if !ok {
|
|
break
|
|
}
|
|
trimmed := bytes.TrimSpace(line)
|
|
if len(trimmed) == 0 {
|
|
continue
|
|
}
|
|
if bytes.HasPrefix(trimmed, []byte(":")) {
|
|
if _, err := w.dst.Write(trimmed); err != nil {
|
|
return len(p), err
|
|
}
|
|
if _, err := w.dst.Write([]byte("\n\n")); err != nil {
|
|
return len(p), err
|
|
}
|
|
if f, ok := w.dst.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
continue
|
|
}
|
|
if !bytes.HasPrefix(trimmed, []byte("data:")) {
|
|
continue
|
|
}
|
|
usage, hasUsage := extractOpenAIUsage(trimmed)
|
|
chunks := sdktranslator.TranslateStream(context.Background(), sdktranslator.FormatOpenAI, w.target, w.model, w.originalReq, w.translatedReq, trimmed, &w.param)
|
|
if hasUsage {
|
|
for i := range chunks {
|
|
chunks[i] = injectStreamUsageMetadata(chunks[i], w.target, usage)
|
|
}
|
|
}
|
|
for i := range chunks {
|
|
if len(chunks[i]) == 0 {
|
|
continue
|
|
}
|
|
if _, err := w.dst.Write(chunks[i]); err != nil {
|
|
return len(p), err
|
|
}
|
|
if !bytes.HasSuffix(chunks[i], []byte("\n")) {
|
|
if _, err := w.dst.Write([]byte("\n")); err != nil {
|
|
return len(p), err
|
|
}
|
|
}
|
|
}
|
|
if f, ok := w.dst.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
func (w *OpenAIStreamTranslatorWriter) Flush() {
|
|
if f, ok := w.dst.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
}
|
|
|
|
func (w *OpenAIStreamTranslatorWriter) Unwrap() http.ResponseWriter {
|
|
return w.dst
|
|
}
|
|
|
|
func (w *OpenAIStreamTranslatorWriter) readOneLine() ([]byte, bool) {
|
|
b := w.lineBuf.Bytes()
|
|
idx := bytes.IndexByte(b, '\n')
|
|
if idx < 0 {
|
|
return nil, false
|
|
}
|
|
line := append([]byte(nil), b[:idx]...)
|
|
w.lineBuf.Next(idx + 1)
|
|
return line, true
|
|
}
|
|
|
|
type openAIUsage struct {
|
|
PromptTokens int
|
|
CompletionTokens int
|
|
TotalTokens int
|
|
}
|
|
|
|
func extractOpenAIUsage(line []byte) (openAIUsage, bool) {
|
|
raw := strings.TrimSpace(strings.TrimPrefix(string(line), "data:"))
|
|
if raw == "" || raw == "[DONE]" {
|
|
return openAIUsage{}, false
|
|
}
|
|
var payload map[string]any
|
|
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
|
return openAIUsage{}, false
|
|
}
|
|
usageObj, _ := payload["usage"].(map[string]any)
|
|
if usageObj == nil {
|
|
return openAIUsage{}, false
|
|
}
|
|
p := toInt(usageObj["prompt_tokens"])
|
|
c := toInt(usageObj["completion_tokens"])
|
|
t := toInt(usageObj["total_tokens"])
|
|
if p <= 0 {
|
|
p = toInt(usageObj["input_tokens"])
|
|
}
|
|
if c <= 0 {
|
|
c = toInt(usageObj["output_tokens"])
|
|
}
|
|
if p <= 0 && c <= 0 && t <= 0 {
|
|
return openAIUsage{}, false
|
|
}
|
|
if t <= 0 {
|
|
t = p + c
|
|
}
|
|
return openAIUsage{PromptTokens: p, CompletionTokens: c, TotalTokens: t}, true
|
|
}
|
|
|
|
func injectStreamUsageMetadata(chunk []byte, target sdktranslator.Format, usage openAIUsage) []byte {
|
|
if target != sdktranslator.FormatGemini {
|
|
return chunk
|
|
}
|
|
suffix := ""
|
|
switch {
|
|
case bytes.HasSuffix(chunk, []byte("\n\n")):
|
|
suffix = "\n\n"
|
|
case bytes.HasSuffix(chunk, []byte("\n")):
|
|
suffix = "\n"
|
|
}
|
|
text := strings.TrimSpace(string(chunk))
|
|
if text == "" {
|
|
return chunk
|
|
}
|
|
var (
|
|
hasDataPrefix bool
|
|
jsonText = text
|
|
)
|
|
if strings.HasPrefix(jsonText, "data:") {
|
|
hasDataPrefix = true
|
|
jsonText = strings.TrimSpace(strings.TrimPrefix(jsonText, "data:"))
|
|
}
|
|
if jsonText == "" || jsonText == "[DONE]" {
|
|
return chunk
|
|
}
|
|
obj := map[string]any{}
|
|
if err := json.Unmarshal([]byte(jsonText), &obj); err != nil {
|
|
return chunk
|
|
}
|
|
if _, ok := obj["candidates"]; !ok {
|
|
return chunk
|
|
}
|
|
obj["usageMetadata"] = map[string]any{
|
|
"promptTokenCount": usage.PromptTokens,
|
|
"candidatesTokenCount": usage.CompletionTokens,
|
|
"totalTokenCount": usage.TotalTokens,
|
|
}
|
|
b, err := json.Marshal(obj)
|
|
if err != nil {
|
|
return chunk
|
|
}
|
|
if hasDataPrefix {
|
|
return []byte("data: " + string(b) + suffix)
|
|
}
|
|
if suffix != "" {
|
|
return append(b, []byte(suffix)...)
|
|
}
|
|
return b
|
|
}
|
|
|
|
func toInt(v any) int {
|
|
switch x := v.(type) {
|
|
case int:
|
|
return x
|
|
case int32:
|
|
return int(x)
|
|
case int64:
|
|
return int(x)
|
|
case float64:
|
|
return int(x)
|
|
case float32:
|
|
return int(x)
|
|
case string:
|
|
n, err := strconv.Atoi(strings.TrimSpace(x))
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return n
|
|
default:
|
|
return 0
|
|
}
|
|
}
|