diff --git a/internal/httpapi/openai/citation_links_test.go b/internal/httpapi/openai/citation_links_test.go index 1cdaf90..a7f10d0 100644 --- a/internal/httpapi/openai/citation_links_test.go +++ b/internal/httpapi/openai/citation_links_test.go @@ -26,3 +26,31 @@ func TestReplaceCitationMarkersWithLinksKeepsUnknownIndex(t *testing.T) { t.Fatalf("expected %q, got %q", want, got) } } + +func TestReplaceCitationMarkersWithLinksSupportsReferenceMarker(t *testing.T) { + raw := "新闻摘要[reference:1],详情[reference:2]。" + links := map[int]string{ + 1: "https://example.com/r1", + 2: "https://example.com/r2", + } + + got := replaceCitationMarkersWithLinks(raw, links) + want := "新闻摘要[1](https://example.com/r1),详情[2](https://example.com/r2)。" + if got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} + +func TestReplaceCitationMarkersWithLinksSupportsReferenceZeroBased(t *testing.T) { + raw := "来源[reference:0] 与 [reference:1]。" + links := map[int]string{ + 1: "https://example.com/first", + 2: "https://example.com/second", + } + + got := replaceCitationMarkersWithLinks(raw, links) + want := "来源[0](https://example.com/first) 与 [1](https://example.com/second)。" + if got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} diff --git a/internal/httpapi/openai/shared/citation_links.go b/internal/httpapi/openai/shared/citation_links.go index 60d7408..9b2b77f 100644 --- a/internal/httpapi/openai/shared/citation_links.go +++ b/internal/httpapi/openai/shared/citation_links.go @@ -7,22 +7,27 @@ import ( "strings" ) -var citationMarkerPattern = regexp.MustCompile(`(?i)\[citation:\s*(\d+)\]`) +var citationMarkerPattern = regexp.MustCompile(`(?i)\[(citation|reference):\s*(\d+)\]`) func ReplaceCitationMarkersWithLinks(text string, links map[int]string) string { if strings.TrimSpace(text) == "" || len(links) == 0 { return text } + zeroBased := strings.Contains(strings.ToLower(text), "[reference:0]") return citationMarkerPattern.ReplaceAllStringFunc(text, func(match string) string { sub := citationMarkerPattern.FindStringSubmatch(match) - if len(sub) < 2 { + if len(sub) < 3 { return match } - idx, err := strconv.Atoi(strings.TrimSpace(sub[1])) - if err != nil || idx <= 0 { + idx, err := strconv.Atoi(strings.TrimSpace(sub[2])) + if err != nil || idx < 0 { return match } - url := strings.TrimSpace(links[idx]) + lookupIdx := idx + if zeroBased { + lookupIdx = idx + 1 + } + url := strings.TrimSpace(links[lookupIdx]) if url == "" { return match }