feat: add stream parsing for DeepSeek SSE responses, centralize model management, and refactor tool call handling.

This commit is contained in:
CJACK
2026-02-01 03:53:01 +08:00
parent 4193336dd8
commit b90901d5a0
10 changed files with 359 additions and 192 deletions

View File

@@ -5,6 +5,7 @@ from fastapi import HTTPException, Request
from .config import CONFIG, logger
from .deepseek import login_deepseek_via_account, BASE_HEADERS
from .utils import get_account_identifier
# -------------------------- 全局账号队列 --------------------------
# 使用列表实现轮询队列,配合线程锁保证并发安全
@@ -37,12 +38,7 @@ init_account_queue()
init_claude_api_key_queue()
# ----------------------------------------------------------------------
# 辅助函数:获取账号唯一标识(优先 email否则 mobile
# ----------------------------------------------------------------------
def get_account_identifier(account: dict) -> str:
"""返回账号的唯一标识,优先使用 email否则使用 mobile"""
return account.get("email", "").strip() or account.get("mobile", "").strip()
# get_account_identifier 已移至 core.utils
def get_queue_status() -> dict:
@@ -176,12 +172,7 @@ def get_auth_headers(request: Request) -> dict:
return {**BASE_HEADERS, "authorization": f"Bearer {request.state.deepseek_token}"}
# ----------------------------------------------------------------------
# Claude 认证相关函数
# ----------------------------------------------------------------------
def determine_claude_mode_and_token(request: Request):
"""Claude认证沿用现有的OpenAI接口认证逻辑"""
determine_mode_and_token(request)
# determine_claude_mode_and_token 已移除(直接使用 determine_mode_and_token
# ----------------------------------------------------------------------

View File

@@ -5,6 +5,7 @@ from curl_cffi import requests
from fastapi import HTTPException
from .config import CONFIG, save_config, logger
from .utils import get_account_identifier
# ----------------------------------------------------------------------
# DeepSeek 相关常量
@@ -28,9 +29,7 @@ BASE_HEADERS = {
}
def get_account_identifier(account: dict) -> str:
"""返回账号的唯一标识,优先使用 email否则使用 mobile"""
return account.get("email", "").strip() or account.get("mobile", "").strip()
# get_account_identifier 已移至 core.utils
# ----------------------------------------------------------------------

90
core/models.py Normal file
View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
"""模型定义模块 - 集中管理所有支持的模型"""
# DeepSeek 模型列表(官方模型名称)
DEEPSEEK_MODELS = [
{
"id": "deepseek-chat",
"object": "model",
"created": 1677610602,
"owned_by": "deepseek",
"permission": [],
},
{
"id": "deepseek-reasoner",
"object": "model",
"created": 1677610602,
"owned_by": "deepseek",
"permission": [],
},
{
"id": "deepseek-chat-search",
"object": "model",
"created": 1677610602,
"owned_by": "deepseek",
"permission": [],
},
{
"id": "deepseek-reasoner-search",
"object": "model",
"created": 1677610602,
"owned_by": "deepseek",
"permission": [],
},
]
# Claude 模型映射列表
CLAUDE_MODELS = [
{
"id": "claude-sonnet-4-20250514",
"object": "model",
"created": 1715635200,
"owned_by": "anthropic",
},
{
"id": "claude-sonnet-4-20250514-fast",
"object": "model",
"created": 1715635200,
"owned_by": "anthropic",
},
{
"id": "claude-sonnet-4-20250514-slow",
"object": "model",
"created": 1715635200,
"owned_by": "anthropic",
},
]
def get_model_config(model: str) -> tuple[bool, bool]:
"""根据模型名称获取配置
Args:
model: 模型名称
Returns:
(thinking_enabled, search_enabled) 元组
"""
model_lower = model.lower()
if model_lower == "deepseek-chat":
return False, False
elif model_lower == "deepseek-reasoner":
return True, False
elif model_lower == "deepseek-chat-search":
return False, True
elif model_lower == "deepseek-reasoner-search":
return True, True
else:
return None, None # 不支持的模型
def get_openai_models_response() -> dict:
"""获取 OpenAI 格式的模型列表响应"""
return {"object": "list", "data": DEEPSEEK_MODELS}
def get_claude_models_response() -> dict:
"""获取 Claude 格式的模型列表响应"""
return {"object": "list", "data": CLAUDE_MODELS}

View File

