feat: Introduce DetermineCaller for auth without account pooling and make wide_input_strict_output configurable.

This commit is contained in:
CJACK
2026-02-18 23:53:50 +08:00
parent 51c543631b
commit 2dcc230852
9 changed files with 257 additions and 31 deletions

View File

@@ -85,7 +85,8 @@ module.exports = async function handler(req, res) {
const finalPrompt = asString(prep.body.final_prompt);
const thinkingEnabled = toBool(prep.body.thinking_enabled);
const searchEnabled = toBool(prep.body.search_enabled);
const toolNames = extractToolNames(payload.tools);
const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools);
const toolNames = toolPolicy.toolNames;
if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
writeOpenAIError(res, 500, 'invalid vercel prepare response');
@@ -156,7 +157,8 @@ module.exports = async function handler(req, res) {
let currentType = thinkingEnabled ? 'thinking' : 'text';
let thinkingText = '';
let outputText = '';
const toolSieveEnabled = toolNames.length > 0;
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
const toolSieveState = createToolSieveState();
let toolCallsEmitted = false;
const streamToolCallIDs = new Map();
@@ -297,6 +299,9 @@ module.exports = async function handler(req, res) {
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
for (const evt of events) {
if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) {
if (!emitEarlyToolDeltas) {
continue;
}
toolCallsEmitted = true;
sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) });
continue;
@@ -407,6 +412,37 @@ function relayPreparedFailure(res, prep) {
writeOpenAIError(res, prep.status || 500, 'vercel prepare failed');
}
function resolveToolcallPolicy(prepBody, payloadTools) {
const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names);
const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match);
const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
return {
toolNames,
toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled,
emitEarlyToolDeltas,
};
}
function normalizePreparedToolNames(v) {
if (!Array.isArray(v) || v.length === 0) {
return [];
}
const out = [];
for (const item of v) {
const name = asString(item);
if (!name) {
continue;
}
out.push(name);
}
return out;
}
function boolDefaultTrue(v) {
return v !== false;
}
async function safeReadText(resp) {
if (!resp) {
return '';
@@ -933,4 +969,7 @@ module.exports.__test = {
extractContentRecursive,
shouldSkipPath,
asString,
resolveToolcallPolicy,
normalizePreparedToolNames,
boolDefaultTrue,
};

View File

@@ -10,10 +10,50 @@ const {
flushToolSieve,
} = require('./helpers/stream-tool-sieve');
const { parseChunkForContent } = handler.__test;
const {
parseChunkForContent,
resolveToolcallPolicy,
normalizePreparedToolNames,
boolDefaultTrue,
} = handler.__test;
test('chat-stream exposes parser test hooks', () => {
assert.equal(typeof parseChunkForContent, 'function');
assert.equal(typeof resolveToolcallPolicy, 'function');
});
test('resolveToolcallPolicy defaults to feature-match + early emit when prepare flags missing', () => {
const policy = resolveToolcallPolicy(
{},
[{ type: 'function', function: { name: 'read_file', parameters: { type: 'object' } } }],
);
assert.deepEqual(policy.toolNames, ['read_file']);
assert.equal(policy.toolSieveEnabled, true);
assert.equal(policy.emitEarlyToolDeltas, true);
});
test('resolveToolcallPolicy respects prepare flags and prepared tool names', () => {
const policy = resolveToolcallPolicy(
{
tool_names: [' prepped_tool ', '', null],
toolcall_feature_match: false,
toolcall_early_emit_high: false,
},
[{ type: 'function', function: { name: 'fallback_tool', parameters: { type: 'object' } } }],
);
assert.deepEqual(policy.toolNames, ['prepped_tool']);
assert.equal(policy.toolSieveEnabled, false);
assert.equal(policy.emitEarlyToolDeltas, false);
});
test('normalizePreparedToolNames filters empty values', () => {
assert.deepEqual(normalizePreparedToolNames([' a ', '', null, 'b']), ['a', 'b']);
});
test('boolDefaultTrue keeps false only when explicitly false', () => {
assert.equal(boolDefaultTrue(false), false);
assert.equal(boolDefaultTrue(true), true);
assert.equal(boolDefaultTrue(undefined), true);
});
test('parseChunkForContent keeps split response/content fragments inside response array', () => {

View File

@@ -16,17 +16,11 @@ import (
)
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
a, err := h.Auth.Determine(r)
a, err := h.Auth.DetermineCaller(r)
if err != nil {
status := http.StatusUnauthorized
detail := err.Error()
if err == auth.ErrNoAccount {
status = http.StatusTooManyRequests
}
writeOpenAIError(w, status, detail)
writeOpenAIError(w, http.StatusUnauthorized, err.Error())
return
}
defer h.Auth.Release(a)
id := strings.TrimSpace(chi.URLParam(r, "response_id"))
if id == "" {

View File

@@ -26,6 +26,22 @@ func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) {
return store, resolver
}
func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["managed-key"],
"accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}]
}`)
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0")
store := config.LoadStore()
pool := account.NewPool(store)
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
return "unused", nil
})
return store, resolver
}
func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth {
t.Helper()
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
@@ -123,3 +139,38 @@ func TestResponsesRouteValidationContract(t *testing.T) {
})
}
}
func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) {
store, resolver := newManagedKeyResolver(t)
h := &Handler{Store: store, Auth: resolver}
r := chi.NewRouter()
RegisterRoutes(r, h)
ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
ownerReq.Header.Set("Authorization", "Bearer managed-key")
ownerAuth, err := resolver.DetermineCaller(ownerReq)
if err != nil {
t.Fatalf("determine caller failed: %v", err)
}
owner := responseStoreOwner(ownerAuth)
h.getResponseStore().put(owner, "resp_test", map[string]any{
"id": "resp_test",
"object": "response",
})
occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
occupyReq.Header.Set("Authorization", "Bearer managed-key")
occupied, err := resolver.Determine(occupyReq)
if err != nil {
t.Fatalf("expected first acquire to succeed: %v", err)
}
defer resolver.Release(occupied)
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
req.Header.Set("Authorization", "Bearer managed-key")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String())
}
}

View File

@@ -93,15 +93,18 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
}
leased = true
writeJSON(w, http.StatusOK, map[string]any{
"session_id": sessionID,
"lease_id": leaseID,
"model": stdReq.ResponseModel,
"final_prompt": stdReq.FinalPrompt,
"thinking_enabled": stdReq.Thinking,
"search_enabled": stdReq.Search,
"deepseek_token": a.DeepSeekToken,
"pow_header": powHeader,
"payload": payload,
"session_id": sessionID,
"lease_id": leaseID,
"model": stdReq.ResponseModel,
"final_prompt": stdReq.FinalPrompt,
"thinking_enabled": stdReq.Thinking,
"search_enabled": stdReq.Search,
"tool_names": stdReq.ToolNames,
"toolcall_feature_match": h.toolcallFeatureMatchEnabled(),
"toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(),
"deepseek_token": a.DeepSeekToken,
"pow_header": powHeader,
"payload": payload,
})
}

View File

@@ -83,6 +83,26 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
return a, nil
}
// DetermineCaller resolves caller identity without acquiring any pooled account.
// Use this for local-cache lookup routes that only need tenant isolation.
func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) {
callerKey := extractCallerToken(req)
if callerKey == "" {
return nil, ErrUnauthorized
}
callerID := callerTokenID(callerKey)
a := &RequestAuth{
UseConfigToken: false,
CallerID: callerID,
resolver: r,
TriedAccounts: map[string]bool{},
}
if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) {
a.DeepSeekToken = callerKey
}
return a, nil
}
func WithAuth(ctx context.Context, a *RequestAuth) context.Context {
return context.WithValue(ctx, authCtxKey, a)
}

View File

@@ -66,6 +66,26 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
}
}
func TestDetermineCallerWithManagedKeySkipsAccountAcquire(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
req.Header.Set("x-api-key", "managed-key")
a, err := r.DetermineCaller(req)
if err != nil {
t.Fatalf("determine caller failed: %v", err)
}
if a.CallerID == "" {
t.Fatalf("expected caller id to be populated")
}
if a.UseConfigToken {
t.Fatalf("expected no config-token lease for caller-only auth")
}
if a.AccountID != "" {
t.Fatalf("expected empty account id, got %q", a.AccountID)
}
}
func TestCallerTokenIDStable(t *testing.T) {
a := callerTokenID("token-a")
b := callerTokenID("token-a")
@@ -93,3 +113,16 @@ func TestDetermineMissingToken(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDetermineCallerMissingToken(t *testing.T) {
r := newTestResolver(t)
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
_, err := r.DetermineCaller(req)
if err == nil {
t.Fatal("expected unauthorized error")
}
if err != ErrUnauthorized {
t.Fatalf("unexpected error: %v", err)
}
}

View File

@@ -73,7 +73,7 @@ type Config struct {
}
type CompatConfig struct {
WideInputStrictOutput bool `json:"wide_input_strict_output,omitempty"`
WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"`
}
type ToolcallConfig struct {
@@ -109,7 +109,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
if len(c.ModelAliases) > 0 {
m["model_aliases"] = c.ModelAliases
}
if c.Compat.WideInputStrictOutput {
if c.Compat.WideInputStrictOutput != nil {
m["compat"] = c.Compat
}
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
@@ -194,12 +194,14 @@ func (c *Config) UnmarshalJSON(b []byte) error {
func (c Config) Clone() Config {
clone := Config{
Keys: slices.Clone(c.Keys),
Accounts: slices.Clone(c.Accounts),
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
ModelAliases: cloneStringMap(c.ModelAliases),
Compat: c.Compat,
Keys: slices.Clone(c.Keys),
Accounts: slices.Clone(c.Accounts),
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
ModelAliases: cloneStringMap(c.ModelAliases),
Compat: CompatConfig{
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
},
Toolcall: c.Toolcall,
Responses: c.Responses,
Embeddings: c.Embeddings,
@@ -224,6 +226,14 @@ func cloneStringMap(in map[string]string) map[string]string {
return out
}
func cloneBoolPtr(in *bool) *bool {
if in == nil {
return nil
}
v := *in
return &v
}
type Store struct {
mu sync.RWMutex
cfg Config
@@ -569,9 +579,12 @@ func (s *Store) ModelAliases() map[string]string {
}
func (s *Store) CompatWideInputStrictOutput() bool {
// Current default policy is always wide-input / strict-output.
// Kept as a method so callers do not depend on storage shape.
return true
s.mu.RLock()
defer s.mu.RUnlock()
if s.cfg.Compat.WideInputStrictOutput == nil {
return true
}
return *s.cfg.Compat.WideInputStrictOutput
}
func (s *Store) ToolcallMode() string {

View File

@@ -320,6 +320,39 @@ func TestStoreFindAccountNotFound(t *testing.T) {
}
}
func TestStoreCompatWideInputStrictOutputDefaultTrue(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
store := LoadStore()
if !store.CompatWideInputStrictOutput() {
t.Fatal("expected default wide_input_strict_output=true when unset")
}
}
func TestStoreCompatWideInputStrictOutputCanDisable(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[],"compat":{"wide_input_strict_output":false}}`)
store := LoadStore()
if store.CompatWideInputStrictOutput() {
t.Fatal("expected wide_input_strict_output=false when explicitly configured")
}
snap := store.Snapshot()
data, err := snap.MarshalJSON()
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out map[string]any
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("decode failed: %v", err)
}
rawCompat, ok := out["compat"].(map[string]any)
if !ok {
t.Fatalf("expected compat in marshaled output, got %#v", out)
}
if rawCompat["wide_input_strict_output"] != false {
t.Fatalf("expected explicit false in compat, got %#v", rawCompat)
}
}
func TestStoreIsEnvBacked(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
store := LoadStore()