mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 08:55:28 +08:00
357 lines
10 KiB
Go
357 lines
10 KiB
Go
package openai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
|
|
"ds2api/internal/auth"
|
|
"ds2api/internal/config"
|
|
"ds2api/internal/deepseek"
|
|
openaifmt "ds2api/internal/format/openai"
|
|
"ds2api/internal/sse"
|
|
streamengine "ds2api/internal/stream"
|
|
"ds2api/internal/util"
|
|
)
|
|
|
|
// writeJSON is a package-internal alias kept to avoid mass-renaming across
|
|
// every call-site in this file. It delegates to the shared util version.
|
|
var writeJSON = util.WriteJSON
|
|
|
|
type Handler struct {
|
|
Store ConfigReader
|
|
Auth AuthResolver
|
|
DS DeepSeekCaller
|
|
|
|
leaseMu sync.Mutex
|
|
streamLeases map[string]streamLease
|
|
responsesMu sync.Mutex
|
|
responses *responseStore
|
|
}
|
|
|
|
type streamLease struct {
|
|
Auth *auth.RequestAuth
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
func RegisterRoutes(r chi.Router, h *Handler) {
|
|
r.Get("/v1/models", h.ListModels)
|
|
r.Get("/v1/models/{model_id}", h.GetModel)
|
|
r.Post("/v1/chat/completions", h.ChatCompletions)
|
|
r.Post("/v1/responses", h.Responses)
|
|
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
|
|
r.Post("/v1/embeddings", h.Embeddings)
|
|
}
|
|
|
|
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
|
|
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
|
|
}
|
|
|
|
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
|
|
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
|
|
model, ok := config.OpenAIModelByID(h.Store, modelID)
|
|
if !ok {
|
|
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, model)
|
|
}
|
|
|
|
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|
if isVercelStreamReleaseRequest(r) {
|
|
h.handleVercelStreamRelease(w, r)
|
|
return
|
|
}
|
|
if isVercelStreamPrepareRequest(r) {
|
|
h.handleVercelStreamPrepare(w, r)
|
|
return
|
|
}
|
|
|
|
a, err := h.Auth.Determine(r)
|
|
if err != nil {
|
|
status := http.StatusUnauthorized
|
|
detail := err.Error()
|
|
if err == auth.ErrNoAccount {
|
|
status = http.StatusTooManyRequests
|
|
}
|
|
writeOpenAIError(w, status, detail)
|
|
return
|
|
}
|
|
defer h.Auth.Release(a)
|
|
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
|
|
|
var req map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
|
return
|
|
}
|
|
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
|
|
if err != nil {
|
|
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
|
if err != nil {
|
|
if a.UseConfigToken {
|
|
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
|
} else {
|
|
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
|
}
|
|
return
|
|
}
|
|
pow, err := h.DS.GetPow(r.Context(), a, 3)
|
|
if err != nil {
|
|
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
|
return
|
|
}
|
|
payload := stdReq.CompletionPayload(sessionID)
|
|
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
|
if err != nil {
|
|
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
|
return
|
|
}
|
|
if stdReq.Stream {
|
|
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
|
return
|
|
}
|
|
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
|
}
|
|
|
|
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
|
if resp.StatusCode != http.StatusOK {
|
|
defer resp.Body.Close()
|
|
body, _ := io.ReadAll(resp.Body)
|
|
writeOpenAIError(w, resp.StatusCode, string(body))
|
|
return
|
|
}
|
|
_ = ctx
|
|
result := sse.CollectStream(resp, thinkingEnabled, true)
|
|
|
|
finalThinking := result.Thinking
|
|
finalText := result.Text
|
|
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
|
writeJSON(w, http.StatusOK, respBody)
|
|
}
|
|
|
|
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
writeOpenAIError(w, resp.StatusCode, string(body))
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache, no-transform")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
rc := http.NewResponseController(w)
|
|
canFlush := rc.Flush() == nil
|
|
if !canFlush {
|
|
config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
|
|
}
|
|
|
|
created := time.Now().Unix()
|
|
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
|
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
|
initialType := "text"
|
|
if thinkingEnabled {
|
|
initialType = "thinking"
|
|
}
|
|
|
|
streamRuntime := newChatStreamRuntime(
|
|
w,
|
|
rc,
|
|
canFlush,
|
|
completionID,
|
|
created,
|
|
model,
|
|
finalPrompt,
|
|
thinkingEnabled,
|
|
searchEnabled,
|
|
toolNames,
|
|
bufferToolContent,
|
|
emitEarlyToolDeltas,
|
|
)
|
|
|
|
streamengine.ConsumeSSE(streamengine.ConsumeConfig{
|
|
Context: r.Context(),
|
|
Body: resp.Body,
|
|
ThinkingEnabled: thinkingEnabled,
|
|
InitialType: initialType,
|
|
KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
|
|
IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
|
|
MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
|
|
}, streamengine.ConsumeHooks{
|
|
OnKeepAlive: func() {
|
|
streamRuntime.sendKeepAlive()
|
|
},
|
|
OnParsed: streamRuntime.onParsed,
|
|
OnFinalize: func(reason streamengine.StopReason, _ error) {
|
|
if string(reason) == "content_filter" {
|
|
streamRuntime.finalize("content_filter")
|
|
return
|
|
}
|
|
streamRuntime.finalize("stop")
|
|
},
|
|
})
|
|
}
|
|
|
|
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
|
toolSchemas := make([]string, 0, len(tools))
|
|
names := make([]string, 0, len(tools))
|
|
for _, t := range tools {
|
|
tool, ok := t.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
fn, _ := tool["function"].(map[string]any)
|
|
if len(fn) == 0 {
|
|
fn = tool
|
|
}
|
|
name, _ := fn["name"].(string)
|
|
desc, _ := fn["description"].(string)
|
|
schema, _ := fn["parameters"].(map[string]any)
|
|
if name == "" {
|
|
name = "unknown"
|
|
}
|
|
names = append(names, name)
|
|
if desc == "" {
|
|
desc = "No description available"
|
|
}
|
|
b, _ := json.Marshal(schema)
|
|
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
|
}
|
|
if len(toolSchemas) == 0 {
|
|
return messages, names
|
|
}
|
|
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
|
|
|
for i := range messages {
|
|
if messages[i]["role"] == "system" {
|
|
old, _ := messages[i]["content"].(string)
|
|
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
|
return messages, names
|
|
}
|
|
}
|
|
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
|
return messages, names
|
|
}
|
|
|
|
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
|
|
if len(deltas) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]map[string]any, 0, len(deltas))
|
|
for _, d := range deltas {
|
|
if d.Name == "" && d.Arguments == "" {
|
|
continue
|
|
}
|
|
callID, ok := ids[d.Index]
|
|
if !ok || callID == "" {
|
|
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
|
ids[d.Index] = callID
|
|
}
|
|
item := map[string]any{
|
|
"index": d.Index,
|
|
"id": callID,
|
|
"type": "function",
|
|
}
|
|
fn := map[string]any{}
|
|
if d.Name != "" {
|
|
fn["name"] = d.Name
|
|
}
|
|
if d.Arguments != "" {
|
|
fn["arguments"] = d.Arguments
|
|
}
|
|
if len(fn) > 0 {
|
|
item["function"] = fn
|
|
}
|
|
out = append(out, item)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
|
writeJSON(w, status, map[string]any{
|
|
"error": map[string]any{
|
|
"message": message,
|
|
"type": openAIErrorType(status),
|
|
"code": openAIErrorCode(status),
|
|
"param": nil,
|
|
},
|
|
})
|
|
}
|
|
|
|
func openAIErrorType(status int) string {
|
|
switch status {
|
|
case http.StatusBadRequest:
|
|
return "invalid_request_error"
|
|
case http.StatusUnauthorized:
|
|
return "authentication_error"
|
|
case http.StatusForbidden:
|
|
return "permission_error"
|
|
case http.StatusTooManyRequests:
|
|
return "rate_limit_error"
|
|
case http.StatusServiceUnavailable:
|
|
return "service_unavailable_error"
|
|
default:
|
|
if status >= 500 {
|
|
return "api_error"
|
|
}
|
|
return "invalid_request_error"
|
|
}
|
|
}
|
|
|
|
func openAIErrorCode(status int) string {
|
|
switch status {
|
|
case http.StatusBadRequest:
|
|
return "invalid_request"
|
|
case http.StatusUnauthorized:
|
|
return "authentication_failed"
|
|
case http.StatusForbidden:
|
|
return "forbidden"
|
|
case http.StatusTooManyRequests:
|
|
return "rate_limit_exceeded"
|
|
case http.StatusNotFound:
|
|
return "not_found"
|
|
case http.StatusServiceUnavailable:
|
|
return "service_unavailable"
|
|
default:
|
|
if status >= 500 {
|
|
return "internal_error"
|
|
}
|
|
return "invalid_request"
|
|
}
|
|
}
|
|
|
|
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
|
|
for k, v := range collectOpenAIChatPassThrough(req) {
|
|
payload[k] = v
|
|
}
|
|
}
|
|
|
|
func (h *Handler) toolcallFeatureMatchEnabled() bool {
|
|
if h == nil || h.Store == nil {
|
|
return true
|
|
}
|
|
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
|
|
return mode == "" || mode == "feature_match"
|
|
}
|
|
|
|
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
|
|
if h == nil || h.Store == nil {
|
|
return true
|
|
}
|
|
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
|
|
return level == "" || level == "high"
|
|
}
|