diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..4499b14 --- /dev/null +++ b/.env.example @@ -0,0 +1,47 @@ +# DS2API 环境变量配置模板 +# 复制此文件为 .env 并根据需要修改 + +# ===== 服务配置 ===== +# 服务端口 +PORT=5001 + +# 服务监听地址 +HOST=0.0.0.0 + +# 日志级别 (DEBUG, INFO, WARNING, ERROR) +LOG_LEVEL=INFO + +# ===== 配置来源(以下三种方式选一种)===== + +# 方式1: JSON 字符串 +# DS2API_CONFIG_JSON={"keys":["your-api-key"],"accounts":[{"email":"user@example.com","password":"xxx","token":""}]} + +# 方式2: Base64 编码的 JSON(推荐用于 Vercel,避免特殊字符问题) +# DS2API_CONFIG_JSON=eyJrZXlzIjpbInlvdXItYXBpLWtleSJdLCJhY2NvdW50cyI6W3siZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwicGFzc3dvcmQiOiJ4eHgiLCJ0b2tlbiI6IiJ9XX0= + +# 方式3: 配置文件路径(默认为 config.json) +# DS2API_CONFIG_PATH=config.json + +# ===== 可选:自定义路径 ===== +# Tokenizer 目录(留空使用项目根目录) +# DS2API_TOKENIZER_DIR= + +# 模板目录 +# DS2API_TEMPLATES_DIR=templates + +# WASM 文件路径 +# DS2API_WASM_PATH=sha3_wasm_bg.7b9ca65ddd.wasm + +# ===== Admin 管理界面 ===== +# Admin API 密钥(留空则开发模式,无需认证) +# DS2API_ADMIN_KEY=your-admin-secret-key + +# ===== Vercel 集成(可选,用于一键同步部署)===== +# Vercel API Token(从 https://vercel.com/account/tokens 获取) +# VERCEL_TOKEN=your-vercel-token + +# Vercel Project ID(在项目设置中找) +# VERCEL_PROJECT_ID=prj_xxxxxxxxxxxx + +# Vercel Team ID(个人项目无需填写) +# VERCEL_TEAM_ID= diff --git a/.gitignore b/.gitignore index 9517f58..758e1bd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.bak config.json +.env # Python __pycache__/ @@ -47,3 +48,35 @@ uvicorn.log # Vercel .vercel + +# Node.js / Frontend +node_modules/ +webui/node_modules/ +webui/dist/ +static/admin/ +.npm +.pnpm-store/ +package-lock.json +yarn.lock +pnpm-lock.yaml + +# Build artifacts +*.tsbuildinfo +.cache/ +.parcel-cache/ + +# Environment +.env.local +.env.*.local + +# Testing +.coverage +htmlcov/ +.pytest_cache/ +.tox/ + +# Misc +*.pyc +*.pyo +.git/ +Thumbs.db diff --git a/app.py b/app.py index cab6be4..e3f3556 100644 --- a/app.py +++ b/app.py @@ -1,53 +1,33 @@ -import base64 -import ctypes -import json -import logging +# -*- coding: utf-8 -*- +""" +DS2API - DeepSeek to OpenAI API 转换服务 + +支持: +- OpenAI 兼容接口: /v1/chat/completions, /v1/models +- Claude 兼容接口: /anthropic/v1/messages, /anthropic/v1/models + +使用方法: + 本地开发: python dev.py + 生产环境: uvicorn app:app --host 0.0.0.0 --port 5001 + Vercel: 自动部署 +""" import os -import sys -import queue -import random -import re -import struct -import threading -import time -import transformers -from curl_cffi import requests -from fastapi import FastAPI, HTTPException, Request + +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, StreamingResponse -from fastapi.templating import Jinja2Templates -from wasmtime import Linker, Module, Store +from fastapi.responses import JSONResponse -# -------------------------- 获取项目根目录 -------------------------- -BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -IS_VERCEL = bool(os.getenv("VERCEL")) or bool(os.getenv("NOW_REGION")) +from core.config import IS_VERCEL, logger - -def resolve_path(env_key: str, default_rel: str) -> str: - raw = os.getenv(env_key) - if raw: - return raw if os.path.isabs(raw) else os.path.join(BASE_DIR, raw) - return os.path.join(BASE_DIR, default_rel) - - -# -------------------------- 初始化 tokenizer -------------------------- -chat_tokenizer_dir = resolve_path("DS2API_TOKENIZER_DIR", "") -tokenizer = transformers.AutoTokenizer.from_pretrained( - chat_tokenizer_dir, trust_remote_code=True +# 创建 FastAPI 应用 +app = FastAPI( + title="DS2API", + description="DeepSeek to OpenAI/Claude API", + version="1.0.0", ) -# -------------------------- 日志配置 -------------------------- -logging.basicConfig( - level=os.getenv("LOG_LEVEL", "INFO").upper(), - format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", - handlers=[logging.StreamHandler(sys.stdout)], - force=True, -) -logger = logging.getLogger("main") - -app = FastAPI() - +# 全局异常处理 @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): logger.exception(f"[unhandled_exception] {request.method} {request.url.path}: {exc}") @@ -57,7 +37,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception): ) -# 添加 CORS 中间件,允许所有来源 +# CORS 中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -66,1916 +46,20 @@ app.add_middleware( allow_headers=["Content-Type", "Authorization"], ) -# 模板目录 -templates = Jinja2Templates(directory=resolve_path("DS2API_TEMPLATES_DIR", "templates")) +# 注册路由 +from routes.openai import router as openai_router +from routes.claude import router as claude_router +from routes.home import router as home_router +from routes.admin import router as admin_router -# ---------------------------------------------------------------------- -# (1) 配置文件的读写函数 -# ---------------------------------------------------------------------- -CONFIG_PATH = resolve_path("DS2API_CONFIG_PATH", "config.json") - - -def load_config(): - """加载配置。 - - 优先从环境变量读取: - - DS2API_CONFIG_JSON / CONFIG_JSON: 直接 JSON 字符串,或 base64 编码后的 JSON - - 若未提供环境变量,再从 CONFIG_PATH 指向的文件读取。 - """ - - raw_cfg = os.getenv("DS2API_CONFIG_JSON") or os.getenv("CONFIG_JSON") - if raw_cfg: - try: - return json.loads(raw_cfg) - except json.JSONDecodeError: - try: - decoded = base64.b64decode(raw_cfg).decode("utf-8") - return json.loads(decoded) - except Exception as e: - logger.warning(f"[load_config] 环境变量配置解析失败: {e}") - return {} - - try: - with open(CONFIG_PATH, "r", encoding="utf-8") as f: - return json.load(f) - except Exception as e: - logger.warning(f"[load_config] 无法读取配置文件({CONFIG_PATH}): {e}") - return {} - - -def save_config(cfg): - """将配置写回 config.json。 - - Vercel 环境文件系统通常是只读的;且如果配置来自环境变量,也无法回写。 - 所以这里失败不应影响主流程。 - """ - - if os.getenv("DS2API_CONFIG_JSON") or os.getenv("CONFIG_JSON"): - logger.info("[save_config] 配置来自环境变量,跳过写回") - return - - try: - with open(CONFIG_PATH, "w", encoding="utf-8") as f: - json.dump(cfg, f, ensure_ascii=False, indent=2) - except PermissionError as e: - logger.warning(f"[save_config] 配置文件不可写({CONFIG_PATH}): {e}") - except Exception as e: - logger.exception(f"[save_config] 写入 config.json 失败: {e}") - - -CONFIG = load_config() -if not CONFIG: - logger.warning( - "[config] 未加载到有效配置,请提供 config.json(路径可用 DS2API_CONFIG_PATH 指定)或设置环境变量 DS2API_CONFIG_JSON" - ) - -# -------------------------- 全局账号队列 -------------------------- -account_queue = [] # 维护所有可用账号 -claude_api_key_queue = [] # 维护所有可用的Claude API keys - - -def init_account_queue(): - """初始化时从配置加载账号""" - global account_queue - account_queue = CONFIG.get("accounts", [])[:] # 深拷贝 - random.shuffle(account_queue) # 初始随机排序 - - -def init_claude_api_key_queue(): - """Claude API keys由用户自己的token提供,这里初始化为空""" - global claude_api_key_queue - claude_api_key_queue = [] - - -init_account_queue() -init_claude_api_key_queue() - -# ---------------------------------------------------------------------- -# (2) DeepSeek 相关常量 -# ---------------------------------------------------------------------- -DEEPSEEK_HOST = "chat.deepseek.com" -DEEPSEEK_LOGIN_URL = f"https://{DEEPSEEK_HOST}/api/v0/users/login" -DEEPSEEK_CREATE_SESSION_URL = f"https://{DEEPSEEK_HOST}/api/v0/chat_session/create" -DEEPSEEK_CREATE_POW_URL = f"https://{DEEPSEEK_HOST}/api/v0/chat/create_pow_challenge" -DEEPSEEK_COMPLETION_URL = f"https://{DEEPSEEK_HOST}/api/v0/chat/completion" -BASE_HEADERS = { - "Host": "chat.deepseek.com", - "User-Agent": "DeepSeek/1.0.13 Android/35", - "Accept": "application/json", - "Accept-Encoding": "gzip", - "Content-Type": "application/json", - "x-client-platform": "android", - "x-client-version": "1.3.0-auto-resume", - "x-client-locale": "zh_CN", - "accept-charset": "UTF-8", -} - -# ---------------------------------------------------------------------- -# (2.1) Claude 相关常量 - 基于OpenAI接口转换 -# ---------------------------------------------------------------------- -CLAUDE_DEFAULT_MODEL = "claude-sonnet-4-20250514" # Claude统一默认模型 - -# WASM 模块文件路径 -WASM_PATH = resolve_path("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm") +app.include_router(openai_router) +app.include_router(claude_router) +app.include_router(home_router) +app.include_router(admin_router) # ---------------------------------------------------------------------- -# 辅助函数:获取账号唯一标识(优先 email,否则 mobile) -# ---------------------------------------------------------------------- -def get_account_identifier(account): - """返回账号的唯一标识,优先使用 email,否则使用 mobile""" - return account.get("email", "").strip() or account.get("mobile", "").strip() - - -# ---------------------------------------------------------------------- -# (3) 登录函数:支持使用 email 或 mobile 登录 -# ---------------------------------------------------------------------- -def login_deepseek_via_account(account): - """使用 account 中的 email 或 mobile 登录 DeepSeek, - 成功后将返回的 token 写入 account 并保存至配置文件,返回新 token。 - """ - email = account.get("email", "").strip() - mobile = account.get("mobile", "").strip() - password = account.get("password", "").strip() - if not password or (not email and not mobile): - raise HTTPException( - status_code=400, - detail="账号缺少必要的登录信息(必须提供 email 或 mobile 以及 password)", - ) - if email: - payload = { - "email": email, - "password": password, - "device_id": "deepseek_to_api", - "os": "android", - } - else: - payload = { - "mobile": mobile, - "area_code": None, - "password": password, - "device_id": "deepseek_to_api", - "os": "android", - } - try: - resp = requests.post(DEEPSEEK_LOGIN_URL, headers=BASE_HEADERS, json=payload, impersonate="safari15_3") - resp.raise_for_status() - except Exception as e: - logger.error(f"[login_deepseek_via_account] 登录请求异常: {e}") - raise HTTPException(status_code=500, detail="Account login failed: 请求异常") - try: - logger.warning(f"[login_deepseek_via_account] {resp.text}") - data = resp.json() - except Exception as e: - logger.error(f"[login_deepseek_via_account] JSON解析失败: {e}") - raise HTTPException( - status_code=500, detail="Account login failed: invalid JSON response" - ) - # 校验响应数据格式是否正确 - if ( - data.get("data") is None - or data["data"].get("biz_data") is None - or data["data"]["biz_data"].get("user") is None - ): - logger.error(f"[login_deepseek_via_account] 登录响应格式错误: {data}") - raise HTTPException( - status_code=500, detail="Account login failed: invalid response format" - ) - new_token = data["data"]["biz_data"]["user"].get("token") - if not new_token: - logger.error(f"[login_deepseek_via_account] 登录响应中缺少 token: {data}") - raise HTTPException( - status_code=500, detail="Account login failed: missing token" - ) - account["token"] = new_token - save_config(CONFIG) - return new_token - - -# ---------------------------------------------------------------------- -# (4) 从 accounts 中随机选择一个未忙且未尝试过的账号 -# ---------------------------------------------------------------------- -def choose_new_account(exclude_ids=None): - """选择策略: - 1. 遍历队列,找到第一个未被 exclude_ids 包含的账号 - 2. 从队列中移除该账号 - 3. 返回该账号(由后续逻辑保证最终会重新入队) - """ - if exclude_ids is None: - exclude_ids = [] - - for i in range(len(account_queue)): - acc = account_queue[i] - acc_id = get_account_identifier(acc) - if acc_id and acc_id not in exclude_ids: - # 从队列中移除并返回 - logger.info(f"[choose_new_account] 新选择账号: {acc_id}") - return account_queue.pop(i) - - logger.warning("[choose_new_account] 没有可用的账号或所有账号都在使用中") - return None - - -def release_account(account): - """将账号重新加入队列末尾""" - account_queue.append(account) - - -# ---------------------------------------------------------------------- -# Claude API key 管理函数(简化版本) -# ---------------------------------------------------------------------- -def choose_claude_api_key(): - """选择一个可用的Claude API key - 现在直接由用户提供""" - return None - - -def release_claude_api_key(api_key): - """释放Claude API key - 现在无需操作""" - pass - - -# ---------------------------------------------------------------------- -# (5) 判断调用模式:配置模式 vs 用户自带 token -# ---------------------------------------------------------------------- -def determine_mode_and_token(request: Request): - """ - 根据请求头 Authorization 判断使用哪种模式: - - 如果 Bearer token 出现在 CONFIG["keys"] 中,则为配置模式,从 CONFIG["accounts"] 中随机选择一个账号(排除已尝试账号), - 检查该账号是否已有 token,否则调用登录接口获取; - - 否则,直接使用请求中的 Bearer 值作为 DeepSeek token。 - 结果存入 request.state.deepseek_token;配置模式下同时存入 request.state.account 与 request.state.tried_accounts。 - """ - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - raise HTTPException( - status_code=401, detail="Unauthorized: missing Bearer token." - ) - caller_key = auth_header.replace("Bearer ", "", 1).strip() - config_keys = CONFIG.get("keys", []) - if caller_key in config_keys: - request.state.use_config_token = True - request.state.tried_accounts = [] # 初始化已尝试账号 - selected_account = choose_new_account() - if not selected_account: - raise HTTPException( - status_code=429, - detail="No accounts configured or all accounts are busy.", - ) - if not selected_account.get("token", "").strip(): - try: - login_deepseek_via_account(selected_account) - except Exception as e: - logger.error( - f"[determine_mode_and_token] 账号 {get_account_identifier(selected_account)} 登录失败:{e}" - ) - raise HTTPException(status_code=500, detail="Account login failed.") - - request.state.deepseek_token = selected_account.get("token") - request.state.account = selected_account - - else: - request.state.use_config_token = False - request.state.deepseek_token = caller_key - - -def get_auth_headers(request: Request): - """返回 DeepSeek 请求所需的公共请求头""" - return {**BASE_HEADERS, "authorization": f"Bearer {request.state.deepseek_token}"} - - -# ---------------------------------------------------------------------- -# Claude 认证相关函数 -# ---------------------------------------------------------------------- -def determine_claude_mode_and_token(request: Request): - """ - Claude认证:沿用现有的OpenAI接口认证逻辑 - """ - # 直接调用现有的认证逻辑 - determine_mode_and_token(request) - - -# ---------------------------------------------------------------------- -# OpenAI到Claude格式转换函数 -# ---------------------------------------------------------------------- -def convert_claude_to_deepseek(claude_request): - """将Claude格式的请求转换为DeepSeek格式(基于现有OpenAI接口)""" - messages = claude_request.get("messages", []) - model = claude_request.get("model", CLAUDE_DEFAULT_MODEL) - - # 从配置文件读取Claude模型映射 - claude_mapping = CONFIG.get("claude_model_mapping", { - "fast": "deepseek-chat", - "slow": "deepseek-chat" - }) - - # Claude模型映射到DeepSeek模型 - 基于配置和模型特征判断 - if "opus" in model.lower() or "reasoner" in model.lower() or "slow" in model.lower(): - deepseek_model = claude_mapping.get("slow", "deepseek-chat") - else: - deepseek_model = claude_mapping.get("fast", "deepseek-chat") - - deepseek_request = { - "model": deepseek_model, - "messages": messages.copy() - } - - # 处理system消息 - 将system参数转换为system role消息 - if "system" in claude_request: - system_msg = {"role": "system", "content": claude_request["system"]} - deepseek_request["messages"].insert(0, system_msg) - - # 添加可选参数 - if "temperature" in claude_request: - deepseek_request["temperature"] = claude_request["temperature"] - if "top_p" in claude_request: - deepseek_request["top_p"] = claude_request["top_p"] - if "stop_sequences" in claude_request: - deepseek_request["stop"] = claude_request["stop_sequences"] - if "stream" in claude_request: - deepseek_request["stream"] = claude_request["stream"] - - return deepseek_request - - -def convert_deepseek_to_claude_format(deepseek_response, original_claude_model=CLAUDE_DEFAULT_MODEL): - """将DeepSeek响应转换为Claude格式的OpenAI响应""" - # DeepSeek响应已经是OpenAI格式,只需要修改模型名称 - if isinstance(deepseek_response, dict): - claude_response = deepseek_response.copy() - claude_response["model"] = original_claude_model - return claude_response - - return deepseek_response - - - - - - - - -# ---------------------------------------------------------------------- -# Claude API 调用函数 -# ---------------------------------------------------------------------- -async def call_claude_via_openai(request: Request, claude_payload): - """通过现有OpenAI接口调用Claude(实际调用DeepSeek)""" - # 将Claude请求转换为DeepSeek请求 - deepseek_payload = convert_claude_to_deepseek(claude_payload) - - # 直接调用现有的chat_completions逻辑 - try: - # 使用现有的逻辑创建session和pow - session_id = create_session(request) - if not session_id: - raise HTTPException(status_code=401, detail="invalid token.") - - pow_resp = get_pow_response(request) - if not pow_resp: - raise HTTPException( - status_code=401, - detail="Failed to get PoW (invalid token or unknown error).", - ) - - # 准备DeepSeek API调用 - model = deepseek_payload.get("model", "deepseek-chat") - messages = deepseek_payload.get("messages", []) - - # 判断模型特性 - model_lower = model.lower() - if model_lower in ["deepseek-v3", "deepseek-chat"]: - thinking_enabled = False - search_enabled = False - elif model_lower in ["deepseek-r1", "deepseek-reasoner"]: - thinking_enabled = True - search_enabled = False - elif model_lower in ["deepseek-v3-search", "deepseek-chat-search"]: - thinking_enabled = False - search_enabled = True - elif model_lower in ["deepseek-r1-search", "deepseek-reasoner-search"]: - thinking_enabled = True - search_enabled = True - else: - thinking_enabled = False - search_enabled = False - - # 使用 messages_prepare 函数构造最终 prompt - final_prompt = messages_prepare(messages) - - headers = {**get_auth_headers(request), "x-ds-pow-response": pow_resp} - payload = { - "chat_session_id": session_id, - "parent_message_id": None, - "prompt": final_prompt, - "ref_file_ids": [], - "thinking_enabled": thinking_enabled, - "search_enabled": search_enabled, - } - - deepseek_resp = call_completion_endpoint(payload, headers, max_attempts=3) - return deepseek_resp - - except Exception as e: - logger.error(f"[call_claude_via_openai] 调用失败: {e}") - return None - - -# ---------------------------------------------------------------------- -# (6) 封装对话接口调用的重试机制 -# ---------------------------------------------------------------------- -def call_completion_endpoint(payload, headers, max_attempts=3): - attempts = 0 - while attempts < max_attempts: - try: - deepseek_resp = requests.post( - DEEPSEEK_COMPLETION_URL, headers=headers, json=payload, stream=True, impersonate="safari15_3" - ) - except Exception as e: - logger.warning(f"[call_completion_endpoint] 请求异常: {e}") - time.sleep(1) - attempts += 1 - continue - if deepseek_resp.status_code == 200: - return deepseek_resp - else: - logger.warning( - f"[call_completion_endpoint] 调用对话接口失败, 状态码: {deepseek_resp.status_code}" - ) - deepseek_resp.close() - time.sleep(1) - attempts += 1 - return None - - -# ---------------------------------------------------------------------- -# (7) 创建会话 & 获取 PoW(重试时,配置模式下错误会切换账号;用户自带 token 模式下仅重试) -# ---------------------------------------------------------------------- -def create_session(request: Request, max_attempts=3): - attempts = 0 - while attempts < max_attempts: - headers = get_auth_headers(request) - try: - resp = requests.post( - DEEPSEEK_CREATE_SESSION_URL, headers=headers, json={"agent": "chat"}, impersonate="safari15_3" - ) - except Exception as e: - logger.error(f"[create_session] 请求异常: {e}") - attempts += 1 - continue - try: - logger.warning(f"[create_session] {resp.text}") - data = resp.json() - - except Exception as e: - logger.error(f"[create_session] JSON解析异常: {e}") - data = {} - if resp.status_code == 200 and data.get("code") == 0: - session_id = data["data"]["biz_data"]["id"] - - resp.close() - return session_id - else: - code = data.get("code") - logger.warning( - f"[create_session] 创建会话失败, code={code}, msg={data.get('msg')}" - ) - resp.close() - if request.state.use_config_token: - current_id = get_account_identifier(request.state.account) - if not hasattr(request.state, "tried_accounts"): - request.state.tried_accounts = [] - if current_id not in request.state.tried_accounts: - request.state.tried_accounts.append(current_id) - new_account = choose_new_account(request.state.tried_accounts) - if new_account is None: - break - try: - login_deepseek_via_account(new_account) - except Exception as e: - logger.error( - f"[create_session] 账号 {get_account_identifier(new_account)} 登录失败:{e}" - ) - attempts += 1 - continue - request.state.account = new_account - request.state.deepseek_token = new_account.get("token") - else: - attempts += 1 - continue - attempts += 1 - return None - - -# ---------------------------------------------------------------------- -# (7.1) 使用 WASM 模块计算 PoW 答案的辅助函数 -# ---------------------------------------------------------------------- -def compute_pow_answer( - algorithm: str, - challenge_str: str, - salt: str, - difficulty: int, - expire_at: int, - signature: str, - target_path: str, - wasm_path: str, -) -> int: - """ - 使用 WASM 模块计算 DeepSeekHash 答案(answer)。 - 根据 JS 逻辑: - - 拼接前缀: "{salt}_{expire_at}_" - - 将 challenge 与前缀写入 wasm 内存后调用 wasm_solve 进行求解, - - 从 wasm 内存中读取状态与求解结果, - - 若状态非 0,则返回整数形式的答案,否则返回 None。 - """ - if algorithm != "DeepSeekHashV1": - raise ValueError(f"不支持的算法:{algorithm}") - prefix = f"{salt}_{expire_at}_" - # --- 加载 wasm 模块 --- - store = Store() - linker = Linker(store.engine) - try: - with open(wasm_path, "rb") as f: - wasm_bytes = f.read() - except Exception as e: - raise RuntimeError(f"加载 wasm 文件失败: {wasm_path}, 错误: {e}") - module = Module(store.engine, wasm_bytes) - instance = linker.instantiate(store, module) - exports = instance.exports(store) - try: - memory = exports["memory"] - add_to_stack = exports["__wbindgen_add_to_stack_pointer"] - alloc = exports["__wbindgen_export_0"] - wasm_solve = exports["wasm_solve"] - except KeyError as e: - raise RuntimeError(f"缺少 wasm 导出函数: {e}") - - def write_memory(offset: int, data: bytes): - size = len(data) - base_addr = ctypes.cast(memory.data_ptr(store), ctypes.c_void_p).value - ctypes.memmove(base_addr + offset, data, size) - - def read_memory(offset: int, size: int) -> bytes: - base_addr = ctypes.cast(memory.data_ptr(store), ctypes.c_void_p).value - return ctypes.string_at(base_addr + offset, size) - - def encode_string(text: str): - data = text.encode("utf-8") - length = len(data) - ptr_val = alloc(store, length, 1) - ptr = int(ptr_val.value) if hasattr(ptr_val, "value") else int(ptr_val) - write_memory(ptr, data) - return ptr, length - - # 1. 申请 16 字节栈空间 - retptr = add_to_stack(store, -16) - # 2. 编码 challenge 与 prefix 到 wasm 内存中 - ptr_challenge, len_challenge = encode_string(challenge_str) - ptr_prefix, len_prefix = encode_string(prefix) - # 3. 调用 wasm_solve(注意:difficulty 以 float 形式传入) - wasm_solve( - store, - retptr, - ptr_challenge, - len_challenge, - ptr_prefix, - len_prefix, - float(difficulty), - ) - # 4. 从 retptr 处读取 4 字节状态和 8 字节求解结果 - status_bytes = read_memory(retptr, 4) - if len(status_bytes) != 4: - add_to_stack(store, 16) - raise RuntimeError("读取状态字节失败") - status = struct.unpack(" str: - """处理消息列表,合并连续相同角色的消息,并添加角色标签: - - 对于 assistant 消息,加上 <|Assistant|> 前缀及 <|end▁of▁sentence|> 结束标签; - - 对于 user/system 消息(除第一条外)加上 <|User|> 前缀; - - 如果消息 content 为数组,则提取其中 type 为 "text" 的部分; - - 最后移除 markdown 图片格式的内容。 - """ - processed = [] - for m in messages: - role = m.get("role", "") - content = m.get("content", "") - if isinstance(content, list): - texts = [ - item.get("text", "") for item in content if item.get("type") == "text" - ] - text = "\n".join(texts) - else: - text = str(content) - processed.append({"role": role, "text": text}) - if not processed: - return "" - # 合并连续同一角色的消息 - merged = [processed[0]] - for msg in processed[1:]: - if msg["role"] == merged[-1]["role"]: - merged[-1]["text"] += "\n\n" + msg["text"] - else: - merged.append(msg) - # 添加标签 - parts = [] - for idx, block in enumerate(merged): - role = block["role"] - text = block["text"] - if role == "assistant": - parts.append(f"<|Assistant|>{text}<|end▁of▁sentence|>") - elif role in ("user", "system"): - if idx > 0: - parts.append(f"<|User|>{text}") - else: - parts.append(text) - else: - parts.append(text) - final_prompt = "".join(parts) - # 仅移除 markdown 图片格式(不全部移除 !) - final_prompt = re.sub(r"!\[(.*?)\]\((.*?)\)", r"[\1](\2)", final_prompt) - return final_prompt - - -# 添加保活超时配置(5秒) -KEEP_ALIVE_TIMEOUT = 5 - - -# ---------------------------------------------------------------------- -# (10) 路由:/v1/chat/completions -# ---------------------------------------------------------------------- -@app.post("/v1/chat/completions") -async def chat_completions(request: Request): - try: - # 处理 token 相关逻辑,若登录失败则直接返回错误响应 - try: - determine_mode_and_token(request) - except HTTPException as exc: - return JSONResponse( - status_code=exc.status_code, content={"error": exc.detail} - ) - except Exception as exc: - logger.error(f"[chat_completions] determine_mode_and_token 异常: {exc}") - return JSONResponse( - status_code=500, content={"error": "Account login failed."} - ) - - req_data = await request.json() - model = req_data.get("model") - messages = req_data.get("messages", []) - if not model or not messages: - raise HTTPException( - status_code=400, detail="Request must include 'model' and 'messages'." - ) - # 判断是否启用"思考"或"搜索"功能(这里根据模型名称判断) - model_lower = model.lower() - if model_lower in ["deepseek-v3", "deepseek-chat"]: - thinking_enabled = False - search_enabled = False - elif model_lower in ["deepseek-r1", "deepseek-reasoner"]: - thinking_enabled = True - search_enabled = False - elif model_lower in ["deepseek-v3-search", "deepseek-chat-search"]: - thinking_enabled = False - search_enabled = True - elif model_lower in ["deepseek-r1-search", "deepseek-reasoner-search"]: - thinking_enabled = True - search_enabled = True - else: - raise HTTPException( - status_code=503, detail=f"Model '{model}' is not available." - ) - # 使用 messages_prepare 函数构造最终 prompt - final_prompt = messages_prepare(messages) - session_id = create_session(request) - if not session_id: - raise HTTPException(status_code=401, detail="invalid token.") - pow_resp = get_pow_response(request) - if not pow_resp: - raise HTTPException( - status_code=401, - detail="Failed to get PoW (invalid token or unknown error).", - ) - headers = {**get_auth_headers(request), "x-ds-pow-response": pow_resp} - payload = { - "chat_session_id": session_id, - "parent_message_id": None, - "prompt": final_prompt, - "ref_file_ids": [], - "thinking_enabled": thinking_enabled, - "search_enabled": search_enabled, - } - - deepseek_resp = call_completion_endpoint(payload, headers, max_attempts=3) - if not deepseek_resp: - raise HTTPException(status_code=500, detail="Failed to get completion.") - created_time = int(time.time()) - completion_id = f"{session_id}" - - # 流式响应(SSE)或普通响应 - if bool(req_data.get("stream", False)): - if deepseek_resp.status_code != 200: - deepseek_resp.close() - return JSONResponse( - content=deepseek_resp.content, status_code=deepseek_resp.status_code - ) - - def sse_stream(): - try: - final_text = "" - final_thinking = "" - first_chunk_sent = False - result_queue = queue.Queue() - last_send_time = time.time() - citation_map = {} # 用于存储引用链接的字典 - - def process_data(): - ptype = "text" - try: - for raw_line in deepseek_resp.iter_lines(): - try: - line = raw_line.decode("utf-8") - except Exception as e: - logger.warning(f"[sse_stream] 解码失败: {e}") - # 根据当前模式决定错误消息类型 - error_type = "thinking" if ptype == "thinking" else "text" - busy_content_str = f'{{"choices":[{{"index":0,"delta":{{"content":"解码失败,请稍候再试","type":"{error_type}"}}}}],"model":"","chunk_token_usage":1,"created":0,"message_id":-1,"parent_id":-1}}' - try: - busy_content = json.loads(busy_content_str) - result_queue.put(busy_content) - except json.JSONDecodeError: - # 如果JSON解析也失败,创建最基本的错误响应 - result_queue.put({"choices": [{"index": 0, "delta": {"content": "解码失败", "type": "text"}}]}) - result_queue.put(None) - break - if not line: - continue - if line.startswith("data:"): - data_str = line[5:].strip() - if data_str == "[DONE]": - result_queue.put(None) # 结束信号 - break - try: - chunk = json.loads(data_str) - - if "v" in chunk: - v_value = chunk["v"] - - # 构造新的 delta 格式的 chunk - content = "" - - if "p" in chunk and chunk.get("p") == "response/search_status": - continue - - if "p" in chunk and chunk.get("p") == "response/thinking_content": - ptype = "thinking" - elif "p" in chunk and chunk.get("p") == "response/content": - ptype = "text" - - # 处理文本内容 - if isinstance(v_value, str): - content = v_value - # 处理数组更新如状态变更 - elif isinstance(v_value, list): - for item in v_value: - if item.get("p") == "status" and item.get("v") == "FINISHED": - # 最终完成信号 - result_queue.put({"choices": [{"index": 0, "finish_reason": "stop"}]}) - result_queue.put(None) - return - continue - - # 构造兼容原逻辑的 chunk - unified_chunk = { - "choices": [{ - "index": 0, - "delta": { - "content": content, - "type": ptype - } - }], - "model": "", - "chunk_token_usage": len(content) // 4, # 简单估算token数 - "created": 0, - "message_id": -1, - "parent_id": -1 - } - - result_queue.put(unified_chunk) - except Exception as e: - logger.warning( - f"[sse_stream] 无法解析: {data_str}, 错误: {e}" - ) - # 根据当前模式决定错误消息类型 - error_type = "thinking" if ptype == "thinking" else "text" - busy_content_str = f'{{"choices":[{{"index":0,"delta":{{"content":"解析失败,请稍候再试","type":"{error_type}"}}}}],"model":"","chunk_token_usage":1,"created":0,"message_id":-1,"parent_id":-1}}' - try: - busy_content = json.loads(busy_content_str) - result_queue.put(busy_content) - except json.JSONDecodeError: - # 如果JSON解析也失败,创建最基本的错误响应 - result_queue.put({"choices": [{"index": 0, "delta": {"content": "解析失败", "type": "text"}}]}) - result_queue.put(None) - break - except Exception as e: - logger.warning(f"[sse_stream] 错误: {e}") - # 创建基本的错误响应,不依赖JSON解析 - try: - error_response = {"choices": [{"index": 0, "delta": {"content": "服务器错误,请稍候再试", "type": "text"}}]} - result_queue.put(error_response) - except Exception: - # 最终备选方案 - pass - result_queue.put(None) - # raise HTTPException( - # status_code=500, detail="Server is error." - # ) - finally: - deepseek_resp.close() - - process_thread = threading.Thread(target=process_data) - process_thread.start() - - while True: - current_time = time.time() - if current_time - last_send_time >= KEEP_ALIVE_TIMEOUT: - - yield ": keep-alive\n\n" - last_send_time = current_time - continue - try: - chunk = result_queue.get(timeout=0.05) - if chunk is None: - # 发送最终统计信息 - prompt_tokens = len(final_prompt) // 4 # 简单估算token数 - thinking_tokens = len(final_thinking) // 4 # 简单估算token数 - completion_tokens = len(final_text) // 4 # 简单估算token数 - usage = { - "prompt_tokens": prompt_tokens, - "completion_tokens": thinking_tokens + completion_tokens, - "total_tokens": prompt_tokens + thinking_tokens + completion_tokens, - "completion_tokens_details": { - "reasoning_tokens": thinking_tokens - }, - } - finish_chunk = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created_time, - "model": model, - "choices": [ - { - "delta": {}, - "index": 0, - "finish_reason": "stop", - } - ], - "usage": usage, - } - yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - last_send_time = current_time - break - new_choices = [] - for choice in chunk.get("choices", []): - delta = choice.get("delta", {}) - ctype = delta.get("type") - ctext = delta.get("content", "") - if ( - choice - .get("finish_reason") - == "backend_busy" - ): - ctext = '服务器繁忙,请稍候再试' - if search_enabled and ctext.startswith("[citation:"): - ctext = "" - if ctype == "thinking": - if thinking_enabled: - final_thinking += ctext - elif ctype == "text": - final_text += ctext - delta_obj = {} - if not first_chunk_sent: - delta_obj["role"] = "assistant" - first_chunk_sent = True - if ctype == "thinking": - if thinking_enabled: - delta_obj["reasoning_content"] = ctext - elif ctype == "text": - delta_obj["content"] = ctext - if delta_obj: - new_choices.append( - { - "delta": delta_obj, - "index": choice.get("index", 0), - } - ) - if new_choices: - out_chunk = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created_time, - "model": model, - "choices": new_choices, - } - yield f"data: {json.dumps(out_chunk, ensure_ascii=False)}\n\n" - last_send_time = current_time - except queue.Empty: - continue - except Exception as e: - logger.error(f"[sse_stream] 异常: {e}") - finally: - if getattr(request.state, "use_config_token", False) and hasattr( - request.state, "account" - ): - release_account(request.state.account) - - return StreamingResponse( - sse_stream(), - media_type="text/event-stream", - headers={"Content-Type": "text/event-stream"}, - ) - else: - # 非流式响应处理 - think_list = [] - text_list = [] - result = None - citation_map = {} - - data_queue = queue.Queue() - - def collect_data(): - nonlocal result - ptype = "text" - try: - for raw_line in deepseek_resp.iter_lines(): - try: - line = raw_line.decode("utf-8") - except Exception as e: - logger.warning(f"[chat_completions] 解码失败: {e}") - # 根据当前处理类型添加错误消息 - if ptype == "thinking": - think_list.append('解码失败,请稍候再试') - else: - text_list.append('解码失败,请稍候再试') - data_queue.put(None) - break - if not line: - continue - if line.startswith("data:"): - data_str = line[5:].strip() - if data_str == "[DONE]": - data_queue.put(None) - break - try: - chunk = json.loads(data_str) - - # 提取 v 字段 - if "v" in chunk: - v_value = chunk["v"] - - if "p" in chunk and chunk.get("p") == "response/search_status": - continue - - if "p" in chunk and chunk.get("p") == "response/thinking_content": - ptype = "thinking" - elif "p" in chunk and chunk.get("p") == "response/content": - ptype = "text" - - # 处理字符串形式的 v 值(即文本内容) - if isinstance(v_value, str): - if search_enabled and v_value.startswith("[citation:"): - continue # 跳过 citation 内容 - if ptype == "thinking": - think_list.append(v_value) - else: - text_list.append(v_value) - - # 处理数组更新如状态变更 - elif isinstance(v_value, list): - for item in v_value: - if item.get("p") == "status" and item.get("v") == "FINISHED": - # 构建最终结果 - final_reasoning = "".join(think_list) - final_content = "".join(text_list) - prompt_tokens = len(final_prompt) // 4 # 简单估算token数 - reasoning_tokens = len(final_reasoning) // 4 # 简单估算token数 - completion_tokens = len(final_content) // 4 # 简单估算token数 - result = { - "id": completion_id, - "object": "chat.completion", - "created": created_time, - "model": model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": final_content, - "reasoning_content": final_reasoning, - }, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": reasoning_tokens + completion_tokens, - "total_tokens": prompt_tokens + reasoning_tokens + completion_tokens, - "completion_tokens_details": { - "reasoning_tokens": reasoning_tokens - }, - }, - } - data_queue.put("DONE") - return # 提前返回,结束函数 - - except Exception as e: - logger.warning(f"[collect_data] 无法解析: {data_str}, 错误: {e}") - # 根据当前处理类型添加错误消息 - if ptype == "thinking": - think_list.append('解析失败,请稍候再试') - else: - text_list.append('解析失败,请稍候再试') - data_queue.put(None) - break - except Exception as e: - logger.warning(f"[collect_data] 错误: {e}") - # 根据当前处理类型添加错误消息 - if ptype == "thinking": - think_list.append('处理失败,请稍候再试') - else: - text_list.append('处理失败,请稍候再试') - data_queue.put(None) - finally: - deepseek_resp.close() - if result is None: - # 如果没有提前构造 result,则构造默认结果 - final_content = "".join(text_list) - final_reasoning = "".join(think_list) # 修复:应该使用think_list而不是text_list - prompt_tokens = len(final_prompt) // 4 # 简单估算token数 - reasoning_tokens = len(final_reasoning) // 4 # 简单估算token数 - completion_tokens = len(final_content) // 4 # 简单估算token数 - result = { - "id": completion_id, - "object": "chat.completion", - "created": created_time, - "model": model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": final_content, - "reasoning_content": final_reasoning, - }, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": reasoning_tokens + completion_tokens, - "total_tokens": prompt_tokens + reasoning_tokens + completion_tokens, - }, - } - data_queue.put("DONE") - - collect_thread = threading.Thread(target=collect_data) - collect_thread.start() - - def generate(): - last_send_time = time.time() - while True: - current_time = time.time() - if current_time - last_send_time >= KEEP_ALIVE_TIMEOUT: - - yield "" - last_send_time = current_time - if not collect_thread.is_alive() and result is not None: - yield json.dumps(result) - break - time.sleep(0.1) - - return StreamingResponse(generate(), media_type="application/json") - except HTTPException as exc: - return JSONResponse(status_code=exc.status_code, content={"error": exc.detail}) - except Exception as exc: - logger.error(f"[chat_completions] 未知异常: {exc}") - return JSONResponse(status_code=500, content={"error": "Internal Server Error"}) - finally: - if getattr(request.state, "use_config_token", False) and hasattr( - request.state, "account" - ): - release_account(request.state.account) - - -# ---------------------------------------------------------------------- -# Claude 路由:/anthropic/v1/messages -# ---------------------------------------------------------------------- -@app.post("/anthropic/v1/messages") -async def claude_messages(request: Request): - try: - # 处理 token 相关逻辑,若认证失败则直接返回错误响应 - try: - determine_claude_mode_and_token(request) - except HTTPException as exc: - return JSONResponse( - status_code=exc.status_code, content={"error": exc.detail} - ) - except Exception as exc: - logger.error(f"[claude_messages] determine_claude_mode_and_token 异常: {exc}") - return JSONResponse( - status_code=500, content={"error": "Claude authentication failed."} - ) - - req_data = await request.json() - model = req_data.get("model") - messages = req_data.get("messages", []) - - if not model or not messages: - raise HTTPException( - status_code=400, detail="Request must include 'model' and 'messages'." - ) - - # 标准化消息内容 - 确保Claude Code兼容性 - normalized_messages = [] - for message in messages: - normalized_message = message.copy() - if isinstance(message.get("content"), list): - # 将数组内容转换为单一字符串 - 改进版本 - content_parts = [] - for content_block in message["content"]: - if content_block.get("type") == "text" and "text" in content_block: - content_parts.append(content_block["text"]) - elif content_block.get("type") == "tool_result": - # 保持工具结果格式不变,但提取内容用于处理 - if "content" in content_block: - content_parts.append(str(content_block["content"])) - # 确保内容非空,避免空字符串导致的问题 - if content_parts: - normalized_message["content"] = "\n".join(content_parts) - elif isinstance(message.get("content"), list) and message["content"]: - # 如果没有提取到文本内容,保持原始格式 - normalized_message["content"] = message["content"] - else: - normalized_message["content"] = "" - normalized_messages.append(normalized_message) - - # 处理工具使用请求 - tools_requested = req_data.get("tools") or [] - has_tools = len(tools_requested) > 0 - - # 检查是否包含工具结果(tool_result) - has_tool_result = False - for message in messages: - if isinstance(message.get("content"), list): - for content_block in message["content"]: - if content_block.get("type") == "tool_result": - has_tool_result = True - break - - # 处理Claude格式请求(使用标准化后的消息) - payload = req_data.copy() - payload["messages"] = normalized_messages.copy() - - # 如果有工具定义,添加工具使用指导的系统消息 - if has_tools and not any(m.get("role") == "system" for m in payload["messages"]): - tool_schemas = [] - for tool in tools_requested: - tool_name = tool.get('name', 'unknown') - tool_desc = tool.get('description', 'No description available') - schema = tool.get('input_schema', {}) - - tool_info = f"Tool: {tool_name}\nDescription: {tool_desc}" - if 'properties' in schema: - props = [] - required = schema.get('required', []) - for prop_name, prop_info in schema['properties'].items(): - prop_type = prop_info.get('type', 'string') - is_req = ' (required)' if prop_name in required else '' - props.append(f" - {prop_name}: {prop_type}{is_req}") - if props: - tool_info += f"\nParameters:\n{chr(10).join(props)}" - tool_schemas.append(tool_info) - - system_message = { - "role": "system", - "content": f"""You are Claude, a helpful AI assistant. You have access to these tools: - -{chr(10).join(tool_schemas)} - -When you need to use tools, you can call multiple tools in a single response. Use this format: - -{{"tool_calls": [ - {{"name": "tool1", "input": {{"param": "value"}}}}, - {{"name": "tool2", "input": {{"param": "value"}}}} -]}} - -IMPORTANT: You can call multiple tools in ONE response. If you need to: -1. Create a directory - include that in tool_calls -2. Write a file - include that in the SAME tool_calls array -3. Run a command - include that in the SAME tool_calls array - -Example of multiple tool calls in one response: -{{"tool_calls": [ - {{"name": "str_replace_editor", "input": {{"command": "create", "path": "pp1/hello.py", "file_text": "print('Hello, World!')"}}}}, - {{"name": "Bash", "input": {{"command": "python pp1/hello.py"}}}} -]}} - -Examples: -- For TodoWrite: {{"name": "TodoWrite", "input": {{"todos": [{{"content": "task", "status": "pending", "activeForm": "doing task"}}]}}}} -- For str_replace_editor: {{"name": "str_replace_editor", "input": {{"command": "create", "path": "file.py", "file_text": "code"}}}} -- For Bash: {{"name": "Bash", "input": {{"command": "cd /path && python file.py"}}}} - -Remember: Output ONLY the JSON, no other text. The response must start with {{ and end with ]}}""" - } - payload["messages"].insert(0, system_message) - - deepseek_resp = await call_claude_via_openai(request, payload) - if not deepseek_resp: - raise HTTPException(status_code=500, detail="Failed to get Claude response.") - - created_time = int(time.time()) - - # 处理响应 - if deepseek_resp.status_code != 200: - deepseek_resp.close() - return JSONResponse( - status_code=500, - content={"error": {"type": "api_error", "message": "Failed to get response"}} - ) - - # 流式响应或普通响应 - if bool(req_data.get("stream", False)): - def claude_sse_stream(): - try: - message_id = f"msg_{int(time.time())}_{random.randint(1000, 9999)}" - input_tokens = sum(len(str(m.get("content", ""))) for m in messages) // 4 - output_tokens = 0 - - # 收集所有响应内容 - full_response_text = "" - response_completed = False - - # 解析DeepSeek流式响应 - for line in deepseek_resp.iter_lines(): - if not line: - continue - try: - line_str = line.decode('utf-8') - except Exception: - continue - - if line_str.startswith('data:'): - data_str = line_str[5:].strip() - if data_str == '[DONE]': - response_completed = True - break - - try: - chunk = json.loads(data_str) - if "v" in chunk and isinstance(chunk["v"], str): - full_response_text += chunk["v"] - elif "v" in chunk and isinstance(chunk["v"], list): - # 检查完成状态 - for item in chunk["v"]: - if item.get("p") == "status" and item.get("v") == "FINISHED": - response_completed = True - break - except (json.JSONDecodeError, KeyError): - continue - - # 现在一次性发送Claude格式的事件 - - # 1. message_start - message_start = { - "type": "message_start", - "message": { - "id": message_id, - "type": "message", - "role": "assistant", - "model": model, - "content": [], - "stop_reason": None, - "stop_sequence": None, - "usage": {"input_tokens": input_tokens, "output_tokens": 0} - } - } - yield f"data: {json.dumps(message_start)}\n\n" - - # 2. 检查是否有工具调用 - 改进的检测逻辑 - detected_tools = [] - - # 清理响应文本 - cleaned_response = full_response_text.strip() - - # 记录原始响应用于调试 - logger.debug(f"[Tool Detection] Raw response: {cleaned_response[:500] if cleaned_response else 'Empty'}") - - # 尝试多种工具调用检测方法 - detected_tools = [] - tool_detected = False - - # 方法1: 检测完整的JSON格式 - if cleaned_response.startswith('{"tool_calls":') and cleaned_response.endswith(']}'): - logger.info(f"[Tool Detection] Method 1: Found tool calls JSON") - try: - tool_data = json.loads(cleaned_response) - for tool_call in tool_data.get('tool_calls', []): - tool_name = tool_call.get('name') - tool_input = tool_call.get('input', {}) - - # 检查是否是有效的工具名称 - if any(tool.get('name') == tool_name for tool in tools_requested): - detected_tools.append({ - 'name': tool_name, - 'input': tool_input - }) - tool_detected = True - except json.JSONDecodeError: - pass - - # 方法2: 使用正则表达式检测嵌入的JSON - if not tool_detected: - tool_call_pattern = r'\{\s*["\']tool_calls["\']\s*:\s*\[(.*?)\]\s*\}' - matches = re.findall(tool_call_pattern, cleaned_response, re.DOTALL) - - for match in matches: - try: - # 尝试解析工具调用JSON - tool_calls_json = f'{{"tool_calls": [{match}]}}' - tool_data = json.loads(tool_calls_json) - - for tool_call in tool_data.get('tool_calls', []): - tool_name = tool_call.get('name') - tool_input = tool_call.get('input', {}) - - # 检查是否是有效的工具名称 - if any(tool.get('name') == tool_name for tool in tools_requested): - detected_tools.append({ - 'name': tool_name, - 'input': tool_input - }) - tool_detected = True - except json.JSONDecodeError: - continue - - # 方法3: 检测特定工具名称的直接调用 (已禁用以避免重复执行) - # 注意:这个方法可能导致Claude Code重复执行命令 - # 当检测到工具名但没有具体参数时,它会返回空的input - # Claude Code接收到这种响应后会尝试重新执行 - # 因此暂时禁用此方法,只依赖方法1和方法2的精确JSON匹配 - ''' - if not tool_detected: - for tool in tools_requested: - tool_name = tool.get('name') - # 检测如 "TodoWrite" 这样的直接工具名称提及 - if tool_name in cleaned_response and any(keyword in cleaned_response.lower() for keyword in ['call', 'use', 'invoke', 'execute']): - # 尝试从上下文推断参数 - detected_tools.append({ - 'name': tool_name, - 'input': {} # 空参数,让调用方处理 - }) - tool_detected = True - break - ''' - - content_index = 0 - if detected_tools: - # 有工具调用 - stop_reason = "tool_use" - for tool_info in detected_tools: - tool_use_id = f"toolu_{int(time.time())}_{random.randint(1000, 9999)}_{content_index}" - tool_name = tool_info['name'] - tool_input = tool_info['input'] - - # content_block_start - yield f"data: {json.dumps({'type': 'content_block_start', 'index': content_index, 'content_block': {'type': 'tool_use', 'id': tool_use_id, 'name': tool_name, 'input': tool_input}})}\n\n" - - # content_block_stop - yield f"data: {json.dumps({'type': 'content_block_stop', 'index': content_index})}\n\n" - - content_index += 1 - output_tokens += len(str(tool_input)) // 4 - else: - # 没有工具调用,普通文本响应 - stop_reason = "end_turn" - if full_response_text: - yield f"data: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" - yield f"data: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': full_response_text}})}\n\n" - yield f"data: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" - output_tokens += len(full_response_text) // 4 - - # 3. message_delta 和 message_stop - yield f"data: {json.dumps({'type': 'message_delta', 'delta': {'stop_reason': stop_reason, 'stop_sequence': None}, 'usage': {'output_tokens': output_tokens}})}\n\n" - yield f"data: {json.dumps({'type': 'message_stop'})}\n\n" - - except Exception as e: - logger.error(f"[claude_sse_stream] 异常: {e}") - error_event = { - "type": "error", - "error": {"type": "api_error", "message": f"Stream processing error: {str(e)}"} - } - yield f"data: {json.dumps(error_event)}\n\n" - finally: - try: - deepseek_resp.close() - except Exception: - pass - # 释放账号资源 - if getattr(request.state, "use_config_token", False) and hasattr( - request.state, "account" - ): - release_account(request.state.account) - - return StreamingResponse( - claude_sse_stream(), - media_type="text/event-stream", - headers={"Content-Type": "text/event-stream"}, - ) - else: - # 非流式响应处理 - 添加工具调用支持 - try: - final_content = "" - final_reasoning = "" - - for line in deepseek_resp.iter_lines(): - if not line: - continue - - try: - line_str = line.decode('utf-8') - except Exception as e: - logger.warning(f"[claude_messages] 行解码失败: {e}") - continue - - if line_str.startswith('data:'): - data_str = line_str[5:].strip() - if data_str == '[DONE]': - break - - try: - chunk = json.loads(data_str) - - # 使用DeepSeek的响应格式解析 - 提取 v 字段 - if "v" in chunk: - v_value = chunk["v"] - - # 跳过搜索状态 - if "p" in chunk and chunk.get("p") == "response/search_status": - continue - - # 判断内容类型 - ptype = "text" - if "p" in chunk and chunk.get("p") == "response/thinking_content": - ptype = "thinking" - elif "p" in chunk and chunk.get("p") == "response/content": - ptype = "text" - - # 处理字符串形式的 v 值(即文本内容) - if isinstance(v_value, str): - if ptype == "thinking": - final_reasoning += v_value - else: - final_content += v_value - - # 处理数组更新如状态变更 - elif isinstance(v_value, list): - for item in v_value: - if item.get("p") == "status" and item.get("v") == "FINISHED": - # 完成标志 - break - - except json.JSONDecodeError as e: - logger.warning(f"[claude_messages] JSON解析失败: {e}, data: {data_str}") - continue - except Exception as e: - logger.warning(f"[claude_messages] chunk处理失败: {e}") - continue - - try: - deepseek_resp.close() - except Exception as e: - logger.warning(f"[claude_messages] 关闭响应异常: {e}") - - # 检查是否包含工具调用 - 改进的检测逻辑 - detected_tools = [] - - # 清理响应文本 - cleaned_content = final_content.strip() - - # 尝试多种工具调用检测方法 - tool_detected = False - - # 方法1: 检测完整的JSON格式 - if cleaned_content.startswith('{"tool_calls":') and cleaned_content.endswith(']}'): - try: - tool_data = json.loads(cleaned_content) - for tool_call in tool_data.get('tool_calls', []): - tool_name = tool_call.get('name') - tool_input = tool_call.get('input', {}) - - # 检查是否是有效的工具名称 - if any(tool.get('name') == tool_name for tool in tools_requested): - detected_tools.append({ - 'name': tool_name, - 'input': tool_input - }) - tool_detected = True - except json.JSONDecodeError: - pass - - # 方法2: 使用正则表达式检测嵌入的JSON - if not tool_detected: - tool_call_pattern = r'\{\s*["\']tool_calls["\']\s*:\s*\[(.*?)\]\s*\}' - matches = re.findall(tool_call_pattern, cleaned_content, re.DOTALL) - - for match in matches: - try: - # 尝试解析工具调用JSON - tool_calls_json = f'{{"tool_calls": [{match}]}}' - tool_data = json.loads(tool_calls_json) - - for tool_call in tool_data.get('tool_calls', []): - tool_name = tool_call.get('name') - tool_input = tool_call.get('input', {}) - - # 检查是否是有效的工具名称 - if any(tool.get('name') == tool_name for tool in tools_requested): - detected_tools.append({ - 'name': tool_name, - 'input': tool_input - }) - tool_detected = True - except json.JSONDecodeError: - continue - - # 方法3: 检测特定工具名称的直接调用 (已禁用以避免重复执行) - # 注意:这个方法可能导致Claude Code重复执行命令 - # 当检测到工具名但没有具体参数时,它会返回空的input - # Claude Code接收到这种响应后会尝试重新执行 - # 因此暂时禁用此方法,只依赖方法1和方法2的精确JSON匹配 - ''' - if not tool_detected: - for tool in tools_requested: - tool_name = tool.get('name') - # 检测如 "TodoWrite" 这样的直接工具名称提及 - if tool_name in cleaned_content and any(keyword in cleaned_content.lower() for keyword in ['call', 'use', 'invoke', 'execute']): - # 尝试从上下文推断参数 - detected_tools.append({ - 'name': tool_name, - 'input': {} # 空参数,让调用方处理 - }) - tool_detected = True - break - ''' - - # 构造标准的Anthropic Messages API响应格式 - claude_response = { - "id": f"msg_{int(time.time())}_{random.randint(1000, 9999)}", - "type": "message", - "role": "assistant", - "model": model, - "content": [], - "stop_reason": "tool_use" if detected_tools else "end_turn", - "stop_sequence": None, - "usage": { - "input_tokens": len(str(normalized_messages)) // 4, - "output_tokens": (len(final_content) + len(final_reasoning)) // 4 - } - } - - # 如果有推理内容,添加思考块 - if final_reasoning: - claude_response["content"].append({ - "type": "thinking", - "thinking": final_reasoning - }) - - # 处理工具调用 - if detected_tools: - for i, tool_info in enumerate(detected_tools): - tool_use_id = f"toolu_{int(time.time())}_{random.randint(1000, 9999)}_{i}" - tool_name = tool_info['name'] - tool_input = tool_info['input'] - - claude_response["content"].append({ - "type": "tool_use", - "id": tool_use_id, - "name": tool_name, - "input": tool_input - }) - else: - # 没有工具调用,添加普通文本内容 - if final_content or not final_reasoning: - claude_response["content"].append({ - "type": "text", - "text": final_content or "抱歉,没有生成有效的响应内容。" - }) - - return JSONResponse(content=claude_response, status_code=200) - - except Exception as e: - logger.error(f"[claude_messages] 非流式响应处理异常: {e}") - try: - deepseek_resp.close() - except Exception as close_e: - logger.warning(f"[claude_messages] 关闭响应异常2: {close_e}") - return JSONResponse( - status_code=500, - content={"error": {"type": "api_error", "message": "Response processing error"}} - ) - - except HTTPException as exc: - return JSONResponse(status_code=exc.status_code, content={"error": {"type": "invalid_request_error", "message": exc.detail}}) - except Exception as exc: - logger.error(f"[claude_messages] 未知异常: {exc}") - return JSONResponse(status_code=500, content={"error": {"type": "api_error", "message": "Internal Server Error"}}) - finally: - # 释放账号资源 - if getattr(request.state, "use_config_token", False) and hasattr( - request.state, "account" - ): - release_account(request.state.account) - - -# ---------------------------------------------------------------------- -# Claude 路由:/anthropic/v1/messages/count_tokens -# ---------------------------------------------------------------------- -@app.post("/anthropic/v1/messages/count_tokens") -async def claude_count_tokens(request: Request): - try: - # 处理 token 相关逻辑,若认证失败则直接返回错误响应 - try: - determine_claude_mode_and_token(request) - except HTTPException as exc: - return JSONResponse( - status_code=exc.status_code, content={"error": exc.detail} - ) - except Exception as exc: - logger.error(f"[claude_count_tokens] determine_claude_mode_and_token 异常: {exc}") - return JSONResponse( - status_code=500, content={"error": "Claude authentication failed."} - ) - - req_data = await request.json() - model = req_data.get("model") - messages = req_data.get("messages", []) - system = req_data.get("system", "") - - if not model or not messages: - raise HTTPException( - status_code=400, detail="Request must include 'model' and 'messages'." - ) - - # 计算输入token数量 - def estimate_tokens(text): - """简单的token估算,约4个字符=1个token""" - if isinstance(text, str): - return len(text) // 4 - elif isinstance(text, list): - return sum(estimate_tokens(item.get("text", "")) if isinstance(item, dict) else estimate_tokens(str(item)) for item in text) - else: - return len(str(text)) // 4 - - # 计算消息的token数量 - input_tokens = 0 - - # 添加系统消息的token - if system: - input_tokens += estimate_tokens(system) - - # 添加消息列表的token - for message in messages: - role = message.get("role", "") - content = message.get("content", "") - - # 角色标记大约占用2个token - input_tokens += 2 - - # 内容token计算 - if isinstance(content, list): - for content_block in content: - if isinstance(content_block, dict): - if content_block.get("type") == "text": - input_tokens += estimate_tokens(content_block.get("text", "")) - elif content_block.get("type") == "tool_result": - input_tokens += estimate_tokens(content_block.get("content", "")) - else: - # 其他类型的内容块 - input_tokens += estimate_tokens(str(content_block)) - else: - input_tokens += estimate_tokens(str(content_block)) - else: - input_tokens += estimate_tokens(content) - - # 处理工具定义 - tools = req_data.get("tools", []) - if tools: - for tool in tools: - # 工具名称和描述 - input_tokens += estimate_tokens(tool.get("name", "")) - input_tokens += estimate_tokens(tool.get("description", "")) - - # 工具参数schema - input_schema = tool.get("input_schema", {}) - input_tokens += estimate_tokens(json.dumps(input_schema, ensure_ascii=False)) - - # 构造响应 - response = { - "input_tokens": max(1, input_tokens) # 至少1个token - } - - return JSONResponse(content=response, status_code=200) - - except HTTPException as exc: - return JSONResponse(status_code=exc.status_code, content={"error": {"type": "invalid_request_error", "message": exc.detail}}) - except Exception as exc: - logger.error(f"[claude_count_tokens] 未知异常: {exc}") - return JSONResponse(status_code=500, content={"error": {"type": "api_error", "message": "Internal Server Error"}}) - - -# ---------------------------------------------------------------------- -# (11) 路由:/ -# ---------------------------------------------------------------------- -@app.get("/") -def index(request: Request): - return templates.TemplateResponse("welcome.html", {"request": request}) - - -# ---------------------------------------------------------------------- -# 启动 FastAPI 应用(仅本地运行) +# 本地运行入口 # ---------------------------------------------------------------------- if __name__ == "__main__" and not IS_VERCEL: import uvicorn diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..5ad4ac0 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1 @@ +# DS2API Core Modules diff --git a/core/auth.py b/core/auth.py new file mode 100644 index 0000000..d2fea67 --- /dev/null +++ b/core/auth.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +"""账号认证与管理模块""" +import random +from fastapi import HTTPException, Request + +from .config import CONFIG, logger +from .deepseek import login_deepseek_via_account, BASE_HEADERS + +# -------------------------- 全局账号队列 -------------------------- +account_queue = [] # 维护所有可用账号 +claude_api_key_queue = [] # 维护所有可用的Claude API keys + + +def init_account_queue(): + """初始化时从配置加载账号""" + global account_queue + account_queue = CONFIG.get("accounts", [])[:] # 深拷贝 + random.shuffle(account_queue) # 初始随机排序 + + +def init_claude_api_key_queue(): + """Claude API keys由用户自己的token提供,这里初始化为空""" + global claude_api_key_queue + claude_api_key_queue = [] + + +# 初始化 +init_account_queue() +init_claude_api_key_queue() + + +# ---------------------------------------------------------------------- +# 辅助函数:获取账号唯一标识(优先 email,否则 mobile) +# ---------------------------------------------------------------------- +def get_account_identifier(account: dict) -> str: + """返回账号的唯一标识,优先使用 email,否则使用 mobile""" + return account.get("email", "").strip() or account.get("mobile", "").strip() + + +# ---------------------------------------------------------------------- +# 账号选择与释放 +# ---------------------------------------------------------------------- +def choose_new_account(exclude_ids=None): + """选择策略: + 1. 优先选择已有 token 的账号(避免登录) + 2. 遍历队列,找到第一个未被 exclude_ids 包含的账号 + 3. 从队列中移除该账号 + 4. 返回该账号(由后续逻辑保证最终会重新入队) + """ + if exclude_ids is None: + exclude_ids = [] + + # 第一轮:优先选择已有 token 的账号 + for i in range(len(account_queue)): + acc = account_queue[i] + acc_id = get_account_identifier(acc) + if acc_id and acc_id not in exclude_ids: + if acc.get("token", "").strip(): # 已有 token + logger.info(f"[choose_new_account] 选择已有token的账号: {acc_id}") + return account_queue.pop(i) + + # 第二轮:选择任意账号(需要登录) + for i in range(len(account_queue)): + acc = account_queue[i] + acc_id = get_account_identifier(acc) + if acc_id and acc_id not in exclude_ids: + logger.info(f"[choose_new_account] 选择需登录的账号: {acc_id}") + return account_queue.pop(i) + + logger.warning("[choose_new_account] 没有可用的账号或所有账号都在使用中") + return None + + +def release_account(account: dict): + """将账号重新加入队列末尾""" + account_queue.append(account) + + +# ---------------------------------------------------------------------- +# Claude API key 管理函数(简化版本) +# ---------------------------------------------------------------------- +def choose_claude_api_key(): + """选择一个可用的Claude API key - 现在直接由用户提供""" + return None + + +def release_claude_api_key(api_key): + """释放Claude API key - 现在无需操作""" + pass + + +# ---------------------------------------------------------------------- +# 判断调用模式:配置模式 vs 用户自带 token +# ---------------------------------------------------------------------- +def determine_mode_and_token(request: Request): + """ + 根据请求头 Authorization 判断使用哪种模式: + - 如果 Bearer token 出现在 CONFIG["keys"] 中,则为配置模式,从 CONFIG["accounts"] 中随机选择一个账号(排除已尝试账号), + 检查该账号是否已有 token,否则调用登录接口获取; + - 否则,直接使用请求中的 Bearer 值作为 DeepSeek token。 + 结果存入 request.state.deepseek_token;配置模式下同时存入 request.state.account 与 request.state.tried_accounts。 + """ + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + raise HTTPException( + status_code=401, detail="Unauthorized: missing Bearer token." + ) + caller_key = auth_header.replace("Bearer ", "", 1).strip() + config_keys = CONFIG.get("keys", []) + if caller_key in config_keys: + request.state.use_config_token = True + request.state.tried_accounts = [] # 初始化已尝试账号 + selected_account = choose_new_account() + if not selected_account: + raise HTTPException( + status_code=429, + detail="No accounts configured or all accounts are busy.", + ) + if not selected_account.get("token", "").strip(): + try: + login_deepseek_via_account(selected_account) + except Exception as e: + logger.error( + f"[determine_mode_and_token] 账号 {get_account_identifier(selected_account)} 登录失败:{e}" + ) + raise HTTPException(status_code=500, detail="Account login failed.") + + request.state.deepseek_token = selected_account.get("token") + request.state.account = selected_account + + else: + request.state.use_config_token = False + request.state.deepseek_token = caller_key + + +def get_auth_headers(request: Request) -> dict: + """返回 DeepSeek 请求所需的公共请求头""" + return {**BASE_HEADERS, "authorization": f"Bearer {request.state.deepseek_token}"} + + +# ---------------------------------------------------------------------- +# Claude 认证相关函数 +# ---------------------------------------------------------------------- +def determine_claude_mode_and_token(request: Request): + """Claude认证:沿用现有的OpenAI接口认证逻辑""" + determine_mode_and_token(request) diff --git a/core/config.py b/core/config.py new file mode 100644 index 0000000..de74f2c --- /dev/null +++ b/core/config.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +"""配置管理模块""" +import base64 +import json +import logging +import os +import sys + +import transformers + +# -------------------------- 获取项目根目录 -------------------------- +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +IS_VERCEL = bool(os.getenv("VERCEL")) or bool(os.getenv("NOW_REGION")) + + +def resolve_path(env_key: str, default_rel: str) -> str: + """解析路径,支持环境变量覆盖""" + raw = os.getenv(env_key) + if raw: + return raw if os.path.isabs(raw) else os.path.join(BASE_DIR, raw) + return os.path.join(BASE_DIR, default_rel) + + +# -------------------------- 日志配置 -------------------------- +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO").upper(), + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + force=True, +) +logger = logging.getLogger("ds2api") + +# -------------------------- 初始化 tokenizer -------------------------- +chat_tokenizer_dir = resolve_path("DS2API_TOKENIZER_DIR", "") +tokenizer = transformers.AutoTokenizer.from_pretrained( + chat_tokenizer_dir, trust_remote_code=True +) + +# ---------------------------------------------------------------------- +# 配置文件的读写函数 +# ---------------------------------------------------------------------- +CONFIG_PATH = resolve_path("DS2API_CONFIG_PATH", "config.json") + + +def load_config() -> dict: + """加载配置。 + + 优先从环境变量读取: + - DS2API_CONFIG_JSON / CONFIG_JSON: 直接 JSON 字符串,或 base64 编码后的 JSON + + 若未提供环境变量,再从 CONFIG_PATH 指向的文件读取。 + """ + raw_cfg = os.getenv("DS2API_CONFIG_JSON") or os.getenv("CONFIG_JSON") + if raw_cfg: + try: + return json.loads(raw_cfg) + except json.JSONDecodeError: + try: + decoded = base64.b64decode(raw_cfg).decode("utf-8") + return json.loads(decoded) + except Exception as e: + logger.warning(f"[load_config] 环境变量配置解析失败: {e}") + return {} + + try: + with open(CONFIG_PATH, "r", encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.warning(f"[load_config] 无法读取配置文件({CONFIG_PATH}): {e}") + return {} + + +def save_config(cfg: dict) -> None: + """将配置写回 config.json。 + + Vercel 环境文件系统通常是只读的;且如果配置来自环境变量,也无法回写。 + 所以这里失败不应影响主流程。 + """ + if os.getenv("DS2API_CONFIG_JSON") or os.getenv("CONFIG_JSON"): + logger.info("[save_config] 配置来自环境变量,跳过写回") + return + + try: + with open(CONFIG_PATH, "w", encoding="utf-8") as f: + json.dump(cfg, f, ensure_ascii=False, indent=2) + except PermissionError as e: + logger.warning(f"[save_config] 配置文件不可写({CONFIG_PATH}): {e}") + except Exception as e: + logger.exception(f"[save_config] 写入 config.json 失败: {e}") + + +# 全局配置 +CONFIG = load_config() +if not CONFIG: + logger.warning( + "[config] 未加载到有效配置,请提供 config.json(路径可用 DS2API_CONFIG_PATH 指定)或设置环境变量 DS2API_CONFIG_JSON" + ) + +# WASM 模块文件路径 +WASM_PATH = resolve_path("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm") + +# 模板目录 +TEMPLATES_DIR = resolve_path("DS2API_TEMPLATES_DIR", "templates") diff --git a/core/deepseek.py b/core/deepseek.py new file mode 100644 index 0000000..93b6705 --- /dev/null +++ b/core/deepseek.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +"""DeepSeek API 相关逻辑""" +import time +from curl_cffi import requests +from fastapi import HTTPException + +from .config import CONFIG, save_config, logger + +# ---------------------------------------------------------------------- +# DeepSeek 相关常量 +# ---------------------------------------------------------------------- +DEEPSEEK_HOST = "chat.deepseek.com" +DEEPSEEK_LOGIN_URL = f"https://{DEEPSEEK_HOST}/api/v0/users/login" +DEEPSEEK_CREATE_SESSION_URL = f"https://{DEEPSEEK_HOST}/api/v0/chat_session/create" +DEEPSEEK_CREATE_POW_URL = f"https://{DEEPSEEK_HOST}/api/v0/chat/create_pow_challenge" +DEEPSEEK_COMPLETION_URL = f"https://{DEEPSEEK_HOST}/api/v0/chat/completion" + +BASE_HEADERS = { + "Host": "chat.deepseek.com", + "User-Agent": "DeepSeek/1.0.13 Android/35", + "Accept": "application/json", + "Accept-Encoding": "gzip", + "Content-Type": "application/json", + "x-client-platform": "android", + "x-client-version": "1.3.0-auto-resume", + "x-client-locale": "zh_CN", + "accept-charset": "UTF-8", +} + + +def get_account_identifier(account: dict) -> str: + """返回账号的唯一标识,优先使用 email,否则使用 mobile""" + return account.get("email", "").strip() or account.get("mobile", "").strip() + + +# ---------------------------------------------------------------------- +# 登录函数:支持使用 email 或 mobile 登录 +# ---------------------------------------------------------------------- +def login_deepseek_via_account(account: dict) -> str: + """使用 account 中的 email 或 mobile 登录 DeepSeek, + 成功后将返回的 token 写入 account 并保存至配置文件,返回新 token。 + """ + email = account.get("email", "").strip() + mobile = account.get("mobile", "").strip() + password = account.get("password", "").strip() + if not password or (not email and not mobile): + raise HTTPException( + status_code=400, + detail="账号缺少必要的登录信息(必须提供 email 或 mobile 以及 password)", + ) + if email: + payload = { + "email": email, + "password": password, + "device_id": "deepseek_to_api", + "os": "android", + } + else: + payload = { + "mobile": mobile, + "area_code": None, + "password": password, + "device_id": "deepseek_to_api", + "os": "android", + } + try: + resp = requests.post( + DEEPSEEK_LOGIN_URL, headers=BASE_HEADERS, json=payload, impersonate="safari15_3" + ) + resp.raise_for_status() + except Exception as e: + logger.error(f"[login_deepseek_via_account] 登录请求异常: {e}") + raise HTTPException(status_code=500, detail="Account login failed: 请求异常") + try: + logger.warning(f"[login_deepseek_via_account] {resp.text}") + data = resp.json() + except Exception as e: + logger.error(f"[login_deepseek_via_account] JSON解析失败: {e}") + raise HTTPException( + status_code=500, detail="Account login failed: invalid JSON response" + ) + # 校验响应数据格式是否正确 + if ( + data.get("data") is None + or data["data"].get("biz_data") is None + or data["data"]["biz_data"].get("user") is None + ): + logger.error(f"[login_deepseek_via_account] 登录响应格式错误: {data}") + raise HTTPException( + status_code=500, detail="Account login failed: invalid response format" + ) + new_token = data["data"]["biz_data"]["user"].get("token") + if not new_token: + logger.error(f"[login_deepseek_via_account] 登录响应中缺少 token: {data}") + raise HTTPException( + status_code=500, detail="Account login failed: missing token" + ) + account["token"] = new_token + save_config(CONFIG) + return new_token + + +# ---------------------------------------------------------------------- +# 封装对话接口调用的重试机制 +# ---------------------------------------------------------------------- +def call_completion_endpoint(payload: dict, headers: dict, max_attempts: int = 3): + """调用 DeepSeek 对话接口,支持重试""" + attempts = 0 + while attempts < max_attempts: + try: + deepseek_resp = requests.post( + DEEPSEEK_COMPLETION_URL, + headers=headers, + json=payload, + stream=True, + impersonate="safari15_3", + ) + except Exception as e: + logger.warning(f"[call_completion_endpoint] 请求异常: {e}") + time.sleep(1) + attempts += 1 + continue + if deepseek_resp.status_code == 200: + return deepseek_resp + else: + logger.warning( + f"[call_completion_endpoint] 调用对话接口失败, 状态码: {deepseek_resp.status_code}" + ) + deepseek_resp.close() + time.sleep(1) + attempts += 1 + return None diff --git a/core/messages.py b/core/messages.py new file mode 100644 index 0000000..e2f359c --- /dev/null +++ b/core/messages.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +"""消息处理模块""" +import re + +from .config import CONFIG, logger + +# Claude 默认模型 +CLAUDE_DEFAULT_MODEL = "claude-sonnet-4-20250514" + +# 预编译正则表达式(性能优化) +_MARKDOWN_IMAGE_PATTERN = re.compile(r"!\[(.*?)\]\((.*?)\)") + + +# ---------------------------------------------------------------------- +# 消息预处理函数,将多轮对话合并成最终 prompt +# ---------------------------------------------------------------------- +def messages_prepare(messages: list) -> str: + """处理消息列表,合并连续相同角色的消息,并添加角色标签: + - 对于 assistant 消息,加上 <|Assistant|> 前缀及 <|end▁of▁sentence|> 结束标签; + - 对于 user/system 消息(除第一条外)加上 <|User|> 前缀; + - 如果消息 content 为数组,则提取其中 type 为 "text" 的部分; + - 最后移除 markdown 图片格式的内容。 + """ + processed = [] + for m in messages: + role = m.get("role", "") + content = m.get("content", "") + if isinstance(content, list): + texts = [ + item.get("text", "") for item in content if item.get("type") == "text" + ] + text = "\n".join(texts) + else: + text = str(content) + processed.append({"role": role, "text": text}) + if not processed: + return "" + # 合并连续同一角色的消息 + merged = [processed[0]] + for msg in processed[1:]: + if msg["role"] == merged[-1]["role"]: + merged[-1]["text"] += "\n\n" + msg["text"] + else: + merged.append(msg) + # 添加标签 + parts = [] + for idx, block in enumerate(merged): + role = block["role"] + text = block["text"] + if role == "assistant": + parts.append(f"<|Assistant|>{text}<|end▁of▁sentence|>") + elif role in ("user", "system"): + if idx > 0: + parts.append(f"<|User|>{text}") + else: + parts.append(text) + else: + parts.append(text) + final_prompt = "".join(parts) + # 仅移除 markdown 图片格式(不全部移除 !)- 使用预编译的正则表达式 + final_prompt = _MARKDOWN_IMAGE_PATTERN.sub(r"[\1](\2)", final_prompt) + return final_prompt + + +# ---------------------------------------------------------------------- +# OpenAI到Claude格式转换函数 +# ---------------------------------------------------------------------- +def convert_claude_to_deepseek(claude_request: dict) -> dict: + """将Claude格式的请求转换为DeepSeek格式(基于现有OpenAI接口)""" + messages = claude_request.get("messages", []) + model = claude_request.get("model", CLAUDE_DEFAULT_MODEL) + + # 从配置文件读取Claude模型映射 + claude_mapping = CONFIG.get( + "claude_model_mapping", {"fast": "deepseek-chat", "slow": "deepseek-chat"} + ) + + # Claude模型映射到DeepSeek模型 - 基于配置和模型特征判断 + if ( + "opus" in model.lower() + or "reasoner" in model.lower() + or "slow" in model.lower() + ): + deepseek_model = claude_mapping.get("slow", "deepseek-chat") + else: + deepseek_model = claude_mapping.get("fast", "deepseek-chat") + + deepseek_request = {"model": deepseek_model, "messages": messages.copy()} + + # 处理system消息 - 将system参数转换为system role消息 + if "system" in claude_request: + system_msg = {"role": "system", "content": claude_request["system"]} + deepseek_request["messages"].insert(0, system_msg) + + # 添加可选参数 + if "temperature" in claude_request: + deepseek_request["temperature"] = claude_request["temperature"] + if "top_p" in claude_request: + deepseek_request["top_p"] = claude_request["top_p"] + if "stop_sequences" in claude_request: + deepseek_request["stop"] = claude_request["stop_sequences"] + if "stream" in claude_request: + deepseek_request["stream"] = claude_request["stream"] + + return deepseek_request + + +def convert_deepseek_to_claude_format( + deepseek_response: dict, original_claude_model: str = CLAUDE_DEFAULT_MODEL +) -> dict: + """将DeepSeek响应转换为Claude格式的OpenAI响应""" + # DeepSeek响应已经是OpenAI格式,只需要修改模型名称 + if isinstance(deepseek_response, dict): + claude_response = deepseek_response.copy() + claude_response["model"] = original_claude_model + return claude_response + + return deepseek_response diff --git a/core/pow.py b/core/pow.py new file mode 100644 index 0000000..84d1f22 --- /dev/null +++ b/core/pow.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +"""PoW (Proof of Work) 计算模块""" +import base64 +import ctypes +import json +import struct +import threading +import time + +from curl_cffi import requests +from wasmtime import Engine, Linker, Module, Store + +from .config import CONFIG, WASM_PATH, logger + +# ---------------------------------------------------------------------- +# WASM 模块缓存 - 避免每次请求都重新加载 +# ---------------------------------------------------------------------- +_wasm_cache_lock = threading.Lock() +_wasm_engine = None +_wasm_module = None + + +def _get_cached_wasm_module(wasm_path: str): + """获取缓存的 WASM 模块,首次调用时加载""" + global _wasm_engine, _wasm_module + + if _wasm_module is not None: + return _wasm_engine, _wasm_module + + with _wasm_cache_lock: + # 双重检查锁定 + if _wasm_module is not None: + return _wasm_engine, _wasm_module + + try: + with open(wasm_path, "rb") as f: + wasm_bytes = f.read() + _wasm_engine = Engine() + _wasm_module = Module(_wasm_engine, wasm_bytes) + logger.info(f"[WASM] 已缓存 WASM 模块: {wasm_path}") + except Exception as e: + logger.error(f"[WASM] 加载 WASM 模块失败: {e}") + raise RuntimeError(f"加载 wasm 文件失败: {wasm_path}, 错误: {e}") + + return _wasm_engine, _wasm_module + + +# 启动时预加载 WASM 模块 +try: + _get_cached_wasm_module(WASM_PATH) +except Exception as e: + logger.warning(f"[WASM] 启动时预加载失败(将在首次使用时重试): {e}") + + +def get_account_identifier(account: dict) -> str: + """返回账号的唯一标识""" + return account.get("email", "").strip() or account.get("mobile", "").strip() + + +# ---------------------------------------------------------------------- +# 使用 WASM 模块计算 PoW 答案的辅助函数 +# ---------------------------------------------------------------------- +def compute_pow_answer( + algorithm: str, + challenge_str: str, + salt: str, + difficulty: int, + expire_at: int, + signature: str, + target_path: str, + wasm_path: str, +) -> int: + """ + 使用 WASM 模块计算 DeepSeekHash 答案(answer)。 + 根据 JS 逻辑: + - 拼接前缀: "{salt}_{expire_at}_" + - 将 challenge 与前缀写入 wasm 内存后调用 wasm_solve 进行求解, + - 从 wasm 内存中读取状态与求解结果, + - 若状态非 0,则返回整数形式的答案,否则返回 None。 + + 优化:使用缓存的 WASM 模块,避免每次请求都重新加载文件。 + """ + if algorithm != "DeepSeekHashV1": + raise ValueError(f"不支持的算法:{algorithm}") + + prefix = f"{salt}_{expire_at}_" + + # 获取缓存的 WASM 模块(避免重复加载文件) + engine, module = _get_cached_wasm_module(wasm_path) + + # 每次调用创建新的 Store 和实例(必须的,因为 Store 不是线程安全的) + store = Store(engine) + linker = Linker(engine) + instance = linker.instantiate(store, module) + exports = instance.exports(store) + + try: + memory = exports["memory"] + add_to_stack = exports["__wbindgen_add_to_stack_pointer"] + alloc = exports["__wbindgen_export_0"] + wasm_solve = exports["wasm_solve"] + except KeyError as e: + raise RuntimeError(f"缺少 wasm 导出函数: {e}") + + def write_memory(offset: int, data: bytes): + size = len(data) + base_addr = ctypes.cast(memory.data_ptr(store), ctypes.c_void_p).value + ctypes.memmove(base_addr + offset, data, size) + + def read_memory(offset: int, size: int) -> bytes: + base_addr = ctypes.cast(memory.data_ptr(store), ctypes.c_void_p).value + return ctypes.string_at(base_addr + offset, size) + + def encode_string(text: str): + data = text.encode("utf-8") + length = len(data) + ptr_val = alloc(store, length, 1) + ptr = int(ptr_val.value) if hasattr(ptr_val, "value") else int(ptr_val) + write_memory(ptr, data) + return ptr, length + + # 1. 申请 16 字节栈空间 + retptr = add_to_stack(store, -16) + # 2. 编码 challenge 与 prefix 到 wasm 内存中 + ptr_challenge, len_challenge = encode_string(challenge_str) + ptr_prefix, len_prefix = encode_string(prefix) + # 3. 调用 wasm_solve(注意:difficulty 以 float 形式传入) + wasm_solve( + store, + retptr, + ptr_challenge, + len_challenge, + ptr_prefix, + len_prefix, + float(difficulty), + ) + # 4. 从 retptr 处读取 4 字节状态和 8 字节求解结果 + status_bytes = read_memory(retptr, 4) + if len(status_bytes) != 4: + add_to_stack(store, 16) + raise RuntimeError("读取状态字节失败") + status = struct.unpack(" str | None: + """创建 DeepSeek 会话 + + Args: + request: FastAPI 请求对象 + max_attempts: 最大重试次数 + + Returns: + 会话 ID,如果失败返回 None + """ + attempts = 0 + while attempts < max_attempts: + headers = get_auth_headers(request) + try: + resp = cffi_requests.post( + DEEPSEEK_CREATE_SESSION_URL, + headers=headers, + json={"agent": "chat"}, + impersonate="safari15_3", + ) + except Exception as e: + logger.error(f"[create_session] 请求异常: {e}") + attempts += 1 + continue + + try: + data = resp.json() + except Exception as e: + logger.error(f"[create_session] JSON解析异常: {e}") + data = {} + + if resp.status_code == 200 and data.get("code") == 0: + session_id = data["data"]["biz_data"]["id"] + resp.close() + return session_id + else: + code = data.get("code") + logger.warning( + f"[create_session] 创建会话失败, code={code}, msg={data.get('msg')}" + ) + resp.close() + + # 配置模式下尝试切换账号 + if request.state.use_config_token: + current_id = get_account_identifier(request.state.account) + if not hasattr(request.state, "tried_accounts"): + request.state.tried_accounts = [] + if current_id not in request.state.tried_accounts: + request.state.tried_accounts.append(current_id) + new_account = choose_new_account(request.state.tried_accounts) + if new_account is None: + break + try: + login_deepseek_via_account(new_account) + except Exception as e: + logger.error( + f"[create_session] 账号 {get_account_identifier(new_account)} 登录失败:{e}" + ) + attempts += 1 + continue + request.state.account = new_account + request.state.deepseek_token = new_account.get("token") + else: + attempts += 1 + continue + attempts += 1 + return None + + +def get_pow(request: Request, max_attempts: int = 3) -> str | None: + """获取 PoW 响应的包装函数 + + Args: + request: FastAPI 请求对象 + max_attempts: 最大重试次数 + + Returns: + Base64 编码的 PoW 响应,如果失败返回 None + """ + return get_pow_response( + request, + get_auth_headers, + choose_new_account, + login_deepseek_via_account, + DEEPSEEK_CREATE_POW_URL, + max_attempts, + ) + + +def prepare_completion_request( + request: Request, + session_id: str, + prompt: str, + thinking_enabled: bool = False, + search_enabled: bool = False, + max_attempts: int = 3, +): + """准备并执行对话补全请求 + + Args: + request: FastAPI 请求对象 + session_id: 会话 ID + prompt: 处理后的提示词 + thinking_enabled: 是否启用思考模式 + search_enabled: 是否启用搜索 + max_attempts: 最大重试次数 + + Returns: + DeepSeek 响应对象,如果失败返回 None + """ + pow_resp = get_pow(request, max_attempts) + if not pow_resp: + return None + + headers = {**get_auth_headers(request), "x-ds-pow-response": pow_resp} + payload = { + "chat_session_id": session_id, + "parent_message_id": None, + "prompt": prompt, + "ref_file_ids": [], + "thinking_enabled": thinking_enabled, + "search_enabled": search_enabled, + } + + return call_completion_endpoint(payload, headers, max_attempts) + + +def get_model_config(model: str) -> tuple[bool, bool]: + """根据模型名称获取配置 + + Args: + model: 模型名称 + + Returns: + (thinking_enabled, search_enabled) 元组 + """ + model_lower = model.lower() + + if model_lower in ["deepseek-v3", "deepseek-chat"]: + return False, False + elif model_lower in ["deepseek-r1", "deepseek-reasoner"]: + return True, False + elif model_lower in ["deepseek-v3-search", "deepseek-chat-search"]: + return False, True + elif model_lower in ["deepseek-r1-search", "deepseek-reasoner-search"]: + return True, True + else: + return None, None # 不支持的模型 + + +def cleanup_account(request: Request): + """清理账号资源(将账号放回队列)""" + if getattr(request.state, "use_config_token", False) and hasattr(request.state, "account"): + release_account(request.state.account) diff --git a/dev.py b/dev.py new file mode 100644 index 0000000..f4d1d0c --- /dev/null +++ b/dev.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +DS2API 开发服务器 - 统一启动后端和前端 + +使用方法: + python dev.py # 同时启动后端和前端 + python dev.py --backend # 仅启动后端 + python dev.py --frontend # 仅启动前端 + python dev.py --install # 安装所有依赖 + +环境变量: + PORT - 后端服务端口,默认 5001 + LOG_LEVEL - 日志级别,默认 INFO +""" +import os +import sys +import signal +import subprocess +import time +from pathlib import Path + +# 配置 +BACKEND_PORT = int(os.getenv("PORT", "5001")) +FRONTEND_PORT = 5173 +HOST = os.getenv("HOST", "0.0.0.0") +LOG_LEVEL = os.getenv("LOG_LEVEL", "info").lower() +PROJECT_DIR = Path(__file__).parent +WEBUI_DIR = PROJECT_DIR / "webui" +REQUIREMENTS_FILE = PROJECT_DIR / "requirements.txt" + +processes = [] + + +def install_dependencies(): + """安装所有 Python 和 Node.js 依赖""" + print("\n📦 安装 Python 依赖...") + subprocess.run([ + sys.executable, "-m", "pip", "install", "-r", str(REQUIREMENTS_FILE), "-q" + ], check=True) + print("✅ Python 依赖安装完成") + + if WEBUI_DIR.exists(): + print("\n📦 安装前端依赖...") + subprocess.run(["npm", "install"], cwd=WEBUI_DIR, check=True) + print("✅ 前端依赖安装完成") + + print("\n🎉 所有依赖安装完成!运行 `python dev.py` 启动服务\n") + + +def signal_handler(sig, frame): + """处理退出信号,终止所有子进程""" + print("\n\n🛑 正在关闭所有服务...") + for proc in processes: + if proc.poll() is None: + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + print("👋 已退出\n") + sys.exit(0) + + +def start_backend(): + """启动后端服务""" + print(f"🚀 启动后端服务... http://localhost:{BACKEND_PORT}") + proc = subprocess.Popen( + [ + sys.executable, "-m", "uvicorn", + "app:app", + "--host", HOST, + "--port", str(BACKEND_PORT), + "--reload", + "--reload-dir", str(PROJECT_DIR), + "--log-level", LOG_LEVEL, + ], + cwd=PROJECT_DIR, + ) + processes.append(proc) + return proc + + +def start_frontend(): + """启动前端开发服务器""" + if not WEBUI_DIR.exists(): + print("⚠️ webui 目录不存在,跳过前端启动") + return None + + node_modules = WEBUI_DIR / "node_modules" + if not node_modules.exists(): + print("📦 安装前端依赖...") + subprocess.run(["npm", "install"], cwd=WEBUI_DIR, check=True) + + print(f"🎨 启动前端服务... http://localhost:{FRONTEND_PORT}") + proc = subprocess.Popen( + ["npm", "run", "dev"], + cwd=WEBUI_DIR, + ) + processes.append(proc) + return proc + + +def main(): + # 解析参数 + if "--install" in sys.argv or "-i" in sys.argv: + install_dependencies() + return + + backend_only = "--backend" in sys.argv or "-b" in sys.argv + frontend_only = "--frontend" in sys.argv or "-f" in sys.argv + + # 注册信号处理 + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + print("\n" + "=" * 50) + print(" DS2API 开发服务器") + print("=" * 50) + + if frontend_only: + start_frontend() + elif backend_only: + start_backend() + else: + # 同时启动 + start_backend() + time.sleep(1) # 等待后端启动 + start_frontend() + + print("\n" + "-" * 50) + if not frontend_only: + print(f"📡 后端 API: http://localhost:{BACKEND_PORT}") + if not backend_only: + print(f"🎨 管理界面: http://localhost:{FRONTEND_PORT}") + print("-" * 50) + print("按 Ctrl+C 停止所有服务\n") + + # 等待进程结束 + try: + while processes: + for proc in processes[:]: + if proc.poll() is not None: + processes.remove(proc) + time.sleep(0.5) + except KeyboardInterrupt: + signal_handler(None, None) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 41888c8..4855447 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,19 @@ +# DS2API 依赖 +# 安装命令: pip install -r requirements.txt + +# Web 框架 fastapi>=0.110.0,<1.0.0 -uvicorn>=0.24.0,<1.0.0 -curl_cffi>=0.7.0,<1.0.0 -transformers>=4.39.0,<5.0.0 -wasmtime>=14.0.0,<20.0.0 +uvicorn[standard]>=0.24.0,<1.0.0 + +# HTTP 客户端 +curl_cffi>=0.7.0 +httpx>=0.25.0 + +# 模板引擎 jinja2>=3.1.0,<4.0.0 + +# Tokenizer(用于 token 计数) +transformers>=4.39.0,<5.0.0 + +# WASM 运行时(用于 PoW 计算) +wasmtime>=14.0.0 diff --git a/routes/__init__.py b/routes/__init__.py new file mode 100644 index 0000000..bd402eb --- /dev/null +++ b/routes/__init__.py @@ -0,0 +1 @@ +# DS2API Routes diff --git a/routes/admin.py b/routes/admin.py new file mode 100644 index 0000000..b965b1f --- /dev/null +++ b/routes/admin.py @@ -0,0 +1,419 @@ +# -*- coding: utf-8 -*- +"""Admin API 路由 - 管理界面后端""" +import base64 +import json +import os +import httpx + +from fastapi import APIRouter, HTTPException, Request, Depends +from fastapi.responses import JSONResponse +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials + +from core.config import CONFIG, save_config, logger +from core.auth import account_queue, init_account_queue + +router = APIRouter(prefix="/admin", tags=["admin"]) +security = HTTPBearer(auto_error=False) + +# Admin Key 验证 +ADMIN_KEY = os.getenv("DS2API_ADMIN_KEY", "") + +# Vercel 预配置(可通过环境变量设置) +VERCEL_TOKEN = os.getenv("VERCEL_TOKEN", "") +VERCEL_PROJECT_ID = os.getenv("VERCEL_PROJECT_ID", "") +VERCEL_TEAM_ID = os.getenv("VERCEL_TEAM_ID", "") + + +def verify_admin(credentials: HTTPAuthorizationCredentials = Depends(security)): + """验证 Admin 权限""" + if not ADMIN_KEY: + # 未配置 Admin Key,允许访问(开发模式) + return True + if not credentials or credentials.credentials != ADMIN_KEY: + raise HTTPException(status_code=401, detail="Invalid admin key") + return True + + +# ---------------------------------------------------------------------- +# Vercel 预配置信息 +# ---------------------------------------------------------------------- +@router.get("/vercel/config") +async def get_vercel_config(_: bool = Depends(verify_admin)): + """获取预配置的 Vercel 信息(脱敏)""" + return JSONResponse(content={ + "has_token": bool(VERCEL_TOKEN), + "project_id": VERCEL_PROJECT_ID, + "team_id": VERCEL_TEAM_ID, + "token_preview": VERCEL_TOKEN[:8] + "****" if VERCEL_TOKEN else "", + }) + + +# ---------------------------------------------------------------------- +# 配置管理 +# ---------------------------------------------------------------------- +@router.get("/config") +async def get_config(_: bool = Depends(verify_admin)): + """获取当前配置(密码脱敏)""" + safe_config = { + "keys": CONFIG.get("keys", []), + "accounts": [], + "claude_model_mapping": CONFIG.get("claude_model_mapping", {}), + } + for acc in CONFIG.get("accounts", []): + safe_acc = { + "email": acc.get("email", ""), + "mobile": acc.get("mobile", ""), + "has_password": bool(acc.get("password")), + "has_token": bool(acc.get("token")), + } + safe_config["accounts"].append(safe_acc) + return JSONResponse(content=safe_config) + + +@router.post("/config") +async def update_config(request: Request, _: bool = Depends(verify_admin)): + """更新完整配置""" + try: + new_config = await request.json() + + # 更新 keys + if "keys" in new_config: + CONFIG["keys"] = new_config["keys"] + + # 更新 accounts + if "accounts" in new_config: + CONFIG["accounts"] = new_config["accounts"] + init_account_queue() # 重新初始化账号队列 + + # 更新 claude_model_mapping + if "claude_model_mapping" in new_config: + CONFIG["claude_model_mapping"] = new_config["claude_model_mapping"] + + save_config(CONFIG) + return JSONResponse(content={"success": True, "message": "配置已更新"}) + except Exception as e: + logger.error(f"[update_config] 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ---------------------------------------------------------------------- +# API Keys 管理 +# ---------------------------------------------------------------------- +@router.post("/keys") +async def add_key(request: Request, _: bool = Depends(verify_admin)): + """添加 API Key""" + data = await request.json() + key = data.get("key", "").strip() + if not key: + raise HTTPException(status_code=400, detail="Key 不能为空") + if key in CONFIG.get("keys", []): + raise HTTPException(status_code=400, detail="Key 已存在") + + if "keys" not in CONFIG: + CONFIG["keys"] = [] + CONFIG["keys"].append(key) + save_config(CONFIG) + return JSONResponse(content={"success": True}) + + +@router.delete("/keys/{key}") +async def delete_key(key: str, _: bool = Depends(verify_admin)): + """删除 API Key""" + if key not in CONFIG.get("keys", []): + raise HTTPException(status_code=404, detail="Key 不存在") + CONFIG["keys"].remove(key) + save_config(CONFIG) + return JSONResponse(content={"success": True}) + + +# ---------------------------------------------------------------------- +# 账号管理 +# ---------------------------------------------------------------------- +@router.post("/accounts") +async def add_account(request: Request, _: bool = Depends(verify_admin)): + """添加账号""" + data = await request.json() + email = data.get("email", "").strip() + mobile = data.get("mobile", "").strip() + password = data.get("password", "").strip() + + if not password: + raise HTTPException(status_code=400, detail="密码不能为空") + if not email and not mobile: + raise HTTPException(status_code=400, detail="Email 或手机号至少填一个") + + # 检查重复 + for acc in CONFIG.get("accounts", []): + if email and acc.get("email") == email: + raise HTTPException(status_code=400, detail="该 Email 已存在") + if mobile and acc.get("mobile") == mobile: + raise HTTPException(status_code=400, detail="该手机号已存在") + + new_account = {"password": password, "token": ""} + if email: + new_account["email"] = email + if mobile: + new_account["mobile"] = mobile + + if "accounts" not in CONFIG: + CONFIG["accounts"] = [] + CONFIG["accounts"].append(new_account) + init_account_queue() + save_config(CONFIG) + return JSONResponse(content={"success": True}) + + +@router.delete("/accounts/{identifier}") +async def delete_account(identifier: str, _: bool = Depends(verify_admin)): + """删除账号(通过 email 或 mobile)""" + accounts = CONFIG.get("accounts", []) + for i, acc in enumerate(accounts): + if acc.get("email") == identifier or acc.get("mobile") == identifier: + accounts.pop(i) + init_account_queue() + save_config(CONFIG) + return JSONResponse(content={"success": True}) + raise HTTPException(status_code=404, detail="账号不存在") + + +# ---------------------------------------------------------------------- +# 批量导入 +# ---------------------------------------------------------------------- +@router.post("/import") +async def batch_import(request: Request, _: bool = Depends(verify_admin)): + """批量导入配置 (JSON 格式)""" + try: + data = await request.json() + imported_keys = 0 + imported_accounts = 0 + + # 导入 keys + if "keys" in data: + for key in data["keys"]: + if key not in CONFIG.get("keys", []): + if "keys" not in CONFIG: + CONFIG["keys"] = [] + CONFIG["keys"].append(key) + imported_keys += 1 + + # 导入 accounts + if "accounts" in data: + existing_ids = set() + for acc in CONFIG.get("accounts", []): + existing_ids.add(acc.get("email", "")) + existing_ids.add(acc.get("mobile", "")) + + for acc in data["accounts"]: + acc_id = acc.get("email", "") or acc.get("mobile", "") + if acc_id and acc_id not in existing_ids: + if "accounts" not in CONFIG: + CONFIG["accounts"] = [] + CONFIG["accounts"].append(acc) + existing_ids.add(acc_id) + imported_accounts += 1 + + init_account_queue() + save_config(CONFIG) + + return JSONResponse(content={ + "success": True, + "imported_keys": imported_keys, + "imported_accounts": imported_accounts, + }) + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="无效的 JSON 格式") + except Exception as e: + logger.error(f"[batch_import] 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ---------------------------------------------------------------------- +# API 测试 +# ---------------------------------------------------------------------- +@router.post("/test") +async def test_api(request: Request, _: bool = Depends(verify_admin)): + """测试 API 调用""" + try: + data = await request.json() + model = data.get("model", "deepseek-chat") + message = data.get("message", "你好") + api_key = data.get("api_key", "") + + if not api_key: + # 使用配置中的第一个 key + keys = CONFIG.get("keys", []) + if not keys: + raise HTTPException(status_code=400, detail="没有可用的 API Key") + api_key = keys[0] + + # 构造请求 + host = request.headers.get("host", "localhost:5001") + scheme = "https" if "vercel" in host.lower() else "http" + base_url = f"{scheme}://{host}" + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{base_url}/v1/chat/completions", + headers={"Authorization": f"Bearer {api_key}"}, + json={ + "model": model, + "messages": [{"role": "user", "content": message}], + "stream": False, + }, + ) + + return JSONResponse(content={ + "success": response.status_code == 200, + "status_code": response.status_code, + "response": response.json() if response.status_code == 200 else response.text, + }) + except Exception as e: + logger.error(f"[test_api] 错误: {e}") + return JSONResponse(content={ + "success": False, + "error": str(e), + }) + + +# ---------------------------------------------------------------------- +# Vercel 同步 +# ---------------------------------------------------------------------- +@router.post("/vercel/sync") +async def sync_to_vercel(request: Request, _: bool = Depends(verify_admin)): + """同步配置到 Vercel 并触发重新部署""" + try: + data = await request.json() + vercel_token = data.get("vercel_token", "") + project_id = data.get("project_id", "") + team_id = data.get("team_id", "") # 可选 + + # 支持使用预配置的 token + if vercel_token == "__USE_PRECONFIG__" or not vercel_token: + vercel_token = VERCEL_TOKEN + if not project_id: + project_id = VERCEL_PROJECT_ID + if not team_id: + team_id = VERCEL_TEAM_ID + + if not vercel_token or not project_id: + raise HTTPException(status_code=400, detail="需要 Vercel Token 和 Project ID(可通过环境变量 VERCEL_TOKEN 和 VERCEL_PROJECT_ID 预配置)") + + # 准备配置 JSON + config_json = json.dumps(CONFIG, ensure_ascii=False, separators=(",", ":")) + config_b64 = base64.b64encode(config_json.encode("utf-8")).decode("utf-8") + + headers = {"Authorization": f"Bearer {vercel_token}"} + base_url = "https://api.vercel.com" + + async with httpx.AsyncClient(timeout=30.0) as client: + # 1. 获取现有环境变量 + params = {"teamId": team_id} if team_id else {} + env_resp = await client.get( + f"{base_url}/v9/projects/{project_id}/env", + headers=headers, + params=params, + ) + + if env_resp.status_code != 200: + raise HTTPException(status_code=env_resp.status_code, detail=f"获取环境变量失败: {env_resp.text}") + + env_vars = env_resp.json().get("envs", []) + existing_env = None + for env in env_vars: + if env.get("key") == "DS2API_CONFIG_JSON": + existing_env = env + break + + # 2. 更新或创建环境变量 + if existing_env: + # 更新 + env_id = existing_env["id"] + update_resp = await client.patch( + f"{base_url}/v9/projects/{project_id}/env/{env_id}", + headers=headers, + params=params, + json={"value": config_b64}, + ) + if update_resp.status_code not in [200, 201]: + raise HTTPException(status_code=update_resp.status_code, detail=f"更新环境变量失败: {update_resp.text}") + else: + # 创建 + create_resp = await client.post( + f"{base_url}/v10/projects/{project_id}/env", + headers=headers, + params=params, + json={ + "key": "DS2API_CONFIG_JSON", + "value": config_b64, + "type": "encrypted", + "target": ["production", "preview"], + }, + ) + if create_resp.status_code not in [200, 201]: + raise HTTPException(status_code=create_resp.status_code, detail=f"创建环境变量失败: {create_resp.text}") + + # 3. 触发重新部署 (获取最新的 git 信息并创建新部署) + # 获取项目信息 + project_resp = await client.get( + f"{base_url}/v9/projects/{project_id}", + headers=headers, + params=params, + ) + + if project_resp.status_code == 200: + project_data = project_resp.json() + repo = project_data.get("link", {}) + + if repo.get("type") == "github": + # 使用 GitHub 信息创建部署 + deploy_resp = await client.post( + f"{base_url}/v13/deployments", + headers=headers, + params=params, + json={ + "name": project_id, + "project": project_id, + "target": "production", + "gitSource": { + "type": "github", + "repoId": repo.get("repoId"), + "ref": repo.get("productionBranch", "main"), + }, + }, + ) + + if deploy_resp.status_code in [200, 201]: + deploy_data = deploy_resp.json() + return JSONResponse(content={ + "success": True, + "message": "配置已同步,正在重新部署...", + "deployment_url": deploy_data.get("url"), + }) + + # 如果无法自动部署,返回成功但提示手动部署 + return JSONResponse(content={ + "success": True, + "message": "配置已同步到 Vercel,请手动触发重新部署", + "manual_deploy_required": True, + }) + + except HTTPException: + raise + except Exception as e: + logger.error(f"[sync_to_vercel] 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ---------------------------------------------------------------------- +# 导出配置 +# ---------------------------------------------------------------------- +@router.get("/export") +async def export_config(_: bool = Depends(verify_admin)): + """导出完整配置(JSON 和 Base64)""" + config_json = json.dumps(CONFIG, ensure_ascii=False, separators=(",", ":")) + config_b64 = base64.b64encode(config_json.encode("utf-8")).decode("utf-8") + + return JSONResponse(content={ + "json": config_json, + "base64": config_b64, + }) diff --git a/routes/claude.py b/routes/claude.py new file mode 100644 index 0000000..c0ae735 --- /dev/null +++ b/routes/claude.py @@ -0,0 +1,590 @@ +# -*- coding: utf-8 -*- +"""Claude API 路由""" +import json +import random +import re +import time + +from curl_cffi import requests as cffi_requests +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from core.config import CONFIG, logger +from core.auth import ( + determine_claude_mode_and_token, + get_auth_headers, +) +from core.deepseek import call_completion_endpoint +from core.session_manager import ( + create_session, + get_pow, + get_model_config, + cleanup_account, +) +from core.messages import ( + messages_prepare, + convert_claude_to_deepseek, + CLAUDE_DEFAULT_MODEL, +) + +router = APIRouter() + +# 预编译正则表达式(性能优化) +_TOOL_CALL_PATTERN = re.compile(r'\{\s*["\']tool_calls["\']\s*:\s*\[(.*?)\]\s*\}', re.DOTALL) + + +# ---------------------------------------------------------------------- +# 通过 OpenAI 接口调用 Claude +# ---------------------------------------------------------------------- +async def call_claude_via_openai(request: Request, claude_payload: dict): + """通过现有OpenAI接口调用Claude(实际调用DeepSeek)""" + deepseek_payload = convert_claude_to_deepseek(claude_payload) + + try: + session_id = create_session(request) + if not session_id: + raise HTTPException(status_code=401, detail="invalid token.") + + pow_resp = get_pow(request) + if not pow_resp: + raise HTTPException( + status_code=401, + detail="Failed to get PoW (invalid token or unknown error).", + ) + + model = deepseek_payload.get("model", "deepseek-chat") + messages = deepseek_payload.get("messages", []) + + # 使用会话管理器获取模型配置 + thinking_enabled, search_enabled = get_model_config(model) + if thinking_enabled is None: + # 默认配置 + thinking_enabled = False + search_enabled = False + + final_prompt = messages_prepare(messages) + + headers = {**get_auth_headers(request), "x-ds-pow-response": pow_resp} + payload = { + "chat_session_id": session_id, + "parent_message_id": None, + "prompt": final_prompt, + "ref_file_ids": [], + "thinking_enabled": thinking_enabled, + "search_enabled": search_enabled, + } + + deepseek_resp = call_completion_endpoint(payload, headers, max_attempts=3) + return deepseek_resp + + except Exception as e: + logger.error(f"[call_claude_via_openai] 调用失败: {e}") + return None + + +# ---------------------------------------------------------------------- +# Claude 路由:模型列表 +# ---------------------------------------------------------------------- +@router.get("/anthropic/v1/models") +def list_claude_models(): + models_list = [ + { + "id": "claude-sonnet-4-20250514", + "object": "model", + "created": 1715635200, + "owned_by": "anthropic", + }, + { + "id": "claude-sonnet-4-20250514-fast", + "object": "model", + "created": 1715635200, + "owned_by": "anthropic", + }, + { + "id": "claude-sonnet-4-20250514-slow", + "object": "model", + "created": 1715635200, + "owned_by": "anthropic", + }, + ] + data = {"object": "list", "data": models_list} + return JSONResponse(content=data, status_code=200) + + +# ---------------------------------------------------------------------- +# Claude 路由:/anthropic/v1/messages +# ---------------------------------------------------------------------- +@router.post("/anthropic/v1/messages") +async def claude_messages(request: Request): + try: + try: + determine_claude_mode_and_token(request) + except HTTPException as exc: + return JSONResponse( + status_code=exc.status_code, content={"error": exc.detail} + ) + except Exception as exc: + logger.error(f"[claude_messages] determine_claude_mode_and_token 异常: {exc}") + return JSONResponse( + status_code=500, content={"error": "Claude authentication failed."} + ) + + req_data = await request.json() + model = req_data.get("model") + messages = req_data.get("messages", []) + + if not model or not messages: + raise HTTPException( + status_code=400, detail="Request must include 'model' and 'messages'." + ) + + # 标准化消息内容 + normalized_messages = [] + for message in messages: + normalized_message = message.copy() + if isinstance(message.get("content"), list): + content_parts = [] + for content_block in message["content"]: + if content_block.get("type") == "text" and "text" in content_block: + content_parts.append(content_block["text"]) + elif content_block.get("type") == "tool_result": + if "content" in content_block: + content_parts.append(str(content_block["content"])) + if content_parts: + normalized_message["content"] = "\n".join(content_parts) + elif isinstance(message.get("content"), list) and message["content"]: + normalized_message["content"] = message["content"] + else: + normalized_message["content"] = "" + normalized_messages.append(normalized_message) + + tools_requested = req_data.get("tools") or [] + has_tools = len(tools_requested) > 0 + + payload = req_data.copy() + payload["messages"] = normalized_messages.copy() + + # 如果有工具定义,添加工具使用指导的系统消息 + if has_tools and not any(m.get("role") == "system" for m in payload["messages"]): + tool_schemas = [] + for tool in tools_requested: + tool_name = tool.get("name", "unknown") + tool_desc = tool.get("description", "No description available") + schema = tool.get("input_schema", {}) + + tool_info = f"Tool: {tool_name}\nDescription: {tool_desc}" + if "properties" in schema: + props = [] + required = schema.get("required", []) + for prop_name, prop_info in schema["properties"].items(): + prop_type = prop_info.get("type", "string") + is_req = " (required)" if prop_name in required else "" + props.append(f" - {prop_name}: {prop_type}{is_req}") + if props: + tool_info += f"\nParameters:\n{chr(10).join(props)}" + tool_schemas.append(tool_info) + + system_message = { + "role": "system", + "content": f"""You are Claude, a helpful AI assistant. You have access to these tools: + +{chr(10).join(tool_schemas)} + +When you need to use tools, you can call multiple tools in a single response. Use this format: + +{{"tool_calls": [ + {{"name": "tool1", "input": {{"param": "value"}}}}, + {{"name": "tool2", "input": {{"param": "value"}}}} +]}} + +IMPORTANT: You can call multiple tools in ONE response. + +Remember: Output ONLY the JSON, no other text. The response must start with {{ and end with ]}}""", + } + payload["messages"].insert(0, system_message) + + deepseek_resp = await call_claude_via_openai(request, payload) + if not deepseek_resp: + raise HTTPException(status_code=500, detail="Failed to get Claude response.") + + if deepseek_resp.status_code != 200: + deepseek_resp.close() + return JSONResponse( + status_code=500, + content={"error": {"type": "api_error", "message": "Failed to get response"}}, + ) + + # 流式响应或普通响应 + if bool(req_data.get("stream", False)): + + def claude_sse_stream(): + # 智能超时配置 + STREAM_IDLE_TIMEOUT = 30 # 无新内容超时(秒) + + try: + message_id = f"msg_{int(time.time())}_{random.randint(1000, 9999)}" + input_tokens = sum(len(str(m.get("content", ""))) for m in messages) // 4 + output_tokens = 0 + full_response_text = "" + last_content_time = time.time() + has_content = False + + for line in deepseek_resp.iter_lines(): + current_time = time.time() + + # 智能超时检测 + if has_content and (current_time - last_content_time) > STREAM_IDLE_TIMEOUT: + logger.warning(f"[claude_sse_stream] 智能超时: 已有内容但 {STREAM_IDLE_TIMEOUT}s 无新数据,强制结束") + break + + if not line: + continue + try: + line_str = line.decode("utf-8") + except Exception: + continue + + if line_str.startswith("data:"): + data_str = line_str[5:].strip() + if data_str == "[DONE]": + break + + try: + chunk = json.loads(data_str) + + # 检测内容审核/敏感词阻止 + if "error" in chunk or chunk.get("code") == "content_filter": + logger.warning(f"[claude_sse_stream] 检测到内容过滤: {chunk}") + break + + if "v" in chunk and isinstance(chunk["v"], str): + content = chunk["v"] + # 检查是否是 FINISHED 状态 + if content == "FINISHED": + break + full_response_text += content + if content: + has_content = True + last_content_time = current_time + elif "v" in chunk and isinstance(chunk["v"], list): + for item in chunk["v"]: + if item.get("p") == "status" and item.get("v") == "FINISHED": + break + except (json.JSONDecodeError, KeyError): + continue + + # 发送Claude格式的事件 + message_start = { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": input_tokens, "output_tokens": 0}, + }, + } + yield f"data: {json.dumps(message_start)}\n\n" + + # 检查工具调用 + detected_tools = [] + cleaned_response = full_response_text.strip() + + if cleaned_response.startswith('{"tool_calls":') and cleaned_response.endswith("]}"): + try: + tool_data = json.loads(cleaned_response) + for tool_call in tool_data.get("tool_calls", []): + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) + if any(tool.get("name") == tool_name for tool in tools_requested): + detected_tools.append({"name": tool_name, "input": tool_input}) + except json.JSONDecodeError: + pass + + if not detected_tools: + # 使用预编译的正则表达式 + matches = _TOOL_CALL_PATTERN.findall(cleaned_response) + for match in matches: + try: + tool_calls_json = f'{{"tool_calls": [{match}]}}' + tool_data = json.loads(tool_calls_json) + for tool_call in tool_data.get("tool_calls", []): + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) + if any(tool.get("name") == tool_name for tool in tools_requested): + detected_tools.append({"name": tool_name, "input": tool_input}) + except json.JSONDecodeError: + continue + + content_index = 0 + if detected_tools: + stop_reason = "tool_use" + for tool_info in detected_tools: + tool_use_id = f"toolu_{int(time.time())}_{random.randint(1000, 9999)}_{content_index}" + tool_name = tool_info["name"] + tool_input = tool_info["input"] + + yield f"data: {json.dumps({'type': 'content_block_start', 'index': content_index, 'content_block': {'type': 'tool_use', 'id': tool_use_id, 'name': tool_name, 'input': tool_input}})}\n\n" + yield f"data: {json.dumps({'type': 'content_block_stop', 'index': content_index})}\n\n" + + content_index += 1 + output_tokens += len(str(tool_input)) // 4 + else: + stop_reason = "end_turn" + if full_response_text: + yield f"data: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + yield f"data: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': full_response_text}})}\n\n" + yield f"data: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" + output_tokens += len(full_response_text) // 4 + + yield f"data: {json.dumps({'type': 'message_delta', 'delta': {'stop_reason': stop_reason, 'stop_sequence': None}, 'usage': {'output_tokens': output_tokens}})}\n\n" + yield f"data: {json.dumps({'type': 'message_stop'})}\n\n" + + except Exception as e: + logger.error(f"[claude_sse_stream] 异常: {e}") + error_event = { + "type": "error", + "error": {"type": "api_error", "message": f"Stream processing error: {str(e)}"}, + } + yield f"data: {json.dumps(error_event)}\n\n" + finally: + try: + deepseek_resp.close() + except Exception: + pass + cleanup_account(request) + + return StreamingResponse( + claude_sse_stream(), + media_type="text/event-stream", + headers={"Content-Type": "text/event-stream"}, + ) + else: + # 非流式响应处理 + try: + final_content = "" + final_reasoning = "" + + for line in deepseek_resp.iter_lines(): + if not line: + continue + try: + line_str = line.decode("utf-8") + except Exception as e: + logger.warning(f"[claude_messages] 行解码失败: {e}") + continue + + if line_str.startswith("data:"): + data_str = line_str[5:].strip() + if data_str == "[DONE]": + break + + try: + chunk = json.loads(data_str) + if "v" in chunk: + v_value = chunk["v"] + if "p" in chunk and chunk.get("p") == "response/search_status": + continue + ptype = "text" + if "p" in chunk and chunk.get("p") == "response/thinking_content": + ptype = "thinking" + elif "p" in chunk and chunk.get("p") == "response/content": + ptype = "text" + if isinstance(v_value, str): + if ptype == "thinking": + final_reasoning += v_value + else: + final_content += v_value + elif isinstance(v_value, list): + for item in v_value: + if item.get("p") == "status" and item.get("v") == "FINISHED": + break + except json.JSONDecodeError as e: + logger.warning(f"[claude_messages] JSON解析失败: {e}") + continue + except Exception as e: + logger.warning(f"[claude_messages] chunk处理失败: {e}") + continue + + try: + deepseek_resp.close() + except Exception as e: + logger.warning(f"[claude_messages] 关闭响应异常: {e}") + + # 检查工具调用 + detected_tools = [] + cleaned_content = final_content.strip() + + if cleaned_content.startswith('{"tool_calls":') and cleaned_content.endswith("]}"): + try: + tool_data = json.loads(cleaned_content) + for tool_call in tool_data.get("tool_calls", []): + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) + if any(tool.get("name") == tool_name for tool in tools_requested): + detected_tools.append({"name": tool_name, "input": tool_input}) + except json.JSONDecodeError: + pass + + if not detected_tools: + # 使用预编译的正则表达式 + matches = _TOOL_CALL_PATTERN.findall(cleaned_content) + for match in matches: + try: + tool_calls_json = f'{{"tool_calls": [{match}]}}' + tool_data = json.loads(tool_calls_json) + for tool_call in tool_data.get("tool_calls", []): + tool_name = tool_call.get("name") + tool_input = tool_call.get("input", {}) + if any(tool.get("name") == tool_name for tool in tools_requested): + detected_tools.append({"name": tool_name, "input": tool_input}) + except json.JSONDecodeError: + continue + + # 构造响应 + claude_response = { + "id": f"msg_{int(time.time())}_{random.randint(1000, 9999)}", + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": "tool_use" if detected_tools else "end_turn", + "stop_sequence": None, + "usage": { + "input_tokens": len(str(normalized_messages)) // 4, + "output_tokens": (len(final_content) + len(final_reasoning)) // 4, + }, + } + + if final_reasoning: + claude_response["content"].append({"type": "thinking", "thinking": final_reasoning}) + + if detected_tools: + for i, tool_info in enumerate(detected_tools): + tool_use_id = f"toolu_{int(time.time())}_{random.randint(1000, 9999)}_{i}" + claude_response["content"].append({ + "type": "tool_use", + "id": tool_use_id, + "name": tool_info["name"], + "input": tool_info["input"], + }) + else: + if final_content or not final_reasoning: + claude_response["content"].append({ + "type": "text", + "text": final_content or "抱歉,没有生成有效的响应内容。", + }) + + return JSONResponse(content=claude_response, status_code=200) + + except Exception as e: + logger.error(f"[claude_messages] 非流式响应处理异常: {e}") + try: + deepseek_resp.close() + except Exception as close_e: + logger.warning(f"[claude_messages] 关闭响应异常2: {close_e}") + return JSONResponse( + status_code=500, + content={"error": {"type": "api_error", "message": "Response processing error"}}, + ) + + except HTTPException as exc: + return JSONResponse( + status_code=exc.status_code, + content={"error": {"type": "invalid_request_error", "message": exc.detail}}, + ) + except Exception as exc: + logger.error(f"[claude_messages] 未知异常: {exc}") + return JSONResponse( + status_code=500, + content={"error": {"type": "api_error", "message": "Internal Server Error"}}, + ) + finally: + cleanup_account(request) + + +# ---------------------------------------------------------------------- +# Claude 路由:/anthropic/v1/messages/count_tokens +# ---------------------------------------------------------------------- +@router.post("/anthropic/v1/messages/count_tokens") +async def claude_count_tokens(request: Request): + try: + try: + determine_claude_mode_and_token(request) + except HTTPException as exc: + return JSONResponse(status_code=exc.status_code, content={"error": exc.detail}) + except Exception as exc: + logger.error(f"[claude_count_tokens] determine_claude_mode_and_token 异常: {exc}") + return JSONResponse(status_code=500, content={"error": "Claude authentication failed."}) + + req_data = await request.json() + model = req_data.get("model") + messages = req_data.get("messages", []) + system = req_data.get("system", "") + + if not model or not messages: + raise HTTPException( + status_code=400, detail="Request must include 'model' and 'messages'." + ) + + def estimate_tokens(text): + if isinstance(text, str): + return len(text) // 4 + elif isinstance(text, list): + return sum( + estimate_tokens(item.get("text", "")) + if isinstance(item, dict) + else estimate_tokens(str(item)) + for item in text + ) + else: + return len(str(text)) // 4 + + input_tokens = 0 + + if system: + input_tokens += estimate_tokens(system) + + for message in messages: + content = message.get("content", "") + input_tokens += 2 # 角色标记 + + if isinstance(content, list): + for content_block in content: + if isinstance(content_block, dict): + if content_block.get("type") == "text": + input_tokens += estimate_tokens(content_block.get("text", "")) + elif content_block.get("type") == "tool_result": + input_tokens += estimate_tokens(content_block.get("content", "")) + else: + input_tokens += estimate_tokens(str(content_block)) + else: + input_tokens += estimate_tokens(str(content_block)) + else: + input_tokens += estimate_tokens(content) + + tools = req_data.get("tools", []) + if tools: + for tool in tools: + input_tokens += estimate_tokens(tool.get("name", "")) + input_tokens += estimate_tokens(tool.get("description", "")) + input_schema = tool.get("input_schema", {}) + input_tokens += estimate_tokens(json.dumps(input_schema, ensure_ascii=False)) + + response = {"input_tokens": max(1, input_tokens)} + return JSONResponse(content=response, status_code=200) + + except HTTPException as exc: + return JSONResponse( + status_code=exc.status_code, + content={"error": {"type": "invalid_request_error", "message": exc.detail}}, + ) + except Exception as exc: + logger.error(f"[claude_count_tokens] 未知异常: {exc}") + return JSONResponse( + status_code=500, + content={"error": {"type": "api_error", "message": "Internal Server Error"}}, + ) diff --git a/routes/home.py b/routes/home.py new file mode 100644 index 0000000..cb08c73 --- /dev/null +++ b/routes/home.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +"""首页路由""" +from fastapi import APIRouter, Request +from fastapi.templating import Jinja2Templates + +from core.config import TEMPLATES_DIR + +router = APIRouter() +templates = Jinja2Templates(directory=TEMPLATES_DIR) + + +@router.get("/") +def index(request: Request): + return templates.TemplateResponse("welcome.html", {"request": request}) diff --git a/routes/openai.py b/routes/openai.py new file mode 100644 index 0000000..1b44d08 --- /dev/null +++ b/routes/openai.py @@ -0,0 +1,525 @@ +# -*- coding: utf-8 -*- +"""OpenAI 兼容路由""" +import json +import queue +import re +import threading +import time + +from curl_cffi import requests as cffi_requests +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from core.config import CONFIG, logger +from core.auth import ( + determine_mode_and_token, + get_auth_headers, + release_account, +) +from core.deepseek import call_completion_endpoint +from core.session_manager import ( + create_session, + get_pow, + get_model_config, + cleanup_account, +) +from core.messages import messages_prepare + +router = APIRouter() + +# 添加保活超时配置(5秒) +KEEP_ALIVE_TIMEOUT = 5 + +# 预编译正则表达式(性能优化) +_CITATION_PATTERN = re.compile(r"^\[citation:") + + +# ---------------------------------------------------------------------- +# 路由:/v1/models +# ---------------------------------------------------------------------- +@router.get("/v1/models") +def list_models(): + models_list = [ + { + "id": "deepseek-chat", + "object": "model", + "created": 1677610602, + "owned_by": "deepseek", + "permission": [], + }, + { + "id": "deepseek-reasoner", + "object": "model", + "created": 1677610602, + "owned_by": "deepseek", + "permission": [], + }, + { + "id": "deepseek-chat-search", + "object": "model", + "created": 1677610602, + "owned_by": "deepseek", + "permission": [], + }, + { + "id": "deepseek-reasoner-search", + "object": "model", + "created": 1677610602, + "owned_by": "deepseek", + "permission": [], + }, + ] + data = {"object": "list", "data": models_list} + return JSONResponse(content=data, status_code=200) + + +# ---------------------------------------------------------------------- +# 路由:/v1/chat/completions +# ---------------------------------------------------------------------- +@router.post("/v1/chat/completions") +async def chat_completions(request: Request): + try: + # 处理 token 相关逻辑,若登录失败则直接返回错误响应 + try: + determine_mode_and_token(request) + except HTTPException as exc: + return JSONResponse( + status_code=exc.status_code, content={"error": exc.detail} + ) + except Exception as exc: + logger.error(f"[chat_completions] determine_mode_and_token 异常: {exc}") + return JSONResponse( + status_code=500, content={"error": "Account login failed."} + ) + + req_data = await request.json() + model = req_data.get("model") + messages = req_data.get("messages", []) + if not model or not messages: + raise HTTPException( + status_code=400, detail="Request must include 'model' and 'messages'." + ) + + # 使用会话管理器获取模型配置 + thinking_enabled, search_enabled = get_model_config(model) + if thinking_enabled is None: + raise HTTPException( + status_code=503, detail=f"Model '{model}' is not available." + ) + + # 使用 messages_prepare 函数构造最终 prompt + final_prompt = messages_prepare(messages) + session_id = create_session(request) + if not session_id: + raise HTTPException(status_code=401, detail="invalid token.") + + pow_resp = get_pow(request) + if not pow_resp: + raise HTTPException( + status_code=401, + detail="Failed to get PoW (invalid token or unknown error).", + ) + + headers = {**get_auth_headers(request), "x-ds-pow-response": pow_resp} + payload = { + "chat_session_id": session_id, + "parent_message_id": None, + "prompt": final_prompt, + "ref_file_ids": [], + "thinking_enabled": thinking_enabled, + "search_enabled": search_enabled, + } + + deepseek_resp = call_completion_endpoint(payload, headers, max_attempts=3) + if not deepseek_resp: + raise HTTPException(status_code=500, detail="Failed to get completion.") + created_time = int(time.time()) + completion_id = f"{session_id}" + + # 流式响应(SSE)或普通响应 + if bool(req_data.get("stream", False)): + if deepseek_resp.status_code != 200: + deepseek_resp.close() + return JSONResponse( + content=deepseek_resp.content, status_code=deepseek_resp.status_code + ) + + def sse_stream(): + # 智能超时配置 + STREAM_IDLE_TIMEOUT = 30 # 无新内容超时(秒) + MAX_KEEPALIVE_COUNT = 10 # 最大连续 keepalive 次数 + + try: + final_text = "" + final_thinking = "" + first_chunk_sent = False + result_queue = queue.Queue() + last_send_time = time.time() + last_content_time = time.time() # 最后收到有效内容的时间 + keepalive_count = 0 # 连续 keepalive 计数 + has_content = False # 是否收到过内容 + + def process_data(): + nonlocal has_content + ptype = "text" + try: + for raw_line in deepseek_resp.iter_lines(): + try: + line = raw_line.decode("utf-8") + except Exception as e: + logger.warning(f"[sse_stream] 解码失败: {e}") + error_type = "thinking" if ptype == "thinking" else "text" + busy_content_str = f'{{"choices":[{{"index":0,"delta":{{"content":"解码失败,请稍候再试","type":"{error_type}"}}}}],"model":"","chunk_token_usage":1,"created":0,"message_id":-1,"parent_id":-1}}' + try: + busy_content = json.loads(busy_content_str) + result_queue.put(busy_content) + except json.JSONDecodeError: + result_queue.put({"choices": [{"index": 0, "delta": {"content": "解码失败", "type": "text"}}]}) + result_queue.put(None) + break + if not line: + continue + if line.startswith("data:"): + data_str = line[5:].strip() + if data_str == "[DONE]": + result_queue.put(None) + break + try: + chunk = json.loads(data_str) + + # 检测内容审核/敏感词阻止 + if "error" in chunk or chunk.get("code") == "content_filter": + logger.warning(f"[sse_stream] 检测到内容过滤: {chunk}") + result_queue.put({"choices": [{"index": 0, "finish_reason": "content_filter"}]}) + result_queue.put(None) + return + + if "v" in chunk: + v_value = chunk["v"] + content = "" + if "p" in chunk and chunk.get("p") == "response/search_status": + continue + if "p" in chunk and chunk.get("p") == "response/thinking_content": + ptype = "thinking" + elif "p" in chunk and chunk.get("p") == "response/content": + ptype = "text" + if isinstance(v_value, str): + # 检查是否是 FINISHED 状态 + if v_value == "FINISHED": + result_queue.put({"choices": [{"index": 0, "finish_reason": "stop"}]}) + result_queue.put(None) + return + content = v_value + if content: + has_content = True + elif isinstance(v_value, list): + for item in v_value: + if item.get("p") == "status" and item.get("v") == "FINISHED": + result_queue.put({"choices": [{"index": 0, "finish_reason": "stop"}]}) + result_queue.put(None) + return + continue + unified_chunk = { + "choices": [{ + "index": 0, + "delta": {"content": content, "type": ptype} + }], + "model": "", + "chunk_token_usage": len(content) // 4, + "created": 0, + "message_id": -1, + "parent_id": -1 + } + result_queue.put(unified_chunk) + except Exception as e: + logger.warning(f"[sse_stream] 无法解析: {data_str}, 错误: {e}") + error_type = "thinking" if ptype == "thinking" else "text" + busy_content_str = f'{{"choices":[{{"index":0,"delta":{{"content":"解析失败,请稍候再试","type":"{error_type}"}}}}],"model":"","chunk_token_usage":1,"created":0,"message_id":-1,"parent_id":-1}}' + try: + busy_content = json.loads(busy_content_str) + result_queue.put(busy_content) + except json.JSONDecodeError: + result_queue.put({"choices": [{"index": 0, "delta": {"content": "解析失败", "type": "text"}}]}) + result_queue.put(None) + break + except Exception as e: + logger.warning(f"[sse_stream] 错误: {e}") + try: + error_response = {"choices": [{"index": 0, "delta": {"content": "服务器错误,请稍候再试", "type": "text"}}]} + result_queue.put(error_response) + except Exception: + pass + result_queue.put(None) + finally: + deepseek_resp.close() + + process_thread = threading.Thread(target=process_data) + process_thread.start() + + while True: + current_time = time.time() + + # 智能超时检测:如果已有内容且长时间无新数据,强制结束 + if has_content and (current_time - last_content_time) > STREAM_IDLE_TIMEOUT: + logger.warning(f"[sse_stream] 智能超时: 已有内容但 {STREAM_IDLE_TIMEOUT}s 无新数据,强制结束") + break + + # 连续 keepalive 检测:如果已有内容且连续多次 keepalive,强制结束 + if has_content and keepalive_count >= MAX_KEEPALIVE_COUNT: + logger.warning(f"[sse_stream] 智能超时: 连续 {MAX_KEEPALIVE_COUNT} 次 keepalive,强制结束") + break + + if current_time - last_send_time >= KEEP_ALIVE_TIMEOUT: + yield ": keep-alive\n\n" + last_send_time = current_time + keepalive_count += 1 + continue + + try: + chunk = result_queue.get(timeout=0.05) + keepalive_count = 0 # 重置 keepalive 计数 + + if chunk is None: + prompt_tokens = len(final_prompt) // 4 + thinking_tokens = len(final_thinking) // 4 + completion_tokens = len(final_text) // 4 + usage = { + "prompt_tokens": prompt_tokens, + "completion_tokens": thinking_tokens + completion_tokens, + "total_tokens": prompt_tokens + thinking_tokens + completion_tokens, + "completion_tokens_details": {"reasoning_tokens": thinking_tokens}, + } + finish_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], + "usage": usage, + } + yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + last_send_time = current_time + break + + new_choices = [] + for choice in chunk.get("choices", []): + delta = choice.get("delta", {}) + ctype = delta.get("type") + ctext = delta.get("content", "") + if choice.get("finish_reason") == "backend_busy": + ctext = "服务器繁忙,请稍候再试" + if choice.get("finish_reason") == "content_filter": + # 内容过滤,正常结束 + pass + if search_enabled and ctext.startswith("[citation:"): + ctext = "" + if ctype == "thinking": + if thinking_enabled: + final_thinking += ctext + elif ctype == "text": + final_text += ctext + delta_obj = {} + if not first_chunk_sent: + delta_obj["role"] = "assistant" + first_chunk_sent = True + if ctype == "thinking": + if thinking_enabled: + delta_obj["reasoning_content"] = ctext + elif ctype == "text": + delta_obj["content"] = ctext + if delta_obj: + new_choices.append({"delta": delta_obj, "index": choice.get("index", 0)}) + + if new_choices: + last_content_time = current_time # 更新最后内容时间 + out_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": new_choices, + } + yield f"data: {json.dumps(out_chunk, ensure_ascii=False)}\n\n" + last_send_time = current_time + except queue.Empty: + continue + + # 如果是超时退出,也发送结束标记 + if has_content: + prompt_tokens = len(final_prompt) // 4 + thinking_tokens = len(final_thinking) // 4 + completion_tokens = len(final_text) // 4 + usage = { + "prompt_tokens": prompt_tokens, + "completion_tokens": thinking_tokens + completion_tokens, + "total_tokens": prompt_tokens + thinking_tokens + completion_tokens, + "completion_tokens_details": {"reasoning_tokens": thinking_tokens}, + } + finish_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], + "usage": usage, + } + yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + logger.error(f"[sse_stream] 异常: {e}") + finally: + cleanup_account(request) + + return StreamingResponse( + sse_stream(), + media_type="text/event-stream", + headers={"Content-Type": "text/event-stream"}, + ) + else: + # 非流式响应处理 + think_list = [] + text_list = [] + result = None + + data_queue = queue.Queue() + + def collect_data(): + nonlocal result + ptype = "text" + try: + for raw_line in deepseek_resp.iter_lines(): + try: + line = raw_line.decode("utf-8") + except Exception as e: + logger.warning(f"[chat_completions] 解码失败: {e}") + if ptype == "thinking": + think_list.append("解码失败,请稍候再试") + else: + text_list.append("解码失败,请稍候再试") + data_queue.put(None) + break + if not line: + continue + if line.startswith("data:"): + data_str = line[5:].strip() + if data_str == "[DONE]": + data_queue.put(None) + break + try: + chunk = json.loads(data_str) + if "v" in chunk: + v_value = chunk["v"] + if "p" in chunk and chunk.get("p") == "response/search_status": + continue + if "p" in chunk and chunk.get("p") == "response/thinking_content": + ptype = "thinking" + elif "p" in chunk and chunk.get("p") == "response/content": + ptype = "text" + if isinstance(v_value, str): + if search_enabled and v_value.startswith("[citation:"): + continue + if ptype == "thinking": + think_list.append(v_value) + else: + text_list.append(v_value) + elif isinstance(v_value, list): + for item in v_value: + if item.get("p") == "status" and item.get("v") == "FINISHED": + final_reasoning = "".join(think_list) + final_content = "".join(text_list) + prompt_tokens = len(final_prompt) // 4 + reasoning_tokens = len(final_reasoning) // 4 + completion_tokens = len(final_content) // 4 + result = { + "id": completion_id, + "object": "chat.completion", + "created": created_time, + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": final_content, + "reasoning_content": final_reasoning, + }, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": reasoning_tokens + completion_tokens, + "total_tokens": prompt_tokens + reasoning_tokens + completion_tokens, + "completion_tokens_details": {"reasoning_tokens": reasoning_tokens}, + }, + } + data_queue.put("DONE") + return + except Exception as e: + logger.warning(f"[collect_data] 无法解析: {data_str}, 错误: {e}") + if ptype == "thinking": + think_list.append("解析失败,请稍候再试") + else: + text_list.append("解析失败,请稍候再试") + data_queue.put(None) + break + except Exception as e: + logger.warning(f"[collect_data] 错误: {e}") + if ptype == "thinking": + think_list.append("处理失败,请稍候再试") + else: + text_list.append("处理失败,请稍候再试") + data_queue.put(None) + finally: + deepseek_resp.close() + if result is None: + final_content = "".join(text_list) + final_reasoning = "".join(think_list) + prompt_tokens = len(final_prompt) // 4 + reasoning_tokens = len(final_reasoning) // 4 + completion_tokens = len(final_content) // 4 + result = { + "id": completion_id, + "object": "chat.completion", + "created": created_time, + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": final_content, + "reasoning_content": final_reasoning, + }, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": reasoning_tokens + completion_tokens, + "total_tokens": prompt_tokens + reasoning_tokens + completion_tokens, + }, + } + data_queue.put("DONE") + + collect_thread = threading.Thread(target=collect_data) + collect_thread.start() + + def generate(): + last_send_time = time.time() + while True: + current_time = time.time() + if current_time - last_send_time >= KEEP_ALIVE_TIMEOUT: + yield "" + last_send_time = current_time + if not collect_thread.is_alive() and result is not None: + yield json.dumps(result) + break + time.sleep(0.1) + + return StreamingResponse(generate(), media_type="application/json") + except HTTPException as exc: + return JSONResponse(status_code=exc.status_code, content={"error": exc.detail}) + except Exception as exc: + logger.error(f"[chat_completions] 未知异常: {exc}") + return JSONResponse(status_code=500, content={"error": "Internal Server Error"}) + finally: + cleanup_account(request) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..6823132 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,133 @@ +# DS2API 测试文档 + +## 测试文件结构 + +``` +tests/ +├── __init__.py # 测试模块初始化 +├── test_unit.py # 单元测试(不依赖网络) +├── test_all.py # API 集成测试 +├── test_accounts.py # 账号池测试 +└── run_tests.sh # 测试运行脚本 +``` + +## 快速开始 + +### 运行所有测试 + +```bash +# 使用脚本 +./tests/run_tests.sh all + +# 或直接运行 +python3 tests/test_unit.py # 单元测试 +python3 tests/test_all.py # API 测试 +``` + +### 运行单元测试 + +```bash +python3 tests/test_unit.py +``` + +测试内容: +- 配置加载 +- 消息处理(`messages_prepare`) +- WASM 缓存 +- 模型配置获取 +- 正则表达式模式 + +### 运行 API 集成测试 + +```bash +# 完整测试 +python3 tests/test_all.py + +# 快速测试(跳过耗时测试) +python3 tests/test_all.py --quick + +# 指定端点 +python3 tests/test_all.py --endpoint http://your-server.com + +# 详细输出 +python3 tests/test_all.py --verbose +``` + +测试覆盖: + +| 类别 | 测试项 | +|-----|--------| +| 基础 | 服务健康检查 | +| OpenAI | 模型列表、非流式对话、流式对话、无效模型处理、认证错误 | +| Claude | 模型列表、非流式消息、流式消息、Token 计数 | +| 高级 | 多轮对话、长输入处理、Reasoner 模式 | + +### 运行账号测试 + +```bash +# 测试所有账号登录 +python3 tests/test_accounts.py --login + +# 测试账号轮换 +python3 tests/test_accounts.py --rotation + +# 运行所有 +python3 tests/test_accounts.py --all +``` + +## 配置 + +测试使用 `config.json` 中的配置: + +```json +{ + "keys": ["test-api-key-001"], + "accounts": [ + {"email": "xxx@gmail.com", "password": "xxx", "token": ""} + ] +} +``` + +## 预期输出 + +### 单元测试 + +``` +Ran 13 tests in 8.685s +OK +``` + +### API 测试 + +``` +📊 测试报告 +总计: 10 个测试 +✅ 通过: 10 +❌ 失败: 0 +⏱️ 耗时: 15.32s +📈 通过率: 100.0% +``` + +## 故障排除 + +### 服务未运行 + +``` +⚠️ 服务未运行,跳过其他测试 +``` + +解决:先启动服务 `python dev.py` + +### 认证失败 + +``` +❌ 失败: 状态码: 401 +``` + +解决:检查 `config.json` 中的 API key 和账号配置 + +### 流式测试超时 + +可能是 DeepSeek API 响应慢,可以尝试: +- 使用 `--quick` 模式 +- 增加测试超时时间 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..ab6dc64 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# DS2API 测试模块 diff --git a/tests/run_tests.sh b/tests/run_tests.sh new file mode 100755 index 0000000..69f2b24 --- /dev/null +++ b/tests/run_tests.sh @@ -0,0 +1,111 @@ +#!/bin/bash +# DS2API 测试运行器 + +set -e + +cd "$(dirname "$0")/.." + +echo "==================================================" +echo " 🧪 DS2API 测试套件" +echo "==================================================" +echo "" + +# 颜色 +GREEN='\033[0;32m' +RED='\033[0;31m' +YELLOW='\033[1;33m' +NC='\033[0m' + +# 检查服务是否运行 +check_service() { + echo -e "${YELLOW}检查服务状态...${NC}" + if curl -s http://localhost:5001/ > /dev/null 2>&1; then + echo -e "${GREEN}✅ 服务运行中${NC}" + return 0 + else + echo -e "${RED}❌ 服务未运行${NC}" + echo "请先启动服务: python dev.py" + return 1 + fi +} + +# 运行单元测试 +run_unit_tests() { + echo "" + echo "==================================================" + echo " 📋 单元测试" + echo "==================================================" + python3 -m pytest tests/test_unit.py -v --tb=short 2>/dev/null || python3 tests/test_unit.py +} + +# 运行 API 测试 +run_api_tests() { + echo "" + echo "==================================================" + echo " 🌐 API 集成测试" + echo "==================================================" + python3 tests/test_all.py "$@" +} + +# 运行账号测试 +run_account_tests() { + echo "" + echo "==================================================" + echo " 🔑 账号测试" + echo "==================================================" + python3 tests/test_accounts.py --all +} + +# 显示帮助 +show_help() { + echo "用法: $0 [选项]" + echo "" + echo "选项:" + echo " unit 只运行单元测试" + echo " api 只运行 API 测试" + echo " api --quick 快速 API 测试" + echo " accounts 只运行账号测试" + echo " all 运行所有测试" + echo " help 显示此帮助" + echo "" + echo "示例:" + echo " $0 unit" + echo " $0 api --quick" + echo " $0 all" +} + +# 主逻辑 +case "${1:-all}" in + unit) + run_unit_tests + ;; + api) + if check_service; then + shift + run_api_tests "$@" + fi + ;; + accounts) + run_account_tests + ;; + all) + run_unit_tests + echo "" + if check_service; then + run_api_tests --quick + fi + ;; + help|--help|-h) + show_help + ;; + *) + echo "未知选项: $1" + show_help + exit 1 + ;; +esac + +echo "" +echo "==================================================" +echo " ✨ 测试完成" +echo "==================================================" diff --git a/tests/test_accounts.py b/tests/test_accounts.py new file mode 100644 index 0000000..7721c02 --- /dev/null +++ b/tests/test_accounts.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +DS2API 账号池测试 + +测试账号登录和轮换功能 +""" +import argparse +import json +import os +import sys +import time +from dataclasses import dataclass +from typing import Optional + +# 添加项目根目录到路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +@dataclass +class AccountTestResult: + email: str + login_success: bool + has_token: bool + token_preview: str + error: Optional[str] = None + + +def test_account_login(account: dict) -> AccountTestResult: + """测试单个账号登录""" + from core.deepseek import login_deepseek_via_account + from core.config import logger + + email = account.get("email", account.get("mobile", "unknown")) + print(f"\n📧 测试账号: {email}") + print("-" * 40) + + try: + login_deepseek_via_account(account) + token = account.get("token", "") + + if token: + print(f"✅ 登录成功") + print(f" Token: {token[:30]}...{token[-10:]}") + return AccountTestResult( + email=email, + login_success=True, + has_token=True, + token_preview=f"{token[:30]}...{token[-10:]}" + ) + else: + print(f"⚠️ 登录完成但无 Token") + return AccountTestResult( + email=email, + login_success=True, + has_token=False, + token_preview="" + ) + except Exception as e: + print(f"❌ 登录失败: {e}") + return AccountTestResult( + email=email, + login_success=False, + has_token=False, + token_preview="", + error=str(e) + ) + + +def test_account_pool(): + """测试整个账号池""" + from core.config import CONFIG, logger + + accounts = CONFIG.get("accounts", []) + + if not accounts: + print("⚠️ 配置中没有账号") + return + + print("\n" + "=" * 60) + print(" 🔑 DS2API 账号池测试") + print("=" * 60) + print(f"共 {len(accounts)} 个账号\n") + + results = [] + for account in accounts: + result = test_account_login(account) + results.append(result) + time.sleep(1) # 避免请求过快 + + # 打印汇总 + print("\n" + "=" * 60) + print(" 📊 测试结果汇总") + print("=" * 60) + + success_count = sum(1 for r in results if r.login_success) + token_count = sum(1 for r in results if r.has_token) + + print(f"\n总计: {len(results)} 个账号") + print(f"✅ 登录成功: {success_count}") + print(f"🔑 获取Token: {token_count}") + print(f"❌ 登录失败: {len(results) - success_count}") + + if any(not r.login_success for r in results): + print("\n失败的账号:") + for r in results: + if not r.login_success: + print(f" • {r.email}: {r.error}") + + print("\n" + "=" * 60) + + # 保存更新后的配置(如果获取了新 token) + if token_count > 0: + print("\n💾 更新配置文件中的 token...") + from core.config import save_config + save_config(CONFIG) + print("✅ 配置已保存") + + return results + + +def test_account_rotation(): + """测试账号轮换功能""" + from core.auth import choose_account, release_account, account_queue + from core.config import CONFIG + + accounts = CONFIG.get("accounts", []) + if len(accounts) < 2: + print("⚠️ 需要至少 2 个账号来测试轮换") + return + + print("\n" + "=" * 60) + print(" 🔄 账号轮换测试") + print("=" * 60) + + # 测试选择账号 + print("\n选择账号 (连续3次):") + selected = [] + for i in range(3): + account = choose_account() + if account: + email = account.get("email", account.get("mobile", "unknown")) + selected.append(email) + print(f" 第{i+1}次: {email}") + else: + print(f" 第{i+1}次: 无可用账号") + + # 释放账号 + print("\n释放账号:") + for i, email in enumerate(selected): + for acc in accounts: + if acc.get("email") == email: + release_account(acc) + print(f" 已释放: {email}") + break + + # 再次选择 + print("\n释放后再选择:") + for i in range(2): + account = choose_account() + if account: + email = account.get("email", account.get("mobile", "unknown")) + print(f" 第{i+1}次: {email}") + release_account(account) + + print("\n✅ 账号轮换功能正常") + + +def main(): + parser = argparse.ArgumentParser(description="DS2API 账号测试") + parser.add_argument("--login", action="store_true", help="测试账号登录") + parser.add_argument("--rotation", action="store_true", help="测试账号轮换") + parser.add_argument("--all", action="store_true", help="运行所有测试") + + args = parser.parse_args() + + if args.all or args.login: + test_account_pool() + + if args.all or args.rotation: + test_account_rotation() + + if not (args.all or args.login or args.rotation): + parser.print_help() + print("\n使用 --all 运行所有测试") + + +if __name__ == "__main__": + main() diff --git a/tests/test_all.py b/tests/test_all.py new file mode 100644 index 0000000..4424f55 --- /dev/null +++ b/tests/test_all.py @@ -0,0 +1,653 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +DS2API 全面自动化测试套件 + +测试覆盖: +- 配置加载和认证 +- 会话创建 +- PoW 计算 +- OpenAI 兼容 API +- Claude 兼容 API +- 流式和非流式响应 +- 错误处理 +- Token 计数 + +使用方法: + python tests/test_all.py # 运行所有测试 + python tests/test_all.py --quick # 快速测试(跳过耗时测试) + python tests/test_all.py --verbose # 详细输出 + python tests/test_all.py --endpoint URL # 指定测试端点 +""" +import argparse +import json +import os +import sys +import time +from dataclasses import dataclass +from typing import Optional +import requests + +# 添加项目根目录到路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# 测试配置 +DEFAULT_ENDPOINT = "http://localhost:5001" +TEST_API_KEY = "test-api-key-001" # 配置中的 API key +TEST_TIMEOUT = 120 # 超时时间(秒) + + +@dataclass +class TestResult: + """测试结果""" + name: str + passed: bool + duration: float + message: str = "" + details: Optional[dict] = None + + +class TestRunner: + """测试运行器""" + + def __init__(self, endpoint: str, api_key: str, verbose: bool = False): + self.endpoint = endpoint.rstrip("/") + self.api_key = api_key + self.verbose = verbose + self.results: list[TestResult] = [] + + def log(self, message: str, level: str = "INFO"): + """日志输出""" + colors = { + "INFO": "\033[94m", + "SUCCESS": "\033[92m", + "WARNING": "\033[93m", + "ERROR": "\033[91m", + "RESET": "\033[0m" + } + if self.verbose or level in ("ERROR", "SUCCESS"): + print(f"{colors.get(level, '')}{message}{colors['RESET']}") + + def run_test(self, name: str, test_func): + """运行单个测试""" + print(f"\n{'='*60}") + print(f"🧪 测试: {name}") + print('='*60) + + start_time = time.time() + try: + result = test_func() + duration = time.time() - start_time + + if result.get("success", False): + self.log(f"✅ 通过 ({duration:.2f}s)", "SUCCESS") + self.results.append(TestResult( + name=name, + passed=True, + duration=duration, + message=result.get("message", ""), + details=result.get("details") + )) + else: + self.log(f"❌ 失败: {result.get('message', '未知错误')}", "ERROR") + self.results.append(TestResult( + name=name, + passed=False, + duration=duration, + message=result.get("message", ""), + details=result.get("details") + )) + except Exception as e: + duration = time.time() - start_time + self.log(f"❌ 异常: {e}", "ERROR") + self.results.append(TestResult( + name=name, + passed=False, + duration=duration, + message=str(e) + )) + + def get_headers(self, is_claude: bool = False) -> dict: + """获取请求头""" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + if is_claude: + headers["anthropic-version"] = "2024-01-01" + return headers + + # ===================================================================== + # 基础测试 + # ===================================================================== + + def test_health_check(self) -> dict: + """测试服务健康状态""" + try: + resp = requests.get(f"{self.endpoint}/", timeout=10) + if resp.status_code == 200: + return {"success": True, "message": "服务运行正常"} + return {"success": False, "message": f"状态码: {resp.status_code}"} + except requests.exceptions.ConnectionError: + return {"success": False, "message": "无法连接到服务"} + + # ===================================================================== + # OpenAI 兼容 API 测试 + # ===================================================================== + + def test_openai_models_list(self) -> dict: + """测试 OpenAI /v1/models 端点""" + resp = requests.get( + f"{self.endpoint}/v1/models", + headers=self.get_headers(), + timeout=TEST_TIMEOUT + ) + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + data = resp.json() + if data.get("object") != "list": + return {"success": False, "message": "响应格式错误"} + + models = [m["id"] for m in data.get("data", [])] + expected_models = ["deepseek-chat", "deepseek-reasoner", "deepseek-chat-search", "deepseek-reasoner-search"] + + for model in expected_models: + if model not in models: + return {"success": False, "message": f"缺少模型: {model}"} + + return { + "success": True, + "message": f"返回 {len(models)} 个模型", + "details": {"models": models} + } + + def test_openai_chat_non_stream(self) -> dict: + """测试 OpenAI 非流式对话""" + payload = { + "model": "deepseek-chat", + "messages": [ + {"role": "user", "content": "请用一句话回答:1+1等于多少?"} + ], + "stream": False + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers=self.get_headers(), + json=payload, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}} + + data = resp.json() + if "error" in data: + return {"success": False, "message": data["error"]} + + content = data.get("choices", [{}])[0].get("message", {}).get("content", "") + if not content: + return {"success": False, "message": "响应内容为空"} + + return { + "success": True, + "message": f"收到 {len(content)} 字符响应", + "details": { + "content_preview": content[:100] + "..." if len(content) > 100 else content, + "usage": data.get("usage", {}) + } + } + + def test_openai_chat_stream(self) -> dict: + """测试 OpenAI 流式对话""" + payload = { + "model": "deepseek-chat", + "messages": [ + {"role": "user", "content": "说'你好'"} + ], + "stream": True + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers=self.get_headers(), + json=payload, + stream=True, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + chunks = [] + content = "" + for line in resp.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + data_str = line_str[6:] + if data_str == "[DONE]": + break + try: + chunk = json.loads(data_str) + chunks.append(chunk) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + if "content" in delta: + content += delta["content"] + except json.JSONDecodeError: + pass + + if not chunks: + return {"success": False, "message": "未收到任何流式数据块"} + + return { + "success": True, + "message": f"收到 {len(chunks)} 个数据块,内容: {content[:50]}", + "details": {"chunk_count": len(chunks), "content": content} + } + + def test_openai_reasoner_stream(self) -> dict: + """测试 OpenAI Reasoner 模式(思考链)""" + payload = { + "model": "deepseek-reasoner", + "messages": [ + {"role": "user", "content": "1加2等于多少?"} + ], + "stream": True + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers=self.get_headers(), + json=payload, + stream=True, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + content = "" + reasoning = "" + for line in resp.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + data_str = line_str[6:] + if data_str == "[DONE]": + break + try: + chunk = json.loads(data_str) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + if "content" in delta: + content += delta["content"] + if "reasoning_content" in delta: + reasoning += delta["reasoning_content"] + except json.JSONDecodeError: + pass + + return { + "success": True, + "message": f"思考: {len(reasoning)}字, 回答: {len(content)}字", + "details": { + "reasoning_preview": reasoning[:100] + "..." if len(reasoning) > 100 else reasoning, + "content": content + } + } + + def test_openai_invalid_model(self) -> dict: + """测试无效模型错误处理""" + payload = { + "model": "invalid-model-name", + "messages": [{"role": "user", "content": "test"}], + "stream": False + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers=self.get_headers(), + json=payload, + timeout=TEST_TIMEOUT + ) + + # 应该返回 503 或 400 + if resp.status_code in (503, 400): + return {"success": True, "message": f"正确返回错误状态码 {resp.status_code}"} + + return {"success": False, "message": f"期望 503/400,实际: {resp.status_code}"} + + def test_openai_missing_auth(self) -> dict: + """测试缺少认证的错误处理""" + payload = { + "model": "deepseek-chat", + "messages": [{"role": "user", "content": "test"}] + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers={"Content-Type": "application/json"}, # 无 Authorization + json=payload, + timeout=TEST_TIMEOUT + ) + + if resp.status_code == 401: + return {"success": True, "message": "正确返回 401 未认证"} + + return {"success": False, "message": f"期望 401,实际: {resp.status_code}"} + + # ===================================================================== + # Claude 兼容 API 测试 + # ===================================================================== + + def test_claude_models_list(self) -> dict: + """测试 Claude /anthropic/v1/models 端点""" + resp = requests.get( + f"{self.endpoint}/anthropic/v1/models", + headers=self.get_headers(is_claude=True), + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + data = resp.json() + models = [m["id"] for m in data.get("data", [])] + + if not models: + return {"success": False, "message": "模型列表为空"} + + return { + "success": True, + "message": f"返回 {len(models)} 个 Claude 模型", + "details": {"models": models} + } + + def test_claude_messages_non_stream(self) -> dict: + """测试 Claude 非流式消息""" + payload = { + "model": "claude-sonnet-4-20250514", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "Say 'Hello' in Chinese"} + ], + "stream": False + } + + resp = requests.post( + f"{self.endpoint}/anthropic/v1/messages", + headers=self.get_headers(is_claude=True), + json=payload, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}} + + data = resp.json() + if "error" in data: + return {"success": False, "message": str(data["error"])} + + content_blocks = data.get("content", []) + text_content = "" + for block in content_blocks: + if block.get("type") == "text": + text_content += block.get("text", "") + + if not text_content: + return {"success": False, "message": "响应内容为空"} + + return { + "success": True, + "message": f"收到 Claude 格式响应: {len(text_content)} 字符", + "details": { + "content": text_content[:100], + "stop_reason": data.get("stop_reason"), + "usage": data.get("usage", {}) + } + } + + def test_claude_messages_stream(self) -> dict: + """测试 Claude 流式消息""" + payload = { + "model": "claude-sonnet-4-20250514", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Reply with just 'OK'"} + ], + "stream": True + } + + resp = requests.post( + f"{self.endpoint}/anthropic/v1/messages", + headers=self.get_headers(is_claude=True), + json=payload, + stream=True, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + events = [] + for line in resp.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + try: + event = json.loads(line_str[6:]) + events.append(event) + except json.JSONDecodeError: + pass + + event_types = [e.get("type") for e in events] + + # 检查必要的事件类型 + required_types = ["message_start", "message_stop"] + for rt in required_types: + if rt not in event_types: + return {"success": False, "message": f"缺少事件类型: {rt}"} + + return { + "success": True, + "message": f"收到 {len(events)} 个 Claude 流事件", + "details": {"event_types": event_types} + } + + def test_claude_count_tokens(self) -> dict: + """测试 Claude token 计数""" + payload = { + "model": "claude-sonnet-4-20250514", + "messages": [ + {"role": "user", "content": "Hello, how are you today?"} + ] + } + + resp = requests.post( + f"{self.endpoint}/anthropic/v1/messages/count_tokens", + headers=self.get_headers(is_claude=True), + json=payload, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + data = resp.json() + input_tokens = data.get("input_tokens", 0) + + if input_tokens <= 0: + return {"success": False, "message": f"token 计数无效: {input_tokens}"} + + return { + "success": True, + "message": f"Token 计数: {input_tokens}", + "details": data + } + + # ===================================================================== + # 高级功能测试 + # ===================================================================== + + def test_multi_turn_conversation(self) -> dict: + """测试多轮对话""" + payload = { + "model": "deepseek-chat", + "messages": [ + {"role": "system", "content": "你是一个数学助手"}, + {"role": "user", "content": "我有3个苹果"}, + {"role": "assistant", "content": "好的,你有3个苹果。"}, + {"role": "user", "content": "我又买了2个,现在有多少?"} + ], + "stream": False + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers=self.get_headers(), + json=payload, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + data = resp.json() + content = data.get("choices", [{}])[0].get("message", {}).get("content", "") + + # 检查是否包含"5" + if "5" in content: + return {"success": True, "message": f"AI 正确理解上下文", "details": {"content": content[:100]}} + + return { + "success": True, # 即使没有5也算通过,因为测试的是多轮对话功能 + "message": f"多轮对话功能正常", + "details": {"content": content[:100]} + } + + def test_long_input(self) -> dict: + """测试长输入处理""" + # 生成约 1000 字的输入 + long_text = "这是一段测试文本。" * 100 + + payload = { + "model": "deepseek-chat", + "messages": [ + {"role": "user", "content": f"请总结以下内容的主题:{long_text}"} + ], + "stream": False + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers=self.get_headers(), + json=payload, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + data = resp.json() + if "error" in data: + return {"success": False, "message": str(data.get("error"))} + + return { + "success": True, + "message": f"成功处理 {len(long_text)} 字符输入", + "details": {"input_length": len(long_text)} + } + + # ===================================================================== + # 运行测试 + # ===================================================================== + + def run_all_tests(self, quick: bool = False): + """运行所有测试""" + print("\n" + "="*70) + print(" 🚀 DS2API 全面自动化测试") + print("="*70) + print(f"端点: {self.endpoint}") + print(f"API Key: {self.api_key[:10]}...") + print(f"模式: {'快速' if quick else '完整'}") + + # 基础测试 + self.run_test("服务健康检查", self.test_health_check) + + if not self.results[-1].passed: + print("\n⚠️ 服务未运行,跳过其他测试") + return + + # OpenAI API 测试 + self.run_test("OpenAI 模型列表", self.test_openai_models_list) + self.run_test("OpenAI 非流式对话", self.test_openai_chat_non_stream) + self.run_test("OpenAI 流式对话", self.test_openai_chat_stream) + self.run_test("OpenAI 无效模型处理", self.test_openai_invalid_model) + self.run_test("OpenAI 缺少认证处理", self.test_openai_missing_auth) + + if not quick: + self.run_test("OpenAI Reasoner 模式", self.test_openai_reasoner_stream) + + # Claude API 测试 + self.run_test("Claude 模型列表", self.test_claude_models_list) + self.run_test("Claude 非流式消息", self.test_claude_messages_non_stream) + self.run_test("Claude 流式消息", self.test_claude_messages_stream) + self.run_test("Claude Token 计数", self.test_claude_count_tokens) + + # 高级功能测试 + if not quick: + self.run_test("多轮对话", self.test_multi_turn_conversation) + self.run_test("长输入处理", self.test_long_input) + + # 输出测试报告 + self.print_report() + + def print_report(self): + """打印测试报告""" + print("\n" + "="*70) + print(" 📊 测试报告") + print("="*70) + + passed = sum(1 for r in self.results if r.passed) + failed = len(self.results) - passed + total_time = sum(r.duration for r in self.results) + + print(f"\n总计: {len(self.results)} 个测试") + print(f"✅ 通过: {passed}") + print(f"❌ 失败: {failed}") + print(f"⏱️ 耗时: {total_time:.2f}s") + print(f"📈 通过率: {passed/len(self.results)*100:.1f}%") + + if failed > 0: + print("\n❌ 失败的测试:") + for r in self.results: + if not r.passed: + print(f" • {r.name}: {r.message}") + + print("\n" + "="*70) + + # 返回退出码 + return 0 if failed == 0 else 1 + + +def main(): + parser = argparse.ArgumentParser(description="DS2API 自动化测试") + parser.add_argument("--endpoint", default=DEFAULT_ENDPOINT, help="API 端点") + parser.add_argument("--api-key", default=TEST_API_KEY, help="API Key") + parser.add_argument("--quick", action="store_true", help="快速测试模式") + parser.add_argument("--verbose", "-v", action="store_true", help="详细输出") + + args = parser.parse_args() + + runner = TestRunner( + endpoint=args.endpoint, + api_key=args.api_key, + verbose=args.verbose + ) + + exit_code = runner.run_all_tests(quick=args.quick) + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/tests/test_unit.py b/tests/test_unit.py new file mode 100644 index 0000000..09c3160 --- /dev/null +++ b/tests/test_unit.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +DS2API 单元测试 + +测试核心模块的功能,不依赖网络请求 +""" +import json +import os +import sys +import unittest + +# 添加项目根目录到路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class TestConfig(unittest.TestCase): + """配置模块测试""" + + def test_config_loading(self): + """测试配置加载""" + from core.config import load_config, CONFIG + + # 测试加载函数不会抛出异常 + config = load_config() + self.assertIsInstance(config, dict) + + def test_config_paths(self): + """测试配置路径""" + from core.config import WASM_PATH, CONFIG_PATH + + # 路径应该是字符串 + self.assertIsInstance(WASM_PATH, str) + self.assertIsInstance(CONFIG_PATH, str) + + +class TestMessages(unittest.TestCase): + """消息处理模块测试""" + + def test_messages_prepare_simple(self): + """测试简单消息处理""" + from core.messages import messages_prepare + + messages = [ + {"role": "user", "content": "Hello"} + ] + result = messages_prepare(messages) + self.assertIn("Hello", result) + + def test_messages_prepare_multi_turn(self): + """测试多轮对话消息处理""" + from core.messages import messages_prepare + + messages = [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "How are you?"} + ] + result = messages_prepare(messages) + + # 检查助手消息标签 + self.assertIn("<|Assistant|>", result) + self.assertIn("<|end▁of▁sentence|>", result) + # 检查用户消息标签 + self.assertIn("<|User|>", result) + + def test_messages_prepare_array_content(self): + """测试数组格式内容处理""" + from core.messages import messages_prepare + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "First part"}, + {"type": "text", "text": "Second part"}, + {"type": "image", "url": "http://example.com/image.png"} + ] + } + ] + result = messages_prepare(messages) + + self.assertIn("First part", result) + self.assertIn("Second part", result) + + def test_markdown_image_removal(self): + """测试 markdown 图片格式移除""" + from core.messages import messages_prepare + + messages = [ + {"role": "user", "content": "Check this ![alt](http://example.com/image.png) image"} + ] + result = messages_prepare(messages) + + # 图片格式应该被改为链接格式 + self.assertNotIn("![alt]", result) + self.assertIn("[alt]", result) + + def test_merge_consecutive_messages(self): + """测试连续相同角色消息合并""" + from core.messages import messages_prepare + + messages = [ + {"role": "user", "content": "Part 1"}, + {"role": "user", "content": "Part 2"}, + {"role": "user", "content": "Part 3"} + ] + result = messages_prepare(messages) + + self.assertIn("Part 1", result) + self.assertIn("Part 2", result) + self.assertIn("Part 3", result) + + def test_convert_claude_to_deepseek(self): + """测试 Claude 到 DeepSeek 格式转换""" + from core.messages import convert_claude_to_deepseek + + claude_request = { + "model": "claude-sonnet-4-20250514", + "messages": [{"role": "user", "content": "Hi"}], + "system": "You are helpful.", + "temperature": 0.7, + "stream": True + } + + result = convert_claude_to_deepseek(claude_request) + + # 检查模型映射 + self.assertIn("deepseek", result.get("model", "").lower()) + + # 检查 system 消息插入 + self.assertEqual(result["messages"][0]["role"], "system") + self.assertEqual(result["messages"][0]["content"], "You are helpful.") + + # 检查其他参数 + self.assertEqual(result.get("temperature"), 0.7) + self.assertEqual(result.get("stream"), True) + + +class TestPow(unittest.TestCase): + """PoW 模块测试""" + + def test_wasm_caching(self): + """测试 WASM 缓存功能""" + from core.pow import _get_cached_wasm_module, _wasm_module, _wasm_engine + from core.config import WASM_PATH + + # 首次调用 + engine1, module1 = _get_cached_wasm_module(WASM_PATH) + self.assertIsNotNone(engine1) + self.assertIsNotNone(module1) + + # 再次调用应该返回相同的实例 + engine2, module2 = _get_cached_wasm_module(WASM_PATH) + self.assertIs(engine1, engine2) + self.assertIs(module1, module2) + + def test_get_account_identifier(self): + """测试账号标识获取""" + from core.pow import get_account_identifier + + # 测试邮箱 + account1 = {"email": "test@example.com"} + self.assertEqual(get_account_identifier(account1), "test@example.com") + + # 测试手机号 + account2 = {"mobile": "13800138000"} + self.assertEqual(get_account_identifier(account2), "13800138000") + + # 邮箱优先 + account3 = {"email": "test@example.com", "mobile": "13800138000"} + self.assertEqual(get_account_identifier(account3), "test@example.com") + + +class TestSessionManager(unittest.TestCase): + """会话管理器模块测试""" + + def test_get_model_config(self): + """测试模型配置获取""" + from core.session_manager import get_model_config + + # deepseek-chat + thinking, search = get_model_config("deepseek-chat") + self.assertEqual(thinking, False) + self.assertEqual(search, False) + + # deepseek-reasoner + thinking, search = get_model_config("deepseek-reasoner") + self.assertEqual(thinking, True) + self.assertEqual(search, False) + + # deepseek-chat-search + thinking, search = get_model_config("deepseek-chat-search") + self.assertEqual(thinking, False) + self.assertEqual(search, True) + + # deepseek-reasoner-search + thinking, search = get_model_config("deepseek-reasoner-search") + self.assertEqual(thinking, True) + self.assertEqual(search, True) + + # 大小写不敏感 + thinking, search = get_model_config("DeepSeek-CHAT") + self.assertEqual(thinking, False) + self.assertEqual(search, False) + + # 无效模型 + thinking, search = get_model_config("invalid-model") + self.assertIsNone(thinking) + self.assertIsNone(search) + + +class TestAuth(unittest.TestCase): + """认证模块测试""" + + def test_auth_key_check(self): + """测试 API Key 检查""" + from core.config import CONFIG + + # 检查配置中是否有 keys + keys = CONFIG.get("keys", []) + self.assertIsInstance(keys, list) + + +class TestRegexPatterns(unittest.TestCase): + """正则表达式测试""" + + def test_markdown_image_pattern(self): + """测试 markdown 图片正则""" + from core.messages import _MARKDOWN_IMAGE_PATTERN + + text = "Check ![alt text](http://example.com/image.png) here" + match = _MARKDOWN_IMAGE_PATTERN.search(text) + + self.assertIsNotNone(match) + self.assertEqual(match.group(1), "alt text") + self.assertEqual(match.group(2), "http://example.com/image.png") + + +if __name__ == "__main__": + # 设置环境变量避免配置警告 + os.environ.setdefault("DS2API_CONFIG_PATH", + os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.json")) + + unittest.main(verbosity=2) diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..424986d --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1 @@ +# DS2API Tools diff --git a/tools/config_generator.py b/tools/config_generator.py new file mode 100644 index 0000000..0d88e9b --- /dev/null +++ b/tools/config_generator.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +DS2API 配置生成器 + +交互式工具,用于批量配置账号和 API Keys。 +支持导出为 JSON 和 Base64 格式,方便 Vercel 部署配置。 + +使用方法: + python tools/config_generator.py +""" +import base64 +import json +import os +import sys + +# 默认配置结构 +DEFAULT_CONFIG = {"keys": [], "accounts": []} + + +def clear_screen(): + """清屏""" + os.system("cls" if os.name == "nt" else "clear") + + +def print_header(): + """打印标题""" + print("\n" + "=" * 50) + print(" DS2API 配置生成器") + print("=" * 50) + + +def print_menu(): + """打印菜单""" + print("\n📋 请选择操作:") + print(" 1. 添加 API Key") + print(" 2. 添加账号 (Email)") + print(" 3. 添加账号 (手机号)") + print(" 4. 删除 API Key") + print(" 5. 删除账号") + print(" 6. 查看当前配置") + print(" 7. 导出 JSON (可直接用于环境变量)") + print(" 8. 导出 Base64 (推荐用于 Vercel)") + print(" 9. 从 config.json 导入") + print(" 10. 保存到 config.json") + print(" 0. 退出") + print() + + +def add_api_key(config): + """添加 API Key""" + print("\n➕ 添加 API Key") + print(" 提示:API Key 是你自定义的密钥,用于调用此 API 服务") + key = input(" 请输入 API Key: ").strip() + if key: + if key in config["keys"]: + print(" ⚠️ 该 Key 已存在") + else: + config["keys"].append(key) + print(f" ✅ 已添加 Key: {key[:8]}...") + else: + print(" ❌ 输入为空,未添加") + + +def add_account_email(config): + """添加 Email 账号""" + print("\n➕ 添加 DeepSeek 账号 (Email)") + email = input(" Email: ").strip() + password = input(" 密码: ").strip() + if email and password: + # 检查是否已存在 + for acc in config["accounts"]: + if acc.get("email") == email: + print(" ⚠️ 该账号已存在") + return + config["accounts"].append({"email": email, "password": password, "token": ""}) + print(f" ✅ 已添加账号: {email}") + else: + print(" ❌ 输入不完整,未添加") + + +def add_account_mobile(config): + """添加手机号账号""" + print("\n➕ 添加 DeepSeek 账号 (手机号)") + mobile = input(" 手机号: ").strip() + password = input(" 密码: ").strip() + if mobile and password: + # 检查是否已存在 + for acc in config["accounts"]: + if acc.get("mobile") == mobile: + print(" ⚠️ 该账号已存在") + return + config["accounts"].append({"mobile": mobile, "password": password, "token": ""}) + print(f" ✅ 已添加账号: {mobile}") + else: + print(" ❌ 输入不完整,未添加") + + +def delete_api_key(config): + """删除 API Key""" + if not config["keys"]: + print("\n ⚠️ 当前没有 API Key") + return + print("\n🗑️ 删除 API Key") + for i, key in enumerate(config["keys"], 1): + print(f" {i}. {key[:8]}...") + try: + idx = int(input(" 选择要删除的序号 (0 取消): ")) + if 0 < idx <= len(config["keys"]): + removed = config["keys"].pop(idx - 1) + print(f" ✅ 已删除: {removed[:8]}...") + elif idx != 0: + print(" ❌ 无效选择") + except ValueError: + print(" ❌ 无效输入") + + +def delete_account(config): + """删除账号""" + if not config["accounts"]: + print("\n ⚠️ 当前没有账号") + return + print("\n🗑️ 删除账号") + for i, acc in enumerate(config["accounts"], 1): + identifier = acc.get("email") or acc.get("mobile", "未知") + print(f" {i}. {identifier}") + try: + idx = int(input(" 选择要删除的序号 (0 取消): ")) + if 0 < idx <= len(config["accounts"]): + removed = config["accounts"].pop(idx - 1) + identifier = removed.get("email") or removed.get("mobile", "未知") + print(f" ✅ 已删除: {identifier}") + elif idx != 0: + print(" ❌ 无效选择") + except ValueError: + print(" ❌ 无效输入") + + +def view_config(config): + """查看当前配置""" + print("\n📄 当前配置") + print("-" * 40) + print(f" API Keys ({len(config['keys'])}个):") + for key in config["keys"]: + print(f" • {key[:8]}...") + print(f"\n 账号 ({len(config['accounts'])}个):") + for acc in config["accounts"]: + identifier = acc.get("email") or acc.get("mobile", "未知") + token_status = "✓ 有Token" if acc.get("token") else "✗ 无Token" + print(f" • {identifier} [{token_status}]") + print("-" * 40) + + +def export_json(config): + """导出 JSON""" + json_str = json.dumps(config, ensure_ascii=False, separators=(",", ":")) + print("\n📤 JSON 格式 (可直接设置为 DS2API_CONFIG_JSON 环境变量):") + print("-" * 50) + print(json_str) + print("-" * 50) + + # 复制到剪贴板(如果可用) + try: + import subprocess + process = subprocess.Popen(["pbcopy"], stdin=subprocess.PIPE) + process.communicate(json_str.encode("utf-8")) + print(" ✅ 已复制到剪贴板 (macOS)") + except Exception: + pass + + +def export_base64(config): + """导出 Base64""" + json_str = json.dumps(config, ensure_ascii=False, separators=(",", ":")) + b64_str = base64.b64encode(json_str.encode("utf-8")).decode("utf-8") + print("\n📤 Base64 格式 (推荐用于 Vercel 环境变量):") + print("-" * 50) + print(b64_str) + print("-" * 50) + + # 复制到剪贴板(如果可用) + try: + import subprocess + process = subprocess.Popen(["pbcopy"], stdin=subprocess.PIPE) + process.communicate(b64_str.encode("utf-8")) + print(" ✅ 已复制到剪贴板 (macOS)") + except Exception: + pass + + +def import_from_file(config): + """从 config.json 导入""" + # 尝试多个可能的路径 + paths = [ + "config.json", + "../config.json", + os.path.join(os.path.dirname(__file__), "..", "config.json"), + ] + + for path in paths: + if os.path.exists(path): + try: + with open(path, "r", encoding="utf-8") as f: + loaded = json.load(f) + config["keys"] = loaded.get("keys", []) + config["accounts"] = loaded.get("accounts", []) + print(f"\n ✅ 已从 {path} 导入配置") + print(f" Keys: {len(config['keys'])}个, 账号: {len(config['accounts'])}个") + return + except Exception as e: + print(f"\n ❌ 导入失败: {e}") + return + + print("\n ⚠️ 未找到 config.json 文件") + + +def save_to_file(config): + """保存到 config.json""" + # 确定保存路径 + path = "config.json" + if not os.path.exists(path): + parent_path = os.path.join(os.path.dirname(__file__), "..", "config.json") + if os.path.exists(os.path.dirname(parent_path)): + path = parent_path + + try: + with open(path, "w", encoding="utf-8") as f: + json.dump(config, f, ensure_ascii=False, indent=2) + print(f"\n ✅ 已保存到 {path}") + except Exception as e: + print(f"\n ❌ 保存失败: {e}") + + +def main(): + """主函数""" + config = DEFAULT_CONFIG.copy() + config["keys"] = [] + config["accounts"] = [] + + print_header() + print("\n💡 提示:此工具帮助你生成 DS2API 配置") + print(" 生成的配置可用于本地 config.json 或 Vercel 环境变量") + + while True: + print_menu() + choice = input("请输入选项: ").strip() + + if choice == "1": + add_api_key(config) + elif choice == "2": + add_account_email(config) + elif choice == "3": + add_account_mobile(config) + elif choice == "4": + delete_api_key(config) + elif choice == "5": + delete_account(config) + elif choice == "6": + view_config(config) + elif choice == "7": + export_json(config) + elif choice == "8": + export_base64(config) + elif choice == "9": + import_from_file(config) + elif choice == "10": + save_to_file(config) + elif choice == "0": + print("\n👋 再见!\n") + break + else: + print("\n ❌ 无效选项,请重新选择") + + input("\n按 Enter 继续...") + + +if __name__ == "__main__": + main() diff --git a/webui/index.html b/webui/index.html new file mode 100644 index 0000000..5bf5917 --- /dev/null +++ b/webui/index.html @@ -0,0 +1,12 @@ + + + + + + DS2API Admin + + +
+ + + diff --git a/webui/package.json b/webui/package.json new file mode 100644 index 0000000..697401d --- /dev/null +++ b/webui/package.json @@ -0,0 +1,19 @@ +{ + "name": "ds2api-admin", + "private": true, + "version": "1.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview" + }, + "dependencies": { + "react": "^18.2.0", + "react-dom": "^18.2.0" + }, + "devDependencies": { + "@vitejs/plugin-react": "^4.2.1", + "vite": "^5.0.0" + } +} \ No newline at end of file diff --git a/webui/src/App.jsx b/webui/src/App.jsx new file mode 100644 index 0000000..ef94ee1 --- /dev/null +++ b/webui/src/App.jsx @@ -0,0 +1,106 @@ +import { useState, useEffect } from 'react' +import AccountManager from './components/AccountManager' +import ApiTester from './components/ApiTester' +import BatchImport from './components/BatchImport' +import VercelSync from './components/VercelSync' + +const TABS = [ + { id: 'accounts', label: '🔑 账号管理' }, + { id: 'test', label: '🧪 API 测试' }, + { id: 'import', label: '📦 批量导入' }, + { id: 'vercel', label: '☁️ Vercel 同步' }, +] + +export default function App() { + const [activeTab, setActiveTab] = useState('accounts') + const [config, setConfig] = useState({ keys: [], accounts: [] }) + const [loading, setLoading] = useState(true) + const [message, setMessage] = useState(null) + + const fetchConfig = async () => { + try { + setLoading(true) + const res = await fetch('/admin/config') + if (res.ok) { + const data = await res.json() + setConfig(data) + } + } catch (e) { + console.error('获取配置失败:', e) + } finally { + setLoading(false) + } + } + + useEffect(() => { + fetchConfig() + }, []) + + const showMessage = (type, text) => { + setMessage({ type, text }) + setTimeout(() => setMessage(null), 5000) + } + + const renderTab = () => { + switch (activeTab) { + case 'accounts': + return + case 'test': + return + case 'import': + return + case 'vercel': + return + default: + return null + } + } + + return ( +
+
+

DS2API Admin

+

账号管理 · API 测试 · Vercel 部署

+
+ + {message && ( +
+ {message.text} +
+ )} + +
+
+
{config.keys?.length || 0}
+
API Keys
+
+
+
{config.accounts?.length || 0}
+
账号
+
+
+ +
+ {TABS.map(tab => ( + + ))} +
+ + {loading ? ( +
+
+ 加载中... +
+
+ ) : ( + renderTab() + )} +
+ ) +} diff --git a/webui/src/components/AccountManager.jsx b/webui/src/components/AccountManager.jsx new file mode 100644 index 0000000..f9c517e --- /dev/null +++ b/webui/src/components/AccountManager.jsx @@ -0,0 +1,219 @@ +import { useState } from 'react' + +export default function AccountManager({ config, onRefresh, onMessage }) { + const [showAddKey, setShowAddKey] = useState(false) + const [showAddAccount, setShowAddAccount] = useState(false) + const [newKey, setNewKey] = useState('') + const [newAccount, setNewAccount] = useState({ email: '', mobile: '', password: '' }) + const [loading, setLoading] = useState(false) + + const addKey = async () => { + if (!newKey.trim()) return + setLoading(true) + try { + const res = await fetch('/admin/keys', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ key: newKey.trim() }), + }) + if (res.ok) { + onMessage('success', 'API Key 添加成功') + setNewKey('') + setShowAddKey(false) + onRefresh() + } else { + const data = await res.json() + onMessage('error', data.detail || '添加失败') + } + } catch (e) { + onMessage('error', '网络错误') + } finally { + setLoading(false) + } + } + + const deleteKey = async (key) => { + if (!confirm('确定删除此 API Key?')) return + try { + const res = await fetch(`/admin/keys/${encodeURIComponent(key)}`, { method: 'DELETE' }) + if (res.ok) { + onMessage('success', '删除成功') + onRefresh() + } else { + onMessage('error', '删除失败') + } + } catch (e) { + onMessage('error', '网络错误') + } + } + + const addAccount = async () => { + if (!newAccount.password || (!newAccount.email && !newAccount.mobile)) { + onMessage('error', '请填写密码和邮箱/手机号') + return + } + setLoading(true) + try { + const res = await fetch('/admin/accounts', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(newAccount), + }) + if (res.ok) { + onMessage('success', '账号添加成功') + setNewAccount({ email: '', mobile: '', password: '' }) + setShowAddAccount(false) + onRefresh() + } else { + const data = await res.json() + onMessage('error', data.detail || '添加失败') + } + } catch (e) { + onMessage('error', '网络错误') + } finally { + setLoading(false) + } + } + + const deleteAccount = async (id) => { + if (!confirm('确定删除此账号?')) return + try { + const res = await fetch(`/admin/accounts/${encodeURIComponent(id)}`, { method: 'DELETE' }) + if (res.ok) { + onMessage('success', '删除成功') + onRefresh() + } else { + onMessage('error', '删除失败') + } + } catch (e) { + onMessage('error', '网络错误') + } + } + + return ( +
+ {/* API Keys */} +
+
+ 🔑 API Keys + +
+ + {config.keys?.length > 0 ? ( +
+ {config.keys.map((key, i) => ( +
+ {key.slice(0, 16)}**** + +
+ ))} +
+ ) : ( +
暂无 API Key
+ )} +
+ + {/* Accounts */} +
+
+ 👤 DeepSeek 账号 + +
+ + {config.accounts?.length > 0 ? ( +
+ {config.accounts.map((acc, i) => ( +
+
+ {acc.email || acc.mobile} + + {acc.has_token ? '已登录' : '未登录'} + +
+ +
+ ))} +
+ ) : ( +
暂无账号
+ )} +
+ + {/* Add Key Modal */} + {showAddKey && ( +
setShowAddKey(false)}> +
e.stopPropagation()}> +
+ 添加 API Key + +
+
+ + setNewKey(e.target.value)} + /> +
+
+ + +
+
+
+ )} + + {/* Add Account Modal */} + {showAddAccount && ( +
setShowAddAccount(false)}> +
e.stopPropagation()}> +
+ 添加 DeepSeek 账号 + +
+
+ + setNewAccount({ ...newAccount, email: e.target.value })} + /> +
+
+ + setNewAccount({ ...newAccount, mobile: e.target.value })} + /> +
+
+ + setNewAccount({ ...newAccount, password: e.target.value })} + /> +
+
+ + +
+
+
+ )} +
+ ) +} diff --git a/webui/src/components/ApiTester.jsx b/webui/src/components/ApiTester.jsx new file mode 100644 index 0000000..a2daa08 --- /dev/null +++ b/webui/src/components/ApiTester.jsx @@ -0,0 +1,162 @@ +import { useState } from 'react' + +const MODELS = [ + { id: 'deepseek-chat', name: 'DeepSeek V3 (Chat)' }, + { id: 'deepseek-reasoner', name: 'DeepSeek R1 (Reasoner)' }, + { id: 'deepseek-chat-search', name: 'DeepSeek V3 + 搜索' }, + { id: 'deepseek-reasoner-search', name: 'DeepSeek R1 + 搜索' }, +] + +export default function ApiTester({ config, onMessage }) { + const [model, setModel] = useState('deepseek-chat') + const [message, setMessage] = useState('你好,请用一句话介绍你自己。') + const [apiKey, setApiKey] = useState('') + const [response, setResponse] = useState(null) + const [loading, setLoading] = useState(false) + + const testApi = async () => { + setLoading(true) + setResponse(null) + try { + const res = await fetch('/admin/test', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + model, + message, + api_key: apiKey || (config.keys?.[0] || ''), + }), + }) + const data = await res.json() + setResponse(data) + if (data.success) { + onMessage('success', 'API 调用成功') + } else { + onMessage('error', data.error || 'API 调用失败') + } + } catch (e) { + onMessage('error', '网络错误') + setResponse({ error: e.message }) + } finally { + setLoading(false) + } + } + + const directTest = async () => { + setLoading(true) + setResponse(null) + try { + const key = apiKey || (config.keys?.[0] || '') + if (!key) { + onMessage('error', '请提供 API Key') + setLoading(false) + return + } + + const res = await fetch('/v1/chat/completions', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${key}`, + }, + body: JSON.stringify({ + model, + messages: [{ role: 'user', content: message }], + stream: false, + }), + }) + const data = await res.json() + setResponse({ + success: res.ok, + status_code: res.status, + response: data, + }) + if (res.ok) { + onMessage('success', 'API 调用成功') + } else { + onMessage('error', data.error || 'API 调用失败') + } + } catch (e) { + onMessage('error', '网络错误') + setResponse({ error: e.message }) + } finally { + setLoading(false) + } + } + + return ( +
+
+
🧪 API 测试
+ +
+ + +
+ +
+ + setApiKey(e.target.value)} + /> +
+ +
+ +