mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 00:15:28 +08:00
feat: Add support for parsing dictionary-based SSE fragments and prevent duplicate stream termination messages.
This commit is contained in:
@@ -32,9 +32,14 @@ logger = logging.getLogger("ds2api")
|
||||
|
||||
# -------------------------- 初始化 tokenizer --------------------------
|
||||
chat_tokenizer_dir = resolve_path("DS2API_TOKENIZER_DIR", "")
|
||||
# 抑制 Mistral tokenizer regex 警告(不影响 DeepSeek tokenization)
|
||||
_tf_logger = logging.getLogger("transformers")
|
||||
_tf_log_level = _tf_logger.level
|
||||
_tf_logger.setLevel(logging.ERROR)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
chat_tokenizer_dir, trust_remote_code=True
|
||||
)
|
||||
_tf_logger.setLevel(_tf_log_level)
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# 配置文件的读写函数
|
||||
|
||||
@@ -255,6 +255,26 @@ def parse_sse_chunk_for_content(
|
||||
return ([], True, new_fragment_type)
|
||||
contents.extend(result)
|
||||
|
||||
# 处理字典值(初始响应 chunk,包含 response.fragments)
|
||||
elif isinstance(v_value, dict):
|
||||
response_obj = v_value.get("response", v_value)
|
||||
fragments = response_obj.get("fragments", [])
|
||||
if isinstance(fragments, list):
|
||||
for frag in fragments:
|
||||
if isinstance(frag, dict):
|
||||
frag_type = frag.get("type", "").upper()
|
||||
frag_content = frag.get("content", "")
|
||||
if frag_type == "THINK" or frag_type == "THINKING":
|
||||
new_fragment_type = "thinking"
|
||||
if frag_content:
|
||||
contents.append((frag_content, "thinking"))
|
||||
elif frag_type == "RESPONSE":
|
||||
new_fragment_type = "text"
|
||||
if frag_content:
|
||||
contents.append((frag_content, "text"))
|
||||
elif frag_content:
|
||||
contents.append((frag_content, ptype))
|
||||
|
||||
return (contents, False, new_fragment_type)
|
||||
|
||||
|
||||
|
||||
@@ -194,6 +194,7 @@ IMPORTANT: If calling tools, output ONLY the JSON. The response must start with
|
||||
last_content_time = time.time() # 最后收到有效内容的时间
|
||||
keepalive_count = 0 # 连续 keepalive 计数
|
||||
has_content = False # 是否收到过内容
|
||||
stream_finished = False # 是否已发送过结束标记
|
||||
|
||||
def process_data():
|
||||
"""处理 DeepSeek SSE 数据流 - 使用 sse_parser 模块"""
|
||||
@@ -343,6 +344,7 @@ IMPORTANT: If calling tools, output ONLY the JSON. The response must start with
|
||||
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
last_send_time = current_time
|
||||
stream_finished = True
|
||||
break
|
||||
|
||||
new_choices = []
|
||||
@@ -391,8 +393,8 @@ IMPORTANT: If calling tools, output ONLY the JSON. The response must start with
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# 如果是超时退出,也发送结束标记
|
||||
if has_content:
|
||||
# 如果是超时退出且尚未发送结束标记,补发结束标记
|
||||
if has_content and not stream_finished:
|
||||
prompt_tokens = len(final_prompt) // 4
|
||||
thinking_tokens = len(final_thinking) // 4
|
||||
completion_tokens = len(final_text) // 4
|
||||
|
||||
Reference in New Issue
Block a user