mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-18 23:25:10 +08:00
feat: Implement OpenAI-compatible tool calling by injecting tool prompts and parsing tool calls from model responses.
This commit is contained in:
@@ -632,6 +632,237 @@ class TestRunner:
|
||||
"details": {"response_time": data["response_time"]}
|
||||
}
|
||||
|
||||
# =====================================================================
|
||||
# 工具调用测试
|
||||
# =====================================================================
|
||||
|
||||
def test_openai_tool_calling(self) -> dict:
|
||||
"""测试 OpenAI 工具调用"""
|
||||
payload = {
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather in Beijing? Use the get_weather tool."}
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City name"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{self.endpoint}/v1/chat/completions",
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
timeout=TEST_TIMEOUT
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}}
|
||||
|
||||
data = resp.json()
|
||||
if "error" in data:
|
||||
return {"success": False, "message": data["error"]}
|
||||
|
||||
message = data.get("choices", [{}])[0].get("message", {})
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
finish_reason = data.get("choices", [{}])[0].get("finish_reason", "")
|
||||
content = message.get("content", "")
|
||||
|
||||
# AI 可能调用工具,也可能直接回复
|
||||
if tool_calls:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"检测到 {len(tool_calls)} 个工具调用, finish_reason={finish_reason}",
|
||||
"details": {"tool_calls": tool_calls}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"AI 直接回复而非调用工具: {content[:50]}...",
|
||||
"details": {"content": content[:100]}
|
||||
}
|
||||
|
||||
def test_openai_tool_calling_stream(self) -> dict:
|
||||
"""测试 OpenAI 流式工具调用"""
|
||||
payload = {
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Use get_time tool to check current time in Tokyo."}
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Get current time for a timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {"type": "string"}
|
||||
},
|
||||
"required": ["timezone"]
|
||||
}
|
||||
}
|
||||
}],
|
||||
"stream": True
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{self.endpoint}/v1/chat/completions",
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=TEST_TIMEOUT
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return {"success": False, "message": f"状态码: {resp.status_code}"}
|
||||
|
||||
chunks = []
|
||||
tool_calls_found = False
|
||||
finish_reason = None
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
data_str = line_str[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
chunks.append(chunk)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
if "tool_calls" in delta:
|
||||
tool_calls_found = True
|
||||
fr = chunk.get("choices", [{}])[0].get("finish_reason")
|
||||
if fr:
|
||||
finish_reason = fr
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"收到 {len(chunks)} 个数据块, 工具调用: {tool_calls_found}, finish: {finish_reason}",
|
||||
"details": {"chunk_count": len(chunks), "tool_calls_found": tool_calls_found}
|
||||
}
|
||||
|
||||
def test_claude_tool_calling(self) -> dict:
|
||||
"""测试 Claude 工具调用"""
|
||||
payload = {
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"max_tokens": 200,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Use the calculator tool to compute 15 * 23"}
|
||||
],
|
||||
"tools": [{
|
||||
"name": "calculator",
|
||||
"description": "Perform arithmetic calculations",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "Math expression"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{self.endpoint}/anthropic/v1/messages",
|
||||
headers=self.get_headers(is_claude=True),
|
||||
json=payload,
|
||||
timeout=TEST_TIMEOUT
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}}
|
||||
|
||||
data = resp.json()
|
||||
if "error" in data:
|
||||
return {"success": False, "message": str(data["error"])}
|
||||
|
||||
content_blocks = data.get("content", [])
|
||||
stop_reason = data.get("stop_reason", "")
|
||||
|
||||
tool_use_blocks = [b for b in content_blocks if b.get("type") == "tool_use"]
|
||||
text_blocks = [b for b in content_blocks if b.get("type") == "text"]
|
||||
|
||||
if tool_use_blocks:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"检测到 {len(tool_use_blocks)} 个工具调用, stop_reason={stop_reason}",
|
||||
"details": {"tool_use": tool_use_blocks}
|
||||
}
|
||||
else:
|
||||
text_content = "".join(b.get("text", "") for b in text_blocks)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"AI 直接回复: {text_content[:50]}...",
|
||||
"details": {"content": text_content[:100]}
|
||||
}
|
||||
|
||||
# =====================================================================
|
||||
# 搜索模式测试
|
||||
# =====================================================================
|
||||
|
||||
def test_openai_search_mode(self) -> dict:
|
||||
"""测试 OpenAI 搜索模式"""
|
||||
payload = {
|
||||
"model": "deepseek-chat-search",
|
||||
"messages": [
|
||||
{"role": "user", "content": "今天的新闻有哪些?"}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{self.endpoint}/v1/chat/completions",
|
||||
headers=self.get_headers(),
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=TEST_TIMEOUT
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
return {"success": False, "message": f"状态码: {resp.status_code}"}
|
||||
|
||||
content = ""
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode("utf-8")
|
||||
if line_str.startswith("data: "):
|
||||
data_str = line_str[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
if "content" in delta:
|
||||
content += delta["content"]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if not content:
|
||||
return {"success": False, "message": "搜索模式无响应内容"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"搜索模式正常,收到 {len(content)} 字符",
|
||||
"details": {"content_preview": content[:100]}
|
||||
}
|
||||
|
||||
# =====================================================================
|
||||
# 运行测试
|
||||
# =====================================================================
|
||||
@@ -672,6 +903,13 @@ class TestRunner:
|
||||
if not quick:
|
||||
self.run_test("多轮对话", self.test_multi_turn_conversation)
|
||||
self.run_test("长输入处理", self.test_long_input)
|
||||
self.run_test("OpenAI 搜索模式", self.test_openai_search_mode)
|
||||
|
||||
# 工具调用测试
|
||||
if not quick:
|
||||
self.run_test("OpenAI 工具调用", self.test_openai_tool_calling)
|
||||
self.run_test("OpenAI 流式工具调用", self.test_openai_tool_calling_stream)
|
||||
self.run_test("Claude 工具调用", self.test_claude_tool_calling)
|
||||
|
||||
# 管理 API 测试
|
||||
self.run_test("管理配置 API", self.test_admin_config)
|
||||
|
||||
Reference in New Issue
Block a user