feat: Introduce model alias resolution, enhanced configuration options, and improved OpenAI/Claude adapter handling for responses, embeddings, and tool calls.

This commit is contained in:
CJACK
2026-02-18 23:06:18 +08:00
parent 27ecb4b69b
commit 3a75b75ae0
28 changed files with 1665 additions and 183 deletions

View File

@@ -0,0 +1,35 @@
package claude
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) {
rec := httptest.NewRecorder()
writeClaudeError(rec, http.StatusUnauthorized, "bad token")
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", rec.Code)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode body: %v", err)
}
errObj, _ := body["error"].(map[string]any)
if errObj["message"] != "bad token" {
t.Fatalf("unexpected message: %v", errObj["message"])
}
if errObj["type"] != "invalid_request_error" {
t.Fatalf("unexpected type: %v", errObj["type"])
}
if errObj["code"] != "authentication_failed" {
t.Fatalf("unexpected code: %v", errObj["code"])
}
if _, ok := errObj["param"]; !ok {
t.Fatal("expected param field")
}
}

View File

@@ -43,6 +43,9 @@ func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
}
func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
r.Header.Set("anthropic-version", "2023-06-01")
}
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
@@ -50,22 +53,25 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeJSON(w, status, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": detail}})
writeClaudeError(w, status, detail)
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "invalid json"}})
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if model == "" || len(messagesRaw) == 0 {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": map[string]any{"type": "invalid_request_error", "message": "Request must include 'model' and 'messages'."}})
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
if _, ok := req["max_tokens"]; !ok {
req["max_tokens"] = 8192
}
normalized := normalizeClaudeMessages(messagesRaw)
payload := cloneMap(req)
@@ -86,12 +92,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "invalid token."}})
writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get PoW"}})
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
return
}
requestPayload := map[string]any{
@@ -104,13 +110,13 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
}
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": "Failed to get Claude response."}})
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
return
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
writeClaudeError(w, http.StatusInternalServerError, string(body))
return
}
@@ -162,20 +168,20 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
writeJSON(w, http.StatusUnauthorized, map[string]any{"error": err.Error()})
writeClaudeError(w, http.StatusUnauthorized, err.Error())
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid json"})
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
messages, _ := req["messages"].([]any)
if model == "" || len(messages) == 0 {
writeJSON(w, http.StatusBadRequest, map[string]any{"error": "Request must include 'model' and 'messages'."})
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
inputTokens := 0
@@ -206,7 +212,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeJSON(w, http.StatusInternalServerError, map[string]any{"error": map[string]any{"type": "api_error", "message": string(body)}})
writeClaudeError(w, http.StatusInternalServerError, string(body))
return
}
@@ -241,6 +247,8 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
"error": map[string]any{
"type": "api_error",
"message": msg,
"code": "internal_error",
"param": nil,
},
})
}
@@ -492,6 +500,28 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
}
}
func writeClaudeError(w http.ResponseWriter, status int, message string) {
code := "invalid_request"
switch status {
case http.StatusUnauthorized:
code = "authentication_failed"
case http.StatusTooManyRequests:
code = "rate_limit_exceeded"
case http.StatusNotFound:
code = "not_found"
case http.StatusInternalServerError:
code = "internal_error"
}
writeJSON(w, status, map[string]any{
"error": map[string]any{
"type": "invalid_request_error",
"message": message,
"code": code,
"param": nil,
},
})
}
func normalizeClaudeMessages(messages []any) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {

View File

@@ -0,0 +1,138 @@
package openai
import (
"crypto/sha256"
"encoding/binary"
"encoding/json"
"fmt"
"net/http"
"strings"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/util"
)
func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeOpenAIError(w, status, detail)
return
}
defer h.Auth.Release(a)
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
model = strings.TrimSpace(model)
if model == "" {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.")
return
}
if _, ok := config.ResolveModel(h.Store, model); !ok {
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
return
}
inputs := extractEmbeddingInputs(req["input"])
if len(inputs) == 0 {
writeOpenAIError(w, http.StatusBadRequest, "Request must include non-empty 'input'.")
return
}
provider := ""
if h.Store != nil {
provider = strings.ToLower(strings.TrimSpace(h.Store.EmbeddingsProvider()))
}
if provider == "" {
writeOpenAIError(w, http.StatusNotImplemented, "Embeddings provider is not configured. Set embeddings.provider in config.")
return
}
switch provider {
case "mock", "deterministic", "builtin":
// supported local deterministic provider
default:
writeOpenAIError(w, http.StatusNotImplemented, fmt.Sprintf("Embeddings provider '%s' is not supported.", provider))
return
}
data := make([]map[string]any, 0, len(inputs))
totalTokens := 0
for i, input := range inputs {
totalTokens += util.EstimateTokens(input)
data = append(data, map[string]any{
"object": "embedding",
"index": i,
"embedding": deterministicEmbedding(input),
})
}
writeJSON(w, http.StatusOK, map[string]any{
"object": "list",
"data": data,
"model": model,
"usage": map[string]any{
"prompt_tokens": totalTokens,
"total_tokens": totalTokens,
},
})
}
func extractEmbeddingInputs(raw any) []string {
switch v := raw.(type) {
case string:
s := strings.TrimSpace(v)
if s == "" {
return nil
}
return []string{s}
case []any:
out := make([]string, 0, len(v))
for _, item := range v {
switch iv := item.(type) {
case string:
s := strings.TrimSpace(iv)
if s != "" {
out = append(out, s)
}
case []any:
// Token array input support: convert to stable string form.
out = append(out, fmt.Sprintf("%v", iv))
default:
s := strings.TrimSpace(fmt.Sprintf("%v", iv))
if s != "" {
out = append(out, s)
}
}
}
return out
default:
return nil
}
}
func deterministicEmbedding(input string) []float64 {
// Keep response shape stable without external dependencies.
const dims = 64
out := make([]float64, dims)
seed := sha256.Sum256([]byte(input))
buf := seed[:]
for i := 0; i < dims; i++ {
if len(buf) < 4 {
next := sha256.Sum256(buf)
buf = next[:]
}
v := binary.BigEndian.Uint32(buf[:4])
buf = buf[4:]
// map [0, 2^32) -> [-1, 1]
out[i] = (float64(v)/2147483647.5 - 1.0)
}
return out
}

