mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 16:35:27 +08:00
feat: Introduce DetermineCaller for auth without account pooling and make wide_input_strict_output configurable.
This commit is contained in:
@@ -85,7 +85,8 @@ module.exports = async function handler(req, res) {
|
||||
const finalPrompt = asString(prep.body.final_prompt);
|
||||
const thinkingEnabled = toBool(prep.body.thinking_enabled);
|
||||
const searchEnabled = toBool(prep.body.search_enabled);
|
||||
const toolNames = extractToolNames(payload.tools);
|
||||
const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools);
|
||||
const toolNames = toolPolicy.toolNames;
|
||||
|
||||
if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
|
||||
writeOpenAIError(res, 500, 'invalid vercel prepare response');
|
||||
@@ -156,7 +157,8 @@ module.exports = async function handler(req, res) {
|
||||
let currentType = thinkingEnabled ? 'thinking' : 'text';
|
||||
let thinkingText = '';
|
||||
let outputText = '';
|
||||
const toolSieveEnabled = toolNames.length > 0;
|
||||
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
|
||||
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
|
||||
const toolSieveState = createToolSieveState();
|
||||
let toolCallsEmitted = false;
|
||||
const streamToolCallIDs = new Map();
|
||||
@@ -297,6 +299,9 @@ module.exports = async function handler(req, res) {
|
||||
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
||||
for (const evt of events) {
|
||||
if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) {
|
||||
if (!emitEarlyToolDeltas) {
|
||||
continue;
|
||||
}
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) });
|
||||
continue;
|
||||
@@ -407,6 +412,37 @@ function relayPreparedFailure(res, prep) {
|
||||
writeOpenAIError(res, prep.status || 500, 'vercel prepare failed');
|
||||
}
|
||||
|
||||
function resolveToolcallPolicy(prepBody, payloadTools) {
|
||||
const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names);
|
||||
const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
|
||||
const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match);
|
||||
const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
|
||||
return {
|
||||
toolNames,
|
||||
toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled,
|
||||
emitEarlyToolDeltas,
|
||||
};
|
||||
}
|
||||
|
||||
function normalizePreparedToolNames(v) {
|
||||
if (!Array.isArray(v) || v.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const item of v) {
|
||||
const name = asString(item);
|
||||
if (!name) {
|
||||
continue;
|
||||
}
|
||||
out.push(name);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function boolDefaultTrue(v) {
|
||||
return v !== false;
|
||||
}
|
||||
|
||||
async function safeReadText(resp) {
|
||||
if (!resp) {
|
||||
return '';
|
||||
@@ -933,4 +969,7 @@ module.exports.__test = {
|
||||
extractContentRecursive,
|
||||
shouldSkipPath,
|
||||
asString,
|
||||
resolveToolcallPolicy,
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
};
|
||||
|
||||
@@ -10,10 +10,50 @@ const {
|
||||
flushToolSieve,
|
||||
} = require('./helpers/stream-tool-sieve');
|
||||
|
||||
const { parseChunkForContent } = handler.__test;
|
||||
const {
|
||||
parseChunkForContent,
|
||||
resolveToolcallPolicy,
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
} = handler.__test;
|
||||
|
||||
test('chat-stream exposes parser test hooks', () => {
|
||||
assert.equal(typeof parseChunkForContent, 'function');
|
||||
assert.equal(typeof resolveToolcallPolicy, 'function');
|
||||
});
|
||||
|
||||
test('resolveToolcallPolicy defaults to feature-match + early emit when prepare flags missing', () => {
|
||||
const policy = resolveToolcallPolicy(
|
||||
{},
|
||||
[{ type: 'function', function: { name: 'read_file', parameters: { type: 'object' } } }],
|
||||
);
|
||||
assert.deepEqual(policy.toolNames, ['read_file']);
|
||||
assert.equal(policy.toolSieveEnabled, true);
|
||||
assert.equal(policy.emitEarlyToolDeltas, true);
|
||||
});
|
||||
|
||||
test('resolveToolcallPolicy respects prepare flags and prepared tool names', () => {
|
||||
const policy = resolveToolcallPolicy(
|
||||
{
|
||||
tool_names: [' prepped_tool ', '', null],
|
||||
toolcall_feature_match: false,
|
||||
toolcall_early_emit_high: false,
|
||||
},
|
||||
[{ type: 'function', function: { name: 'fallback_tool', parameters: { type: 'object' } } }],
|
||||
);
|
||||
assert.deepEqual(policy.toolNames, ['prepped_tool']);
|
||||
assert.equal(policy.toolSieveEnabled, false);
|
||||
assert.equal(policy.emitEarlyToolDeltas, false);
|
||||
});
|
||||
|
||||
test('normalizePreparedToolNames filters empty values', () => {
|
||||
assert.deepEqual(normalizePreparedToolNames([' a ', '', null, 'b']), ['a', 'b']);
|
||||
});
|
||||
|
||||
test('boolDefaultTrue keeps false only when explicitly false', () => {
|
||||
assert.equal(boolDefaultTrue(false), false);
|
||||
assert.equal(boolDefaultTrue(true), true);
|
||||
assert.equal(boolDefaultTrue(undefined), true);
|
||||
});
|
||||
|
||||
test('parseChunkForContent keeps split response/content fragments inside response array', () => {
|
||||
|
||||
@@ -16,17 +16,11 @@ import (
|
||||
)
|
||||
|
||||
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
a, err := h.Auth.DetermineCaller(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
writeOpenAIError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
id := strings.TrimSpace(chi.URLParam(r, "response_id"))
|
||||
if id == "" {
|
||||
|
||||
@@ -26,6 +26,22 @@ func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) {
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func newManagedKeyResolver(t *testing.T) (*config.Store, *auth.Resolver) {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["managed-key"],
|
||||
"accounts":[{"email":"acc@example.com","password":"pwd","token":"account-token"}]
|
||||
}`)
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_QUEUE", "0")
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "unused", nil
|
||||
})
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth {
|
||||
t.Helper()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
@@ -123,3 +139,38 @@ func TestResponsesRouteValidationContract(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetResponseByIDManagedKeySkipsAccountPoolPressure(t *testing.T) {
|
||||
store, resolver := newManagedKeyResolver(t)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
ownerReq := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
ownerReq.Header.Set("Authorization", "Bearer managed-key")
|
||||
ownerAuth, err := resolver.DetermineCaller(ownerReq)
|
||||
if err != nil {
|
||||
t.Fatalf("determine caller failed: %v", err)
|
||||
}
|
||||
owner := responseStoreOwner(ownerAuth)
|
||||
h.getResponseStore().put(owner, "resp_test", map[string]any{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
})
|
||||
|
||||
occupyReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
occupyReq.Header.Set("Authorization", "Bearer managed-key")
|
||||
occupied, err := resolver.Determine(occupyReq)
|
||||
if err != nil {
|
||||
t.Fatalf("expected first acquire to succeed: %v", err)
|
||||
}
|
||||
defer resolver.Release(occupied)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer managed-key")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 under pool pressure, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,15 +93,18 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
leased = true
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"session_id": sessionID,
|
||||
"lease_id": leaseID,
|
||||
"model": stdReq.ResponseModel,
|
||||
"final_prompt": stdReq.FinalPrompt,
|
||||
"thinking_enabled": stdReq.Thinking,
|
||||
"search_enabled": stdReq.Search,
|
||||
"deepseek_token": a.DeepSeekToken,
|
||||
"pow_header": powHeader,
|
||||
"payload": payload,
|
||||
"session_id": sessionID,
|
||||
"lease_id": leaseID,
|
||||
"model": stdReq.ResponseModel,
|
||||
"final_prompt": stdReq.FinalPrompt,
|
||||
"thinking_enabled": stdReq.Thinking,
|
||||
"search_enabled": stdReq.Search,
|
||||
"tool_names": stdReq.ToolNames,
|
||||
"toolcall_feature_match": h.toolcallFeatureMatchEnabled(),
|
||||
"toolcall_early_emit_high": h.toolcallEarlyEmitHighConfidence(),
|
||||
"deepseek_token": a.DeepSeekToken,
|
||||
"pow_header": powHeader,
|
||||
"payload": payload,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -83,6 +83,26 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// DetermineCaller resolves caller identity without acquiring any pooled account.
|
||||
// Use this for local-cache lookup routes that only need tenant isolation.
|
||||
func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) {
|
||||
callerKey := extractCallerToken(req)
|
||||
if callerKey == "" {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
callerID := callerTokenID(callerKey)
|
||||
a := &RequestAuth{
|
||||
UseConfigToken: false,
|
||||
CallerID: callerID,
|
||||
resolver: r,
|
||||
TriedAccounts: map[string]bool{},
|
||||
}
|
||||
if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) {
|
||||
a.DeepSeekToken = callerKey
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func WithAuth(ctx context.Context, a *RequestAuth) context.Context {
|
||||
return context.WithValue(ctx, authCtxKey, a)
|
||||
}
|
||||
|
||||
@@ -66,6 +66,26 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineCallerWithManagedKeySkipsAccountAcquire(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
|
||||
req.Header.Set("x-api-key", "managed-key")
|
||||
|
||||
a, err := r.DetermineCaller(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine caller failed: %v", err)
|
||||
}
|
||||
if a.CallerID == "" {
|
||||
t.Fatalf("expected caller id to be populated")
|
||||
}
|
||||
if a.UseConfigToken {
|
||||
t.Fatalf("expected no config-token lease for caller-only auth")
|
||||
}
|
||||
if a.AccountID != "" {
|
||||
t.Fatalf("expected empty account id, got %q", a.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallerTokenIDStable(t *testing.T) {
|
||||
a := callerTokenID("token-a")
|
||||
b := callerTokenID("token-a")
|
||||
@@ -93,3 +113,16 @@ func TestDetermineMissingToken(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineCallerMissingToken(t *testing.T) {
|
||||
r := newTestResolver(t)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil)
|
||||
|
||||
_, err := r.DetermineCaller(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected unauthorized error")
|
||||
}
|
||||
if err != ErrUnauthorized {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ type Config struct {
|
||||
}
|
||||
|
||||
type CompatConfig struct {
|
||||
WideInputStrictOutput bool `json:"wide_input_strict_output,omitempty"`
|
||||
WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"`
|
||||
}
|
||||
|
||||
type ToolcallConfig struct {
|
||||
@@ -109,7 +109,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
if len(c.ModelAliases) > 0 {
|
||||
m["model_aliases"] = c.ModelAliases
|
||||
}
|
||||
if c.Compat.WideInputStrictOutput {
|
||||
if c.Compat.WideInputStrictOutput != nil {
|
||||
m["compat"] = c.Compat
|
||||
}
|
||||
if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
|
||||
@@ -194,12 +194,14 @@ func (c *Config) UnmarshalJSON(b []byte) error {
|
||||
|
||||
func (c Config) Clone() Config {
|
||||
clone := Config{
|
||||
Keys: slices.Clone(c.Keys),
|
||||
Accounts: slices.Clone(c.Accounts),
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
|
||||
ModelAliases: cloneStringMap(c.ModelAliases),
|
||||
Compat: c.Compat,
|
||||
Keys: slices.Clone(c.Keys),
|
||||
Accounts: slices.Clone(c.Accounts),
|
||||
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
|
||||
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
|
||||
ModelAliases: cloneStringMap(c.ModelAliases),
|
||||
Compat: CompatConfig{
|
||||
WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
|
||||
},
|
||||
Toolcall: c.Toolcall,
|
||||
Responses: c.Responses,
|
||||
Embeddings: c.Embeddings,
|
||||
@@ -224,6 +226,14 @@ func cloneStringMap(in map[string]string) map[string]string {
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneBoolPtr(in *bool) *bool {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
v := *in
|
||||
return &v
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
cfg Config
|
||||
@@ -569,9 +579,12 @@ func (s *Store) ModelAliases() map[string]string {
|
||||
}
|
||||
|
||||
func (s *Store) CompatWideInputStrictOutput() bool {
|
||||
// Current default policy is always wide-input / strict-output.
|
||||
// Kept as a method so callers do not depend on storage shape.
|
||||
return true
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.cfg.Compat.WideInputStrictOutput == nil {
|
||||
return true
|
||||
}
|
||||
return *s.cfg.Compat.WideInputStrictOutput
|
||||
}
|
||||
|
||||
func (s *Store) ToolcallMode() string {
|
||||
|
||||
@@ -320,6 +320,39 @@ func TestStoreFindAccountNotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCompatWideInputStrictOutputDefaultTrue(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
if !store.CompatWideInputStrictOutput() {
|
||||
t.Fatal("expected default wide_input_strict_output=true when unset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCompatWideInputStrictOutputCanDisable(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[],"compat":{"wide_input_strict_output":false}}`)
|
||||
store := LoadStore()
|
||||
if store.CompatWideInputStrictOutput() {
|
||||
t.Fatal("expected wide_input_strict_output=false when explicitly configured")
|
||||
}
|
||||
|
||||
snap := store.Snapshot()
|
||||
data, err := snap.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
rawCompat, ok := out["compat"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected compat in marshaled output, got %#v", out)
|
||||
}
|
||||
if rawCompat["wide_input_strict_output"] != false {
|
||||
t.Fatalf("expected explicit false in compat, got %#v", rawCompat)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreIsEnvBacked(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[]}`)
|
||||
store := LoadStore()
|
||||
|
||||
Reference in New Issue
Block a user