mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-08 18:35:35 +08:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
410efbd70b | ||
|
|
7179b995bb | ||
|
|
fef3798e5e | ||
|
|
00fe18b505 | ||
|
|
9b746e32d8 | ||
|
|
ace440481a | ||
|
|
66e0fa568f | ||
|
|
fa489248bc | ||
|
|
657b9379ed | ||
|
|
9062330104 | ||
|
|
d0d61a5d77 | ||
|
|
ffef451f7a | ||
|
|
a68a79e087 | ||
|
|
c8db66615c | ||
|
|
79ae9c8970 | ||
|
|
2378f0fbe7 |
18
API.en.md
18
API.en.md
@@ -18,6 +18,7 @@ Docs: [Overview](README.en.md) / [Architecture](docs/ARCHITECTURE.en.md) / [Depl
|
||||
- [OpenAI-Compatible API](#openai-compatible-api)
|
||||
- [Claude-Compatible API](#claude-compatible-api)
|
||||
- [Gemini-Compatible API](#gemini-compatible-api)
|
||||
- [Ollama API](#ollama-api)
|
||||
- [Admin API](#admin-api)
|
||||
- [Error Payloads](#error-payloads)
|
||||
- [cURL Examples](#curl-examples)
|
||||
@@ -123,6 +124,9 @@ Gemini-compatible clients can also send `x-goog-api-key`, `?key=`, or `?api_key=
|
||||
| POST | `/v1beta/models/{model}:streamGenerateContent` | Business | Gemini stream |
|
||||
| POST | `/v1/models/{model}:generateContent` | Business | Gemini non-stream compat path |
|
||||
| POST | `/v1/models/{model}:streamGenerateContent` | Business | Gemini stream compat path |
|
||||
| GET | `/api/version` | None | Ollama version endpoint |
|
||||
| GET | `/api/tags` | None | Ollama model list |
|
||||
| POST | `/api/show` | None | Ollama model capability query (returns `id` + `capabilities`) |
|
||||
| POST | `/admin/login` | None | Admin login |
|
||||
| GET | `/admin/verify` | JWT | Verify admin JWT |
|
||||
| GET | `/admin/vercel/config` | Admin | Read preconfigured Vercel creds |
|
||||
@@ -617,6 +621,20 @@ Returns SSE (`text/event-stream`), each chunk as `data: <json>`:
|
||||
|
||||
---
|
||||
|
||||
## Ollama API
|
||||
|
||||
- `POST /api/show` request body: `{"model":"<model-id>"}`.
|
||||
- Response uses lowercase `id` (not `ID`) and includes `capabilities` for Ollama-style clients and strict schemas.
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "deepseek-v4-flash",
|
||||
"capabilities": ["tools", "thinking"]
|
||||
}
|
||||
```
|
||||
|
||||
## Admin API
|
||||
|
||||
### `POST /admin/login`
|
||||
|
||||
18
API.md
18
API.md
@@ -18,6 +18,7 @@
|
||||
- [OpenAI 兼容接口](#openai-兼容接口)
|
||||
- [Claude 兼容接口](#claude-兼容接口)
|
||||
- [Gemini 兼容接口](#gemini-兼容接口)
|
||||
- [Ollama 兼容接口](#ollama-兼容接口)
|
||||
- [Admin 接口](#admin-接口)
|
||||
- [错误响应格式](#错误响应格式)
|
||||
- [cURL 示例](#curl-示例)
|
||||
@@ -125,6 +126,9 @@ Gemini 兼容客户端还可以使用 `x-goog-api-key`、`?key=` 或 `?api_key=`
|
||||
| POST | `/v1beta/models/{model}:streamGenerateContent` | 业务 | Gemini 流式 |
|
||||
| POST | `/v1/models/{model}:generateContent` | 业务 | Gemini 非流式兼容路径 |
|
||||
| POST | `/v1/models/{model}:streamGenerateContent` | 业务 | Gemini 流式兼容路径 |
|
||||
| GET | `/api/version` | 无 | Ollama 版本接口 |
|
||||
| GET | `/api/tags` | 无 | Ollama 模型列表 |
|
||||
| POST | `/api/show` | 无 | Ollama 单模型能力查询(返回 `id` 与 `capabilities`) |
|
||||
| POST | `/admin/login` | 无 | 管理登录 |
|
||||
| GET | `/admin/verify` | JWT | 校验管理 JWT |
|
||||
| GET | `/admin/vercel/config` | Admin | 读取 Vercel 预配置 |
|
||||
@@ -628,6 +632,20 @@ data: {"type":"message_stop"}
|
||||
|
||||
---
|
||||
|
||||
## Ollama 兼容接口
|
||||
|
||||
- `POST /api/show` 请求体:`{"model":"<model-id>"}`。
|
||||
- 响应字段使用小写 `id`(不是 `ID`),并返回 `capabilities` 数组,便于与 Ollama 风格客户端/严格 schema 对齐。
|
||||
|
||||
示例响应:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "deepseek-v4-flash",
|
||||
"capabilities": ["tools", "thinking"]
|
||||
}
|
||||
```
|
||||
|
||||
## Admin 接口
|
||||
|
||||
### `POST /admin/login`
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
@@ -9,6 +12,16 @@ type ModelInfo struct {
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []any `json:"permission,omitempty"`
|
||||
}
|
||||
type OllamaModelInfo struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
}
|
||||
type OllamaCapabilitiesModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
}
|
||||
|
||||
type ModelAliasReader interface {
|
||||
ModelAliases() map[string]string
|
||||
@@ -24,8 +37,21 @@ var deepSeekBaseModels = []ModelInfo{
|
||||
{ID: "deepseek-v4-vision", Object: "model", Created: 1677610602, OwnedBy: "deepseek", Permission: []any{}},
|
||||
}
|
||||
|
||||
var DeepSeekModels = appendNoThinkingVariants(deepSeekBaseModels)
|
||||
var OllamaCapabilitiesModels = []OllamaCapabilitiesModelInfo{
|
||||
{ID: "deepseek-v4-flash", Capabilities: []string{"tools", "thinking"}},
|
||||
{ID: "deepseek-v4-pro", Capabilities: []string{"tools", "thinking"}},
|
||||
{ID: "deepseek-v4-flash-search", Capabilities: []string{"tools", "thinking"}},
|
||||
{ID: "deepseek-v4-pro-search", Capabilities: []string{"tools", "thinking"}},
|
||||
{ID: "deepseek-v4-vision", Capabilities: []string{"tools", "thinking", "vision"}},
|
||||
{ID: "deepseek-v4-flash-nothinking", Capabilities: []string{"tools"}},
|
||||
{ID: "deepseek-v4-pro-nothinking", Capabilities: []string{"tools"}},
|
||||
{ID: "deepseek-v4-flash-search-nothinking", Capabilities: []string{"tools"}},
|
||||
{ID: "deepseek-v4-pro-search-nothinking", Capabilities: []string{"tools"}},
|
||||
{ID: "deepseek-v4-vision-nothinking", Capabilities: []string{"tools", "vision"}},
|
||||
}
|
||||
|
||||
var DeepSeekModels = appendNoThinkingVariants(deepSeekBaseModels)
|
||||
var OllamaModels = mapToOllamaModels(DeepSeekModels)
|
||||
var claudeBaseModels = []ModelInfo{
|
||||
// Current aliases
|
||||
{ID: "claude-opus-4-6", Object: "model", Created: 1715635200, OwnedBy: "anthropic"},
|
||||
@@ -247,6 +273,23 @@ func OpenAIModelByID(store ModelAliasReader, id string) (ModelInfo, bool) {
|
||||
return ModelInfo{}, false
|
||||
}
|
||||
|
||||
func OllamaModelsResponse() map[string]any {
|
||||
return map[string]any{"models": OllamaModels}
|
||||
}
|
||||
|
||||
func OllamaModelByID(store ModelAliasReader, id string) (OllamaCapabilitiesModelInfo, bool) {
|
||||
canonical, ok := ResolveModel(store, id)
|
||||
if !ok {
|
||||
return OllamaCapabilitiesModelInfo{}, false
|
||||
}
|
||||
for _, model := range OllamaCapabilitiesModels {
|
||||
if model.ID == canonical {
|
||||
return model, true
|
||||
}
|
||||
}
|
||||
return OllamaCapabilitiesModelInfo{}, false
|
||||
}
|
||||
|
||||
func ClaudeModelsResponse() map[string]any {
|
||||
resp := map[string]any{"object": "list", "data": ClaudeModels}
|
||||
if len(ClaudeModels) > 0 {
|
||||
@@ -270,6 +313,23 @@ func appendNoThinkingVariants(models []ModelInfo) []ModelInfo {
|
||||
}
|
||||
return out
|
||||
}
|
||||
func mapToOllamaModels(models []ModelInfo) []OllamaModelInfo {
|
||||
out := make([]OllamaModelInfo, 0, len(models))
|
||||
for _, model := range models {
|
||||
var modifiedAt string
|
||||
if model.Created > 0 {
|
||||
modifiedAt = time.Unix(model.Created, 0).Format(time.RFC3339)
|
||||
}
|
||||
ollamaModel := OllamaModelInfo{
|
||||
Name: model.ID,
|
||||
Model: model.ID,
|
||||
Size: 0,
|
||||
ModifiedAt: modifiedAt,
|
||||
}
|
||||
out = append(out, ollamaModel)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func splitNoThinkingModel(model string) (string, bool) {
|
||||
model = lower(strings.TrimSpace(model))
|
||||
|
||||
@@ -58,6 +58,11 @@ func RawStreamSampleRoot() string {
|
||||
}
|
||||
|
||||
func ChatHistoryPath() string {
|
||||
// On Vercel, /var/task is read-only at runtime. If no explicit path is set,
|
||||
// default to /tmp/chat_history.json (the only writable directory).
|
||||
if IsVercel() && strings.TrimSpace(os.Getenv("DS2API_CHAT_HISTORY_PATH")) == "" {
|
||||
return "/tmp/chat_history.json"
|
||||
}
|
||||
return ResolvePath("DS2API_CHAT_HISTORY_PATH", "data/chat_history.json")
|
||||
}
|
||||
|
||||
|
||||
58
internal/httpapi/ollama/handler_routes.go
Normal file
58
internal/httpapi/ollama/handler_routes.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
"encoding/json"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var WriteJSON = util.WriteJSON
|
||||
|
||||
type ConfigReader interface {
|
||||
ModelAliases() map[string]string
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
Store ConfigReader
|
||||
}
|
||||
|
||||
type OllamaModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/api/version", h.GetVersion)
|
||||
r.Get("/api/tags", h.ListOllamaModels)
|
||||
r.Post("/api/show", h.GetOllamaModel)
|
||||
}
|
||||
|
||||
func (h *Handler) GetVersion(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"version":"0.23.1"}`))
|
||||
}
|
||||
func (h *Handler) ListOllamaModels(w http.ResponseWriter, r *http.Request) {
|
||||
WriteJSON(w, http.StatusOK, config.OllamaModelsResponse())
|
||||
}
|
||||
func (h *Handler) GetOllamaModel(w http.ResponseWriter, r *http.Request) {
|
||||
var payload OllamaModelRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, "Invalid JSON body: "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := r.Body.Close(); err != nil {
|
||||
slog.Warn("[ollama] failed to close request body", "error", err)
|
||||
}
|
||||
}()
|
||||
modelID := payload.Model
|
||||
model, ok := config.OllamaModelByID(h.Store, modelID)
|
||||
if !ok {
|
||||
http.Error(w, "Model not found.", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
WriteJSON(w, http.StatusOK, model)
|
||||
}
|
||||
127
internal/httpapi/ollama/handler_routes_test.go
Normal file
127
internal/httpapi/ollama/handler_routes_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type ollamaTestSurface struct {
|
||||
Store ConfigReader
|
||||
handler *Handler
|
||||
}
|
||||
|
||||
func (h *ollamaTestSurface) apiHandler() *Handler {
|
||||
if h.handler == nil {
|
||||
h.handler = &Handler{Store: h.Store}
|
||||
}
|
||||
return h.handler
|
||||
}
|
||||
|
||||
func registerOllamaTestRoutes(r chi.Router, h *ollamaTestSurface) {
|
||||
r.Get("/api/version", h.apiHandler().GetVersion)
|
||||
r.Get("/api/tags", h.apiHandler().ListOllamaModels)
|
||||
r.Post("/api/show", h.apiHandler().GetOllamaModel)
|
||||
}
|
||||
|
||||
func TestGetOllamaVersionRoute(t *testing.T) {
|
||||
h := &ollamaTestSurface{}
|
||||
r := chi.NewRouter()
|
||||
registerOllamaTestRoutes(r, h)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOllamaModelsRoute(t *testing.T) {
|
||||
h := &ollamaTestSurface{}
|
||||
r := chi.NewRouter()
|
||||
registerOllamaTestRoutes(r, h)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/tags", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOllamaModelRoute(t *testing.T) {
|
||||
h := &ollamaTestSurface{}
|
||||
r := chi.NewRouter()
|
||||
registerOllamaTestRoutes(r, h)
|
||||
|
||||
t.Run("direct", func(t *testing.T) {
|
||||
body := `{"model":"deepseek-v4-flash"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/show", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("expected valid json body, got err=%v body=%s", err, rec.Body.String())
|
||||
}
|
||||
if _, ok := payload["id"]; !ok {
|
||||
t.Fatalf("expected response has lowercase id field, body=%s", rec.Body.String())
|
||||
}
|
||||
if _, ok := payload["ID"]; ok {
|
||||
t.Fatalf("expected response does not expose uppercase ID field, body=%s", rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("direct_nothinking", func(t *testing.T) {
|
||||
body := `{"model":"deepseek-v4-flash-nothinking"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/show", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("direct_expert", func(t *testing.T) {
|
||||
body := `{"model":"deepseek-v4-pro"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/show", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("direct_vision", func(t *testing.T) {
|
||||
body := `{"model":"deepseek-v4-vision"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/show", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetOllamaModelRouteNotFound(t *testing.T) {
|
||||
h := &ollamaTestSurface{}
|
||||
r := chi.NewRouter()
|
||||
registerOllamaTestRoutes(r, h)
|
||||
|
||||
body := `{"model":"not-exists"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/show", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"ds2api/internal/httpapi/admin"
|
||||
"ds2api/internal/httpapi/claude"
|
||||
"ds2api/internal/httpapi/gemini"
|
||||
"ds2api/internal/httpapi/ollama"
|
||||
"ds2api/internal/httpapi/openai/chat"
|
||||
"ds2api/internal/httpapi/openai/embeddings"
|
||||
"ds2api/internal/httpapi/openai/files"
|
||||
@@ -68,6 +69,7 @@ func NewApp() (*App, error) {
|
||||
claudeHandler := &claude.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: chatHandler, ChatHistory: chatHistoryStore}
|
||||
geminiHandler := &gemini.Handler{Store: store, Auth: resolver, DS: dsClient, OpenAI: chatHandler, ChatHistory: chatHistoryStore}
|
||||
adminHandler := &admin.Handler{Store: store, Pool: pool, DS: dsClient, OpenAI: chatHandler, ChatHistory: chatHistoryStore}
|
||||
ollamaHandler := &ollama.Handler{Store: store}
|
||||
webuiHandler := webui.NewHandler()
|
||||
|
||||
r := chi.NewRouter()
|
||||
@@ -112,6 +114,7 @@ func NewApp() (*App, error) {
|
||||
r.Post("/embeddings", embeddingsHandler.Embeddings)
|
||||
claude.RegisterRoutes(r, claudeHandler)
|
||||
gemini.RegisterRoutes(r, geminiHandler)
|
||||
ollama.RegisterRoutes(r, ollamaHandler)
|
||||
r.Route("/admin", func(ar chi.Router) {
|
||||
admin.RegisterRoutes(ar, adminHandler)
|
||||
})
|
||||
|
||||
@@ -17,11 +17,10 @@ func rewriteDSMLToolMarkupOutsideIgnored(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(text)
|
||||
var b strings.Builder
|
||||
b.Grow(len(text))
|
||||
for i := 0; i < len(text); {
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, lower, i)
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, i)
|
||||
if blocked {
|
||||
b.WriteString(text[i:])
|
||||
break
|
||||
|
||||
@@ -144,7 +144,7 @@ func findXMLStartTagOutsideCDATA(text, tag string, from int) (start, bodyStart i
|
||||
lower := strings.ToLower(text)
|
||||
target := "<" + strings.ToLower(tag)
|
||||
for i := maxInt(from, 0); i < len(text); {
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, lower, i)
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, i)
|
||||
if blocked {
|
||||
return -1, -1, "", false
|
||||
}
|
||||
@@ -170,7 +170,7 @@ func findMatchingXMLEndTagOutsideCDATA(text, tag string, from int) (closeStart,
|
||||
closeTarget := "</" + strings.ToLower(tag)
|
||||
depth := 1
|
||||
for i := maxInt(from, 0); i < len(text); {
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, lower, i)
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, i)
|
||||
if blocked {
|
||||
return -1, -1, false
|
||||
}
|
||||
@@ -206,16 +206,19 @@ func findMatchingXMLEndTagOutsideCDATA(text, tag string, from int) (closeStart,
|
||||
return -1, -1, false
|
||||
}
|
||||
|
||||
func skipXMLIgnoredSection(text, lower string, i int) (next int, advanced bool, blocked bool) {
|
||||
func skipXMLIgnoredSection(text string, i int) (next int, advanced bool, blocked bool) {
|
||||
if i < 0 || i >= len(text) {
|
||||
return i, false, false
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(lower[i:], "<![cdata["):
|
||||
end := findToolCDATAEnd(text, lower, i+len("<![cdata["))
|
||||
case hasASCIIPrefixFoldAt(text, i, "<![cdata["):
|
||||
end := findToolCDATAEnd(text, i+len("<![cdata["))
|
||||
if end < 0 {
|
||||
return 0, false, true
|
||||
}
|
||||
return end + len("]]>"), true, false
|
||||
case strings.HasPrefix(lower[i:], "<!--"):
|
||||
end := strings.Index(lower[i+len("<!--"):], "-->")
|
||||
case strings.HasPrefix(text[i:], "<!--"):
|
||||
end := strings.Index(text[i+len("<!--"):], "-->")
|
||||
if end < 0 {
|
||||
return 0, false, true
|
||||
}
|
||||
@@ -225,14 +228,33 @@ func skipXMLIgnoredSection(text, lower string, i int) (next int, advanced bool,
|
||||
}
|
||||
}
|
||||
|
||||
func findToolCDATAEnd(text, lower string, from int) int {
|
||||
if from < 0 || from > len(text) {
|
||||
func hasASCIIPrefixFoldAt(text string, start int, prefix string) bool {
|
||||
if start < 0 || len(text)-start < len(prefix) {
|
||||
return false
|
||||
}
|
||||
for j := 0; j < len(prefix); j++ {
|
||||
if asciiLower(text[start+j]) != asciiLower(prefix[j]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func asciiLower(b byte) byte {
|
||||
if b >= 'A' && b <= 'Z' {
|
||||
return b + ('a' - 'A')
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func findToolCDATAEnd(text string, from int) int {
|
||||
if from < 0 || from >= len(text) {
|
||||
return -1
|
||||
}
|
||||
const closeMarker = "]]>"
|
||||
firstNonFenceEnd := -1
|
||||
for searchFrom := from; searchFrom < len(text); {
|
||||
rel := strings.Index(lower[searchFrom:], closeMarker)
|
||||
rel := strings.Index(text[searchFrom:], closeMarker)
|
||||
if rel < 0 {
|
||||
break
|
||||
}
|
||||
@@ -241,27 +263,28 @@ func findToolCDATAEnd(text, lower string, from int) int {
|
||||
if cdataOffsetIsInsideMarkdownFence(text[from:end]) {
|
||||
continue
|
||||
}
|
||||
if cdataEndLooksStructural(text, searchFrom) {
|
||||
return end
|
||||
}
|
||||
if firstNonFenceEnd < 0 {
|
||||
firstNonFenceEnd = end
|
||||
}
|
||||
if cdataEndLooksStructural(lower, searchFrom) {
|
||||
return end
|
||||
}
|
||||
}
|
||||
return firstNonFenceEnd
|
||||
}
|
||||
|
||||
func cdataEndLooksStructural(lower string, after int) bool {
|
||||
for after < len(lower) {
|
||||
switch lower[after] {
|
||||
case ' ', '\t', '\r', '\n':
|
||||
func cdataEndLooksStructural(text string, after int) bool {
|
||||
for after < len(text) {
|
||||
switch {
|
||||
case text[after] == ' ' || text[after] == '\t' || text[after] == '\r' || text[after] == '\n':
|
||||
after++
|
||||
continue
|
||||
case after+1 < len(text) && text[after] == '<' && text[after+1] == '/':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
break
|
||||
}
|
||||
return strings.HasPrefix(lower[after:], "</")
|
||||
return false
|
||||
}
|
||||
|
||||
func cdataOffsetIsInsideMarkdownFence(fragment string) bool {
|
||||
|
||||
@@ -28,9 +28,8 @@ type ToolMarkupTag struct {
|
||||
}
|
||||
|
||||
func ContainsToolMarkupSyntaxOutsideIgnored(text string) (hasDSML, hasCanonical bool) {
|
||||
lower := strings.ToLower(text)
|
||||
for i := 0; i < len(text); {
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, lower, i)
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, i)
|
||||
if blocked {
|
||||
return hasDSML, hasCanonical
|
||||
}
|
||||
@@ -56,9 +55,8 @@ func ContainsToolMarkupSyntaxOutsideIgnored(text string) (hasDSML, hasCanonical
|
||||
}
|
||||
|
||||
func ContainsToolCallWrapperSyntaxOutsideIgnored(text string) (hasDSML, hasCanonical bool) {
|
||||
lower := strings.ToLower(text)
|
||||
for i := 0; i < len(text); {
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, lower, i)
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, i)
|
||||
if blocked {
|
||||
return hasDSML, hasCanonical
|
||||
}
|
||||
@@ -88,9 +86,8 @@ func ContainsToolCallWrapperSyntaxOutsideIgnored(text string) (hasDSML, hasCanon
|
||||
}
|
||||
|
||||
func FindToolMarkupTagOutsideIgnored(text string, start int) (ToolMarkupTag, bool) {
|
||||
lower := strings.ToLower(text)
|
||||
for i := maxInt(start, 0); i < len(text); {
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, lower, i)
|
||||
next, advanced, blocked := skipXMLIgnoredSection(text, i)
|
||||
if blocked {
|
||||
return ToolMarkupTag{}, false
|
||||
}
|
||||
@@ -107,7 +104,7 @@ func FindToolMarkupTagOutsideIgnored(text string, start int) (ToolMarkupTag, boo
|
||||
}
|
||||
|
||||
func FindMatchingToolMarkupClose(text string, open ToolMarkupTag) (ToolMarkupTag, bool) {
|
||||
if text == "" || open.Name == "" || open.Closing {
|
||||
if text == "" || open.Name == "" || open.Closing || open.End >= len(text) {
|
||||
return ToolMarkupTag{}, false
|
||||
}
|
||||
depth := 1
|
||||
|
||||
@@ -892,3 +892,139 @@ func TestParseToolCallsSkipsProseMentionOfSameWrapperVariant(t *testing.T) {
|
||||
t.Fatalf("expected command to parse, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTurkishILowercaseMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
start int
|
||||
wantOk bool
|
||||
wantName string
|
||||
}{
|
||||
{"turkish_i_at_name_start", "İ<tool>", 0, false, ""},
|
||||
{"turkish_i_at_name_end", "<toolİ>", 0, false, ""},
|
||||
{"turkish_i_before_tag", "İ<tool>", 0, false, ""},
|
||||
{"normal_tool_calls", "<tool_calls>", 0, true, "tool_calls"},
|
||||
{"normal_invoke", "<invoke name=\"test\">", 0, true, "invoke"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := FindToolMarkupTagOutsideIgnored(tt.text, tt.start)
|
||||
if ok != tt.wantOk {
|
||||
t.Errorf("FindToolMarkupTagOutsideIgnored(%q, %d) ok = %v, want %v", tt.text, tt.start, ok, tt.wantOk)
|
||||
return
|
||||
}
|
||||
if ok && got.Name != tt.wantName {
|
||||
t.Errorf("FindToolMarkupTagOutsideIgnored(%q, %d) name = %q, want %q", tt.text, tt.start, got.Name, tt.wantName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipXMLIgnoredSectionBoundaryConditions(t *testing.T) {
|
||||
text := "hello"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
i int
|
||||
wantNext int
|
||||
wantAdv bool
|
||||
wantBlk bool
|
||||
}{
|
||||
{"valid_index", 2, 2, false, false},
|
||||
{"at_end_equal_len", 5, 5, false, false},
|
||||
{"beyond_end", 6, 6, false, false},
|
||||
{"negative", -1, -1, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
next, adv, blk := skipXMLIgnoredSection(text, tt.i)
|
||||
if next != tt.wantNext || adv != tt.wantAdv || blk != tt.wantBlk {
|
||||
t.Errorf("skipXMLIgnoredSection(%q, %d) = (%d, %v, %v), want (%d, %v, %v)",
|
||||
text, tt.i, next, adv, blk, tt.wantNext, tt.wantAdv, tt.wantBlk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipXMLIgnoredSectionCommentWithUnicodeKeepsByteOffset(t *testing.T) {
|
||||
text := "<!-- İ -->x<tool_calls>"
|
||||
|
||||
next, adv, blk := skipXMLIgnoredSection(text, 0)
|
||||
if blk || !adv {
|
||||
t.Fatalf("skipXMLIgnoredSection() = (%d, %v, %v), want advanced unblocked comment", next, adv, blk)
|
||||
}
|
||||
if want := len("<!-- İ -->"); next != want {
|
||||
t.Fatalf("skipXMLIgnoredSection() next = %d, want %d", next, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipXMLIgnoredSectionMatchesCDATAWithoutAllocatingTail(t *testing.T) {
|
||||
text := "<![cDaTa[<tool_calls>]]><tool_calls>"
|
||||
|
||||
next, adv, blk := skipXMLIgnoredSection(text, 0)
|
||||
if blk || !adv {
|
||||
t.Fatalf("skipXMLIgnoredSection() = (%d, %v, %v), want advanced unblocked CDATA", next, adv, blk)
|
||||
}
|
||||
if want := len("<![cDaTa[<tool_calls>]]>"); next != want {
|
||||
t.Fatalf("skipXMLIgnoredSection() next = %d, want %d", next, want)
|
||||
}
|
||||
|
||||
tag, ok := FindToolMarkupTagOutsideIgnored(text, 0)
|
||||
if !ok {
|
||||
t.Fatal("expected tool tag after skipped CDATA")
|
||||
}
|
||||
if tag.Start != next {
|
||||
t.Fatalf("FindToolMarkupTagOutsideIgnored() start = %d, want %d", tag.Start, next)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolCDATAEndBoundaryConditions(t *testing.T) {
|
||||
text := "<![CDATA[hello]]>"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
from int
|
||||
wantResult int
|
||||
}{
|
||||
{"valid", 12, 14},
|
||||
{"at_end", 17, -1},
|
||||
{"beyond_end", 18, -1},
|
||||
{"negative", -1, -1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := findToolCDATAEnd(text, tt.from)
|
||||
if got != tt.wantResult {
|
||||
t.Errorf("findToolCDATAEnd(%q, %d) = %d, want %d",
|
||||
text, tt.from, got, tt.wantResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindMatchingToolMarkupCloseBoundaryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
open ToolMarkupTag
|
||||
wantOk bool
|
||||
}{
|
||||
{"empty_text", "", ToolMarkupTag{Name: "tool_calls", End: 0}, false},
|
||||
{"open_end_beyond_text", "hello", ToolMarkupTag{Name: "tool_calls", End: 100}, false},
|
||||
{"open_end_equals_len", "hello", ToolMarkupTag{Name: "tool_calls", End: 5}, false},
|
||||
{"valid_simple", "<tool_calls></tool_calls>", ToolMarkupTag{Name: "tool_calls", End: 11}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, ok := FindMatchingToolMarkupClose(tt.text, tt.open)
|
||||
if ok != tt.wantOk {
|
||||
t.Errorf("FindMatchingToolMarkupClose(%q, %+v) ok = %v, want %v", tt.text, tt.open, ok, tt.wantOk)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user