View File

@@ -0,0 +1,35 @@
package openai
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) {
rec := httptest.NewRecorder()
writeOpenAIError(rec, http.StatusBadRequest, "invalid input")
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", rec.Code)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode body: %v", err)
}
errObj, _ := body["error"].(map[string]any)
if errObj["message"] != "invalid input" {
t.Fatalf("unexpected message: %v", errObj["message"])
}
if errObj["type"] != "invalid_request_error" {
t.Fatalf("unexpected type: %v", errObj["type"])
}
if errObj["code"] != "invalid_request" {
t.Fatalf("unexpected code: %v", errObj["code"])
}
if _, ok := errObj["param"]; !ok {
t.Fatal("expected param field")
}
}

View File

@@ -31,6 +31,8 @@ type Handler struct {
leaseMu sync.Mutex
streamLeases map[string]streamLease
responsesMu sync.Mutex
responses *responseStore
}
type streamLease struct {
@@ -40,13 +42,27 @@ type streamLease struct {
func RegisterRoutes(r chi.Router, h *Handler) {
r.Get("/v1/models", h.ListModels)
r.Get("/v1/models/{model_id}", h.GetModel)
r.Post("/v1/chat/completions", h.ChatCompletions)
r.Post("/v1/responses", h.Responses)
r.Get("/v1/responses/{response_id}", h.GetResponseByID)
r.Post("/v1/embeddings", h.Embeddings)
}
func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
}
func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
model, ok := config.OpenAIModelByID(h.Store, modelID)
if !ok {
writeOpenAIError(w, http.StatusNotFound, "Model not found.")
return
}
writeJSON(w, http.StatusOK, model)
}
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
if isVercelStreamReleaseRequest(r) {
h.handleVercelStreamRelease(w, r)
@@ -81,11 +97,16 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
resolvedModel, ok := config.ResolveModel(h.Store, model)
if !ok {
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
return
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
responseModel := strings.TrimSpace(model)
if responseModel == "" {
responseModel = resolvedModel
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
@@ -111,16 +132,17 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
applyOpenAIChatPassThrough(req, payload)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
if util.ToBool(req["stream"]) {
h.handleStream(w, r, resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
h.handleStream(w, r, resp, sessionID, responseModel, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
return
}
h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, toolNames)
h.handleNonStream(w, r.Context(), resp, sessionID, responseModel, finalPrompt, thinkingEnabled, toolNames)
}
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
@@ -135,7 +157,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
finalThinking := result.Thinking
finalText := result.Text
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
detected := util.ParseToolCalls(finalText, toolNames)
finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText}
if thinkingEnabled && finalThinking != "" {
@@ -222,7 +244,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
finalize := func(finishReason string) {
finalThinking := thinking.String()
finalText := text.String()
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
detected := util.ParseToolCalls(finalText, toolNames)
if len(detected) > 0 && !toolCallsEmitted {
finishReason = "tool_calls"
delta := map[string]any{
@@ -497,6 +519,8 @@ func writeOpenAIError(w http.ResponseWriter, status int, message string) {
"error": map[string]any{
"message": message,
"type": openAIErrorType(status),
"code": openAIErrorCode(status),
"param": nil,
},
})
}
@@ -520,3 +544,41 @@ func openAIErrorType(status int) string {
return "invalid_request_error"
}
}
func openAIErrorCode(status int) string {
switch status {
case http.StatusBadRequest:
return "invalid_request"
case http.StatusUnauthorized:
return "authentication_failed"
case http.StatusForbidden:
return "forbidden"
case http.StatusTooManyRequests:
return "rate_limit_exceeded"
case http.StatusNotFound:
return "not_found"
case http.StatusServiceUnavailable:
return "service_unavailable"
default:
if status >= 500 {
return "internal_error"
}
return "invalid_request"
}
}
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
for _, k := range []string{
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"presence_penalty",
"frequency_penalty",
"stop",
} {
if v, ok := req[k]; ok {
payload[k] = v
}
}
}

View File

@@ -210,7 +210,7 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
}
}
func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) {
func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"下面是示例:"}`,
@@ -228,16 +228,16 @@ func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) {
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "stop" {
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
if choice["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
if _, ok := msg["tool_calls"]; ok {
t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"])
toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) == 0 {
t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"])
}
content, _ := msg["content"].(string)
if !strings.Contains(content, "示例") || !strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected embedded example to pass through as text, got %q", content)
if msg["content"] != nil {
t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"])
}
}
@@ -471,8 +471,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
@@ -489,11 +489,11 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
}
if !strings.Contains(got, `"tool_calls"`) {
t.Fatalf("expected mixed stream to preserve embedded tool_calls example text, got=%q", got)
if strings.Contains(strings.ToLower(got), `"tool_calls"`) {
t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got)
}
if streamFinishReason(frames) != "stop" {
t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String())
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
}
}

View File

@@ -0,0 +1,46 @@
package openai
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
)
func TestGetModelRouteDirectAndAlias(t *testing.T) {
h := &Handler{}
r := chi.NewRouter()
RegisterRoutes(r, h)
t.Run("direct", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/models/deepseek-chat", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
})
t.Run("alias", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/models/gpt-4.1", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 for alias, got %d body=%s", rec.Code, rec.Body.String())
}
})
}
func TestGetModelRouteNotFound(t *testing.T) {
h := &Handler{}
r := chi.NewRouter()
RegisterRoutes(r, h)
req := httptest.NewRequest(http.MethodGet, "/v1/models/not-exists", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
}
}

View File

@@ -0,0 +1,91 @@
package openai
import (
"sync"
"time"
)
type storedResponse struct {
Value map[string]any
ExpiresAt time.Time
}
type responseStore struct {
mu sync.Mutex
ttl time.Duration
items map[string]storedResponse
}
func newResponseStore(ttl time.Duration) *responseStore {
if ttl <= 0 {
ttl = 15 * time.Minute
}
return &responseStore{
ttl: ttl,
items: make(map[string]storedResponse),
}
}
func (s *responseStore) put(id string, value map[string]any) {
if s == nil || id == "" || value == nil {
return
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.sweepLocked(now)
s.items[id] = storedResponse{
Value: cloneAnyMap(value),
ExpiresAt: now.Add(s.ttl),
}
}
func (s *responseStore) get(id string) (map[string]any, bool) {
if s == nil || id == "" {
return nil, false
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.sweepLocked(now)
item, ok := s.items[id]
if !ok {
return nil, false
}
return cloneAnyMap(item.Value), true
}
func (s *responseStore) sweepLocked(now time.Time) {
for k, v := range s.items {
if now.After(v.ExpiresAt) {
delete(s.items, k)
}
}
}
func cloneAnyMap(in map[string]any) map[string]any {
if in == nil {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func (h *Handler) getResponseStore() *responseStore {
if h == nil {
return nil
}
h.responsesMu.Lock()
defer h.responsesMu.Unlock()
if h.responses == nil {
ttl := 15 * time.Minute
if h.Store != nil {
ttl = time.Duration(h.Store.ResponsesStoreTTLSeconds()) * time.Second
}
h.responses = newResponseStore(ttl)
}
return h.responses
}

View File

@@ -0,0 +1,65 @@
package openai
import (
"testing"
"time"
)
func TestNormalizeResponsesInputAsMessagesString(t *testing.T) {
msgs := normalizeResponsesInputAsMessages("hello")
if len(msgs) != 1 {
t.Fatalf("expected one message, got %d", len(msgs))
}
m, _ := msgs[0].(map[string]any)
if m["role"] != "user" || m["content"] != "hello" {
t.Fatalf("unexpected message: %#v", m)
}
}
func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) {
req := map[string]any{
"model": "gpt-4.1",
"input": "ping",
"instructions": "system text",
}
msgs := responsesMessagesFromRequest(req)
if len(msgs) != 2 {
t.Fatalf("expected two messages, got %d", len(msgs))
}
sys, _ := msgs[0].(map[string]any)
if sys["role"] != "system" {
t.Fatalf("unexpected first message: %#v", sys)
}
}
func TestExtractEmbeddingInputs(t *testing.T) {
got := extractEmbeddingInputs([]any{"a", "b"})
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
t.Fatalf("unexpected inputs: %#v", got)
}
}
func TestDeterministicEmbeddingStable(t *testing.T) {
a := deterministicEmbedding("hello")
b := deterministicEmbedding("hello")
if len(a) != 64 || len(b) != 64 {
t.Fatalf("expected 64 dims, got %d and %d", len(a), len(b))
}
for i := range a {
if a[i] != b[i] {
t.Fatalf("expected stable embedding at %d: %v != %v", i, a[i], b[i])
}
}
}
func TestResponseStorePutGet(t *testing.T) {
st := newResponseStore(100 * time.Millisecond)
st.put("resp_1", map[string]any{"id": "resp_1"})
got, ok := st.get("resp_1")
if !ok {
t.Fatal("expected stored response")
}
if got["id"] != "resp_1" {
t.Fatalf("unexpected response payload: %#v", got)
}
}

View File

@@ -0,0 +1,407 @@
package openai
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/sse"
"ds2api/internal/util"
)
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
id := strings.TrimSpace(chi.URLParam(r, "response_id"))
if id == "" {
writeOpenAIError(w, http.StatusBadRequest, "response_id is required.")
return
}
st := h.getResponseStore()
item, ok := st.get(id)
if !ok {
writeOpenAIError(w, http.StatusNotFound, "Response not found.")
return
}
writeJSON(w, http.StatusOK, item)
}
func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeOpenAIError(w, status, detail)
return
}
defer h.Auth.Release(a)
r = r.WithContext(auth.WithAuth(r.Context(), a))
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
model = strings.TrimSpace(model)
if model == "" {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.")
return
}
resolvedModel, ok := config.ResolveModel(h.Store, model)
if !ok {
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
return
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
messagesRaw := responsesMessagesFromRequest(req)
if len(messagesRaw) == 0 {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'input' or 'messages'.")
return
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
if a.UseConfigToken {
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
} else {
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
}
return
}
pow, err := h.DS.GetPow(r.Context(), a, 3)
if err != nil {
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": finalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
applyOpenAIChatPassThrough(req, payload)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if util.ToBool(req["stream"]) {
h.handleResponsesStream(w, r, resp, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
return
}
h.handleResponsesNonStream(w, resp, responseID, model, finalPrompt, thinkingEnabled, toolNames)
}
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
result := sse.CollectStream(resp, thinkingEnabled, true)
responseObj := buildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
h.getResponseStore().put(responseID, responseObj)
writeJSON(w, http.StatusOK, responseObj)
}
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
writeOpenAIError(w, resp.StatusCode, strings.TrimSpace(string(body)))
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache, no-transform")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
rc := http.NewResponseController(w)
canFlush := rc.Flush() == nil
sendEvent := func(event string, payload map[string]any) {
b, _ := json.Marshal(payload)
_, _ = w.Write([]byte("event: " + event + "\n"))
_, _ = w.Write([]byte("data: "))
_, _ = w.Write(b)
_, _ = w.Write([]byte("\n\n"))
if canFlush {
_ = rc.Flush()
}
}
sendEvent("response.created", map[string]any{
"type": "response.created",
"id": responseID,
"object": "response",
"model": model,
"status": "in_progress",
})
initialType := "text"
if thinkingEnabled {
initialType = "thinking"
}
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
bufferToolContent := len(toolNames) > 0
var sieve toolStreamSieveState
thinking := strings.Builder{}
text := strings.Builder{}
toolCallsEmitted := false
streamToolCallIDs := map[int]string{}
finalize := func() {
finalThinking := thinking.String()
finalText := text.String()
if bufferToolContent {
for _, evt := range flushToolSieve(&sieve, toolNames) {
if evt.Content != "" {
finalText += evt.Content
sendEvent("response.output_text.delta", map[string]any{
"type": "response.output_text.delta",
"id": responseID,
"delta": evt.Content,
})
}
if len(evt.ToolCalls) > 0 {
toolCallsEmitted = true
sendEvent("response.output_tool_call.done", map[string]any{
"type": "response.output_tool_call.done",
"id": responseID,
"tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls),
})
}
}
}
obj := buildResponseObject(responseID, model, finalPrompt, finalThinking, finalText, toolNames)
if toolCallsEmitted {
obj["status"] = "completed"
}
h.getResponseStore().put(responseID, obj)
sendEvent("response.completed", map[string]any{
"type": "response.completed",
"response": obj,
})
_, _ = w.Write([]byte("data: [DONE]\n\n"))
if canFlush {
_ = rc.Flush()
}
}
for {
select {
case <-r.Context().Done():
return
case parsed, ok := <-parsedLines:
if !ok {
_ = <-done
finalize()
return
}
if !parsed.Parsed {
continue
}
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
finalize()
return
}
for _, p := range parsed.Parts {
if p.Text == "" {
continue
}
if p.Type != "thinking" && searchEnabled && sse.IsCitation(p.Text) {
continue
}
if p.Type == "thinking" {
if !thinkingEnabled {
continue
}
thinking.WriteString(p.Text)
sendEvent("response.reasoning.delta", map[string]any{
"type": "response.reasoning.delta",
"id": responseID,
"delta": p.Text,
})
continue
}
text.WriteString(p.Text)
if !bufferToolContent {
sendEvent("response.output_text.delta", map[string]any{
"type": "response.output_text.delta",
"id": responseID,
"delta": p.Text,
})
continue
}
for _, evt := range processToolSieveChunk(&sieve, p.Text, toolNames) {
if evt.Content != "" {
sendEvent("response.output_text.delta", map[string]any{
"type": "response.output_text.delta",
"id": responseID,
"delta": evt.Content,
})
}
if len(evt.ToolCallDeltas) > 0 {
toolCallsEmitted = true
sendEvent("response.output_tool_call.delta", map[string]any{
"type": "response.output_tool_call.delta",
"id": responseID,
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs),
})
}
if len(evt.ToolCalls) > 0 {
toolCallsEmitted = true
sendEvent("response.output_tool_call.done", map[string]any{
"type": "response.output_tool_call.done",
"id": responseID,
"tool_calls": util.FormatOpenAIStreamToolCalls(evt.ToolCalls),
})
}
}
}
}
}
}
func buildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
detected := util.ParseToolCalls(finalText, toolNames)
output := make([]any, 0, 2)
if len(detected) > 0 {
toolCalls := make([]any, 0, len(detected))
for _, tc := range detected {
toolCalls = append(toolCalls, map[string]any{
"type": "tool_call",
"name": tc.Name,
"arguments": tc.Input,
})
}
output = append(output, map[string]any{
"type": "tool_calls",
"tool_calls": toolCalls,
})
} else {
content := []any{
map[string]any{
"type": "output_text",
"text": finalText,
},
}
if finalThinking != "" {
content = append([]any{map[string]any{
"type": "reasoning",
"text": finalThinking,
}}, content...)
}
output = append(output, map[string]any{
"type": "message",
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"role": "assistant",
"content": content,
})
}
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
return map[string]any{
"id": responseID,
"type": "response",
"object": "response",
"created_at": time.Now().Unix(),
"status": "completed",
"model": model,
"output": output,
"output_text": finalText,
"usage": map[string]any{
"input_tokens": promptTokens,
"output_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
},
}
}
func responsesMessagesFromRequest(req map[string]any) []any {
if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
return prependInstructionMessage(msgs, req["instructions"])
}
if rawInput, ok := req["input"]; ok {
if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
return prependInstructionMessage(msgs, req["instructions"])
}
}
return nil
}
func prependInstructionMessage(messages []any, instructions any) []any {
sys, _ := instructions.(string)
sys = strings.TrimSpace(sys)
if sys == "" {
return messages
}
out := make([]any, 0, len(messages)+1)
out = append(out, map[string]any{"role": "system", "content": sys})
out = append(out, messages...)
return out
}
func normalizeResponsesInputAsMessages(input any) []any {
switch v := input.(type) {
case string:
if strings.TrimSpace(v) == "" {
return nil
}
return []any{map[string]any{"role": "user", "content": v}}
case []any:
if len(v) == 0 {
return nil
}
// If caller already provides role-shaped items, keep as-is.
if first, ok := v[0].(map[string]any); ok {
if _, hasRole := first["role"]; hasRole {
return v
}
}
parts := make([]string, 0, len(v))
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
parts = append(parts, txt)
continue
}
}
}
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
parts = append(parts, s)
}
}
if len(parts) == 0 {
return nil
}
return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}}
case map[string]any:
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
return []any{map[string]any{"role": "user", "content": txt}}
}
if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" {
return []any{map[string]any{"role": "user", "content": content}}
}
}
return nil
}

View File

@@ -7,17 +7,16 @@ import (
)
type toolStreamSieveState struct {
pending strings.Builder
capture strings.Builder
capturing bool
hasMeaningfulText bool
recentTextTail string
toolNameSent bool
toolName string
toolArgsStart int
toolArgsSent int
toolArgsString bool
toolArgsDone bool
pending strings.Builder
capture strings.Builder
capturing bool
recentTextTail string
toolNameSent bool
toolName string
toolArgsStart int
toolArgsSent int
toolArgsString bool
toolArgsDone bool
}
type toolStreamEvent struct {
@@ -197,14 +196,22 @@ func findToolSegmentStart(s string) int {
return -1
}
lower := strings.ToLower(s)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return -1
offset := 0
for {
keyRel := strings.Index(lower[offset:], "tool_calls")
if keyRel < 0 {
return -1
}
keyIdx := offset + keyRel
start := strings.LastIndex(s[:keyIdx], "{")
if start < 0 {
start = keyIdx
}
if !insideCodeFence(s[:start]) {
return start
}
offset = keyIdx + len("tool_calls")
}
if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 {
return start
}
return keyIdx
}
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
@@ -227,7 +234,7 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
}
prefixPart := captured[:start]
suffixPart := captured[end:]
if !state.toolNameSent && (strings.TrimSpace(prefixPart) != "" || looksLikeToolExampleContext(state.recentTextTail) || looksLikeToolExampleContext(suffixPart)) {
if insideCodeFence(state.recentTextTail + prefixPart) {
return captured, nil, "", true
}
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
@@ -293,16 +300,16 @@ func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if captured == "" {
return nil
}
if looksLikeToolExampleContext(state.recentTextTail) {
return nil
}
lower := strings.ToLower(captured)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return nil
}
start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 || strings.TrimSpace(captured[:start]) != "" {
if start < 0 {
return nil
}
if insideCodeFence(state.recentTextTail + captured[:start]) {
return nil
}
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
@@ -612,7 +619,6 @@ func (s *toolStreamSieveState) noteText(content string) {
if strings.TrimSpace(content) == "" {
return
}
s.hasMeaningfulText = true
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
}
@@ -628,25 +634,12 @@ func appendTail(prev, next string, max int) string {
}
func looksLikeToolExampleContext(text string) bool {
t := strings.ToLower(strings.TrimSpace(text))
if t == "" {
return insideCodeFence(text)
}
func insideCodeFence(text string) bool {
if text == "" {
return false
}
cues := []string{
"示例",
"例子",
"for example",
"example",
"demo",
"请勿执行",
"不要执行",
"do not execute",
"```",
}
for _, cue := range cues {
if strings.Contains(t, cue) {
return true
}
}
return false
return strings.Count(text, "```")%2 == 1
}

