mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-19 23:47:45 +08:00
Refactor admin handlers into specialized files and introduce OpenAI tool sieving and Vercel streaming capabilities.
This commit is contained in:
@@ -3,15 +3,10 @@ package openai
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -39,17 +34,6 @@ type streamLease struct {
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
Content string
|
||||
ToolCalls []util.ParsedToolCall
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
r.Get("/v1/models", h.ListModels)
|
||||
r.Post("/v1/chat/completions", h.ChatCompletions)
|
||||
@@ -140,142 +124,6 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleNonStream(w, r.Context(), resp, sessionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
|
||||
}
|
||||
|
||||
func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Request) {
|
||||
if !config.IsVercel() {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
h.sweepExpiredStreamLeases()
|
||||
internalSecret := vercelInternalSecret()
|
||||
internalToken := strings.TrimSpace(r.Header.Get("X-Ds2-Internal-Token"))
|
||||
if internalSecret == "" || subtle.ConstantTimeCompare([]byte(internalToken), []byte(internalSecret)) != 1 {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized internal request")
|
||||
return
|
||||
}
|
||||
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, err.Error())
|
||||
return
|
||||
}
|
||||
leased := false
|
||||
defer func() {
|
||||
if !leased {
|
||||
h.Auth.Release(a)
|
||||
}
|
||||
}()
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if !toBool(req["stream"]) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
}
|
||||
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "tools are not supported by vercel stream prepare")
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
return
|
||||
}
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
return
|
||||
}
|
||||
|
||||
messages := normalizeMessages(messagesRaw)
|
||||
finalPrompt := util.MessagesPrepare(messages)
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
powHeader, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(a.DeepSeekToken) == "" {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
return
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
leaseID := h.holdStreamLease(a)
|
||||
if leaseID == "" {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease")
|
||||
return
|
||||
}
|
||||
leased = true
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"session_id": sessionID,
|
||||
"lease_id": leaseID,
|
||||
"model": model,
|
||||
"final_prompt": finalPrompt,
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
"deepseek_token": a.DeepSeekToken,
|
||||
"pow_header": powHeader,
|
||||
"payload": payload,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleVercelStreamRelease(w http.ResponseWriter, r *http.Request) {
|
||||
if !config.IsVercel() {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
h.sweepExpiredStreamLeases()
|
||||
internalSecret := vercelInternalSecret()
|
||||
internalToken := strings.TrimSpace(r.Header.Get("X-Ds2-Internal-Token"))
|
||||
if internalSecret == "" || subtle.ConstantTimeCompare([]byte(internalToken), []byte(internalSecret)) != 1 {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized internal request")
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
leaseID, _ := req["lease_id"].(string)
|
||||
leaseID = strings.TrimSpace(leaseID)
|
||||
if leaseID == "" {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "lease_id is required")
|
||||
return
|
||||
}
|
||||
if !h.releaseStreamLease(leaseID) {
|
||||
writeOpenAIError(w, http.StatusNotFound, "stream lease not found")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true})
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
@@ -700,367 +548,3 @@ func openAIErrorType(status int) string {
|
||||
return "invalid_request_error"
|
||||
}
|
||||
}
|
||||
|
||||
func isVercelStreamPrepareRequest(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1"
|
||||
}
|
||||
|
||||
func isVercelStreamReleaseRequest(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
}
|
||||
|
||||
func vercelInternalSecret() string {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_VERCEL_INTERNAL_SECRET")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" {
|
||||
return v
|
||||
}
|
||||
return "admin"
|
||||
}
|
||||
|
||||
func shouldEmitBufferedToolProbeContent(buffered string) bool {
|
||||
trimmed := strings.TrimSpace(buffered)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
normalized := normalizeToolProbePrefix(trimmed)
|
||||
if normalized == "" {
|
||||
return false
|
||||
}
|
||||
first := normalized[0]
|
||||
switch first {
|
||||
case '{', '[', '`':
|
||||
lower := strings.ToLower(normalized)
|
||||
if strings.Contains(lower, "tool_calls") {
|
||||
return false
|
||||
}
|
||||
// Keep a short hold window for JSON-ish starts to avoid leaking tool JSON.
|
||||
if len([]rune(normalized)) < 20 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
default:
|
||||
// Natural language starts can be streamed immediately.
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeToolProbePrefix(s string) string {
|
||||
t := strings.TrimSpace(s)
|
||||
if strings.HasPrefix(t, "```") {
|
||||
t = strings.TrimPrefix(t, "```")
|
||||
t = strings.TrimSpace(t)
|
||||
t = strings.TrimPrefix(strings.ToLower(t), "json")
|
||||
t = strings.TrimSpace(t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
||||
if state == nil || chunk == "" {
|
||||
return nil
|
||||
}
|
||||
state.pending.WriteString(chunk)
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
|
||||
for {
|
||||
if state.capturing {
|
||||
if state.pending.Len() > 0 {
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
||||
if !ready {
|
||||
break
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
if prefix != "" {
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: calls})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pending := state.pending.String()
|
||||
if pending == "" {
|
||||
break
|
||||
}
|
||||
start := findToolSegmentStart(pending)
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.capture.WriteString(pending[start:])
|
||||
state.capturing = true
|
||||
continue
|
||||
}
|
||||
|
||||
safe, hold := splitSafeContentForToolDetection(pending)
|
||||
if safe == "" {
|
||||
break
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.pending.WriteString(hold)
|
||||
events = append(events, toolStreamEvent{Content: safe})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if state.capturing {
|
||||
raw := state.capture.String()
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
if raw != "" {
|
||||
events = append(events, toolStreamEvent{Content: raw})
|
||||
}
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
events = append(events, toolStreamEvent{Content: state.pending.String()})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func splitSafeContentForToolDetection(s string) (safe, hold string) {
|
||||
if s == "" {
|
||||
return "", ""
|
||||
}
|
||||
suspiciousStart := findSuspiciousPrefixStart(s)
|
||||
if suspiciousStart < 0 {
|
||||
return s, ""
|
||||
}
|
||||
if suspiciousStart > 0 {
|
||||
return s[:suspiciousStart], s[suspiciousStart:]
|
||||
}
|
||||
runes := []rune(s)
|
||||
const maxHold = 128
|
||||
if len(runes) <= maxHold {
|
||||
return "", s
|
||||
}
|
||||
return string(runes[:len(runes)-maxHold]), string(runes[len(runes)-maxHold:])
|
||||
}
|
||||
|
||||
func findSuspiciousPrefixStart(s string) int {
|
||||
start := -1
|
||||
indices := []int{
|
||||
strings.LastIndex(s, "{"),
|
||||
strings.LastIndex(s, "["),
|
||||
strings.LastIndex(s, "```"),
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if idx > start {
|
||||
start = idx
|
||||
}
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func findToolSegmentStart(s string) int {
|
||||
if s == "" {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 {
|
||||
return start
|
||||
}
|
||||
return keyIdx
|
||||
}
|
||||
|
||||
func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
if captured == "" {
|
||||
return "", nil, "", false
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
if len([]rune(captured)) >= 256 {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
if len([]rune(captured)) >= 512 {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
if len([]rune(captured)) >= 4096 {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
parsed := util.ParseToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
return captured[:end], nil, captured[end:], true
|
||||
}
|
||||
return captured[:start], parsed, captured[end:], true
|
||||
}
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
end := i + 1
|
||||
return text[start:end], end, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func (h *Handler) holdStreamLease(a *auth.RequestAuth) string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
now := time.Now()
|
||||
ttl := streamLeaseTTL()
|
||||
if ttl <= 0 {
|
||||
ttl = 15 * time.Minute
|
||||
}
|
||||
|
||||
h.leaseMu.Lock()
|
||||
expired := h.popExpiredLeasesLocked(now)
|
||||
if h.streamLeases == nil {
|
||||
h.streamLeases = make(map[string]streamLease)
|
||||
}
|
||||
leaseID := newLeaseID()
|
||||
h.streamLeases[leaseID] = streamLease{
|
||||
Auth: a,
|
||||
ExpiresAt: now.Add(ttl),
|
||||
}
|
||||
h.leaseMu.Unlock()
|
||||
h.releaseExpiredAuths(expired)
|
||||
return leaseID
|
||||
}
|
||||
|
||||
func (h *Handler) releaseStreamLease(leaseID string) bool {
|
||||
leaseID = strings.TrimSpace(leaseID)
|
||||
if leaseID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
h.leaseMu.Lock()
|
||||
expired := h.popExpiredLeasesLocked(time.Now())
|
||||
lease, ok := h.streamLeases[leaseID]
|
||||
if ok {
|
||||
delete(h.streamLeases, leaseID)
|
||||
}
|
||||
h.leaseMu.Unlock()
|
||||
h.releaseExpiredAuths(expired)
|
||||
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if h.Auth != nil {
|
||||
h.Auth.Release(lease.Auth)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) popExpiredLeasesLocked(now time.Time) []*auth.RequestAuth {
|
||||
if len(h.streamLeases) == 0 {
|
||||
return nil
|
||||
}
|
||||
expired := make([]*auth.RequestAuth, 0)
|
||||
for leaseID, lease := range h.streamLeases {
|
||||
if now.After(lease.ExpiresAt) {
|
||||
delete(h.streamLeases, leaseID)
|
||||
expired = append(expired, lease.Auth)
|
||||
}
|
||||
}
|
||||
return expired
|
||||
}
|
||||
|
||||
func (h *Handler) releaseExpiredAuths(expired []*auth.RequestAuth) {
|
||||
if h.Auth == nil || len(expired) == 0 {
|
||||
return
|
||||
}
|
||||
for _, a := range expired {
|
||||
h.Auth.Release(a)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) sweepExpiredStreamLeases() {
|
||||
h.leaseMu.Lock()
|
||||
expired := h.popExpiredLeasesLocked(time.Now())
|
||||
h.leaseMu.Unlock()
|
||||
h.releaseExpiredAuths(expired)
|
||||
}
|
||||
|
||||
func streamLeaseTTL() time.Duration {
|
||||
raw := strings.TrimSpace(os.Getenv("DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS"))
|
||||
if raw == "" {
|
||||
return 15 * time.Minute
|
||||
}
|
||||
seconds, err := strconv.Atoi(raw)
|
||||
if err != nil || seconds <= 0 {
|
||||
return 15 * time.Minute
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func newLeaseID() string {
|
||||
buf := make([]byte, 16)
|
||||
if _, err := rand.Read(buf); err == nil {
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
return fmt.Sprintf("lease-%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
236
internal/adapter/openai/tool_sieve.go
Normal file
236
internal/adapter/openai/tool_sieve.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
Content string
|
||||
ToolCalls []util.ParsedToolCall
|
||||
}
|
||||
|
||||
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
if chunk != "" {
|
||||
state.pending.WriteString(chunk)
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
|
||||
for {
|
||||
if state.capturing {
|
||||
if state.pending.Len() > 0 {
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
||||
if !ready {
|
||||
break
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
if prefix != "" {
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: calls})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
pending := state.pending.String()
|
||||
if pending == "" {
|
||||
break
|
||||
}
|
||||
start := findToolSegmentStart(pending)
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.capture.WriteString(pending[start:])
|
||||
state.capturing = true
|
||||
continue
|
||||
}
|
||||
|
||||
safe, hold := splitSafeContentForToolDetection(pending)
|
||||
if safe == "" {
|
||||
break
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.pending.WriteString(hold)
|
||||
events = append(events, toolStreamEvent{Content: safe})
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if state.capturing {
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
||||
if ready {
|
||||
if consumedPrefix != "" {
|
||||
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
||||
}
|
||||
if len(consumedCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
||||
}
|
||||
if consumedSuffix != "" {
|
||||
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
||||
}
|
||||
} else {
|
||||
raw := state.capture.String()
|
||||
if raw != "" {
|
||||
events = append(events, toolStreamEvent{Content: raw})
|
||||
}
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
events = append(events, toolStreamEvent{Content: state.pending.String()})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func splitSafeContentForToolDetection(s string) (safe, hold string) {
|
||||
if s == "" {
|
||||
return "", ""
|
||||
}
|
||||
suspiciousStart := findSuspiciousPrefixStart(s)
|
||||
if suspiciousStart < 0 {
|
||||
return s, ""
|
||||
}
|
||||
if suspiciousStart > 0 {
|
||||
return s[:suspiciousStart], s[suspiciousStart:]
|
||||
}
|
||||
runes := []rune(s)
|
||||
const maxHold = 128
|
||||
if len(runes) <= maxHold {
|
||||
return "", s
|
||||
}
|
||||
return string(runes[:len(runes)-maxHold]), string(runes[len(runes)-maxHold:])
|
||||
}
|
||||
|
||||
func findSuspiciousPrefixStart(s string) int {
|
||||
start := -1
|
||||
indices := []int{
|
||||
strings.LastIndex(s, "{"),
|
||||
strings.LastIndex(s, "["),
|
||||
strings.LastIndex(s, "```"),
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if idx > start {
|
||||
start = idx
|
||||
}
|
||||
}
|
||||
return start
|
||||
}
|
||||
|
||||
func findToolSegmentStart(s string) int {
|
||||
if s == "" {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return -1
|
||||
}
|
||||
if start := strings.LastIndex(s[:keyIdx], "{"); start >= 0 {
|
||||
return start
|
||||
}
|
||||
return keyIdx
|
||||
}
|
||||
|
||||
func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
if captured == "" {
|
||||
return "", nil, "", false
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
if len([]rune(captured)) >= 256 {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 {
|
||||
if len([]rune(captured)) >= 512 {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
obj, end, ok := extractJSONObjectFrom(captured, start)
|
||||
if !ok {
|
||||
if len([]rune(captured)) >= 4096 {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
parsed := util.ParseToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
return captured[:end], nil, captured[end:], true
|
||||
}
|
||||
return captured[:start], parsed, captured[end:], true
|
||||
}
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '{' {
|
||||
return "", 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
end := i + 1
|
||||
return text[start:end], end, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
277
internal/adapter/openai/vercel_stream.go
Normal file
277
internal/adapter/openai/vercel_stream.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Request) {
|
||||
if !config.IsVercel() {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
h.sweepExpiredStreamLeases()
|
||||
internalSecret := vercelInternalSecret()
|
||||
internalToken := strings.TrimSpace(r.Header.Get("X-Ds2-Internal-Token"))
|
||||
if internalSecret == "" || subtle.ConstantTimeCompare([]byte(internalToken), []byte(internalSecret)) != 1 {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized internal request")
|
||||
return
|
||||
}
|
||||
|
||||
a, err := h.Auth.Determine(r)
|
||||
if err != nil {
|
||||
status := http.StatusUnauthorized
|
||||
if err == auth.ErrNoAccount {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
writeOpenAIError(w, status, err.Error())
|
||||
return
|
||||
}
|
||||
leased := false
|
||||
defer func() {
|
||||
if !leased {
|
||||
h.Auth.Release(a)
|
||||
}
|
||||
}()
|
||||
r = r.WithContext(auth.WithAuth(r.Context(), a))
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
if !toBool(req["stream"]) {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "stream must be true")
|
||||
return
|
||||
}
|
||||
if tools, ok := req["tools"].([]any); ok && len(tools) > 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "tools are not supported by vercel stream prepare")
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := req["model"].(string)
|
||||
messagesRaw, _ := req["messages"].([]any)
|
||||
if model == "" || len(messagesRaw) == 0 {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
|
||||
return
|
||||
}
|
||||
thinkingEnabled, searchEnabled, ok := config.GetModelConfig(model)
|
||||
if !ok {
|
||||
writeOpenAIError(w, http.StatusServiceUnavailable, fmt.Sprintf("Model '%s' is not available.", model))
|
||||
return
|
||||
}
|
||||
|
||||
messages := normalizeMessages(messagesRaw)
|
||||
finalPrompt := util.MessagesPrepare(messages)
|
||||
|
||||
sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
if a.UseConfigToken {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
|
||||
} else {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
}
|
||||
return
|
||||
}
|
||||
powHeader, err := h.DS.GetPow(r.Context(), a, 3)
|
||||
if err != nil {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(a.DeepSeekToken) == "" {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
|
||||
return
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
"parent_message_id": nil,
|
||||
"prompt": finalPrompt,
|
||||
"ref_file_ids": []any{},
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
}
|
||||
leaseID := h.holdStreamLease(a)
|
||||
if leaseID == "" {
|
||||
writeOpenAIError(w, http.StatusInternalServerError, "failed to create stream lease")
|
||||
return
|
||||
}
|
||||
leased = true
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"session_id": sessionID,
|
||||
"lease_id": leaseID,
|
||||
"model": model,
|
||||
"final_prompt": finalPrompt,
|
||||
"thinking_enabled": thinkingEnabled,
|
||||
"search_enabled": searchEnabled,
|
||||
"deepseek_token": a.DeepSeekToken,
|
||||
"pow_header": powHeader,
|
||||
"payload": payload,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleVercelStreamRelease(w http.ResponseWriter, r *http.Request) {
|
||||
if !config.IsVercel() {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
h.sweepExpiredStreamLeases()
|
||||
internalSecret := vercelInternalSecret()
|
||||
internalToken := strings.TrimSpace(r.Header.Get("X-Ds2-Internal-Token"))
|
||||
if internalSecret == "" || subtle.ConstantTimeCompare([]byte(internalToken), []byte(internalSecret)) != 1 {
|
||||
writeOpenAIError(w, http.StatusUnauthorized, "unauthorized internal request")
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "invalid json")
|
||||
return
|
||||
}
|
||||
leaseID, _ := req["lease_id"].(string)
|
||||
leaseID = strings.TrimSpace(leaseID)
|
||||
if leaseID == "" {
|
||||
writeOpenAIError(w, http.StatusBadRequest, "lease_id is required")
|
||||
return
|
||||
}
|
||||
if !h.releaseStreamLease(leaseID) {
|
||||
writeOpenAIError(w, http.StatusNotFound, "stream lease not found")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"success": true})
|
||||
}
|
||||
|
||||
func isVercelStreamPrepareRequest(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.URL.Query().Get("__stream_prepare")) == "1"
|
||||
}
|
||||
|
||||
func isVercelStreamReleaseRequest(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(r.URL.Query().Get("__stream_release")) == "1"
|
||||
}
|
||||
|
||||
func vercelInternalSecret() string {
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_VERCEL_INTERNAL_SECRET")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")); v != "" {
|
||||
return v
|
||||
}
|
||||
return "admin"
|
||||
}
|
||||
|
||||
func (h *Handler) holdStreamLease(a *auth.RequestAuth) string {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
now := time.Now()
|
||||
ttl := streamLeaseTTL()
|
||||
if ttl <= 0 {
|
||||
ttl = 15 * time.Minute
|
||||
}
|
||||
|
||||
h.leaseMu.Lock()
|
||||
expired := h.popExpiredLeasesLocked(now)
|
||||
if h.streamLeases == nil {
|
||||
h.streamLeases = make(map[string]streamLease)
|
||||
}
|
||||
leaseID := newLeaseID()
|
||||
h.streamLeases[leaseID] = streamLease{
|
||||
Auth: a,
|
||||
ExpiresAt: now.Add(ttl),
|
||||
}
|
||||
h.leaseMu.Unlock()
|
||||
h.releaseExpiredAuths(expired)
|
||||
return leaseID
|
||||
}
|
||||
|
||||
func (h *Handler) releaseStreamLease(leaseID string) bool {
|
||||
leaseID = strings.TrimSpace(leaseID)
|
||||
if leaseID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
h.leaseMu.Lock()
|
||||
expired := h.popExpiredLeasesLocked(time.Now())
|
||||
lease, ok := h.streamLeases[leaseID]
|
||||
if ok {
|
||||
delete(h.streamLeases, leaseID)
|
||||
}
|
||||
h.leaseMu.Unlock()
|
||||
h.releaseExpiredAuths(expired)
|
||||
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if h.Auth != nil {
|
||||
h.Auth.Release(lease.Auth)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Handler) popExpiredLeasesLocked(now time.Time) []*auth.RequestAuth {
|
||||
if len(h.streamLeases) == 0 {
|
||||
return nil
|
||||
}
|
||||
expired := make([]*auth.RequestAuth, 0)
|
||||
for leaseID, lease := range h.streamLeases {
|
||||
if now.After(lease.ExpiresAt) {
|
||||
delete(h.streamLeases, leaseID)
|
||||
expired = append(expired, lease.Auth)
|
||||
}
|
||||
}
|
||||
return expired
|
||||
}
|
||||
|
||||
func (h *Handler) releaseExpiredAuths(expired []*auth.RequestAuth) {
|
||||
if h.Auth == nil || len(expired) == 0 {
|
||||
return
|
||||
}
|
||||
for _, a := range expired {
|
||||
h.Auth.Release(a)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) sweepExpiredStreamLeases() {
|
||||
h.leaseMu.Lock()
|
||||
expired := h.popExpiredLeasesLocked(time.Now())
|
||||
h.leaseMu.Unlock()
|
||||
h.releaseExpiredAuths(expired)
|
||||
}
|
||||
|
||||
func streamLeaseTTL() time.Duration {
|
||||
raw := strings.TrimSpace(os.Getenv("DS2API_VERCEL_STREAM_LEASE_TTL_SECONDS"))
|
||||
if raw == "" {
|
||||
return 15 * time.Minute
|
||||
}
|
||||
seconds, err := strconv.Atoi(raw)
|
||||
if err != nil || seconds <= 0 {
|
||||
return 15 * time.Minute
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func newLeaseID() string {
|
||||
buf := make([]byte, 16)
|
||||
if _, err := rand.Read(buf); err == nil {
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
return fmt.Sprintf("lease-%d", time.Now().UnixNano())
|
||||
}
|
||||
Reference in New Issue
Block a user