mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
feat: Introduce standard request normalization and response building for OpenAI and Claude, enhance tool call streaming, and improve caller identification.
This commit is contained in:
@@ -309,7 +309,7 @@ data: [DONE]
|
||||
|
||||
### `GET /v1/responses/{response_id}`
|
||||
|
||||
Business auth required. Fetches cached responses created by `POST /v1/responses`.
|
||||
Business auth required. Fetches cached responses created by `POST /v1/responses` (caller-scoped; only the same key/token can read).
|
||||
|
||||
> Backed by in-memory TTL store. Default TTL is `900s` (configurable via `responses.store_ttl_seconds`).
|
||||
|
||||
|
||||
2
API.md
2
API.md
@@ -309,7 +309,7 @@ data: [DONE]
|
||||
|
||||
### `GET /v1/responses/{response_id}`
|
||||
|
||||
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象。
|
||||
需要业务鉴权。查询 `POST /v1/responses` 生成并缓存的 response 对象(按调用方鉴权隔离,仅同一 key/token 可读取)。
|
||||
|
||||
> 当前为内存 TTL 存储,默认过期时间 `900s`(可用 `responses.store_ttl_seconds` 调整)。
|
||||
|
||||
|
||||
@@ -63,32 +63,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
norm, err := normalizeClaudeRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
if _, ok := req["max_tokens"]; !ok {
|
||||
req["max_tokens"] = 8192
|
||||
}
|
||||
|
||||
normalized := normalizeClaudeMessages(messagesRaw)
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalized
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
if len(toolsRequested) > 0 && !hasSystemMessage(normalized) {
|
||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalized...)
|
||||
}
|
||||
|
||||
dsPayload := util.ConvertClaudeToDeepSeek(payload, h.Store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||
if !ok {
|
||||
thinkingEnabled = false
|
||||
searchEnabled = false
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
stdReq := norm.Standard
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
@@ -100,14 +80,7 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
|
||||
return
|
||||
}
|
||||
requestPayload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
requestPayload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
|
||||
@@ -120,15 +93,14 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
if util.ToBool(req["stream"]) {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, model, normalized, thinkingEnabled, searchEnabled, toolNames)
|
||||
if stdReq.Stream {
|
||||
h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
result := sse.CollectStream(resp, stdReq.Thinking, true)
|
||||
fullText := result.Text
|
||||
fullThinking := result.Thinking
|
||||
detected := util.ParseToolCalls(fullText, toolNames)
|
||||
detected := util.ParseToolCalls(fullText, stdReq.ToolNames)
|
||||
content := make([]map[string]any, 0, 4)
|
||||
if fullThinking != "" {
|
||||
content = append(content, map[string]any{"type": "thinking", "thinking": fullThinking})
|
||||
@@ -154,12 +126,12 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
"id": fmt.Sprintf("msg_%d", time.Now().UnixNano()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"model": stdReq.ResponseModel,
|
||||
"content": content,
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalized)),
|
||||
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", norm.NormalizedMessages)),
|
||||
"output_tokens": util.EstimateTokens(fullThinking) + util.EstimateTokens(fullText),
|
||||
},
|
||||
})
|
||||
|
||||
58
internal/adapter/claude/standard_request.go
Normal file
58
internal/adapter/claude/standard_request.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type claudeNormalizedRequest struct {
|
||||
Standard util.StandardRequest
|
||||
NormalizedMessages []any
|
||||
}
|
||||
|
||||
func normalizeClaudeRequest(store *config.Store, req map[string]any) (claudeNormalizedRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
return claudeNormalizedRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
|
||||
}
|
||||
if _, ok := req["max_tokens"]; !ok {
|
||||
req["max_tokens"] = 8192
|
||||
}
|
||||
normalizedMessages := normalizeClaudeMessages(messagesRaw)
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalizedMessages
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) {
|
||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...)
|
||||
}
|
||||
|
||||
dsPayload := util.ConvertClaudeToDeepSeek(payload, store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(dsModel)
|
||||
if !ok {
|
||||
thinkingEnabled = false
|
||||
searchEnabled = false
|
||||
}
|
||||
finalPrompt := util.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
|
||||
toolNames := extractClaudeToolNames(toolsRequested)
|
||||
|
||||
return claudeNormalizedRequest{
|
||||
Standard: util.StandardRequest{
|
||||
Surface: "anthropic_messages",
|
||||
RequestedModel: strings.TrimSpace(model),
|
||||
ResolvedModel: dsModel,
|
||||
ResponseModel: strings.TrimSpace(model),
|
||||
Messages: payload["messages"].([]any),
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
},
|
||||
NormalizedMessages: normalizedMessages,
|
||||
}, nil
|
||||
}
|
||||
38
internal/adapter/claude/standard_request_test.go
Normal file
38
internal/adapter/claude/standard_request_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func TestNormalizeClaudeRequest(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-opus-4-6",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"stream": true,
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if norm.Standard.ResolvedModel == "" {
|
||||
t.Fatalf("expected resolved model")
|
||||
}
|
||||
if !norm.Standard.Stream {
|
||||
t.Fatalf("expected stream=true")
|
||||
}
|
||||
if len(norm.Standard.ToolNames) == 0 {
|
||||
t.Fatalf("expected tool names")
|
||||
}
|
||||
if norm.Standard.FinalPrompt == "" {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
96
internal/adapter/openai/embeddings_route_test.go
Normal file
96
internal/adapter/openai/embeddings_route_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newResolverWithConfigJSON(t *testing.T, cfgJSON string) (*config.Store, *auth.Resolver) {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", cfgJSON)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "unused", nil
|
||||
})
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func TestEmbeddingsRouteContract(t *testing.T) {
|
||||
store, resolver := newResolverWithConfigJSON(t, `{"embeddings":{"provider":"deterministic"}}`)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
t.Run("unauthorized", func(t *testing.T) {
|
||||
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
body := bytes.NewBufferString(`{"model":"gpt-4o","input":["a","b"]}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if out["object"] != "list" {
|
||||
t.Fatalf("unexpected object: %#v", out["object"])
|
||||
}
|
||||
data, _ := out["data"].([]any)
|
||||
if len(data) != 2 {
|
||||
t.Fatalf("expected 2 embeddings, got %d", len(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddingsRouteProviderMissing(t *testing.T) {
|
||||
store, resolver := newResolverWithConfigJSON(t, `{}`)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := bytes.NewBufferString(`{"model":"gpt-4o","input":"hello"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/embeddings", body)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("expected 501, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if _, ok := errObj["code"]; !ok {
|
||||
t.Fatalf("expected error.code in response: %#v", out)
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatalf("expected error.param in response: %#v", out)
|
||||
}
|
||||
}
|
||||
@@ -91,24 +91,11 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
resolvedModel, ok := config.ResolveModel(h.Store, model)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
return
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
responseModel := strings.TrimSpace(model)
|
||||
if responseModel == "" {
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
@@ -124,25 +111,17 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
applyOpenAIChatPassThrough(req, payload)
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
return
|
||||
}
|
||||
if util.ToBool(req["stream"]) {
|
||||
h.handleStream(w, r, resp, sessionID, responseModel, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
if stdReq.Stream {
|
||||
h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, responseModel, finalPrompt, thinkingEnabled, toolNames)
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
@@ -208,7 +187,8 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
|
||||
created := time.Now().Unix()
|
||||
firstChunkSent := false
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
var toolSieve toolStreamSieveState
|
||||
toolCallsEmitted := false
|
||||
streamToolCallIDs := map[int]string{}
|
||||
@@ -377,6 +357,9 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
}
|
||||
for _, evt := range events {
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs),
|
||||
@@ -568,17 +551,23 @@ func openAIErrorCode(status int) string {
|
||||
}
|
||||
|
||||
func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
|
||||
for _, k := range []string{
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"stop",
|
||||
} {
|
||||
if v, ok := req[k]; ok {
|
||||
payload[k] = v
|
||||
}
|
||||
for k, v := range collectOpenAIChatPassThrough(req) {
|
||||
payload[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallFeatureMatchEnabled() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
|
||||
return mode == "" || mode == "feature_match"
|
||||
}
|
||||
|
||||
func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
|
||||
if h == nil || h.Store == nil {
|
||||
return true
|
||||
}
|
||||
level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
|
||||
return level == "" || level == "high"
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@ package openai
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
)
|
||||
|
||||
type storedResponse struct {
|
||||
Owner string
|
||||
Value map[string]any
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
@@ -26,32 +29,47 @@ func newResponseStore(ttl time.Duration) *responseStore {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responseStore) put(id string, value map[string]any) {
|
||||
if s == nil || id == "" || value == nil {
|
||||
func responseStoreKey(owner, id string) string {
|
||||
return owner + "\x00" + id
|
||||
}
|
||||
|
||||
func responseStoreOwner(a *auth.RequestAuth) string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
return a.CallerID
|
||||
}
|
||||
|
||||
func (s *responseStore) put(owner, id string, value map[string]any) {
|
||||
if s == nil || owner == "" || id == "" || value == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sweepLocked(now)
|
||||
s.items[id] = storedResponse{
|
||||
s.items[responseStoreKey(owner, id)] = storedResponse{
|
||||
Owner: owner,
|
||||
Value: cloneAnyMap(value),
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responseStore) get(id string) (map[string]any, bool) {
|
||||
if s == nil || id == "" {
|
||||
func (s *responseStore) get(owner, id string) (map[string]any, bool) {
|
||||
if s == nil || owner == "" || id == "" {
|
||||
return nil, false
|
||||
}
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sweepLocked(now)
|
||||
item, ok := s.items[id]
|
||||
item, ok := s.items[responseStoreKey(owner, id)]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if item.Owner != owner {
|
||||
return nil, false
|
||||
}
|
||||
return cloneAnyMap(item.Value), true
|
||||
}
|
||||
|
||||
|
||||
@@ -54,8 +54,8 @@ func TestDeterministicEmbeddingStable(t *testing.T) {
|
||||
|
||||
func TestResponseStorePutGet(t *testing.T) {
|
||||
st := newResponseStore(100 * time.Millisecond)
|
||||
st.put("resp_1", map[string]any{"id": "resp_1"})
|
||||
got, ok := st.get("resp_1")
|
||||
st.put("owner_1", "resp_1", map[string]any{"id": "resp_1"})
|
||||
got, ok := st.get("owner_1", "resp_1")
|
||||
if !ok {
|
||||
t.Fatal("expected stored response")
|
||||
}
|
||||
@@ -63,3 +63,11 @@ func TestResponseStorePutGet(t *testing.T) {
|
||||
t.Fatalf("unexpected response payload: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseStoreTenantIsolation(t *testing.T) {
|
||||
st := newResponseStore(100 * time.Millisecond)
|
||||
st.put("owner_a", "resp_1", map[string]any{"id": "resp_1"})
|
||||
if _, ok := st.get("owner_b", "resp_1"); ok {
|
||||
t.Fatal("expected owner_b to be isolated from owner_a response")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,19 +12,35 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) GetResponseByID(w http.ResponseWriter, r *http.Request) {
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
detail := err.Error()
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, detail)
|
||||
return
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
|
||||
id := strings.TrimSpace(chi.URLParam(r, "response_id"))
|
||||
if id == "" {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "response_id is required.")
|
||||
return
|
||||
}
|
||||
owner := responseStoreOwner(a)
|
||||
if owner == "" {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
st := h.getResponseStore()
|
||||
item, ok := st.get(id)
|
||||
item, ok := st.get(owner, id)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusNotFound, "Response not found.")
|
||||
return
|
||||
@@ -45,32 +61,22 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer h.Auth.Release(a)
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
owner := responseStoreOwner(a)
|
||||
if owner == "" {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := req["model"].(string)
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model'.")
|
||||
stdReq, err := normalizeOpenAIResponsesRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
resolvedModel, ok := config.ResolveModel(h.Store, model)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
return
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
|
||||
messagesRaw := responsesMessagesFromRequest(req)
|
||||
if len(messagesRaw) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'input' or 'messages'.")
|
||||
return
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
@@ -86,15 +92,7 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
applyOpenAIChatPassThrough(req, payload)
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
|
||||
@@ -102,14 +100,14 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if util.ToBool(req["stream"]) {
|
||||
h.handleResponsesStream(w, r, resp, responseID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
if stdReq.Stream {
|
||||
h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStream(w, resp, responseID, model, finalPrompt, thinkingEnabled, toolNames)
|
||||
h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -118,11 +116,11 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
responseObj := buildResponseObject(responseID, model, finalPrompt, result.Thinking, result.Text, toolNames)
|
||||
h.getResponseStore().put(responseID, responseObj)
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -160,7 +158,8 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
initialType = "thinking"
|
||||
}
|
||||
parsedLines, done := sse.StartParsedLinePump(r.Context(), resp.Body, thinkingEnabled, initialType)
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
|
||||
emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
|
||||
var sieve toolStreamSieveState
|
||||
thinking := strings.Builder{}
|
||||
text := strings.Builder{}
|
||||
@@ -194,7 +193,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
if toolCallsEmitted {
|
||||
obj["status"] = "completed"
|
||||
}
|
||||
h.getResponseStore().put(responseID, obj)
|
||||
h.getResponseStore().put(owner, responseID, obj)
|
||||
sendEvent("response.completed", map[string]any{
|
||||
"type": "response.completed",
|
||||
"response": obj,
|
||||
@@ -259,6 +258,9 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
|
||||
})
|
||||
}
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
if !emitEarlyToolDeltas {
|
||||
continue
|
||||
}
|
||||
toolCallsEmitted = true
|
||||
sendEvent("response.output_tool_call.delta", map[string]any{
|
||||
"type": "response.output_tool_call.delta",
|
||||
|
||||
125
internal/adapter/openai/responses_route_test.go
Normal file
125
internal/adapter/openai/responses_route_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/account"
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newDirectTokenResolver(t *testing.T) (*config.Store, *auth.Resolver) {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":[],"accounts":[]}`)
|
||||
store := config.LoadStore()
|
||||
pool := account.NewPool(store)
|
||||
resolver := auth.NewResolver(store, pool, func(_ context.Context, _ config.Account) (string, error) {
|
||||
return "unused", nil
|
||||
})
|
||||
return store, resolver
|
||||
}
|
||||
|
||||
func authForToken(t *testing.T, resolver *auth.Resolver, token string) *auth.RequestAuth {
|
||||
t.Helper()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
a, err := resolver.Determine(req)
|
||||
if err != nil {
|
||||
t.Fatalf("determine auth failed: %v", err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestGetResponseByIDRequiresAuthAndIsTenantIsolated(t *testing.T) {
|
||||
store, resolver := newDirectTokenResolver(t)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
ownerA := responseStoreOwner(authForToken(t, resolver, "token-a"))
|
||||
h.getResponseStore().put(ownerA, "resp_test", map[string]any{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
})
|
||||
|
||||
t.Run("unauthorized", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cross-tenant-not-found", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-b")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("same-tenant-ok", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test", nil)
|
||||
req.Header.Set("Authorization", "Bearer token-a")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode body failed: %v", err)
|
||||
}
|
||||
if body["id"] != "resp_test" {
|
||||
t.Fatalf("unexpected body: %#v", body)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponsesRouteValidationContract(t *testing.T) {
|
||||
store, resolver := newDirectTokenResolver(t)
|
||||
h := &Handler{Store: store, Auth: resolver}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{name: "missing_model", body: `{"input":"hello"}`},
|
||||
{name: "missing_input_and_messages", body: `{"model":"gpt-4o"}`},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewBufferString(tc.body))
|
||||
req.Header.Set("Authorization", "Bearer token-a")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if _, ok := errObj["code"]; !ok {
|
||||
t.Fatalf("expected error.code: %#v", out)
|
||||
}
|
||||
if _, ok := errObj["param"]; !ok {
|
||||
t.Fatalf("expected error.param: %#v", out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
104
internal/adapter/openai/standard_request.go
Normal file
104
internal/adapter/openai/standard_request.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func normalizeOpenAIChatRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if strings.TrimSpace(model) == "" || len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'model' and 'messages'.")
|
||||
}
|
||||
resolvedModel, ok := config.ResolveModel(store, model)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
responseModel := strings.TrimSpace(model)
|
||||
if responseModel == "" {
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "openai_chat",
|
||||
RequestedModel: strings.TrimSpace(model),
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: responseModel,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeOpenAIResponsesRequest(store *config.Store, req map[string]any) (util.StandardRequest, error) {
|
||||
model, _ := req["model"].(string)
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'model'.")
|
||||
}
|
||||
resolvedModel, ok := config.ResolveModel(store, model)
|
||||
if !ok {
|
||||
return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", model)
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
|
||||
// Keep width-control as an explicit policy hook even if current default is true.
|
||||
allowWideInput := true
|
||||
if store != nil {
|
||||
allowWideInput = store.CompatWideInputStrictOutput()
|
||||
}
|
||||
var messagesRaw []any
|
||||
if allowWideInput {
|
||||
messagesRaw = responsesMessagesFromRequest(req)
|
||||
} else if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
|
||||
messagesRaw = msgs
|
||||
}
|
||||
if len(messagesRaw) == 0 {
|
||||
return util.StandardRequest{}, fmt.Errorf("Request must include 'input' or 'messages'.")
|
||||
}
|
||||
finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
passThrough := collectOpenAIChatPassThrough(req)
|
||||
|
||||
return util.StandardRequest{
|
||||
Surface: "openai_responses",
|
||||
RequestedModel: model,
|
||||
ResolvedModel: resolvedModel,
|
||||
ResponseModel: model,
|
||||
Messages: messagesRaw,
|
||||
FinalPrompt: finalPrompt,
|
||||
ToolNames: toolNames,
|
||||
Stream: util.ToBool(req["stream"]),
|
||||
Thinking: thinkingEnabled,
|
||||
Search: searchEnabled,
|
||||
PassThrough: passThrough,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
|
||||
out := map[string]any{}
|
||||
for _, k := range []string{
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"stop",
|
||||
} {
|
||||
if v, ok := req[k]; ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
60
internal/adapter/openai/standard_request_test.go
Normal file
60
internal/adapter/openai/standard_request_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func newEmptyStoreForNormalizeTest(t *testing.T) *config.Store {
|
||||
t.Helper()
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
return config.LoadStore()
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIChatRequest(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-5-codex",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"temperature": 0.3,
|
||||
"stream": true,
|
||||
}
|
||||
n, err := normalizeOpenAIChatRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ResolvedModel != "deepseek-reasoner" {
|
||||
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
|
||||
}
|
||||
if !n.Stream {
|
||||
t.Fatalf("expected stream=true")
|
||||
}
|
||||
if _, ok := n.PassThrough["temperature"]; !ok {
|
||||
t.Fatalf("expected temperature passthrough")
|
||||
}
|
||||
if n.FinalPrompt == "" {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIResponsesRequestInput(t *testing.T) {
|
||||
store := newEmptyStoreForNormalizeTest(t)
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "ping",
|
||||
"instructions": "system",
|
||||
}
|
||||
n, err := normalizeOpenAIResponsesRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
if n.ResolvedModel != "deepseek-chat" {
|
||||
t.Fatalf("unexpected resolved model: %s", n.ResolvedModel)
|
||||
}
|
||||
if len(n.Messages) != 2 {
|
||||
t.Fatalf("expected 2 normalized messages, got %d", len(n.Messages))
|
||||
}
|
||||
}
|
||||
@@ -56,24 +56,15 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
stdReq, err := normalizeOpenAIChatRequest(h.Store, req)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
resolvedModel, ok := config.ResolveModel(h.Store, model)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusBadRequest, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
if !stdReq.Stream {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
}
|
||||
thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
|
||||
responseModel := strings.TrimSpace(model)
|
||||
if responseModel == "" {
|
||||
responseModel = resolvedModel
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messagesRaw, req["tools"])
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
@@ -94,15 +85,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
applyOpenAIChatPassThrough(req, payload)
|
||||
payload := stdReq.CompletionPayload(sessionID)
|
||||
leaseID := h.holdStreamLease(a)
|
||||
if leaseID == "" {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease")
|
||||
@@ -112,10 +95,10 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"session_id": sessionID,
|
||||
"lease_id": leaseID,
|
||||
"model": responseModel,
|
||||
"final_prompt": finalPrompt,
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
"model": stdReq.ResponseModel,
|
||||
"final_prompt": stdReq.FinalPrompt,
|
||||
"thinking_enabled": stdReq.Thinking,
|
||||
"search_enabled": stdReq.Search,
|
||||
"deepseek_token": a.DeepSeekToken,
|
||||
"pow_header": powHeader,
|
||||
"payload": payload,
|
||||
|
||||
@@ -2,6 +2,8 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -22,6 +24,7 @@ var (
|
||||
type RequestAuth struct {
|
||||
UseConfigToken bool
|
||||
DeepSeekToken string
|
||||
CallerID string
|
||||
AccountID string
|
||||
Account config.Account
|
||||
TriedAccounts map[string]bool
|
||||
@@ -45,9 +48,16 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
if callerKey == "" {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
callerID := callerTokenID(callerKey)
|
||||
ctx := req.Context()
|
||||
if !r.Store.HasAPIKey(callerKey) {
|
||||
return &RequestAuth{UseConfigToken: false, DeepSeekToken: callerKey, resolver: r, TriedAccounts: map[string]bool{}}, nil
|
||||
return &RequestAuth{
|
||||
UseConfigToken: false,
|
||||
DeepSeekToken: callerKey,
|
||||
CallerID: callerID,
|
||||
resolver: r,
|
||||
TriedAccounts: map[string]bool{},
|
||||
}, nil
|
||||
}
|
||||
target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account"))
|
||||
acc, ok := r.Pool.AcquireWait(ctx, target, nil)
|
||||
@@ -56,6 +66,7 @@ func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) {
|
||||
}
|
||||
a := &RequestAuth{
|
||||
UseConfigToken: true,
|
||||
CallerID: callerID,
|
||||
AccountID: acc.Identifier(),
|
||||
Account: acc,
|
||||
TriedAccounts: map[string]bool{},
|
||||
@@ -158,3 +169,12 @@ func extractCallerToken(req *http.Request) string {
|
||||
}
|
||||
return strings.TrimSpace(req.Header.Get("x-api-key"))
|
||||
}
|
||||
|
||||
func callerTokenID(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return "caller:" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
@@ -37,6 +37,9 @@ func TestDetermineWithXAPIKeyUsesDirectToken(t *testing.T) {
|
||||
if auth.DeepSeekToken != "direct-token" {
|
||||
t.Fatalf("unexpected token: %q", auth.DeepSeekToken)
|
||||
}
|
||||
if auth.CallerID == "" {
|
||||
t.Fatalf("expected caller id to be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
|
||||
@@ -58,6 +61,24 @@ func TestDetermineWithXAPIKeyManagedKeyAcquiresAccount(t *testing.T) {
|
||||
if auth.DeepSeekToken != "account-token" {
|
||||
t.Fatalf("unexpected account token: %q", auth.DeepSeekToken)
|
||||
}
|
||||
if auth.CallerID == "" {
|
||||
t.Fatalf("expected caller id to be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallerTokenIDStable(t *testing.T) {
|
||||
a := callerTokenID("token-a")
|
||||
b := callerTokenID("token-a")
|
||||
c := callerTokenID("token-b")
|
||||
if a == "" || b == "" || c == "" {
|
||||
t.Fatalf("expected non-empty caller ids")
|
||||
}
|
||||
if a != b {
|
||||
t.Fatalf("expected stable caller id, got %q and %q", a, b)
|
||||
}
|
||||
if a == c {
|
||||
t.Fatalf("expected different caller id for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineMissingToken(t *testing.T) {
|
||||
|
||||
@@ -755,11 +755,15 @@ func (r *Runner) cases() []caseDef {
|
||||
{ID: "healthz_ok", Run: r.caseHealthz},
|
||||
{ID: "readyz_ok", Run: r.caseReadyz},
|
||||
{ID: "models_openai", Run: r.caseModelsOpenAI},
|
||||
{ID: "model_openai_by_id", Run: r.caseModelOpenAIByID},
|
||||
{ID: "models_claude", Run: r.caseModelsClaude},
|
||||
{ID: "admin_login_verify", Run: r.caseAdminLoginVerify},
|
||||
{ID: "admin_queue_status", Run: r.caseAdminQueueStatus},
|
||||
{ID: "chat_nonstream_basic", Run: r.caseChatNonstream},
|
||||
{ID: "chat_stream_basic", Run: r.caseChatStream},
|
||||
{ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream},
|
||||
{ID: "responses_stream_basic", Run: r.caseResponsesStream},
|
||||
{ID: "embeddings_contract", Run: r.caseEmbeddings},
|
||||
{ID: "reasoner_stream", Run: r.caseReasonerStream},
|
||||
{ID: "toolcall_nonstream", Run: r.caseToolcallNonstream},
|
||||
{ID: "toolcall_stream", Run: r.caseToolcallStream},
|
||||
@@ -817,6 +821,19 @@ func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true})
|
||||
if err != nil {
|
||||
@@ -942,6 +959,115 @@ func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/responses",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "请简要回答 hello",
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
responseID := asString(m["id"])
|
||||
cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
if responseID != "" {
|
||||
getResp, getErr := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodGet,
|
||||
Path: "/v1/responses/" + responseID,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if getErr != nil {
|
||||
return getErr
|
||||
}
|
||||
cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/responses",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": "请流式回答 hello",
|
||||
"stream": true,
|
||||
},
|
||||
Stream: true,
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
frames, done := parseSSEFrames(resp.Body)
|
||||
cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
|
||||
hasCreated := false
|
||||
hasCompleted := false
|
||||
for _, f := range frames {
|
||||
switch asString(f["type"]) {
|
||||
case "response.created":
|
||||
hasCreated = true
|
||||
case "response.completed":
|
||||
hasCompleted = true
|
||||
}
|
||||
}
|
||||
cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("done_terminated", done, "expected [DONE]")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
Path: "/v1/embeddings",
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer " + r.apiKey,
|
||||
},
|
||||
Body: map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"input": []string{"hello", "world"},
|
||||
},
|
||||
Retryable: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode))
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
data, _ := m["data"].([]any)
|
||||
cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
errObj, _ := m["error"].(map[string]any)
|
||||
_, hasCode := errObj["code"]
|
||||
_, hasParam := errObj["param"]
|
||||
cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error {
|
||||
resp, err := cc.request(ctx, requestSpec{
|
||||
Method: http.MethodPost,
|
||||
|
||||
30
internal/util/standard_request.go
Normal file
30
internal/util/standard_request.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package util
|
||||
|
||||
type StandardRequest struct {
|
||||
Surface string
|
||||
RequestedModel string
|
||||
ResolvedModel string
|
||||
ResponseModel string
|
||||
Messages []any
|
||||
FinalPrompt string
|
||||
ToolNames []string
|
||||
Stream bool
|
||||
Thinking bool
|
||||
Search bool
|
||||
PassThrough map[string]any
|
||||
}
|
||||
|
||||
func (r StandardRequest) CompletionPayload(sessionID string) map[string]any {
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": r.FinalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": r.Thinking,
|
||||
"search_enabled": r.Search,
|
||||
}
|
||||
for k, v := range r.PassThrough {
|
||||
payload[k] = v
|
||||
}
|
||||
return payload
|
||||
}
|
||||
Reference in New Issue
Block a user