mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
feat: Implement OpenAI-compatible tool calling by injecting tool prompts and parsing tool calls from model responses.
This commit is contained in:
176
routes/openai.py
176
routes/openai.py
@@ -2,6 +2,7 @@
|
||||
"""OpenAI 兼容路由"""
|
||||
import json
|
||||
import queue
|
||||
import random
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
@@ -29,6 +30,7 @@ from core.sse_parser import (
|
||||
extract_content_from_chunk,
|
||||
extract_content_recursive,
|
||||
should_filter_citation,
|
||||
parse_tool_calls,
|
||||
)
|
||||
from core.constants import (
|
||||
KEEP_ALIVE_TIMEOUT,
|
||||
@@ -83,6 +85,59 @@ async def chat_completions(request: Request):
|
||||
status_code=400, detail="Request must include 'model' and 'messages'."
|
||||
)
|
||||
|
||||
# 解析工具调用参数(OpenAI 格式)
|
||||
tools_requested = req_data.get("tools") or []
|
||||
has_tools = len(tools_requested) > 0
|
||||
|
||||
# 如果有工具定义,构建工具提示并注入到消息中
|
||||
messages_with_tools = messages.copy()
|
||||
if has_tools:
|
||||
tool_schemas = []
|
||||
for tool in tools_requested:
|
||||
# OpenAI 格式: {"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}}
|
||||
func = tool.get("function", tool) # 兼容简化格式
|
||||
tool_name = func.get("name", "unknown")
|
||||
tool_desc = func.get("description", "No description available")
|
||||
schema = func.get("parameters", {})
|
||||
|
||||
tool_info = f"Tool: {tool_name}\nDescription: {tool_desc}"
|
||||
if "properties" in schema:
|
||||
props = []
|
||||
required = schema.get("required", [])
|
||||
for prop_name, prop_info in schema["properties"].items():
|
||||
prop_type = prop_info.get("type", "string")
|
||||
is_req = " (required)" if prop_name in required else ""
|
||||
props.append(f" - {prop_name}: {prop_type}{is_req}")
|
||||
if props:
|
||||
tool_info += f"\nParameters:\n" + "\n".join(props)
|
||||
tool_schemas.append(tool_info)
|
||||
|
||||
# 检查是否已有系统消息
|
||||
has_system = any(m.get("role") == "system" for m in messages_with_tools)
|
||||
tool_prompt = f"""You have access to these tools:
|
||||
|
||||
{chr(10).join(tool_schemas)}
|
||||
|
||||
When you need to use tools, output ONLY this JSON format (no other text):
|
||||
{{"tool_calls": [
|
||||
{{"name": "tool_name", "input": {{"param": "value"}}}}
|
||||
]}}
|
||||
|
||||
IMPORTANT: If calling tools, output ONLY the JSON. The response must start with {{ and end with }}"""
|
||||
|
||||
if has_system:
|
||||
# 追加到现有系统消息
|
||||
for i, m in enumerate(messages_with_tools):
|
||||
if m.get("role") == "system":
|
||||
messages_with_tools[i] = {
|
||||
"role": "system",
|
||||
"content": m.get("content", "") + "\n\n" + tool_prompt
|
||||
}
|
||||
break
|
||||
else:
|
||||
# 添加新的系统消息
|
||||
messages_with_tools.insert(0, {"role": "system", "content": tool_prompt})
|
||||
|
||||
# 使用会话管理器获取模型配置
|
||||
thinking_enabled, search_enabled = get_model_config(model)
|
||||
if thinking_enabled is None:
|
||||
@@ -90,8 +145,8 @@ async def chat_completions(request: Request):
|
||||
status_code=503, detail=f"Model '{model}' is not available."
|
||||
)
|
||||
|
||||
# 使用 messages_prepare 函数构造最终 prompt
|
||||
final_prompt = messages_prepare(messages)
|
||||
# 使用 messages_prepare 函数构造最终 prompt(使用带工具提示的消息)
|
||||
final_prompt = messages_prepare(messages_with_tools)
|
||||
session_id = create_session(request)
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=401, detail="invalid token.")
|
||||
@@ -255,12 +310,42 @@ async def chat_completions(request: Request):
|
||||
"total_tokens": prompt_tokens + thinking_tokens + completion_tokens,
|
||||
"completion_tokens_details": {"reasoning_tokens": thinking_tokens},
|
||||
}
|
||||
|
||||
# 检测工具调用
|
||||
detected_tools = []
|
||||
finish_reason = "stop"
|
||||
if has_tools:
|
||||
detected_tools = parse_tool_calls(final_text, [{"name": t.get("function", t).get("name")} for t in tools_requested])
|
||||
if detected_tools:
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
if detected_tools:
|
||||
# 发送工具调用响应
|
||||
tool_calls_data = []
|
||||
for idx, tool_info in enumerate(detected_tools):
|
||||
tool_calls_data.append({
|
||||
"id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_info["name"],
|
||||
"arguments": json.dumps(tool_info["input"], ensure_ascii=False)
|
||||
}
|
||||
})
|
||||
tool_chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [{"delta": {"tool_calls": tool_calls_data}, "index": 0}],
|
||||
}
|
||||
yield f"data: {json.dumps(tool_chunk, ensure_ascii=False)}\n\n"
|
||||
|
||||
finish_chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}],
|
||||
"choices": [{"delta": {}, "index": 0, "finish_reason": finish_reason}],
|
||||
"usage": usage,
|
||||
}
|
||||
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
|
||||
@@ -325,12 +410,41 @@ async def chat_completions(request: Request):
|
||||
"total_tokens": prompt_tokens + thinking_tokens + completion_tokens,
|
||||
"completion_tokens_details": {"reasoning_tokens": thinking_tokens},
|
||||
}
|
||||
|
||||
# 检测工具调用
|
||||
detected_tools = []
|
||||
finish_reason = "stop"
|
||||
if has_tools:
|
||||
detected_tools = parse_tool_calls(final_text, [{"name": t.get("function", t).get("name")} for t in tools_requested])
|
||||
if detected_tools:
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
if detected_tools:
|
||||
tool_calls_data = []
|
||||
for idx, tool_info in enumerate(detected_tools):
|
||||
tool_calls_data.append({
|
||||
"id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_info["name"],
|
||||
"arguments": json.dumps(tool_info["input"], ensure_ascii=False)
|
||||
}
|
||||
})
|
||||
tool_chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [{"delta": {"tool_calls": tool_calls_data}, "index": 0}],
|
||||
}
|
||||
yield f"data: {json.dumps(tool_chunk, ensure_ascii=False)}\n\n"
|
||||
|
||||
finish_chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}],
|
||||
"choices": [{"delta": {}, "index": 0, "finish_reason": finish_reason}],
|
||||
"usage": usage,
|
||||
}
|
||||
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
|
||||
@@ -401,14 +515,37 @@ async def chat_completions(request: Request):
|
||||
prompt_tokens = len(final_prompt) // 4
|
||||
reasoning_tokens = len(final_reasoning) // 4
|
||||
completion_tokens = len(final_content) // 4
|
||||
|
||||
# 检测工具调用
|
||||
detected_tools = []
|
||||
finish_reason = "stop"
|
||||
if has_tools:
|
||||
detected_tools = parse_tool_calls(final_content, [{"name": t.get("function", t).get("name")} for t in tools_requested])
|
||||
if detected_tools:
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
# 构建 message 对象
|
||||
message_obj = {
|
||||
"role": "assistant",
|
||||
"content": final_content,
|
||||
"content": final_content if not detected_tools else None,
|
||||
}
|
||||
# 只有启用思考模式时才包含 reasoning_content
|
||||
if thinking_enabled and final_reasoning:
|
||||
message_obj["reasoning_content"] = final_reasoning
|
||||
# 添加工具调用
|
||||
if detected_tools:
|
||||
tool_calls_data = []
|
||||
for idx, tool_info in enumerate(detected_tools):
|
||||
tool_calls_data.append({
|
||||
"id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_info["name"],
|
||||
"arguments": json.dumps(tool_info["input"], ensure_ascii=False)
|
||||
}
|
||||
})
|
||||
message_obj["tool_calls"] = tool_calls_data
|
||||
message_obj["content"] = None
|
||||
|
||||
result = {
|
||||
"id": completion_id,
|
||||
@@ -418,7 +555,7 @@ async def chat_completions(request: Request):
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": message_obj,
|
||||
"finish_reason": "stop",
|
||||
"finish_reason": finish_reason,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
@@ -452,14 +589,37 @@ async def chat_completions(request: Request):
|
||||
prompt_tokens = len(final_prompt) // 4
|
||||
reasoning_tokens = len(final_reasoning) // 4
|
||||
completion_tokens = len(final_content) // 4
|
||||
|
||||
# 检测工具调用
|
||||
detected_tools = []
|
||||
finish_reason = "stop"
|
||||
if has_tools:
|
||||
detected_tools = parse_tool_calls(final_content, [{"name": t.get("function", t).get("name")} for t in tools_requested])
|
||||
if detected_tools:
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
# 构建 message 对象
|
||||
message_obj = {
|
||||
"role": "assistant",
|
||||
"content": final_content,
|
||||
"content": final_content if not detected_tools else None,
|
||||
}
|
||||
# 只有启用思考模式时才包含 reasoning_content
|
||||
if thinking_enabled and final_reasoning:
|
||||
message_obj["reasoning_content"] = final_reasoning
|
||||
# 添加工具调用
|
||||
if detected_tools:
|
||||
tool_calls_data = []
|
||||
for idx, tool_info in enumerate(detected_tools):
|
||||
tool_calls_data.append({
|
||||
"id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_info["name"],
|
||||
"arguments": json.dumps(tool_info["input"], ensure_ascii=False)
|
||||
}
|
||||
})
|
||||
message_obj["tool_calls"] = tool_calls_data
|
||||
message_obj["content"] = None
|
||||
|
||||
result = {
|
||||
"id": completion_id,
|
||||
@@ -469,7 +629,7 @@ async def chat_completions(request: Request):
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": message_obj,
|
||||
"finish_reason": "stop",
|
||||
"finish_reason": finish_reason,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
|
||||
@@ -632,6 +632,237 @@ class TestRunner:
|
||||
"details": {"response_time": data["response_time"]}
|
||||
}
|
||||
|
||||
# =====================================================================
|
||||
# 工具调用测试
|
||||
# =====================================================================
|
||||
|
||||
def test_openai_tool_calling(self) -> dict:
|
||||
"""测试 OpenAI 工具调用"""
|
||||
payload = {
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather in Beijing? Use the get_weather tool."}
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City name"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{self.endpoint}/v1/chat/completions",
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
timeout=TEST_TIMEOUT
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}}
|
||||
|
||||
data = resp.json()
|
||||
if "error" in data:
|
||||
return {"success": False, "message": data["error"]}
|
||||
|
||||
message = data.get("choices", [{}])[0].get("message", {})
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
finish_reason = data.get("choices", [{}])[0].get("finish_reason", "")
|
||||
content = message.get("content", "")
|
||||
|
||||
# AI 可能调用工具,也可能直接回复
|
||||
if tool_calls:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"检测到 {len(tool_calls)} 个工具调用, finish_reason={finish_reason}",
|
||||
"details": {"tool_calls": tool_calls}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"AI 直接回复而非调用工具: {content[:50]}...",
|
||||
"details": {"content": content[:100]}
|
||||
}
|
||||
|
||||
def test_openai_tool_calling_stream(self) -> dict:
|
||||
"""测试 OpenAI 流式工具调用"""
|
||||
payload = {
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Use get_time tool to check current time in Tokyo."}
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Get current time for a timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {"type": "string"}
|
||||
},
|
||||
"required": ["timezone"]
|
||||
}
|
||||
}
|
||||
}],
|
||||
"stream": True
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{self.endpoint}/v1/chat/completions",
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=TEST_TIMEOUT
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return {"success": False, "message": f"状态码: {resp.status_code}"}
|
||||
|
||||
chunks = []
|
||||
tool_calls_found = False
|
||||
finish_reason = None
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
data_str = line_str[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
chunks.append(chunk)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
if "tool_calls" in delta:
|
||||
tool_calls_found = True
|
||||
fr = chunk.get("choices", [{}])[0].get("finish_reason")
|
||||
if fr:
|
||||
finish_reason = fr
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"收到 {len(chunks)} 个数据块, 工具调用: {tool_calls_found}, finish: {finish_reason}",
|
||||
"details": {"chunk_count": len(chunks), "tool_calls_found": tool_calls_found}
|
||||
}
|
||||
|
||||
def test_claude_tool_calling(self) -> dict:
|
||||
"""测试 Claude 工具调用"""
|
||||
payload = {
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"max_tokens": 200,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Use the calculator tool to compute 15 * 23"}
|
||||
],
|
||||
"tools": [{
|
||||
"name": "calculator",
|
||||
"description": "Perform arithmetic calculations",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "Math expression"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{self.endpoint}/anthropic/v1/messages",
|
||||
headers=self.get_headers(is_claude=True),
|
||||
json=payload,
|
||||
timeout=TEST_TIMEOUT
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}}
|
||||
|
||||
data = resp.json()
|
||||
if "error" in data:
|
||||
return {"success": False, "message": str(data["error"])}
|
||||
|
||||
content_blocks = data.get("content", [])
|
||||
stop_reason = data.get("stop_reason", "")
|
||||
|
||||
tool_use_blocks = [b for b in content_blocks if b.get("type") == "tool_use"]
|
||||
text_blocks = [b for b in content_blocks if b.get("type") == "text"]
|
||||
|
||||
if tool_use_blocks:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"检测到 {len(tool_use_blocks)} 个工具调用, stop_reason={stop_reason}",
|
||||
"details": {"tool_use": tool_use_blocks}
|
||||
}
|
||||
else:
|
||||
text_content = "".join(b.get("text", "") for b in text_blocks)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"AI 直接回复: {text_content[:50]}...",
|
||||
"details": {"content": text_content[:100]}
|
||||
}
|
||||
|
||||
# =====================================================================
|
||||
# 搜索模式测试
|
||||
# =====================================================================
|
||||
|
||||
def test_openai_search_mode(self) -> dict:
|
||||
"""测试 OpenAI 搜索模式"""
|
||||
payload = {
|
||||
"model": "deepseek-chat-search",
|
||||
"messages": [
|
||||
{"role": "user", "content": "今天的新闻有哪些?"}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{self.endpoint}/v1/chat/completions",
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=TEST_TIMEOUT
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return {"success": False, "message": f"状态码: {resp.status_code}"}
|
||||
|
||||
content = ""
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
data_str = line_str[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
if "content" in delta:
|
||||
content += delta["content"]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if not content:
|
||||
return {"success": False, "message": "搜索模式无响应内容"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"搜索模式正常,收到 {len(content)} 字符",
|
||||
"details": {"content_preview": content[:100]}
|
||||
}
|
||||
|
||||
# =====================================================================
|
||||
# 运行测试
|
||||
# =====================================================================
|
||||
@@ -672,6 +903,13 @@ class TestRunner:
|
||||
if not quick:
|
||||
self.run_test("多轮对话", self.test_multi_turn_conversation)
|
||||
self.run_test("长输入处理", self.test_long_input)
|
||||
self.run_test("OpenAI 搜索模式", self.test_openai_search_mode)
|
||||
|
||||
# 工具调用测试
|
||||
if not quick:
|
||||
self.run_test("OpenAI 工具调用", self.test_openai_tool_calling)
|
||||
self.run_test("OpenAI 流式工具调用", self.test_openai_tool_calling_stream)
|
||||
self.run_test("Claude 工具调用", self.test_claude_tool_calling)
|
||||
|
||||
# 管理 API 测试
|
||||
self.run_test("管理配置 API", self.test_admin_config)
|
||||
|
||||
@@ -451,6 +451,112 @@ class TestStreamParsing(unittest.TestCase):
|
||||
self.assertTrue(check_response_started(response_fragment)) # RESPONSE 触发
|
||||
|
||||
|
||||
class TestToolCallParsing(unittest.TestCase):
|
||||
"""工具调用解析测试"""
|
||||
|
||||
def test_parse_tool_calls_simple(self):
|
||||
"""测试简单工具调用解析"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = '{"tool_calls": [{"name": "get_weather", "input": {"location": "Beijing"}}]}'
|
||||
tools = [{"name": "get_weather"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["name"], "get_weather")
|
||||
self.assertEqual(result[0]["input"]["location"], "Beijing")
|
||||
|
||||
def test_parse_tool_calls_multiple(self):
|
||||
"""测试多工具调用解析"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = '''{"tool_calls": [
|
||||
{"name": "get_weather", "input": {"location": "Beijing"}},
|
||||
{"name": "get_time", "input": {"timezone": "Asia/Shanghai"}}
|
||||
]}'''
|
||||
tools = [{"name": "get_weather"}, {"name": "get_time"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["name"], "get_weather")
|
||||
self.assertEqual(result[1]["name"], "get_time")
|
||||
|
||||
def test_parse_tool_calls_no_match(self):
|
||||
"""测试无工具调用时返回空列表"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = "这是一个普通的回复,没有工具调用。"
|
||||
tools = [{"name": "get_weather"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_parse_tool_calls_with_surrounding_text(self):
|
||||
"""测试带有周围文本的工具调用"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = '''好的,我来帮你查询天气。
|
||||
{"tool_calls": [{"name": "get_weather", "input": {"location": "Shanghai"}}]}'''
|
||||
tools = [{"name": "get_weather"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["name"], "get_weather")
|
||||
|
||||
def test_parse_tool_calls_empty_input(self):
|
||||
"""测试空输入"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
result = parse_tool_calls("", [])
|
||||
self.assertEqual(result, [])
|
||||
|
||||
result = parse_tool_calls("some text", [])
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_parse_tool_calls_invalid_json(self):
|
||||
"""测试无效 JSON"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = '{"tool_calls": [{"name": "get_weather", invalid json here}'
|
||||
tools = [{"name": "get_weather"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
# 应该返回空列表而不是抛出异常
|
||||
self.assertEqual(result, [])
|
||||
|
||||
|
||||
class TestTokenEstimation(unittest.TestCase):
|
||||
"""Token 估算测试"""
|
||||
|
||||
def test_estimate_tokens_string(self):
|
||||
"""测试字符串 token 估算"""
|
||||
from core.utils import estimate_tokens
|
||||
|
||||
# 8个字符应该约等于2个token
|
||||
result = estimate_tokens("12345678")
|
||||
self.assertEqual(result, 2)
|
||||
|
||||
# 空字符串应该返回1
|
||||
result = estimate_tokens("")
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
def test_estimate_tokens_list(self):
|
||||
"""测试列表 token 估算"""
|
||||
from core.utils import estimate_tokens
|
||||
|
||||
content = [
|
||||
{"text": "Hello"},
|
||||
{"text": "World"}
|
||||
]
|
||||
result = estimate_tokens(content)
|
||||
self.assertGreater(result, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置环境变量避免配置警告
|
||||
os.environ.setdefault("DS2API_CONFIG_PATH",
|
||||
|
||||
Reference in New Issue
Block a user