feat: Implement OpenAI-compatible tool calling by injecting tool prompts and parsing tool calls from model responses.

This commit is contained in:
CJACK
2026-02-01 16:05:10 +08:00
parent 1cf52502bb
commit dbae110d2b
3 changed files with 512 additions and 8 deletions

View File

@@ -632,6 +632,237 @@ class TestRunner:
"details": {"response_time": data["response_time"]}
}
# =====================================================================
# 工具调用测试
# =====================================================================
def test_openai_tool_calling(self) -> dict:
"""测试 OpenAI 工具调用"""
payload = {
"model": "deepseek-chat",
"messages": [
{"role": "user", "content": "What's the weather in Beijing? Use the get_weather tool."}
],
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"]
}
}
}],
"stream": False
}
resp = requests.post(
f"{self.endpoint}/v1/chat/completions",
headers=self.get_headers(),
json=payload,
timeout=TEST_TIMEOUT
)
if resp.status_code != 200:
return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}}
data = resp.json()
if "error" in data:
return {"success": False, "message": data["error"]}
message = data.get("choices", [{}])[0].get("message", {})
tool_calls = message.get("tool_calls", [])
finish_reason = data.get("choices", [{}])[0].get("finish_reason", "")
content = message.get("content", "")
# AI 可能调用工具,也可能直接回复
if tool_calls:
return {
"success": True,
"message": f"检测到 {len(tool_calls)} 个工具调用, finish_reason={finish_reason}",
"details": {"tool_calls": tool_calls}
}
else:
return {
"success": True,
"message": f"AI 直接回复而非调用工具: {content[:50]}...",
"details": {"content": content[:100]}
}
def test_openai_tool_calling_stream(self) -> dict:
"""测试 OpenAI 流式工具调用"""
payload = {
"model": "deepseek-chat",
"messages": [
{"role": "user", "content": "Use get_time tool to check current time in Tokyo."}
],
"tools": [{
"type": "function",
"function": {
"name": "get_time",
"description": "Get current time for a timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {"type": "string"}
},
"required": ["timezone"]
}
}
}],
"stream": True
}
resp = requests.post(
f"{self.endpoint}/v1/chat/completions",
headers=self.get_headers(),
json=payload,
stream=True,
timeout=TEST_TIMEOUT
)
if resp.status_code != 200:
return {"success": False, "message": f"状态码: {resp.status_code}"}
chunks = []
tool_calls_found = False
finish_reason = None
for line in resp.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
data_str = line_str[6:]
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
chunks.append(chunk)
delta = chunk.get("choices", [{}])[0].get("delta", {})
if "tool_calls" in delta:
tool_calls_found = True
fr = chunk.get("choices", [{}])[0].get("finish_reason")
if fr:
finish_reason = fr
except json.JSONDecodeError:
pass
return {
"success": True,
"message": f"收到 {len(chunks)} 个数据块, 工具调用: {tool_calls_found}, finish: {finish_reason}",
"details": {"chunk_count": len(chunks), "tool_calls_found": tool_calls_found}
}
def test_claude_tool_calling(self) -> dict:
"""测试 Claude 工具调用"""
payload = {
"model": "claude-sonnet-4-20250514",
"max_tokens": 200,
"messages": [
{"role": "user", "content": "Use the calculator tool to compute 15 * 23"}
],
"tools": [{
"name": "calculator",
"description": "Perform arithmetic calculations",
"input_schema": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "Math expression"}
},
"required": ["expression"]
}
}],
"stream": False
}
resp = requests.post(
f"{self.endpoint}/anthropic/v1/messages",
headers=self.get_headers(is_claude=True),
json=payload,
timeout=TEST_TIMEOUT
)
if resp.status_code != 200:
return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}}
data = resp.json()
if "error" in data:
return {"success": False, "message": str(data["error"])}
content_blocks = data.get("content", [])
stop_reason = data.get("stop_reason", "")
tool_use_blocks = [b for b in content_blocks if b.get("type") == "tool_use"]
text_blocks = [b for b in content_blocks if b.get("type") == "text"]
if tool_use_blocks:
return {
"success": True,
"message": f"检测到 {len(tool_use_blocks)} 个工具调用, stop_reason={stop_reason}",
"details": {"tool_use": tool_use_blocks}
}
else:
text_content = "".join(b.get("text", "") for b in text_blocks)
return {
"success": True,
"message": f"AI 直接回复: {text_content[:50]}...",
"details": {"content": text_content[:100]}
}
# =====================================================================
# 搜索模式测试
# =====================================================================
def test_openai_search_mode(self) -> dict:
"""测试 OpenAI 搜索模式"""
payload = {
"model": "deepseek-chat-search",
"messages": [
{"role": "user", "content": "今天的新闻有哪些?"}
],
"stream": True
}
resp = requests.post(
f"{self.endpoint}/v1/chat/completions",
headers=self.get_headers(),
json=payload,
stream=True,
timeout=TEST_TIMEOUT
)
if resp.status_code != 200:
return {"success": False, "message": f"状态码: {resp.status_code}"}
content = ""
for line in resp.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
data_str = line_str[6:]
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
delta = chunk.get("choices", [{}])[0].get("delta", {})
if "content" in delta:
content += delta["content"]
except json.JSONDecodeError:
pass
if not content:
return {"success": False, "message": "搜索模式无响应内容"}
return {
"success": True,
"message": f"搜索模式正常,收到 {len(content)} 字符",
"details": {"content_preview": content[:100]}
}
# =====================================================================
# 运行测试
# =====================================================================
@@ -672,6 +903,13 @@ class TestRunner:
if not quick:
self.run_test("多轮对话", self.test_multi_turn_conversation)
self.run_test("长输入处理", self.test_long_input)
self.run_test("OpenAI 搜索模式", self.test_openai_search_mode)
# 工具调用测试
if not quick:
self.run_test("OpenAI 工具调用", self.test_openai_tool_calling)
self.run_test("OpenAI 流式工具调用", self.test_openai_tool_calling_stream)
self.run_test("Claude 工具调用", self.test_claude_tool_calling)
# 管理 API 测试
self.run_test("管理配置 API", self.test_admin_config)

