mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-02 07:25:26 +08:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1c749b6803 | ||
|
|
c329bf26b6 | ||
|
|
3ae5b57ebe | ||
|
|
0bf5d5440c | ||
|
|
d731a1fd4f | ||
|
|
93e9fb531d | ||
|
|
6daeb2553d | ||
|
|
321b8a89ee | ||
|
|
d84875e466 | ||
|
|
ea8c9a28a9 | ||
|
|
a302fb3c25 |
@@ -1,5 +1,5 @@
|
||||
<p align="center">
|
||||
<img src="assets/ds2api-icon.svg" width="128" height="128" alt="DS2API icon" />
|
||||
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
|
||||
</p>
|
||||
|
||||
# DS2API
|
||||
@@ -10,6 +10,7 @@
|
||||
[](https://github.com/CJackHwang/ds2api/releases)
|
||||
[](DEPLOY.md)
|
||||
[](https://zeabur.com/templates/L4CFHP)
|
||||
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||
|
||||
语言 / Language: [中文](README.MD) | [English](README.en.md)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<p align="center">
|
||||
<img src="assets/ds2api-icon.svg" width="128" height="128" alt="DS2API icon" />
|
||||
<img src="webui/public/ds2api-favicon.svg" width="128" height="128" alt="DS2API icon" />
|
||||
</p>
|
||||
|
||||
# DS2API
|
||||
@@ -10,6 +10,7 @@
|
||||
[](https://github.com/CJackHwang/ds2api/releases)
|
||||
[](DEPLOY.en.md)
|
||||
[](https://zeabur.com/templates/L4CFHP)
|
||||
[](https://vercel.com/new/clone?repository-url=https://github.com/CJackHwang/ds2api)
|
||||
|
||||
Language: [中文](README.MD) | [English](README.en.md)
|
||||
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg" role="img" aria-label="DS2API icon">
|
||||
<defs>
|
||||
<linearGradient id="bg" x1="96" y1="96" x2="416" y2="416" gradientUnits="userSpaceOnUse">
|
||||
<stop offset="0" stop-color="#06162D" />
|
||||
<stop offset="0.6" stop-color="#0A3A6A" />
|
||||
<stop offset="1" stop-color="#00B4D8" />
|
||||
</linearGradient>
|
||||
<radialGradient id="glow" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(256 180) rotate(90) scale(260)">
|
||||
<stop offset="0" stop-color="#FFFFFF" stop-opacity="0.18" />
|
||||
<stop offset="1" stop-color="#FFFFFF" stop-opacity="0" />
|
||||
</radialGradient>
|
||||
<linearGradient id="whale" x1="180" y1="140" x2="360" y2="360" gradientUnits="userSpaceOnUse">
|
||||
<stop offset="0" stop-color="#EAF7FF" />
|
||||
<stop offset="1" stop-color="#BDEBFF" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
|
||||
<circle cx="256" cy="256" r="240" fill="url(#bg)" />
|
||||
<circle cx="256" cy="256" r="240" fill="url(#glow)" />
|
||||
<circle cx="256" cy="256" r="240" stroke="#FFFFFF" stroke-opacity="0.14" stroke-width="8" />
|
||||
|
||||
<!-- subtle waves -->
|
||||
<path d="M104 338 C156 308 204 366 256 334 C308 302 356 360 408 330" stroke="#FFFFFF" stroke-opacity="0.16" stroke-width="12" stroke-linecap="round" />
|
||||
<path d="M124 372 C174 344 212 396 256 372 C300 348 338 396 388 368" stroke="#FFFFFF" stroke-opacity="0.12" stroke-width="10" stroke-linecap="round" />
|
||||
|
||||
<!-- whale tail (DeepSeek-inspired element, original design) -->
|
||||
<path
|
||||
d="M256 162
|
||||
C228 124 184 118 156 146
|
||||
C132 170 138 206 162 230
|
||||
C190 262 230 252 252 220
|
||||
C254 218 255 216 256 214
|
||||
C257 216 258 218 260 220
|
||||
C282 252 322 262 350 230
|
||||
C374 206 380 170 356 146
|
||||
C328 118 284 124 256 162 Z"
|
||||
fill="url(#whale)"
|
||||
/>
|
||||
<rect x="236" y="214" width="40" height="168" rx="20" fill="url(#whale)" />
|
||||
|
||||
<!-- API nodes -->
|
||||
<g opacity="0.55" stroke="#FFFFFF" stroke-opacity="0.35" stroke-width="6" stroke-linecap="round">
|
||||
<path d="M156 236 L208 206" />
|
||||
<path d="M356 236 L304 206" />
|
||||
<path d="M208 206 L232 172" />
|
||||
<circle cx="156" cy="236" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
||||
<circle cx="208" cy="206" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
||||
<circle cx="232" cy="172" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
||||
<circle cx="304" cy="206" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
||||
<circle cx="356" cy="236" r="10" fill="#FFFFFF" fill-opacity="0.28" />
|
||||
</g>
|
||||
|
||||
<!-- tiny sparkle -->
|
||||
<path
|
||||
d="M378 164
|
||||
C372 170 366 174 358 176
|
||||
C366 178 372 182 378 188
|
||||
C380 180 384 176 392 176
|
||||
C384 174 380 170 378 164 Z"
|
||||
fill="#FFFFFF"
|
||||
fill-opacity="0.32"
|
||||
/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.7 KiB |
@@ -1,6 +1,6 @@
|
||||
services:
|
||||
ds2api:
|
||||
image: crpi-cnazxqmg4avmg4fq.cn-beijing.personal.cr.aliyuncs.com/ronghuaxueleng/ds2api:latest
|
||||
services:
|
||||
ds2api:
|
||||
image: ghcr.io/cjackhwang/ds2api:latest
|
||||
container_name: ds2api
|
||||
restart: always
|
||||
ports:
|
||||
|
||||
@@ -98,7 +98,7 @@ func (s *chatStreamRuntime) sendDone() {
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := s.text.String()
|
||||
detected := util.ParseToolCalls(finalText, s.toolNames)
|
||||
detected := util.ParseStandaloneToolCalls(finalText, s.toolNames)
|
||||
if len(detected) > 0 && !s.toolCallsDoneEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
|
||||
@@ -3,6 +3,7 @@ package openai
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -210,7 +211,7 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
||||
func TestHandleNonStreamEmbeddedToolCallExampleRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
@@ -228,16 +229,16 @@ func TestHandleNonStreamEmbeddedToolCallExampleIntercepted(t *testing.T) {
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
toolCalls, _ := msg["tool_calls"].([]any)
|
||||
if len(toolCalls) == 0 {
|
||||
t.Fatalf("expected tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||
}
|
||||
if msg["content"] != nil {
|
||||
t.Fatalf("expected content nil when tool_calls detected, got %#v", msg["content"])
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "下面是示例:") || !strings.Contains(content, "请勿执行。") || !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded example to remain plain text, got %#v", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -315,6 +316,36 @@ func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallLargeArgumentsStillIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
large := strings.Repeat("a", 9000)
|
||||
payload := fmt.Sprintf(`{"tool_calls":[{"name":"search","input":{"q":"%s"}}]}`, large)
|
||||
splitAt := len(payload) / 2
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[:splitAt]),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[splitAt:]),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid3-large", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -482,8 +513,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -500,15 +531,15 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
||||
t.Fatalf("expected no raw tool_calls json leak in content, got=%q", got)
|
||||
if !strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded tool json to remain text in strict mode, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
||||
func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"我将调用工具。"}`,
|
||||
@@ -524,8 +555,8 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -542,15 +573,15 @@ func TestHandleStreamToolCallAfterLeadingTextStillIntercepted(t *testing.T) {
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testing.T) {
|
||||
func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
|
||||
@@ -565,8 +596,8 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testin
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -583,15 +614,15 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextStillIntercepted(t *testin
|
||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool json leak, got=%q", got)
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
spaces := strings.Repeat(" ", 200)
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -609,11 +640,8 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -627,14 +655,14 @@ func TestHandleStreamToolCallKeyAppearsLateStillNoPrefixLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if strings.Contains(got, "{") {
|
||||
t.Fatalf("unexpected suspicious prefix leak in content: %q", got)
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") || !strings.Contains(got, "{") {
|
||||
t.Fatalf("expected embedded tool json to remain in text, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(got, "后置正文C。") {
|
||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -712,7 +740,7 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
||||
func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||
@@ -735,8 +763,8 @@ func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
argChunks := streamToolCallArgumentChunks(frames)
|
||||
if len(argChunks) < 2 {
|
||||
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String())
|
||||
if len(argChunks) == 0 {
|
||||
t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String())
|
||||
}
|
||||
joined := strings.Join(argChunks, "")
|
||||
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||
|
||||
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
@@ -175,30 +174,11 @@ func normalizeToolArgumentString(raw string) string {
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if !looksLikeConcatenatedJSON(trimmed) {
|
||||
return trimmed
|
||||
if looksLikeConcatenatedJSON(trimmed) {
|
||||
// Keep original payload to avoid silent argument rewrites.
|
||||
return raw
|
||||
}
|
||||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||||
values := make([]any, 0, 2)
|
||||
for {
|
||||
var v any
|
||||
if err := dec.Decode(&v); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
if len(values) < 2 {
|
||||
return trimmed
|
||||
}
|
||||
last := values[len(values)-1]
|
||||
b, err := json.Marshal(last)
|
||||
if err != nil || len(b) == 0 {
|
||||
return trimmed
|
||||
}
|
||||
return string(b)
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func marshalToPromptString(v any) string {
|
||||
|
||||
@@ -168,7 +168,7 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *testing.T) {
|
||||
func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
@@ -189,10 +189,7 @@ func TestNormalizeOpenAIMessagesForPrompt_RepairsConcatenatedToolArguments(t *te
|
||||
t.Fatalf("expected one normalized message, got %d", len(normalized))
|
||||
}
|
||||
content, _ := normalized[0]["content"].(string)
|
||||
if !strings.Contains(content, `function.arguments: {"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected repaired arguments in tool history, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, `{}{"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected concatenated JSON to be repaired, got %q", content)
|
||||
if !strings.Contains(content, `function.arguments: {}{"query":"测试工具调用"}`) {
|
||||
t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArguments(t *testing.T) {
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItemPreservesConcatenatedArguments(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
@@ -151,8 +151,8 @@ func TestNormalizeResponsesInputAsMessagesFunctionCallItemRepairsConcatenatedArg
|
||||
toolCalls, _ := m["tool_calls"].([]any)
|
||||
call, _ := toolCalls[0].(map[string]any)
|
||||
fn, _ := call["function"].(map[string]any)
|
||||
if fn["arguments"] != `{"q":"golang"}` {
|
||||
t.Fatalf("expected concatenated call arguments repaired, got %#v", fn["arguments"])
|
||||
if fn["arguments"] != `{}{"q":"golang"}` {
|
||||
t.Fatalf("expected original concatenated call arguments preserved, got %#v", fn["arguments"])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -113,15 +113,10 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
return
|
||||
}
|
||||
result := sse.CollectStream(resp, thinkingEnabled, true)
|
||||
textParsed := util.ParseToolCallsDetailed(result.Text, toolNames)
|
||||
thinkingParsed := util.ParseToolCallsDetailed(result.Thinking, toolNames)
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(result.Text, toolNames)
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, textParsed, "text")
|
||||
logResponsesToolPolicyRejection(traceID, toolChoice, thinkingParsed, "thinking")
|
||||
|
||||
callCount := len(textParsed.Calls)
|
||||
if callCount == 0 {
|
||||
callCount = len(thinkingParsed.Calls)
|
||||
}
|
||||
if toolChoice.IsRequired() && callCount == 0 {
|
||||
writeOpenAIErrorWithCode(w, http.StatusUnprocessableEntity, "tool_choice requires at least one valid tool call.", "tool_choice_violation")
|
||||
return
|
||||
|
||||
@@ -102,16 +102,11 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
|
||||
}
|
||||
|
||||
textParsed := util.ParseToolCallsDetailed(finalText, s.toolNames)
|
||||
thinkingParsed := util.ParseToolCallsDetailed(finalThinking, s.toolNames)
|
||||
textParsed := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
detected := textParsed.Calls
|
||||
if len(detected) == 0 {
|
||||
detected = thinkingParsed.Calls
|
||||
}
|
||||
s.logToolPolicyRejections(textParsed, thinkingParsed)
|
||||
s.logToolPolicyRejections(textParsed)
|
||||
|
||||
if len(detected) > 0 {
|
||||
s.toolCallsEmitted = true
|
||||
@@ -157,7 +152,7 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingParsed util.ToolCallParseResult) {
|
||||
func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed util.ToolCallParseResult) {
|
||||
logRejected := func(parsed util.ToolCallParseResult, channel string) {
|
||||
rejected := filteredRejectedToolNamesForLog(parsed.RejectedToolNames)
|
||||
if !parsed.RejectedByPolicy || len(rejected) == 0 {
|
||||
@@ -172,7 +167,6 @@ func (s *responsesStreamRuntime) logToolPolicyRejections(textParsed, thinkingPar
|
||||
)
|
||||
}
|
||||
logRejected(textParsed, "text")
|
||||
logRejected(thinkingParsed, "thinking")
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) hasFunctionCallDone() bool {
|
||||
@@ -207,9 +201,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
}
|
||||
s.thinking.WriteString(p.Text)
|
||||
s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -99,9 +99,6 @@ func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected response.output_item.done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||
t.Fatalf("expected response.function_call_arguments.delta event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
|
||||
}
|
||||
@@ -266,7 +263,7 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *testing.T) {
|
||||
func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -291,23 +288,12 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||
if len(addedPayloads) < 2 {
|
||||
t.Fatalf("expected message + function_call output_item.added events, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
if len(addedPayloads) != 1 {
|
||||
t.Fatalf("expected only one message output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
}
|
||||
|
||||
indexes := map[int]struct{}{}
|
||||
typeByIndex := map[int]string{}
|
||||
addedIDs := map[string]string{}
|
||||
for _, payload := range addedPayloads {
|
||||
item, _ := payload["item"].(map[string]any)
|
||||
itemType := strings.TrimSpace(asString(item["type"]))
|
||||
outputIndex := int(asFloat(payload["output_index"]))
|
||||
if _, exists := indexes[outputIndex]; exists {
|
||||
t.Fatalf("found duplicated output_index=%d for item types=%q and %q payload=%#v", outputIndex, typeByIndex[outputIndex], itemType, payload)
|
||||
}
|
||||
indexes[outputIndex] = struct{}{}
|
||||
typeByIndex[outputIndex] = itemType
|
||||
addedIDs[itemType] = strings.TrimSpace(asString(payload["item_id"]))
|
||||
item, _ := addedPayloads[0]["item"].(map[string]any)
|
||||
if asString(item["type"]) != "message" {
|
||||
t.Fatalf("expected only message output item in strict mode, got %#v", item)
|
||||
}
|
||||
|
||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
@@ -316,21 +302,15 @@ func TestHandleResponsesStreamThinkingTextAndToolUseDistinctOutputIndexes(t *tes
|
||||
}
|
||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
found := map[string]bool{}
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
itemType := strings.TrimSpace(asString(m["type"]))
|
||||
itemID := strings.TrimSpace(asString(m["id"]))
|
||||
if itemType == "" || itemID == "" {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if wantID := strings.TrimSpace(addedIDs[itemType]); wantID != "" && wantID == itemID {
|
||||
found[itemType] = true
|
||||
if asString(m["type"]) == "function_call" {
|
||||
t.Fatalf("did not expect function_call output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
if !found["message"] || !found["function_call"] {
|
||||
t.Fatalf("expected completed output to contain streamed message/function_call item ids, found=%#v output=%#v", found, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
@@ -360,7 +340,7 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *testing.T) {
|
||||
func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -373,7 +353,7 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
// invalid JSON (NaN) can still trigger incremental tool deltas before final parse rejects it
|
||||
// invalid JSON (NaN) should remain plain text in strict mode.
|
||||
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
@@ -382,14 +362,11 @@ func TestHandleResponsesStreamMalformedToolJSONClosesInProgressFunctionItem(t *t
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.delta") {
|
||||
t.Fatalf("expected response.function_call_arguments.delta event for malformed payload, body=%s", body)
|
||||
if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.function_call_arguments.done") {
|
||||
t.Fatalf("expected runtime to close in-progress function_call with done event, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.output_item.done") {
|
||||
t.Fatalf("expected runtime to close function output item, body=%s", body)
|
||||
if !strings.Contains(body, "event: response.output_text.delta") {
|
||||
t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("expected response.completed event, body=%s", body)
|
||||
@@ -430,6 +407,42 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(path, value string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": path,
|
||||
"v": value,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
|
||||
sseLine("response/content", "plain text only") +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "event: response.failed") {
|
||||
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
|
||||
}
|
||||
if strings.Contains(body, "event: response.completed") {
|
||||
t.Fatalf("did not expect response.completed after failure, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -516,6 +529,33 @@ func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
|
||||
`data: {"p":"response/content","v":"plain text only"}` + "\n" +
|
||||
`data: [DONE]` + "\n",
|
||||
)),
|
||||
}
|
||||
policy := util.ToolChoicePolicy{
|
||||
Mode: util.ToolChoiceRequired,
|
||||
Allowed: map[string]struct{}{"read_file": {}},
|
||||
}
|
||||
|
||||
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, []string{"read_file"}, policy, "")
|
||||
if rec.Code != http.StatusUnprocessableEntity {
|
||||
t.Fatalf("expected 422 for required tool_choice violation, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
errObj, _ := out["error"].(map[string]any)
|
||||
if asString(errObj["code"]) != "tool_choice_violation" {
|
||||
t.Fatalf("expected code=tool_choice_violation, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
h := &Handler{}
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -167,19 +167,15 @@ func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
|
||||
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
|
||||
}
|
||||
outputText, _ := out["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden for tool call payload, got %q", outputText)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||
}
|
||||
output, _ := out["output"].([]any)
|
||||
hasFunctionCall := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil && m["type"] == "function_call" {
|
||||
hasFunctionCall = true
|
||||
break
|
||||
}
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one output item, got %#v", output)
|
||||
}
|
||||
if !hasFunctionCall {
|
||||
t.Fatalf("expected function_call output item, got %#v", output)
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected message output item, got %#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,21 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.pending.WriteString(chunk)
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
if len(state.pendingToolCalls) > 0 {
|
||||
pending := state.pending.String()
|
||||
if strings.TrimSpace(pending) != "" {
|
||||
content := state.pendingToolRaw + pending
|
||||
state.pending.Reset()
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
} else {
|
||||
// Wait for either more non-whitespace content (demote to plain text)
|
||||
// or stream flush (promote to executable tool calls).
|
||||
return events
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
if state.capturing {
|
||||
@@ -21,32 +36,23 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||
if !ready {
|
||||
if state.capture.Len() > toolSieveCaptureLimit {
|
||||
content := state.capture.String()
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
captured := state.capture.String()
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if len(calls) > 0 {
|
||||
state.pendingToolRaw = captured
|
||||
state.pendingToolCalls = calls
|
||||
continue
|
||||
}
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: calls})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
@@ -89,6 +95,11 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
return nil
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if len(state.pendingToolCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
}
|
||||
if state.capturing {
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||
if ready {
|
||||
@@ -200,6 +211,11 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
// Strict mode: only standalone tool payloads are executable. If the
|
||||
// payload is wrapped by non-whitespace prose, keep it as plain text.
|
||||
if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||
if len(parsed.Calls) == 0 {
|
||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||
|
||||
@@ -7,17 +7,19 @@ import (
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
recentTextTail string
|
||||
pendingToolRaw string
|
||||
pendingToolCalls []util.ParsedToolCall
|
||||
disableDeltas bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
@@ -32,7 +34,6 @@ type toolCallDelta struct {
|
||||
Arguments string
|
||||
}
|
||||
|
||||
const toolSieveCaptureLimit = 8 * 1024
|
||||
const toolSieveContextTailLimit = 256
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
|
||||
@@ -1,128 +1,133 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
page := intFromQuery(r, "page", 1)
|
||||
pageSize := intFromQuery(r, "page_size", 10)
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 1
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
reverseAccounts(accounts)
|
||||
q := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("q")))
|
||||
if q != "" {
|
||||
filtered := make([]config.Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
id := strings.ToLower(acc.Identifier())
|
||||
if strings.Contains(id, q) ||
|
||||
strings.Contains(strings.ToLower(acc.Email), q) ||
|
||||
strings.Contains(strings.ToLower(acc.Mobile), q) {
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
}
|
||||
accounts = filtered
|
||||
}
|
||||
total := len(accounts)
|
||||
totalPages := 1
|
||||
if total > 0 {
|
||||
totalPages = (total + pageSize - 1) / pageSize
|
||||
}
|
||||
start := (page - 1) * pageSize
|
||||
if start > total {
|
||||
start = total
|
||||
}
|
||||
end := start + pageSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
items := make([]map[string]any, 0, end-start)
|
||||
for _, acc := range accounts[start:end] {
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
if len(token) > 20 {
|
||||
preview = token[:20] + "..."
|
||||
} else {
|
||||
preview = token
|
||||
}
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"has_password": acc.Password != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
"test_status": acc.TestStatus,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||
}
|
||||
|
||||
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
acc := toAccount(req)
|
||||
if acc.Identifier() == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
for _, a := range c.Accounts {
|
||||
if acc.Email != "" && a.Email == acc.Email {
|
||||
return fmt.Errorf("邮箱已存在")
|
||||
}
|
||||
if acc.Mobile != "" && a.Mobile == acc.Mobile {
|
||||
return fmt.Errorf("手机号已存在")
|
||||
}
|
||||
}
|
||||
c.Accounts = append(c.Accounts, acc)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := chi.URLParam(r, "identifier")
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, a := range c.Accounts {
|
||||
if accountMatchesIdentifier(a, identifier) {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("账号不存在")
|
||||
}
|
||||
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
page := intFromQuery(r, "page", 1)
|
||||
pageSize := intFromQuery(r, "page_size", 10)
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 1
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
reverseAccounts(accounts)
|
||||
q := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("q")))
|
||||
if q != "" {
|
||||
filtered := make([]config.Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
id := strings.ToLower(acc.Identifier())
|
||||
if strings.Contains(id, q) ||
|
||||
strings.Contains(strings.ToLower(acc.Email), q) ||
|
||||
strings.Contains(strings.ToLower(acc.Mobile), q) {
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
}
|
||||
accounts = filtered
|
||||
}
|
||||
total := len(accounts)
|
||||
totalPages := 1
|
||||
if total > 0 {
|
||||
totalPages = (total + pageSize - 1) / pageSize
|
||||
}
|
||||
start := (page - 1) * pageSize
|
||||
if start > total {
|
||||
start = total
|
||||
}
|
||||
end := start + pageSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
items := make([]map[string]any, 0, end-start)
|
||||
for _, acc := range accounts[start:end] {
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
preview := ""
|
||||
if token != "" {
|
||||
if len(token) > 20 {
|
||||
preview = token[:20] + "..."
|
||||
} else {
|
||||
preview = token
|
||||
}
|
||||
}
|
||||
items = append(items, map[string]any{
|
||||
"identifier": acc.Identifier(),
|
||||
"email": acc.Email,
|
||||
"mobile": acc.Mobile,
|
||||
"has_password": acc.Password != "",
|
||||
"has_token": token != "",
|
||||
"token_preview": preview,
|
||||
"test_status": acc.TestStatus,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
|
||||
}
|
||||
|
||||
func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
acc := toAccount(req)
|
||||
if acc.Identifier() == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
|
||||
return
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
mobileKey := config.CanonicalMobileKey(acc.Mobile)
|
||||
for _, a := range c.Accounts {
|
||||
if acc.Email != "" && a.Email == acc.Email {
|
||||
return fmt.Errorf("邮箱已存在")
|
||||
}
|
||||
if mobileKey != "" && config.CanonicalMobileKey(a.Mobile) == mobileKey {
|
||||
return fmt.Errorf("手机号已存在")
|
||||
}
|
||||
}
|
||||
c.Accounts = append(c.Accounts, acc)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := chi.URLParam(r, "identifier")
|
||||
if decoded, err := url.PathUnescape(identifier); err == nil {
|
||||
identifier = decoded
|
||||
}
|
||||
err := h.Store.Update(func(c *config.Config) error {
|
||||
idx := -1
|
||||
for i, a := range c.Accounts {
|
||||
if accountMatchesIdentifier(a, identifier) {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("账号不存在")
|
||||
}
|
||||
c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
|
||||
return
|
||||
}
|
||||
h.Pool.Reset()
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -102,6 +103,45 @@ func TestDeleteAccountSupportsMobileAlias(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAccountSupportsEncodedPlusMobile(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"mobile":"+8613800138000","password":"pwd"}]
|
||||
}`)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Delete("/admin/accounts/{identifier}", h.deleteAccount)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/admin/accounts/"+url.PathEscape("+8613800138000"), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := len(h.Store.Accounts()); got != 0 {
|
||||
t.Fatalf("expected account removed, remaining=%d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddAccountRejectsCanonicalMobileDuplicate(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[{"mobile":"+8613800138000","password":"pwd"}]
|
||||
}`)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Post("/admin/accounts", h.addAccount)
|
||||
body := []byte(`{"mobile":"13800138000","password":"pwd2"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/accounts", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := len(h.Store.Accounts()); got != 1 {
|
||||
t.Fatalf("expected no duplicate insert, got=%d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"accounts":[
|
||||
@@ -117,6 +157,13 @@ func TestFindAccountByIdentifierSupportsMobileAndTokenOnly(t *testing.T) {
|
||||
if accByMobile.Email != "u@example.com" {
|
||||
t.Fatalf("unexpected account by mobile: %#v", accByMobile)
|
||||
}
|
||||
accByMobileWithCountryCode, ok := findAccountByIdentifier(h.Store, "+8613800138000")
|
||||
if !ok {
|
||||
t.Fatal("expected find by +86 mobile")
|
||||
}
|
||||
if accByMobileWithCountryCode.Email != "u@example.com" {
|
||||
t.Fatalf("unexpected account by +86 mobile: %#v", accByMobileWithCountryCode)
|
||||
}
|
||||
|
||||
tokenOnlyID := ""
|
||||
for _, acc := range h.Store.Accounts() {
|
||||
|
||||
@@ -1,209 +1,212 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
authn "ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
identifier, _ := req["identifier"].(string)
|
||||
if strings.TrimSpace(identifier) == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
||||
return
|
||||
}
|
||||
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
message, _ := req["message"].(string)
|
||||
result := h.testAccount(r.Context(), acc, model, message)
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
model, _ := req["model"].(string)
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
if len(accounts) == 0 {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}})
|
||||
return
|
||||
}
|
||||
|
||||
// Concurrent testing with a semaphore to limit parallelism.
|
||||
const maxConcurrency = 5
|
||||
results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any {
|
||||
return h.testAccount(r.Context(), account, model, "")
|
||||
})
|
||||
|
||||
success := 0
|
||||
for _, res := range results {
|
||||
if ok, _ := res["success"].(bool); ok {
|
||||
success++
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
||||
}
|
||||
|
||||
func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any {
|
||||
if maxConcurrency <= 0 {
|
||||
maxConcurrency = 1
|
||||
}
|
||||
sem := make(chan struct{}, maxConcurrency)
|
||||
results := make([]map[string]any, len(accounts))
|
||||
var wg sync.WaitGroup
|
||||
for i, acc := range accounts {
|
||||
wg.Add(1)
|
||||
go func(idx int, account config.Account) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // acquire
|
||||
defer func() { <-sem }() // release
|
||||
results[idx] = testFn(idx, account)
|
||||
}(i, acc)
|
||||
}
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
||||
start := time.Now()
|
||||
identifier := acc.Identifier()
|
||||
result := map[string]any{"account": identifier, "success": false, "response_time": 0, "message": "", "model": model}
|
||||
defer func() {
|
||||
status := "failed"
|
||||
if ok, _ := result["success"].(bool); ok {
|
||||
status = "ok"
|
||||
}
|
||||
_ = h.Store.UpdateAccountTestStatus(identifier, status)
|
||||
}()
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
if token == "" {
|
||||
newToken, err := h.DS.Login(ctx, acc)
|
||||
if err != nil {
|
||||
result["message"] = "登录失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
token = newToken
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
}
|
||||
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
|
||||
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
newToken, loginErr := h.DS.Login(ctx, acc)
|
||||
if loginErr != nil {
|
||||
result["message"] = "创建会话失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
token = newToken
|
||||
authCtx.DeepSeekToken = token
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
result["message"] = "创建会话失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(message) == "" {
|
||||
message = "你是谁?"
|
||||
}
|
||||
thinking, search, ok := config.GetModelConfig(model)
|
||||
if !ok {
|
||||
thinking, search = false, false
|
||||
}
|
||||
_ = search
|
||||
pow, err := h.DS.GetPow(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
result["message"] = "获取 PoW 失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search}
|
||||
resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1)
|
||||
if err != nil {
|
||||
result["message"] = "请求失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode)
|
||||
return result
|
||||
}
|
||||
collected := sse.CollectStream(resp, thinking, true)
|
||||
result["success"] = true
|
||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||
if collected.Text != "" {
|
||||
result["message"] = collected.Text
|
||||
} else {
|
||||
result["message"] = "(无回复内容)"
|
||||
}
|
||||
if collected.Thinking != "" {
|
||||
result["thinking"] = collected.Thinking
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
model, _ := req["model"].(string)
|
||||
message, _ := req["message"].(string)
|
||||
apiKey, _ := req["api_key"].(string)
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
if message == "" {
|
||||
message = "你好"
|
||||
}
|
||||
if apiKey == "" {
|
||||
keys := h.Store.Snapshot().Keys
|
||||
if len(keys) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"})
|
||||
return
|
||||
}
|
||||
apiKey = keys[0]
|
||||
}
|
||||
host := r.Host
|
||||
scheme := "http"
|
||||
if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") {
|
||||
scheme = "https"
|
||||
}
|
||||
payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false}
|
||||
b, _ := json.Marshal(payload)
|
||||
request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b))
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var parsed any
|
||||
_ = json.Unmarshal(body, &parsed)
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
|
||||
}
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
authn "ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
identifier, _ := req["identifier"].(string)
|
||||
if strings.TrimSpace(identifier) == "" {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要账号标识(identifier / email / mobile)"})
|
||||
return
|
||||
}
|
||||
acc, ok := findAccountByIdentifier(h.Store, identifier)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "账号不存在"})
|
||||
return
|
||||
}
|
||||
model, _ := req["model"].(string)
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
message, _ := req["message"].(string)
|
||||
result := h.testAccount(r.Context(), acc, model, message)
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *Handler) testAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
model, _ := req["model"].(string)
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
accounts := h.Store.Snapshot().Accounts
|
||||
if len(accounts) == 0 {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"total": 0, "success": 0, "failed": 0, "results": []any{}})
|
||||
return
|
||||
}
|
||||
|
||||
// Concurrent testing with a semaphore to limit parallelism.
|
||||
const maxConcurrency = 5
|
||||
results := runAccountTestsConcurrently(accounts, maxConcurrency, func(_ int, account config.Account) map[string]any {
|
||||
return h.testAccount(r.Context(), account, model, "")
|
||||
})
|
||||
|
||||
success := 0
|
||||
for _, res := range results {
|
||||
if ok, _ := res["success"].(bool); ok {
|
||||
success++
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"total": len(accounts), "success": success, "failed": len(accounts) - success, "results": results})
|
||||
}
|
||||
|
||||
func runAccountTestsConcurrently(accounts []config.Account, maxConcurrency int, testFn func(int, config.Account) map[string]any) []map[string]any {
|
||||
if maxConcurrency <= 0 {
|
||||
maxConcurrency = 1
|
||||
}
|
||||
sem := make(chan struct{}, maxConcurrency)
|
||||
results := make([]map[string]any, len(accounts))
|
||||
var wg sync.WaitGroup
|
||||
for i, acc := range accounts {
|
||||
wg.Add(1)
|
||||
go func(idx int, account config.Account) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{} // acquire
|
||||
defer func() { <-sem }() // release
|
||||
results[idx] = testFn(idx, account)
|
||||
}(i, acc)
|
||||
}
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, message string) map[string]any {
|
||||
start := time.Now()
|
||||
identifier := acc.Identifier()
|
||||
result := map[string]any{"account": identifier, "success": false, "response_time": 0, "message": "", "model": model}
|
||||
defer func() {
|
||||
status := "failed"
|
||||
if ok, _ := result["success"].(bool); ok {
|
||||
status = "ok"
|
||||
}
|
||||
_ = h.Store.UpdateAccountTestStatus(identifier, status)
|
||||
}()
|
||||
token := strings.TrimSpace(acc.Token)
|
||||
if token == "" {
|
||||
newToken, err := h.DS.Login(ctx, acc)
|
||||
if err != nil {
|
||||
result["message"] = "登录失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
token = newToken
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
}
|
||||
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
|
||||
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
newToken, loginErr := h.DS.Login(ctx, acc)
|
||||
if loginErr != nil {
|
||||
result["message"] = "创建会话失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
token = newToken
|
||||
authCtx.DeepSeekToken = token
|
||||
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
|
||||
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
result["message"] = "创建会话失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(message) == "" {
|
||||
result["success"] = true
|
||||
result["message"] = "API 测试成功(仅会话创建)"
|
||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||
return result
|
||||
}
|
||||
thinking, search, ok := config.GetModelConfig(model)
|
||||
if !ok {
|
||||
thinking, search = false, false
|
||||
}
|
||||
_ = search
|
||||
pow, err := h.DS.GetPow(ctx, authCtx, 1)
|
||||
if err != nil {
|
||||
result["message"] = "获取 PoW 失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
payload := map[string]any{"chat_session_id": sessionID, "prompt": "<|User|>" + message, "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search}
|
||||
resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1)
|
||||
if err != nil {
|
||||
result["message"] = "请求失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
result["message"] = fmt.Sprintf("请求失败: HTTP %d", resp.StatusCode)
|
||||
return result
|
||||
}
|
||||
collected := sse.CollectStream(resp, thinking, true)
|
||||
result["success"] = true
|
||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||
if collected.Text != "" {
|
||||
result["message"] = collected.Text
|
||||
} else {
|
||||
result["message"] = "(无回复内容)"
|
||||
}
|
||||
if collected.Thinking != "" {
|
||||
result["thinking"] = collected.Thinking
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *Handler) testAPI(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
model, _ := req["model"].(string)
|
||||
message, _ := req["message"].(string)
|
||||
apiKey, _ := req["api_key"].(string)
|
||||
if model == "" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
if message == "" {
|
||||
message = "你好"
|
||||
}
|
||||
if apiKey == "" {
|
||||
keys := h.Store.Snapshot().Keys
|
||||
if len(keys) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "没有可用的 API Key"})
|
||||
return
|
||||
}
|
||||
apiKey = keys[0]
|
||||
}
|
||||
host := r.Host
|
||||
scheme := "http"
|
||||
if strings.Contains(strings.ToLower(host), "vercel") || strings.Contains(strings.ToLower(r.Header.Get("X-Forwarded-Proto")), "https") {
|
||||
scheme = "https"
|
||||
}
|
||||
payload := map[string]any{"model": model, "messages": []map[string]any{{"role": "user", "content": message}}, "stream": false}
|
||||
b, _ := json.Marshal(payload)
|
||||
request, _ := http.NewRequestWithContext(r.Context(), http.MethodPost, fmt.Sprintf("%s://%s/v1/chat/completions", scheme, host), bytes.NewReader(b))
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
resp, err := (&http.Client{Timeout: 60 * time.Second}).Do(request)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var parsed any
|
||||
_ = json.Unmarshal(body, &parsed)
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true, "status_code": resp.StatusCode, "response": parsed})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": false, "status_code": resp.StatusCode, "response": string(body)})
|
||||
}
|
||||
|
||||
76
internal/admin/handler_accounts_testing_test.go
Normal file
76
internal/admin/handler_accounts_testing_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
)
|
||||
|
||||
type testingDSMock struct {
|
||||
loginCalls int
|
||||
createSessionCalls int
|
||||
getPowCalls int
|
||||
callCompletionCalls int
|
||||
}
|
||||
|
||||
func (m *testingDSMock) Login(_ context.Context, _ config.Account) (string, error) {
|
||||
m.loginCalls++
|
||||
return "new-token", nil
|
||||
}
|
||||
|
||||
func (m *testingDSMock) CreateSession(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
m.createSessionCalls++
|
||||
return "session-id", nil
|
||||
}
|
||||
|
||||
func (m *testingDSMock) GetPow(_ context.Context, _ *auth.RequestAuth, _ int) (string, error) {
|
||||
m.getPowCalls++
|
||||
return "", errors.New("should not call GetPow in this test")
|
||||
}
|
||||
|
||||
func (m *testingDSMock) CallCompletion(_ context.Context, _ *auth.RequestAuth, _ map[string]any, _ string, _ int) (*http.Response, error) {
|
||||
m.callCompletionCalls++
|
||||
return nil, errors.New("should not call CallCompletion in this test")
|
||||
}
|
||||
|
||||
func TestTestAccount_BatchModeOnlyCreatesSession(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"accounts":[{"email":"batch@example.com","password":"pwd","token":""}]}`)
|
||||
store := config.LoadStore()
|
||||
ds := &testingDSMock{}
|
||||
h := &Handler{Store: store, DS: ds}
|
||||
acc, ok := store.FindAccount("batch@example.com")
|
||||
if !ok {
|
||||
t.Fatal("expected test account")
|
||||
}
|
||||
|
||||
result := h.testAccount(context.Background(), acc, "deepseek-chat", "")
|
||||
|
||||
if ok, _ := result["success"].(bool); !ok {
|
||||
t.Fatalf("expected success=true, got %#v", result)
|
||||
}
|
||||
msg, _ := result["message"].(string)
|
||||
if !strings.Contains(msg, "仅会话创建") {
|
||||
t.Fatalf("expected session-only success message, got %q", msg)
|
||||
}
|
||||
if ds.loginCalls != 1 || ds.createSessionCalls != 1 {
|
||||
t.Fatalf("unexpected Login/CreateSession calls: login=%d createSession=%d", ds.loginCalls, ds.createSessionCalls)
|
||||
}
|
||||
if ds.getPowCalls != 0 || ds.callCompletionCalls != 0 {
|
||||
t.Fatalf("expected no completion flow calls, got getPow=%d callCompletion=%d", ds.getPowCalls, ds.callCompletionCalls)
|
||||
}
|
||||
updated, ok := store.FindAccount("batch@example.com")
|
||||
if !ok {
|
||||
t.Fatal("expected updated account")
|
||||
}
|
||||
if updated.Token != "new-token" {
|
||||
t.Fatalf("expected refreshed token to be persisted, got %q", updated.Token)
|
||||
}
|
||||
if updated.TestStatus != "ok" {
|
||||
t.Fatalf("expected test status ok, got %q", updated.TestStatus)
|
||||
}
|
||||
}
|
||||
@@ -49,6 +49,7 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
next := c.Clone()
|
||||
if mode == "replace" {
|
||||
next = incoming.Clone()
|
||||
next.Accounts = normalizeAndDedupeAccounts(next.Accounts)
|
||||
next.VercelSyncHash = c.VercelSyncHash
|
||||
next.VercelSyncTime = c.VercelSyncTime
|
||||
importedKeys = len(next.Keys)
|
||||
@@ -73,17 +74,22 @@ func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
existingAccounts := map[string]struct{}{}
|
||||
for _, acc := range next.Accounts {
|
||||
existingAccounts[acc.Identifier()] = struct{}{}
|
||||
acc = normalizeAccountForStorage(acc)
|
||||
key := accountDedupeKey(acc)
|
||||
if key != "" {
|
||||
existingAccounts[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, acc := range incoming.Accounts {
|
||||
id := acc.Identifier()
|
||||
if id == "" {
|
||||
acc = normalizeAccountForStorage(acc)
|
||||
key := accountDedupeKey(acc)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := existingAccounts[id]; ok {
|
||||
if _, ok := existingAccounts[key]; ok {
|
||||
continue
|
||||
}
|
||||
existingAccounts[id] = struct{}{}
|
||||
existingAccounts[key] = struct{}{}
|
||||
next.Accounts = append(next.Accounts, acc)
|
||||
importedAccounts++
|
||||
}
|
||||
|
||||
@@ -25,17 +25,28 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
if accountsRaw, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]config.Account{}
|
||||
for _, a := range old.Accounts {
|
||||
existing[a.Identifier()] = a
|
||||
a = normalizeAccountForStorage(a)
|
||||
key := accountDedupeKey(a)
|
||||
if key != "" {
|
||||
existing[key] = a
|
||||
}
|
||||
}
|
||||
seen := map[string]struct{}{}
|
||||
accounts := make([]config.Account, 0, len(accountsRaw))
|
||||
for _, item := range accountsRaw {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
acc := toAccount(m)
|
||||
id := acc.Identifier()
|
||||
if prev, ok := existing[id]; ok {
|
||||
acc := normalizeAccountForStorage(toAccount(m))
|
||||
key := accountDedupeKey(acc)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
if prev, ok := existing[key]; ok {
|
||||
if strings.TrimSpace(acc.Password) == "" {
|
||||
acc.Password = prev.Password
|
||||
}
|
||||
@@ -43,6 +54,7 @@ func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
acc.Token = prev.Token
|
||||
}
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
accounts = append(accounts, acc)
|
||||
}
|
||||
c.Accounts = accounts
|
||||
@@ -138,20 +150,24 @@ func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
|
||||
if accounts, ok := req["accounts"].([]any); ok {
|
||||
existing := map[string]bool{}
|
||||
for _, a := range c.Accounts {
|
||||
existing[a.Identifier()] = true
|
||||
a = normalizeAccountForStorage(a)
|
||||
key := accountDedupeKey(a)
|
||||
if key != "" {
|
||||
existing[key] = true
|
||||
}
|
||||
}
|
||||
for _, item := range accounts {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
acc := toAccount(m)
|
||||
id := acc.Identifier()
|
||||
if id == "" || existing[id] {
|
||||
acc := normalizeAccountForStorage(toAccount(m))
|
||||
key := accountDedupeKey(acc)
|
||||
if key == "" || existing[key] {
|
||||
continue
|
||||
}
|
||||
c.Accounts = append(c.Accounts, acc)
|
||||
existing[id] = true
|
||||
existing[key] = true
|
||||
importedAccounts++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,3 +265,57 @@ func TestConfigImportRejectsMergedRuntimeConflict(t *testing.T) {
|
||||
t.Fatalf("runtime should remain unchanged, runtime=%+v", snap.Runtime)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigImportMergeDedupesMobileAliases(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"mobile":"+8613800138000","password":"p1"}]
|
||||
}`)
|
||||
|
||||
merge := map[string]any{
|
||||
"mode": "merge",
|
||||
"config": map[string]any{
|
||||
"accounts": []any{
|
||||
map[string]any{"mobile": "13800138000", "password": "p2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(merge)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/config/import?mode=merge", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.configImport(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := len(h.Store.Accounts()); got != 1 {
|
||||
t.Fatalf("expected merge dedupe by canonical mobile, got=%d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConfigDedupesMobileAliases(t *testing.T) {
|
||||
h := newAdminTestHandler(t, `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"mobile":"+8613800138000","password":"old"}]
|
||||
}`)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"accounts": []any{
|
||||
map[string]any{"mobile": "+8613800138000"},
|
||||
map[string]any{"mobile": "13800138000"},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest(http.MethodPost, "/admin/config", bytes.NewReader(b))
|
||||
rec := httptest.NewRecorder()
|
||||
h.updateConfig(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
accounts := h.Store.Accounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected update dedupe by canonical mobile, got=%d", len(accounts))
|
||||
}
|
||||
if accounts[0].Identifier() != "+8613800138000" {
|
||||
t.Fatalf("unexpected identifier: %q", accounts[0].Identifier())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,9 +59,11 @@ func toStringSlice(v any) ([]string, bool) {
|
||||
}
|
||||
|
||||
func toAccount(m map[string]any) config.Account {
|
||||
email := fieldString(m, "email")
|
||||
mobile := config.NormalizeMobileForStorage(fieldString(m, "mobile"))
|
||||
return config.Account{
|
||||
Email: fieldString(m, "email"),
|
||||
Mobile: fieldString(m, "mobile"),
|
||||
Email: email,
|
||||
Mobile: mobile,
|
||||
Password: fieldString(m, "password"),
|
||||
Token: fieldString(m, "token"),
|
||||
}
|
||||
@@ -90,12 +92,52 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool {
|
||||
if strings.TrimSpace(acc.Email) == id {
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(acc.Mobile) == id {
|
||||
if mobileKey := config.CanonicalMobileKey(id); mobileKey != "" && mobileKey == config.CanonicalMobileKey(acc.Mobile) {
|
||||
return true
|
||||
}
|
||||
return acc.Identifier() == id
|
||||
}
|
||||
|
||||
func normalizeAccountForStorage(acc config.Account) config.Account {
|
||||
acc.Email = strings.TrimSpace(acc.Email)
|
||||
acc.Mobile = config.NormalizeMobileForStorage(acc.Mobile)
|
||||
return acc
|
||||
}
|
||||
|
||||
func accountDedupeKey(acc config.Account) string {
|
||||
if email := strings.TrimSpace(acc.Email); email != "" {
|
||||
return "email:" + email
|
||||
}
|
||||
if mobile := config.CanonicalMobileKey(acc.Mobile); mobile != "" {
|
||||
return "mobile:" + mobile
|
||||
}
|
||||
if id := strings.TrimSpace(acc.Identifier()); id != "" {
|
||||
return "id:" + id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalizeAndDedupeAccounts(accounts []config.Account) []config.Account {
|
||||
if len(accounts) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]config.Account, 0, len(accounts))
|
||||
seen := make(map[string]struct{}, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
acc = normalizeAccountForStorage(acc)
|
||||
key := accountDedupeKey(acc)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, acc)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func findAccountByIdentifier(store ConfigStore, identifier string) (config.Account, bool) {
|
||||
id := strings.TrimSpace(identifier)
|
||||
if id == "" {
|
||||
|
||||
@@ -182,7 +182,7 @@ func TestToAccountAllFields(t *testing.T) {
|
||||
if acc.Email != "user@test.com" {
|
||||
t.Fatalf("unexpected email: %q", acc.Email)
|
||||
}
|
||||
if acc.Mobile != "13800138000" {
|
||||
if acc.Mobile != "+8613800138000" {
|
||||
t.Fatalf("unexpected mobile: %q", acc.Mobile)
|
||||
}
|
||||
if acc.Password != "secret" {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
@@ -67,6 +68,7 @@ func TestGoCompatToolcallFixtures(t *testing.T) {
|
||||
var fixture struct {
|
||||
Text string `json:"text"`
|
||||
ToolNames []string `json:"tool_names"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
mustLoadJSON(t, fixturePath, &fixture)
|
||||
|
||||
@@ -75,7 +77,13 @@ func TestGoCompatToolcallFixtures(t *testing.T) {
|
||||
}
|
||||
mustLoadJSON(t, expectedPath, &expected)
|
||||
|
||||
got := util.ParseToolCalls(fixture.Text, fixture.ToolNames)
|
||||
var got []util.ParsedToolCall
|
||||
switch strings.ToLower(strings.TrimSpace(fixture.Mode)) {
|
||||
case "standalone":
|
||||
got = util.ParseStandaloneToolCalls(fixture.Text, fixture.ToolNames)
|
||||
default:
|
||||
got = util.ParseToolCalls(fixture.Text, fixture.ToolNames)
|
||||
}
|
||||
if len(got) == 0 && len(expected.Calls) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ func (a Account) Identifier() string {
|
||||
if strings.TrimSpace(a.Email) != "" {
|
||||
return strings.TrimSpace(a.Email)
|
||||
}
|
||||
if strings.TrimSpace(a.Mobile) != "" {
|
||||
return strings.TrimSpace(a.Mobile)
|
||||
if mobile := NormalizeMobileForStorage(a.Mobile); mobile != "" {
|
||||
return mobile
|
||||
}
|
||||
// Backward compatibility: old configs may contain token-only accounts.
|
||||
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
||||
|
||||
@@ -202,7 +202,7 @@ func TestConfigCloneNilMaps(t *testing.T) {
|
||||
|
||||
func TestAccountIdentifierPreferenceMobileOverToken(t *testing.T) {
|
||||
acc := Account{Mobile: "13800138000", Token: "tok"}
|
||||
if acc.Identifier() != "13800138000" {
|
||||
if acc.Identifier() != "+8613800138000" {
|
||||
t.Fatalf("expected mobile identifier, got %q", acc.Identifier())
|
||||
}
|
||||
}
|
||||
|
||||
82
internal/config/mobile.go
Normal file
82
internal/config/mobile.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
|
||||
// NormalizeMobileForStorage normalizes user input to a stable storage format.
|
||||
// It keeps existing country codes and auto-prefixes mainland China numbers with +86.
|
||||
func NormalizeMobileForStorage(raw string) string {
|
||||
digits, hasPlus := extractMobileDigits(raw)
|
||||
if digits == "" {
|
||||
return ""
|
||||
}
|
||||
if hasPlus {
|
||||
return "+" + digits
|
||||
}
|
||||
if isChinaMobileWithCountryCode(digits) {
|
||||
return "+86" + digits[2:]
|
||||
}
|
||||
if isChinaMainlandMobileDigits(digits) {
|
||||
return "+86" + digits
|
||||
}
|
||||
// For non-China numbers without a leading +, preserve semantics by adding it.
|
||||
return "+" + digits
|
||||
}
|
||||
|
||||
// CanonicalMobileKey returns the comparison key used by dedupe/matching logic.
|
||||
func CanonicalMobileKey(raw string) string {
|
||||
return NormalizeMobileForStorage(raw)
|
||||
}
|
||||
|
||||
func extractMobileDigits(raw string) (digits string, hasPlus bool) {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= '0' && r <= '9':
|
||||
goto collect
|
||||
case isMobileSeparator(r):
|
||||
continue
|
||||
case r == '+':
|
||||
hasPlus = true
|
||||
goto collect
|
||||
default:
|
||||
goto collect
|
||||
}
|
||||
}
|
||||
|
||||
collect:
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
for _, r := range s {
|
||||
if r >= '0' && r <= '9' {
|
||||
b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return b.String(), hasPlus
|
||||
}
|
||||
|
||||
func isChinaMainlandMobileDigits(digits string) bool {
|
||||
if len(digits) != 11 || digits[0] != '1' {
|
||||
return false
|
||||
}
|
||||
return digits[1] >= '3' && digits[1] <= '9'
|
||||
}
|
||||
|
||||
func isChinaMobileWithCountryCode(digits string) bool {
|
||||
if len(digits) != 13 || !strings.HasPrefix(digits, "86") {
|
||||
return false
|
||||
}
|
||||
return isChinaMainlandMobileDigits(digits[2:])
|
||||
}
|
||||
|
||||
func isMobileSeparator(r rune) bool {
|
||||
switch r {
|
||||
case ' ', '\t', '\n', '\r', '-', '(', ')', '.', '/':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
36
internal/config/mobile_test.go
Normal file
36
internal/config/mobile_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeMobileForStorageChinaMainlandAddsPlus86(t *testing.T) {
|
||||
if got := NormalizeMobileForStorage("13800138000"); got != "+8613800138000" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMobileForStorageChinaWithCountryCode(t *testing.T) {
|
||||
if got := NormalizeMobileForStorage("8613800138000"); got != "+8613800138000" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMobileForStorageKeepsExistingCountryCode(t *testing.T) {
|
||||
if got := NormalizeMobileForStorage(" +1 (415) 555-2671 "); got != "+14155552671" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalMobileKeyMatchesChinaAliases(t *testing.T) {
|
||||
a := CanonicalMobileKey("+8613800138000")
|
||||
b := CanonicalMobileKey("13800138000")
|
||||
c := CanonicalMobileKey("86 13800138000")
|
||||
if a == "" || a != b || b != c {
|
||||
t.Fatalf("alias mismatch: a=%q b=%q c=%q", a, b, c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalMobileKeyEmptyForInvalidInput(t *testing.T) {
|
||||
if got := CanonicalMobileKey("() --"); got != "" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
@@ -20,8 +21,9 @@ func (c *Client) Login(ctx context.Context, acc config.Account) (string, error)
|
||||
if email := strings.TrimSpace(acc.Email); email != "" {
|
||||
payload["email"] = email
|
||||
} else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
|
||||
payload["mobile"] = mobile
|
||||
payload["area_code"] = nil
|
||||
loginMobile, areaCode := normalizeMobileForLogin(mobile)
|
||||
payload["mobile"] = loginMobile
|
||||
payload["area_code"] = areaCode
|
||||
} else {
|
||||
return "", errors.New("missing email/mobile")
|
||||
}
|
||||
@@ -151,3 +153,26 @@ func isTokenInvalid(status int, code int, msg string) bool {
|
||||
}
|
||||
return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
|
||||
}
|
||||
|
||||
func normalizeMobileForLogin(raw string) (mobile string, areaCode any) {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return "", nil
|
||||
}
|
||||
hasPlus := strings.HasPrefix(s, "+")
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
for _, r := range s {
|
||||
if unicode.IsDigit(r) {
|
||||
b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
digits := b.String()
|
||||
if digits == "" {
|
||||
return "", nil
|
||||
}
|
||||
if (hasPlus || strings.HasPrefix(digits, "86")) && strings.HasPrefix(digits, "86") && len(digits) == 13 {
|
||||
return digits[2:], nil
|
||||
}
|
||||
return digits, nil
|
||||
}
|
||||
|
||||
33
internal/deepseek/client_auth_mobile_test.go
Normal file
33
internal/deepseek/client_auth_mobile_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package deepseek
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeMobileForLogin_ChinaWithPlus86(t *testing.T) {
|
||||
mobile, areaCode := normalizeMobileForLogin("+8613800138000")
|
||||
if mobile != "13800138000" {
|
||||
t.Fatalf("unexpected mobile: %q", mobile)
|
||||
}
|
||||
if areaCode != nil {
|
||||
t.Fatalf("expected nil areaCode, got %#v", areaCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMobileForLogin_ChinaWith86Prefix(t *testing.T) {
|
||||
mobile, areaCode := normalizeMobileForLogin("8613800138000")
|
||||
if mobile != "13800138000" {
|
||||
t.Fatalf("unexpected mobile: %q", mobile)
|
||||
}
|
||||
if areaCode != nil {
|
||||
t.Fatalf("expected nil areaCode, got %#v", areaCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeMobileForLogin_KeepPlainDigits(t *testing.T) {
|
||||
mobile, areaCode := normalizeMobileForLogin("13800138000")
|
||||
if mobile != "13800138000" {
|
||||
t.Fatalf("unexpected mobile: %q", mobile)
|
||||
}
|
||||
if areaCode != nil {
|
||||
t.Fatalf("expected nil areaCode, got %#v", areaCode)
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if strings.TrimSpace(finalThinking) != "" {
|
||||
|
||||
@@ -11,12 +11,9 @@ import (
|
||||
)
|
||||
|
||||
func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
// Align responses tool-call semantics with chat/completions:
|
||||
// mixed prose + tool_call payloads should still be interpreted as tool calls.
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
|
||||
detected = util.ParseToolCalls(finalThinking, toolNames)
|
||||
}
|
||||
// Strict mode: only standalone, structured tool-call payloads are treated
|
||||
// as executable tool calls.
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
exposedOutputText := finalText
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T) {
|
||||
func TestBuildResponseObjectTreatsMixedProseToolPayloadAsText(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
@@ -56,17 +56,16 @@ func TestBuildResponseObjectTreatsMixedProseToolPayloadAsToolCall(t *testing.T)
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text hidden once tool calls are detected, got %q", outputText)
|
||||
if outputText == "" {
|
||||
t.Fatalf("expected output_text preserved for mixed prose payload")
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected first output type function_call, got %#v", first["type"])
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected message output type, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,7 +126,7 @@ func TestBuildResponseObjectReasoningOnlyFallsBackToOutputText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
func TestBuildResponseObjectIgnoresToolCallFromThinkingChannel(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
@@ -139,10 +138,10 @@ func TestBuildResponseObjectDetectsToolCallFromThinkingChannel(t *testing.T) {
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected function_call output only, got %#v", obj["output"])
|
||||
t.Fatalf("expected one message output item, got %#v", obj["output"])
|
||||
}
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "function_call" {
|
||||
t.Fatalf("expected output function_call, got %#v", first["type"])
|
||||
if first["type"] != "message" {
|
||||
t.Fatalf("expected output message, got %#v", first["type"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,10 @@ const {
|
||||
} = require('./sse_parse');
|
||||
const {
|
||||
resolveToolcallPolicy,
|
||||
formatIncrementalToolCallDeltas,
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
} = require('./toolcall_policy');
|
||||
const {
|
||||
estimateTokens,
|
||||
@@ -82,7 +84,9 @@ module.exports.__test = {
|
||||
shouldSkipPath,
|
||||
asString,
|
||||
resolveToolcallPolicy,
|
||||
formatIncrementalToolCallDeltas,
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
estimateTokens,
|
||||
};
|
||||
|
||||
@@ -68,6 +68,47 @@ function formatIncrementalToolCallDeltas(deltas, idStore) {
|
||||
return out;
|
||||
}
|
||||
|
||||
function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenNames) {
|
||||
if (!Array.isArray(deltas) || deltas.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const seen = seenNames instanceof Map ? seenNames : new Map();
|
||||
const allowed = new Set((allowedNames || []).filter((name) => asString(name) !== ''));
|
||||
if (allowed.size === 0) {
|
||||
for (const d of deltas) {
|
||||
if (d && typeof d === 'object' && asString(d.name)) {
|
||||
const index = Number.isInteger(d.index) ? d.index : 0;
|
||||
seen.set(index, '__blocked__');
|
||||
}
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const out = [];
|
||||
for (const d of deltas) {
|
||||
if (!d || typeof d !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const index = Number.isInteger(d.index) ? d.index : 0;
|
||||
const name = asString(d.name);
|
||||
if (name) {
|
||||
if (!allowed.has(name)) {
|
||||
seen.set(index, '__blocked__');
|
||||
continue;
|
||||
}
|
||||
seen.set(index, name);
|
||||
out.push(d);
|
||||
continue;
|
||||
}
|
||||
const existing = asString(seen.get(index));
|
||||
if (!existing || existing === '__blocked__') {
|
||||
continue;
|
||||
}
|
||||
out.push(d);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function ensureStreamToolCallID(idStore, index) {
|
||||
const key = Number.isInteger(index) ? index : 0;
|
||||
const existing = idStore.get(key);
|
||||
@@ -104,4 +145,5 @@ module.exports = {
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
formatIncrementalToolCallDeltas,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
};
|
||||
|
||||
@@ -5,7 +5,7 @@ const {
|
||||
createToolSieveState,
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseToolCalls,
|
||||
parseStandaloneToolCalls,
|
||||
formatOpenAIStreamToolCalls,
|
||||
} = require('../helpers/stream-tool-sieve');
|
||||
const {
|
||||
@@ -24,7 +24,6 @@ const {
|
||||
} = require('./token_usage');
|
||||
const {
|
||||
resolveToolcallPolicy,
|
||||
formatIncrementalToolCallDeltas,
|
||||
} = require('./toolcall_policy');
|
||||
const {
|
||||
createChatCompletionEmitter,
|
||||
@@ -130,7 +129,6 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
let thinkingText = '';
|
||||
let outputText = '';
|
||||
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
|
||||
const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
|
||||
const toolSieveState = createToolSieveState();
|
||||
let toolCallsEmitted = false;
|
||||
const streamToolCallIDs = new Map();
|
||||
@@ -155,13 +153,18 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
await releaseLease();
|
||||
return;
|
||||
}
|
||||
const detected = parseToolCalls(outputText, toolNames);
|
||||
const detected = parseStandaloneToolCalls(outputText, toolNames);
|
||||
if (detected.length > 0 && !toolCallsEmitted) {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) });
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected, streamToolCallIDs) });
|
||||
} else if (toolSieveEnabled) {
|
||||
const tailEvents = flushToolSieve(toolSieveState, toolNames);
|
||||
for (const evt of tailEvents) {
|
||||
if (evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0) {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||
continue;
|
||||
}
|
||||
if (evt.text) {
|
||||
sendDeltaFrame({ content: evt.text });
|
||||
}
|
||||
@@ -252,17 +255,9 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
}
|
||||
const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
|
||||
for (const evt of events) {
|
||||
if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) {
|
||||
if (!emitEarlyToolDeltas) {
|
||||
continue;
|
||||
}
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) });
|
||||
continue;
|
||||
}
|
||||
if (evt.type === 'tool_calls') {
|
||||
toolCallsEmitted = true;
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) });
|
||||
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
|
||||
continue;
|
||||
}
|
||||
if (evt.text) {
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
|
||||
const crypto = require('crypto');
|
||||
|
||||
function formatOpenAIStreamToolCalls(calls) {
|
||||
function formatOpenAIStreamToolCalls(calls, idStore) {
|
||||
if (!Array.isArray(calls) || calls.length === 0) {
|
||||
return [];
|
||||
}
|
||||
return calls.map((c, idx) => ({
|
||||
index: idx,
|
||||
id: `call_${newCallID()}`,
|
||||
id: ensureStreamToolCallID(idStore, idx),
|
||||
type: 'function',
|
||||
function: {
|
||||
name: c.name,
|
||||
@@ -17,6 +17,20 @@ function formatOpenAIStreamToolCalls(calls) {
|
||||
}));
|
||||
}
|
||||
|
||||
function ensureStreamToolCallID(idStore, index) {
|
||||
if (!(idStore instanceof Map)) {
|
||||
return `call_${newCallID()}`;
|
||||
}
|
||||
const key = Number.isInteger(index) ? index : 0;
|
||||
const existing = idStore.get(key);
|
||||
if (existing) {
|
||||
return existing;
|
||||
}
|
||||
const next = `call_${newCallID()}`;
|
||||
idStore.set(key, next);
|
||||
return next;
|
||||
}
|
||||
|
||||
function newCallID() {
|
||||
if (typeof crypto.randomUUID === 'function') {
|
||||
return crypto.randomUUID().replace(/-/g, '');
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
'use strict';
|
||||
|
||||
const {
|
||||
looksLikeToolExampleContext,
|
||||
insideCodeFence,
|
||||
} = require('./state');
|
||||
const {
|
||||
findObjectFieldValueStart,
|
||||
parseJSONStringLiteral,
|
||||
skipSpaces,
|
||||
} = require('./jsonscan');
|
||||
|
||||
function buildIncrementalToolDeltas(state) {
|
||||
const captured = state.capture || '';
|
||||
if (!captured) {
|
||||
return [];
|
||||
}
|
||||
if (looksLikeToolExampleContext(state.recentTextTail)) {
|
||||
return [];
|
||||
}
|
||||
const lower = captured.toLowerCase();
|
||||
const keyIdx = lower.indexOf('tool_calls');
|
||||
if (keyIdx < 0) {
|
||||
return [];
|
||||
}
|
||||
const start = captured.slice(0, keyIdx).lastIndexOf('{');
|
||||
if (start < 0) {
|
||||
return [];
|
||||
}
|
||||
if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) {
|
||||
return [];
|
||||
}
|
||||
const callStart = findFirstToolCallObjectStart(captured, keyIdx);
|
||||
if (callStart < 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const deltas = [];
|
||||
if (!state.toolName) {
|
||||
const name = extractToolCallName(captured, callStart);
|
||||
if (!name) {
|
||||
return [];
|
||||
}
|
||||
state.toolName = name;
|
||||
}
|
||||
|
||||
if (state.toolArgsStart < 0) {
|
||||
const args = findToolCallArgsStart(captured, callStart);
|
||||
if (args) {
|
||||
state.toolArgsString = Boolean(args.stringMode);
|
||||
state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start;
|
||||
state.toolArgsSent = state.toolArgsStart;
|
||||
}
|
||||
}
|
||||
if (!state.toolNameSent) {
|
||||
if (state.toolArgsStart < 0) {
|
||||
return [];
|
||||
}
|
||||
state.toolNameSent = true;
|
||||
deltas.push({ index: 0, name: state.toolName });
|
||||
}
|
||||
if (state.toolArgsStart < 0 || state.toolArgsDone) {
|
||||
return deltas;
|
||||
}
|
||||
const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString);
|
||||
if (!progress) {
|
||||
return deltas;
|
||||
}
|
||||
if (progress.end > state.toolArgsSent) {
|
||||
deltas.push({
|
||||
index: 0,
|
||||
arguments: captured.slice(state.toolArgsSent, progress.end),
|
||||
});
|
||||
state.toolArgsSent = progress.end;
|
||||
}
|
||||
if (progress.complete) {
|
||||
state.toolArgsDone = true;
|
||||
}
|
||||
return deltas;
|
||||
}
|
||||
|
||||
function findFirstToolCallObjectStart(text, keyIdx) {
|
||||
const arrStart = findToolCallsArrayStart(text, keyIdx);
|
||||
if (arrStart < 0) {
|
||||
return -1;
|
||||
}
|
||||
const i = skipSpaces(text, arrStart + 1);
|
||||
if (i >= text.length || text[i] !== '{') {
|
||||
return -1;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
function findToolCallsArrayStart(text, keyIdx) {
|
||||
let i = keyIdx + 'tool_calls'.length;
|
||||
while (i < text.length && text[i] !== ':') {
|
||||
i += 1;
|
||||
}
|
||||
if (i >= text.length) {
|
||||
return -1;
|
||||
}
|
||||
i = skipSpaces(text, i + 1);
|
||||
if (i >= text.length || text[i] !== '[') {
|
||||
return -1;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
function extractToolCallName(text, callStart) {
|
||||
let valueStart = findObjectFieldValueStart(text, callStart, ['name']);
|
||||
if (valueStart < 0 || text[valueStart] !== '"') {
|
||||
const fnStart = findFunctionObjectStart(text, callStart);
|
||||
if (fnStart < 0) {
|
||||
return '';
|
||||
}
|
||||
valueStart = findObjectFieldValueStart(text, fnStart, ['name']);
|
||||
if (valueStart < 0 || text[valueStart] !== '"') {
|
||||
return '';
|
||||
}
|
||||
}
|
||||
const parsed = parseJSONStringLiteral(text, valueStart);
|
||||
if (!parsed) {
|
||||
return '';
|
||||
}
|
||||
return parsed.value;
|
||||
}
|
||||
|
||||
function findToolCallArgsStart(text, callStart) {
|
||||
const keys = ['input', 'arguments', 'args', 'parameters', 'params'];
|
||||
let valueStart = findObjectFieldValueStart(text, callStart, keys);
|
||||
if (valueStart < 0) {
|
||||
const fnStart = findFunctionObjectStart(text, callStart);
|
||||
if (fnStart < 0) {
|
||||
return null;
|
||||
}
|
||||
valueStart = findObjectFieldValueStart(text, fnStart, keys);
|
||||
if (valueStart < 0) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
if (valueStart >= text.length) {
|
||||
return null;
|
||||
}
|
||||
const ch = text[valueStart];
|
||||
if (ch === '{' || ch === '[') {
|
||||
return { start: valueStart, stringMode: false };
|
||||
}
|
||||
if (ch === '"') {
|
||||
return { start: valueStart, stringMode: true };
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function scanToolCallArgsProgress(text, start, stringMode) {
|
||||
if (start < 0 || start > text.length) {
|
||||
return null;
|
||||
}
|
||||
if (stringMode) {
|
||||
let escaped = false;
|
||||
for (let i = start; i < text.length; i += 1) {
|
||||
const ch = text[i];
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if (ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === '"') {
|
||||
return { end: i, complete: true };
|
||||
}
|
||||
}
|
||||
return { end: text.length, complete: false };
|
||||
}
|
||||
if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) {
|
||||
return null;
|
||||
}
|
||||
let depth = 0;
|
||||
let quote = '';
|
||||
let escaped = false;
|
||||
for (let i = start; i < text.length; i += 1) {
|
||||
const ch = text[i];
|
||||
if (quote) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
continue;
|
||||
}
|
||||
if (ch === '\\') {
|
||||
escaped = true;
|
||||
continue;
|
||||
}
|
||||
if (ch === quote) {
|
||||
quote = '';
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '"' || ch === "'") {
|
||||
quote = ch;
|
||||
continue;
|
||||
}
|
||||
if (ch === '{' || ch === '[') {
|
||||
depth += 1;
|
||||
continue;
|
||||
}
|
||||
if (ch === '}' || ch === ']') {
|
||||
depth -= 1;
|
||||
if (depth === 0) {
|
||||
return { end: i + 1, complete: true };
|
||||
}
|
||||
}
|
||||
}
|
||||
return { end: text.length, complete: false };
|
||||
}
|
||||
|
||||
function findFunctionObjectStart(text, callStart) {
|
||||
const valueStart = findObjectFieldValueStart(text, callStart, ['function']);
|
||||
if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') {
|
||||
return -1;
|
||||
}
|
||||
return valueStart;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
buildIncrementalToolDeltas,
|
||||
};
|
||||
@@ -10,7 +10,9 @@ const {
|
||||
const {
|
||||
extractToolNames,
|
||||
parseToolCalls,
|
||||
parseToolCallsDetailed,
|
||||
parseStandaloneToolCalls,
|
||||
parseStandaloneToolCallsDetailed,
|
||||
} = require('./parse');
|
||||
const {
|
||||
formatOpenAIStreamToolCalls,
|
||||
@@ -22,6 +24,8 @@ module.exports = {
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseToolCalls,
|
||||
parseToolCallsDetailed,
|
||||
parseStandaloneToolCalls,
|
||||
parseStandaloneToolCallsDetailed,
|
||||
formatOpenAIStreamToolCalls,
|
||||
};
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
'use strict';
|
||||
|
||||
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
||||
|
||||
const {
|
||||
toStringSafe,
|
||||
looksLikeToolExampleContext,
|
||||
} = require('./state');
|
||||
const {
|
||||
extractJSONObjectFrom,
|
||||
} = require('./jsonscan');
|
||||
stripFencedCodeBlocks,
|
||||
buildToolCallCandidates,
|
||||
parseToolCallsPayload,
|
||||
} = require('./parse_payload');
|
||||
|
||||
function extractToolNames(tools) {
|
||||
if (!Array.isArray(tools) || tools.length === 0) {
|
||||
@@ -29,245 +29,144 @@ function extractToolNames(tools) {
|
||||
}
|
||||
|
||||
function parseToolCalls(text, toolNames) {
|
||||
return parseToolCallsDetailed(text, toolNames).calls;
|
||||
}
|
||||
|
||||
function parseToolCallsDetailed(text, toolNames) {
|
||||
const result = emptyParseResult();
|
||||
if (!toStringSafe(text)) {
|
||||
return [];
|
||||
return result;
|
||||
}
|
||||
const sanitized = stripFencedCodeBlocks(text);
|
||||
if (!toStringSafe(sanitized)) {
|
||||
return [];
|
||||
return result;
|
||||
}
|
||||
result.sawToolCallSyntax = sanitized.toLowerCase().includes('tool_calls');
|
||||
|
||||
const candidates = buildToolCallCandidates(sanitized);
|
||||
let parsed = [];
|
||||
for (const c of candidates) {
|
||||
parsed = parseToolCallsPayload(c);
|
||||
if (parsed.length > 0) {
|
||||
result.sawToolCallSyntax = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (parsed.length === 0) {
|
||||
return [];
|
||||
return result;
|
||||
}
|
||||
return filterToolCalls(parsed, toolNames);
|
||||
}
|
||||
|
||||
function stripFencedCodeBlocks(text) {
|
||||
const t = typeof text === 'string' ? text : '';
|
||||
if (!t) {
|
||||
return '';
|
||||
}
|
||||
return t.replace(/```[\s\S]*?```/g, ' ');
|
||||
const filtered = filterToolCallsDetailed(parsed, toolNames);
|
||||
result.calls = filtered.calls;
|
||||
result.rejectedToolNames = filtered.rejectedToolNames;
|
||||
result.rejectedByPolicy = filtered.rejectedToolNames.length > 0 && filtered.calls.length === 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
function parseStandaloneToolCalls(text, toolNames) {
|
||||
return parseStandaloneToolCallsDetailed(text, toolNames).calls;
|
||||
}
|
||||
|
||||
function parseStandaloneToolCallsDetailed(text, toolNames) {
|
||||
const result = emptyParseResult();
|
||||
const trimmed = toStringSafe(text);
|
||||
if (!trimmed) {
|
||||
return [];
|
||||
}
|
||||
if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) {
|
||||
return [];
|
||||
return result;
|
||||
}
|
||||
if (looksLikeToolExampleContext(trimmed)) {
|
||||
return [];
|
||||
return result;
|
||||
}
|
||||
const candidates = [trimmed];
|
||||
if (trimmed.startsWith('```') && trimmed.endsWith('```')) {
|
||||
const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
||||
if (m && m[1]) {
|
||||
candidates.push(toStringSafe(m[1]));
|
||||
}
|
||||
result.sawToolCallSyntax = trimmed.toLowerCase().includes('tool_calls');
|
||||
if (!trimmed.startsWith('{') && !trimmed.startsWith('[')) {
|
||||
return result;
|
||||
}
|
||||
for (const candidate of candidates) {
|
||||
const c = toStringSafe(candidate);
|
||||
if (!c) {
|
||||
continue;
|
||||
}
|
||||
if (!c.startsWith('{') && !c.startsWith('[')) {
|
||||
continue;
|
||||
}
|
||||
const parsed = parseToolCallsPayload(c);
|
||||
if (parsed.length > 0) {
|
||||
return filterToolCalls(parsed, toolNames);
|
||||
}
|
||||
|
||||
const parsed = parseToolCallsPayload(trimmed);
|
||||
if (parsed.length === 0) {
|
||||
return result;
|
||||
}
|
||||
return [];
|
||||
|
||||
result.sawToolCallSyntax = true;
|
||||
const filtered = filterToolCallsDetailed(parsed, toolNames);
|
||||
result.calls = filtered.calls;
|
||||
result.rejectedToolNames = filtered.rejectedToolNames;
|
||||
result.rejectedByPolicy = filtered.rejectedToolNames.length > 0 && filtered.calls.length === 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
function buildToolCallCandidates(text) {
|
||||
const trimmed = toStringSafe(text);
|
||||
const candidates = [trimmed];
|
||||
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
|
||||
for (const block of fenced) {
|
||||
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
||||
if (m && m[1]) {
|
||||
candidates.push(toStringSafe(m[1]));
|
||||
}
|
||||
}
|
||||
for (const candidate of extractToolCallObjects(trimmed)) {
|
||||
candidates.push(toStringSafe(candidate));
|
||||
}
|
||||
const first = trimmed.indexOf('{');
|
||||
const last = trimmed.lastIndexOf('}');
|
||||
if (first >= 0 && last > first) {
|
||||
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
|
||||
}
|
||||
const m = trimmed.match(TOOL_CALL_PATTERN);
|
||||
if (m && m[1]) {
|
||||
candidates.push(`{"tool_calls":[${m[1]}]}`);
|
||||
}
|
||||
return [...new Set(candidates.filter(Boolean))];
|
||||
}
|
||||
|
||||
function extractToolCallObjects(text) {
|
||||
const raw = toStringSafe(text);
|
||||
if (!raw) {
|
||||
return [];
|
||||
}
|
||||
const lower = raw.toLowerCase();
|
||||
const out = [];
|
||||
let offset = 0;
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
let idx = lower.indexOf('tool_calls', offset);
|
||||
if (idx < 0) {
|
||||
break;
|
||||
}
|
||||
let start = raw.slice(0, idx).lastIndexOf('{');
|
||||
while (start >= 0) {
|
||||
const obj = extractJSONObjectFrom(raw, start);
|
||||
if (obj.ok) {
|
||||
out.push(raw.slice(start, obj.end).trim());
|
||||
offset = obj.end;
|
||||
idx = -1;
|
||||
break;
|
||||
}
|
||||
start = raw.slice(0, start).lastIndexOf('{');
|
||||
}
|
||||
if (idx >= 0) {
|
||||
offset = idx + 'tool_calls'.length;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseToolCallsPayload(payload) {
|
||||
let decoded;
|
||||
try {
|
||||
decoded = JSON.parse(payload);
|
||||
} catch (_err) {
|
||||
return [];
|
||||
}
|
||||
if (Array.isArray(decoded)) {
|
||||
return parseToolCallList(decoded);
|
||||
}
|
||||
if (!decoded || typeof decoded !== 'object') {
|
||||
return [];
|
||||
}
|
||||
if (decoded.tool_calls) {
|
||||
return parseToolCallList(decoded.tool_calls);
|
||||
}
|
||||
const one = parseToolCallItem(decoded);
|
||||
return one ? [one] : [];
|
||||
}
|
||||
|
||||
function parseToolCallList(v) {
|
||||
if (!Array.isArray(v)) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const item of v) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const one = parseToolCallItem(item);
|
||||
if (one) {
|
||||
out.push(one);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseToolCallItem(m) {
|
||||
let name = toStringSafe(m.name);
|
||||
let inputRaw = m.input;
|
||||
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
||||
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
||||
if (fn) {
|
||||
if (!name) {
|
||||
name = toStringSafe(fn.name);
|
||||
}
|
||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
|
||||
inputRaw = fn.arguments;
|
||||
hasInput = true;
|
||||
}
|
||||
}
|
||||
if (!hasInput) {
|
||||
for (const k of ['arguments', 'args', 'parameters', 'params']) {
|
||||
if (Object.prototype.hasOwnProperty.call(m, k)) {
|
||||
inputRaw = m[k];
|
||||
hasInput = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!name) {
|
||||
return null;
|
||||
}
|
||||
function emptyParseResult() {
|
||||
return {
|
||||
name,
|
||||
input: parseToolCallInput(inputRaw),
|
||||
calls: [],
|
||||
sawToolCallSyntax: false,
|
||||
rejectedByPolicy: false,
|
||||
rejectedToolNames: [],
|
||||
};
|
||||
}
|
||||
|
||||
function parseToolCallInput(v) {
|
||||
if (v == null) {
|
||||
return {};
|
||||
}
|
||||
if (typeof v === 'string') {
|
||||
const raw = toStringSafe(v);
|
||||
if (!raw) {
|
||||
return {};
|
||||
function filterToolCallsDetailed(parsed, toolNames) {
|
||||
const sourceNames = Array.isArray(toolNames) ? toolNames : [];
|
||||
const allowed = new Set();
|
||||
const allowedCanonical = new Map();
|
||||
for (const item of sourceNames) {
|
||||
const name = toStringSafe(item);
|
||||
if (!name) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(raw);
|
||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||
return parsed;
|
||||
}
|
||||
return { _raw: raw };
|
||||
} catch (_err) {
|
||||
return { _raw: raw };
|
||||
allowed.add(name);
|
||||
const lower = name.toLowerCase();
|
||||
if (!allowedCanonical.has(lower)) {
|
||||
allowedCanonical.set(lower, name);
|
||||
}
|
||||
}
|
||||
if (typeof v === 'object' && !Array.isArray(v)) {
|
||||
return v;
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(JSON.stringify(v));
|
||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||
return parsed;
|
||||
}
|
||||
} catch (_err) {
|
||||
return {};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
function filterToolCalls(parsed, toolNames) {
|
||||
const allowed = new Set((toolNames || []).filter(Boolean));
|
||||
const out = [];
|
||||
if (allowed.size === 0) {
|
||||
const rejected = [];
|
||||
const seen = new Set();
|
||||
for (const tc of parsed) {
|
||||
if (!tc || !tc.name) {
|
||||
continue;
|
||||
}
|
||||
if (seen.has(tc.name)) {
|
||||
continue;
|
||||
}
|
||||
seen.add(tc.name);
|
||||
rejected.push(tc.name);
|
||||
}
|
||||
return { calls: [], rejectedToolNames: rejected };
|
||||
}
|
||||
|
||||
const calls = [];
|
||||
const rejected = [];
|
||||
const seenRejected = new Set();
|
||||
for (const tc of parsed) {
|
||||
if (!tc || !tc.name) {
|
||||
continue;
|
||||
}
|
||||
if (allowed.size > 0 && !allowed.has(tc.name)) {
|
||||
let matchedName = '';
|
||||
if (allowed.has(tc.name)) {
|
||||
matchedName = tc.name;
|
||||
} else {
|
||||
matchedName = allowedCanonical.get(tc.name.toLowerCase()) || '';
|
||||
}
|
||||
if (!matchedName) {
|
||||
if (!seenRejected.has(tc.name)) {
|
||||
seenRejected.add(tc.name);
|
||||
rejected.push(tc.name);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
out.push({ name: tc.name, input: tc.input || {} });
|
||||
calls.push({
|
||||
name: matchedName,
|
||||
input: tc.input && typeof tc.input === 'object' && !Array.isArray(tc.input) ? tc.input : {},
|
||||
});
|
||||
}
|
||||
return out;
|
||||
return { calls, rejectedToolNames: rejected };
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
extractToolNames,
|
||||
parseToolCalls,
|
||||
parseToolCallsDetailed,
|
||||
parseStandaloneToolCalls,
|
||||
parseStandaloneToolCallsDetailed,
|
||||
};
|
||||
|
||||
196
internal/js/helpers/stream-tool-sieve/parse_payload.js
Normal file
196
internal/js/helpers/stream-tool-sieve/parse_payload.js
Normal file
@@ -0,0 +1,196 @@
|
||||
'use strict';
|
||||
|
||||
const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
|
||||
|
||||
const {
|
||||
toStringSafe,
|
||||
} = require('./state');
|
||||
const {
|
||||
extractJSONObjectFrom,
|
||||
} = require('./jsonscan');
|
||||
|
||||
function stripFencedCodeBlocks(text) {
|
||||
const t = typeof text === 'string' ? text : '';
|
||||
if (!t) {
|
||||
return '';
|
||||
}
|
||||
return t.replace(/```[\s\S]*?```/g, ' ');
|
||||
}
|
||||
|
||||
function buildToolCallCandidates(text) {
|
||||
const trimmed = toStringSafe(text);
|
||||
const candidates = [trimmed];
|
||||
|
||||
const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
|
||||
for (const block of fenced) {
|
||||
const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
|
||||
if (m && m[1]) {
|
||||
candidates.push(toStringSafe(m[1]));
|
||||
}
|
||||
}
|
||||
|
||||
for (const candidate of extractToolCallObjects(trimmed)) {
|
||||
candidates.push(toStringSafe(candidate));
|
||||
}
|
||||
|
||||
const first = trimmed.indexOf('{');
|
||||
const last = trimmed.lastIndexOf('}');
|
||||
if (first >= 0 && last > first) {
|
||||
candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
|
||||
}
|
||||
|
||||
const m = trimmed.match(TOOL_CALL_PATTERN);
|
||||
if (m && m[1]) {
|
||||
candidates.push(`{"tool_calls":[${m[1]}]}`);
|
||||
}
|
||||
|
||||
return [...new Set(candidates.filter(Boolean))];
|
||||
}
|
||||
|
||||
function extractToolCallObjects(text) {
|
||||
const raw = toStringSafe(text);
|
||||
if (!raw) {
|
||||
return [];
|
||||
}
|
||||
const lower = raw.toLowerCase();
|
||||
const out = [];
|
||||
let offset = 0;
|
||||
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
let idx = lower.indexOf('tool_calls', offset);
|
||||
if (idx < 0) {
|
||||
break;
|
||||
}
|
||||
let start = raw.slice(0, idx).lastIndexOf('{');
|
||||
while (start >= 0) {
|
||||
const obj = extractJSONObjectFrom(raw, start);
|
||||
if (obj.ok) {
|
||||
out.push(raw.slice(start, obj.end).trim());
|
||||
offset = obj.end;
|
||||
idx = -1;
|
||||
break;
|
||||
}
|
||||
start = raw.slice(0, start).lastIndexOf('{');
|
||||
}
|
||||
if (idx >= 0) {
|
||||
offset = idx + 'tool_calls'.length;
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseToolCallsPayload(payload) {
|
||||
let decoded;
|
||||
try {
|
||||
decoded = JSON.parse(payload);
|
||||
} catch (_err) {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (Array.isArray(decoded)) {
|
||||
return parseToolCallList(decoded);
|
||||
}
|
||||
if (!decoded || typeof decoded !== 'object') {
|
||||
return [];
|
||||
}
|
||||
if (decoded.tool_calls) {
|
||||
return parseToolCallList(decoded.tool_calls);
|
||||
}
|
||||
|
||||
const one = parseToolCallItem(decoded);
|
||||
return one ? [one] : [];
|
||||
}
|
||||
|
||||
function parseToolCallList(v) {
|
||||
if (!Array.isArray(v)) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const item of v) {
|
||||
if (!item || typeof item !== 'object') {
|
||||
continue;
|
||||
}
|
||||
const one = parseToolCallItem(item);
|
||||
if (one) {
|
||||
out.push(one);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function parseToolCallItem(m) {
|
||||
let name = toStringSafe(m.name);
|
||||
let inputRaw = m.input;
|
||||
let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
|
||||
const fn = m.function && typeof m.function === 'object' ? m.function : null;
|
||||
|
||||
if (fn) {
|
||||
if (!name) {
|
||||
name = toStringSafe(fn.name);
|
||||
}
|
||||
if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
|
||||
inputRaw = fn.arguments;
|
||||
hasInput = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasInput) {
|
||||
for (const k of ['arguments', 'args', 'parameters', 'params']) {
|
||||
if (Object.prototype.hasOwnProperty.call(m, k)) {
|
||||
inputRaw = m[k];
|
||||
hasInput = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!name) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
name,
|
||||
input: parseToolCallInput(inputRaw),
|
||||
};
|
||||
}
|
||||
|
||||
function parseToolCallInput(v) {
|
||||
if (v == null) {
|
||||
return {};
|
||||
}
|
||||
if (typeof v === 'string') {
|
||||
const raw = toStringSafe(v);
|
||||
if (!raw) {
|
||||
return {};
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(raw);
|
||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||
return parsed;
|
||||
}
|
||||
return { _raw: raw };
|
||||
} catch (_err) {
|
||||
return { _raw: raw };
|
||||
}
|
||||
}
|
||||
if (typeof v === 'object' && !Array.isArray(v)) {
|
||||
return v;
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(JSON.stringify(v));
|
||||
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||
return parsed;
|
||||
}
|
||||
} catch (_err) {
|
||||
return {};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
stripFencedCodeBlocks,
|
||||
buildToolCallCandidates,
|
||||
parseToolCallsPayload,
|
||||
};
|
||||
@@ -1,16 +1,12 @@
|
||||
'use strict';
|
||||
|
||||
const {
|
||||
TOOL_SIEVE_CAPTURE_LIMIT,
|
||||
resetIncrementalToolState,
|
||||
noteText,
|
||||
insideCodeFence,
|
||||
} = require('./state');
|
||||
const {
|
||||
buildIncrementalToolDeltas,
|
||||
} = require('./incremental');
|
||||
const {
|
||||
parseStandaloneToolCalls,
|
||||
parseStandaloneToolCallsDetailed,
|
||||
} = require('./parse');
|
||||
const {
|
||||
extractJSONObjectFrom,
|
||||
@@ -24,6 +20,21 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
state.pending += chunk;
|
||||
}
|
||||
const events = [];
|
||||
|
||||
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||
const pending = state.pending || '';
|
||||
if (pending.trim() !== '') {
|
||||
const content = (state.pendingToolRaw || '') + pending;
|
||||
state.pending = '';
|
||||
state.pendingToolRaw = '';
|
||||
state.pendingToolCalls = [];
|
||||
noteText(state, content);
|
||||
events.push({ type: 'text', text: content });
|
||||
} else {
|
||||
return events;
|
||||
}
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
if (state.capturing) {
|
||||
@@ -31,57 +42,50 @@ function processToolSieveChunk(state, chunk, toolNames) {
|
||||
state.capture += state.pending;
|
||||
state.pending = '';
|
||||
}
|
||||
const deltas = buildIncrementalToolDeltas(state);
|
||||
if (deltas.length > 0) {
|
||||
events.push({ type: 'tool_call_deltas', deltas });
|
||||
}
|
||||
const consumed = consumeToolCapture(state, toolNames);
|
||||
if (!consumed.ready) {
|
||||
if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) {
|
||||
noteText(state, state.capture);
|
||||
events.push({ type: 'text', text: state.capture });
|
||||
state.capture = '';
|
||||
state.capturing = false;
|
||||
resetIncrementalToolState(state);
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
const captured = state.capture;
|
||||
state.capture = '';
|
||||
state.capturing = false;
|
||||
resetIncrementalToolState(state);
|
||||
|
||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
||||
state.pendingToolRaw = captured;
|
||||
state.pendingToolCalls = consumed.calls;
|
||||
continue;
|
||||
}
|
||||
if (consumed.prefix) {
|
||||
noteText(state, consumed.prefix);
|
||||
events.push({ type: 'text', text: consumed.prefix });
|
||||
}
|
||||
if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
|
||||
events.push({ type: 'tool_calls', calls: consumed.calls });
|
||||
}
|
||||
if (consumed.suffix) {
|
||||
state.pending += consumed.suffix;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!state.pending) {
|
||||
const pending = state.pending || '';
|
||||
if (!pending) {
|
||||
break;
|
||||
}
|
||||
|
||||
const start = findToolSegmentStart(state.pending);
|
||||
const start = findToolSegmentStart(pending);
|
||||
if (start >= 0) {
|
||||
const prefix = state.pending.slice(0, start);
|
||||
const prefix = pending.slice(0, start);
|
||||
if (prefix) {
|
||||
noteText(state, prefix);
|
||||
events.push({ type: 'text', text: prefix });
|
||||
}
|
||||
state.capture = state.pending.slice(start);
|
||||
state.pending = '';
|
||||
state.capture += pending.slice(start);
|
||||
state.capturing = true;
|
||||
resetIncrementalToolState(state);
|
||||
continue;
|
||||
}
|
||||
|
||||
const [safe, hold] = splitSafeContentForToolDetection(state.pending);
|
||||
const [safe, hold] = splitSafeContentForToolDetection(pending);
|
||||
if (!safe) {
|
||||
break;
|
||||
}
|
||||
@@ -97,6 +101,13 @@ function flushToolSieve(state, toolNames) {
|
||||
return [];
|
||||
}
|
||||
const events = processToolSieveChunk(state, '', toolNames);
|
||||
|
||||
if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) {
|
||||
events.push({ type: 'tool_calls', calls: state.pendingToolCalls });
|
||||
state.pendingToolRaw = '';
|
||||
state.pendingToolCalls = [];
|
||||
}
|
||||
|
||||
if (state.capturing) {
|
||||
const consumed = consumeToolCapture(state, toolNames);
|
||||
if (consumed.ready) {
|
||||
@@ -119,11 +130,13 @@ function flushToolSieve(state, toolNames) {
|
||||
state.capturing = false;
|
||||
resetIncrementalToolState(state);
|
||||
}
|
||||
|
||||
if (state.pending) {
|
||||
noteText(state, state.pending);
|
||||
events.push({ type: 'text', text: state.pending });
|
||||
state.pending = '';
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
@@ -163,11 +176,10 @@ function findToolSegmentStart(s) {
|
||||
let offset = 0;
|
||||
// eslint-disable-next-line no-constant-condition
|
||||
while (true) {
|
||||
const keyRel = lower.indexOf('tool_calls', offset);
|
||||
if (keyRel < 0) {
|
||||
const keyIdx = lower.indexOf('tool_calls', offset);
|
||||
if (keyIdx < 0) {
|
||||
return -1;
|
||||
}
|
||||
const keyIdx = keyRel;
|
||||
const start = s.slice(0, keyIdx).lastIndexOf('{');
|
||||
const candidateStart = start >= 0 ? start : keyIdx;
|
||||
if (!insideCodeFence(s.slice(0, candidateStart))) {
|
||||
@@ -178,7 +190,7 @@ function findToolSegmentStart(s) {
|
||||
}
|
||||
|
||||
function consumeToolCapture(state, toolNames) {
|
||||
const captured = state.capture;
|
||||
const captured = state.capture || '';
|
||||
if (!captured) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
@@ -195,8 +207,10 @@ function consumeToolCapture(state, toolNames) {
|
||||
if (!obj.ok) {
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
}
|
||||
|
||||
const prefixPart = captured.slice(0, start);
|
||||
const suffixPart = captured.slice(obj.end);
|
||||
|
||||
if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
|
||||
return {
|
||||
ready: true,
|
||||
@@ -205,18 +219,19 @@ function consumeToolCapture(state, toolNames) {
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
const rawParsed = parseStandaloneToolCalls(captured.slice(start, obj.end), []);
|
||||
const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames);
|
||||
if (parsed.length === 0) {
|
||||
if (rawParsed.length > 0 && Array.isArray(toolNames) && toolNames.length > 0) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
calls: [],
|
||||
suffix: suffixPart,
|
||||
};
|
||||
}
|
||||
if (state.toolNameSent) {
|
||||
|
||||
if ((state.recentTextTail || '').trim() !== '' || prefixPart.trim() !== '' || suffixPart.trim() !== '') {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: captured,
|
||||
calls: [],
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
|
||||
const parsed = parseStandaloneToolCallsDetailed(captured.slice(start, obj.end), toolNames);
|
||||
if (!Array.isArray(parsed.calls) || parsed.calls.length === 0) {
|
||||
if (parsed.sawToolCallSyntax && parsed.rejectedByPolicy) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
@@ -231,26 +246,11 @@ function consumeToolCapture(state, toolNames) {
|
||||
suffix: '',
|
||||
};
|
||||
}
|
||||
if (state.toolNameSent) {
|
||||
if (parsed.length > 1) {
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
calls: parsed.slice(1),
|
||||
suffix: suffixPart,
|
||||
};
|
||||
}
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
calls: [],
|
||||
suffix: suffixPart,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
ready: true,
|
||||
prefix: prefixPart,
|
||||
calls: parsed,
|
||||
calls: parsed.calls,
|
||||
suffix: suffixPart,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
'use strict';
|
||||
|
||||
const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024;
|
||||
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
|
||||
|
||||
function createToolSieveState() {
|
||||
@@ -9,6 +8,9 @@ function createToolSieveState() {
|
||||
capture: '',
|
||||
capturing: false,
|
||||
recentTextTail: '',
|
||||
pendingToolRaw: '',
|
||||
pendingToolCalls: [],
|
||||
disableDeltas: false,
|
||||
toolNameSent: false,
|
||||
toolName: '',
|
||||
toolArgsStart: -1,
|
||||
@@ -19,6 +21,7 @@ function createToolSieveState() {
|
||||
}
|
||||
|
||||
function resetIncrementalToolState(state) {
|
||||
state.disableDeltas = false;
|
||||
state.toolNameSent = false;
|
||||
state.toolName = '';
|
||||
state.toolArgsStart = -1;
|
||||
@@ -78,7 +81,6 @@ function toStringSafe(v) {
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
TOOL_SIEVE_CAPTURE_LIMIT,
|
||||
TOOL_SIEVE_CONTEXT_TAIL_LIMIT,
|
||||
createToolSieveState,
|
||||
resetIncrementalToolState,
|
||||
|
||||
@@ -57,16 +57,20 @@ func NewApp() *App {
|
||||
r.Use(cors)
|
||||
r.Use(timeout(0))
|
||||
|
||||
r.Get("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
healthzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||
})
|
||||
r.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
readyzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"status":"ready"}`))
|
||||
})
|
||||
}
|
||||
r.Get("/healthz", healthzHandler)
|
||||
r.Head("/healthz", healthzHandler)
|
||||
r.Get("/readyz", readyzHandler)
|
||||
r.Head("/readyz", readyzHandler)
|
||||
openai.RegisterRoutes(r, openaiHandler)
|
||||
claude.RegisterRoutes(r, claudeHandler)
|
||||
gemini.RegisterRoutes(r, geminiHandler)
|
||||
|
||||
20
internal/server/router_health_test.go
Normal file
20
internal/server/router_health_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHealthEndpointsSupportHEAD(t *testing.T) {
|
||||
app := NewApp()
|
||||
|
||||
for _, path := range []string{"/healthz", "/readyz"} {
|
||||
req := httptest.NewRequest(http.MethodHead, path, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.Router.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected %s HEAD status 200, got %d", path, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,12 @@ func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error {
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
|
||||
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/healthz", Retryable: true})
|
||||
if headErr != nil {
|
||||
return headErr
|
||||
}
|
||||
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -29,6 +35,12 @@ func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error {
|
||||
var m map[string]any
|
||||
_ = json.Unmarshal(resp.Body, &m)
|
||||
cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
|
||||
|
||||
headResp, headErr := cc.request(ctx, requestSpec{Method: http.MethodHead, Path: "/readyz", Retryable: true})
|
||||
if headErr != nil {
|
||||
return headErr
|
||||
}
|
||||
cc.assert("head_status_200", headResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", headResp.StatusCode))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ internal/js/helpers/stream-tool-sieve.js
|
||||
internal/js/helpers/stream-tool-sieve/index.js
|
||||
internal/js/helpers/stream-tool-sieve/state.js
|
||||
internal/js/helpers/stream-tool-sieve/sieve.js
|
||||
internal/js/helpers/stream-tool-sieve/incremental.js
|
||||
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
||||
internal/js/helpers/stream-tool-sieve/parse.js
|
||||
internal/js/helpers/stream-tool-sieve/format.js
|
||||
|
||||
@@ -105,7 +105,6 @@ internal/js/helpers/stream-tool-sieve.js
|
||||
internal/js/helpers/stream-tool-sieve/index.js
|
||||
internal/js/helpers/stream-tool-sieve/state.js
|
||||
internal/js/helpers/stream-tool-sieve/sieve.js
|
||||
internal/js/helpers/stream-tool-sieve/incremental.js
|
||||
internal/js/helpers/stream-tool-sieve/jsonscan.js
|
||||
internal/js/helpers/stream-tool-sieve/parse.js
|
||||
internal/js/helpers/stream-tool-sieve/format.js
|
||||
|
||||
3
tests/compat/expected/toolcalls_allowlist_empty.json
Normal file
3
tests/compat/expected/toolcalls_allowlist_empty.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"calls": []
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"calls": [
|
||||
{
|
||||
"name": "read_file",
|
||||
"input": {
|
||||
"path": "README.MD"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"calls": []
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"calls": []
|
||||
}
|
||||
10
tests/compat/expected/toolcalls_standalone_pure.json
Normal file
10
tests/compat/expected/toolcalls_standalone_pure.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"calls": [
|
||||
{
|
||||
"name": "read_file",
|
||||
"input": {
|
||||
"path": "README.MD"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
4
tests/compat/fixtures/toolcalls/allowlist_empty.json
Normal file
4
tests/compat/fixtures/toolcalls/allowlist_empty.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"text": "{\"tool_calls\":[{\"name\":\"unknown_tool\",\"input\":{\"x\":1}}]}",
|
||||
"tool_names": []
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"text": "{\"tool_calls\":[{\"name\":\"Read_File\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||
"tool_names": ["read_file"]
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"mode": "standalone",
|
||||
"text": "```json\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}\n```",
|
||||
"tool_names": ["read_file"]
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"mode": "standalone",
|
||||
"text": "下面是示例:{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}请勿执行。",
|
||||
"tool_names": ["read_file"]
|
||||
}
|
||||
5
tests/compat/fixtures/toolcalls/standalone_pure.json
Normal file
5
tests/compat/fixtures/toolcalls/standalone_pure.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"mode": "standalone",
|
||||
"text": "{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
|
||||
"tool_names": ["read_file"]
|
||||
}
|
||||
@@ -13,8 +13,10 @@ const {
|
||||
const {
|
||||
parseChunkForContent,
|
||||
resolveToolcallPolicy,
|
||||
formatIncrementalToolCallDeltas,
|
||||
normalizePreparedToolNames,
|
||||
boolDefaultTrue,
|
||||
filterIncrementalToolCallDeltasByAllowed,
|
||||
} = handler.__test;
|
||||
|
||||
test('chat-stream exposes parser test hooks', () => {
|
||||
@@ -56,6 +58,46 @@ test('boolDefaultTrue keeps false only when explicitly false', () => {
|
||||
assert.equal(boolDefaultTrue(undefined), true);
|
||||
});
|
||||
|
||||
test('filterIncrementalToolCallDeltasByAllowed blocks unknown name and follow-up args', () => {
|
||||
const seen = new Map();
|
||||
const filtered = filterIncrementalToolCallDeltasByAllowed(
|
||||
[
|
||||
{ index: 0, name: 'not_in_schema' },
|
||||
{ index: 0, arguments: '{"x":1}' },
|
||||
],
|
||||
['read_file'],
|
||||
seen,
|
||||
);
|
||||
assert.deepEqual(filtered, []);
|
||||
assert.equal(seen.get(0), '__blocked__');
|
||||
});
|
||||
|
||||
test('filterIncrementalToolCallDeltasByAllowed keeps allowed name and args', () => {
|
||||
const seen = new Map();
|
||||
const filtered = filterIncrementalToolCallDeltasByAllowed(
|
||||
[
|
||||
{ index: 0, name: 'read_file' },
|
||||
{ index: 0, arguments: '{"path":"README.MD"}' },
|
||||
],
|
||||
['read_file'],
|
||||
seen,
|
||||
);
|
||||
assert.deepEqual(filtered, [
|
||||
{ index: 0, name: 'read_file' },
|
||||
{ index: 0, arguments: '{"path":"README.MD"}' },
|
||||
]);
|
||||
});
|
||||
|
||||
test('incremental and final tool formatting share stable id via idStore', () => {
|
||||
const idStore = new Map();
|
||||
const incremental = formatIncrementalToolCallDeltas([{ index: 0, name: 'read_file' }], idStore);
|
||||
const { formatOpenAIStreamToolCalls } = require('../../internal/js/helpers/stream-tool-sieve.js');
|
||||
const finalCalls = formatOpenAIStreamToolCalls([{ name: 'read_file', input: { path: 'README.MD' } }], idStore);
|
||||
assert.equal(incremental.length, 1);
|
||||
assert.equal(finalCalls.length, 1);
|
||||
assert.equal(incremental[0].id, finalCalls[0].id);
|
||||
});
|
||||
|
||||
test('parseChunkForContent keeps split response/content fragments inside response array', () => {
|
||||
const chunk = {
|
||||
p: 'response',
|
||||
|
||||
@@ -6,7 +6,7 @@ const fs = require('node:fs');
|
||||
const path = require('node:path');
|
||||
|
||||
const chatStream = require('../../api/chat-stream.js');
|
||||
const { parseToolCalls } = require('../../internal/js/helpers/stream-tool-sieve.js');
|
||||
const { parseToolCalls, parseStandaloneToolCalls } = require('../../internal/js/helpers/stream-tool-sieve.js');
|
||||
|
||||
const { parseChunkForContent, estimateTokens } = chatStream.__test;
|
||||
|
||||
@@ -41,12 +41,14 @@ test('js compat: toolcall fixtures', () => {
|
||||
|
||||
for (const file of files) {
|
||||
const name = file.replace(/\.json$/i, '');
|
||||
const fixture = readJSON(path.join(fixtureDir, file));
|
||||
const expected = readJSON(path.join(expectedDir, `toolcalls_${name}.json`));
|
||||
const got = parseToolCalls(fixture.text, fixture.tool_names || []);
|
||||
assert.deepEqual(got, expected.calls, `${name}: calls mismatch`);
|
||||
}
|
||||
});
|
||||
const fixture = readJSON(path.join(fixtureDir, file));
|
||||
const expected = readJSON(path.join(expectedDir, `toolcalls_${name}.json`));
|
||||
const mode = typeof fixture.mode === 'string' ? fixture.mode.trim().toLowerCase() : '';
|
||||
const parser = mode === 'standalone' ? parseStandaloneToolCalls : parseToolCalls;
|
||||
const got = parser(fixture.text, fixture.tool_names || []);
|
||||
assert.deepEqual(got, expected.calls, `${name}: calls mismatch`);
|
||||
}
|
||||
});
|
||||
|
||||
test('js compat: token fixtures', () => {
|
||||
const fixture = readJSON(path.join(compatRoot, 'fixtures', 'token_cases.json'));
|
||||
|
||||
@@ -9,7 +9,9 @@ const {
|
||||
processToolSieveChunk,
|
||||
flushToolSieve,
|
||||
parseToolCalls,
|
||||
parseToolCallsDetailed,
|
||||
parseStandaloneToolCalls,
|
||||
formatOpenAIStreamToolCalls,
|
||||
} = require('../../internal/js/helpers/stream-tool-sieve.js');
|
||||
|
||||
function runSieve(chunks, toolNames) {
|
||||
@@ -60,13 +62,25 @@ test('parseToolCalls drops unknown schema names when toolNames is provided', ()
|
||||
assert.equal(calls.length, 0);
|
||||
});
|
||||
|
||||
test('parseToolCalls keeps unknown names when toolNames is empty', () => {
|
||||
test('parseToolCalls matches tool name case-insensitively and canonicalizes', () => {
|
||||
const payload = JSON.stringify({
|
||||
tool_calls: [{ name: 'Read_File', input: { path: 'README.MD' } }],
|
||||
});
|
||||
const calls = parseToolCalls(payload, ['read_file']);
|
||||
assert.deepEqual(calls, [{ name: 'read_file', input: { path: 'README.MD' } }]);
|
||||
});
|
||||
|
||||
test('parseToolCalls rejects all names when toolNames is empty (Go strict parity)', () => {
|
||||
const payload = JSON.stringify({
|
||||
tool_calls: [{ name: 'not_in_schema', input: { q: 'go' } }],
|
||||
});
|
||||
const calls = parseToolCalls(payload, []);
|
||||
assert.equal(calls.length, 1);
|
||||
assert.equal(calls[0].name, 'not_in_schema');
|
||||
assert.equal(calls.length, 0);
|
||||
|
||||
const detailed = parseToolCallsDetailed(payload, []);
|
||||
assert.equal(detailed.sawToolCallSyntax, true);
|
||||
assert.equal(detailed.rejectedByPolicy, true);
|
||||
assert.deepEqual(detailed.rejectedToolNames, ['not_in_schema']);
|
||||
});
|
||||
|
||||
test('parseToolCalls supports fenced json and function.arguments string payload', () => {
|
||||
@@ -95,7 +109,7 @@ test('parseStandaloneToolCalls ignores fenced code block tool_call examples', ()
|
||||
assert.equal(calls.length, 0);
|
||||
});
|
||||
|
||||
test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => {
|
||||
test('sieve keeps late key convergence payload as plain text in strict mode', () => {
|
||||
const events = runSieve(
|
||||
[
|
||||
'{"',
|
||||
@@ -107,9 +121,9 @@ test('sieve emits tool_calls and does not leak suspicious prefix on late key con
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0);
|
||||
const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0);
|
||||
assert.equal(hasToolCall || hasToolDelta, true);
|
||||
assert.equal(leakedText.includes('{'), false);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
assert.equal(hasToolCall || hasToolDelta, false);
|
||||
assert.equal(leakedText.includes('{'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
|
||||
assert.equal(leakedText.includes('后置正文C。'), true);
|
||||
});
|
||||
|
||||
@@ -141,6 +155,20 @@ test('sieve flushes incomplete captured tool json as text on stream finalize', (
|
||||
assert.equal(leakedText.includes('{'), true);
|
||||
});
|
||||
|
||||
test('sieve still intercepts large tool json payloads over previous capture limit', () => {
|
||||
const large = 'a'.repeat(9000);
|
||||
const payload = `{"tool_calls":[{"name":"read_file","input":{"path":"${large}"}}]}`;
|
||||
const events = runSieve(
|
||||
[payload.slice(0, 3000), payload.slice(3000, 7000), payload.slice(7000)],
|
||||
['read_file'],
|
||||
);
|
||||
const leakedText = collectText(events);
|
||||
const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && evt.calls?.length > 0);
|
||||
const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && evt.deltas?.length > 0);
|
||||
assert.equal(hasToolCall || hasToolDelta, true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
});
|
||||
|
||||
test('sieve keeps plain text intact in tool mode when no tool call appears', () => {
|
||||
const events = runSieve(
|
||||
['你好,', '这是普通文本回复。', '请继续。'],
|
||||
@@ -166,7 +194,7 @@ test('sieve intercepts rejected unknown tool payload (no args) without raw leak'
|
||||
assert.equal(leakedText.includes('后置正文G。'), true);
|
||||
});
|
||||
|
||||
test('sieve emits incremental tool_call_deltas for split arguments payload', () => {
|
||||
test('sieve emits final tool_calls for split arguments payload without incremental deltas', () => {
|
||||
const state = createToolSieveState();
|
||||
const first = processToolSieveChunk(
|
||||
state,
|
||||
@@ -181,37 +209,43 @@ test('sieve emits incremental tool_call_deltas for split arguments payload', ()
|
||||
const tail = flushToolSieve(state, ['read_file']);
|
||||
const events = [...first, ...second, ...tail];
|
||||
const deltaEvents = events.filter((evt) => evt.type === 'tool_call_deltas');
|
||||
assert.equal(deltaEvents.length > 0, true);
|
||||
const merged = deltaEvents.flatMap((evt) => evt.deltas || []);
|
||||
const hasName = merged.some((d) => d.name === 'read_file');
|
||||
const argsJoined = merged
|
||||
.map((d) => d.arguments || '')
|
||||
.join('');
|
||||
assert.equal(hasName, true);
|
||||
assert.equal(argsJoined.includes('"path":"README.MD"'), true);
|
||||
assert.equal(argsJoined.includes('"mode":"head"'), true);
|
||||
assert.equal(deltaEvents.length, 0);
|
||||
const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||
assert.equal(finalCalls.length, 1);
|
||||
assert.equal(finalCalls[0].name, 'read_file');
|
||||
assert.deepEqual(finalCalls[0].input, { path: 'README.MD', mode: 'head' });
|
||||
});
|
||||
|
||||
test('sieve still intercepts tool call after leading plain text without suffix', () => {
|
||||
test('sieve keeps tool json as text when leading prose exists (strict mode)', () => {
|
||||
const events = runSieve(
|
||||
['我将调用工具。', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],
|
||||
['read_file'],
|
||||
);
|
||||
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
|
||||
const leakedText = collectText(events);
|
||||
assert.equal(hasTool, true);
|
||||
assert.equal(hasTool, false);
|
||||
assert.equal(leakedText.includes('我将调用工具。'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
|
||||
});
|
||||
|
||||
test('sieve intercepts tool call and preserves trailing same-chunk text', () => {
|
||||
test('sieve keeps same-chunk trailing prose payload as text in strict mode', () => {
|
||||
const events = runSieve(
|
||||
['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}然后继续解释。'],
|
||||
['read_file'],
|
||||
);
|
||||
const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0));
|
||||
const leakedText = collectText(events);
|
||||
assert.equal(hasTool, true);
|
||||
assert.equal(hasTool, false);
|
||||
assert.equal(leakedText.includes('然后继续解释。'), true);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), false);
|
||||
assert.equal(leakedText.toLowerCase().includes('tool_calls'), true);
|
||||
});
|
||||
|
||||
test('formatOpenAIStreamToolCalls reuses ids with the same idStore', () => {
|
||||
const idStore = new Map();
|
||||
const calls = [{ name: 'read_file', input: { path: 'README.MD' } }];
|
||||
const first = formatOpenAIStreamToolCalls(calls, idStore);
|
||||
const second = formatOpenAIStreamToolCalls(calls, idStore);
|
||||
assert.equal(first.length, 1);
|
||||
assert.equal(second.length, 1);
|
||||
assert.equal(first[0].id, second[0].id);
|
||||
});
|
||||
|
||||
@@ -24,9 +24,8 @@
|
||||
<meta name="apple-mobile-web-app-status-bar-style" content="black-translucent" />
|
||||
<meta name="apple-mobile-web-app-title" content="DS2API" />
|
||||
|
||||
<!-- Favicon - using data URI for orange-yellow gradient icon -->
|
||||
<link rel="icon" type="image/svg+xml"
|
||||
href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Cdefs%3E%3ClinearGradient id='g' x1='0%25' y1='0%25' x2='100%25' y2='100%25'%3E%3Cstop offset='0%25' stop-color='%23f59e0b'/%3E%3Cstop offset='100%25' stop-color='%23ef4444'/%3E%3C/linearGradient%3E%3C/defs%3E%3Crect rx='20' width='100' height='100' fill='url(%23g)'/%3E%3Ctext x='50' y='68' font-family='Arial,sans-serif' font-size='48' font-weight='bold' fill='white' text-anchor='middle'%3EDS%3C/text%3E%3C/svg%3E" />
|
||||
<!-- Favicon -->
|
||||
<link rel="icon" type="image/svg+xml" href="/ds2api-favicon.svg" />
|
||||
|
||||
<!-- Fonts -->
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
|
||||
20
webui/public/ds2api-favicon.svg
Normal file
20
webui/public/ds2api-favicon.svg
Normal file
@@ -0,0 +1,20 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100" role="img" aria-label="DS2API icon">
|
||||
<defs>
|
||||
<linearGradient id="g" x1="0%" y1="0%" x2="100%" y2="100%">
|
||||
<stop offset="0%" stop-color="#f59e0b" />
|
||||
<stop offset="100%" stop-color="#ef4444" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<rect width="100" height="100" rx="20" fill="url(#g)" />
|
||||
<text
|
||||
x="50"
|
||||
y="68"
|
||||
text-anchor="middle"
|
||||
font-family="Arial,sans-serif"
|
||||
font-size="48"
|
||||
font-weight="700"
|
||||
fill="#ffffff"
|
||||
>
|
||||
DS
|
||||
</text>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 539 B |
Reference in New Issue
Block a user