diff --git a/internal/adapter/openai/citation_links.go b/internal/adapter/openai/citation_links.go new file mode 100644 index 0000000..009d728 --- /dev/null +++ b/internal/adapter/openai/citation_links.go @@ -0,0 +1,31 @@ +package openai + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +var citationMarkerPattern = regexp.MustCompile(`(?i)\[citation:\s*(\d+)\]`) + +func replaceCitationMarkersWithLinks(text string, links map[int]string) string { + if strings.TrimSpace(text) == "" || len(links) == 0 { + return text + } + return citationMarkerPattern.ReplaceAllStringFunc(text, func(match string) string { + sub := citationMarkerPattern.FindStringSubmatch(match) + if len(sub) < 2 { + return match + } + idx, err := strconv.Atoi(strings.TrimSpace(sub[1])) + if err != nil || idx <= 0 { + return match + } + url := strings.TrimSpace(links[idx]) + if url == "" { + return match + } + return fmt.Sprintf("[%d](%s)", idx, url) + }) +} diff --git a/internal/adapter/openai/citation_links_test.go b/internal/adapter/openai/citation_links_test.go new file mode 100644 index 0000000..1cdaf90 --- /dev/null +++ b/internal/adapter/openai/citation_links_test.go @@ -0,0 +1,28 @@ +package openai + +import "testing" + +func TestReplaceCitationMarkersWithLinks(t *testing.T) { + raw := "这是一条更新[citation:1],更多信息见[citation:2]。" + links := map[int]string{ + 1: "https://example.com/news-1", + 2: "https://example.com/news-2", + } + + got := replaceCitationMarkersWithLinks(raw, links) + want := "这是一条更新[1](https://example.com/news-1),更多信息见[2](https://example.com/news-2)。" + if got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} + +func TestReplaceCitationMarkersWithLinksKeepsUnknownIndex(t *testing.T) { + raw := "只有一个来源[citation:1],未知来源[citation:3]。" + links := map[int]string{1: "https://example.com/a"} + + got := replaceCitationMarkersWithLinks(raw, links) + want := "只有一个来源[1](https://example.com/a),未知来源[citation:3]。" + if got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} diff --git a/internal/adapter/openai/handler_chat.go b/internal/adapter/openai/handler_chat.go index 5599eec..e2c5f3c 100644 --- a/internal/adapter/openai/handler_chat.go +++ b/internal/adapter/openai/handler_chat.go @@ -88,7 +88,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) return } - h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) + h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) } func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) { @@ -124,7 +124,7 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu } } -func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { +func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { if resp.StatusCode != http.StatusOK { defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) @@ -137,6 +137,9 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re stripReferenceMarkers := h.compatStripReferenceMarkers() finalThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers) finalText := cleanVisibleOutput(result.Text, stripReferenceMarkers) + if searchEnabled { + finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks) + } if writeUpstreamEmptyOutputError(w, finalText, result.ContentFilter) { return } diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index bad8820..a274d5b 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -94,7 +94,7 @@ func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid-empty", "deepseek-chat", "prompt", false, nil) + h.handleNonStream(rec, context.Background(), resp, "cid-empty", "deepseek-chat", "prompt", false, false, nil) if rec.Code != http.StatusTooManyRequests { t.Fatalf("expected status 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String()) } @@ -113,7 +113,7 @@ func TestHandleNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWithoutOutp ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, nil) + h.handleNonStream(rec, context.Background(), resp, "cid-empty-filtered", "deepseek-chat", "prompt", false, false, nil) if rec.Code != http.StatusBadRequest { t.Fatalf("expected status 400 for filtered upstream output, got %d body=%s", rec.Code, rec.Body.String()) } @@ -132,7 +132,7 @@ func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) { ) rec := httptest.NewRecorder() - h.handleNonStream(rec, context.Background(), resp, "cid-thinking-only", "deepseek-reasoner", "prompt", true, nil) + h.handleNonStream(rec, context.Background(), resp, "cid-thinking-only", "deepseek-reasoner", "prompt", true, false, nil) if rec.Code != http.StatusTooManyRequests { t.Fatalf("expected status 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String()) } diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 6494157..35c616b 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -112,10 +112,10 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { h.handleResponsesStream(w, r, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID) return } - h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames, stdReq.ToolChoice, traceID) + h.handleResponsesNonStream(w, resp, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolChoice, traceID) } -func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) { +func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolChoice util.ToolChoicePolicy, traceID string) { defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -126,6 +126,9 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res stripReferenceMarkers := h.compatStripReferenceMarkers() sanitizedThinking := cleanVisibleOutput(result.Thinking, stripReferenceMarkers) sanitizedText := cleanVisibleOutput(result.Text, stripReferenceMarkers) + if searchEnabled { + sanitizedText = replaceCitationMarkersWithLinks(sanitizedText, result.CitationLinks) + } if writeUpstreamEmptyOutputError(w, sanitizedText, result.ContentFilter) { return } diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index 2e139d3..f9f170e 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -196,7 +196,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) { Allowed: map[string]struct{}{"read_file": {}}, } - h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, []string{"read_file"}, policy, "") + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "") if rec.Code != http.StatusUnprocessableEntity { t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String()) } @@ -223,7 +223,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t Allowed: map[string]struct{}{"read_file": {}}, } - h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "") + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "") if rec.Code != http.StatusUnprocessableEntity { t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String()) } @@ -245,7 +245,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) )), } - h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "") + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "") if rec.Code != http.StatusTooManyRequests { t.Fatalf("expected 429 for empty upstream output, got %d body=%s", rec.Code, rec.Body.String()) } @@ -267,7 +267,7 @@ func TestHandleResponsesNonStreamReturnsContentFilterErrorWhenUpstreamFilteredWi )), } - h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, util.DefaultToolChoicePolicy(), "") + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "") if rec.Code != http.StatusBadRequest { t.Fatalf("expected 400 for filtered empty upstream output, got %d body=%s", rec.Code, rec.Body.String()) } @@ -289,7 +289,7 @@ func TestHandleResponsesNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testin )), } - h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, nil, util.DefaultToolChoicePolicy(), "") + h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, nil, util.DefaultToolChoicePolicy(), "") if rec.Code != http.StatusTooManyRequests { t.Fatalf("expected 429 for thinking-only upstream output, got %d body=%s", rec.Code, rec.Body.String()) } diff --git a/internal/sse/citation_links.go b/internal/sse/citation_links.go new file mode 100644 index 0000000..978d02d --- /dev/null +++ b/internal/sse/citation_links.go @@ -0,0 +1,124 @@ +package sse + +import ( + "strconv" + "strings" +) + +type citationLinkCollector struct { + ordered []string + seen map[string]struct{} + explicit map[int]string +} + +func newCitationLinkCollector() *citationLinkCollector { + return &citationLinkCollector{ + seen: map[string]struct{}{}, + explicit: map[int]string{}, + } +} + +func (c *citationLinkCollector) ingestChunk(chunk map[string]any) { + if c == nil || len(chunk) == 0 { + return + } + c.walkValue(chunk) +} + +func (c *citationLinkCollector) build() map[int]string { + out := make(map[int]string, len(c.explicit)+len(c.ordered)) + for idx, u := range c.explicit { + if idx > 0 && strings.TrimSpace(u) != "" { + out[idx] = u + } + } + for i, u := range c.ordered { + idx := i + 1 + if _, exists := out[idx]; !exists { + out[idx] = u + } + } + return out +} + +func (c *citationLinkCollector) walkValue(v any) { + switch x := v.(type) { + case []any: + for _, item := range x { + c.walkValue(item) + } + case map[string]any: + c.captureURLAndIndex(x) + for _, vv := range x { + c.walkValue(vv) + } + } +} + +func (c *citationLinkCollector) captureURLAndIndex(m map[string]any) { + url := strings.TrimSpace(asString(m["url"])) + if !isWebURL(url) { + return + } + c.addOrdered(url) + + idx, hasIdx := citationIndexFromAny(m["cite_index"]) + if !hasIdx { + return + } + if idx <= 0 { + idx = idx + 1 + } + if idx <= 0 { + return + } + if existing, ok := c.explicit[idx]; ok && strings.TrimSpace(existing) != "" { + return + } + c.explicit[idx] = url +} + +func (c *citationLinkCollector) addOrdered(url string) { + if _, ok := c.seen[url]; ok { + return + } + c.seen[url] = struct{}{} + c.ordered = append(c.ordered, url) +} + +func citationIndexFromAny(v any) (int, bool) { + switch x := v.(type) { + case int: + return x, true + case int32: + return int(x), true + case int64: + return int(x), true + case float32: + return int(x), true + case float64: + return int(x), true + case string: + s := strings.TrimSpace(x) + if s == "" { + return 0, false + } + n, err := strconv.Atoi(s) + if err != nil { + return 0, false + } + return n, true + default: + return 0, false + } +} + +func isWebURL(v string) bool { + v = strings.ToLower(strings.TrimSpace(v)) + return strings.HasPrefix(v, "http://") || strings.HasPrefix(v, "https://") +} + +func asString(v any) string { + s, _ := v.(string) + return s +} diff --git a/internal/sse/consumer.go b/internal/sse/consumer.go index 341db2b..83d66e9 100644 --- a/internal/sse/consumer.go +++ b/internal/sse/consumer.go @@ -13,6 +13,7 @@ type CollectResult struct { Text string Thinking string ContentFilter bool + CitationLinks map[int]string } // CollectStream fully consumes a DeepSeek SSE response and separates @@ -28,11 +29,15 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co text := strings.Builder{} thinking := strings.Builder{} contentFilter := false + collector := newCitationLinkCollector() currentType := "text" if thinkingEnabled { currentType = "thinking" } _ = deepseek.ScanSSELines(resp, func(line []byte) bool { + if chunk, done, parsed := ParseDeepSeekSSELine(line); parsed && !done { + collector.ingestChunk(chunk) + } result := ParseDeepSeekContentLine(line, thinkingEnabled, currentType) currentType = result.NextType if !result.Parsed { @@ -59,5 +64,6 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co Text: text.String(), Thinking: thinking.String(), ContentFilter: contentFilter, + CitationLinks: collector.build(), } } diff --git a/internal/sse/consumer_edge_test.go b/internal/sse/consumer_edge_test.go index 54f841b..d0da4bf 100644 --- a/internal/sse/consumer_edge_test.go +++ b/internal/sse/consumer_edge_test.go @@ -115,6 +115,22 @@ func TestCollectStreamWithCitation(t *testing.T) { } } +func TestCollectStreamExtractsCitationLinks(t *testing.T) { + resp := makeHTTPResponse( + "data: {\"p\":\"response/fragments/-1/results\",\"v\":[{\"url\":\"https://example.com/a\",\"cite_index\":0},{\"url\":\"https://example.com/b\",\"cite_index\":1}]}\n" + + "data: {\"p\":\"response/content\",\"v\":\"结论[citation:1][citation:2]\"}\n" + + "data: [DONE]\n", + ) + result := CollectStream(resp, false, false) + + if got := result.CitationLinks[1]; got != "https://example.com/a" { + t.Fatalf("expected citation 1 link, got %q", got) + } + if got := result.CitationLinks[2]; got != "https://example.com/b" { + t.Fatalf("expected citation 2 link, got %q", got) + } +} + func TestCollectStreamMultipleThinkingChunks(t *testing.T) { resp := makeHTTPResponse( "data: {\"p\":\"response/thinking_content\",\"v\":\"part1\"}\n" +