View File

@@ -62,11 +62,16 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
resolvedModel, ok := config.ResolveModel(h.Store, model)
if !ok {
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
return
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
responseModel := strings.TrimSpace(model)
if responseModel == "" {
responseModel = resolvedModel
}
finalPrompt, _ := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
@@ -97,6 +102,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
applyOpenAIChatPassThrough(req, payload)
leaseID := h.holdStreamLease(a)
if leaseID == "" {
writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease")
@@ -106,7 +112,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
writeJSON(w, http.StatusOK, map[string]any{
"session_id": sessionID,
"lease_id": leaseID,
"model": model,
"model": responseModel,
"final_prompt": finalPrompt,
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,

View File

@@ -62,11 +62,33 @@ type Config struct {
Accounts []Account `json:"accounts,omitempty"`
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
ModelAliases map[string]string `json:"model_aliases,omitempty"`
Compat CompatConfig `json:"compat,omitempty"`
Toolcall ToolcallConfig `json:"toolcall,omitempty"`
Responses ResponsesConfig `json:"responses,omitempty"`
Embeddings EmbeddingsConfig `json:"embeddings,omitempty"`
VercelSyncHash string `json:"_vercel_sync_hash,omitempty"`
VercelSyncTime int64 `json:"_vercel_sync_time,omitempty"`
AdditionalFields map[string]any `json:"-"`
}
type CompatConfig struct {
WideInputStrictOutput bool `json:"wide_input_strict_output,omitempty"`
}
type ToolcallConfig struct {
Mode string `json:"mode,omitempty"`
EarlyEmitConfidence string `json:"early_emit_confidence,omitempty"`
}
type ResponsesConfig struct {
StoreTTLSeconds int `json:"store_ttl_seconds,omitempty"`
}
type EmbeddingsConfig struct {
Provider string `json:"provider,omitempty"`
}
func (c Config) MarshalJSON() ([]byte, error) {
m := map[string]any{}
for k, v := range c.AdditionalFields {
@@ -84,6 +106,21 @@ func (c Config) MarshalJSON() ([]byte, error) {
if len(c.ClaudeModelMap) > 0 {
m["claude_model_mapping"] = c.ClaudeModelMap
}
if len(c.ModelAliases) > 0 {
m["model_aliases"] = c.ModelAliases
}
if c.Compat.WideInputStrictOutput {
m["compat"] = c.Compat
}
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
m["toolcall"] = c.Toolcall
}
if c.Responses.StoreTTLSeconds > 0 {
m["responses"] = c.Responses
}
if strings.TrimSpace(c.Embeddings.Provider) != "" {
m["embeddings"] = c.Embeddings
}
if c.VercelSyncHash != "" {
m["_vercel_sync_hash"] = c.VercelSyncHash
}
@@ -117,6 +154,26 @@ func (c *Config) UnmarshalJSON(b []byte) error {
if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "model_aliases":
if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "compat":
if err := json.Unmarshal(v, &c.Compat); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "toolcall":
if err := json.Unmarshal(v, &c.Toolcall); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "responses":
if err := json.Unmarshal(v, &c.Responses); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "embeddings":
if err := json.Unmarshal(v, &c.Embeddings); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "_vercel_sync_hash":
if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
@@ -141,6 +198,11 @@ func (c Config) Clone() Config {
Accounts: slices.Clone(c.Accounts),
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
ModelAliases: cloneStringMap(c.ModelAliases),
Compat: c.Compat,
Toolcall: c.Toolcall,
Responses: c.Responses,
Embeddings: c.Embeddings,
VercelSyncHash: c.VercelSyncHash,
VercelSyncTime: c.VercelSyncTime,
AdditionalFields: map[string]any{},
@@ -490,3 +552,59 @@ func (s *Store) ClaudeMapping() map[string]string {
}
return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
}
func (s *Store) ModelAliases() map[string]string {
s.mu.RLock()
defer s.mu.RUnlock()
out := DefaultModelAliases()
for k, v := range s.cfg.ModelAliases {
key := strings.TrimSpace(lower(k))
val := strings.TrimSpace(lower(v))
if key == "" || val == "" {
continue
}
out[key] = val
}
return out
}
func (s *Store) CompatWideInputStrictOutput() bool {
// Current default policy is always wide-input / strict-output.
// Kept as a method so callers do not depend on storage shape.
return true
}
func (s *Store) ToolcallMode() string {
s.mu.RLock()
defer s.mu.RUnlock()
mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
if mode == "" {
return "feature_match"
}
return mode
}
func (s *Store) ToolcallEarlyEmitConfidence() string {
s.mu.RLock()
defer s.mu.RUnlock()
level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
if level == "" {
return "high"
}
return level
}
func (s *Store) ResponsesStoreTTLSeconds() int {
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Responses.StoreTTLSeconds > 0 {
return s.cfg.Responses.StoreTTLSeconds
}
return 900
}
func (s *Store) EmbeddingsProvider() string {
s.mu.RLock()
defer s.mu.RUnlock()
return strings.TrimSpace(s.cfg.Embeddings.Provider)
}

View File

@@ -0,0 +1,44 @@
package config
import "testing"
func TestResolveModelDirectDeepSeek(t *testing.T) {
got, ok := ResolveModel(nil, "deepseek-chat")
if !ok || got != "deepseek-chat" {
t.Fatalf("expected deepseek-chat, got ok=%v model=%q", ok, got)
}
}
func TestResolveModelAlias(t *testing.T) {
got, ok := ResolveModel(nil, "gpt-4.1")
if !ok || got != "deepseek-chat" {
t.Fatalf("expected alias gpt-4.1 -> deepseek-chat, got ok=%v model=%q", ok, got)
}
}
func TestResolveModelHeuristicReasoner(t *testing.T) {
got, ok := ResolveModel(nil, "o3-super")
if !ok || got != "deepseek-reasoner" {
t.Fatalf("expected heuristic reasoner, got ok=%v model=%q", ok, got)
}
}
func TestResolveModelUnknown(t *testing.T) {
_, ok := ResolveModel(nil, "totally-custom-model")
if ok {
t.Fatal("expected unknown model to fail resolve")
}
}
func TestClaudeModelsResponsePaginationFields(t *testing.T) {
resp := ClaudeModelsResponse()
if _, ok := resp["first_id"]; !ok {
t.Fatalf("expected first_id in response: %#v", resp)
}
if _, ok := resp["last_id"]; !ok {
t.Fatalf("expected last_id in response: %#v", resp)
}
if _, ok := resp["has_more"]; !ok {
t.Fatalf("expected has_more in response: %#v", resp)
}
}

View File

@@ -1,5 +1,7 @@
package config
import "strings"
type ModelInfo struct {
ID string `json:"id"`
Object string `json:"object"`
@@ -71,6 +73,91 @@ func GetModelConfig(model string) (thinking bool, search bool, ok bool) {
}
}
func IsSupportedDeepSeekModel(model string) bool {
_, _, ok := GetModelConfig(model)
return ok
}
func DefaultModelAliases() map[string]string {
return map[string]string{
"gpt-4o": "deepseek-chat",
"gpt-4.1": "deepseek-chat",
"gpt-4.1-mini": "deepseek-chat",
"gpt-4.1-nano": "deepseek-chat",
"gpt-5": "deepseek-chat",
"gpt-5-mini": "deepseek-chat",
"gpt-5-codex": "deepseek-reasoner",
"o1": "deepseek-reasoner",
"o1-mini": "deepseek-reasoner",
"o3": "deepseek-reasoner",
"o3-mini": "deepseek-reasoner",
"claude-sonnet-4-5": "deepseek-chat",
"claude-haiku-4-5": "deepseek-chat",
"claude-opus-4-6": "deepseek-reasoner",
"claude-3-5-sonnet": "deepseek-chat",
"claude-3-5-haiku": "deepseek-chat",
"claude-3-opus": "deepseek-reasoner",
"gemini-2.5-pro": "deepseek-chat",
"gemini-2.5-flash": "deepseek-chat",
"llama-3.1-70b-instruct": "deepseek-chat",
"qwen-max": "deepseek-chat",
}
}
func ResolveModel(store *Store, requested string) (string, bool) {
model := lower(strings.TrimSpace(requested))
if model == "" {
return "", false
}
if IsSupportedDeepSeekModel(model) {
return model, true
}
aliases := DefaultModelAliases()
if store != nil {
for k, v := range store.ModelAliases() {
aliases[lower(strings.TrimSpace(k))] = lower(strings.TrimSpace(v))
}
}
if mapped, ok := aliases[model]; ok && IsSupportedDeepSeekModel(mapped) {
return mapped, true
}
if strings.HasPrefix(model, "deepseek-") {
return "", false
}
knownFamily := false
for _, prefix := range []string{
"gpt-", "o1", "o3", "claude-", "gemini-", "llama-", "qwen-", "mistral-", "command-",
} {
if strings.HasPrefix(model, prefix) {
knownFamily = true
break
}
}
if !knownFamily {
return "", false
}
useReasoner := strings.Contains(model, "reason") ||
strings.Contains(model, "reasoner") ||
strings.HasPrefix(model, "o1") ||
strings.HasPrefix(model, "o3") ||
strings.Contains(model, "opus") ||
strings.Contains(model, "r1")
useSearch := strings.Contains(model, "search")
switch {
case useReasoner && useSearch:
return "deepseek-reasoner-search", true
case useReasoner:
return "deepseek-reasoner", true
case useSearch:
return "deepseek-chat-search", true
default:
return "deepseek-chat", true
}
}
func lower(s string) string {
b := []byte(s)
for i, c := range b {
@@ -85,6 +172,28 @@ func OpenAIModelsResponse() map[string]any {
return map[string]any{"object": "list", "data": DeepSeekModels}
}
func ClaudeModelsResponse() map[string]any {
return map[string]any{"object": "list", "data": ClaudeModels}
func OpenAIModelByID(store *Store, id string) (ModelInfo, bool) {
canonical, ok := ResolveModel(store, id)
if !ok {
return ModelInfo{}, false
}
for _, model := range DeepSeekModels {
if model.ID == canonical {
return model, true
}
}
return ModelInfo{}, false
}
func ClaudeModelsResponse() map[string]any {
resp := map[string]any{"object": "list", "data": ClaudeModels}
if len(ClaudeModels) > 0 {
resp["first_id"] = ClaudeModels[0].ID
resp["last_id"] = ClaudeModels[len(ClaudeModels)-1].ID
} else {
resp["first_id"] = nil
resp["last_id"] = nil
}
resp["has_more"] = false
return resp
}

View File

@@ -92,7 +92,7 @@ func cors(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return

View File

@@ -10,6 +10,7 @@ import (
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
type ParsedToolCall struct {
Name string `json:"name"`
@@ -20,6 +21,10 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
if strings.TrimSpace(text) == "" {
return nil
}
text = stripFencedCodeBlocks(text)
if strings.TrimSpace(text) == "" {
return nil
}
candidates := buildToolCallCandidates(text)
var parsed []ParsedToolCall
@@ -45,11 +50,6 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed
return nil
}
candidates := []string{trimmed}
if strings.HasPrefix(trimmed, "```") && strings.HasSuffix(trimmed, "```") {
if m := fencedJSONPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
candidates = append(candidates, strings.TrimSpace(m[1]))
}
}
for _, candidate := range candidates {
candidate = strings.TrimSpace(candidate)
if candidate == "" {
@@ -321,23 +321,14 @@ func looksLikeToolExampleContext(text string) bool {
if t == "" {
return false
}
cues := []string{
"```",
"示例",
"例子",
"for example",
"example",
"demo",
"请勿执行",
"不要执行",
"do not execute",
return strings.Contains(t, "```")
}
func stripFencedCodeBlocks(text string) string {
if strings.TrimSpace(text) == "" {
return ""
}
for _, cue := range cues {
if strings.Contains(t, cue) {
return true
}
}
return false
return fencedBlockPattern.ReplaceAllString(text, " ")
}
func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {

View File

@@ -19,11 +19,8 @@ func TestParseToolCalls(t *testing.T) {
func TestParseToolCallsFromFencedJSON(t *testing.T) {
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Input["q"] != "news" {
t.Fatalf("unexpected args: %#v", calls[0].Input)
if len(calls) != 0 {
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
}
}

View File

@@ -416,18 +416,6 @@ func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) {
// ─── looksLikeToolExampleContext ─────────────────────────────────────
func TestLooksLikeToolExampleContextChinese(t *testing.T) {
if !looksLikeToolExampleContext("下面是示例") {
t.Fatal("expected true for Chinese example context")
}
}
func TestLooksLikeToolExampleContextEnglish(t *testing.T) {
if !looksLikeToolExampleContext("here is an example of") {
t.Fatal("expected true for English example context")
}
}
func TestLooksLikeToolExampleContextNone(t *testing.T) {
if looksLikeToolExampleContext("I will call the tool now") {
t.Fatal("expected false for non-example context")