feat: Implement OpenAI-compatible tool calling by injecting tool prompts and parsing tool calls from model responses.

This commit is contained in:
CJACK
2026-02-01 16:05:10 +08:00
parent 1cf52502bb
commit dbae110d2b
3 changed files with 512 additions and 8 deletions

View File

@@ -2,6 +2,7 @@
"""OpenAI 兼容路由"""
import json
import queue
import random
import re
import threading
import time
@@ -29,6 +30,7 @@ from core.sse_parser import (
extract_content_from_chunk,
extract_content_recursive,
should_filter_citation,
parse_tool_calls,
)
from core.constants import (
KEEP_ALIVE_TIMEOUT,
@@ -83,6 +85,59 @@ async def chat_completions(request: Request):
status_code=400, detail="Request must include 'model' and 'messages'."
)
# 解析工具调用参数OpenAI 格式)
tools_requested = req_data.get("tools") or []
has_tools = len(tools_requested) > 0
# 如果有工具定义,构建工具提示并注入到消息中
messages_with_tools = messages.copy()
if has_tools:
tool_schemas = []
for tool in tools_requested:
# OpenAI 格式: {"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}}
func = tool.get("function", tool) # 兼容简化格式
tool_name = func.get("name", "unknown")
tool_desc = func.get("description", "No description available")
schema = func.get("parameters", {})
tool_info = f"Tool: {tool_name}\nDescription: {tool_desc}"
if "properties" in schema:
props = []
required = schema.get("required", [])
for prop_name, prop_info in schema["properties"].items():
prop_type = prop_info.get("type", "string")
is_req = " (required)" if prop_name in required else ""
props.append(f" - {prop_name}: {prop_type}{is_req}")
if props:
tool_info += f"\nParameters:\n" + "\n".join(props)
tool_schemas.append(tool_info)
# 检查是否已有系统消息
has_system = any(m.get("role") == "system" for m in messages_with_tools)
tool_prompt = f"""You have access to these tools:
{chr(10).join(tool_schemas)}
When you need to use tools, output ONLY this JSON format (no other text):
{{"tool_calls": [
{{"name": "tool_name", "input": {{"param": "value"}}}}
]}}
IMPORTANT: If calling tools, output ONLY the JSON. The response must start with {{ and end with }}"""
if has_system:
# 追加到现有系统消息
for i, m in enumerate(messages_with_tools):
if m.get("role") == "system":
messages_with_tools[i] = {
"role": "system",
"content": m.get("content", "") + "\n\n" + tool_prompt
}
break
else:
# 添加新的系统消息
messages_with_tools.insert(0, {"role": "system", "content": tool_prompt})
# 使用会话管理器获取模型配置
thinking_enabled, search_enabled = get_model_config(model)
if thinking_enabled is None:
@@ -90,8 +145,8 @@ async def chat_completions(request: Request):
status_code=503, detail=f"Model '{model}' is not available."
)
# 使用 messages_prepare 函数构造最终 prompt
final_prompt = messages_prepare(messages)
# 使用 messages_prepare 函数构造最终 prompt(使用带工具提示的消息)
final_prompt = messages_prepare(messages_with_tools)
session_id = create_session(request)
if not session_id:
raise HTTPException(status_code=401, detail="invalid token.")
@@ -255,12 +310,42 @@ async def chat_completions(request: Request):
"total_tokens": prompt_tokens + thinking_tokens + completion_tokens,
"completion_tokens_details": {"reasoning_tokens": thinking_tokens},
}
# 检测工具调用
detected_tools = []
finish_reason = "stop"
if has_tools:
detected_tools = parse_tool_calls(final_text, [{"name": t.get("function", t).get("name")} for t in tools_requested])
if detected_tools:
finish_reason = "tool_calls"
if detected_tools:
# 发送工具调用响应
tool_calls_data = []
for idx, tool_info in enumerate(detected_tools):
tool_calls_data.append({
"id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}",
"type": "function",
"function": {
"name": tool_info["name"],
"arguments": json.dumps(tool_info["input"], ensure_ascii=False)
}
})
tool_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model,
"choices": [{"delta": {"tool_calls": tool_calls_data}, "index": 0}],
}
yield f"data: {json.dumps(tool_chunk, ensure_ascii=False)}\n\n"
finish_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model,
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}],
"choices": [{"delta": {}, "index": 0, "finish_reason": finish_reason}],
"usage": usage,
}
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
@@ -325,12 +410,41 @@ async def chat_completions(request: Request):
"total_tokens": prompt_tokens + thinking_tokens + completion_tokens,
"completion_tokens_details": {"reasoning_tokens": thinking_tokens},
}
# 检测工具调用
detected_tools = []
finish_reason = "stop"
if has_tools:
detected_tools = parse_tool_calls(final_text, [{"name": t.get("function", t).get("name")} for t in tools_requested])
if detected_tools:
finish_reason = "tool_calls"
if detected_tools:
tool_calls_data = []
for idx, tool_info in enumerate(detected_tools):
tool_calls_data.append({
"id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}",
"type": "function",
"function": {
"name": tool_info["name"],
"arguments": json.dumps(tool_info["input"], ensure_ascii=False)
}
})
tool_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model,
"choices": [{"delta": {"tool_calls": tool_calls_data}, "index": 0}],
}
yield f"data: {json.dumps(tool_chunk, ensure_ascii=False)}\n\n"
finish_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model,
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}],
"choices": [{"delta": {}, "index": 0, "finish_reason": finish_reason}],
"usage": usage,
}
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
@@ -401,14 +515,37 @@ async def chat_completions(request: Request):
prompt_tokens = len(final_prompt) // 4
reasoning_tokens = len(final_reasoning) // 4
completion_tokens = len(final_content) // 4
# 检测工具调用
detected_tools = []
finish_reason = "stop"
if has_tools:
detected_tools = parse_tool_calls(final_content, [{"name": t.get("function", t).get("name")} for t in tools_requested])
if detected_tools:
finish_reason = "tool_calls"
# 构建 message 对象
message_obj = {
"role": "assistant",
"content": final_content,
"content": final_content if not detected_tools else None,
}
# 只有启用思考模式时才包含 reasoning_content
if thinking_enabled and final_reasoning:
message_obj["reasoning_content"] = final_reasoning
# 添加工具调用
if detected_tools:
tool_calls_data = []
for idx, tool_info in enumerate(detected_tools):
tool_calls_data.append({
"id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}",
"type": "function",
"function": {
"name": tool_info["name"],
"arguments": json.dumps(tool_info["input"], ensure_ascii=False)
}
})
message_obj["tool_calls"] = tool_calls_data
message_obj["content"] = None
result = {
"id": completion_id,
@@ -418,7 +555,7 @@ async def chat_completions(request: Request):
"choices": [{
"index": 0,
"message": message_obj,
"finish_reason": "stop",
"finish_reason": finish_reason,
}],
"usage": {
"prompt_tokens": prompt_tokens,
@@ -452,14 +589,37 @@ async def chat_completions(request: Request):
prompt_tokens = len(final_prompt) // 4
reasoning_tokens = len(final_reasoning) // 4
completion_tokens = len(final_content) // 4
# 检测工具调用
detected_tools = []
finish_reason = "stop"
if has_tools:
detected_tools = parse_tool_calls(final_content, [{"name": t.get("function", t).get("name")} for t in tools_requested])
if detected_tools:
finish_reason = "tool_calls"
# 构建 message 对象
message_obj = {
"role": "assistant",
"content": final_content,
"content": final_content if not detected_tools else None,
}
# 只有启用思考模式时才包含 reasoning_content
if thinking_enabled and final_reasoning:
message_obj["reasoning_content"] = final_reasoning
# 添加工具调用
if detected_tools:
tool_calls_data = []
for idx, tool_info in enumerate(detected_tools):
tool_calls_data.append({
"id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}",
"type": "function",
"function": {
"name": tool_info["name"],
"arguments": json.dumps(tool_info["input"], ensure_ascii=False)
}
})
message_obj["tool_calls"] = tool_calls_data
message_obj["content"] = None
result = {
"id": completion_id,
@@ -469,7 +629,7 @@ async def chat_completions(request: Request):
"choices": [{
"index": 0,
"message": message_obj,
"finish_reason": "stop",
"finish_reason": finish_reason,
}],
"usage": {
"prompt_tokens": prompt_tokens,

View File

@@ -632,6 +632,237 @@ class TestRunner:
"details": {"response_time": data["response_time"]}
}
# =====================================================================
# 工具调用测试
# =====================================================================
def test_openai_tool_calling(self) -> dict:
"""测试 OpenAI 工具调用"""
payload = {
"model": "deepseek-chat",
"messages": [
{"role": "user", "content": "What's the weather in Beijing? Use the get_weather tool."}
],
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"]
}
}
}],
"stream": False
}
resp = requests.post(
f"{self.endpoint}/v1/chat/completions",
headers=self.get_headers(),
json=payload,
timeout=TEST_TIMEOUT
)
if resp.status_code != 200:
return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}}
data = resp.json()
if "error" in data:
return {"success": False, "message": data["error"]}
message = data.get("choices", [{}])[0].get("message", {})
tool_calls = message.get("tool_calls", [])
finish_reason = data.get("choices", [{}])[0].get("finish_reason", "")
content = message.get("content", "")
# AI 可能调用工具,也可能直接回复
if tool_calls:
return {
"success": True,
"message": f"检测到 {len(tool_calls)} 个工具调用, finish_reason={finish_reason}",
"details": {"tool_calls": tool_calls}
}
else:
return {
"success": True,
"message": f"AI 直接回复而非调用工具: {content[:50]}...",
"details": {"content": content[:100]}
}
def test_openai_tool_calling_stream(self) -> dict:
"""测试 OpenAI 流式工具调用"""
payload = {
"model": "deepseek-chat",
"messages": [
{"role": "user", "content": "Use get_time tool to check current time in Tokyo."}
],
"tools": [{
"type": "function",
"function": {
"name": "get_time",
"description": "Get current time for a timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {"type": "string"}
},
"required": ["timezone"]
}
}
}],
"stream": True
}
resp = requests.post(
f"{self.endpoint}/v1/chat/completions",
headers=self.get_headers(),
json=payload,
stream=True,
timeout=TEST_TIMEOUT
)
if resp.status_code != 200:
return {"success": False, "message": f"状态码: {resp.status_code}"}
chunks = []
tool_calls_found = False
finish_reason = None
for line in resp.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
data_str = line_str[6:]
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
chunks.append(chunk)
delta = chunk.get("choices", [{}])[0].get("delta", {})
if "tool_calls" in delta:
tool_calls_found = True
fr = chunk.get("choices", [{}])[0].get("finish_reason")
if fr:
finish_reason = fr
except json.JSONDecodeError:
pass
return {
"success": True,
"message": f"收到 {len(chunks)} 个数据块, 工具调用: {tool_calls_found}, finish: {finish_reason}",
"details": {"chunk_count": len(chunks), "tool_calls_found": tool_calls_found}
}
def test_claude_tool_calling(self) -> dict:
"""测试 Claude 工具调用"""
payload = {
"model": "claude-sonnet-4-20250514",
"max_tokens": 200,
"messages": [
{"role": "user", "content": "Use the calculator tool to compute 15 * 23"}
],
"tools": [{
"name": "calculator",
"description": "Perform arithmetic calculations",
"input_schema": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "Math expression"}
},
"required": ["expression"]
}
}],
"stream": False
}
resp = requests.post(
f"{self.endpoint}/anthropic/v1/messages",
headers=self.get_headers(is_claude=True),
json=payload,
timeout=TEST_TIMEOUT
)
if resp.status_code != 200:
return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}}
data = resp.json()
if "error" in data:
return {"success": False, "message": str(data["error"])}
content_blocks = data.get("content", [])
stop_reason = data.get("stop_reason", "")
tool_use_blocks = [b for b in content_blocks if b.get("type") == "tool_use"]
text_blocks = [b for b in content_blocks if b.get("type") == "text"]
if tool_use_blocks:
return {
"success": True,
"message": f"检测到 {len(tool_use_blocks)} 个工具调用, stop_reason={stop_reason}",
"details": {"tool_use": tool_use_blocks}
}
else:
text_content = "".join(b.get("text", "") for b in text_blocks)
return {
"success": True,
"message": f"AI 直接回复: {text_content[:50]}...",
"details": {"content": text_content[:100]}
}
# =====================================================================
# 搜索模式测试
# =====================================================================
def test_openai_search_mode(self) -> dict:
"""测试 OpenAI 搜索模式"""
payload = {
"model": "deepseek-chat-search",
"messages": [
{"role": "user", "content": "今天的新闻有哪些?"}
],
"stream": True
}
resp = requests.post(
f"{self.endpoint}/v1/chat/completions",
headers=self.get_headers(),
json=payload,
stream=True,
timeout=TEST_TIMEOUT
)
if resp.status_code != 200:
return {"success": False, "message": f"状态码: {resp.status_code}"}
content = ""
for line in resp.iter_lines():
if line:
line_str = line.decode("utf-8")
if line_str.startswith("data: "):
data_str = line_str[6:]
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
delta = chunk.get("choices", [{}])[0].get("delta", {})
if "content" in delta:
content += delta["content"]
except json.JSONDecodeError:
pass
if not content:
return {"success": False, "message": "搜索模式无响应内容"}
return {
"success": True,
"message": f"搜索模式正常,收到 {len(content)} 字符",
"details": {"content_preview": content[:100]}
}
# =====================================================================
# 运行测试
# =====================================================================
@@ -672,6 +903,13 @@ class TestRunner:
if not quick:
self.run_test("多轮对话", self.test_multi_turn_conversation)
self.run_test("长输入处理", self.test_long_input)
self.run_test("OpenAI 搜索模式", self.test_openai_search_mode)
# 工具调用测试
if not quick:
self.run_test("OpenAI 工具调用", self.test_openai_tool_calling)
self.run_test("OpenAI 流式工具调用", self.test_openai_tool_calling_stream)
self.run_test("Claude 工具调用", self.test_claude_tool_calling)
# 管理 API 测试
self.run_test("管理配置 API", self.test_admin_config)

