mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-13 20:57:41 +08:00
feat: Implement OpenAI-compatible tool calling by injecting tool prompts and parsing tool calls from model responses.
This commit is contained in:
@@ -451,6 +451,112 @@ class TestStreamParsing(unittest.TestCase):
|
||||
self.assertTrue(check_response_started(response_fragment)) # RESPONSE 触发
|
||||
|
||||
|
||||
class TestToolCallParsing(unittest.TestCase):
|
||||
"""工具调用解析测试"""
|
||||
|
||||
def test_parse_tool_calls_simple(self):
|
||||
"""测试简单工具调用解析"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = '{"tool_calls": [{"name": "get_weather", "input": {"location": "Beijing"}}]}'
|
||||
tools = [{"name": "get_weather"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["name"], "get_weather")
|
||||
self.assertEqual(result[0]["input"]["location"], "Beijing")
|
||||
|
||||
def test_parse_tool_calls_multiple(self):
|
||||
"""测试多工具调用解析"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = '''{"tool_calls": [
|
||||
{"name": "get_weather", "input": {"location": "Beijing"}},
|
||||
{"name": "get_time", "input": {"timezone": "Asia/Shanghai"}}
|
||||
]}'''
|
||||
tools = [{"name": "get_weather"}, {"name": "get_time"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0]["name"], "get_weather")
|
||||
self.assertEqual(result[1]["name"], "get_time")
|
||||
|
||||
def test_parse_tool_calls_no_match(self):
|
||||
"""测试无工具调用时返回空列表"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = "这是一个普通的回复,没有工具调用。"
|
||||
tools = [{"name": "get_weather"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_parse_tool_calls_with_surrounding_text(self):
|
||||
"""测试带有周围文本的工具调用"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = '''好的,我来帮你查询天气。
|
||||
{"tool_calls": [{"name": "get_weather", "input": {"location": "Shanghai"}}]}'''
|
||||
tools = [{"name": "get_weather"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["name"], "get_weather")
|
||||
|
||||
def test_parse_tool_calls_empty_input(self):
|
||||
"""测试空输入"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
result = parse_tool_calls("", [])
|
||||
self.assertEqual(result, [])
|
||||
|
||||
result = parse_tool_calls("some text", [])
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_parse_tool_calls_invalid_json(self):
|
||||
"""测试无效 JSON"""
|
||||
from core.sse_parser import parse_tool_calls
|
||||
|
||||
response_text = '{"tool_calls": [{"name": "get_weather", invalid json here}'
|
||||
tools = [{"name": "get_weather"}]
|
||||
|
||||
result = parse_tool_calls(response_text, tools)
|
||||
|
||||
# 应该返回空列表而不是抛出异常
|
||||
self.assertEqual(result, [])
|
||||
|
||||
|
||||
class TestTokenEstimation(unittest.TestCase):
|
||||
"""Token 估算测试"""
|
||||
|
||||
def test_estimate_tokens_string(self):
|
||||
"""测试字符串 token 估算"""
|
||||
from core.utils import estimate_tokens
|
||||
|
||||
# 8个字符应该约等于2个token
|
||||
result = estimate_tokens("12345678")
|
||||
self.assertEqual(result, 2)
|
||||
|
||||
# 空字符串应该返回1
|
||||
result = estimate_tokens("")
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
def test_estimate_tokens_list(self):
|
||||
"""测试列表 token 估算"""
|
||||
from core.utils import estimate_tokens
|
||||
|
||||
content = [
|
||||
{"text": "Hello"},
|
||||
{"text": "World"}
|
||||
]
|
||||
result = estimate_tokens(content)
|
||||
self.assertGreater(result, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置环境变量避免配置警告
|
||||
os.environ.setdefault("DS2API_CONFIG_PATH",
|
||||
|
||||
Reference in New Issue
Block a user