mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-14 13:15:07 +08:00
714 lines
15 KiB
Go
714 lines
15 KiB
Go
package openai
|
|
|
|
import (
|
|
"strings"
|
|
|
|
"ds2api/internal/util"
|
|
)
|
|
|
|
type toolStreamSieveState struct {
|
|
pending strings.Builder
|
|
capture strings.Builder
|
|
capturing bool
|
|
recentTextTail string
|
|
disableDeltas bool
|
|
toolNameSent bool
|
|
toolName string
|
|
toolArgsStart int
|
|
toolArgsSent int
|
|
toolArgsString bool
|
|
toolArgsDone bool
|
|
}
|
|
|
|
type toolStreamEvent struct {
|
|
Content string
|
|
ToolCalls []util.ParsedToolCall
|
|
ToolCallDeltas []toolCallDelta
|
|
}
|
|
|
|
type toolCallDelta struct {
|
|
Index int
|
|
Name string
|
|
Arguments string
|
|
}
|
|
|
|
const toolSieveCaptureLimit = 8 * 1024
|
|
const toolSieveContextTailLimit = 256
|
|
|
|
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
|
s.disableDeltas = false
|
|
s.toolNameSent = false
|
|
s.toolName = ""
|
|
s.toolArgsStart = -1
|
|
s.toolArgsSent = -1
|
|
s.toolArgsString = false
|
|
s.toolArgsDone = false
|
|
}
|
|
|
|
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
|
if state == nil {
|
|
return nil
|
|
}
|
|
if chunk != "" {
|
|
state.pending.WriteString(chunk)
|
|
}
|
|
events := make([]toolStreamEvent, 0, 2)
|
|
|
|
for {
|
|
if state.capturing {
|
|
if state.pending.Len() > 0 {
|
|
state.capture.WriteString(state.pending.String())
|
|
state.pending.Reset()
|
|
}
|
|
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
|
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
|
}
|
|
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
|
if !ready {
|
|
if state.capture.Len() > toolSieveCaptureLimit {
|
|
content := state.capture.String()
|
|
state.capture.Reset()
|
|
state.capturing = false
|
|
state.resetIncrementalToolState()
|
|
state.noteText(content)
|
|
events = append(events, toolStreamEvent{Content: content})
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
state.capture.Reset()
|
|
state.capturing = false
|
|
state.resetIncrementalToolState()
|
|
if prefix != "" {
|
|
state.noteText(prefix)
|
|
events = append(events, toolStreamEvent{Content: prefix})
|
|
}
|
|
if len(calls) > 0 {
|
|
events = append(events, toolStreamEvent{ToolCalls: calls})
|
|
}
|
|
if suffix != "" {
|
|
state.pending.WriteString(suffix)
|
|
}
|
|
continue
|
|
}
|
|
|
|
pending := state.pending.String()
|
|
if pending == "" {
|
|
break
|
|
}
|
|
start := findToolSegmentStart(pending)
|
|
if start >= 0 {
|
|
prefix := pending[:start]
|
|
if prefix != "" {
|
|
state.noteText(prefix)
|
|
events = append(events, toolStreamEvent{Content: prefix})
|
|
}
|
|
state.pending.Reset()
|
|
state.capture.WriteString(pending[start:])
|
|
state.capturing = true
|
|
state.resetIncrementalToolState()
|
|
continue
|
|
}
|
|
|
|
safe, hold := splitSafeContentForToolDetection(pending)
|
|
if safe == "" {
|
|
break
|
|
}
|
|
state.pending.Reset()
|
|
state.pending.WriteString(hold)
|
|
state.noteText(safe)
|
|
events = append(events, toolStreamEvent{Content: safe})
|
|
}
|
|
|
|
return events
|
|
}
|
|
|
|
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
|
|
if state == nil {
|
|
return nil
|
|
}
|
|
events := processToolSieveChunk(state, "", toolNames)
|
|
if state.capturing {
|
|
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
|
if ready {
|
|
if consumedPrefix != "" {
|
|
state.noteText(consumedPrefix)
|
|
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
|
}
|
|
if len(consumedCalls) > 0 {
|
|
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
|
}
|
|
if consumedSuffix != "" {
|
|
state.noteText(consumedSuffix)
|
|
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
|
}
|
|
} else {
|
|
content := state.capture.String()
|
|
if content != "" {
|
|
state.noteText(content)
|
|
events = append(events, toolStreamEvent{Content: content})
|
|
}
|
|
}
|
|
state.capture.Reset()
|
|
state.capturing = false
|
|
state.resetIncrementalToolState()
|
|
}
|
|
if state.pending.Len() > 0 {
|
|
content := state.pending.String()
|
|
state.noteText(content)
|
|
events = append(events, toolStreamEvent{Content: content})
|
|
state.pending.Reset()
|
|
}
|
|
return events
|
|
}
|
|
|
|
func splitSafeContentForToolDetection(s string) (safe, hold string) {
|
|
if s == "" {
|
|
return "", ""
|
|
}
|
|
suspiciousStart := findSuspiciousPrefixStart(s)
|
|
if suspiciousStart < 0 {
|
|
return s, ""
|
|
}
|
|
if suspiciousStart > 0 {
|
|
return s[:suspiciousStart], s[suspiciousStart:]
|
|
}
|
|
// If suspicious content starts at position 0, keep holding until we can
|
|
// parse a complete tool JSON block or reach stream flush.
|
|
return "", s
|
|
}
|
|
|
|
func findSuspiciousPrefixStart(s string) int {
|
|
start := -1
|
|
indices := []int{
|
|
strings.LastIndex(s, "{"),
|
|
strings.LastIndex(s, "["),
|
|
strings.LastIndex(s, "```"),
|
|
}
|
|
for _, idx := range indices {
|
|
if idx > start {
|
|
start = idx
|
|
}
|
|
}
|
|
return start
|
|
}
|
|
|
|
func findToolSegmentStart(s string) int {
|
|
if s == "" {
|
|
return -1
|
|
}
|
|
lower := strings.ToLower(s)
|
|
offset := 0
|
|
for {
|
|
keyRel := strings.Index(lower[offset:], "tool_calls")
|
|
if keyRel < 0 {
|
|
return -1
|
|
}
|
|
keyIdx := offset + keyRel
|
|
start := strings.LastIndex(s[:keyIdx], "{")
|
|
if start < 0 {
|
|
start = keyIdx
|
|
}
|
|
if !insideCodeFence(s[:start]) {
|
|
return start
|
|
}
|
|
offset = keyIdx + len("tool_calls")
|
|
}
|
|
}
|
|
|
|
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
|
captured := state.capture.String()
|
|
if captured == "" {
|
|
return "", nil, "", false
|
|
}
|
|
lower := strings.ToLower(captured)
|
|
keyIdx := strings.Index(lower, "tool_calls")
|
|
if keyIdx < 0 {
|
|
return "", nil, "", false
|
|
}
|
|
start := strings.LastIndex(captured[:keyIdx], "{")
|
|
if start < 0 {
|
|
return "", nil, "", false
|
|
}
|
|
obj, end, ok := extractJSONObjectFrom(captured, start)
|
|
if !ok {
|
|
return "", nil, "", false
|
|
}
|
|
prefixPart := captured[:start]
|
|
suffixPart := captured[end:]
|
|
if insideCodeFence(state.recentTextTail + prefixPart) {
|
|
return captured, nil, "", true
|
|
}
|
|
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
|
|
if len(parsed) == 0 {
|
|
return captured, nil, "", true
|
|
}
|
|
return prefixPart, parsed, suffixPart, true
|
|
}
|
|
|
|
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
|
if start < 0 || start >= len(text) || text[start] != '{' {
|
|
return "", 0, false
|
|
}
|
|
depth := 0
|
|
quote := byte(0)
|
|
escaped := false
|
|
for i := start; i < len(text); i++ {
|
|
ch := text[i]
|
|
if quote != 0 {
|
|
if escaped {
|
|
escaped = false
|
|
continue
|
|
}
|
|
if ch == '\\' {
|
|
escaped = true
|
|
continue
|
|
}
|
|
if ch == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
if ch == '"' || ch == '\'' {
|
|
quote = ch
|
|
continue
|
|
}
|
|
if ch == '{' {
|
|
depth++
|
|
continue
|
|
}
|
|
if ch == '}' {
|
|
depth--
|
|
if depth == 0 {
|
|
end := i + 1
|
|
return text[start:end], end, true
|
|
}
|
|
}
|
|
}
|
|
return "", 0, false
|
|
}
|
|
|
|
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
|
if state.disableDeltas {
|
|
return nil
|
|
}
|
|
captured := state.capture.String()
|
|
if captured == "" {
|
|
return nil
|
|
}
|
|
lower := strings.ToLower(captured)
|
|
keyIdx := strings.Index(lower, "tool_calls")
|
|
if keyIdx < 0 {
|
|
return nil
|
|
}
|
|
start := strings.LastIndex(captured[:keyIdx], "{")
|
|
if start < 0 {
|
|
return nil
|
|
}
|
|
if insideCodeFence(state.recentTextTail + captured[:start]) {
|
|
return nil
|
|
}
|
|
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
|
|
if hasMultiple {
|
|
state.disableDeltas = true
|
|
return nil
|
|
}
|
|
if !certainSingle {
|
|
// In uncertain phases (e.g. first call arrived but array not closed yet),
|
|
// avoid speculative deltas and wait for final parsed tool_calls payload.
|
|
return nil
|
|
}
|
|
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
deltas := make([]toolCallDelta, 0, 2)
|
|
if state.toolName == "" {
|
|
name, ok := extractToolCallName(captured, callStart)
|
|
if !ok || name == "" {
|
|
return nil
|
|
}
|
|
state.toolName = name
|
|
}
|
|
if state.toolArgsStart < 0 {
|
|
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
|
|
if ok {
|
|
state.toolArgsString = stringMode
|
|
if stringMode {
|
|
state.toolArgsStart = argsStart + 1
|
|
} else {
|
|
state.toolArgsStart = argsStart
|
|
}
|
|
state.toolArgsSent = state.toolArgsStart
|
|
}
|
|
}
|
|
if !state.toolNameSent {
|
|
if state.toolArgsStart < 0 {
|
|
return nil
|
|
}
|
|
state.toolNameSent = true
|
|
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
|
|
}
|
|
if state.toolArgsStart < 0 || state.toolArgsDone {
|
|
return deltas
|
|
}
|
|
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
|
|
if !ok {
|
|
return deltas
|
|
}
|
|
if end > state.toolArgsSent {
|
|
deltas = append(deltas, toolCallDelta{
|
|
Index: 0,
|
|
Arguments: captured[state.toolArgsSent:end],
|
|
})
|
|
state.toolArgsSent = end
|
|
}
|
|
if complete {
|
|
state.toolArgsDone = true
|
|
}
|
|
return deltas
|
|
}
|
|
|
|
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
|
|
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
|
if !ok {
|
|
return false, false
|
|
}
|
|
i := skipSpaces(text, arrStart+1)
|
|
if i >= len(text) || text[i] != '{' {
|
|
return false, false
|
|
}
|
|
count := 0
|
|
depth := 0
|
|
quote := byte(0)
|
|
escaped := false
|
|
for ; i < len(text); i++ {
|
|
ch := text[i]
|
|
if quote != 0 {
|
|
if escaped {
|
|
escaped = false
|
|
continue
|
|
}
|
|
if ch == '\\' {
|
|
escaped = true
|
|
continue
|
|
}
|
|
if ch == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
if ch == '"' || ch == '\'' {
|
|
quote = ch
|
|
continue
|
|
}
|
|
if ch == '{' {
|
|
if depth == 0 {
|
|
count++
|
|
if count > 1 {
|
|
return false, true
|
|
}
|
|
}
|
|
depth++
|
|
continue
|
|
}
|
|
if ch == '}' {
|
|
if depth > 0 {
|
|
depth--
|
|
}
|
|
continue
|
|
}
|
|
if ch == ',' && depth == 0 {
|
|
// top-level separator means at least one more tool call exists
|
|
// (or is expected). Treat as multi-call and stop incremental deltas.
|
|
return false, true
|
|
}
|
|
if ch == ']' && depth == 0 {
|
|
return count == 1, false
|
|
}
|
|
}
|
|
// array not closed yet: still uncertain whether more calls will appear
|
|
return false, false
|
|
}
|
|
|
|
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
|
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
|
if !ok {
|
|
return -1, false
|
|
}
|
|
i := skipSpaces(text, arrStart+1)
|
|
if i >= len(text) || text[i] != '{' {
|
|
return -1, false
|
|
}
|
|
return i, true
|
|
}
|
|
|
|
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
|
|
i := keyIdx + len("tool_calls")
|
|
for i < len(text) && text[i] != ':' {
|
|
i++
|
|
}
|
|
if i >= len(text) {
|
|
return -1, false
|
|
}
|
|
i = skipSpaces(text, i+1)
|
|
if i >= len(text) || text[i] != '[' {
|
|
return -1, false
|
|
}
|
|
return i, true
|
|
}
|
|
|
|
func extractToolCallName(text string, callStart int) (string, bool) {
|
|
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
|
|
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
|
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
|
if !fnOK {
|
|
return "", false
|
|
}
|
|
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
|
|
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
|
return "", false
|
|
}
|
|
}
|
|
name, _, ok := parseJSONStringLiteral(text, valueStart)
|
|
if !ok {
|
|
return "", false
|
|
}
|
|
return name, true
|
|
}
|
|
|
|
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
|
|
keys := []string{"input", "arguments", "args", "parameters", "params"}
|
|
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
|
|
if !ok {
|
|
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
|
if !fnOK {
|
|
return -1, false, false
|
|
}
|
|
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
|
|
if !ok {
|
|
return -1, false, false
|
|
}
|
|
}
|
|
if valueStart >= len(text) {
|
|
return -1, false, false
|
|
}
|
|
ch := text[valueStart]
|
|
if ch == '{' || ch == '[' {
|
|
return valueStart, false, true
|
|
}
|
|
if ch == '"' {
|
|
return valueStart, true, true
|
|
}
|
|
return -1, false, false
|
|
}
|
|
|
|
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
|
|
if start < 0 || start > len(text) {
|
|
return 0, false, false
|
|
}
|
|
if stringMode {
|
|
escaped := false
|
|
for i := start; i < len(text); i++ {
|
|
ch := text[i]
|
|
if escaped {
|
|
escaped = false
|
|
continue
|
|
}
|
|
if ch == '\\' {
|
|
escaped = true
|
|
continue
|
|
}
|
|
if ch == '"' {
|
|
return i, true, true
|
|
}
|
|
}
|
|
return len(text), false, true
|
|
}
|
|
if start >= len(text) {
|
|
return start, false, false
|
|
}
|
|
if text[start] != '{' && text[start] != '[' {
|
|
return 0, false, false
|
|
}
|
|
depth := 0
|
|
quote := byte(0)
|
|
escaped := false
|
|
for i := start; i < len(text); i++ {
|
|
ch := text[i]
|
|
if quote != 0 {
|
|
if escaped {
|
|
escaped = false
|
|
continue
|
|
}
|
|
if ch == '\\' {
|
|
escaped = true
|
|
continue
|
|
}
|
|
if ch == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
if ch == '"' || ch == '\'' {
|
|
quote = ch
|
|
continue
|
|
}
|
|
if ch == '{' || ch == '[' {
|
|
depth++
|
|
continue
|
|
}
|
|
if ch == '}' || ch == ']' {
|
|
depth--
|
|
if depth == 0 {
|
|
return i + 1, true, true
|
|
}
|
|
}
|
|
}
|
|
return len(text), false, true
|
|
}
|
|
|
|
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
|
|
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
|
|
return 0, false
|
|
}
|
|
depth := 0
|
|
quote := byte(0)
|
|
escaped := false
|
|
for i := objStart; i < len(text); i++ {
|
|
ch := text[i]
|
|
if quote != 0 {
|
|
if escaped {
|
|
escaped = false
|
|
continue
|
|
}
|
|
if ch == '\\' {
|
|
escaped = true
|
|
continue
|
|
}
|
|
if ch == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
if ch == '"' || ch == '\'' {
|
|
if depth == 1 {
|
|
key, end, ok := parseJSONStringLiteral(text, i)
|
|
if !ok {
|
|
return 0, false
|
|
}
|
|
j := skipSpaces(text, end)
|
|
if j >= len(text) || text[j] != ':' {
|
|
i = end - 1
|
|
continue
|
|
}
|
|
j = skipSpaces(text, j+1)
|
|
if j >= len(text) {
|
|
return 0, false
|
|
}
|
|
if containsKey(keys, key) {
|
|
return j, true
|
|
}
|
|
i = j - 1
|
|
continue
|
|
}
|
|
quote = ch
|
|
continue
|
|
}
|
|
if ch == '{' {
|
|
depth++
|
|
continue
|
|
}
|
|
if ch == '}' {
|
|
depth--
|
|
if depth == 0 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return 0, false
|
|
}
|
|
|
|
func findFunctionObjectStart(text string, callStart int) (int, bool) {
|
|
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
|
|
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
|
|
return -1, false
|
|
}
|
|
return valueStart, true
|
|
}
|
|
|
|
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
|
|
if start < 0 || start >= len(text) || text[start] != '"' {
|
|
return "", 0, false
|
|
}
|
|
var b strings.Builder
|
|
escaped := false
|
|
for i := start + 1; i < len(text); i++ {
|
|
ch := text[i]
|
|
if escaped {
|
|
b.WriteByte(ch)
|
|
escaped = false
|
|
continue
|
|
}
|
|
if ch == '\\' {
|
|
escaped = true
|
|
continue
|
|
}
|
|
if ch == '"' {
|
|
return b.String(), i + 1, true
|
|
}
|
|
b.WriteByte(ch)
|
|
}
|
|
return "", 0, false
|
|
}
|
|
|
|
func containsKey(keys []string, value string) bool {
|
|
for _, k := range keys {
|
|
if k == value {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func skipSpaces(text string, i int) int {
|
|
for i < len(text) {
|
|
switch text[i] {
|
|
case ' ', '\t', '\n', '\r':
|
|
i++
|
|
default:
|
|
return i
|
|
}
|
|
}
|
|
return i
|
|
}
|
|
|
|
func (s *toolStreamSieveState) noteText(content string) {
|
|
if strings.TrimSpace(content) == "" {
|
|
return
|
|
}
|
|
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
|
|
}
|
|
|
|
func appendTail(prev, next string, max int) string {
|
|
if max <= 0 {
|
|
return ""
|
|
}
|
|
combined := prev + next
|
|
if len(combined) <= max {
|
|
return combined
|
|
}
|
|
return combined[len(combined)-max:]
|
|
}
|
|
|
|
func looksLikeToolExampleContext(text string) bool {
|
|
return insideCodeFence(text)
|
|
}
|
|
|
|
func insideCodeFence(text string) bool {
|
|
if text == "" {
|
|
return false
|
|
}
|
|
return strings.Count(text, "```")%2 == 1
|
|
}
|