@@ -11,6 +11,7 @@ from curl_cffi import requests
from wasmtime import Engine, Linker, Module, Store
from .config import CONFIG, WASM_PATH, logger
from .utils import get_account_identifier
# ----------------------------------------------------------------------
# WASM 模块缓存 - 避免每次请求都重新加载
@@ -51,10 +52,7 @@ try:
except Exception as e:
logger.warning(f"[WASM] 启动时预加载失败(将在首次使用时重试): {e}")
def get_account_identifier(account: dict) -> str:
"""返回账号的唯一标识"""
return account.get("email", "").strip() or account.get("mobile", "").strip()
# get_account_identifier 已移至 core.utils
# ----------------------------------------------------------------------
@@ -152,17 +150,24 @@ def compute_pow_answer(
return int(value)
# ----------------------------------------------------------------------
# 获取 PoW 响应,融合计算 answer 逻辑
# ----------------------------------------------------------------------
def get_pow_response(request, get_auth_headers_func, choose_new_account_func,
login_func, pow_url: str, max_attempts: int = 3):
"""获取 PoW 响应"""
from .deepseek import BASE_HEADERS
def get_pow_response(request, max_attempts: int = 3):
"""获取 PoW 响应
Args:
request: FastAPI 请求对象
max_attempts: 最大重试次数
Returns:
Base64 编码的 PoW 响应,如果失败返回 None
"""
from .auth import get_auth_headers, choose_new_account
from .deepseek import BASE_HEADERS, login_deepseek_via_account, DEEPSEEK_CREATE_POW_URL
pow_url = DEEPSEEK_CREATE_POW_URL
attempts = 0
while attempts < max_attempts:
headers = get_auth_headers_func(request)
headers = get_auth_headers(request)
try:
resp = requests.post(
pow_url,
@@ -227,11 +232,11 @@ def get_pow_response(request, get_auth_headers_func, choose_new_account_func,
request.state.tried_accounts = []
if current_id not in request.state.tried_accounts:
request.state.tried_accounts.append(current_id)
new_account = choose_new_account_func(request.state.tried_accounts)
new_account = choose_new_account(request.state.tried_accounts)
if new_account is None:
break
try:
login_func(new_account)
login_deepseek_via_account(new_account)
except Exception as e:
logger.error(
f"[get_pow_response] 账号 {get_account_identifier(new_account)} 登录失败:{e}"
@@ -245,3 +250,4 @@ def get_pow_response(request, get_auth_headers_func, choose_new_account_func,
continue
attempts += 1
return None

View File

@@ -4,10 +4,11 @@ from curl_cffi import requests as cffi_requests
from fastapi import HTTPException, Request
from .config import logger
from .utils import get_account_identifier
from .models import get_model_config
from .auth import (
get_auth_headers,
choose_new_account,
get_account_identifier,
release_account,
refresh_account_token,
)
@@ -114,14 +115,7 @@ def get_pow(request: Request, max_attempts: int = 3) -> str | None:
Returns:
Base64 编码的 PoW 响应,如果失败返回 None
"""
return get_pow_response(
request,
get_auth_headers,
choose_new_account,
login_deepseek_via_account,
DEEPSEEK_CREATE_POW_URL,
max_attempts,
)
return get_pow_response(request, max_attempts)
def prepare_completion_request(
@@ -162,27 +156,7 @@ def prepare_completion_request(
return call_completion_endpoint(payload, headers, max_attempts)
def get_model_config(model: str) -> tuple[bool, bool]:
"""根据模型名称获取配置
Args:
model: 模型名称
Returns:
(thinking_enabled, search_enabled) 元组
"""
model_lower = model.lower()
if model_lower in ["deepseek-v3", "deepseek-chat"]:
return False, False
elif model_lower in ["deepseek-r1", "deepseek-reasoner"]:
return True, False
elif model_lower in ["deepseek-v3-search", "deepseek-chat-search"]:
return False, True
elif model_lower in ["deepseek-r1-search", "deepseek-reasoner-search"]:
return True, True
else:
return None, None # 不支持的模型
# get_model_config 已移至 core.models
def cleanup_account(request: Request):

186
core/stream_parser.py Normal file
View File

@@ -0,0 +1,186 @@
# -*- coding: utf-8 -*-
"""流解析模块 - 处理 DeepSeek SSE 流响应"""
import json
import re
from .config import logger
# 预编译正则表达式
_TOOL_CALL_PATTERN = re.compile(r'\{\s*["\']tool_calls["\']\s*:\s*\[(.*?)\]\s*\}', re.DOTALL)
_CITATION_PATTERN = re.compile(r"^\[citation:")
def parse_deepseek_sse_line(raw_line: bytes) -> dict | None:
"""解析 DeepSeek SSE 行
Args:
raw_line: 原始字节行
Returns:
解析后的 chunk 字典,如果解析失败或应跳过则返回 None
"""
try:
line = raw_line.decode("utf-8")
except Exception as e:
logger.warning(f"[parse_deepseek_sse_line] 解码失败: {e}")
return None
if not line or not line.startswith("data:"):
return None
data_str = line[5:].strip()
if data_str == "[DONE]":
return {"type": "done"}
try:
chunk = json.loads(data_str)
return chunk
except json.JSONDecodeError as e:
logger.warning(f"[parse_deepseek_sse_line] JSON解析失败: {e}")
return None
def extract_content_from_chunk(chunk: dict) -> tuple[str, str, bool]:
"""从 DeepSeek chunk 中提取内容
Args:
chunk: 解析后的 chunk 字典
Returns:
(content, content_type, is_finished) 元组
content_type 为 "thinking""text"
is_finished 为 True 表示响应结束
"""
if chunk.get("type") == "done":
return "", "text", True
# 检测内容审核/敏感词阻止
if "error" in chunk or chunk.get("code") == "content_filter":
logger.warning(f"[extract_content_from_chunk] 检测到内容过滤: {chunk}")
return "", "text", True
if "v" not in chunk:
return "", "text", False
v_value = chunk["v"]
ptype = "text"
# 检查路径确定类型
path = chunk.get("p", "")
if path == "response/search_status":
return "", "text", False # 跳过搜索状态
elif path == "response/thinking_content":
ptype = "thinking"
elif path == "response/content":
ptype = "text"
if isinstance(v_value, str):
if v_value == "FINISHED":
return "", ptype, True
return v_value, ptype, False
elif isinstance(v_value, list):
for item in v_value:
if item.get("p") == "status" and item.get("v") == "FINISHED":
return "", ptype, True
return "", ptype, False
return "", ptype, False
def collect_deepseek_response(response) -> tuple[str, str]:
"""收集 DeepSeek 流响应的完整内容
Args:
response: DeepSeek 流响应对象
Returns:
(reasoning_content, text_content) 元组
"""
thinking_parts = []
text_parts = []
try:
for raw_line in response.iter_lines():
chunk = parse_deepseek_sse_line(raw_line)
if not chunk:
continue
content, content_type, is_finished = extract_content_from_chunk(chunk)
if is_finished:
break
if content:
if content_type == "thinking":
thinking_parts.append(content)
else:
text_parts.append(content)
except Exception as e:
logger.error(f"[collect_deepseek_response] 收集响应失败: {e}")
finally:
try:
response.close()
except Exception:
pass
return "".join(thinking_parts), "".join(text_parts)
def parse_tool_calls(text: str, tools_requested: list) -> list[dict]:
"""从响应文本中解析工具调用
Args:
text: 响应文本
tools_requested: 请求中定义的工具列表
Returns:
检测到的工具调用列表,每项包含 name 和 input
"""
detected_tools = []
cleaned_text = text.strip()
# 尝试直接解析完整 JSON
if cleaned_text.startswith('{"tool_calls":') and cleaned_text.endswith("]}"):
try:
tool_data = json.loads(cleaned_text)
for tool_call in tool_data.get("tool_calls", []):
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
if any(tool.get("name") == tool_name for tool in tools_requested):
detected_tools.append({"name": tool_name, "input": tool_input})
if detected_tools:
return detected_tools
except json.JSONDecodeError:
pass
# 使用正则匹配
matches = _TOOL_CALL_PATTERN.findall(cleaned_text)
for match in matches:
try:
tool_calls_json = f'{{"tool_calls": [{match}]}}'
tool_data = json.loads(tool_calls_json)
for tool_call in tool_data.get("tool_calls", []):
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
if any(tool.get("name") == tool_name for tool in tools_requested):
detected_tools.append({"name": tool_name, "input": tool_input})
except json.JSONDecodeError:
continue
return detected_tools
def should_filter_citation(text: str, search_enabled: bool) -> bool:
"""检查是否应该过滤引用内容
Args:
text: 内容文本
search_enabled: 是否启用搜索
Returns:
是否应该过滤
"""
if not search_enabled:
return False
return _CITATION_PATTERN.match(text) is not None

29
core/utils.py Normal file
View File

@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
"""公共工具函数模块"""
def get_account_identifier(account: dict) -> str:
"""返回账号的唯一标识,优先使用 email否则使用 mobile"""
return account.get("email", "").strip() or account.get("mobile", "").strip()
def estimate_tokens(text) -> int:
"""估算文本的 token 数量(简单估算:字符数/4
Args:
text: 字符串或其他类型
Returns:
估算的 token 数量,最小为 1
"""
if isinstance(text, str):
return max(1, len(text) // 4)
elif isinstance(text, list):
return sum(
estimate_tokens(item.get("text", ""))
if isinstance(item, dict)
else estimate_tokens(str(item))
for item in text
)
else:
return max(1, len(str(text)) // 4)

View File

@@ -2,7 +2,6 @@
"""Claude API 路由"""
import json
import random
import re
import time
from curl_cffi import requests as cffi_requests
@@ -11,16 +10,23 @@ from fastapi.responses import JSONResponse, StreamingResponse
from core.config import CONFIG, logger
from core.auth import (
determine_claude_mode_and_token,
determine_mode_and_token,
get_auth_headers,
)
from core.deepseek import call_completion_endpoint
from core.session_manager import (
create_session,
get_pow,
get_model_config,
cleanup_account,
)
from core.models import get_model_config, get_claude_models_response
from core.stream_parser import (
parse_deepseek_sse_line,
extract_content_from_chunk,
collect_deepseek_response,
parse_tool_calls,
)
from core.utils import estimate_tokens
from core.messages import (
messages_prepare,
convert_claude_to_deepseek,
@@ -29,9 +35,6 @@ from core.messages import (
router = APIRouter()
# 预编译正则表达式(性能优化)
_TOOL_CALL_PATTERN = re.compile(r'\{\s*["\']tool_calls["\']\s*:\s*\[(.*?)\]\s*\}', re.DOTALL)
# ----------------------------------------------------------------------
# 通过 OpenAI 接口调用 Claude
@@ -87,27 +90,7 @@ async def call_claude_via_openai(request: Request, claude_payload: dict):
# ----------------------------------------------------------------------
@router.get("/anthropic/v1/models")
def list_claude_models():
models_list = [
{
"id": "claude-sonnet-4-20250514",
"object": "model",
"created": 1715635200,
"owned_by": "anthropic",
},
{
"id": "claude-sonnet-4-20250514-fast",
"object": "model",
"created": 1715635200,
"owned_by": "anthropic",
},
{
"id": "claude-sonnet-4-20250514-slow",
"object": "model",
"created": 1715635200,
"owned_by": "anthropic",
},
]
data = {"object": "list", "data": models_list}
data = get_claude_models_response()
return JSONResponse(content=data, status_code=200)
@@ -118,13 +101,13 @@ def list_claude_models():
async def claude_messages(request: Request):
try:
try:
determine_claude_mode_and_token(request)
determine_mode_and_token(request)
except HTTPException as exc:
return JSONResponse(
status_code=exc.status_code, content={"error": exc.detail}
)
except Exception as exc:
logger.error(f"[claude_messages] determine_claude_mode_and_token 异常: {exc}")
logger.error(f"[claude_messages] determine_mode_and_token 异常: {exc}")
return JSONResponse(
status_code=500, content={"error": "Claude authentication failed."}
)
@@ -290,34 +273,8 @@ Remember: Output ONLY the JSON, no other text. The response must start with {{ a
yield f"data: {json.dumps(message_start)}\n\n"
# 检查工具调用
detected_tools = []
cleaned_response = full_response_text.strip()
if cleaned_response.startswith('{"tool_calls":') and cleaned_response.endswith("]}"):
try:
tool_data = json.loads(cleaned_response)
for tool_call in tool_data.get("tool_calls", []):
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
if any(tool.get("name") == tool_name for tool in tools_requested):
detected_tools.append({"name": tool_name, "input": tool_input})
except json.JSONDecodeError:
pass
if not detected_tools:
# 使用预编译的正则表达式
matches = _TOOL_CALL_PATTERN.findall(cleaned_response)
for match in matches:
try:
tool_calls_json = f'{{"tool_calls": [{match}]}}'
tool_data = json.loads(tool_calls_json)
for tool_call in tool_data.get("tool_calls", []):
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
if any(tool.get("name") == tool_name for tool in tools_requested):
detected_tools.append({"name": tool_name, "input": tool_input})
except json.JSONDecodeError:
continue
# 使用公共函数检测工具调用
detected_tools = parse_tool_calls(full_response_text, tools_requested)
content_index = 0
if detected_tools:
@@ -415,34 +372,7 @@ Remember: Output ONLY the JSON, no other text. The response must start with {{ a
logger.warning(f"[claude_messages] 关闭响应异常: {e}")
# 检查工具调用
detected_tools = []
cleaned_content = final_content.strip()
if cleaned_content.startswith('{"tool_calls":') and cleaned_content.endswith("]}"):
try:
tool_data = json.loads(cleaned_content)
for tool_call in tool_data.get("tool_calls", []):
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
if any(tool.get("name") == tool_name for tool in tools_requested):
detected_tools.append({"name": tool_name, "input": tool_input})
except json.JSONDecodeError:
pass
if not detected_tools:
# 使用预编译的正则表达式
matches = _TOOL_CALL_PATTERN.findall(cleaned_content)
for match in matches:
try:
tool_calls_json = f'{{"tool_calls": [{match}]}}'
tool_data = json.loads(tool_calls_json)
for tool_call in tool_data.get("tool_calls", []):
tool_name = tool_call.get("name")
tool_input = tool_call.get("input", {})
if any(tool.get("name") == tool_name for tool in tools_requested):
detected_tools.append({"name": tool_name, "input": tool_input})
except json.JSONDecodeError:
continue
detected_tools = parse_tool_calls(final_content, tools_requested)
# 构造响应
claude_response = {
@@ -513,11 +443,11 @@ Remember: Output ONLY the JSON, no other text. The response must start with {{ a
async def claude_count_tokens(request: Request):
try:
try:
determine_claude_mode_and_token(request)
determine_mode_and_token(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
except Exception as exc:
logger.error(f"[claude_count_tokens] determine_claude_mode_and_token 异常: {exc}")
logger.error(f"[claude_count_tokens] determine_mode_and_token 异常: {exc}")
return JSONResponse(status_code=500, content={"error": "Claude authentication failed."})
req_data = await request.json()
@@ -530,19 +460,6 @@ async def claude_count_tokens(request: Request):
status_code=400, detail="Request must include 'model' and 'messages'."
)
def estimate_tokens(text):
if isinstance(text, str):
return len(text) // 4
elif isinstance(text, list):
return sum(
estimate_tokens(item.get("text", ""))
if isinstance(item, dict)
else estimate_tokens(str(item))
for item in text
)
else:
return len(str(text)) // 4
input_tokens = 0
if system:

View File

@@ -20,9 +20,14 @@ from core.deepseek import call_completion_endpoint
from core.session_manager import (
create_session,
get_pow,
get_model_config,
cleanup_account,
)
from core.models import get_model_config, get_openai_models_response
from core.stream_parser import (
parse_deepseek_sse_line,
extract_content_from_chunk,
should_filter_citation,
)
from core.messages import messages_prepare
router = APIRouter()
@@ -39,37 +44,7 @@ _CITATION_PATTERN = re.compile(r"^\[citation:")
# ----------------------------------------------------------------------
@router.get("/v1/models")
def list_models():
models_list = [
{
"id": "deepseek-chat",
"object": "model",
"created": 1677610602,
"owned_by": "deepseek",
"permission": [],
},
{
"id": "deepseek-reasoner",
"object": "model",
"created": 1677610602,
"owned_by": "deepseek",
"permission": [],
},
{
"id": "deepseek-chat-search",
"object": "model",
"created": 1677610602,
"owned_by": "deepseek",
"permission": [],
},
{
"id": "deepseek-reasoner-search",
"object": "model",
"created": 1677610602,
"owned_by": "deepseek",
"permission": [],
},
]
data = {"object": "list", "data": models_list}
data = get_openai_models_response()
return JSONResponse(content=data, status_code=200)

View File

@@ -158,7 +158,7 @@ class TestPow(unittest.TestCase):
def test_get_account_identifier(self):
"""测试账号标识获取"""
from core.pow import get_account_identifier
from core.utils import get_account_identifier
# 测试邮箱
account1 = {"email": "test@example.com"}