diff --git a/core/auth.py b/core/auth.py index 707e94d..cddeb1a 100644 --- a/core/auth.py +++ b/core/auth.py @@ -5,6 +5,7 @@ from fastapi import HTTPException, Request from .config import CONFIG, logger from .deepseek import login_deepseek_via_account, BASE_HEADERS +from .utils import get_account_identifier # -------------------------- 全局账号队列 -------------------------- # 使用列表实现轮询队列,配合线程锁保证并发安全 @@ -37,12 +38,7 @@ init_account_queue() init_claude_api_key_queue() -# ---------------------------------------------------------------------- -# 辅助函数:获取账号唯一标识(优先 email,否则 mobile) -# ---------------------------------------------------------------------- -def get_account_identifier(account: dict) -> str: - """返回账号的唯一标识,优先使用 email,否则使用 mobile""" - return account.get("email", "").strip() or account.get("mobile", "").strip() +# get_account_identifier 已移至 core.utils def get_queue_status() -> dict: @@ -176,12 +172,7 @@ def get_auth_headers(request: Request) -> dict: return {**BASE_HEADERS, "authorization": f"Bearer {request.state.deepseek_token}"} -# ---------------------------------------------------------------------- -# Claude 认证相关函数 -# ---------------------------------------------------------------------- -def determine_claude_mode_and_token(request: Request): - """Claude认证:沿用现有的OpenAI接口认证逻辑""" - determine_mode_and_token(request) +# determine_claude_mode_and_token 已移除(直接使用 determine_mode_and_token) # ---------------------------------------------------------------------- diff --git a/core/deepseek.py b/core/deepseek.py index 9ffb62e..041442d 100644 --- a/core/deepseek.py +++ b/core/deepseek.py @@ -5,6 +5,7 @@ from curl_cffi import requests from fastapi import HTTPException from .config import CONFIG, save_config, logger +from .utils import get_account_identifier # ---------------------------------------------------------------------- # DeepSeek 相关常量 @@ -28,9 +29,7 @@ BASE_HEADERS = { } -def get_account_identifier(account: dict) -> str: - """返回账号的唯一标识,优先使用 email,否则使用 mobile""" - return account.get("email", "").strip() or account.get("mobile", "").strip() +# get_account_identifier 已移至 core.utils # ---------------------------------------------------------------------- diff --git a/core/models.py b/core/models.py new file mode 100644 index 0000000..f1dddaa --- /dev/null +++ b/core/models.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +"""模型定义模块 - 集中管理所有支持的模型""" + +# DeepSeek 模型列表(官方模型名称) +DEEPSEEK_MODELS = [ + { + "id": "deepseek-chat", + "object": "model", + "created": 1677610602, + "owned_by": "deepseek", + "permission": [], + }, + { + "id": "deepseek-reasoner", + "object": "model", + "created": 1677610602, + "owned_by": "deepseek", + "permission": [], + }, + { + "id": "deepseek-chat-search", + "object": "model", + "created": 1677610602, + "owned_by": "deepseek", + "permission": [], + }, + { + "id": "deepseek-reasoner-search", + "object": "model", + "created": 1677610602, + "owned_by": "deepseek", + "permission": [], + }, +] + +# Claude 模型映射列表 +CLAUDE_MODELS = [ + { + "id": "claude-sonnet-4-20250514", + "object": "model", + "created": 1715635200, + "owned_by": "anthropic", + }, + { + "id": "claude-sonnet-4-20250514-fast", + "object": "model", + "created": 1715635200, + "owned_by": "anthropic", + }, + { + "id": "claude-sonnet-4-20250514-slow", + "object": "model", + "created": 1715635200, + "owned_by": "anthropic", + }, +] + + +def get_model_config(model: str) -> tuple[bool, bool]: + """根据模型名称获取配置 + + Args: + model: 模型名称 + + Returns: + (thinking_enabled, search_enabled) 元组 + """ + model_lower = model.lower() + + if model_lower == "deepseek-chat": + return False, False + elif model_lower == "deepseek-reasoner": + return True, False + elif model_lower == "deepseek-chat-search": + return False, True + elif model_lower == "deepseek-reasoner-search": + return True, True + else: + return None, None # 不支持的模型 + + +def get_openai_models_response() -> dict: + """获取 OpenAI 格式的模型列表响应""" + return {"object": "list", "data": DEEPSEEK_MODELS} + + +def get_claude_models_response() -> dict: + """获取 Claude 格式的模型列表响应""" + return {"object": "list", "data": CLAUDE_MODELS} + diff --git a/core/pow.py b/core/pow.py index 84d1f22..d8d1a72 100644 --- a/core/pow.py +++ b/core/pow.py @@ -11,6 +11,7 @@ from curl_cffi import requests from wasmtime import Engine, Linker, Module, Store from .config import CONFIG, WASM_PATH, logger +from .utils import get_account_identifier # ---------------------------------------------------------------------- # WASM 模块缓存 - 避免每次请求都重新加载 @@ -51,10 +52,7 @@ try: except Exception as e: logger.warning(f"[WASM] 启动时预加载失败(将在首次使用时重试): {e}") - -def get_account_identifier(account: dict) -> str: - """返回账号的唯一标识""" - return account.get("email", "").strip() or account.get("mobile", "").strip() +# get_account_identifier 已移至 core.utils # ---------------------------------------------------------------------- @@ -152,17 +150,24 @@ def compute_pow_answer( return int(value) -# ---------------------------------------------------------------------- -# 获取 PoW 响应,融合计算 answer 逻辑 -# ---------------------------------------------------------------------- -def get_pow_response(request, get_auth_headers_func, choose_new_account_func, - login_func, pow_url: str, max_attempts: int = 3): - """获取 PoW 响应""" - from .deepseek import BASE_HEADERS +def get_pow_response(request, max_attempts: int = 3): + """获取 PoW 响应 + + Args: + request: FastAPI 请求对象 + max_attempts: 最大重试次数 + + Returns: + Base64 编码的 PoW 响应,如果失败返回 None + """ + from .auth import get_auth_headers, choose_new_account + from .deepseek import BASE_HEADERS, login_deepseek_via_account, DEEPSEEK_CREATE_POW_URL + + pow_url = DEEPSEEK_CREATE_POW_URL attempts = 0 while attempts < max_attempts: - headers = get_auth_headers_func(request) + headers = get_auth_headers(request) try: resp = requests.post( pow_url, @@ -227,11 +232,11 @@ def get_pow_response(request, get_auth_headers_func, choose_new_account_func, request.state.tried_accounts = [] if current_id not in request.state.tried_accounts: request.state.tried_accounts.append(current_id) - new_account = choose_new_account_func(request.state.tried_accounts) + new_account = choose_new_account(request.state.tried_accounts) if new_account is None: break try: - login_func(new_account) + login_deepseek_via_account(new_account) except Exception as e: logger.error( f"[get_pow_response] 账号 {get_account_identifier(new_account)} 登录失败:{e}" @@ -245,3 +250,4 @@ def get_pow_response(request, get_auth_headers_func, choose_new_account_func, continue attempts += 1 return None + diff --git a/core/session_manager.py b/core/session_manager.py index 8849042..ca87816 100644 --- a/core/session_manager.py +++ b/core/session_manager.py @@ -4,10 +4,11 @@ from curl_cffi import requests as cffi_requests from fastapi import HTTPException, Request from .config import logger +from .utils import get_account_identifier +from .models import get_model_config from .auth import ( get_auth_headers, choose_new_account, - get_account_identifier, release_account, refresh_account_token, ) @@ -114,14 +115,7 @@ def get_pow(request: Request, max_attempts: int = 3) -> str | None: Returns: Base64 编码的 PoW 响应,如果失败返回 None """ - return get_pow_response( - request, - get_auth_headers, - choose_new_account, - login_deepseek_via_account, - DEEPSEEK_CREATE_POW_URL, - max_attempts, - ) + return get_pow_response(request, max_attempts) def prepare_completion_request( @@ -162,27 +156,7 @@ def prepare_completion_request( return call_completion_endpoint(payload, headers, max_attempts) -def get_model_config(model: str) -> tuple[bool, bool]: - """根据模型名称获取配置 - - Args: - model: 模型名称 - - Returns: - (thinking_enabled, search_enabled) 元组 - """ - model_lower = model.lower() - - if model_lower in ["deepseek-v3", "deepseek-chat"]: - return False, False - elif model_lower in ["deepseek-r1", "deepseek-reasoner"]: - return True, False - elif model_lower in ["deepseek-v3-search", "deepseek-chat-search"]: - return False, True - elif model_lower in ["deepseek-r1-search", "deepseek-reasoner-search"]: - return True, True - else: - return None, None # 不支持的模型 +# get_model_config 已移至 core.models def cleanup_account(request: Request): diff --git a/core/stream_parser.py b/core/stream_parser.py new file mode 100644 index 0000000..2e2229b --- /dev/null +++ b/core/stream_parser.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +"""流解析模块 - 处理 DeepSeek SSE 流响应""" +import json +import re + +from .config import logger + +# 预编译正则表达式 +_TOOL_CALL_PATTERN = re.compile(r'\{\s*["\']tool_calls["\']\s*:\s*\[(.*?)\]\s*\}', re.DOTALL) +_CITATION_PATTERN = re.compile(r"^\[citation:") + + +def parse_deepseek_sse_line(raw_line: bytes) -> dict | None: + """解析 DeepSeek SSE 行 + + Args: + raw_line: 原始字节行 + + Returns: + 解析后的 chunk 字典,如果解析失败或应跳过则返回 None + """ + try: + line = raw_line.decode("utf-8") + except Exception as e: + logger.warning(f"[parse_deepseek_sse_line] 解码失败: {e}") + return None + + if not line or not line.startswith("data:"): + return None + + data_str = line[5:].strip() + + if data_str == "[DONE]": + return {"type": "done"} + + try: + chunk = json.loads(data_str) + return chunk + except json.JSONDecodeError as e: + logger.warning(f"[parse_deepseek_sse_line] JSON解析失败: {e}") + return None + + +def extract_content_from_chunk(chunk: dict) -> tuple[str, str, bool]: + """从 DeepSeek chunk 中提取内容 + + Args: + chunk: 解析后的 chunk 字典 + + Returns: + (content, content_type, is_finished) 元组 + content_type 为 "thinking" 或 "text" + is_finished 为 True 表示响应结束 + """ + if chunk.get("type") == "done": + return "", "text", True + + # 检测内容审核/敏感词阻止 + if "error" in chunk or chunk.get("code") == "content_filter": + logger.warning(f"[extract_content_from_chunk] 检测到内容过滤: {chunk}") + return "", "text", True + + if "v" not in chunk: + return "", "text", False + + v_value = chunk["v"] + ptype = "text" + + # 检查路径确定类型 + path = chunk.get("p", "") + if path == "response/search_status": + return "", "text", False # 跳过搜索状态 + elif path == "response/thinking_content": + ptype = "thinking" + elif path == "response/content": + ptype = "text" + + if isinstance(v_value, str): + if v_value == "FINISHED": + return "", ptype, True + return v_value, ptype, False + elif isinstance(v_value, list): + for item in v_value: + if item.get("p") == "status" and item.get("v") == "FINISHED": + return "", ptype, True + return "", ptype, False + + return "", ptype, False + + +def collect_deepseek_response(response) -> tuple[str, str]: + """收集 DeepSeek 流响应的完整内容 + + Args: + response: DeepSeek 流响应对象 + + Returns: + (reasoning_content, text_content) 元组 + """ + thinking_parts = [] + text_parts = [] + + try: + for raw_line in response.iter_lines(): + chunk = parse_deepseek_sse_line(raw_line) + if not chunk: + continue + + content, content_type, is_finished = extract_content_from_chunk(chunk) + + if is_finished: + break + + if content: + if content_type == "thinking": + thinking_parts.append(content) + else: + text_parts.append(content) + except Exception as e: + logger.error(f"[collect_deepseek_response] 收集响应失败: {e}") + finally: + try: + response.close() + except Exception: + pass + + return "".join(thinking_parts), "".join(text_parts) + + +def parse_tool_calls(text: str, tools_requested: list) -> list[dict]: + """从响应文本中解析工具调用 + + Args: + text: 响应文本 + tools_requested: 请求中定义的工具列表 + + Returns: + 检测到的工具调用列表,每项包含 name 和 input + """ + detected_tools = [] + cleaned_text = text.strip() + + # 尝试直接解析完整 JSON + if cleaned_text.startswith('{"tool_calls":') and cleaned_text.endswith("]}"): + try: + tool_data = json.loads(cleaned_text) + for tool_call in tool_data.get("tool_calls", []): + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) + if any(tool.get("name") == tool_name for tool in tools_requested): + detected_tools.append({"name": tool_name, "input": tool_input}) + if detected_tools: + return detected_tools + except json.JSONDecodeError: + pass + + # 使用正则匹配 + matches = _TOOL_CALL_PATTERN.findall(cleaned_text) + for match in matches: + try: + tool_calls_json = f'{{"tool_calls": [{match}]}}' + tool_data = json.loads(tool_calls_json) + for tool_call in tool_data.get("tool_calls", []): + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) + if any(tool.get("name") == tool_name for tool in tools_requested): + detected_tools.append({"name": tool_name, "input": tool_input}) + except json.JSONDecodeError: + continue + + return detected_tools + + +def should_filter_citation(text: str, search_enabled: bool) -> bool: + """检查是否应该过滤引用内容 + + Args: + text: 内容文本 + search_enabled: 是否启用搜索 + + Returns: + 是否应该过滤 + """ + if not search_enabled: + return False + return _CITATION_PATTERN.match(text) is not None diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 0000000..78cb59a --- /dev/null +++ b/core/utils.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +"""公共工具函数模块""" + + +def get_account_identifier(account: dict) -> str: + """返回账号的唯一标识,优先使用 email,否则使用 mobile""" + return account.get("email", "").strip() or account.get("mobile", "").strip() + + +def estimate_tokens(text) -> int: + """估算文本的 token 数量(简单估算:字符数/4) + + Args: + text: 字符串或其他类型 + + Returns: + 估算的 token 数量,最小为 1 + """ + if isinstance(text, str): + return max(1, len(text) // 4) + elif isinstance(text, list): + return sum( + estimate_tokens(item.get("text", "")) + if isinstance(item, dict) + else estimate_tokens(str(item)) + for item in text + ) + else: + return max(1, len(str(text)) // 4) diff --git a/routes/claude.py b/routes/claude.py index c0ae735..cdb331f 100644 --- a/routes/claude.py +++ b/routes/claude.py @@ -2,7 +2,6 @@ """Claude API 路由""" import json import random -import re import time from curl_cffi import requests as cffi_requests @@ -11,16 +10,23 @@ from fastapi.responses import JSONResponse, StreamingResponse from core.config import CONFIG, logger from core.auth import ( - determine_claude_mode_and_token, + determine_mode_and_token, get_auth_headers, ) from core.deepseek import call_completion_endpoint from core.session_manager import ( create_session, get_pow, - get_model_config, cleanup_account, ) +from core.models import get_model_config, get_claude_models_response +from core.stream_parser import ( + parse_deepseek_sse_line, + extract_content_from_chunk, + collect_deepseek_response, + parse_tool_calls, +) +from core.utils import estimate_tokens from core.messages import ( messages_prepare, convert_claude_to_deepseek, @@ -29,9 +35,6 @@ from core.messages import ( router = APIRouter() -# 预编译正则表达式(性能优化) -_TOOL_CALL_PATTERN = re.compile(r'\{\s*["\']tool_calls["\']\s*:\s*\[(.*?)\]\s*\}', re.DOTALL) - # ---------------------------------------------------------------------- # 通过 OpenAI 接口调用 Claude @@ -87,27 +90,7 @@ async def call_claude_via_openai(request: Request, claude_payload: dict): # ---------------------------------------------------------------------- @router.get("/anthropic/v1/models") def list_claude_models(): - models_list = [ - { - "id": "claude-sonnet-4-20250514", - "object": "model", - "created": 1715635200, - "owned_by": "anthropic", - }, - { - "id": "claude-sonnet-4-20250514-fast", - "object": "model", - "created": 1715635200, - "owned_by": "anthropic", - }, - { - "id": "claude-sonnet-4-20250514-slow", - "object": "model", - "created": 1715635200, - "owned_by": "anthropic", - }, - ] - data = {"object": "list", "data": models_list} + data = get_claude_models_response() return JSONResponse(content=data, status_code=200) @@ -118,13 +101,13 @@ def list_claude_models(): async def claude_messages(request: Request): try: try: - determine_claude_mode_and_token(request) + determine_mode_and_token(request) except HTTPException as exc: return JSONResponse( status_code=exc.status_code, content={"error": exc.detail} ) except Exception as exc: - logger.error(f"[claude_messages] determine_claude_mode_and_token 异常: {exc}") + logger.error(f"[claude_messages] determine_mode_and_token 异常: {exc}") return JSONResponse( status_code=500, content={"error": "Claude authentication failed."} ) @@ -290,34 +273,8 @@ Remember: Output ONLY the JSON, no other text. The response must start with {{ a yield f"data: {json.dumps(message_start)}\n\n" # 检查工具调用 - detected_tools = [] - cleaned_response = full_response_text.strip() - - if cleaned_response.startswith('{"tool_calls":') and cleaned_response.endswith("]}"): - try: - tool_data = json.loads(cleaned_response) - for tool_call in tool_data.get("tool_calls", []): - tool_name = tool_call.get("name") - tool_input = tool_call.get("input", {}) - if any(tool.get("name") == tool_name for tool in tools_requested): - detected_tools.append({"name": tool_name, "input": tool_input}) - except json.JSONDecodeError: - pass - - if not detected_tools: - # 使用预编译的正则表达式 - matches = _TOOL_CALL_PATTERN.findall(cleaned_response) - for match in matches: - try: - tool_calls_json = f'{{"tool_calls": [{match}]}}' - tool_data = json.loads(tool_calls_json) - for tool_call in tool_data.get("tool_calls", []): - tool_name = tool_call.get("name") - tool_input = tool_call.get("input", {}) - if any(tool.get("name") == tool_name for tool in tools_requested): - detected_tools.append({"name": tool_name, "input": tool_input}) - except json.JSONDecodeError: - continue + # 使用公共函数检测工具调用 + detected_tools = parse_tool_calls(full_response_text, tools_requested) content_index = 0 if detected_tools: @@ -415,34 +372,7 @@ Remember: Output ONLY the JSON, no other text. The response must start with {{ a logger.warning(f"[claude_messages] 关闭响应异常: {e}") # 检查工具调用 - detected_tools = [] - cleaned_content = final_content.strip() - - if cleaned_content.startswith('{"tool_calls":') and cleaned_content.endswith("]}"): - try: - tool_data = json.loads(cleaned_content) - for tool_call in tool_data.get("tool_calls", []): - tool_name = tool_call.get("name") - tool_input = tool_call.get("input", {}) - if any(tool.get("name") == tool_name for tool in tools_requested): - detected_tools.append({"name": tool_name, "input": tool_input}) - except json.JSONDecodeError: - pass - - if not detected_tools: - # 使用预编译的正则表达式 - matches = _TOOL_CALL_PATTERN.findall(cleaned_content) - for match in matches: - try: - tool_calls_json = f'{{"tool_calls": [{match}]}}' - tool_data = json.loads(tool_calls_json) - for tool_call in tool_data.get("tool_calls", []): - tool_name = tool_call.get("name") - tool_input = tool_call.get("input", {}) - if any(tool.get("name") == tool_name for tool in tools_requested): - detected_tools.append({"name": tool_name, "input": tool_input}) - except json.JSONDecodeError: - continue + detected_tools = parse_tool_calls(final_content, tools_requested) # 构造响应 claude_response = { @@ -513,11 +443,11 @@ Remember: Output ONLY the JSON, no other text. The response must start with {{ a async def claude_count_tokens(request: Request): try: try: - determine_claude_mode_and_token(request) + determine_mode_and_token(request) except HTTPException as exc: return JSONResponse(status_code=exc.status_code, content={"error": exc.detail}) except Exception as exc: - logger.error(f"[claude_count_tokens] determine_claude_mode_and_token 异常: {exc}") + logger.error(f"[claude_count_tokens] determine_mode_and_token 异常: {exc}") return JSONResponse(status_code=500, content={"error": "Claude authentication failed."}) req_data = await request.json() @@ -530,19 +460,6 @@ async def claude_count_tokens(request: Request): status_code=400, detail="Request must include 'model' and 'messages'." ) - def estimate_tokens(text): - if isinstance(text, str): - return len(text) // 4 - elif isinstance(text, list): - return sum( - estimate_tokens(item.get("text", "")) - if isinstance(item, dict) - else estimate_tokens(str(item)) - for item in text - ) - else: - return len(str(text)) // 4 - input_tokens = 0 if system: diff --git a/routes/openai.py b/routes/openai.py index 1b44d08..ffa3ef1 100644 --- a/routes/openai.py +++ b/routes/openai.py @@ -20,9 +20,14 @@ from core.deepseek import call_completion_endpoint from core.session_manager import ( create_session, get_pow, - get_model_config, cleanup_account, ) +from core.models import get_model_config, get_openai_models_response +from core.stream_parser import ( + parse_deepseek_sse_line, + extract_content_from_chunk, + should_filter_citation, +) from core.messages import messages_prepare router = APIRouter() @@ -39,37 +44,7 @@ _CITATION_PATTERN = re.compile(r"^\[citation:") # ---------------------------------------------------------------------- @router.get("/v1/models") def list_models(): - models_list = [ - { - "id": "deepseek-chat", - "object": "model", - "created": 1677610602, - "owned_by": "deepseek", - "permission": [], - }, - { - "id": "deepseek-reasoner", - "object": "model", - "created": 1677610602, - "owned_by": "deepseek", - "permission": [], - }, - { - "id": "deepseek-chat-search", - "object": "model", - "created": 1677610602, - "owned_by": "deepseek", - "permission": [], - }, - { - "id": "deepseek-reasoner-search", - "object": "model", - "created": 1677610602, - "owned_by": "deepseek", - "permission": [], - }, - ] - data = {"object": "list", "data": models_list} + data = get_openai_models_response() return JSONResponse(content=data, status_code=200) diff --git a/tests/test_unit.py b/tests/test_unit.py index 09c3160..6bdf6b4 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -158,7 +158,7 @@ class TestPow(unittest.TestCase): def test_get_account_identifier(self): """测试账号标识获取""" - from core.pow import get_account_identifier + from core.utils import get_account_identifier # 测试邮箱 account1 = {"email": "test@example.com"}