diff --git a/api/chat-stream.js b/api/chat-stream.js index 680651d..1a8e896 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -85,7 +85,8 @@ module.exports = async function handler(req, res) { const finalPrompt = asString(prep.body.final_prompt); const thinkingEnabled = toBool(prep.body.thinking_enabled); const searchEnabled = toBool(prep.body.search_enabled); - const toolNames = extractToolNames(payload.tools); + const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools); + const toolNames = toolPolicy.toolNames; if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) { writeOpenAIError(res, 500, 'invalid vercel prepare response'); @@ -156,7 +157,8 @@ module.exports = async function handler(req, res) { let currentType = thinkingEnabled ? 'thinking' : 'text'; let thinkingText = ''; let outputText = ''; - const toolSieveEnabled = toolNames.length > 0; + const toolSieveEnabled = toolPolicy.toolSieveEnabled; + const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas; const toolSieveState = createToolSieveState(); let toolCallsEmitted = false; const streamToolCallIDs = new Map(); @@ -297,6 +299,9 @@ module.exports = async function handler(req, res) { const events = processToolSieveChunk(toolSieveState, p.text, toolNames); for (const evt of events) { if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) { + if (!emitEarlyToolDeltas) { + continue; + } toolCallsEmitted = true; sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) }); continue; @@ -407,6 +412,37 @@ function relayPreparedFailure(res, prep) { writeOpenAIError(res, prep.status || 500, 'vercel prepare failed'); } +function resolveToolcallPolicy(prepBody, payloadTools) { + const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names); + const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools); + const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match); + const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high); + return { + toolNames, + toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled, + emitEarlyToolDeltas, + }; +} + +function normalizePreparedToolNames(v) { + if (!Array.isArray(v) || v.length === 0) { + return []; + } + const out = []; + for (const item of v) { + const name = asString(item); + if (!name) { + continue; + } + out.push(name); + } + return out; +} + +function boolDefaultTrue(v) { + return v !== false; +} + async function safeReadText(resp) { if (!resp) { return ''; @@ -933,4 +969,7 @@ module.exports.__test = { extractContentRecursive, shouldSkipPath, asString, + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, }; diff --git a/api/chat-stream.test.js b/api/chat-stream.test.js index c849f7c..7424df2 100644 --- a/api/chat-stream.test.js +++ b/api/chat-stream.test.js @@ -10,10 +10,50 @@ const { flushToolSieve, } = require('./helpers/stream-tool-sieve'); -const { parseChunkForContent } = handler.__test; +const { + parseChunkForContent, + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, +} = handler.__test; test('chat-stream exposes parser test hooks', () => { assert.equal(typeof parseChunkForContent, 'function'); + assert.equal(typeof resolveToolcallPolicy, 'function'); +}); + +test('resolveToolcallPolicy defaults to feature-match + early emit when prepare flags missing', () => { + const policy = resolveToolcallPolicy( + {}, + [{ type: 'function', function: { name: 'read_file', parameters: { type: 'object' } } }], + ); + assert.deepEqual(policy.toolNames, ['read_file']); + assert.equal(policy.toolSieveEnabled, true); + assert.equal(policy.emitEarlyToolDeltas, true); +}); + +test('resolveToolcallPolicy respects prepare flags and prepared tool names', () => { + const policy = resolveToolcallPolicy( + { + tool_names: [' prepped_tool ', '', null], + toolcall_feature_match: false, + toolcall_early_emit_high: false, + }, + [{ type: 'function', function: { name: 'fallback_tool', parameters: { type: 'object' } } }], + ); + assert.deepEqual(policy.toolNames, ['prepped_tool']); + assert.equal(policy.toolSieveEnabled, false); + assert.equal(policy.emitEarlyToolDeltas, false); +}); + +test('normalizePreparedToolNames filters empty values', () => { + assert.deepEqual(normalizePreparedToolNames([' a ', '', null, 'b']), ['a', 'b']); +}); + +test('boolDefaultTrue keeps false only when explicitly false', () => { + assert.equal(boolDefaultTrue(false), false); + assert.equal(boolDefaultTrue(true), true); + assert.equal(boolDefaultTrue(undefined), true); }); test('parseChunkForContent keeps split response/content fragments inside response array', () => { diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 9aaa7cd..ff324b4 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -16,17 +16,11 @@ import ( ) func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) { - a, err := h.Auth.Determine(r) + a, err := h.Auth.DetermineCaller(r) if err != nil { - status := http.StatusUnauthorized - detail := err.Error() - if err == auth.ErrNoAccount { - status = http.StatusTooManyRequests - } - writeOpenAIError(w, status, detail) + writeOpenAIError(w, http.StatusUnauthorized, err.Error()) return } - defer h.Auth.Release(a) id := strings.TrimSpace(chi.URLParam(r, "response_id")) if id == "" { diff --git a/internal/adapter/openai/responses_route_test.go b/internal/adapter/openai/responses_route_test.go index 6db0c23..574c6fa 100644 --- a/internal/adapter/openai/responses_route_test.go +++ b/internal/adapter/openai/responses_route_test.go @@ -26,6 +26,22 @@ func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) { return store, resolver } +func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) { + t.Helper() + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["managed-key"], + "accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}] + }`) + t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1") + t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0") + store := config.LoadStore() + pool := account.NewPool(store) + resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) { + return "unused", nil + }) + return store, resolver +} + func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth { t.Helper() req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) @@ -123,3 +139,38 @@ func TestResponsesRouteValidationContract(t *testing.T) { }) } } + +func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) { + store, resolver := newManagedKeyResolver(t) + h := &Handler{Store: store, Auth: resolver} + r := chi.NewRouter() + RegisterRoutes(r, h) + + ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + ownerReq.Header.Set("Authorization", "Bearer managed-key") + ownerAuth, err := resolver.DetermineCaller(ownerReq) + if err != nil { + t.Fatalf("determine caller failed: %v", err) + } + owner := responseStoreOwner(ownerAuth) + h.getResponseStore().put(owner, "resp_test", map[string]any{ + "id": "resp_test", + "object": "response", + }) + + occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + occupyReq.Header.Set("Authorization", "Bearer managed-key") + occupied, err := resolver.Determine(occupyReq) + if err != nil { + t.Fatalf("expected first acquire to succeed: %v", err) + } + defer resolver.Release(occupied) + + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil) + req.Header.Set("Authorization", "Bearer managed-key") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String()) + } +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index c8bd6d0..65006c4 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -93,15 +93,18 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque } leased = true writeJSON(w, http.StatusOK, map[string]any{ - "session_id": sessionID, - "lease_id": leaseID, - "model": stdReq.ResponseModel, - "final_prompt": stdReq.FinalPrompt, - "thinking_enabled": stdReq.Thinking, - "search_enabled": stdReq.Search, - "deepseek_token": a.DeepSeekToken, - "pow_header": powHeader, - "payload": payload, + "session_id": sessionID, + "lease_id": leaseID, + "model": stdReq.ResponseModel, + "final_prompt": stdReq.FinalPrompt, + "thinking_enabled": stdReq.Thinking, + "search_enabled": stdReq.Search, + "tool_names": stdReq.ToolNames, + "toolcall_feature_match": h.toolcallFeatureMatchEnabled(), + "toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(), + "deepseek_token": a.DeepSeekToken, + "pow_header": powHeader, + "payload": payload, }) } diff --git a/internal/auth/request.go b/internal/auth/request.go index d7faf8d..25980cf 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -83,6 +83,26 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { return a, nil } +// DetermineCaller resolves caller identity without acquiring any pooled account. +// Use this for local-cache lookup routes that only need tenant isolation. +func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) { + callerKey := extractCallerToken(req) + if callerKey == "" { + return nil, ErrUnauthorized + } + callerID := callerTokenID(callerKey) + a := &RequestAuth{ + UseConfigToken: false, + CallerID: callerID, + resolver: r, + TriedAccounts: map[string]bool{}, + } + if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) { + a.DeepSeekToken = callerKey + } + return a, nil +} + func WithAuth(ctx context.Context, a *RequestAuth) context.Context { return context.WithValue(ctx, authCtxKey, a) } diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index ee74092..c292856 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -66,6 +66,26 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) { } } +func TestDetermineCallerWithManagedKeySkipsAccountAcquire(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) + req.Header.Set("x-api-key", "managed-key") + + a, err := r.DetermineCaller(req) + if err != nil { + t.Fatalf("determine caller failed: %v", err) + } + if a.CallerID == "" { + t.Fatalf("expected caller id to be populated") + } + if a.UseConfigToken { + t.Fatalf("expected no config-token lease for caller-only auth") + } + if a.AccountID != "" { + t.Fatalf("expected empty account id, got %q", a.AccountID) + } +} + func TestCallerTokenIDStable(t *testing.T) { a := callerTokenID("token-a") b := callerTokenID("token-a") @@ -93,3 +113,16 @@ func TestDetermineMissingToken(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestDetermineCallerMissingToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) + + _, err := r.DetermineCaller(req) + if err == nil { + t.Fatal("expected unauthorized error") + } + if err != ErrUnauthorized { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index d583159..d391462 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -73,7 +73,7 @@ type Config struct { } type CompatConfig struct { - WideInputStrictOutput bool `json:"wide_input_strict_output,omitempty"` + WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"` } type ToolcallConfig struct { @@ -109,7 +109,7 @@ func (c Config) MarshalJSON() ([]byte, error) { if len(c.ModelAliases) > 0 { m["model_aliases"] = c.ModelAliases } - if c.Compat.WideInputStrictOutput { + if c.Compat.WideInputStrictOutput != nil { m["compat"] = c.Compat } if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" { @@ -194,12 +194,14 @@ func (c *Config) UnmarshalJSON(b []byte) error { func (c Config) Clone() Config { clone := Config{ - Keys: slices.Clone(c.Keys), - Accounts: slices.Clone(c.Accounts), - ClaudeMapping: cloneStringMap(c.ClaudeMapping), - ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), - ModelAliases: cloneStringMap(c.ModelAliases), - Compat: c.Compat, + Keys: slices.Clone(c.Keys), + Accounts: slices.Clone(c.Accounts), + ClaudeMapping: cloneStringMap(c.ClaudeMapping), + ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), + ModelAliases: cloneStringMap(c.ModelAliases), + Compat: CompatConfig{ + WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput), + }, Toolcall: c.Toolcall, Responses: c.Responses, Embeddings: c.Embeddings, @@ -224,6 +226,14 @@ func cloneStringMap(in map[string]string) map[string]string { return out } +func cloneBoolPtr(in *bool) *bool { + if in == nil { + return nil + } + v := *in + return &v +} + type Store struct { mu sync.RWMutex cfg Config @@ -569,9 +579,12 @@ func (s *Store) ModelAliases() map[string]string { } func (s *Store) CompatWideInputStrictOutput() bool { - // Current default policy is always wide-input / strict-output. - // Kept as a method so callers do not depend on storage shape. - return true + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Compat.WideInputStrictOutput == nil { + return true + } + return *s.cfg.Compat.WideInputStrictOutput } func (s *Store) ToolcallMode() string { diff --git a/internal/config/config_edge_test.go b/internal/config/config_edge_test.go index 81cc7ec..1138867 100644 --- a/internal/config/config_edge_test.go +++ b/internal/config/config_edge_test.go @@ -320,6 +320,39 @@ func TestStoreFindAccountNotFound(t *testing.T) { } } +func TestStoreCompatWideInputStrictOutputDefaultTrue(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) + store := LoadStore() + if !store.CompatWideInputStrictOutput() { + t.Fatal("expected default wide_input_strict_output=true when unset") + } +} + +func TestStoreCompatWideInputStrictOutputCanDisable(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[],"compat":{"wide_input_strict_output":false}}`) + store := LoadStore() + if store.CompatWideInputStrictOutput() { + t.Fatal("expected wide_input_strict_output=false when explicitly configured") + } + + snap := store.Snapshot() + data, err := snap.MarshalJSON() + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + t.Fatalf("decode failed: %v", err) + } + rawCompat, ok := out["compat"].(map[string]any) + if !ok { + t.Fatalf("expected compat in marshaled output, got %#v", out) + } + if rawCompat["wide_input_strict_output"] != false { + t.Fatalf("expected explicit false in compat, got %#v", rawCompat) + } +} + func TestStoreIsEnvBacked(t *testing.T) { t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`) store := LoadStore()