Files
ds2api/internal/adapter/openai/tool_sieve.go

714 lines
15 KiB
Go

package openai
import (
"strings"
"ds2api/internal/util"
)
type toolStreamSieveState struct {
pending strings.Builder
capture strings.Builder
capturing bool
recentTextTail string
disableDeltas bool
toolNameSent bool
toolName string
toolArgsStart int
toolArgsSent int
toolArgsString bool
toolArgsDone bool
}
type toolStreamEvent struct {
Content string
ToolCalls []util.ParsedToolCall
ToolCallDeltas []toolCallDelta
}
type toolCallDelta struct {
Index int
Name string
Arguments string
}
const toolSieveCaptureLimit = 8 * 1024
const toolSieveContextTailLimit = 256
func (s *toolStreamSieveState) resetIncrementalToolState() {
s.disableDeltas = false
s.toolNameSent = false
s.toolName = ""
s.toolArgsStart = -1
s.toolArgsSent = -1
s.toolArgsString = false
s.toolArgsDone = false
}
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
if state == nil {
return nil
}
if chunk != "" {
state.pending.WriteString(chunk)
}
events := make([]toolStreamEvent, 0, 2)
for {
if state.capturing {
if state.pending.Len() > 0 {
state.capture.WriteString(state.pending.String())
state.pending.Reset()
}
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
}
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
if !ready {
if state.capture.Len() > toolSieveCaptureLimit {
content := state.capture.String()
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
continue
}
break
}
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
if len(calls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: calls})
}
if suffix != "" {
state.pending.WriteString(suffix)
}
continue
}
pending := state.pending.String()
if pending == "" {
break
}
start := findToolSegmentStart(pending)
if start >= 0 {
prefix := pending[:start]
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
state.pending.Reset()
state.capture.WriteString(pending[start:])
state.capturing = true
state.resetIncrementalToolState()
continue
}
safe, hold := splitSafeContentForToolDetection(pending)
if safe == "" {
break
}
state.pending.Reset()
state.pending.WriteString(hold)
state.noteText(safe)
events = append(events, toolStreamEvent{Content: safe})
}
return events
}
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
if state == nil {
return nil
}
events := processToolSieveChunk(state, "", toolNames)
if state.capturing {
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
if ready {
if consumedPrefix != "" {
state.noteText(consumedPrefix)
events = append(events, toolStreamEvent{Content: consumedPrefix})
}
if len(consumedCalls) > 0 {
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
}
if consumedSuffix != "" {
state.noteText(consumedSuffix)
events = append(events, toolStreamEvent{Content: consumedSuffix})
}
} else {
content := state.capture.String()
if content != "" {
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
}
}
state.capture.Reset()
state.capturing = false
state.resetIncrementalToolState()
}
if state.pending.Len() > 0 {
content := state.pending.String()
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
state.pending.Reset()
}
return events
}
func splitSafeContentForToolDetection(s string) (safe, hold string) {
if s == "" {
return "", ""
}
suspiciousStart := findSuspiciousPrefixStart(s)
if suspiciousStart < 0 {
return s, ""
}
if suspiciousStart > 0 {
return s[:suspiciousStart], s[suspiciousStart:]
}
// If suspicious content starts at position 0, keep holding until we can
// parse a complete tool JSON block or reach stream flush.
return "", s
}
func findSuspiciousPrefixStart(s string) int {
start := -1
indices := []int{
strings.LastIndex(s, "{"),
strings.LastIndex(s, "["),
strings.LastIndex(s, "```"),
}
for _, idx := range indices {
if idx > start {
start = idx
}
}
return start
}
func findToolSegmentStart(s string) int {
if s == "" {
return -1
}
lower := strings.ToLower(s)
offset := 0
for {
keyRel := strings.Index(lower[offset:], "tool_calls")
if keyRel < 0 {
return -1
}
keyIdx := offset + keyRel
start := strings.LastIndex(s[:keyIdx], "{")
if start < 0 {
start = keyIdx
}
if !insideCodeFence(s[:start]) {
return start
}
offset = keyIdx + len("tool_calls")
}
}
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
captured := state.capture.String()
if captured == "" {
return "", nil, "", false
}
lower := strings.ToLower(captured)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return "", nil, "", false
}
start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 {
return "", nil, "", false
}
obj, end, ok := extractJSONObjectFrom(captured, start)
if !ok {
return "", nil, "", false
}
prefixPart := captured[:start]
suffixPart := captured[end:]
if insideCodeFence(state.recentTextTail + prefixPart) {
return captured, nil, "", true
}
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
if len(parsed) == 0 {
return captured, nil, "", true
}
return prefixPart, parsed, suffixPart, true
}
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '{' {
return "", 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
end := i + 1
return text[start:end], end, true
}
}
}
return "", 0, false
}
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
if state.disableDeltas {
return nil
}
captured := state.capture.String()
if captured == "" {
return nil
}
lower := strings.ToLower(captured)
keyIdx := strings.Index(lower, "tool_calls")
if keyIdx < 0 {
return nil
}
start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 {
return nil
}
if insideCodeFence(state.recentTextTail + captured[:start]) {
return nil
}
certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
if hasMultiple {
state.disableDeltas = true
return nil
}
if !certainSingle {
// In uncertain phases (e.g. first call arrived but array not closed yet),
// avoid speculative deltas and wait for final parsed tool_calls payload.
return nil
}
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
if !ok {
return nil
}
deltas := make([]toolCallDelta, 0, 2)
if state.toolName == "" {
name, ok := extractToolCallName(captured, callStart)
if !ok || name == "" {
return nil
}
state.toolName = name
}
if state.toolArgsStart < 0 {
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
if ok {
state.toolArgsString = stringMode
if stringMode {
state.toolArgsStart = argsStart + 1
} else {
state.toolArgsStart = argsStart
}
state.toolArgsSent = state.toolArgsStart
}
}
if !state.toolNameSent {
if state.toolArgsStart < 0 {
return nil
}
state.toolNameSent = true
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
}
if state.toolArgsStart < 0 || state.toolArgsDone {
return deltas
}
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
if !ok {
return deltas
}
if end > state.toolArgsSent {
deltas = append(deltas, toolCallDelta{
Index: 0,
Arguments: captured[state.toolArgsSent:end],
})
state.toolArgsSent = end
}
if complete {
state.toolArgsDone = true
}
return deltas
}
func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
return false, false
}
i := skipSpaces(text, arrStart+1)
if i >= len(text) || text[i] != '{' {
return false, false
}
count := 0
depth := 0
quote := byte(0)
escaped := false
for ; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
if depth == 0 {
count++
if count > 1 {
return false, true
}
}
depth++
continue
}
if ch == '}' {
if depth > 0 {
depth--
}
continue
}
if ch == ',' && depth == 0 {
// top-level separator means at least one more tool call exists
// (or is expected). Treat as multi-call and stop incremental deltas.
return false, true
}
if ch == ']' && depth == 0 {
return count == 1, false
}
}
// array not closed yet: still uncertain whether more calls will appear
return false, false
}
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
if !ok {
return -1, false
}
i := skipSpaces(text, arrStart+1)
if i >= len(text) || text[i] != '{' {
return -1, false
}
return i, true
}
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
i := keyIdx + len("tool_calls")
for i < len(text) && text[i] != ':' {
i++
}
if i >= len(text) {
return -1, false
}
i = skipSpaces(text, i+1)
if i >= len(text) || text[i] != '[' {
return -1, false
}
return i, true
}
func extractToolCallName(text string, callStart int) (string, bool) {
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
fnStart, fnOK := findFunctionObjectStart(text, callStart)
if !fnOK {
return "", false
}
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
return "", false
}
}
name, _, ok := parseJSONStringLiteral(text, valueStart)
if !ok {
return "", false
}
return name, true
}
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
keys := []string{"input", "arguments", "args", "parameters", "params"}
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
if !ok {
fnStart, fnOK := findFunctionObjectStart(text, callStart)
if !fnOK {
return -1, false, false
}
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
if !ok {
return -1, false, false
}
}
if valueStart >= len(text) {
return -1, false, false
}
ch := text[valueStart]
if ch == '{' || ch == '[' {
return valueStart, false, true
}
if ch == '"' {
return valueStart, true, true
}
return -1, false, false
}
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
if start < 0 || start > len(text) {
return 0, false, false
}
if stringMode {
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
return i, true, true
}
}
return len(text), false, true
}
if start >= len(text) {
return start, false, false
}
if text[start] != '{' && text[start] != '[' {
return 0, false, false
}
depth := 0
quote := byte(0)
escaped := false
for i := start; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' || ch == '[' {
depth++
continue
}
if ch == '}' || ch == ']' {
depth--
if depth == 0 {
return i + 1, true, true
}
}
}
return len(text), false, true
}
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
return 0, false
}
depth := 0
quote := byte(0)
escaped := false
for i := objStart; i < len(text); i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
if depth == 1 {
key, end, ok := parseJSONStringLiteral(text, i)
if !ok {
return 0, false
}
j := skipSpaces(text, end)
if j >= len(text) || text[j] != ':' {
i = end - 1
continue
}
j = skipSpaces(text, j+1)
if j >= len(text) {
return 0, false
}
if containsKey(keys, key) {
return j, true
}
i = j - 1
continue
}
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
break
}
}
}
return 0, false
}
func findFunctionObjectStart(text string, callStart int) (int, bool) {
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
return -1, false
}
return valueStart, true
}
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '"' {
return "", 0, false
}
var b strings.Builder
escaped := false
for i := start + 1; i < len(text); i++ {
ch := text[i]
if escaped {
b.WriteByte(ch)
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
return b.String(), i + 1, true
}
b.WriteByte(ch)
}
return "", 0, false
}
func containsKey(keys []string, value string) bool {
for _, k := range keys {
if k == value {
return true
}
}
return false
}
func skipSpaces(text string, i int) int {
for i < len(text) {
switch text[i] {
case ' ', '\t', '\n', '\r':
i++
default:
return i
}
}
return i
}
func (s *toolStreamSieveState) noteText(content string) {
if strings.TrimSpace(content) == "" {
return
}
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
}
func appendTail(prev, next string, max int) string {
if max <= 0 {
return ""
}
combined := prev + next
if len(combined) <= max {
return combined
}
return combined[len(combined)-max:]
}
func looksLikeToolExampleContext(text string) bool {
return insideCodeFence(text)
}
func insideCodeFence(text string) bool {
if text == "" {
return false
}
return strings.Count(text, "```")%2 == 1
}