View File

@@ -451,6 +451,112 @@ class TestStreamParsing(unittest.TestCase):
self.assertTrue(check_response_started(response_fragment)) # RESPONSE 触发
class TestToolCallParsing(unittest.TestCase):
"""工具调用解析测试"""
def test_parse_tool_calls_simple(self):
"""测试简单工具调用解析"""
from core.sse_parser import parse_tool_calls
response_text = '{"tool_calls": [{"name": "get_weather", "input": {"location": "Beijing"}}]}'
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["name"], "get_weather")
self.assertEqual(result[0]["input"]["location"], "Beijing")
def test_parse_tool_calls_multiple(self):
"""测试多工具调用解析"""
from core.sse_parser import parse_tool_calls
response_text = '''{"tool_calls": [
{"name": "get_weather", "input": {"location": "Beijing"}},
{"name": "get_time", "input": {"timezone": "Asia/Shanghai"}}
]}'''
tools = [{"name": "get_weather"}, {"name": "get_time"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(len(result), 2)
self.assertEqual(result[0]["name"], "get_weather")
self.assertEqual(result[1]["name"], "get_time")
def test_parse_tool_calls_no_match(self):
"""测试无工具调用时返回空列表"""
from core.sse_parser import parse_tool_calls
response_text = "这是一个普通的回复,没有工具调用。"
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(result, [])
def test_parse_tool_calls_with_surrounding_text(self):
"""测试带有周围文本的工具调用"""
from core.sse_parser import parse_tool_calls
response_text = '''好的,我来帮你查询天气。
{"tool_calls": [{"name": "get_weather", "input": {"location": "Shanghai"}}]}'''
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["name"], "get_weather")
def test_parse_tool_calls_empty_input(self):
"""测试空输入"""
from core.sse_parser import parse_tool_calls
result = parse_tool_calls("", [])
self.assertEqual(result, [])
result = parse_tool_calls("some text", [])
self.assertEqual(result, [])
def test_parse_tool_calls_invalid_json(self):
"""测试无效 JSON"""
from core.sse_parser import parse_tool_calls
response_text = '{"tool_calls": [{"name": "get_weather", invalid json here}'
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
# 应该返回空列表而不是抛出异常
self.assertEqual(result, [])
class TestTokenEstimation(unittest.TestCase):
"""Token 估算测试"""
def test_estimate_tokens_string(self):
"""测试字符串 token 估算"""
from core.utils import estimate_tokens
# 8个字符应该约等于2个token
result = estimate_tokens("12345678")
self.assertEqual(result, 2)
# 空字符串应该返回1
result = estimate_tokens("")
self.assertEqual(result, 1)
def test_estimate_tokens_list(self):
"""测试列表 token 估算"""
from core.utils import estimate_tokens
content = [
{"text": "Hello"},
{"text": "World"}
]
result = estimate_tokens(content)
self.assertGreater(result, 0)
if __name__ == "__main__":
# 设置环境变量避免配置警告
os.environ.setdefault("DS2API_CONFIG_PATH",