mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-13 12:47:41 +08:00
feat: add model type support for file uploads with automatic resolution and header propagation
This commit is contained in:
@@ -94,6 +94,9 @@ func TestPreprocessInlineFileInputsReplacesDataURLAndCollectsRefFileIDs(t *testi
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
if ds.uploadCalls[0].ModelType != "default" {
|
||||
t.Fatalf("expected default model type when request omits model, got %q", ds.uploadCalls[0].ModelType)
|
||||
}
|
||||
if ds.lastCtx != ctx {
|
||||
t.Fatalf("expected upload to use request context")
|
||||
}
|
||||
@@ -149,7 +152,7 @@ func TestPreprocessInlineFileInputsDeduplicatesIdenticalPayloads(t *testing.T) {
|
||||
func TestChatCompletionsUploadsInlineFilesBeforeCompletion(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &openAITestSurface{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
|
||||
reqBody := `{"model":"deepseek-v4-flash","messages":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"image_url","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
|
||||
reqBody := `{"model":"deepseek-v4-vision","messages":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"image_url","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -163,6 +166,9 @@ func TestChatCompletionsUploadsInlineFilesBeforeCompletion(t *testing.T) {
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
if ds.uploadCalls[0].ModelType != "vision" {
|
||||
t.Fatalf("expected vision model type for vision request, got %q", ds.uploadCalls[0].ModelType)
|
||||
}
|
||||
if ds.completionReq == nil {
|
||||
t.Fatal("expected completion payload to be captured")
|
||||
}
|
||||
@@ -177,7 +183,7 @@ func TestResponsesUploadsInlineFilesBeforeCompletion(t *testing.T) {
|
||||
h := &openAITestSurface{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
|
||||
r := chi.NewRouter()
|
||||
registerOpenAITestRoutes(r, h)
|
||||
reqBody := `{"model":"deepseek-v4-flash","input":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"input_image","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
|
||||
reqBody := `{"model":"deepseek-v4-pro","input":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"input_image","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -191,6 +197,9 @@ func TestResponsesUploadsInlineFilesBeforeCompletion(t *testing.T) {
|
||||
if len(ds.uploadCalls) != 1 {
|
||||
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
|
||||
}
|
||||
if ds.uploadCalls[0].ModelType != "expert" {
|
||||
t.Fatalf("expected expert model type for pro request, got %q", ds.uploadCalls[0].ModelType)
|
||||
}
|
||||
refIDs, _ := ds.completionReq["ref_file_ids"].([]any)
|
||||
if len(refIDs) != 1 || refIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("unexpected completion ref_file_ids: %#v", ds.completionReq["ref_file_ids"])
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
dsclient "ds2api/internal/deepseek/client"
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
"ds2api/internal/promptcompat"
|
||||
@@ -42,6 +43,7 @@ type inlineUploadState struct {
|
||||
ctx context.Context
|
||||
handler *Handler
|
||||
auth *auth.RequestAuth
|
||||
modelType string
|
||||
uploadedByID map[string]string
|
||||
uploadCount int
|
||||
inlineFileBytes int
|
||||
@@ -58,10 +60,19 @@ func (h *Handler) PreprocessInlineFileInputs(ctx context.Context, a *auth.Reques
|
||||
if h == nil || h.DS == nil || len(req) == 0 {
|
||||
return nil
|
||||
}
|
||||
modelType := "default"
|
||||
if requestedModel, ok := req["model"].(string); ok {
|
||||
if resolvedModel, ok := config.ResolveModel(h.Store, requestedModel); ok {
|
||||
if resolvedType, ok := config.GetModelType(resolvedModel); ok {
|
||||
modelType = resolvedType
|
||||
}
|
||||
}
|
||||
}
|
||||
state := &inlineUploadState{
|
||||
ctx: ctx,
|
||||
handler: h,
|
||||
auth: a,
|
||||
modelType: modelType,
|
||||
uploadedByID: map[string]string{},
|
||||
}
|
||||
for _, key := range []string{"messages", "input", "attachments"} {
|
||||
@@ -174,6 +185,7 @@ func (s *inlineUploadState) uploadInlineFile(file inlineDecodedFile) (string, er
|
||||
result, err := s.handler.DS.UploadFile(s.ctx, s.auth, dsclient.UploadFileRequest{
|
||||
Filename: file.Filename,
|
||||
ContentType: contentType,
|
||||
ModelType: s.modelType,
|
||||
Data: file.Data,
|
||||
}, 3)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/chathistory"
|
||||
"ds2api/internal/config"
|
||||
dsclient "ds2api/internal/deepseek/client"
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
)
|
||||
@@ -66,10 +67,12 @@ func (h *Handler) UploadFile(w http.ResponseWriter, r *http.Request) {
|
||||
if contentType == "" && len(data) > 0 {
|
||||
contentType = http.DetectContentType(data)
|
||||
}
|
||||
modelType := resolveUploadModelType(h.Store, r)
|
||||
result, err := h.DS.UploadFile(r.Context(), a, dsclient.UploadFileRequest{
|
||||
Filename: header.Filename,
|
||||
ContentType: contentType,
|
||||
Purpose: strings.TrimSpace(r.FormValue("purpose")),
|
||||
ModelType: modelType,
|
||||
Data: data,
|
||||
}, 3)
|
||||
if err != nil {
|
||||
@@ -82,6 +85,32 @@ func (h *Handler) UploadFile(w http.ResponseWriter, r *http.Request) {
|
||||
shared.WriteJSON(w, http.StatusOK, buildOpenAIFileObject(result))
|
||||
}
|
||||
|
||||
func resolveUploadModelType(store shared.ConfigReader, r *http.Request) string {
|
||||
for _, candidate := range []string{r.FormValue("model_type"), r.Header.Get("X-Model-Type")} {
|
||||
if modelType := normalizeUploadModelType(candidate); modelType != "" {
|
||||
return modelType
|
||||
}
|
||||
}
|
||||
requestedModel := strings.TrimSpace(r.FormValue("model"))
|
||||
if requestedModel != "" {
|
||||
if resolvedModel, ok := config.ResolveModel(store, requestedModel); ok {
|
||||
if modelType, ok := config.GetModelType(resolvedModel); ok {
|
||||
return modelType
|
||||
}
|
||||
}
|
||||
}
|
||||
return "default"
|
||||
}
|
||||
|
||||
func normalizeUploadModelType(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "default", "expert", "vision":
|
||||
return strings.ToLower(strings.TrimSpace(raw))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func buildOpenAIFileObject(result *dsclient.UploadFileResult) map[string]any {
|
||||
if result == nil {
|
||||
obj := map[string]any{
|
||||
|
||||
@@ -77,7 +77,7 @@ func (m *filesRouteDSStub) DeleteAllSessionsForToken(_ context.Context, _ string
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMultipartUploadRequest(t *testing.T, purpose string, filename string, data []byte) *http.Request {
|
||||
func newMultipartUploadRequest(t *testing.T, purpose string, filename string, data []byte, model string) *http.Request {
|
||||
t.Helper()
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
@@ -86,6 +86,11 @@ func newMultipartUploadRequest(t *testing.T, purpose string, filename string, da
|
||||
t.Fatalf("write purpose failed: %v", err)
|
||||
}
|
||||
}
|
||||
if model != "" {
|
||||
if err := writer.WriteField("model", model); err != nil {
|
||||
t.Fatalf("write model failed: %v", err)
|
||||
}
|
||||
}
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
t.Fatalf("create form file failed: %v", err)
|
||||
@@ -108,7 +113,7 @@ func TestFilesRouteUploadSuccess(t *testing.T) {
|
||||
r := chi.NewRouter()
|
||||
registerOpenAITestRoutes(r, h)
|
||||
|
||||
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"))
|
||||
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"), "deepseek-v4-vision")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
@@ -121,6 +126,9 @@ func TestFilesRouteUploadSuccess(t *testing.T) {
|
||||
if ds.lastReq.Purpose != "assistants" {
|
||||
t.Fatalf("expected purpose assistants, got %q", ds.lastReq.Purpose)
|
||||
}
|
||||
if ds.lastReq.ModelType != "vision" {
|
||||
t.Fatalf("expected vision model type, got %q", ds.lastReq.ModelType)
|
||||
}
|
||||
if string(ds.lastReq.Data) != "hello world" {
|
||||
t.Fatalf("unexpected uploaded data: %q", string(ds.lastReq.Data))
|
||||
}
|
||||
@@ -145,7 +153,7 @@ func TestFilesRouteUploadIncludesAccountIDForManagedAccount(t *testing.T) {
|
||||
r := chi.NewRouter()
|
||||
registerOpenAITestRoutes(r, h)
|
||||
|
||||
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"))
|
||||
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"), "deepseek-v4-vision")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
dsclient "ds2api/internal/deepseek/client"
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
"ds2api/internal/promptcompat"
|
||||
@@ -35,10 +36,15 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth,
|
||||
if strings.TrimSpace(fileText) == "" {
|
||||
return stdReq, errors.New("current user input file produced empty transcript")
|
||||
}
|
||||
modelType := "default"
|
||||
if resolvedType, ok := config.GetModelType(stdReq.ResolvedModel); ok {
|
||||
modelType = resolvedType
|
||||
}
|
||||
result, err := s.DS.UploadFile(ctx, a, dsclient.UploadFileRequest{
|
||||
Filename: currentInputFilename,
|
||||
ContentType: currentInputContentType,
|
||||
Purpose: currentInputPurpose,
|
||||
ModelType: modelType,
|
||||
Data: []byte(fileText),
|
||||
}, 3)
|
||||
if err != nil {
|
||||
|
||||
@@ -227,7 +227,7 @@ func TestApplyCurrentInputFileDisabledPassThrough(t *testing.T) {
|
||||
DS: ds,
|
||||
}
|
||||
req := map[string]any{
|
||||
"model": "deepseek-v4-flash",
|
||||
"model": "deepseek-v4-vision",
|
||||
"messages": historySplitTestMessages(),
|
||||
}
|
||||
stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "")
|
||||
@@ -332,7 +332,7 @@ func TestApplyCurrentInputFilePreservesFullContextPromptForTokenCounting(t *test
|
||||
DS: ds,
|
||||
}
|
||||
req := map[string]any{
|
||||
"model": "deepseek-v4-flash",
|
||||
"model": "deepseek-v4-vision",
|
||||
"messages": historySplitTestMessages(),
|
||||
}
|
||||
stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "")
|
||||
@@ -378,7 +378,7 @@ func TestApplyCurrentInputFileUploadsFullContextFile(t *testing.T) {
|
||||
DS: ds,
|
||||
}
|
||||
req := map[string]any{
|
||||
"model": "deepseek-v4-flash",
|
||||
"model": "deepseek-v4-vision",
|
||||
"messages": historySplitTestMessages(),
|
||||
}
|
||||
stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "")
|
||||
@@ -400,6 +400,9 @@ func TestApplyCurrentInputFileUploadsFullContextFile(t *testing.T) {
|
||||
if upload.Filename != "DS2API_HISTORY.txt" {
|
||||
t.Fatalf("expected DS2API_HISTORY.txt upload, got %q", upload.Filename)
|
||||
}
|
||||
if upload.ModelType != "vision" {
|
||||
t.Fatalf("expected vision model type for vision request, got %q", upload.ModelType)
|
||||
}
|
||||
uploadedText := string(upload.Data)
|
||||
for _, want := range []string{"# DS2API_HISTORY.txt", "=== 1. SYSTEM ===", "=== 2. USER ===", "=== 3. ASSISTANT ===", "=== 4. TOOL ===", "=== 5. USER ===", "system instructions", "first user turn", "hidden reasoning", "tool result", "latest user turn", promptcompat.ThinkingInjectionMarker} {
|
||||
if !strings.Contains(uploadedText, want) {
|
||||
|
||||
Reference in New Issue
Block a user