diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go index 593c1bc..2fa6796 100644 --- a/internal/adapter/claude/handler.go +++ b/internal/adapter/claude/handler.go @@ -38,6 +38,10 @@ func RegisterRoutes(r chi.Router, h *Handler) { r.Get("/anthropic/v1/models", h.ListModels) r.Post("/anthropic/v1/messages", h.Messages) r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens) + r.Post("/v1/messages", h.Messages) + r.Post("/messages", h.Messages) + r.Post("/v1/messages/count_tokens", h.CountTokens) + r.Post("/messages/count_tokens", h.CountTokens) } func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { @@ -167,7 +171,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") rc := http.NewResponseController(w) - canFlush := rc.Flush() == nil + _, canFlush := w.(http.Flusher) if !canFlush { config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") } diff --git a/internal/adapter/claude/route_alias_test.go b/internal/adapter/claude/route_alias_test.go new file mode 100644 index 0000000..f01e5e3 --- /dev/null +++ b/internal/adapter/claude/route_alias_test.go @@ -0,0 +1,44 @@ +package claude + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" +) + +type routeAliasAuthStub struct{} + +func (routeAliasAuthStub) Determine(_ *http.Request) (*auth.RequestAuth, error) { + return nil, auth.ErrUnauthorized +} + +func (routeAliasAuthStub) Release(_ *auth.RequestAuth) {} + +func TestClaudeRouteAliasesDoNot404(t *testing.T) { + h := &Handler{ + Auth: routeAliasAuthStub{}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + paths := []string{ + "/anthropic/v1/messages", + "/v1/messages", + "/messages", + "/anthropic/v1/messages/count_tokens", + "/v1/messages/count_tokens", + "/messages/count_tokens", + } + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code == http.StatusNotFound { + t.Fatalf("expected route %s to be registered, got 404", path) + } + } +} diff --git a/internal/adapter/gemini/convert.go b/internal/adapter/gemini/convert.go new file mode 100644 index 0000000..3f63579 --- /dev/null +++ b/internal/adapter/gemini/convert.go @@ -0,0 +1,313 @@ +package gemini + +import ( + "encoding/json" + "fmt" + "strings" + + "ds2api/internal/adapter/openai" + "ds2api/internal/config" + "ds2api/internal/util" +) + +func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) { + requestedModel := strings.TrimSpace(routeModel) + if requestedModel == "" { + return util.StandardRequest{}, fmt.Errorf("model is required in request path") + } + + resolvedModel, ok := config.ResolveModel(store, requestedModel) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + + messagesRaw := geminiMessagesFromRequest(req) + if len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.") + } + + toolsRaw := convertGeminiTools(req["tools"]) + finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "") + passThrough := collectGeminiPassThrough(req) + + return util.StandardRequest{ + Surface: "google_gemini", + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + ResponseModel: requestedModel, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: stream, + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} + +func geminiMessagesFromRequest(req map[string]any) []any { + out := make([]any, 0, 8) + if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" { + out = append(out, map[string]any{ + "role": "system", + "content": sys, + }) + } + + contents, _ := req["contents"].([]any) + for _, item := range contents { + content, ok := item.(map[string]any) + if !ok { + continue + } + role := mapGeminiRole(content["role"]) + if role == "" { + role = "user" + } + parts, _ := content["parts"].([]any) + if len(parts) == 0 { + if text := strings.TrimSpace(asString(content["text"])); text != "" { + out = append(out, map[string]any{ + "role": role, + "content": text, + }) + } + continue + } + + textParts := make([]string, 0, len(parts)) + flushText := func() { + if len(textParts) == 0 { + return + } + out = append(out, map[string]any{ + "role": role, + "content": strings.Join(textParts, "\n"), + }) + textParts = textParts[:0] + } + + for _, rawPart := range parts { + part, ok := rawPart.(map[string]any) + if !ok { + continue + } + if text := strings.TrimSpace(asString(part["text"])); text != "" { + textParts = append(textParts, text) + continue + } + + if fnCall, ok := part["functionCall"].(map[string]any); ok { + flushText() + if name := strings.TrimSpace(asString(fnCall["name"])); name != "" { + callID := strings.TrimSpace(asString(fnCall["id"])) + if callID == "" { + callID = "call_gemini" + } + out = append(out, map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": callID, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": stringifyJSON(fnCall["args"]), + }, + }, + }, + }) + } + continue + } + + if fnResp, ok := part["functionResponse"].(map[string]any); ok { + flushText() + name := strings.TrimSpace(asString(fnResp["name"])) + callID := strings.TrimSpace(asString(fnResp["id"])) + if callID == "" { + callID = strings.TrimSpace(asString(fnResp["callId"])) + } + if callID == "" { + callID = strings.TrimSpace(asString(fnResp["tool_call_id"])) + } + if callID == "" { + callID = "call_gemini" + } + content := fnResp["response"] + if content == nil { + content = fnResp["output"] + } + if content == nil { + content = "" + } + msg := map[string]any{ + "role": "tool", + "tool_call_id": callID, + "content": content, + } + if name != "" { + msg["name"] = name + } + out = append(out, msg) + } + } + flushText() + } + return out +} + +func normalizeGeminiSystemInstruction(raw any) string { + switch v := raw.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if parts, ok := v["parts"].([]any); ok { + texts := make([]string, 0, len(parts)) + for _, item := range parts { + part, ok := item.(map[string]any) + if !ok { + continue + } + if text := strings.TrimSpace(asString(part["text"])); text != "" { + texts = append(texts, text) + } + } + return strings.Join(texts, "\n") + } + if text := strings.TrimSpace(asString(v["text"])); text != "" { + return text + } + } + return "" +} + +func mapGeminiRole(v any) string { + switch strings.ToLower(strings.TrimSpace(asString(v))) { + case "user": + return "user" + case "model", "assistant": + return "assistant" + case "system": + return "system" + default: + return "" + } +} + +func convertGeminiTools(raw any) []any { + tools, _ := raw.([]any) + if len(tools) == 0 { + return nil + } + out := make([]any, 0, len(tools)) + for _, item := range tools { + tool, ok := item.(map[string]any) + if !ok { + continue + } + + if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 { + for _, declRaw := range fnDecls { + decl, ok := declRaw.(map[string]any) + if !ok { + continue + } + name := strings.TrimSpace(asString(decl["name"])) + if name == "" { + continue + } + function := map[string]any{ + "name": name, + } + if desc := strings.TrimSpace(asString(decl["description"])); desc != "" { + function["description"] = desc + } + if params, ok := decl["parameters"].(map[string]any); ok { + function["parameters"] = params + } + out = append(out, map[string]any{ + "type": "function", + "function": function, + }) + } + continue + } + + // OpenAI-style passthrough fallback. + if _, ok := tool["function"].(map[string]any); ok { + out = append(out, tool) + continue + } + + // Loose fallback for flattened function schema objects. + name := strings.TrimSpace(asString(tool["name"])) + if name == "" { + continue + } + fn := map[string]any{"name": name} + if desc := strings.TrimSpace(asString(tool["description"])); desc != "" { + fn["description"] = desc + } + if params, ok := tool["parameters"].(map[string]any); ok { + fn["parameters"] = params + } + out = append(out, map[string]any{ + "type": "function", + "function": fn, + }) + } + if len(out) == 0 { + return nil + } + return out +} + +func collectGeminiPassThrough(req map[string]any) map[string]any { + cfg, _ := req["generationConfig"].(map[string]any) + if len(cfg) == 0 { + return nil + } + out := map[string]any{} + if v, ok := cfg["temperature"]; ok { + out["temperature"] = v + } + if v, ok := cfg["topP"]; ok { + out["top_p"] = v + } + if v, ok := cfg["maxOutputTokens"]; ok { + out["max_tokens"] = v + } + if v, ok := cfg["stopSequences"]; ok { + out["stop"] = v + } + if len(out) == 0 { + return nil + } + return out +} + +func asString(v any) string { + s, _ := v.(string) + return s +} + +func stringifyJSON(v any) string { + switch x := v.(type) { + case nil: + return "{}" + case string: + s := strings.TrimSpace(x) + if s == "" { + return "{}" + } + return s + default: + b, err := json.Marshal(x) + if err != nil || len(b) == 0 { + return "{}" + } + return string(b) + } +} diff --git a/internal/adapter/gemini/deps.go b/internal/adapter/gemini/deps.go new file mode 100644 index 0000000..312114a --- /dev/null +++ b/internal/adapter/gemini/deps.go @@ -0,0 +1,29 @@ +package gemini + +import ( + "context" + "net/http" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" +) + +type AuthResolver interface { + Determine(req *http.Request) (*auth.RequestAuth, error) + Release(a *auth.RequestAuth) +} + +type DeepSeekCaller interface { + CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) + CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) +} + +type ConfigReader interface { + ModelAliases() map[string]string +} + +var _ AuthResolver = (*auth.Resolver)(nil) +var _ DeepSeekCaller = (*deepseek.Client)(nil) +var _ ConfigReader = (*config.Store)(nil) diff --git a/internal/adapter/gemini/handler.go b/internal/adapter/gemini/handler.go new file mode 100644 index 0000000..8daaeda --- /dev/null +++ b/internal/adapter/gemini/handler.go @@ -0,0 +1,348 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" + "ds2api/internal/deepseek" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +var writeJSON = util.WriteJSON + +type Handler struct { + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller +} + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent) + r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent) + r.Post("/v1/models/{model}:generateContent", h.GenerateContent) + r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent) +} + +func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) { + h.handleGenerateContent(w, r, false) +} + +func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) { + h.handleGenerateContent(w, r, true) +} + +func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeGeminiError(w, status, detail) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeGeminiError(w, http.StatusBadRequest, "invalid json") + return + } + + routeModel := strings.TrimSpace(chi.URLParam(r, "model")) + stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream) + if err != nil { + writeGeminiError(w, http.StatusBadRequest, err.Error()) + return + } + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeGeminiError(w, http.StatusUnauthorized, "Invalid token.") + } + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + payload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.") + return + } + + if stream { + h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + return + } + h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) +} + +func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + + result := sse.CollectStream(resp, thinkingEnabled, true) + writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames)) +} + +func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + rc := http.NewResponseController(w) + _, canFlush := w.(http.Flusher) + runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnParsed: runtime.onParsed, + OnFinalize: func(_ streamengine.StopReason, _ error) { + runtime.finalize() + }, + }) +} + +func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) + usage := buildGeminiUsage(finalPrompt, finalThinking, finalText) + return map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + "finishReason": "STOP", + }, + }, + "modelVersion": model, + "usageMetadata": usage, + } +} + +func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "promptTokenCount": promptTokens, + "candidatesTokenCount": reasoningTokens + completionTokens, + "totalTokenCount": promptTokens + reasoningTokens + completionTokens, + } +} + +func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { + detected = util.ParseToolCalls(finalThinking, toolNames) + } + if len(detected) > 0 { + parts := make([]map[string]any, 0, len(detected)) + for _, tc := range detected { + parts = append(parts, map[string]any{ + "functionCall": map[string]any{ + "name": tc.Name, + "args": tc.Input, + }, + }) + } + return parts + } + + text := finalText + if strings.TrimSpace(text) == "" { + text = finalThinking + } + return []map[string]any{{"text": text}} +} + +type geminiStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + model string + finalPrompt string + + thinkingEnabled bool + searchEnabled bool + bufferContent bool + toolNames []string + + thinking strings.Builder + text strings.Builder +} + +func newGeminiStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, +) *geminiStreamRuntime { + return &geminiStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + model: model, + finalPrompt: finalPrompt, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferContent: len(toolNames) > 0, + toolNames: toolNames, + } +} + +func (s *geminiStreamRuntime) sendChunk(payload map[string]any) { + b, _ := json.Marshal(payload) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + if p.Type == "thinking" { + if s.thinkingEnabled { + s.thinking.WriteString(p.Text) + } + continue + } + s.text.WriteString(p.Text) + if s.bufferContent { + continue + } + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": []map[string]any{{"text": p.Text}}, + }, + }, + }, + "modelVersion": s.model, + }) + } + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} + +func (s *geminiStreamRuntime) finalize() { + finalThinking := s.thinking.String() + finalText := s.text.String() + + if s.bufferContent { + parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames) + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + }, + }, + "modelVersion": s.model, + }) + } + + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "finishReason": "STOP", + }, + }, + "modelVersion": s.model, + "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText), + }) +} + +func writeGeminiError(w http.ResponseWriter, status int, message string) { + errorStatus := "INVALID_ARGUMENT" + switch status { + case http.StatusUnauthorized: + errorStatus = "UNAUTHENTICATED" + case http.StatusForbidden: + errorStatus = "PERMISSION_DENIED" + case http.StatusTooManyRequests: + errorStatus = "RESOURCE_EXHAUSTED" + case http.StatusNotFound: + errorStatus = "NOT_FOUND" + default: + if status >= 500 { + errorStatus = "INTERNAL" + } + } + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "code": status, + "message": message, + "status": errorStatus, + }, + }) +} diff --git a/internal/adapter/gemini/handler_test.go b/internal/adapter/gemini/handler_test.go new file mode 100644 index 0000000..862750a --- /dev/null +++ b/internal/adapter/gemini/handler_test.go @@ -0,0 +1,174 @@ +package gemini + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" +) + +type testGeminiConfig struct{} + +func (testGeminiConfig) ModelAliases() map[string]string { return nil } + +type testGeminiAuth struct { + a *auth.RequestAuth + err error +} + +func (m testGeminiAuth) Determine(_ *http.Request) (*auth.RequestAuth, error) { + if m.err != nil { + return nil, m.err + } + if m.a != nil { + return m.a, nil + } + return &auth.RequestAuth{ + UseConfigToken: false, + DeepSeekToken: "direct-token", + CallerID: "caller:test", + TriedAccounts: map[string]bool{}, + }, nil +} + +func (testGeminiAuth) Release(_ *auth.RequestAuth) {} + +type testGeminiDS struct { + resp *http.Response + err error +} + +func (m testGeminiDS) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "session-id", nil +} + +func (m testGeminiDS) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) { + return "pow", nil +} + +func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) { + if m.err != nil { + return nil, m.err + } + return m.resp, nil +} + +func makeGeminiUpstreamResponse(lines ...string) *http.Response { + body := strings.Join(lines, "\n") + if !strings.HasSuffix(body, "\n") { + body += "\n" + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestGeminiRoutesRegistered(t *testing.T) { + h := &Handler{ + Store: testGeminiConfig{}, + Auth: testGeminiAuth{err: auth.ErrUnauthorized}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + paths := []string{ + "/v1beta/models/gemini-2.5-pro:generateContent", + "/v1beta/models/gemini-2.5-pro:streamGenerateContent", + "/v1/models/gemini-2.5-pro:generateContent", + "/v1/models/gemini-2.5-pro:streamGenerateContent", + } + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`)) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code == http.StatusNotFound { + t.Fatalf("expected route %s to be registered, got 404", path) + } + } +} + +func TestGenerateContentReturnsFunctionCallParts(t *testing.T) { + upstream := makeGeminiUpstreamResponse( + `data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`, + `data: [DONE]`, + ) + h := &Handler{ + Store: testGeminiConfig{}, + Auth: testGeminiAuth{}, + DS: testGeminiDS{resp: upstream}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := `{ + "contents":[{"role":"user","parts":[{"text":"call tool"}]}], + "tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer direct-token") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + candidates, _ := out["candidates"].([]any) + if len(candidates) == 0 { + t.Fatalf("expected non-empty candidates: %#v", out) + } + c0, _ := candidates[0].(map[string]any) + content, _ := c0["content"].(map[string]any) + parts, _ := content["parts"].([]any) + if len(parts) == 0 { + t.Fatalf("expected non-empty parts: %#v", content) + } + part0, _ := parts[0].(map[string]any) + functionCall, _ := part0["functionCall"].(map[string]any) + if functionCall["name"] != "eval_javascript" { + t.Fatalf("expected functionCall name eval_javascript, got %#v", functionCall) + } +} + +func TestStreamGenerateContentEmitsSSE(t *testing.T) { + upstream := makeGeminiUpstreamResponse( + `data: {"p":"response/content","v":"hello "}`, + `data: {"p":"response/content","v":"world"}`, + `data: [DONE]`, + ) + h := &Handler{ + Store: testGeminiConfig{}, + Auth: testGeminiAuth{}, + DS: testGeminiDS{resp: upstream}, + } + r := chi.NewRouter() + RegisterRoutes(r, h) + + body := `{"contents":[{"role":"user","parts":[{"text":"hello"}]}]}` + req := httptest.NewRequest(http.MethodPost, "/v1/models/gemini-2.5-pro:streamGenerateContent?alt=sse", strings.NewReader(body)) + req.Header.Set("Authorization", "Bearer direct-token") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "data: ") { + t.Fatalf("expected SSE data frames, got body=%s", rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"finishReason":"STOP"`) { + t.Fatalf("expected stream finish frame, got body=%s", rec.Body.String()) + } +} diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 517c88a..391a035 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -154,7 +154,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") rc := http.NewResponseController(w) - canFlush := rc.Flush() == nil + _, canFlush := w.(http.Flusher) if !canFlush { config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered") } diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go index 890e3dc..76739ed 100644 --- a/internal/adapter/openai/prompt_build.go +++ b/internal/adapter/openai/prompt_build.go @@ -12,3 +12,10 @@ func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (st } return deepseek.MessagesPrepare(messages), toolNames } + +// BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so +// other protocol adapters (for example Gemini) can reuse the same tool/history +// normalization logic and remain behavior-compatible with chat/completions. +func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) { + return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID) +} diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index bd9ff3a..e71cafe 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -128,7 +128,7 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") rc := http.NewResponseController(w) - canFlush := rc.Flush() == nil + _, canFlush := w.(http.Flusher) initialType := "text" if thinkingEnabled { diff --git a/internal/adapter/openai/responses_stream_runtime.go b/internal/adapter/openai/responses_stream_runtime.go index d059ca1..050965c 100644 --- a/internal/adapter/openai/responses_stream_runtime.go +++ b/internal/adapter/openai/responses_stream_runtime.go @@ -114,13 +114,7 @@ func (s *responsesStreamRuntime) finalize() { // Compatibility fallback: some streams only emit incremental tool deltas. // Ensure final function_call_arguments.done is emitted at least once. if s.toolCallsEmitted { - detected := util.ParseStandaloneToolCalls(finalText, s.toolNames) - if len(detected) == 0 { - detected = util.ParseToolCalls(finalText, s.toolNames) - } - if len(detected) == 0 { - detected = util.ParseStandaloneToolCalls(finalThinking, s.toolNames) - } + detected := util.ParseToolCalls(finalText, s.toolNames) if len(detected) == 0 { detected = util.ParseToolCalls(finalThinking, s.toolNames) } diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index a47903c..a513e6f 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -381,6 +381,48 @@ func TestHandleResponsesStreamMultiToolCallFromThinkingChannel(t *testing.T) { } } +func TestHandleResponsesStreamCompletedFollowsChatToolCallSemantics(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("我来调用工具\n") + + sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}) + + completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed") + if !ok { + t.Fatalf("expected response.completed event, body=%s", rec.Body.String()) + } + responseObj, _ := completed["response"].(map[string]any) + output, _ := responseObj["output"].([]any) + hasFunctionCall := false + for _, item := range output { + m, _ := item.(map[string]any) + if m != nil && m["type"] == "function_call" { + hasFunctionCall = true + break + } + } + if !hasFunctionCall { + t.Fatalf("expected completed output to include function_call when mixed prose contains tool_calls payload, output=%#v", output) + } +} + func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) { scanner := bufio.NewScanner(strings.NewReader(body)) matched := false diff --git a/internal/auth/request.go b/internal/auth/request.go index 25980cf..23acb40 100644 --- a/internal/auth/request.go +++ b/internal/auth/request.go @@ -187,7 +187,12 @@ func extractCallerToken(req *http.Request) string { return token } } - return strings.TrimSpace(req.Header.Get("x-api-key")) + if key := strings.TrimSpace(req.Header.Get("x-api-key")); key != "" { + return key + } + // Gemini AI Studio compatibility: allow query key fallback only when no + // header-based credential is present. + return strings.TrimSpace(req.URL.Query().Get("key")) } func callerTokenID(token string) string { diff --git a/internal/auth/request_test.go b/internal/auth/request_test.go index c292856..2eca44b 100644 --- a/internal/auth/request_test.go +++ b/internal/auth/request_test.go @@ -114,6 +114,40 @@ func TestDetermineMissingToken(t *testing.T) { } } +func TestDetermineWithQueryKeyUsesDirectToken(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=direct-query-key", nil) + + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + if a.UseConfigToken { + t.Fatalf("expected direct token mode") + } + if a.DeepSeekToken != "direct-query-key" { + t.Fatalf("unexpected token: %q", a.DeepSeekToken) + } +} + +func TestDetermineHeaderTokenPrecedenceOverQueryKey(t *testing.T) { + r := newTestResolver(t) + req, _ := http.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent?key=query-key", nil) + req.Header.Set("x-api-key", "managed-key") + + a, err := r.Determine(req) + if err != nil { + t.Fatalf("determine failed: %v", err) + } + defer r.Release(a) + if !a.UseConfigToken { + t.Fatalf("expected managed key mode from header token") + } + if a.AccountID == "" { + t.Fatalf("expected managed account to be acquired") + } +} + func TestDetermineCallerMissingToken(t *testing.T) { r := newTestResolver(t) req, _ := http.NewRequest(http.MethodGet, "/v1/responses/resp_1", nil) diff --git a/internal/format/openai/render.go b/internal/format/openai/render.go index 3f0519a..2107d4e 100644 --- a/internal/format/openai/render.go +++ b/internal/format/openai/render.go @@ -44,12 +44,11 @@ func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalT } func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { - // Responses output should only be treated as tool calls when the model - // produced a standalone structured payload. This prevents accidental - // empty output_text on normal prose that merely contains tool_call-like text. - detected := util.ParseStandaloneToolCalls(finalText, toolNames) + // Align responses tool-call semantics with chat/completions: + // mixed prose + tool_call payloads should still be interpreted as tool calls. + detected := util.ParseToolCalls(finalText, toolNames) if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { - detected = util.ParseStandaloneToolCalls(finalThinking, toolNames) + detected = util.ParseToolCalls(finalThinking, toolNames) } exposedOutputText := finalText output := make([]any, 0, 2) diff --git a/internal/format/openai/render_test.go b/internal/format/openai/render_test.go index 2e36903..e3bf0dd 100644 --- a/internal/format/openai/render_test.go +++ b/internal/format/openai/render_test.go @@ -70,7 +70,7 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) { } } -func TestBuildResponseObjectKeepsOutputTextForMixedProse(t *testing.T) { +func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) { obj := BuildResponseObject( "resp_test", "gpt-4o", @@ -81,17 +81,41 @@ func TestBuildResponseObjectKeepsOutputTextForMixedProse(t *testing.T) { ) outputText, _ := obj["output_text"].(string) - if outputText == "" { - t.Fatalf("expected output_text to be preserved for mixed prose") + if outputText != "" { + t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText) } + output, _ := obj["output"].([]any) + if len(output) != 2 { + t.Fatalf("expected function_call + tool_calls wrapper, got %#v", obj["output"]) + } + first, _ := output[0].(map[string]any) + if first["type"] != "function_call" { + t.Fatalf("expected first output type function_call, got %#v", first["type"]) + } +} + +func TestBuildResponseObjectFencedToolPayloadRemainsText(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "", + "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"golang\"}}]}\n```", + []string{"search"}, + ) + + outputText, _ := obj["output_text"].(string) + if outputText == "" { + t.Fatalf("expected output_text preserved for fenced example") + } output, _ := obj["output"].([]any) if len(output) != 1 { - t.Fatalf("expected one output item, got %#v", obj["output"]) + t.Fatalf("expected one message output item, got %#v", obj["output"]) } first, _ := output[0].(map[string]any) if first["type"] != "message" { - t.Fatalf("expected output type message, got %#v", first["type"]) + t.Fatalf("expected message output type, got %#v", first["type"]) } } diff --git a/internal/server/router.go b/internal/server/router.go index a81f0cb..ae3108e 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -12,6 +12,7 @@ import ( "ds2api/internal/account" "ds2api/internal/adapter/claude" + "ds2api/internal/adapter/gemini" "ds2api/internal/adapter/openai" "ds2api/internal/admin" "ds2api/internal/auth" @@ -44,6 +45,7 @@ func NewApp() *App { openaiHandler := &openai.Handler{Store: store, Auth: resolver, DS: dsClient} claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient} + geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient} adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient} webuiHandler := webui.NewHandler() @@ -67,6 +69,7 @@ func NewApp() *App { }) openai.RegisterRoutes(r, openaiHandler) claude.RegisterRoutes(r, claudeHandler) + gemini.RegisterRoutes(r, geminiHandler) r.Route("/admin", func(ar chi.Router) { admin.RegisterRoutes(ar, adminHandler) })