Files
ds2api/internal/translatorcliproxy/stream_writer.go

241 lines
5.4 KiB
Go

package translatorcliproxy
import (
"bytes"
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
// OpenAIStreamTranslatorWriter translates OpenAI SSE output to another client format in real-time.
type OpenAIStreamTranslatorWriter struct {
dst http.ResponseWriter
target sdktranslator.Format
model string
originalReq []byte
translatedReq []byte
param any
statusCode int
headersSent bool
lineBuf bytes.Buffer
}
func NewOpenAIStreamTranslatorWriter(dst http.ResponseWriter, target sdktranslator.Format, model string, originalReq, translatedReq []byte) *OpenAIStreamTranslatorWriter {
return &OpenAIStreamTranslatorWriter{
dst: dst,
target: target,
model: model,
originalReq: originalReq,
translatedReq: translatedReq,
statusCode: http.StatusOK,
}
}
func (w *OpenAIStreamTranslatorWriter) Header() http.Header {
return w.dst.Header()
}
func (w *OpenAIStreamTranslatorWriter) WriteHeader(statusCode int) {
if w.headersSent {
return
}
w.statusCode = statusCode
w.headersSent = true
w.dst.WriteHeader(statusCode)
}
func (w *OpenAIStreamTranslatorWriter) Write(p []byte) (int, error) {
if !w.headersSent {
w.WriteHeader(http.StatusOK)
}
if w.statusCode < 200 || w.statusCode >= 300 {
return w.dst.Write(p)
}
w.lineBuf.Write(p)
for {
line, ok := w.readOneLine()
if !ok {
break
}
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 {
continue
}
if bytes.HasPrefix(trimmed, []byte(":")) {
if _, err := w.dst.Write(trimmed); err != nil {
return len(p), err
}
if _, err := w.dst.Write([]byte("\n\n")); err != nil {
return len(p), err
}
if f, ok := w.dst.(http.Flusher); ok {
f.Flush()
}
continue
}
if !bytes.HasPrefix(trimmed, []byte("data:")) {
continue
}
usage, hasUsage := extractOpenAIUsage(trimmed)
chunks := sdktranslator.TranslateStream(context.Background(), sdktranslator.FormatOpenAI, w.target, w.model, w.originalReq, w.translatedReq, trimmed, &w.param)
if hasUsage {
for i := range chunks {
chunks[i] = injectStreamUsageMetadata(chunks[i], w.target, usage)
}
}
for i := range chunks {
if len(chunks[i]) == 0 {
continue
}
if _, err := w.dst.Write(chunks[i]); err != nil {
return len(p), err
}
if !bytes.HasSuffix(chunks[i], []byte("\n")) {
if _, err := w.dst.Write([]byte("\n")); err != nil {
return len(p), err
}
}
}
if f, ok := w.dst.(http.Flusher); ok {
f.Flush()
}
}
return len(p), nil
}
func (w *OpenAIStreamTranslatorWriter) Flush() {
if f, ok := w.dst.(http.Flusher); ok {
f.Flush()
}
}
func (w *OpenAIStreamTranslatorWriter) Unwrap() http.ResponseWriter {
return w.dst
}
func (w *OpenAIStreamTranslatorWriter) readOneLine() ([]byte, bool) {
b := w.lineBuf.Bytes()
idx := bytes.IndexByte(b, '\n')
if idx < 0 {
return nil, false
}
line := append([]byte(nil), b[:idx]...)
w.lineBuf.Next(idx + 1)
return line, true
}
type openAIUsage struct {
PromptTokens int
CompletionTokens int
TotalTokens int
}
func extractOpenAIUsage(line []byte) (openAIUsage, bool) {
raw := strings.TrimSpace(strings.TrimPrefix(string(line), "data:"))
if raw == "" || raw == "[DONE]" {
return openAIUsage{}, false
}
var payload map[string]any
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
return openAIUsage{}, false
}
usageObj, _ := payload["usage"].(map[string]any)
if usageObj == nil {
return openAIUsage{}, false
}
p := toInt(usageObj["prompt_tokens"])
c := toInt(usageObj["completion_tokens"])
t := toInt(usageObj["total_tokens"])
if p <= 0 {
p = toInt(usageObj["input_tokens"])
}
if c <= 0 {
c = toInt(usageObj["output_tokens"])
}
if p <= 0 && c <= 0 && t <= 0 {
return openAIUsage{}, false
}
if t <= 0 {
t = p + c
}
return openAIUsage{PromptTokens: p, CompletionTokens: c, TotalTokens: t}, true
}
func injectStreamUsageMetadata(chunk []byte, target sdktranslator.Format, usage openAIUsage) []byte {
if target != sdktranslator.FormatGemini {
return chunk
}
suffix := ""
switch {
case bytes.HasSuffix(chunk, []byte("\n\n")):
suffix = "\n\n"
case bytes.HasSuffix(chunk, []byte("\n")):
suffix = "\n"
}
text := strings.TrimSpace(string(chunk))
if text == "" {
return chunk
}
var (
hasDataPrefix bool
jsonText = text
)
if strings.HasPrefix(jsonText, "data:") {
hasDataPrefix = true
jsonText = strings.TrimSpace(strings.TrimPrefix(jsonText, "data:"))
}
if jsonText == "" || jsonText == "[DONE]" {
return chunk
}
obj := map[string]any{}
if err := json.Unmarshal([]byte(jsonText), &obj); err != nil {
return chunk
}
if _, ok := obj["candidates"]; !ok {
return chunk
}
obj["usageMetadata"] = map[string]any{
"promptTokenCount": usage.PromptTokens,
"candidatesTokenCount": usage.CompletionTokens,
"totalTokenCount": usage.TotalTokens,
}
b, err := json.Marshal(obj)
if err != nil {
return chunk
}
if hasDataPrefix {
return []byte("data: " + string(b) + suffix)
}
if suffix != "" {
return append(b, []byte(suffix)...)
}
return b
}
func toInt(v any) int {
switch x := v.(type) {
case int:
return x
case int32:
return int(x)
case int64:
return int(x)
case float64:
return int(x)
case float32:
return int(x)
case string:
n, err := strconv.Atoi(strings.TrimSpace(x))
if err != nil {
return 0
}
return n
default:
return 0
}
}