feat: Implement OpenAI-compatible tool calling by injecting tool prompts and parsing tool calls from model responses.

This commit is contained in:
CJACK
2026-02-01 16:05:10 +08:00
parent 1cf52502bb
commit dbae110d2b
3 changed files with 512 additions and 8 deletions

View File

@@ -451,6 +451,112 @@ class TestStreamParsing(unittest.TestCase):
self.assertTrue(check_response_started(response_fragment)) # RESPONSE 触发
class TestToolCallParsing(unittest.TestCase):
"""工具调用解析测试"""
def test_parse_tool_calls_simple(self):
"""测试简单工具调用解析"""
from core.sse_parser import parse_tool_calls
response_text = '{"tool_calls": [{"name": "get_weather", "input": {"location": "Beijing"}}]}'
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["name"], "get_weather")
self.assertEqual(result[0]["input"]["location"], "Beijing")
def test_parse_tool_calls_multiple(self):
"""测试多工具调用解析"""
from core.sse_parser import parse_tool_calls
response_text = '''{"tool_calls": [
{"name": "get_weather", "input": {"location": "Beijing"}},
{"name": "get_time", "input": {"timezone": "Asia/Shanghai"}}
]}'''
tools = [{"name": "get_weather"}, {"name": "get_time"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(len(result), 2)
self.assertEqual(result[0]["name"], "get_weather")
self.assertEqual(result[1]["name"], "get_time")
def test_parse_tool_calls_no_match(self):
"""测试无工具调用时返回空列表"""
from core.sse_parser import parse_tool_calls
response_text = "这是一个普通的回复,没有工具调用。"
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(result, [])
def test_parse_tool_calls_with_surrounding_text(self):
"""测试带有周围文本的工具调用"""
from core.sse_parser import parse_tool_calls
response_text = '''好的,我来帮你查询天气。
{"tool_calls": [{"name": "get_weather", "input": {"location": "Shanghai"}}]}'''
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["name"], "get_weather")
def test_parse_tool_calls_empty_input(self):
"""测试空输入"""
from core.sse_parser import parse_tool_calls
result = parse_tool_calls("", [])
self.assertEqual(result, [])
result = parse_tool_calls("some text", [])
self.assertEqual(result, [])
def test_parse_tool_calls_invalid_json(self):
"""测试无效 JSON"""
from core.sse_parser import parse_tool_calls
response_text = '{"tool_calls": [{"name": "get_weather", invalid json here}'
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
# 应该返回空列表而不是抛出异常
self.assertEqual(result, [])
class TestTokenEstimation(unittest.TestCase):
"""Token 估算测试"""
def test_estimate_tokens_string(self):
"""测试字符串 token 估算"""
from core.utils import estimate_tokens
# 8个字符应该约等于2个token
result = estimate_tokens("12345678")
self.assertEqual(result, 2)
# 空字符串应该返回1
result = estimate_tokens("")
self.assertEqual(result, 1)
def test_estimate_tokens_list(self):
"""测试列表 token 估算"""
from core.utils import estimate_tokens
content = [
{"text": "Hello"},
{"text": "World"}
]
result = estimate_tokens(content)
self.assertGreater(result, 0)
if __name__ == "__main__":
# 设置环境变量避免配置警告
os.environ.setdefault("DS2API_CONFIG_PATH",