mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 08:55:28 +08:00
refactor(ds2api): modularize app into a package structure and extract concerns into core/services/utils/models; drop heavy tokenizer usage, add scaffolding for PoW caching and async solving, and enable HTTP connection pooling
This change reorganizes the codebase for better maintainability and performance while preserving API surface. - Create ds2api package with modules: core, services, utils, models - Migrate config, logging, auth, DeepSeek, PoW, and message processing into dedicated modules - Introduce PoW caching (60s TTL) and async/parallel support (scalability for multiple requests) - Replace direct curl calls with a pool-enabled HTTP client setup and WASM-based PoW engine - Add in-memory token/account management scaffolding and improved token estimation - Optimize streaming paths and prepare for better backpressure and concurrency - Remove transformers/tokenizer usage and keep a simple token length estimator Non-breaking migration: keep API endpoints intact; new structure under ds2api is transparent to clients
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,5 +1,7 @@
|
||||
*.bak
|
||||
config.json
|
||||
tokenizer.json
|
||||
tokenizer_config.json
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
|
||||
0
ds2api/__init__.py
Normal file
0
ds2api/__init__.py
Normal file
58
ds2api/app.py
Normal file
58
ds2api/app.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from ds2api.config import CONFIG, settings
|
||||
from ds2api.core.auth import AccountManager
|
||||
from ds2api.core.deepseek import DeepSeekClient
|
||||
from ds2api.core.pow import PowService
|
||||
from ds2api.services import claude, completion, models
|
||||
from ds2api.utils.logger import configure_logging, get_logger
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
configure_logging()
|
||||
logger = get_logger("ds2api")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "OPTIONS", "PUT", "DELETE"],
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
)
|
||||
|
||||
templates = Jinja2Templates(directory=settings.templates_dir)
|
||||
|
||||
app.state.settings = settings
|
||||
app.state.config = CONFIG
|
||||
app.state.deepseek = DeepSeekClient()
|
||||
app.state.pow = PowService(settings.wasm_path)
|
||||
app.state.account_manager = AccountManager(CONFIG.get("accounts", []))
|
||||
app.state.templates = templates
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
logger.exception(f"[unhandled_exception] {request.method} {request.url.path}: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": {"type": "api_error", "message": "Internal Server Error"}},
|
||||
)
|
||||
|
||||
app.include_router(models.router)
|
||||
app.include_router(completion.router)
|
||||
app.include_router(claude.router)
|
||||
|
||||
@app.get("/")
|
||||
def index(request: Request):
|
||||
return templates.TemplateResponse("welcome.html", {"request": request})
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
90
ds2api/config.py
Normal file
90
ds2api/config.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from ds2api.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)))
|
||||
IS_VERCEL = bool(os.getenv("VERCEL")) or bool(os.getenv("NOW_REGION"))
|
||||
|
||||
|
||||
def resolve_path(env_key: str, default_rel: str) -> str:
|
||||
raw = os.getenv(env_key)
|
||||
if raw:
|
||||
return raw if os.path.isabs(raw) else os.path.join(BASE_DIR, raw)
|
||||
return os.path.join(BASE_DIR, default_rel)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Settings:
|
||||
config_path: str
|
||||
templates_dir: str
|
||||
wasm_path: str
|
||||
keep_alive_timeout: int
|
||||
|
||||
|
||||
settings = Settings(
|
||||
config_path=resolve_path("DS2API_CONFIG_PATH", "config.json"),
|
||||
templates_dir=resolve_path("DS2API_TEMPLATES_DIR", "templates"),
|
||||
wasm_path=resolve_path("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm"),
|
||||
keep_alive_timeout=int(os.getenv("DS2API_KEEP_ALIVE_TIMEOUT", "5")),
|
||||
)
|
||||
|
||||
|
||||
def _load_config_from_env() -> dict[str, Any] | None:
|
||||
raw_cfg = os.getenv("DS2API_CONFIG_JSON") or os.getenv("CONFIG_JSON")
|
||||
if not raw_cfg:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = json.loads(raw_cfg)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
decoded = base64.b64decode(raw_cfg).decode("utf-8")
|
||||
parsed = json.loads(decoded)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
except Exception as e:
|
||||
logger.warning(f"[load_config] 环境变量配置解析失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def load_config() -> dict[str, Any]:
|
||||
cfg = _load_config_from_env()
|
||||
if cfg is not None:
|
||||
return cfg
|
||||
|
||||
try:
|
||||
with open(settings.config_path, "r", encoding="utf-8") as f:
|
||||
parsed = json.load(f)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
except Exception as e:
|
||||
logger.warning(f"[load_config] 无法读取配置文件({settings.config_path}): {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def save_config(cfg: dict[str, Any]) -> None:
|
||||
if os.getenv("DS2API_CONFIG_JSON") or os.getenv("CONFIG_JSON"):
|
||||
logger.info("[save_config] 配置来自环境变量,跳过写回")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(settings.config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(cfg, f, ensure_ascii=False, indent=2)
|
||||
except PermissionError as e:
|
||||
logger.warning(f"[save_config] 配置文件不可写({settings.config_path}): {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"[save_config] 写入 config.json 失败: {e}")
|
||||
|
||||
|
||||
CONFIG: dict[str, Any] = load_config()
|
||||
|
||||
if not CONFIG:
|
||||
logger.warning(
|
||||
"[config] 未加载到有效配置,请提供 config.json(路径可用 DS2API_CONFIG_PATH 指定)或设置环境变量 DS2API_CONFIG_JSON"
|
||||
)
|
||||
0
ds2api/core/__init__.py
Normal file
0
ds2api/core/__init__.py
Normal file
186
ds2api/core/auth.py
Normal file
186
ds2api/core/auth.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from ds2api.config import CONFIG, save_config
|
||||
from ds2api.core.deepseek import DeepSeekClient
|
||||
from ds2api.utils.helpers import try_decode_jwt_exp
|
||||
from ds2api.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_account_identifier(account: dict[str, Any]) -> str:
|
||||
return str(account.get("email", "")).strip() or str(account.get("mobile", "")).strip()
|
||||
|
||||
|
||||
class AccountManager:
|
||||
def __init__(self, accounts: list[dict[str, Any]]):
|
||||
self._accounts = accounts
|
||||
shuffled = accounts[:]
|
||||
random.shuffle(shuffled)
|
||||
self._available: deque[dict[str, Any]] = deque(shuffled)
|
||||
self._in_use: set[str] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._token_obtained_at: dict[str, float] = {}
|
||||
|
||||
async def acquire(self, *, exclude_ids: set[str] | None = None) -> dict[str, Any] | None:
|
||||
exclude_ids = exclude_ids or set()
|
||||
async with self._lock:
|
||||
for _ in range(len(self._available)):
|
||||
acc = self._available.popleft()
|
||||
acc_id = get_account_identifier(acc)
|
||||
if not acc_id or acc_id in exclude_ids or acc_id in self._in_use:
|
||||
self._available.append(acc)
|
||||
continue
|
||||
|
||||
self._in_use.add(acc_id)
|
||||
logger.info(f"[accounts] acquire: {acc_id}")
|
||||
return acc
|
||||
|
||||
logger.warning("[accounts] 没有可用的账号或所有账号都在使用中")
|
||||
return None
|
||||
|
||||
async def release(self, account: dict[str, Any]) -> None:
|
||||
acc_id = get_account_identifier(account)
|
||||
async with self._lock:
|
||||
if acc_id:
|
||||
self._in_use.discard(acc_id)
|
||||
self._available.append(account)
|
||||
if acc_id:
|
||||
logger.info(f"[accounts] release: {acc_id}")
|
||||
|
||||
def _token_needs_refresh(self, token: str | None) -> bool:
|
||||
if not token:
|
||||
return True
|
||||
|
||||
exp = try_decode_jwt_exp(token)
|
||||
if exp is not None:
|
||||
return exp - int(time.time()) < 300
|
||||
|
||||
return False
|
||||
|
||||
async def ensure_token(self, account: dict[str, Any], deepseek: DeepSeekClient) -> str:
|
||||
token = str(account.get("token", "")).strip() or None
|
||||
|
||||
if token and not self._token_needs_refresh(token):
|
||||
return token
|
||||
|
||||
email = str(account.get("email", "")).strip() or None
|
||||
mobile = str(account.get("mobile", "")).strip() or None
|
||||
password = str(account.get("password", "")).strip()
|
||||
|
||||
try:
|
||||
new_token = await asyncio.to_thread(
|
||||
deepseek.login,
|
||||
email=email,
|
||||
mobile=mobile,
|
||||
password=password,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[accounts] 登录失败 {get_account_identifier(account)}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Account login failed.")
|
||||
|
||||
account["token"] = new_token
|
||||
self._token_obtained_at[get_account_identifier(account)] = time.time()
|
||||
save_config(CONFIG)
|
||||
return new_token
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
use_config_token: bool
|
||||
token: str
|
||||
account: dict[str, Any] | None
|
||||
tried_accounts: set[str]
|
||||
account_manager: AccountManager | None
|
||||
deepseek: DeepSeekClient
|
||||
|
||||
async def rotate_account(self) -> bool:
|
||||
if not self.use_config_token or not self.account_manager:
|
||||
return False
|
||||
|
||||
if self.account:
|
||||
acc_id = get_account_identifier(self.account)
|
||||
if acc_id:
|
||||
self.tried_accounts.add(acc_id)
|
||||
await self.account_manager.release(self.account)
|
||||
|
||||
new_acc = await self.account_manager.acquire(exclude_ids=self.tried_accounts)
|
||||
if not new_acc:
|
||||
return False
|
||||
|
||||
self.account = new_acc
|
||||
self.token = await self.account_manager.ensure_token(new_acc, self.deepseek)
|
||||
return True
|
||||
|
||||
async def release(self) -> None:
|
||||
if self.use_config_token and self.account_manager and self.account:
|
||||
await self.account_manager.release(self.account)
|
||||
|
||||
|
||||
async def determine_mode_and_token(request: Request) -> AuthContext:
|
||||
deepseek: DeepSeekClient = request.app.state.deepseek
|
||||
account_manager: AccountManager = request.app.state.account_manager
|
||||
cfg: dict[str, Any] = request.app.state.config
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized: missing Bearer token.")
|
||||
|
||||
caller_key = auth_header.replace("Bearer ", "", 1).strip()
|
||||
config_keys = cfg.get("keys", [])
|
||||
|
||||
if caller_key in config_keys:
|
||||
account = await account_manager.acquire()
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="No accounts configured or all accounts are busy."
|
||||
)
|
||||
|
||||
token = await account_manager.ensure_token(account, deepseek)
|
||||
|
||||
ctx = AuthContext(
|
||||
use_config_token=True,
|
||||
token=token,
|
||||
account=account,
|
||||
tried_accounts=set(),
|
||||
account_manager=account_manager,
|
||||
deepseek=deepseek,
|
||||
)
|
||||
|
||||
request.state.use_config_token = True
|
||||
request.state.deepseek_token = token
|
||||
request.state.account = account
|
||||
request.state.tried_accounts = []
|
||||
return ctx
|
||||
|
||||
ctx = AuthContext(
|
||||
use_config_token=False,
|
||||
token=caller_key,
|
||||
account=None,
|
||||
tried_accounts=set(),
|
||||
account_manager=None,
|
||||
deepseek=deepseek,
|
||||
)
|
||||
|
||||
request.state.use_config_token = False
|
||||
request.state.deepseek_token = caller_key
|
||||
return ctx
|
||||
|
||||
|
||||
async def determine_claude_mode_and_token(request: Request) -> AuthContext:
|
||||
return await determine_mode_and_token(request)
|
||||
|
||||
|
||||
def get_auth_headers(token: str) -> dict[str, str]:
|
||||
from ds2api.core.deepseek import BASE_HEADERS
|
||||
|
||||
return {**BASE_HEADERS, "authorization": f"Bearer {token}"}
|
||||
188
ds2api/core/deepseek.py
Normal file
188
ds2api/core/deepseek.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from curl_cffi import requests
|
||||
|
||||
from ds2api.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
DEEPSEEK_HOST = "chat.deepseek.com"
|
||||
DEEPSEEK_LOGIN_URL = f"https://{DEEPSEEK_HOST}/api/v0/users/login"
|
||||
DEEPSEEK_CREATE_SESSION_URL = f"https://{DEEPSEEK_HOST}/api/v0/chat_session/create"
|
||||
DEEPSEEK_CREATE_POW_URL = f"https://{DEEPSEEK_HOST}/api/v0/chat/create_pow_challenge"
|
||||
DEEPSEEK_COMPLETION_URL = f"https://{DEEPSEEK_HOST}/api/v0/chat/completion"
|
||||
|
||||
BASE_HEADERS: dict[str, str] = {
|
||||
"Host": DEEPSEEK_HOST,
|
||||
"User-Agent": "DeepSeek/1.0.13 Android/35",
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"Content-Type": "application/json",
|
||||
"x-client-platform": "android",
|
||||
"x-client-version": "1.3.0-auto-resume",
|
||||
"x-client-locale": "zh_CN",
|
||||
"accept-charset": "UTF-8",
|
||||
}
|
||||
|
||||
|
||||
class DeepSeekClient:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
impersonate: str = "safari15_3",
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
self._session = requests.Session()
|
||||
self._impersonate = impersonate
|
||||
self._timeout = timeout
|
||||
|
||||
def _headers(self, token: str | None = None) -> dict[str, str]:
|
||||
if token:
|
||||
return {**BASE_HEADERS, "authorization": f"Bearer {token}"}
|
||||
return dict(BASE_HEADERS)
|
||||
|
||||
def login(self, *, email: str | None, mobile: str | None, password: str) -> str:
|
||||
if not password or (not email and not mobile):
|
||||
raise ValueError("账号缺少必要的登录信息(必须提供 email 或 mobile 以及 password)")
|
||||
|
||||
if email:
|
||||
payload: dict[str, Any] = {
|
||||
"email": email,
|
||||
"password": password,
|
||||
"device_id": "deepseek_to_api",
|
||||
"os": "android",
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"mobile": mobile,
|
||||
"area_code": None,
|
||||
"password": password,
|
||||
"device_id": "deepseek_to_api",
|
||||
"os": "android",
|
||||
}
|
||||
|
||||
resp = self._session.post(
|
||||
DEEPSEEK_LOGIN_URL,
|
||||
headers=self._headers(),
|
||||
json=payload,
|
||||
impersonate=self._impersonate,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if (
|
||||
data.get("data") is None
|
||||
or data["data"].get("biz_data") is None
|
||||
or data["data"]["biz_data"].get("user") is None
|
||||
):
|
||||
raise RuntimeError("Account login failed: invalid response format")
|
||||
|
||||
token = data["data"]["biz_data"]["user"].get("token")
|
||||
if not token:
|
||||
raise RuntimeError("Account login failed: missing token")
|
||||
|
||||
return token
|
||||
|
||||
def create_session(self, token: str) -> str | None:
|
||||
headers = self._headers(token)
|
||||
try:
|
||||
resp = self._session.post(
|
||||
DEEPSEEK_CREATE_SESSION_URL,
|
||||
headers=headers,
|
||||
json={"agent": "chat"},
|
||||
impersonate=self._impersonate,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[create_session] 请求异常: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception as e:
|
||||
logger.error(f"[create_session] JSON解析异常: {e}")
|
||||
data = {}
|
||||
|
||||
if resp.status_code == 200 and data.get("code") == 0:
|
||||
try:
|
||||
return data["data"]["biz_data"]["id"]
|
||||
finally:
|
||||
resp.close()
|
||||
|
||||
code = data.get("code")
|
||||
logger.warning(
|
||||
f"[create_session] 创建会话失败, code={code}, msg={data.get('msg')}, status={resp.status_code}"
|
||||
)
|
||||
resp.close()
|
||||
return None
|
||||
|
||||
def create_pow_challenge(self, token: str) -> dict[str, Any] | None:
|
||||
headers = self._headers(token)
|
||||
try:
|
||||
resp = self._session.post(
|
||||
DEEPSEEK_CREATE_POW_URL,
|
||||
headers=headers,
|
||||
json={"target_path": "/api/v0/chat/completion"},
|
||||
timeout=self._timeout,
|
||||
impersonate=self._impersonate,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[create_pow_challenge] 请求异常: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception as e:
|
||||
logger.error(f"[create_pow_challenge] JSON解析异常: {e}")
|
||||
data = {}
|
||||
|
||||
if resp.status_code == 200 and data.get("code") == 0:
|
||||
try:
|
||||
return data["data"]["biz_data"]["challenge"]
|
||||
finally:
|
||||
resp.close()
|
||||
|
||||
code = data.get("code")
|
||||
logger.warning(
|
||||
f"[create_pow_challenge] 获取 PoW 失败, code={code}, msg={data.get('msg')}, status={resp.status_code}"
|
||||
)
|
||||
resp.close()
|
||||
return None
|
||||
|
||||
def completion(
|
||||
self,
|
||||
*,
|
||||
headers: dict[str, str],
|
||||
payload: dict[str, Any],
|
||||
max_attempts: int = 3,
|
||||
) -> requests.Response | None:
|
||||
attempts = 0
|
||||
while attempts < max_attempts:
|
||||
try:
|
||||
resp = self._session.post(
|
||||
DEEPSEEK_COMPLETION_URL,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
impersonate=self._impersonate,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[completion] 请求异常: {e}")
|
||||
time.sleep(1)
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
if resp.status_code == 200:
|
||||
return resp
|
||||
|
||||
logger.warning(f"[completion] 调用对话接口失败, 状态码: {resp.status_code}")
|
||||
resp.close()
|
||||
time.sleep(1)
|
||||
attempts += 1
|
||||
|
||||
return None
|
||||
107
ds2api/core/message_processor.py
Normal file
107
ds2api/core/message_processor.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
CLAUDE_DEFAULT_MODEL = "claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
def messages_prepare(messages: list[dict[str, Any]]) -> str:
|
||||
processed: list[dict[str, str]] = []
|
||||
for m in messages:
|
||||
role = str(m.get("role", ""))
|
||||
content = m.get("content", "")
|
||||
if isinstance(content, list):
|
||||
texts = [
|
||||
str(item.get("text", ""))
|
||||
for item in content
|
||||
if isinstance(item, dict) and item.get("type") == "text"
|
||||
]
|
||||
text = "\n".join(texts)
|
||||
else:
|
||||
text = str(content)
|
||||
processed.append({"role": role, "text": text})
|
||||
|
||||
if not processed:
|
||||
return ""
|
||||
|
||||
merged = [processed[0]]
|
||||
for msg in processed[1:]:
|
||||
if msg["role"] == merged[-1]["role"]:
|
||||
merged[-1]["text"] += "\n\n" + msg["text"]
|
||||
else:
|
||||
merged.append(msg)
|
||||
|
||||
parts: list[str] = []
|
||||
for idx, block in enumerate(merged):
|
||||
role = block["role"]
|
||||
text = block["text"]
|
||||
if role == "assistant":
|
||||
parts.append(f"<|Assistant|>{text}<|end▁of▁sentence|>")
|
||||
elif role in ("user", "system"):
|
||||
if idx > 0:
|
||||
parts.append(f"<|User|>{text}")
|
||||
else:
|
||||
parts.append(text)
|
||||
else:
|
||||
parts.append(text)
|
||||
|
||||
final_prompt = "".join(parts)
|
||||
return re.sub(r"!\[(.*?)\]\((.*?)\)", r"[\1](\2)", final_prompt)
|
||||
|
||||
|
||||
def normalize_claude_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
normalized_messages: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
normalized_message = dict(message)
|
||||
if isinstance(message.get("content"), list):
|
||||
content_parts: list[str] = []
|
||||
for content_block in message["content"]:
|
||||
if not isinstance(content_block, dict):
|
||||
continue
|
||||
if content_block.get("type") == "text" and "text" in content_block:
|
||||
content_parts.append(str(content_block["text"]))
|
||||
elif content_block.get("type") == "tool_result" and "content" in content_block:
|
||||
content_parts.append(str(content_block["content"]))
|
||||
|
||||
if content_parts:
|
||||
normalized_message["content"] = "\n".join(content_parts)
|
||||
normalized_messages.append(normalized_message)
|
||||
|
||||
return normalized_messages
|
||||
|
||||
|
||||
def convert_claude_to_deepseek(
|
||||
claude_request: dict[str, Any],
|
||||
*,
|
||||
default_model: str = CLAUDE_DEFAULT_MODEL,
|
||||
model_mapping: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
messages = claude_request.get("messages", [])
|
||||
model = claude_request.get("model", default_model)
|
||||
|
||||
mapping = model_mapping or {"fast": "deepseek-chat", "slow": "deepseek-chat"}
|
||||
|
||||
model_lower = str(model).lower()
|
||||
if "opus" in model_lower or "reasoner" in model_lower or "slow" in model_lower:
|
||||
deepseek_model = mapping.get("slow", "deepseek-chat")
|
||||
else:
|
||||
deepseek_model = mapping.get("fast", "deepseek-chat")
|
||||
|
||||
deepseek_request: dict[str, Any] = {"model": deepseek_model, "messages": list(messages)}
|
||||
|
||||
if "system" in claude_request:
|
||||
system_msg = {"role": "system", "content": claude_request["system"]}
|
||||
deepseek_request["messages"].insert(0, system_msg)
|
||||
|
||||
if "temperature" in claude_request:
|
||||
deepseek_request["temperature"] = claude_request["temperature"]
|
||||
if "top_p" in claude_request:
|
||||
deepseek_request["top_p"] = claude_request["top_p"]
|
||||
if "stop_sequences" in claude_request:
|
||||
deepseek_request["stop"] = claude_request["stop_sequences"]
|
||||
if "stream" in claude_request:
|
||||
deepseek_request["stream"] = claude_request["stream"]
|
||||
|
||||
return deepseek_request
|
||||
193
ds2api/core/pow.py
Normal file
193
ds2api/core/pow.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import ctypes
|
||||
import hashlib
|
||||
import struct
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from wasmtime import Engine, Linker, Module, Store
|
||||
|
||||
from ds2api.utils.helpers import compact_json_dumps
|
||||
from ds2api.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PowSolver:
|
||||
def __init__(self, wasm_path: str) -> None:
|
||||
self._wasm_path = wasm_path
|
||||
self._engine = Engine()
|
||||
with open(wasm_path, "rb") as f:
|
||||
wasm_bytes = f.read()
|
||||
self._module = Module(self._engine, wasm_bytes)
|
||||
|
||||
def compute_answer(
|
||||
self,
|
||||
*,
|
||||
algorithm: str,
|
||||
challenge_str: str,
|
||||
salt: str,
|
||||
difficulty: int,
|
||||
expire_at: int,
|
||||
) -> int | None:
|
||||
if algorithm != "DeepSeekHashV1":
|
||||
raise ValueError(f"不支持的算法:{algorithm}")
|
||||
|
||||
prefix = f"{salt}_{expire_at}_"
|
||||
store = Store(self._engine)
|
||||
linker = Linker(store.engine)
|
||||
instance = linker.instantiate(store, self._module)
|
||||
exports = instance.exports(store)
|
||||
|
||||
try:
|
||||
memory = exports["memory"]
|
||||
add_to_stack = exports["__wbindgen_add_to_stack_pointer"]
|
||||
alloc = exports["__wbindgen_export_0"]
|
||||
wasm_solve = exports["wasm_solve"]
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"缺少 wasm 导出函数: {e}")
|
||||
|
||||
def write_memory(offset: int, data: bytes) -> None:
|
||||
base_addr = ctypes.cast(memory.data_ptr(store), ctypes.c_void_p).value
|
||||
ctypes.memmove(base_addr + offset, data, len(data))
|
||||
|
||||
def read_memory(offset: int, size: int) -> bytes:
|
||||
base_addr = ctypes.cast(memory.data_ptr(store), ctypes.c_void_p).value
|
||||
return ctypes.string_at(base_addr + offset, size)
|
||||
|
||||
def encode_string(text: str) -> tuple[int, int]:
|
||||
data = text.encode("utf-8")
|
||||
length = len(data)
|
||||
ptr_val = alloc(store, length, 1)
|
||||
ptr = int(ptr_val.value) if hasattr(ptr_val, "value") else int(ptr_val)
|
||||
write_memory(ptr, data)
|
||||
return ptr, length
|
||||
|
||||
retptr = add_to_stack(store, -16)
|
||||
ptr_challenge, len_challenge = encode_string(challenge_str)
|
||||
ptr_prefix, len_prefix = encode_string(prefix)
|
||||
|
||||
wasm_solve(
|
||||
store,
|
||||
retptr,
|
||||
ptr_challenge,
|
||||
len_challenge,
|
||||
ptr_prefix,
|
||||
len_prefix,
|
||||
float(difficulty),
|
||||
)
|
||||
|
||||
status_bytes = read_memory(retptr, 4)
|
||||
if len(status_bytes) != 4:
|
||||
add_to_stack(store, 16)
|
||||
raise RuntimeError("读取状态字节失败")
|
||||
|
||||
status = struct.unpack("<i", status_bytes)[0]
|
||||
value_bytes = read_memory(retptr + 8, 8)
|
||||
if len(value_bytes) != 8:
|
||||
add_to_stack(store, 16)
|
||||
raise RuntimeError("读取结果字节失败")
|
||||
|
||||
value = struct.unpack("<d", value_bytes)[0]
|
||||
add_to_stack(store, 16)
|
||||
|
||||
if status == 0:
|
||||
return None
|
||||
|
||||
return int(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CacheEntry:
|
||||
value: int
|
||||
expire_at: float
|
||||
|
||||
|
||||
class PowCache:
|
||||
def __init__(self, ttl_seconds: int = 60) -> None:
|
||||
self._ttl = ttl_seconds
|
||||
self._lock = asyncio.Lock()
|
||||
self._data: dict[str, _CacheEntry] = {}
|
||||
|
||||
async def get(self, key: str) -> int | None:
|
||||
async with self._lock:
|
||||
entry = self._data.get(key)
|
||||
if not entry:
|
||||
return None
|
||||
if entry.expire_at < time.time():
|
||||
self._data.pop(key, None)
|
||||
return None
|
||||
return entry.value
|
||||
|
||||
async def set(self, key: str, value: int, *, ttl: int | None = None) -> None:
|
||||
async with self._lock:
|
||||
ttl_seconds = self._ttl if ttl is None else ttl
|
||||
self._data[key] = _CacheEntry(value=value, expire_at=time.time() + ttl_seconds)
|
||||
|
||||
|
||||
class PowService:
|
||||
def __init__(self, wasm_path: str, *, cache_ttl_seconds: int = 60) -> None:
|
||||
self._solver = PowSolver(wasm_path)
|
||||
self._cache = PowCache(ttl_seconds=cache_ttl_seconds)
|
||||
|
||||
@staticmethod
|
||||
def _make_cache_key(challenge: dict[str, Any]) -> str:
|
||||
challenge_str = str(challenge.get("challenge", ""))
|
||||
difficulty = str(challenge.get("difficulty", ""))
|
||||
raw = f"{challenge_str}|{difficulty}".encode("utf-8")
|
||||
return hashlib.sha256(raw).hexdigest()
|
||||
|
||||
async def solve_encoded_response(self, challenge: dict[str, Any]) -> str | None:
|
||||
key = self._make_cache_key(challenge)
|
||||
cached = await self._cache.get(key)
|
||||
if cached is not None:
|
||||
return self._encode_response(challenge, cached)
|
||||
|
||||
algorithm = challenge.get("algorithm")
|
||||
challenge_str = challenge.get("challenge")
|
||||
salt = challenge.get("salt")
|
||||
difficulty = int(challenge.get("difficulty", 144000))
|
||||
expire_at = int(challenge.get("expire_at", 0))
|
||||
|
||||
if not all([algorithm, challenge_str, salt, expire_at]):
|
||||
logger.warning("[pow] challenge 字段不完整")
|
||||
return None
|
||||
|
||||
try:
|
||||
answer = await asyncio.to_thread(
|
||||
self._solver.compute_answer,
|
||||
algorithm=algorithm,
|
||||
challenge_str=challenge_str,
|
||||
salt=salt,
|
||||
difficulty=difficulty,
|
||||
expire_at=expire_at,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[pow] PoW 答案计算异常: {e}")
|
||||
return None
|
||||
|
||||
if answer is None:
|
||||
return None
|
||||
|
||||
ttl = 60
|
||||
if expire_at:
|
||||
ttl = max(1, min(60, expire_at - int(time.time())))
|
||||
await self._cache.set(key, answer, ttl=ttl)
|
||||
return self._encode_response(challenge, answer)
|
||||
|
||||
@staticmethod
|
||||
def _encode_response(challenge: dict[str, Any], answer: int) -> str:
|
||||
pow_dict = {
|
||||
"algorithm": challenge.get("algorithm"),
|
||||
"challenge": challenge.get("challenge"),
|
||||
"salt": challenge.get("salt"),
|
||||
"answer": answer,
|
||||
"signature": challenge.get("signature"),
|
||||
"target_path": challenge.get("target_path"),
|
||||
}
|
||||
pow_str = compact_json_dumps(pow_dict)
|
||||
return base64.b64encode(pow_str.encode("utf-8")).decode("utf-8").rstrip()
|
||||
0
ds2api/models/__init__.py
Normal file
0
ds2api/models/__init__.py
Normal file
20
ds2api/models/schemas.py
Normal file
20
ds2api/models/schemas.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "tool"] | str
|
||||
content: Any = ""
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[ChatMessage]
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: dict[str, Any] = Field(default_factory=dict)
|
||||
0
ds2api/services/__init__.py
Normal file
0
ds2api/services/__init__.py
Normal file
283
ds2api/services/claude.py
Normal file
283
ds2api/services/claude.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from ds2api.core.auth import AuthContext, determine_claude_mode_and_token, get_auth_headers
|
||||
from ds2api.core.message_processor import (
|
||||
convert_claude_to_deepseek,
|
||||
messages_prepare,
|
||||
normalize_claude_messages,
|
||||
)
|
||||
from ds2api.services.claude_streaming import (
|
||||
claude_sse_stream,
|
||||
collect_deepseek_content_and_reasoning,
|
||||
detect_tool_calls,
|
||||
)
|
||||
from ds2api.services.token_counter import count_claude_tokens
|
||||
from ds2api.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _create_session_with_retry(ctx: AuthContext, *, max_attempts: int = 3) -> str | None:
|
||||
for _ in range(max_attempts):
|
||||
session_id = ctx.deepseek.create_session(ctx.token)
|
||||
if session_id:
|
||||
return session_id
|
||||
if ctx.use_config_token and await ctx.rotate_account():
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
async def _get_pow_with_retry(ctx: AuthContext, request: Request, *, max_attempts: int = 3) -> str | None:
|
||||
pow_service = request.app.state.pow
|
||||
|
||||
for _ in range(max_attempts):
|
||||
challenge = ctx.deepseek.create_pow_challenge(ctx.token)
|
||||
if not challenge:
|
||||
if ctx.use_config_token and await ctx.rotate_account():
|
||||
continue
|
||||
continue
|
||||
|
||||
pow_resp = await pow_service.solve_encoded_response(challenge)
|
||||
if pow_resp:
|
||||
return pow_resp
|
||||
|
||||
if ctx.use_config_token and await ctx.rotate_account():
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _call_deepseek_for_claude(ctx: AuthContext, request: Request, claude_payload: dict[str, Any]):
|
||||
cfg: dict[str, Any] = request.app.state.config
|
||||
|
||||
deepseek_payload = convert_claude_to_deepseek(
|
||||
claude_payload,
|
||||
model_mapping=cfg.get("claude_model_mapping"),
|
||||
)
|
||||
|
||||
model = deepseek_payload.get("model", "deepseek-chat")
|
||||
model_lower = str(model).lower()
|
||||
if model_lower in ["deepseek-v3", "deepseek-chat"]:
|
||||
thinking_enabled = False
|
||||
search_enabled = False
|
||||
elif model_lower in ["deepseek-r1", "deepseek-reasoner"]:
|
||||
thinking_enabled = True
|
||||
search_enabled = False
|
||||
elif model_lower in ["deepseek-v3-search", "deepseek-chat-search"]:
|
||||
thinking_enabled = False
|
||||
search_enabled = True
|
||||
elif model_lower in ["deepseek-r1-search", "deepseek-reasoner-search"]:
|
||||
thinking_enabled = True
|
||||
search_enabled = True
|
||||
else:
|
||||
thinking_enabled = False
|
||||
search_enabled = False
|
||||
|
||||
final_prompt = messages_prepare(deepseek_payload.get("messages", []))
|
||||
|
||||
session_id = await _create_session_with_retry(ctx)
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=401, detail="invalid token.")
|
||||
|
||||
pow_resp = await _get_pow_with_retry(ctx, request)
|
||||
if not pow_resp:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Failed to get PoW (invalid token or unknown error).",
|
||||
)
|
||||
|
||||
headers = {**get_auth_headers(ctx.token), "x-ds-pow-response": pow_resp}
|
||||
payload = {
|
||||
"chat_session_id": session_id,
|
||||
"parent_message_id": None,
|
||||
"prompt": final_prompt,
|
||||
"ref_file_ids": [],
|
||||
"thinking_enabled": thinking_enabled,
|
||||
"search_enabled": search_enabled,
|
||||
}
|
||||
|
||||
resp = ctx.deepseek.completion(headers=headers, payload=payload, max_attempts=3)
|
||||
return resp
|
||||
|
||||
|
||||
@router.post("/anthropic/v1/messages")
|
||||
async def claude_messages(request: Request):
|
||||
ctx: AuthContext | None = None
|
||||
try:
|
||||
try:
|
||||
ctx = await determine_claude_mode_and_token(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
|
||||
|
||||
req_data = await request.json()
|
||||
model = req_data.get("model")
|
||||
messages = req_data.get("messages", [])
|
||||
if not model or not messages:
|
||||
raise HTTPException(status_code=400, detail="Request must include 'model' and 'messages'.")
|
||||
|
||||
normalized_messages = normalize_claude_messages(messages)
|
||||
tools_requested = req_data.get("tools") or []
|
||||
|
||||
payload = dict(req_data)
|
||||
payload["messages"] = list(normalized_messages)
|
||||
|
||||
if tools_requested and not any(m.get("role") == "system" for m in payload["messages"]):
|
||||
tool_schemas: list[str] = []
|
||||
for tool in tools_requested:
|
||||
tool_name = tool.get("name", "unknown")
|
||||
tool_desc = tool.get("description", "No description available")
|
||||
schema = tool.get("input_schema", {})
|
||||
|
||||
tool_info = f"Tool: {tool_name}\nDescription: {tool_desc}"
|
||||
if isinstance(schema, dict) and "properties" in schema:
|
||||
props = []
|
||||
required = schema.get("required", [])
|
||||
for prop_name, prop_info in schema["properties"].items():
|
||||
prop_type = prop_info.get("type", "string") if isinstance(prop_info, dict) else "string"
|
||||
is_req = " (required)" if prop_name in required else ""
|
||||
props.append(f" - {prop_name}: {prop_type}{is_req}")
|
||||
if props:
|
||||
tool_info += f"\nParameters:\n{chr(10).join(props)}"
|
||||
tool_schemas.append(tool_info)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are Claude, a helpful AI assistant. You have access to these tools:\n\n"
|
||||
+ "\n".join(tool_schemas)
|
||||
+ "\n\nWhen you need to use tools, respond ONLY a JSON object with a tool_calls array."
|
||||
),
|
||||
}
|
||||
payload["messages"].insert(0, system_message)
|
||||
|
||||
deepseek_resp = await _call_deepseek_for_claude(ctx, request, payload)
|
||||
if not deepseek_resp:
|
||||
raise HTTPException(status_code=500, detail="Failed to get Claude response.")
|
||||
|
||||
if deepseek_resp.status_code != 200:
|
||||
deepseek_resp.close()
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": {"type": "api_error", "message": "Failed to get response"}},
|
||||
)
|
||||
|
||||
background = BackgroundTask(ctx.release) if ctx else None
|
||||
|
||||
if bool(req_data.get("stream", False)):
|
||||
return StreamingResponse(
|
||||
claude_sse_stream(
|
||||
deepseek_resp=deepseek_resp,
|
||||
model=str(model),
|
||||
messages=messages,
|
||||
tools_requested=tools_requested,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={"Content-Type": "text/event-stream"},
|
||||
background=background,
|
||||
)
|
||||
|
||||
try:
|
||||
final_content, final_reasoning = collect_deepseek_content_and_reasoning(deepseek_resp)
|
||||
finally:
|
||||
try:
|
||||
deepseek_resp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cleaned_content = final_content.strip()
|
||||
detected_tools = detect_tool_calls(cleaned_content, tools_requested)
|
||||
|
||||
claude_response: dict[str, Any] = {
|
||||
"id": f"msg_{int(time.time())}_{random.randint(1000, 9999)}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": [],
|
||||
"stop_reason": "tool_use" if detected_tools else "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": len(str(normalized_messages)) // 4,
|
||||
"output_tokens": (len(final_content) + len(final_reasoning)) // 4,
|
||||
},
|
||||
}
|
||||
|
||||
if final_reasoning:
|
||||
claude_response["content"].append({"type": "thinking", "thinking": final_reasoning})
|
||||
|
||||
if detected_tools:
|
||||
for i, tool_info in enumerate(detected_tools):
|
||||
tool_use_id = f"toolu_{int(time.time())}_{random.randint(1000, 9999)}_{i}"
|
||||
claude_response["content"].append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_use_id,
|
||||
"name": tool_info["name"],
|
||||
"input": tool_info["input"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
claude_response["content"].append(
|
||||
{"type": "text", "text": final_content or "抱歉,没有生成有效的响应内容。"}
|
||||
)
|
||||
|
||||
return JSONResponse(content=claude_response, status_code=200, background=background)
|
||||
|
||||
except HTTPException as exc:
|
||||
if ctx:
|
||||
await ctx.release()
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"error": {"type": "invalid_request_error", "message": exc.detail}},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[claude_messages] 未知异常: {exc}")
|
||||
if ctx:
|
||||
await ctx.release()
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": {"type": "api_error", "message": "Internal Server Error"}},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anthropic/v1/messages/count_tokens")
|
||||
async def claude_count_tokens(request: Request):
|
||||
ctx: AuthContext | None = None
|
||||
try:
|
||||
try:
|
||||
ctx = await determine_claude_mode_and_token(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
|
||||
|
||||
req_data = await request.json()
|
||||
if not req_data.get("model") or not req_data.get("messages"):
|
||||
raise HTTPException(status_code=400, detail="Request must include 'model' and 'messages'.")
|
||||
|
||||
response = {"input_tokens": count_claude_tokens(req_data)}
|
||||
background = BackgroundTask(ctx.release) if ctx else None
|
||||
return JSONResponse(content=response, status_code=200, background=background)
|
||||
|
||||
except HTTPException as exc:
|
||||
if ctx:
|
||||
await ctx.release()
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"error": {"type": "invalid_request_error", "message": exc.detail}},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[claude_count_tokens] 未知异常: {exc}")
|
||||
if ctx:
|
||||
await ctx.release()
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": {"type": "api_error", "message": "Internal Server Error"}},
|
||||
)
|
||||
230
ds2api/services/claude_streaming.py
Normal file
230
ds2api/services/claude_streaming.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Iterator
|
||||
|
||||
from ds2api.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def detect_tool_calls(cleaned_response: str, tools_requested: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
detected_tools: list[dict[str, Any]] = []
|
||||
tool_detected = False
|
||||
|
||||
if cleaned_response.startswith('{"tool_calls":') and cleaned_response.endswith(']}'):
|
||||
try:
|
||||
tool_data = json.loads(cleaned_response)
|
||||
for tool_call in tool_data.get("tool_calls", []):
|
||||
tool_name = tool_call.get("name")
|
||||
tool_input = tool_call.get("input", {})
|
||||
if any(tool.get("name") == tool_name for tool in tools_requested):
|
||||
detected_tools.append({"name": tool_name, "input": tool_input})
|
||||
tool_detected = True
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if not tool_detected:
|
||||
tool_call_pattern = r"\{\s*[\"\']tool_calls[\"\']\s*:\s*\[(.*?)\]\s*\}"
|
||||
matches = re.findall(tool_call_pattern, cleaned_response, re.DOTALL)
|
||||
for match in matches:
|
||||
try:
|
||||
tool_calls_json = f'{{"tool_calls": [{match}]}}'
|
||||
tool_data = json.loads(tool_calls_json)
|
||||
for tool_call in tool_data.get("tool_calls", []):
|
||||
tool_name = tool_call.get("name")
|
||||
tool_input = tool_call.get("input", {})
|
||||
if any(tool.get("name") == tool_name for tool in tools_requested):
|
||||
detected_tools.append({"name": tool_name, "input": tool_input})
|
||||
tool_detected = True
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return detected_tools
|
||||
|
||||
|
||||
def collect_deepseek_text(deepseek_resp) -> str:
|
||||
full_response_text = ""
|
||||
for line in deepseek_resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
line_str = line.decode("utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not line_str.startswith("data:"):
|
||||
continue
|
||||
|
||||
data_str = line_str[5:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if "v" in chunk and isinstance(chunk["v"], str):
|
||||
full_response_text += chunk["v"]
|
||||
elif "v" in chunk and isinstance(chunk["v"], list):
|
||||
for item in chunk["v"]:
|
||||
if item.get("p") == "status" and item.get("v") == "FINISHED":
|
||||
break
|
||||
|
||||
return full_response_text
|
||||
|
||||
|
||||
def collect_deepseek_content_and_reasoning(deepseek_resp) -> tuple[str, str]:
|
||||
final_content = ""
|
||||
final_reasoning = ""
|
||||
ptype = "text"
|
||||
|
||||
for raw_line in deepseek_resp.iter_lines():
|
||||
if not raw_line:
|
||||
continue
|
||||
try:
|
||||
line = raw_line.decode("utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
|
||||
data_str = line[5:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if "v" not in chunk:
|
||||
continue
|
||||
|
||||
v_value = chunk["v"]
|
||||
if chunk.get("p") == "response/thinking_content":
|
||||
ptype = "thinking"
|
||||
elif chunk.get("p") == "response/content":
|
||||
ptype = "text"
|
||||
|
||||
if isinstance(v_value, str):
|
||||
if ptype == "thinking":
|
||||
final_reasoning += v_value
|
||||
else:
|
||||
final_content += v_value
|
||||
elif isinstance(v_value, list):
|
||||
for item in v_value:
|
||||
if item.get("p") == "status" and item.get("v") == "FINISHED":
|
||||
break
|
||||
|
||||
return final_content, final_reasoning
|
||||
|
||||
|
||||
def claude_sse_stream(
|
||||
*,
|
||||
deepseek_resp,
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools_requested: list[dict[str, Any]],
|
||||
) -> Iterator[str]:
|
||||
message_id = f"msg_{int(time.time())}_{random.randint(1000, 9999)}"
|
||||
input_tokens = max(1, sum(len(str(m.get("content", ""))) for m in messages) // 4)
|
||||
|
||||
try:
|
||||
full_response_text = collect_deepseek_text(deepseek_resp)
|
||||
cleaned_response = full_response_text.strip()
|
||||
detected_tools = detect_tool_calls(cleaned_response, tools_requested)
|
||||
|
||||
message_start = {
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": [],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": input_tokens, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(message_start)}\n\n"
|
||||
|
||||
content_index = 0
|
||||
if detected_tools:
|
||||
stop_reason = "tool_use"
|
||||
for tool_info in detected_tools:
|
||||
tool_use_id = f"toolu_{int(time.time())}_{random.randint(1000, 9999)}_{content_index}"
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": content_index,
|
||||
"content_block": {
|
||||
"type": "tool_use",
|
||||
"id": tool_use_id,
|
||||
"name": tool_info["name"],
|
||||
"input": tool_info["input"],
|
||||
},
|
||||
}
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps({"type": "content_block_stop", "index": content_index})
|
||||
+ "\n\n"
|
||||
)
|
||||
content_index += 1
|
||||
else:
|
||||
stop_reason = "end_turn"
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
if cleaned_response:
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": cleaned_response},
|
||||
}
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
yield "data: " + json.dumps({"type": "content_block_stop", "index": 0}) + "\n\n"
|
||||
|
||||
output_tokens = max(1, len(cleaned_response) // 4)
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
|
||||
"usage": {"output_tokens": output_tokens},
|
||||
}
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
yield "data: " + json.dumps({"type": "message_stop"}) + "\n\n"
|
||||
|
||||
finally:
|
||||
try:
|
||||
deepseek_resp.close()
|
||||
except Exception:
|
||||
pass
|
||||
156
ds2api/services/completion.py
Normal file
156
ds2api/services/completion.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from ds2api.core.auth import AuthContext, determine_mode_and_token, get_auth_headers
|
||||
from ds2api.core.message_processor import messages_prepare
|
||||
from ds2api.services.openai_streaming import openai_json_response_stream, openai_sse_stream
|
||||
from ds2api.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _create_session_with_retry(ctx: AuthContext, *, max_attempts: int = 3) -> str | None:
|
||||
for _ in range(max_attempts):
|
||||
session_id = ctx.deepseek.create_session(ctx.token)
|
||||
if session_id:
|
||||
return session_id
|
||||
if ctx.use_config_token and await ctx.rotate_account():
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
async def _get_pow_with_retry(ctx: AuthContext, request: Request, *, max_attempts: int = 3) -> str | None:
|
||||
pow_service = request.app.state.pow
|
||||
|
||||
for _ in range(max_attempts):
|
||||
challenge = ctx.deepseek.create_pow_challenge(ctx.token)
|
||||
if not challenge:
|
||||
if ctx.use_config_token and await ctx.rotate_account():
|
||||
continue
|
||||
continue
|
||||
|
||||
pow_resp = await pow_service.solve_encoded_response(challenge)
|
||||
if pow_resp:
|
||||
return pow_resp
|
||||
|
||||
if ctx.use_config_token and await ctx.rotate_account():
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
ctx: AuthContext | None = None
|
||||
try:
|
||||
try:
|
||||
ctx = await determine_mode_and_token(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
|
||||
|
||||
req_data = await request.json()
|
||||
model = req_data.get("model")
|
||||
messages = req_data.get("messages", [])
|
||||
if not model or not messages:
|
||||
raise HTTPException(status_code=400, detail="Request must include 'model' and 'messages'.")
|
||||
|
||||
model_lower = str(model).lower()
|
||||
if model_lower in ["deepseek-v3", "deepseek-chat"]:
|
||||
thinking_enabled = False
|
||||
search_enabled = False
|
||||
elif model_lower in ["deepseek-r1", "deepseek-reasoner"]:
|
||||
thinking_enabled = True
|
||||
search_enabled = False
|
||||
elif model_lower in ["deepseek-v3-search", "deepseek-chat-search"]:
|
||||
thinking_enabled = False
|
||||
search_enabled = True
|
||||
elif model_lower in ["deepseek-r1-search", "deepseek-reasoner-search"]:
|
||||
thinking_enabled = True
|
||||
search_enabled = True
|
||||
else:
|
||||
raise HTTPException(status_code=503, detail=f"Model '{model}' is not available.")
|
||||
|
||||
final_prompt = messages_prepare(messages)
|
||||
|
||||
session_id = await _create_session_with_retry(ctx)
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=401, detail="invalid token.")
|
||||
|
||||
pow_resp = await _get_pow_with_retry(ctx, request)
|
||||
if not pow_resp:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Failed to get PoW (invalid token or unknown error).",
|
||||
)
|
||||
|
||||
headers = {**get_auth_headers(ctx.token), "x-ds-pow-response": pow_resp}
|
||||
payload: dict[str, Any] = {
|
||||
"chat_session_id": session_id,
|
||||
"parent_message_id": None,
|
||||
"prompt": final_prompt,
|
||||
"ref_file_ids": [],
|
||||
"thinking_enabled": thinking_enabled,
|
||||
"search_enabled": search_enabled,
|
||||
}
|
||||
|
||||
deepseek_resp = ctx.deepseek.completion(headers=headers, payload=payload, max_attempts=3)
|
||||
if not deepseek_resp:
|
||||
raise HTTPException(status_code=500, detail="Failed to get completion.")
|
||||
|
||||
created_time = int(time.time())
|
||||
completion_id = f"{session_id}"
|
||||
keep_alive_timeout = request.app.state.settings.keep_alive_timeout
|
||||
|
||||
background = BackgroundTask(ctx.release) if ctx else None
|
||||
|
||||
if bool(req_data.get("stream", False)):
|
||||
if deepseek_resp.status_code != 200:
|
||||
deepseek_resp.close()
|
||||
return JSONResponse(content=deepseek_resp.content, status_code=deepseek_resp.status_code)
|
||||
|
||||
return StreamingResponse(
|
||||
openai_sse_stream(
|
||||
deepseek_resp=deepseek_resp,
|
||||
model=str(model),
|
||||
completion_id=completion_id,
|
||||
created_time=created_time,
|
||||
final_prompt=final_prompt,
|
||||
thinking_enabled=thinking_enabled,
|
||||
search_enabled=search_enabled,
|
||||
keep_alive_timeout=keep_alive_timeout,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={"Content-Type": "text/event-stream"},
|
||||
background=background,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
openai_json_response_stream(
|
||||
deepseek_resp=deepseek_resp,
|
||||
model=str(model),
|
||||
completion_id=completion_id,
|
||||
created_time=created_time,
|
||||
final_prompt=final_prompt,
|
||||
search_enabled=search_enabled,
|
||||
),
|
||||
media_type="application/json",
|
||||
background=background,
|
||||
)
|
||||
|
||||
except HTTPException as exc:
|
||||
if ctx:
|
||||
await ctx.release()
|
||||
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
|
||||
except Exception as exc:
|
||||
logger.error(f"[chat_completions] 未知异常: {exc}")
|
||||
if ctx:
|
||||
await ctx.release()
|
||||
return JSONResponse(status_code=500, content={"error": "Internal Server Error"})
|
||||
64
ds2api/services/models.py
Normal file
64
ds2api/services/models.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
def list_models():
|
||||
models_list = [
|
||||
{
|
||||
"id": "deepseek-chat",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "deepseek",
|
||||
"permission": [],
|
||||
},
|
||||
{
|
||||
"id": "deepseek-reasoner",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "deepseek",
|
||||
"permission": [],
|
||||
},
|
||||
{
|
||||
"id": "deepseek-chat-search",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "deepseek",
|
||||
"permission": [],
|
||||
},
|
||||
{
|
||||
"id": "deepseek-reasoner-search",
|
||||
"object": "model",
|
||||
"created": 1677610602,
|
||||
"owned_by": "deepseek",
|
||||
"permission": [],
|
||||
},
|
||||
]
|
||||
return JSONResponse(content={"object": "list", "data": models_list}, status_code=200)
|
||||
|
||||
|
||||
@router.get("/anthropic/v1/models")
|
||||
def list_claude_models():
|
||||
models_list = [
|
||||
{
|
||||
"id": "claude-sonnet-4-20250514",
|
||||
"object": "model",
|
||||
"created": 1715635200,
|
||||
"owned_by": "anthropic",
|
||||
},
|
||||
{
|
||||
"id": "claude-sonnet-4-20250514-fast",
|
||||
"object": "model",
|
||||
"created": 1715635200,
|
||||
"owned_by": "anthropic",
|
||||
},
|
||||
{
|
||||
"id": "claude-sonnet-4-20250514-slow",
|
||||
"object": "model",
|
||||
"created": 1715635200,
|
||||
"owned_by": "anthropic",
|
||||
},
|
||||
]
|
||||
return JSONResponse(content={"object": "list", "data": models_list}, status_code=200)
|
||||
379
ds2api/services/openai_streaming.py
Normal file
379
ds2api/services/openai_streaming.py
Normal file
@@ -0,0 +1,379 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Iterator
|
||||
|
||||
from ds2api.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def openai_sse_stream(
|
||||
*,
|
||||
deepseek_resp,
|
||||
model: str,
|
||||
completion_id: str,
|
||||
created_time: int,
|
||||
final_prompt: str,
|
||||
thinking_enabled: bool,
|
||||
search_enabled: bool,
|
||||
keep_alive_timeout: int,
|
||||
) -> Iterator[str]:
|
||||
final_text = ""
|
||||
final_thinking = ""
|
||||
first_chunk_sent = False
|
||||
result_queue: queue.Queue[dict[str, Any] | None] = queue.Queue()
|
||||
last_send_time = time.time()
|
||||
|
||||
def process_data() -> None:
|
||||
ptype = "text"
|
||||
try:
|
||||
for raw_line in deepseek_resp.iter_lines():
|
||||
try:
|
||||
line = raw_line.decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"[sse_stream] 解码失败: {e}")
|
||||
error_type = "thinking" if ptype == "thinking" else "text"
|
||||
result_queue.put(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": "解码失败,请稍候再试",
|
||||
"type": error_type,
|
||||
},
|
||||
}
|
||||
],
|
||||
"model": "",
|
||||
"chunk_token_usage": 1,
|
||||
"created": 0,
|
||||
"message_id": -1,
|
||||
"parent_id": -1,
|
||||
}
|
||||
)
|
||||
result_queue.put(None)
|
||||
break
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
|
||||
data_str = line[5:].strip()
|
||||
if data_str == "[DONE]":
|
||||
result_queue.put(None)
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
if "v" not in chunk:
|
||||
continue
|
||||
|
||||
if chunk.get("p") == "response/search_status":
|
||||
continue
|
||||
|
||||
if chunk.get("p") == "response/thinking_content":
|
||||
ptype = "thinking"
|
||||
elif chunk.get("p") == "response/content":
|
||||
ptype = "text"
|
||||
|
||||
v_value = chunk["v"]
|
||||
if isinstance(v_value, str):
|
||||
content = v_value
|
||||
elif isinstance(v_value, list):
|
||||
for item in v_value:
|
||||
if item.get("p") == "status" and item.get("v") == "FINISHED":
|
||||
result_queue.put({"choices": [{"index": 0, "finish_reason": "stop"}]})
|
||||
result_queue.put(None)
|
||||
return
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": content, "type": ptype},
|
||||
}
|
||||
],
|
||||
"model": "",
|
||||
"chunk_token_usage": len(content) // 4,
|
||||
"created": 0,
|
||||
"message_id": -1,
|
||||
"parent_id": -1,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[sse_stream] 无法解析: {data_str}, 错误: {e}")
|
||||
error_type = "thinking" if ptype == "thinking" else "text"
|
||||
result_queue.put(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": "解析失败,请稍候再试",
|
||||
"type": error_type,
|
||||
},
|
||||
}
|
||||
],
|
||||
"model": "",
|
||||
"chunk_token_usage": 1,
|
||||
"created": 0,
|
||||
"message_id": -1,
|
||||
"parent_id": -1,
|
||||
}
|
||||
)
|
||||
result_queue.put(None)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"[sse_stream] 错误: {e}")
|
||||
result_queue.put(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": "服务器错误,请稍候再试",
|
||||
"type": "text",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
result_queue.put(None)
|
||||
finally:
|
||||
try:
|
||||
deepseek_resp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
threading.Thread(target=process_data, daemon=True).start()
|
||||
|
||||
while True:
|
||||
current_time = time.time()
|
||||
if current_time - last_send_time >= keep_alive_timeout:
|
||||
yield ": keep-alive\n\n"
|
||||
last_send_time = current_time
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk = result_queue.get(timeout=0.05)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if chunk is None:
|
||||
prompt_tokens = len(final_prompt) // 4
|
||||
thinking_tokens = len(final_thinking) // 4
|
||||
completion_tokens = len(final_text) // 4
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": thinking_tokens + completion_tokens,
|
||||
"total_tokens": prompt_tokens + thinking_tokens + completion_tokens,
|
||||
"completion_tokens_details": {"reasoning_tokens": thinking_tokens},
|
||||
}
|
||||
finish_chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}],
|
||||
"usage": usage,
|
||||
}
|
||||
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
new_choices = []
|
||||
for choice in chunk.get("choices", []):
|
||||
delta = choice.get("delta", {})
|
||||
ctype = delta.get("type")
|
||||
ctext = delta.get("content", "")
|
||||
if choice.get("finish_reason") == "backend_busy":
|
||||
ctext = "服务器繁忙,请稍候再试"
|
||||
if search_enabled and isinstance(ctext, str) and ctext.startswith("[citation:"):
|
||||
ctext = ""
|
||||
|
||||
delta_obj: dict[str, Any] = {}
|
||||
if not first_chunk_sent:
|
||||
delta_obj["role"] = "assistant"
|
||||
first_chunk_sent = True
|
||||
|
||||
if ctype == "thinking":
|
||||
if thinking_enabled:
|
||||
final_thinking += ctext
|
||||
delta_obj["reasoning_content"] = ctext
|
||||
elif ctype == "text":
|
||||
final_text += ctext
|
||||
delta_obj["content"] = ctext
|
||||
|
||||
if delta_obj:
|
||||
new_choices.append({"delta": delta_obj, "index": choice.get("index", 0)})
|
||||
|
||||
if new_choices:
|
||||
out_chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": new_choices,
|
||||
}
|
||||
yield f"data: {json.dumps(out_chunk, ensure_ascii=False)}\n\n"
|
||||
last_send_time = current_time
|
||||
|
||||
|
||||
def openai_json_response_stream(
|
||||
*,
|
||||
deepseek_resp,
|
||||
model: str,
|
||||
completion_id: str,
|
||||
created_time: int,
|
||||
final_prompt: str,
|
||||
search_enabled: bool,
|
||||
) -> Iterator[str]:
|
||||
think_list: list[str] = []
|
||||
text_list: list[str] = []
|
||||
result: dict[str, Any] | None = None
|
||||
|
||||
def collect_data() -> None:
|
||||
nonlocal result
|
||||
ptype = "text"
|
||||
try:
|
||||
for raw_line in deepseek_resp.iter_lines():
|
||||
try:
|
||||
line = raw_line.decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"[chat_completions] 解码失败: {e}")
|
||||
if ptype == "thinking":
|
||||
think_list.append("解码失败,请稍候再试")
|
||||
else:
|
||||
text_list.append("解码失败,请稍候再试")
|
||||
break
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
|
||||
data_str = line[5:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
if "v" not in chunk:
|
||||
continue
|
||||
|
||||
if chunk.get("p") == "response/search_status":
|
||||
continue
|
||||
|
||||
if chunk.get("p") == "response/thinking_content":
|
||||
ptype = "thinking"
|
||||
elif chunk.get("p") == "response/content":
|
||||
ptype = "text"
|
||||
|
||||
v_value = chunk["v"]
|
||||
if isinstance(v_value, str):
|
||||
if search_enabled and v_value.startswith("[citation:"):
|
||||
continue
|
||||
if ptype == "thinking":
|
||||
think_list.append(v_value)
|
||||
else:
|
||||
text_list.append(v_value)
|
||||
elif isinstance(v_value, list):
|
||||
for item in v_value:
|
||||
if item.get("p") == "status" and item.get("v") == "FINISHED":
|
||||
final_reasoning = "".join(think_list)
|
||||
final_content = "".join(text_list)
|
||||
prompt_tokens = len(final_prompt) // 4
|
||||
reasoning_tokens = len(final_reasoning) // 4
|
||||
completion_tokens = len(final_content) // 4
|
||||
result = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": final_content,
|
||||
"reasoning_content": final_reasoning,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": reasoning_tokens + completion_tokens,
|
||||
"total_tokens": prompt_tokens + reasoning_tokens + completion_tokens,
|
||||
"completion_tokens_details": {
|
||||
"reasoning_tokens": reasoning_tokens
|
||||
},
|
||||
},
|
||||
}
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"[collect_data] 无法解析: {data_str}, 错误: {e}")
|
||||
if ptype == "thinking":
|
||||
think_list.append("解析失败,请稍候再试")
|
||||
else:
|
||||
text_list.append("解析失败,请稍候再试")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"[collect_data] 错误: {e}")
|
||||
if ptype == "thinking":
|
||||
think_list.append("处理失败,请稍候再试")
|
||||
else:
|
||||
text_list.append("处理失败,请稍候再试")
|
||||
finally:
|
||||
try:
|
||||
deepseek_resp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if result is None:
|
||||
final_content = "".join(text_list)
|
||||
final_reasoning = "".join(think_list)
|
||||
prompt_tokens = len(final_prompt) // 4
|
||||
reasoning_tokens = len(final_reasoning) // 4
|
||||
completion_tokens = len(final_content) // 4
|
||||
result = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": final_content,
|
||||
"reasoning_content": final_reasoning,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": reasoning_tokens + completion_tokens,
|
||||
"total_tokens": prompt_tokens + reasoning_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
t = threading.Thread(target=collect_data, daemon=True)
|
||||
t.start()
|
||||
|
||||
while t.is_alive():
|
||||
time.sleep(0.1)
|
||||
|
||||
yield json.dumps(result)
|
||||
53
ds2api/services/token_counter.py
Normal file
53
ds2api/services/token_counter.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def estimate_tokens(text: Any) -> int:
|
||||
if text is None:
|
||||
return 0
|
||||
if isinstance(text, str):
|
||||
return len(text) // 4
|
||||
if isinstance(text, (bytes, bytearray)):
|
||||
return len(text) // 4
|
||||
if isinstance(text, list):
|
||||
return sum(estimate_tokens(item) for item in text)
|
||||
if isinstance(text, dict):
|
||||
if text.get("type") == "text":
|
||||
return estimate_tokens(text.get("text", ""))
|
||||
if text.get("type") == "tool_result":
|
||||
return estimate_tokens(text.get("content", ""))
|
||||
return estimate_tokens(json.dumps(text, ensure_ascii=False))
|
||||
return len(str(text)) // 4
|
||||
|
||||
|
||||
def count_claude_tokens(payload: dict[str, Any]) -> int:
|
||||
messages = payload.get("messages", [])
|
||||
system = payload.get("system", "")
|
||||
tools = payload.get("tools", [])
|
||||
|
||||
input_tokens = 0
|
||||
if system:
|
||||
input_tokens += estimate_tokens(system)
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
|
||||
input_tokens += 2
|
||||
input_tokens += estimate_tokens(role)
|
||||
|
||||
if isinstance(content, list):
|
||||
for content_block in content:
|
||||
input_tokens += estimate_tokens(content_block)
|
||||
else:
|
||||
input_tokens += estimate_tokens(content)
|
||||
|
||||
if tools:
|
||||
for tool in tools:
|
||||
input_tokens += estimate_tokens(tool.get("name", ""))
|
||||
input_tokens += estimate_tokens(tool.get("description", ""))
|
||||
input_tokens += estimate_tokens(json.dumps(tool.get("input_schema", {}), ensure_ascii=False))
|
||||
|
||||
return max(1, input_tokens)
|
||||
0
ds2api/utils/__init__.py
Normal file
0
ds2api/utils/__init__.py
Normal file
39
ds2api/utils/helpers.py
Normal file
39
ds2api/utils/helpers.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def try_parse_json(raw: str) -> dict[str, Any] | None:
|
||||
try:
|
||||
val = json.loads(raw)
|
||||
except Exception:
|
||||
return None
|
||||
return val if isinstance(val, dict) else None
|
||||
|
||||
|
||||
def safe_b64decode(raw: str) -> bytes | None:
|
||||
try:
|
||||
padding = "=" * (-len(raw) % 4)
|
||||
return base64.b64decode(raw + padding)
|
||||
except (binascii.Error, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def try_decode_jwt_exp(token: str) -> int | None:
|
||||
parts = token.split(".")
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
payload = safe_b64decode(parts[1])
|
||||
if not payload:
|
||||
return None
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
except Exception:
|
||||
return None
|
||||
exp = data.get("exp")
|
||||
return int(exp) if isinstance(exp, (int, float)) else None
|
||||
|
||||
|
||||
def compact_json_dumps(data: Any) -> str:
|
||||
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
16
ds2api/utils/logger.py
Normal file
16
ds2api/utils/logger.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
logging.basicConfig(
|
||||
level=os.getenv("LOG_LEVEL", "INFO").upper(),
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
force=True,
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
return logging.getLogger(name)
|
||||
@@ -1,6 +1,5 @@
|
||||
fastapi>=0.110.0,<1.0.0
|
||||
uvicorn>=0.24.0,<1.0.0
|
||||
curl_cffi>=0.7.0,<1.0.0
|
||||
transformers>=4.39.0,<5.0.0
|
||||
wasmtime>=14.0.0,<20.0.0
|
||||
jinja2>=3.1.0,<4.0.0
|
||||
|
||||
263174
tokenizer.json
263174
tokenizer.json
File diff suppressed because it is too large
Load Diff
@@ -1,35 +0,0 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_eos_token": false,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|begin▁of▁sentence|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|end▁of▁sentence|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"legacy": true,
|
||||
"model_max_length": 16384,
|
||||
"pad_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|end▁of▁sentence|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"sp_model_kwargs": {},
|
||||
"unk_token": null,
|
||||
"tokenizer_class": "LlamaTokenizerFast",
|
||||
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}"
|
||||
}
|
||||
Reference in New Issue
Block a user