View File

@@ -451,6 +451,112 @@ class TestStreamParsing(unittest.TestCase):
self.assertTrue(check_response_started(response_fragment)) # RESPONSE 触发
class TestToolCallParsing(unittest.TestCase):
"""工具调用解析测试"""
def test_parse_tool_calls_simple(self):
"""测试简单工具调用解析"""
from core.sse_parser import parse_tool_calls
response_text = '{"tool_calls": [{"name": "get_weather", "input": {"location": "Beijing"}}]}'
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["name"], "get_weather")
self.assertEqual(result[0]["input"]["location"], "Beijing")
def test_parse_tool_calls_multiple(self):
"""测试多工具调用解析"""
from core.sse_parser import parse_tool_calls
response_text = '''{"tool_calls": [
{"name": "get_weather", "input": {"location": "Beijing"}},
{"name": "get_time", "input": {"timezone": "Asia/Shanghai"}}
]}'''
tools = [{"name": "get_weather"}, {"name": "get_time"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(len(result), 2)
self.assertEqual(result[0]["name"], "get_weather")
self.assertEqual(result[1]["name"], "get_time")
def test_parse_tool_calls_no_match(self):
"""测试无工具调用时返回空列表"""
from core.sse_parser import parse_tool_calls
response_text = "这是一个普通的回复,没有工具调用。"
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(result, [])
def test_parse_tool_calls_with_surrounding_text(self):
"""测试带有周围文本的工具调用"""
from core.sse_parser import parse_tool_calls
response_text = '''好的,我来帮你查询天气。
{"tool_calls": [{"name": "get_weather", "input": {"location": "Shanghai"}}]}'''
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["name"], "get_weather")
def test_parse_tool_calls_empty_input(self):
"""测试空输入"""
from core.sse_parser import parse_tool_calls
result = parse_tool_calls("", [])
self.assertEqual(result, [])
result = parse_tool_calls("some text", [])
self.assertEqual(result, [])
def test_parse_tool_calls_invalid_json(self):
"""测试无效 JSON"""
from core.sse_parser import parse_tool_calls
response_text = '{"tool_calls": [{"name": "get_weather", invalid json here}'
tools = [{"name": "get_weather"}]
result = parse_tool_calls(response_text, tools)
# 应该返回空列表而不是抛出异常
self.assertEqual(result, [])
class TestTokenEstimation(unittest.TestCase):
"""Token 估算测试"""
def test_estimate_tokens_string(self):
"""测试字符串 token 估算"""
from core.utils import estimate_tokens
# 8个字符应该约等于2个token
result = estimate_tokens("12345678")
self.assertEqual(result, 2)
# 空字符串应该返回1
result = estimate_tokens("")
self.assertEqual(result, 1)
def test_estimate_tokens_list(self):
"""测试列表 token 估算"""
from core.utils import estimate_tokens
content = [
{"text": "Hello"},
{"text": "World"}
]
result = estimate_tokens(content)
self.assertGreater(result, 0)
if __name__ == "__main__":
# 设置环境变量避免配置警告
os.environ.setdefault("DS2API_CONFIG_PATH",