feat: Introduce standard request normalization and response building for OpenAI and Claude, enhance tool call streaming, and improve caller identification.

This commit is contained in:
CJACK
2026-02-18 23:35:17 +08:00
parent 3a75b75ae0
commit eb253a9d3a
18 changed files with 805 additions and 155 deletions

View File

@@ -309,7 +309,7 @@ data: [DONE]
### `GET /v1/responses/{response_id}`
Business auth required. Fetches cached responses created by `POST /v1/responses`.
Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read).
> Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`).

2
API.md
View File

@@ -309,7 +309,7 @@ data: [DONE]
### `GET /v1/responses/{response_id}`
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象。
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)
> 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。

View File

@@ -63,32 +63,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
writeClaudeError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if model == "" || len(messagesRaw) == 0 {
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
norm, err := normalizeClaudeRequest(h.Store, req)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, err.Error())
return
}
if _, ok := req["max_tokens"]; !ok {
req["max_tokens"] = 8192
}
normalized := normalizeClaudeMessages(messagesRaw)
payload := cloneMap(req)
payload["messages"] = normalized
toolsRequested, _ := req["tools"].([]any)
if len(toolsRequested) > 0 && !hasSystemMessage(normalized) {
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...)
}
dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store)
dsModel, _ := dsPayload["model"].(string)
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
if !ok {
thinkingEnabled = false
searchEnabled = false
}
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
stdReq := norm.Standard
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
@@ -100,14 +80,7 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
return
}
requestPayload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": finalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
requestPayload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
if err != nil {
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
@@ -120,15 +93,14 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
return
}
toolNames := extractClaudeToolNames(toolsRequested)
if util.ToBool(req["stream"]) {
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
if stdReq.Stream {
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
result := sse.CollectStream(resp, thinkingEnabled, true)
result := sse.CollectStream(resp, stdReq.Thinking, true)
fullText := result.Text
fullThinking := result.Thinking
detected := util.ParseToolCalls(fullText, toolNames)
detected := util.ParseToolCalls(fullText, stdReq.ToolNames)
content := make([]map[string]any, 0, 4)
if fullThinking != "" {
content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking})
@@ -154,12 +126,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
"id": fmt.Sprintf("msg_%d", time.Now().UnixNano()),
"type": "message",
"role": "assistant",
"model": model,
"model": stdReq.ResponseModel,
"content": content,
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)),
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", norm.NormalizedMessages)),
"output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText),
},
})

View File

@@ -0,0 +1,58 @@
package claude
import (
"fmt"
"strings"
"ds2api/internal/config"
"ds2api/internal/util"
)
type claudeNormalizedRequest struct {
Standard util.StandardRequest
NormalizedMessages []any
}
func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) {
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
}
if _, ok := req["max_tokens"]; !ok {
req["max_tokens"] = 8192
}
normalizedMessages := normalizeClaudeMessages(messagesRaw)
payload := cloneMap(req)
payload["messages"] = normalizedMessages
toolsRequested, _ := req["tools"].([]any)
if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) {
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...)
}
dsPayload := util.ConvertClaudeToDeepSeek(payload, store)
dsModel, _ := dsPayload["model"].(string)
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
if !ok {
thinkingEnabled = false
searchEnabled = false
}
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
toolNames := extractClaudeToolNames(toolsRequested)
return claudeNormalizedRequest{
Standard: util.StandardRequest{
Surface: "anthropic_messages",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: dsModel,
ResponseModel: strings.TrimSpace(model),
Messages: payload["messages"].([]any),
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
},
NormalizedMessages: normalizedMessages,
}, nil
}

View File

@@ -0,0 +1,38 @@
package claude
import (
"testing"
"ds2api/internal/config"
)
func TestNormalizeClaudeRequest(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{}`)
store := config.LoadStore()
req := map[string]any{
"model": "claude-opus-4-6",
"messages": []any{
map[string]any{"role": "user", "content": "hello"},
},
"stream": true,
"tools": []any{
map[string]any{"name": "search", "description": "Search"},
},
}
norm, err := normalizeClaudeRequest(store, req)
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if norm.Standard.ResolvedModel == "" {
t.Fatalf("expected resolved model")
}
if !norm.Standard.Stream {
t.Fatalf("expected stream=true")
}
if len(norm.Standard.ToolNames) == 0 {
t.Fatalf("expected tool names")
}
if norm.Standard.FinalPrompt == "" {
t.Fatalf("expected non-empty final prompt")
}
}

