mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 08:25:26 +08:00
316 lines
9.4 KiB
Go
316 lines
9.4 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
|
|
"ds2api/internal/account"
|
|
"ds2api/internal/auth"
|
|
"ds2api/internal/chathistory"
|
|
"ds2api/internal/config"
|
|
dsclient "ds2api/internal/deepseek/client"
|
|
"ds2api/internal/httpapi/admin"
|
|
"ds2api/internal/httpapi/claude"
|
|
"ds2api/internal/httpapi/gemini"
|
|
"ds2api/internal/httpapi/openai/chat"
|
|
"ds2api/internal/httpapi/openai/embeddings"
|
|
"ds2api/internal/httpapi/openai/files"
|
|
"ds2api/internal/httpapi/openai/responses"
|
|
"ds2api/internal/httpapi/openai/shared"
|
|
"ds2api/internal/webui"
|
|
)
|
|
|
|
type App struct {
|
|
Store *config.Store
|
|
Pool *account.Pool
|
|
Resolver *auth.Resolver
|
|
DS *dsclient.Client
|
|
Router http.Handler
|
|
}
|
|
|
|
func NewApp() (*App, error) {
|
|
store, err := config.LoadStoreWithError()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load config: %w", err)
|
|
}
|
|
pool := account.NewPool(store)
|
|
var dsClient *dsclient.Client
|
|
resolver := auth.NewResolver(store, pool, func(ctx context.Context, acc config.Account) (string, error) {
|
|
return dsClient.Login(ctx, acc)
|
|
})
|
|
dsClient = dsclient.NewClient(store, resolver)
|
|
if err := dsClient.PreloadPow(context.Background()); err != nil {
|
|
config.Logger.Warn("[PoW] init failed", "error", err)
|
|
} else {
|
|
config.Logger.Info("[PoW] pure Go solver ready")
|
|
}
|
|
chatHistoryStore := chathistory.New(config.ChatHistoryPath())
|
|
if err := chatHistoryStore.Err(); err != nil {
|
|
config.Logger.Warn("[chat_history] unavailable", "path", chatHistoryStore.Path(), "error", err)
|
|
}
|
|
|
|
modelsHandler := &shared.ModelsHandler{Store: store}
|
|
chatHandler := &chat.Handler{Store: store, Auth: resolver, DS: dsClient, ChatHistory: chatHistoryStore}
|
|
responsesHandler := &responses.Handler{Store: store, Auth: resolver, DS: dsClient, ChatHistory: chatHistoryStore}
|
|
filesHandler := &files.Handler{Store: store, Auth: resolver, DS: dsClient, ChatHistory: chatHistoryStore}
|
|
embeddingsHandler := &embeddings.Handler{Store: store, Auth: resolver, DS: dsClient, ChatHistory: chatHistoryStore}
|
|
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: chatHandler}
|
|
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: chatHandler}
|
|
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient, OpenAI: chatHandler, ChatHistory: chatHistoryStore}
|
|
webuiHandler := webui.NewHandler()
|
|
|
|
r := chi.NewRouter()
|
|
r.Use(middleware.RequestID)
|
|
r.Use(middleware.RealIP)
|
|
r.Use(filteredLogger())
|
|
r.Use(middleware.Recoverer)
|
|
r.Use(cors)
|
|
r.Use(timeout(0))
|
|
|
|
healthzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
|
}
|
|
readyzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"status":"ready"}`))
|
|
}
|
|
r.Get("/healthz", healthzHandler)
|
|
r.Head("/healthz", healthzHandler)
|
|
r.Get("/readyz", readyzHandler)
|
|
r.Head("/readyz", readyzHandler)
|
|
r.Get("/v1/models", modelsHandler.ListModels)
|
|
r.Get("/v1/models/{model_id}", modelsHandler.GetModel)
|
|
r.Post("/v1/chat/completions", chatHandler.ChatCompletions)
|
|
r.Post("/v1/responses", responsesHandler.Responses)
|
|
r.Get("/v1/responses/{response_id}", responsesHandler.GetResponseByID)
|
|
r.Post("/v1/files", filesHandler.UploadFile)
|
|
r.Post("/v1/embeddings", embeddingsHandler.Embeddings)
|
|
// Root OpenAI aliases support clients configured with the bare DS2API service URL.
|
|
r.Get("/models", modelsHandler.ListModels)
|
|
r.Get("/models/{model_id}", modelsHandler.GetModel)
|
|
r.Post("/chat/completions", chatHandler.ChatCompletions)
|
|
r.Post("/responses", responsesHandler.Responses)
|
|
r.Get("/responses/{response_id}", responsesHandler.GetResponseByID)
|
|
r.Post("/files", filesHandler.UploadFile)
|
|
r.Post("/embeddings", embeddingsHandler.Embeddings)
|
|
claude.RegisterRoutes(r, claudeHandler)
|
|
gemini.RegisterRoutes(r, geminiHandler)
|
|
r.Route("/admin", func(ar chi.Router) {
|
|
admin.RegisterRoutes(ar, adminHandler)
|
|
})
|
|
webui.RegisterRoutes(r, webuiHandler)
|
|
r.NotFound(func(w http.ResponseWriter, req *http.Request) {
|
|
if strings.HasPrefix(req.URL.Path, "/admin/") && webuiHandler.HandleAdminFallback(w, req) {
|
|
return
|
|
}
|
|
http.NotFound(w, req)
|
|
})
|
|
|
|
return &App{Store: store, Pool: pool, Resolver: resolver, DS: dsClient, Router: r}, nil
|
|
}
|
|
|
|
func timeout(d time.Duration) func(http.Handler) http.Handler {
|
|
if d <= 0 {
|
|
return func(next http.Handler) http.Handler { return next }
|
|
}
|
|
return middleware.Timeout(d)
|
|
}
|
|
|
|
func filteredLogger() func(http.Handler) http.Handler {
|
|
color := !isWindowsRuntime()
|
|
base := &middleware.DefaultLogFormatter{
|
|
Logger: log.New(os.Stdout, "", log.LstdFlags),
|
|
NoColor: !color,
|
|
}
|
|
return middleware.RequestLogger(&filteredLogFormatter{base: base})
|
|
}
|
|
|
|
func isWindowsRuntime() bool {
|
|
return runtime.GOOS == "windows"
|
|
}
|
|
|
|
type filteredLogFormatter struct {
|
|
base *middleware.DefaultLogFormatter
|
|
}
|
|
|
|
func (f *filteredLogFormatter) NewLogEntry(r *http.Request) middleware.LogEntry {
|
|
if r != nil && r.Method == http.MethodGet {
|
|
path := strings.TrimSpace(r.URL.Path)
|
|
if path == "/admin/chat-history" || strings.HasPrefix(path, "/admin/chat-history/") {
|
|
return noopLogEntry{}
|
|
}
|
|
}
|
|
return f.base.NewLogEntry(r)
|
|
}
|
|
|
|
type noopLogEntry struct{}
|
|
|
|
func (noopLogEntry) Write(_ int, _ int, _ http.Header, _ time.Duration, _ interface{}) {}
|
|
|
|
func (noopLogEntry) Panic(_ interface{}, _ []byte) {}
|
|
|
|
var defaultCORSAllowHeaders = []string{
|
|
"Content-Type",
|
|
"Authorization",
|
|
"X-API-Key",
|
|
"X-Ds2-Target-Account",
|
|
"X-Ds2-Source",
|
|
"X-Vercel-Protection-Bypass",
|
|
"X-Goog-Api-Key",
|
|
"Anthropic-Version",
|
|
"Anthropic-Beta",
|
|
}
|
|
|
|
var blockedCORSRequestHeaders = map[string]struct{}{
|
|
"x-ds2-internal-token": {},
|
|
}
|
|
|
|
func cors(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
setCORSHeaders(w, r)
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func setCORSHeaders(w http.ResponseWriter, r *http.Request) {
|
|
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
|
if origin == "" {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
} else {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
addVaryHeaderToken(w.Header(), "Origin")
|
|
}
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
|
|
w.Header().Set("Access-Control-Allow-Headers", buildCORSAllowHeaders(r))
|
|
w.Header().Set("Access-Control-Max-Age", "600")
|
|
addVaryHeaderToken(w.Header(), "Access-Control-Request-Headers")
|
|
if strings.EqualFold(strings.TrimSpace(r.Header.Get("Access-Control-Request-Private-Network")), "true") {
|
|
w.Header().Set("Access-Control-Allow-Private-Network", "true")
|
|
addVaryHeaderToken(w.Header(), "Access-Control-Request-Private-Network")
|
|
}
|
|
}
|
|
|
|
func buildCORSAllowHeaders(r *http.Request) string {
|
|
names := make([]string, 0, len(defaultCORSAllowHeaders)+4)
|
|
seen := make(map[string]struct{}, len(defaultCORSAllowHeaders)+4)
|
|
for _, name := range defaultCORSAllowHeaders {
|
|
appendCORSHeaderName(&names, seen, name)
|
|
}
|
|
if r == nil {
|
|
return strings.Join(names, ", ")
|
|
}
|
|
for _, name := range splitCORSRequestHeaders(r.Header.Get("Access-Control-Request-Headers")) {
|
|
appendCORSHeaderName(&names, seen, name)
|
|
}
|
|
return strings.Join(names, ", ")
|
|
}
|
|
|
|
func splitCORSRequestHeaders(raw string) []string {
|
|
if strings.TrimSpace(raw) == "" {
|
|
return nil
|
|
}
|
|
parts := strings.Split(raw, ",")
|
|
out := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
name := strings.TrimSpace(part)
|
|
if !isValidCORSHeaderToken(name) {
|
|
continue
|
|
}
|
|
if _, blocked := blockedCORSRequestHeaders[strings.ToLower(name)]; blocked {
|
|
continue
|
|
}
|
|
out = append(out, name)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func appendCORSHeaderName(dst *[]string, seen map[string]struct{}, name string) {
|
|
name = strings.TrimSpace(name)
|
|
if !isValidCORSHeaderToken(name) {
|
|
return
|
|
}
|
|
key := strings.ToLower(name)
|
|
if _, blocked := blockedCORSRequestHeaders[key]; blocked {
|
|
return
|
|
}
|
|
if _, ok := seen[key]; ok {
|
|
return
|
|
}
|
|
seen[key] = struct{}{}
|
|
*dst = append(*dst, name)
|
|
}
|
|
|
|
func isValidCORSHeaderToken(v string) bool {
|
|
if v == "" {
|
|
return false
|
|
}
|
|
for i := 0; i < len(v); i++ {
|
|
c := v[i]
|
|
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') {
|
|
continue
|
|
}
|
|
switch c {
|
|
case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~':
|
|
continue
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func addVaryHeaderToken(h http.Header, token string) {
|
|
if h == nil {
|
|
return
|
|
}
|
|
token = strings.TrimSpace(token)
|
|
if token == "" {
|
|
return
|
|
}
|
|
current := h.Values("Vary")
|
|
seen := map[string]struct{}{}
|
|
merged := make([]string, 0, len(current)+1)
|
|
for _, value := range current {
|
|
for _, part := range strings.Split(value, ",") {
|
|
name := strings.TrimSpace(part)
|
|
if name == "" {
|
|
continue
|
|
}
|
|
key := strings.ToLower(name)
|
|
if _, ok := seen[key]; ok {
|
|
continue
|
|
}
|
|
seen[key] = struct{}{}
|
|
merged = append(merged, name)
|
|
}
|
|
}
|
|
key := strings.ToLower(token)
|
|
if _, ok := seen[key]; !ok {
|
|
merged = append(merged, token)
|
|
}
|
|
h.Set("Vary", strings.Join(merged, ", "))
|
|
}
|
|
|
|
func WriteUnhandledError(w http.ResponseWriter, err error) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
_ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"type": "api_error", "message": "Internal Server Error", "detail": err.Error()}})
|
|
}
|