diff --git a/API.en.md b/API.en.md index babd1dc..ef1a6f3 100644 --- a/API.en.md +++ b/API.en.md @@ -309,7 +309,7 @@ data: [DONE] ### `GET /v1/responses/{response_id}` -Business auth required. Fetches cached responses created by `POST /v1/responses`. +Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read). > Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`). diff --git a/API.md b/API.md index fa07cfa..3770924 100644 --- a/API.md +++ b/API.md @@ -309,7 +309,7 @@ data: [DONE] ### `GET /v1/responses/{response_id}` -需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象。 +需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。 > 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。 diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index a7d3431..44240af 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -63,32 +63,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { writeClaudeError(w, http.StatusBadRequest, "invalid json") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + norm, err := normalizeClaudeRequest(h.Store, req) + if err != nil { + writeClaudeError(w, http.StatusBadRequest, err.Error()) return } - if _, ok := req["max_tokens"]; !ok { - req["max_tokens"] = 8192 - } - - normalized := normalizeClaudeMessages(messagesRaw) - payload := cloneMap(req) - payload["messages"] = normalized - toolsRequested, _ := req["tools"].([]any) - if len(toolsRequested) > 0 && !hasSystemMessage(normalized) { - payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...) - } - - dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store) - dsModel, _ := dsPayload["model"].(string) - thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) - if !ok { - thinkingEnabled = false - searchEnabled = false - } - finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + stdReq := norm.Standard sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -100,14 +80,7 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW") return } - requestPayload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } + requestPayload := stdReq.CompletionPayload(sessionID) resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) if err != nil { writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.") @@ -120,15 +93,14 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { return } - toolNames := extractClaudeToolNames(toolsRequested) - if util.ToBool(req["stream"]) { - h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames) + if stdReq.Stream { + h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } - result := sse.CollectStream(resp, thinkingEnabled, true) + result := sse.CollectStream(resp, stdReq.Thinking, true) fullText := result.Text fullThinking := result.Thinking - detected := util.ParseToolCalls(fullText, toolNames) + detected := util.ParseToolCalls(fullText, stdReq.ToolNames) content := make([]map[string]any, 0, 4) if fullThinking != "" { content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking}) @@ -154,12 +126,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { "id": fmt.Sprintf("msg_%d", time.Now().UnixNano()), "type": "message", "role": "assistant", - "model": model, + "model": stdReq.ResponseModel, "content": content, "stop_reason": stopReason, "stop_sequence": nil, "usage": map[string]any{ - "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)), + "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", norm.NormalizedMessages)), "output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText), }, }) diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go new file mode 100644 index 0000000..de97c6a --- /dev/null +++ b/internal/adapter/claude/standard_request.go @@ -0,0 +1,58 @@ +package claude + +import ( + "fmt" + "strings" + + "ds2api/internal/config" + "ds2api/internal/util" +) + +type claudeNormalizedRequest struct { + Standard util.StandardRequest + NormalizedMessages []any +} + +func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) { + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { + return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.") + } + if _, ok := req["max_tokens"]; !ok { + req["max_tokens"] = 8192 + } + normalizedMessages := normalizeClaudeMessages(messagesRaw) + payload := cloneMap(req) + payload["messages"] = normalizedMessages + toolsRequested, _ := req["tools"].([]any) + if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) { + payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...) + } + + dsPayload := util.ConvertClaudeToDeepSeek(payload, store) + dsModel, _ := dsPayload["model"].(string) + thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel) + if !ok { + thinkingEnabled = false + searchEnabled = false + } + finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + toolNames := extractClaudeToolNames(toolsRequested) + + return claudeNormalizedRequest{ + Standard: util.StandardRequest{ + Surface: "anthropic_messages", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: dsModel, + ResponseModel: strings.TrimSpace(model), + Messages: payload["messages"].([]any), + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + }, + NormalizedMessages: normalizedMessages, + }, nil +} diff --git a/internal/adapter/claude/standard_request_test.go b/internal/adapter/claude/standard_request_test.go new file mode 100644 index 0000000..7ffdfb8 --- /dev/null +++ b/internal/adapter/claude/standard_request_test.go @@ -0,0 +1,38 @@ +package claude + +import ( + "testing" + + "ds2api/internal/config" +) + +func TestNormalizeClaudeRequest(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-opus-4-6", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "stream": true, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if norm.Standard.ResolvedModel == "" { + t.Fatalf("expected resolved model") + } + if !norm.Standard.Stream { + t.Fatalf("expected stream=true") + } + if len(norm.Standard.ToolNames) == 0 { + t.Fatalf("expected tool names") + } + if norm.Standard.FinalPrompt == "" { + t.Fatalf("expected non-empty final prompt") + } +} diff --git a/internal/adapter/openai/embeddings_route_test.go b/internal/adapter/openai/embeddings_route_test.go new file mode 100644 index 0000000..4395d16 --- /dev/null +++ b/internal/adapter/openai/embeddings_route_test.go @@ -0,0 +1,96 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", cfgJSON) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func TestEmbeddingsRouteContract(t *testing.T) { + store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + t.Run("unauthorized", func(t *testing.T) { + body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("ok", func(t *testing.T) { + body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + if out["object"] != "list" { + t.Fatalf("unexpected object: %#v", out["object"]) + } + data, _ := out["data"].([]any) + if len(data) != 2 { + t.Fatalf("expected 2 embeddings, got %d", len(data)) + } + }) +} + +func TestEmbeddingsRouteProviderMissing(t *testing.T) { + store, resolver := newResolverWithConfigJSON(t, `{}`) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body) + req.Header.Set("Authorization", "Bearer test-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotImplemented { + t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + errObj, _ := out["error"].(map[string]any) + if _, ok := errObj["code"]; !ok { + t.Fatalf("expected error.code in response: %#v", out) + } + if _, ok := errObj["param"]; !ok { + t.Fatalf("expected error.param in response: %#v", out) + } +} diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index a2a1c4d..fadca38 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -91,24 +91,11 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) return } - resolvedModel, ok := config.ResolveModel(h.Store, model) - if !ok { - writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) - return - } - thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) - responseModel := strings.TrimSpace(model) - if responseModel == "" { - responseModel = resolvedModel - } - - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -124,25 +111,17 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") return } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } - applyOpenAIChatPassThrough(req, payload) + payload := stdReq.CompletionPayload(sessionID) resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) if err != nil { writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") return } - if util.ToBool(req["stream"]) { - h.handleStream(w, r, resp, sessionID, responseModel, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + if stdReq.Stream { + h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } - h.handleNonStream(w, r.Context(), resp, sessionID, responseModel, finalPrompt, thinkingEnabled, toolNames) + h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) } func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { @@ -208,7 +187,8 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt created := time.Now().Unix() firstChunkSent := false - bufferToolContent := len(toolNames) > 0 + bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() + emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() var toolSieve toolStreamSieveState toolCallsEmitted := false streamToolCallIDs := map[int]string{} @@ -377,6 +357,9 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt } for _, evt := range events { if len(evt.ToolCallDeltas) > 0 { + if !emitEarlyToolDeltas { + continue + } toolCallsEmitted = true tcDelta := map[string]any{ "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs), @@ -568,17 +551,23 @@ func openAIErrorCode(status int) string { } func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) { - for _, k := range []string{ - "temperature", - "top_p", - "max_tokens", - "max_completion_tokens", - "presence_penalty", - "frequency_penalty", - "stop", - } { - if v, ok := req[k]; ok { - payload[k] = v - } + for k, v := range collectOpenAIChatPassThrough(req) { + payload[k] = v } } + +func (h *Handler) toolcallFeatureMatchEnabled() bool { + if h == nil || h.Store == nil { + return true + } + mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode())) + return mode == "" || mode == "feature_match" +} + +func (h *Handler) toolcallEarlyEmitHighConfidence() bool { + if h == nil || h.Store == nil { + return true + } + level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence())) + return level == "" || level == "high" +} diff --git a/internal/adapter/openai/response_store.go b/internal/adapter/openai/response_store.go index 4f51dfa..63ebbaa 100644 --- a/internal/adapter/openai/response_store.go +++ b/internal/adapter/openai/response_store.go @@ -3,9 +3,12 @@ package openai import ( "sync" "time" + + "ds2api/internal/auth" ) type storedResponse struct { + Owner string Value map[string]any ExpiresAt time.Time } @@ -26,32 +29,47 @@ func newResponseStore(ttl time.Duration) *responseStore { } } -func (s *responseStore) put(id string, value map[string]any) { - if s == nil || id == "" || value == nil { +func responseStoreKey(owner, id string) string { + return owner + "\x00" + id +} + +func responseStoreOwner(a *auth.RequestAuth) string { + if a == nil { + return "" + } + return a.CallerID +} + +func (s *responseStore) put(owner, id string, value map[string]any) { + if s == nil || owner == "" || id == "" || value == nil { return } now := time.Now() s.mu.Lock() defer s.mu.Unlock() s.sweepLocked(now) - s.items[id] = storedResponse{ + s.items[responseStoreKey(owner, id)] = storedResponse{ + Owner: owner, Value: cloneAnyMap(value), ExpiresAt: now.Add(s.ttl), } } -func (s *responseStore) get(id string) (map[string]any, bool) { - if s == nil || id == "" { +func (s *responseStore) get(owner, id string) (map[string]any, bool) { + if s == nil || owner == "" || id == "" { return nil, false } now := time.Now() s.mu.Lock() defer s.mu.Unlock() s.sweepLocked(now) - item, ok := s.items[id] + item, ok := s.items[responseStoreKey(owner, id)] if !ok { return nil, false } + if item.Owner != owner { + return nil, false + } return cloneAnyMap(item.Value), true } diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go index b23597d..d270e1a 100644 --- a/internal/adapter/openai/responses_embeddings_test.go +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -54,8 +54,8 @@ func TestDeterministicEmbeddingStable(t *testing.T) { func TestResponseStorePutGet(t *testing.T) { st := newResponseStore(100 * time.Millisecond) - st.put("resp_1", map[string]any{"id": "resp_1"}) - got, ok := st.get("resp_1") + st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"}) + got, ok := st.get("owner_1", "resp_1") if !ok { t.Fatal("expected stored response") } @@ -63,3 +63,11 @@ func TestResponseStorePutGet(t *testing.T) { t.Fatalf("unexpected response payload: %#v", got) } } + +func TestResponseStoreTenantIsolation(t *testing.T) { + st := newResponseStore(100 * time.Millisecond) + st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"}) + if _, ok := st.get("owner_b", "resp_1"); ok { + t.Fatal("expected owner_b to be isolated from owner_a response") + } +} diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 8fbb132..b70fe0b 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -12,19 +12,35 @@ import ( "github.com/google/uuid" "ds2api/internal/auth" - "ds2api/internal/config" "ds2api/internal/sse" "ds2api/internal/util" ) func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + id := strings.TrimSpace(chi.URLParam(r, "response_id")) if id == "" { writeOpenAIError(w, http.StatusBadRequest, "response_id is required.") return } + owner := responseStoreOwner(a) + if owner == "" { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized") + return + } st := h.getResponseStore() - item, ok := st.get(id) + item, ok := st.get(owner, id) if !ok { writeOpenAIError(w, http.StatusNotFound, "Response not found.") return @@ -45,32 +61,22 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { } defer h.Auth.Release(a) r = r.WithContext(auth.WithAuth(r.Context(), a)) + owner := responseStoreOwner(a) + if owner == "" { + writeOpenAIError(w, http.StatusUnauthorized, "unauthorized") + return + } var req map[string]any if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeOpenAIError(w, http.StatusBadRequest, "invalid json") return } - - model, _ := req["model"].(string) - model = strings.TrimSpace(model) - if model == "" { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.") + stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) return } - resolvedModel, ok := config.ResolveModel(h.Store, model) - if !ok { - writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) - return - } - thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) - - messagesRaw := responsesMessagesFromRequest(req) - if len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'input' or 'messages'.") - return - } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -86,15 +92,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") return } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } - applyOpenAIChatPassThrough(req, payload) + payload := stdReq.CompletionPayload(sessionID) resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) if err != nil { writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") @@ -102,14 +100,14 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { } responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") - if util.ToBool(req["stream"]) { - h.handleResponsesStream(w, r, resp, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + if stdReq.Stream { + h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } - h.handleResponsesNonStream(w, resp, responseID, model, finalPrompt, thinkingEnabled, toolNames) + h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) } -func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { +func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -118,11 +116,11 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res } result := sse.CollectStream(resp, thinkingEnabled, true) responseObj := buildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames) - h.getResponseStore().put(responseID, responseObj) + h.getResponseStore().put(owner, responseID, responseObj) writeJSON(w, http.StatusOK, responseObj) } -func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { +func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -160,7 +158,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, initialType = "thinking" } parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType) - bufferToolContent := len(toolNames) > 0 + bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() + emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() var sieve toolStreamSieveState thinking := strings.Builder{} text := strings.Builder{} @@ -194,7 +193,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, if toolCallsEmitted { obj["status"] = "completed" } - h.getResponseStore().put(responseID, obj) + h.getResponseStore().put(owner, responseID, obj) sendEvent("response.completed", map[string]any{ "type": "response.completed", "response": obj, @@ -259,6 +258,9 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, }) } if len(evt.ToolCallDeltas) > 0 { + if !emitEarlyToolDeltas { + continue + } toolCallsEmitted = true sendEvent("response.output_tool_call.delta", map[string]any{ "type": "response.output_tool_call.delta", diff --git a/internal/adapter/openai/responses_route_test.go b/internal/adapter/openai/responses_route_test.go new file mode 100644 index 0000000..6db0c23 --- /dev/null +++ b/internal/adapter/openai/responses_route_test.go @@ -0,0 +1,125 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/account" + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`) + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + +func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth { + t.Helper() + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer "+token) + a, err := resolver.Determine(req) + if err != nil { + t.Fatalf("determine auth failed: %v", err) + } + return a +} + +func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) { + store, resolver := newDirectTokenResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + ownerA := responseStoreOwner(authForToken(t, resolver, "token-a")) + h.getResponseStore().put(ownerA, "resp_test", map[string]any{ + "id": "resp_test", + "object": "response", + }) + + t.Run("unauthorized", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("cross-tenant-not-found", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer token-b") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String()) + } + }) + + t.Run("same-tenant-ok", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer token-a") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body failed: %v", err) + } + if body["id"] != "resp_test" { + t.Fatalf("unexpected body: %#v", body) + } + }) +} + +func TestResponsesRouteValidationContract(t *testing.T) { + store, resolver := newDirectTokenResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + tests := []struct { + name string + body string + }{ + {name: "missing_model", body: `{"input":"hello"}`}, + {name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body)) + req.Header.Set("Authorization", "Bearer token-a") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String()) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + errObj, _ := out["error"].(map[string]any) + if _, ok := errObj["code"]; !ok { + t.Fatalf("expected error.code: %#v", out) + } + if _, ok := errObj["param"]; !ok { + t.Fatalf("expected error.param: %#v", out) + } + }) + } +} diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go new file mode 100644 index 0000000..52344d4 --- /dev/null +++ b/internal/adapter/openai/standard_request.go @@ -0,0 +1,104 @@ +package openai + +import ( + "fmt" + "strings" + + "ds2api/internal/config" + "ds2api/internal/util" +) + +func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { + model, _ := req["model"].(string) + messagesRaw, _ := req["messages"].([]any) + if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.") + } + resolvedModel, ok := config.ResolveModel(store, model) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + responseModel := strings.TrimSpace(model) + if responseModel == "" { + responseModel = resolvedModel + } + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + passThrough := collectOpenAIChatPassThrough(req) + + return util.StandardRequest{ + Surface: "openai_chat", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: resolvedModel, + ResponseModel: responseModel, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) { + model, _ := req["model"].(string) + model = strings.TrimSpace(model) + if model == "" { + return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.") + } + resolvedModel, ok := config.ResolveModel(store, model) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + + // Keep width-control as an explicit policy hook even if current default is true. + allowWideInput := true + if store != nil { + allowWideInput = store.CompatWideInputStrictOutput() + } + var messagesRaw []any + if allowWideInput { + messagesRaw = responsesMessagesFromRequest(req) + } else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { + messagesRaw = msgs + } + if len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.") + } + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) + passThrough := collectOpenAIChatPassThrough(req) + + return util.StandardRequest{ + Surface: "openai_responses", + RequestedModel: model, + ResolvedModel: resolvedModel, + ResponseModel: model, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func collectOpenAIChatPassThrough(req map[string]any) map[string]any { + out := map[string]any{} + for _, k := range []string{ + "temperature", + "top_p", + "max_tokens", + "max_completion_tokens", + "presence_penalty", + "frequency_penalty", + "stop", + } { + if v, ok := req[k]; ok { + out[k] = v + } + } + return out +} diff --git a/internal/adapter/openai/standard_request_test.go b/internal/adapter/openai/standard_request_test.go new file mode 100644 index 0000000..f3453a2 --- /dev/null +++ b/internal/adapter/openai/standard_request_test.go @@ -0,0 +1,60 @@ +package openai + +import ( + "testing" + + "ds2api/internal/config" +) + +func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{}`) + return config.LoadStore() +} + +func TestNormalizeOpenAIChatRequest(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-5-codex", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "temperature": 0.3, + "stream": true, + } + n, err := normalizeOpenAIChatRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ResolvedModel != "deepseek-reasoner" { + t.Fatalf("unexpected resolved model: %s", n.ResolvedModel) + } + if !n.Stream { + t.Fatalf("expected stream=true") + } + if _, ok := n.PassThrough["temperature"]; !ok { + t.Fatalf("expected temperature passthrough") + } + if n.FinalPrompt == "" { + t.Fatalf("expected non-empty final prompt") + } +} + +func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) { + store := newEmptyStoreForNormalizeTest(t) + req := map[string]any{ + "model": "gpt-4o", + "input": "ping", + "instructions": "system", + } + n, err := normalizeOpenAIResponsesRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + if n.ResolvedModel != "deepseek-chat" { + t.Fatalf("unexpected resolved model: %s", n.ResolvedModel) + } + if len(n.Messages) != 2 { + t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages)) + } +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index be8a590..c8bd6d0 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -56,24 +56,15 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - model, _ := req["model"].(string) - messagesRaw, _ := req["messages"].([]any) - if model == "" || len(messagesRaw) == 0 { - writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + stdReq, err := normalizeOpenAIChatRequest(h.Store, req) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) return } - resolvedModel, ok := config.ResolveModel(h.Store, model) - if !ok { - writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model)) + if !stdReq.Stream { + writeOpenAIError(w, http.StatusBadRequest, "stream must be true") return } - thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) - responseModel := strings.TrimSpace(model) - if responseModel == "" { - responseModel = resolvedModel - } - - finalPrompt, _ := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -94,15 +85,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque return } - payload := map[string]any{ - "chat_session_id": sessionID, - "parent_message_id": nil, - "prompt": finalPrompt, - "ref_file_ids": []any{}, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, - } - applyOpenAIChatPassThrough(req, payload) + payload := stdReq.CompletionPayload(sessionID) leaseID := h.holdStreamLease(a) if leaseID == "" { writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease") @@ -112,10 +95,10 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque writeJSON(w, http.StatusOK, map[string]any{ "session_id": sessionID, "lease_id": leaseID, - "model": responseModel, - "final_prompt": finalPrompt, - "thinking_enabled": thinkingEnabled, - "search_enabled": searchEnabled, + "model": stdReq.ResponseModel, + "final_prompt": stdReq.FinalPrompt, + "thinking_enabled": stdReq.Thinking, + "search_enabled": stdReq.Search, "deepseek_token": a.DeepSeekToken, "pow_header": powHeader, "payload": payload, diff --git a/internal/auth/request.go b/internal/auth/request.go index ea3d7f1..d7faf8d 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -2,6 +2,8 @@ package auth import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "net/http" "strings" @@ -22,6 +24,7 @@ var ( type RequestAuth struct { UseConfigToken bool DeepSeekToken string + CallerID string AccountID string Account config.Account TriedAccounts map[string]bool @@ -45,9 +48,16 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { if callerKey == "" { return nil, ErrUnauthorized } + callerID := callerTokenID(callerKey) ctx := req.Context() if !r.Store.HasAPIKey(callerKey) { - return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil + return &RequestAuth{ + UseConfigToken: false, + DeepSeekToken: callerKey, + CallerID: callerID, + resolver: r, + TriedAccounts: map[string]bool{}, + }, nil } target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account")) acc, ok := r.Pool.AcquireWait(ctx, target, nil) @@ -56,6 +66,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { } a := &RequestAuth{ UseConfigToken: true, + CallerID: callerID, AccountID: acc.Identifier(), Account: acc, TriedAccounts: map[string]bool{}, @@ -158,3 +169,12 @@ func extractCallerToken(req *http.Request) string { } return strings.TrimSpace(req.Header.Get("x-api-key")) } + +func callerTokenID(token string) string { + token = strings.TrimSpace(token) + if token == "" { + return "" + } + sum := sha256.Sum256([]byte(token)) + return "caller:" + hex.EncodeToString(sum[:8]) +} diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index 1d568f3..ee74092 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -37,6 +37,9 @@ func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) { if auth.DeepSeekToken != "direct-token" { t.Fatalf("unexpected token: %q", auth.DeepSeekToken) } + if auth.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } } func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { @@ -58,6 +61,24 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { if auth.DeepSeekToken != "account-token" { t.Fatalf("unexpected account token: %q", auth.DeepSeekToken) } + if auth.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } +} + +func TestCallerTokenIDStable(t *testing.T) { + a := callerTokenID("token-a") + b := callerTokenID("token-a") + c := callerTokenID("token-b") + if a == "" || b == "" || c == "" { + t.Fatalf("expected non-empty caller ids") + } + if a != b { + t.Fatalf("expected stable caller id, got %q and %q", a, b) + } + if a == c { + t.Fatalf("expected different caller id for different tokens") + } } func TestDetermineMissingToken(t *testing.T) { diff --git a/internal/testsuite/runner.go b/internal/testsuite/runner.go index b48bce5..e6ae9a6 100644 --- a/internal/testsuite/runner.go +++ b/internal/testsuite/runner.go @@ -755,11 +755,15 @@ func (r *Runner) cases() []caseDef { {ID: "healthz_ok", Run: r.caseHealthz}, {ID: "readyz_ok", Run: r.caseReadyz}, {ID: "models_openai", Run: r.caseModelsOpenAI}, + {ID: "model_openai_by_id", Run: r.caseModelOpenAIByID}, {ID: "models_claude", Run: r.caseModelsClaude}, {ID: "admin_login_verify", Run: r.caseAdminLoginVerify}, {ID: "admin_queue_status", Run: r.caseAdminQueueStatus}, {ID: "chat_nonstream_basic", Run: r.caseChatNonstream}, {ID: "chat_stream_basic", Run: r.caseChatStream}, + {ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream}, + {ID: "responses_stream_basic", Run: r.caseResponsesStream}, + {ID: "embeddings_contract", Run: r.caseEmbeddings}, {ID: "reasoner_stream", Run: r.caseReasonerStream}, {ID: "toolcall_nonstream", Run: r.caseToolcallNonstream}, {ID: "toolcall_stream", Run: r.caseToolcallStream}, @@ -817,6 +821,19 @@ func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error { return nil } +func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error { resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true}) if err != nil { @@ -942,6 +959,115 @@ func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error { return nil } +func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请简要回答 hello", + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body))) + responseID := asString(m["id"]) + cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body))) + if responseID != "" { + getResp, getErr := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/v1/responses/" + responseID, + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Retryable: true, + }) + if getErr != nil { + return getErr + } + cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode)) + } + return nil +} + +func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请流式回答 hello", + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) + hasCreated := false + hasCompleted := false + for _, f := range frames { + switch asString(f["type"]) { + case "response.created": + hasCreated = true + case "response.completed": + hasCompleted = true + } + } + cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/embeddings", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": []string{"hello", "world"}, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + if resp.StatusCode == http.StatusOK { + cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body))) + data, _ := m["data"].([]any) + cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil + } + errObj, _ := m["error"].(map[string]any) + _, hasCode := errObj["code"] + _, hasParam := errObj["param"] + cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error { resp, err := cc.request(ctx, requestSpec{ Method: http.MethodPost, diff --git a/internal/util/standard_request.go b/internal/util/standard_request.go new file mode 100644 index 0000000..af73acf --- /dev/null +++ b/internal/util/standard_request.go @@ -0,0 +1,30 @@ +package util + +type StandardRequest struct { + Surface string + RequestedModel string + ResolvedModel string + ResponseModel string + Messages []any + FinalPrompt string + ToolNames []string + Stream bool + Thinking bool + Search bool + PassThrough map[string]any +} + +func (r StandardRequest) CompletionPayload(sessionID string) map[string]any { + payload := map[string]any{ + "chat_session_id": sessionID, + "parent_message_id": nil, + "prompt": r.FinalPrompt, + "ref_file_ids": []any{}, + "thinking_enabled": r.Thinking, + "search_enabled": r.Search, + } + for k, v := range r.PassThrough { + payload[k] = v + } + return payload +}