View File

@@ -0,0 +1,96 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"ds2api/internal/account"
"ds2api/internal/auth"
"ds2api/internal/config"
)
func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", cfgJSON)
store := config.LoadStore()
pool := account.NewPool(store)
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "unused", nil
})
return store, resolver
}
func TestEmbeddingsRouteContract(t *testing.T) {
store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
t.Run("unauthorized", func(t *testing.T) {
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
}
})
t.Run("ok", func(t *testing.T) {
body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`)
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
if out["object"] != "list" {
t.Fatalf("unexpected object: %#v", out["object"])
}
data, _ := out["data"].([]any)
if len(data) != 2 {
t.Fatalf("expected 2 embeddings, got %d", len(data))
}
})
}
func TestEmbeddingsRouteProviderMissing(t *testing.T) {
store, resolver := newResolverWithConfigJSON(t, `{}`)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusNotImplemented {
t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
errObj, _ := out["error"].(map[string]any)
if _, ok := errObj["code"]; !ok {
t.Fatalf("expected error.code in response: %#v", out)
}
if _, ok := errObj["param"]; !ok {
t.Fatalf("expected error.param in response: %#v", out)
}
}

View File

@@ -91,24 +91,11 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if model == "" || len(messagesRaw) == 0 {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
resolvedModel, ok := config.ResolveModel(h.Store, model)
if !ok {
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
return
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
responseModel := strings.TrimSpace(model)
if responseModel == "" {
responseModel = resolvedModel
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
@@ -124,25 +111,17 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": finalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
applyOpenAIChatPassThrough(req, payload)
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
return
}
if util.ToBool(req["stream"]) {
h.handleStream(w, r, resp, sessionID, responseModel, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
if stdReq.Stream {
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
h.handleNonStream(w, r.Context(), resp, sessionID, responseModel, finalPrompt, thinkingEnabled, toolNames)
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
}
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
@@ -208,7 +187,8 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
created := time.Now().Unix()
firstChunkSent := false
bufferToolContent := len(toolNames) > 0
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
var toolSieve toolStreamSieveState
toolCallsEmitted := false
streamToolCallIDs := map[int]string{}
@@ -377,6 +357,9 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
}
for _, evt := range events {
if len(evt.ToolCallDeltas) > 0 {
if !emitEarlyToolDeltas {
continue
}
toolCallsEmitted = true
tcDelta := map[string]any{
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs),
@@ -568,17 +551,23 @@ func openAIErrorCode(status int) string {
}
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
for _, k := range []string{
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"presence_penalty",
"frequency_penalty",
"stop",
} {
if v, ok := req[k]; ok {
payload[k] = v
}
for k, v := range collectOpenAIChatPassThrough(req) {
payload[k] = v
}
}
func (h *Handler) toolcallFeatureMatchEnabled() bool {
if h == nil || h.Store == nil {
return true
}
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
return mode == "" || mode == "feature_match"
}
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
if h == nil || h.Store == nil {
return true
}
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
return level == "" || level == "high"
}

View File

@@ -3,9 +3,12 @@ package openai
import (
"sync"
"time"
"ds2api/internal/auth"
)
type storedResponse struct {
Owner string
Value map[string]any
ExpiresAt time.Time
}
@@ -26,32 +29,47 @@ func newResponseStore(ttl time.Duration) *responseStore {
}
}
func (s *responseStore) put(id string, value map[string]any) {
if s == nil || id == "" || value == nil {
func responseStoreKey(owner, id string) string {
return owner + "\x00" + id
}
func responseStoreOwner(a *auth.RequestAuth) string {
if a == nil {
return ""
}
return a.CallerID
}
func (s *responseStore) put(owner, id string, value map[string]any) {
if s == nil || owner == "" || id == "" || value == nil {
return
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.sweepLocked(now)
s.items[id] = storedResponse{
s.items[responseStoreKey(owner, id)] = storedResponse{
Owner: owner,
Value: cloneAnyMap(value),
ExpiresAt: now.Add(s.ttl),
}
}
func (s *responseStore) get(id string) (map[string]any, bool) {
if s == nil || id == "" {
func (s *responseStore) get(owner, id string) (map[string]any, bool) {
if s == nil || owner == "" || id == "" {
return nil, false
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.sweepLocked(now)
item, ok := s.items[id]
item, ok := s.items[responseStoreKey(owner, id)]
if !ok {
return nil, false
}
if item.Owner != owner {
return nil, false
}
return cloneAnyMap(item.Value), true
}

View File

@@ -54,8 +54,8 @@ func TestDeterministicEmbeddingStable(t *testing.T) {
func TestResponseStorePutGet(t *testing.T) {
st := newResponseStore(100 * time.Millisecond)
st.put("resp_1", map[string]any{"id": "resp_1"})
got, ok := st.get("resp_1")
st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"})
got, ok := st.get("owner_1", "resp_1")
if !ok {
t.Fatal("expected stored response")
}
@@ -63,3 +63,11 @@ func TestResponseStorePutGet(t *testing.T) {
t.Fatalf("unexpected response payload: %#v", got)
}
}
func TestResponseStoreTenantIsolation(t *testing.T) {
st := newResponseStore(100 * time.Millisecond)
st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"})
if _, ok := st.get("owner_b", "resp_1"); ok {
t.Fatal("expected owner_b to be isolated from owner_a response")
}
}

View File

@@ -12,19 +12,35 @@ import (
"github.com/google/uuid"
"ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/sse"
"ds2api/internal/util"
)
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeOpenAIError(w, status, detail)
return
}
defer h.Auth.Release(a)
id := strings.TrimSpace(chi.URLParam(r, "response_id"))
if id == "" {
writeOpenAIError(w, http.StatusBadRequest, "response_id is required.")
return
}
owner := responseStoreOwner(a)
if owner == "" {
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
return
}
st := h.getResponseStore()
item, ok := st.get(id)
item, ok := st.get(owner, id)
if !ok {
writeOpenAIError(w, http.StatusNotFound, "Response not found.")
return
@@ -45,32 +61,22 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
}
defer h.Auth.Release(a)
r = r.WithContext(auth.WithAuth(r.Context(), a))
owner := responseStoreOwner(a)
if owner == "" {
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
return
}
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
return
}
model, _ := req["model"].(string)
model = strings.TrimSpace(model)
if model == "" {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.")
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
resolvedModel, ok := config.ResolveModel(h.Store, model)
if !ok {
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
return
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
messagesRaw := responsesMessagesFromRequest(req)
if len(messagesRaw) == 0 {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'input' or 'messages'.")
return
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
@@ -86,15 +92,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
return
}
payload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": finalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
applyOpenAIChatPassThrough(req, payload)
payload := stdReq.CompletionPayload(sessionID)
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
@@ -102,14 +100,14 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
}
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if util.ToBool(req["stream"]) {
h.handleResponsesStream(w, r, resp, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
if stdReq.Stream {
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
return
}
h.handleResponsesNonStream(w, resp, responseID, model, finalPrompt, thinkingEnabled, toolNames)
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
}
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -118,11 +116,11 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
}
result := sse.CollectStream(resp, thinkingEnabled, true)
responseObj := buildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
h.getResponseStore().put(responseID, responseObj)
h.getResponseStore().put(owner, responseID, responseObj)
writeJSON(w, http.StatusOK, responseObj)
}
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -160,7 +158,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
initialType = "thinking"
}
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
bufferToolContent := len(toolNames) > 0
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
var sieve toolStreamSieveState
thinking := strings.Builder{}
text := strings.Builder{}
@@ -194,7 +193,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
if toolCallsEmitted {
obj["status"] = "completed"
}
h.getResponseStore().put(responseID, obj)
h.getResponseStore().put(owner, responseID, obj)
sendEvent("response.completed", map[string]any{
"type": "response.completed",
"response": obj,
@@ -259,6 +258,9 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
})
}
if len(evt.ToolCallDeltas) > 0 {
if !emitEarlyToolDeltas {
continue
}
toolCallsEmitted = true
sendEvent("response.output_tool_call.delta", map[string]any{
"type": "response.output_tool_call.delta",

View File

@@ -0,0 +1,125 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"ds2api/internal/account"
"ds2api/internal/auth"
"ds2api/internal/config"
)
func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
store := config.LoadStore()
pool := account.NewPool(store)
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "unused", nil
})
return store, resolver
}
func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth {
t.Helper()
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
req.Header.Set("Authorization", "Bearer "+token)
a, err := resolver.Determine(req)
if err != nil {
t.Fatalf("determine auth failed: %v", err)
}
return a
}
func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) {
store, resolver := newDirectTokenResolver(t)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
ownerA := responseStoreOwner(authForToken(t, resolver, "token-a"))
h.getResponseStore().put(ownerA, "resp_test", map[string]any{
"id": "resp_test",
"object": "response",
})
t.Run("unauthorized", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
}
})
t.Run("cross-tenant-not-found", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
req.Header.Set("Authorization", "Bearer token-b")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
}
})
t.Run("same-tenant-ok", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
req.Header.Set("Authorization", "Bearer token-a")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode body failed: %v", err)
}
if body["id"] != "resp_test" {
t.Fatalf("unexpected body: %#v", body)
}
})
}
func TestResponsesRouteValidationContract(t *testing.T) {
store, resolver := newDirectTokenResolver(t)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
tests := []struct {
name string
body string
}{
{name: "missing_model", body: `{"input":"hello"}`},
{name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body))
req.Header.Set("Authorization", "Bearer token-a")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
errObj, _ := out["error"].(map[string]any)
if _, ok := errObj["code"]; !ok {
t.Fatalf("expected error.code: %#v", out)
}
if _, ok := errObj["param"]; !ok {
t.Fatalf("expected error.param: %#v", out)
}
})
}
}

View File

@@ -0,0 +1,104 @@
package openai
import (
"fmt"
"strings"
"ds2api/internal/config"
"ds2api/internal/util"
)
func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
}
resolvedModel, ok := config.ResolveModel(store, model)
if !ok {
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
responseModel := strings.TrimSpace(model)
if responseModel == "" {
responseModel = resolvedModel
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
Surface: "openai_chat",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: resolvedModel,
ResponseModel: responseModel,
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
}, nil
}
func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
model, _ := req["model"].(string)
model = strings.TrimSpace(model)
if model == "" {
return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.")
}
resolvedModel, ok := config.ResolveModel(store, model)
if !ok {
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
// Keep width-control as an explicit policy hook even if current default is true.
allowWideInput := true
if store != nil {
allowWideInput = store.CompatWideInputStrictOutput()
}
var messagesRaw []any
if allowWideInput {
messagesRaw = responsesMessagesFromRequest(req)
} else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
messagesRaw = msgs
}
if len(messagesRaw) == 0 {
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
}
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
passThrough := collectOpenAIChatPassThrough(req)
return util.StandardRequest{
Surface: "openai_responses",
RequestedModel: model,
ResolvedModel: resolvedModel,
ResponseModel: model,
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
}, nil
}
func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
out := map[string]any{}
for _, k := range []string{
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"presence_penalty",
"frequency_penalty",
"stop",
} {
if v, ok := req[k]; ok {
out[k] = v
}
}
return out
}

View File

@@ -0,0 +1,60 @@
package openai
import (
"testing"
"ds2api/internal/config"
)
func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", `{}`)
return config.LoadStore()
}
func TestNormalizeOpenAIChatRequest(t *testing.T) {
store := newEmptyStoreForNormalizeTest(t)
req := map[string]any{
"model": "gpt-5-codex",
"messages": []any{
map[string]any{"role": "user", "content": "hello"},
},
"temperature": 0.3,
"stream": true,
}
n, err := normalizeOpenAIChatRequest(store, req)
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if n.ResolvedModel != "deepseek-reasoner" {
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
}
if !n.Stream {
t.Fatalf("expected stream=true")
}
if _, ok := n.PassThrough["temperature"]; !ok {
t.Fatalf("expected temperature passthrough")
}
if n.FinalPrompt == "" {
t.Fatalf("expected non-empty final prompt")
}
}
func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
store := newEmptyStoreForNormalizeTest(t)
req := map[string]any{
"model": "gpt-4o",
"input": "ping",
"instructions": "system",
}
n, err := normalizeOpenAIResponsesRequest(store, req)
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
if n.ResolvedModel != "deepseek-chat" {
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
}
if len(n.Messages) != 2 {
t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages))
}
}

View File

@@ -56,24 +56,15 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
return
}
model, _ := req["model"].(string)
messagesRaw, _ := req["messages"].([]any)
if model == "" || len(messagesRaw) == 0 {
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
if err != nil {
writeOpenAIError(w, http.StatusBadRequest, err.Error())
return
}
resolvedModel, ok := config.ResolveModel(h.Store, model)
if !ok {
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
if !stdReq.Stream {
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
return
}
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
responseModel := strings.TrimSpace(model)
if responseModel == "" {
responseModel = resolvedModel
}
finalPrompt, _ := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
if err != nil {
@@ -94,15 +85,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
return
}
payload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": finalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
}
applyOpenAIChatPassThrough(req, payload)
payload := stdReq.CompletionPayload(sessionID)
leaseID := h.holdStreamLease(a)
if leaseID == "" {
writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease")
@@ -112,10 +95,10 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
writeJSON(w, http.StatusOK, map[string]any{
"session_id": sessionID,
"lease_id": leaseID,
"model": responseModel,
"final_prompt": finalPrompt,
"thinking_enabled": thinkingEnabled,
"search_enabled": searchEnabled,
"model": stdReq.ResponseModel,
"final_prompt": stdReq.FinalPrompt,
"thinking_enabled": stdReq.Thinking,
"search_enabled": stdReq.Search,
"deepseek_token": a.DeepSeekToken,
"pow_header": powHeader,
"payload": payload,

View File

@@ -2,6 +2,8 @@ package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"net/http"
"strings"
@@ -22,6 +24,7 @@ var (
type RequestAuth struct {
UseConfigToken bool
DeepSeekToken string
CallerID string
AccountID string
Account config.Account
TriedAccounts map[string]bool
@@ -45,9 +48,16 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
if callerKey == "" {
return nil, ErrUnauthorized
}
callerID := callerTokenID(callerKey)
ctx := req.Context()
if !r.Store.HasAPIKey(callerKey) {
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
return &RequestAuth{
UseConfigToken: false,
DeepSeekToken: callerKey,
CallerID: callerID,
resolver: r,
TriedAccounts: map[string]bool{},
}, nil
}
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
acc, ok := r.Pool.AcquireWait(ctx, target, nil)
@@ -56,6 +66,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
}
a := &RequestAuth{
UseConfigToken: true,
CallerID: callerID,
AccountID: acc.Identifier(),
Account: acc,
TriedAccounts: map[string]bool{},
@@ -158,3 +169,12 @@ func extractCallerToken(req *http.Request) string {
}
return strings.TrimSpace(req.Header.Get("x-api-key"))
}
func callerTokenID(token string) string {
token = strings.TrimSpace(token)
if token == "" {
return ""
}
sum := sha256.Sum256([]byte(token))
return "caller:" + hex.EncodeToString(sum[:8])
}

View File

@@ -37,6 +37,9 @@ func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) {
if auth.DeepSeekToken != "direct-token" {
t.Fatalf("unexpected token: %q", auth.DeepSeekToken)
}
if auth.CallerID == "" {
t.Fatalf("expected caller id to be populated")
}
}
func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
@@ -58,6 +61,24 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
if auth.DeepSeekToken != "account-token" {
t.Fatalf("unexpected account token: %q", auth.DeepSeekToken)
}
if auth.CallerID == "" {
t.Fatalf("expected caller id to be populated")
}
}
func TestCallerTokenIDStable(t *testing.T) {
a := callerTokenID("token-a")
b := callerTokenID("token-a")
c := callerTokenID("token-b")
if a == "" || b == "" || c == "" {
t.Fatalf("expected non-empty caller ids")
}
if a != b {
t.Fatalf("expected stable caller id, got %q and %q", a, b)
}
if a == c {
t.Fatalf("expected different caller id for different tokens")
}
}
func TestDetermineMissingToken(t *testing.T) {

View File

@@ -755,11 +755,15 @@ func (r *Runner) cases() []caseDef {
{ID: "healthz_ok", Run: r.caseHealthz},
{ID: "readyz_ok", Run: r.caseReadyz},
{ID: "models_openai", Run: r.caseModelsOpenAI},
{ID: "model_openai_by_id", Run: r.caseModelOpenAIByID},
{ID: "models_claude", Run: r.caseModelsClaude},
{ID: "admin_login_verify", Run: r.caseAdminLoginVerify},
{ID: "admin_queue_status", Run: r.caseAdminQueueStatus},
{ID: "chat_nonstream_basic", Run: r.caseChatNonstream},
{ID: "chat_stream_basic", Run: r.caseChatStream},
{ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream},
{ID: "responses_stream_basic", Run: r.caseResponsesStream},
{ID: "embeddings_contract", Run: r.caseEmbeddings},
{ID: "reasoner_stream", Run: r.caseReasonerStream},
{ID: "toolcall_nonstream", Run: r.caseToolcallNonstream},
{ID: "toolcall_stream", Run: r.caseToolcallStream},
@@ -817,6 +821,19 @@ func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error {
return nil
}
func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true})
if err != nil {
@@ -942,6 +959,115 @@ func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error {
return nil
}
func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/responses",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "gpt-4o",
"input": "请简要回答 hello",
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body)))
responseID := asString(m["id"])
cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body)))
if responseID != "" {
getResp, getErr := cc.request(ctx, requestSpec{
Method: http.MethodGet,
Path: "/v1/responses/" + responseID,
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Retryable: true,
})
if getErr != nil {
return getErr
}
cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode))
}
return nil
}
func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/responses",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "gpt-4o",
"input": "请流式回答 hello",
"stream": true,
},
Stream: true,
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
frames, done := parseSSEFrames(resp.Body)
cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
hasCreated := false
hasCompleted := false
for _, f := range frames {
switch asString(f["type"]) {
case "response.created":
hasCreated = true
case "response.completed":
hasCompleted = true
}
}
cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("done_terminated", done, "expected [DONE]")
return nil
}
func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,
Path: "/v1/embeddings",
Headers: map[string]string{
"Authorization": "Bearer " + r.apiKey,
},
Body: map[string]any{
"model": "gpt-4o",
"input": []string{"hello", "world"},
},
Retryable: true,
})
if err != nil {
return err
}
cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode))
var m map[string]any
_ = json.Unmarshal(resp.Body, &m)
if resp.StatusCode == http.StatusOK {
cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body)))
data, _ := m["data"].([]any)
cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
errObj, _ := m["error"].(map[string]any)
_, hasCode := errObj["code"]
_, hasParam := errObj["param"]
cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body)))
cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body)))
return nil
}
func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodPost,

View File

@@ -0,0 +1,30 @@
package util
type StandardRequest struct {
Surface string
RequestedModel string
ResolvedModel string
ResponseModel string
Messages []any
FinalPrompt string
ToolNames []string
Stream bool
Thinking bool
Search bool
PassThrough map[string]any
}
func (r StandardRequest) CompletionPayload(sessionID string) map[string]any {
payload := map[string]any{
"chat_session_id": sessionID,
"parent_message_id": nil,
"prompt": r.FinalPrompt,
"ref_file_ids": []any{},
"thinking_enabled": r.Thinking,
"search_enabled": r.Search,
}
for k, v := range r.PassThrough {
payload[k] = v
}
return payload
}