diff --git a/internal/adapter/gemini/handler_generate.go b/internal/adapter/gemini/handler_generate.go index 5703d0b..a7de92d 100644 --- a/internal/adapter/gemini/handler_generate.go +++ b/internal/adapter/gemini/handler_generate.go @@ -90,6 +90,11 @@ func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream defer res.Body.Close() body, _ := io.ReadAll(res.Body) if res.StatusCode < 200 || res.StatusCode >= 300 { + for k, vv := range res.Header { + for _, v := range vv { + w.Header().Add(k, v) + } + } writeGeminiErrorFromOpenAI(w, res.StatusCode, body) return true } diff --git a/internal/adapter/gemini/handler_test.go b/internal/adapter/gemini/handler_test.go index 20cc0e6..fdb4b79 100644 --- a/internal/adapter/gemini/handler_test.go +++ b/internal/adapter/gemini/handler_test.go @@ -64,9 +64,13 @@ func (m testGeminiDS) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ m type geminiOpenAIErrorStub struct { status int body string + headers map[string]string } func (s geminiOpenAIErrorStub) ChatCompletions(w http.ResponseWriter, _ *http.Request) { + for k, v := range s.headers { + w.Header().Set(k, v) + } w.Header().Set("Content-Type", "application/json") w.WriteHeader(s.status) _, _ = w.Write([]byte(s.body)) @@ -244,7 +248,15 @@ func TestStreamGenerateContentEmitsSSE(t *testing.T) { func TestGenerateContentOpenAIProxyErrorUsesGeminiEnvelope(t *testing.T) { h := &Handler{ Store: testGeminiConfig{}, - OpenAI: geminiOpenAIErrorStub{status: http.StatusUnauthorized, body: `{"error":{"message":"invalid api key"}}`}, + OpenAI: geminiOpenAIErrorStub{ + status: http.StatusUnauthorized, + body: `{"error":{"message":"invalid api key"}}`, + headers: map[string]string{ + "WWW-Authenticate": `Bearer realm="example"`, + "Retry-After": "30", + "X-RateLimit-Remaining": "0", + }, + }, } r := chi.NewRouter() RegisterRoutes(r, h) @@ -267,6 +279,15 @@ func TestGenerateContentOpenAIProxyErrorUsesGeminiEnvelope(t *testing.T) { if errObj["message"] != "invalid api key" { t.Fatalf("expected parsed error message, got=%v", errObj["message"]) } + if got := rec.Header().Get("WWW-Authenticate"); got == "" { + t.Fatalf("expected WWW-Authenticate header to be preserved") + } + if got := rec.Header().Get("Retry-After"); got != "30" { + t.Fatalf("expected Retry-After header 30, got=%q", got) + } + if got := rec.Header().Get("X-RateLimit-Remaining"); got != "0" { + t.Fatalf("expected X-RateLimit-Remaining header 0, got=%q", got) + } } func extractGeminiSSEFrames(t *testing.T, body string) []map[string]any {