From dbae110d2bb09d8360e0c46b50ec57c80769b39d Mon Sep 17 00:00:00 2001 From: CJACK Date: Sun, 1 Feb 2026 16:05:10 +0800 Subject: [PATCH] feat: Implement OpenAI-compatible tool calling by injecting tool prompts and parsing tool calls from model responses. --- routes/openai.py | 176 +++++++++++++++++++++++++++++++-- tests/test_all.py | 238 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_unit.py | 106 ++++++++++++++++++++ 3 files changed, 512 insertions(+), 8 deletions(-) diff --git a/routes/openai.py b/routes/openai.py index e4fe73b..ab67b9b 100644 --- a/routes/openai.py +++ b/routes/openai.py @@ -2,6 +2,7 @@ """OpenAI 兼容路由""" import json import queue +import random import re import threading import time @@ -29,6 +30,7 @@ from core.sse_parser import ( extract_content_from_chunk, extract_content_recursive, should_filter_citation, + parse_tool_calls, ) from core.constants import ( KEEP_ALIVE_TIMEOUT, @@ -83,6 +85,59 @@ async def chat_completions(request: Request): status_code=400, detail="Request must include 'model' and 'messages'." ) + # 解析工具调用参数(OpenAI 格式) + tools_requested = req_data.get("tools") or [] + has_tools = len(tools_requested) > 0 + + # 如果有工具定义,构建工具提示并注入到消息中 + messages_with_tools = messages.copy() + if has_tools: + tool_schemas = [] + for tool in tools_requested: + # OpenAI 格式: {"type": "function", "function": {"name": ..., "description": ..., "parameters": ...}} + func = tool.get("function", tool) # 兼容简化格式 + tool_name = func.get("name", "unknown") + tool_desc = func.get("description", "No description available") + schema = func.get("parameters", {}) + + tool_info = f"Tool: {tool_name}\nDescription: {tool_desc}" + if "properties" in schema: + props = [] + required = schema.get("required", []) + for prop_name, prop_info in schema["properties"].items(): + prop_type = prop_info.get("type", "string") + is_req = " (required)" if prop_name in required else "" + props.append(f" - {prop_name}: {prop_type}{is_req}") + if props: + tool_info += f"\nParameters:\n" + "\n".join(props) + tool_schemas.append(tool_info) + + # 检查是否已有系统消息 + has_system = any(m.get("role") == "system" for m in messages_with_tools) + tool_prompt = f"""You have access to these tools: + +{chr(10).join(tool_schemas)} + +When you need to use tools, output ONLY this JSON format (no other text): +{{"tool_calls": [ + {{"name": "tool_name", "input": {{"param": "value"}}}} +]}} + +IMPORTANT: If calling tools, output ONLY the JSON. The response must start with {{ and end with }}""" + + if has_system: + # 追加到现有系统消息 + for i, m in enumerate(messages_with_tools): + if m.get("role") == "system": + messages_with_tools[i] = { + "role": "system", + "content": m.get("content", "") + "\n\n" + tool_prompt + } + break + else: + # 添加新的系统消息 + messages_with_tools.insert(0, {"role": "system", "content": tool_prompt}) + # 使用会话管理器获取模型配置 thinking_enabled, search_enabled = get_model_config(model) if thinking_enabled is None: @@ -90,8 +145,8 @@ async def chat_completions(request: Request): status_code=503, detail=f"Model '{model}' is not available." ) - # 使用 messages_prepare 函数构造最终 prompt - final_prompt = messages_prepare(messages) + # 使用 messages_prepare 函数构造最终 prompt(使用带工具提示的消息) + final_prompt = messages_prepare(messages_with_tools) session_id = create_session(request) if not session_id: raise HTTPException(status_code=401, detail="invalid token.") @@ -255,12 +310,42 @@ async def chat_completions(request: Request): "total_tokens": prompt_tokens + thinking_tokens + completion_tokens, "completion_tokens_details": {"reasoning_tokens": thinking_tokens}, } + + # 检测工具调用 + detected_tools = [] + finish_reason = "stop" + if has_tools: + detected_tools = parse_tool_calls(final_text, [{"name": t.get("function", t).get("name")} for t in tools_requested]) + if detected_tools: + finish_reason = "tool_calls" + + if detected_tools: + # 发送工具调用响应 + tool_calls_data = [] + for idx, tool_info in enumerate(detected_tools): + tool_calls_data.append({ + "id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}", + "type": "function", + "function": { + "name": tool_info["name"], + "arguments": json.dumps(tool_info["input"], ensure_ascii=False) + } + }) + tool_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": [{"delta": {"tool_calls": tool_calls_data}, "index": 0}], + } + yield f"data: {json.dumps(tool_chunk, ensure_ascii=False)}\n\n" + finish_chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": model, - "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], + "choices": [{"delta": {}, "index": 0, "finish_reason": finish_reason}], "usage": usage, } yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" @@ -325,12 +410,41 @@ async def chat_completions(request: Request): "total_tokens": prompt_tokens + thinking_tokens + completion_tokens, "completion_tokens_details": {"reasoning_tokens": thinking_tokens}, } + + # 检测工具调用 + detected_tools = [] + finish_reason = "stop" + if has_tools: + detected_tools = parse_tool_calls(final_text, [{"name": t.get("function", t).get("name")} for t in tools_requested]) + if detected_tools: + finish_reason = "tool_calls" + + if detected_tools: + tool_calls_data = [] + for idx, tool_info in enumerate(detected_tools): + tool_calls_data.append({ + "id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}", + "type": "function", + "function": { + "name": tool_info["name"], + "arguments": json.dumps(tool_info["input"], ensure_ascii=False) + } + }) + tool_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": [{"delta": {"tool_calls": tool_calls_data}, "index": 0}], + } + yield f"data: {json.dumps(tool_chunk, ensure_ascii=False)}\n\n" + finish_chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": model, - "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], + "choices": [{"delta": {}, "index": 0, "finish_reason": finish_reason}], "usage": usage, } yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" @@ -401,14 +515,37 @@ async def chat_completions(request: Request): prompt_tokens = len(final_prompt) // 4 reasoning_tokens = len(final_reasoning) // 4 completion_tokens = len(final_content) // 4 + + # 检测工具调用 + detected_tools = [] + finish_reason = "stop" + if has_tools: + detected_tools = parse_tool_calls(final_content, [{"name": t.get("function", t).get("name")} for t in tools_requested]) + if detected_tools: + finish_reason = "tool_calls" + # 构建 message 对象 message_obj = { "role": "assistant", - "content": final_content, + "content": final_content if not detected_tools else None, } # 只有启用思考模式时才包含 reasoning_content if thinking_enabled and final_reasoning: message_obj["reasoning_content"] = final_reasoning + # 添加工具调用 + if detected_tools: + tool_calls_data = [] + for idx, tool_info in enumerate(detected_tools): + tool_calls_data.append({ + "id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}", + "type": "function", + "function": { + "name": tool_info["name"], + "arguments": json.dumps(tool_info["input"], ensure_ascii=False) + } + }) + message_obj["tool_calls"] = tool_calls_data + message_obj["content"] = None result = { "id": completion_id, @@ -418,7 +555,7 @@ async def chat_completions(request: Request): "choices": [{ "index": 0, "message": message_obj, - "finish_reason": "stop", + "finish_reason": finish_reason, }], "usage": { "prompt_tokens": prompt_tokens, @@ -452,14 +589,37 @@ async def chat_completions(request: Request): prompt_tokens = len(final_prompt) // 4 reasoning_tokens = len(final_reasoning) // 4 completion_tokens = len(final_content) // 4 + + # 检测工具调用 + detected_tools = [] + finish_reason = "stop" + if has_tools: + detected_tools = parse_tool_calls(final_content, [{"name": t.get("function", t).get("name")} for t in tools_requested]) + if detected_tools: + finish_reason = "tool_calls" + # 构建 message 对象 message_obj = { "role": "assistant", - "content": final_content, + "content": final_content if not detected_tools else None, } # 只有启用思考模式时才包含 reasoning_content if thinking_enabled and final_reasoning: message_obj["reasoning_content"] = final_reasoning + # 添加工具调用 + if detected_tools: + tool_calls_data = [] + for idx, tool_info in enumerate(detected_tools): + tool_calls_data.append({ + "id": f"call_{int(time.time())}_{random.randint(1000,9999)}_{idx}", + "type": "function", + "function": { + "name": tool_info["name"], + "arguments": json.dumps(tool_info["input"], ensure_ascii=False) + } + }) + message_obj["tool_calls"] = tool_calls_data + message_obj["content"] = None result = { "id": completion_id, @@ -469,7 +629,7 @@ async def chat_completions(request: Request): "choices": [{ "index": 0, "message": message_obj, - "finish_reason": "stop", + "finish_reason": finish_reason, }], "usage": { "prompt_tokens": prompt_tokens, diff --git a/tests/test_all.py b/tests/test_all.py index 1ef14e4..8f548b9 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -632,6 +632,237 @@ class TestRunner: "details": {"response_time": data["response_time"]} } + # ===================================================================== + # 工具调用测试 + # ===================================================================== + + def test_openai_tool_calling(self) -> dict: + """测试 OpenAI 工具调用""" + payload = { + "model": "deepseek-chat", + "messages": [ + {"role": "user", "content": "What's the weather in Beijing? Use the get_weather tool."} + ], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"} + }, + "required": ["location"] + } + } + }], + "stream": False + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers=self.get_headers(), + json=payload, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}} + + data = resp.json() + if "error" in data: + return {"success": False, "message": data["error"]} + + message = data.get("choices", [{}])[0].get("message", {}) + tool_calls = message.get("tool_calls", []) + finish_reason = data.get("choices", [{}])[0].get("finish_reason", "") + content = message.get("content", "") + + # AI 可能调用工具,也可能直接回复 + if tool_calls: + return { + "success": True, + "message": f"检测到 {len(tool_calls)} 个工具调用, finish_reason={finish_reason}", + "details": {"tool_calls": tool_calls} + } + else: + return { + "success": True, + "message": f"AI 直接回复而非调用工具: {content[:50]}...", + "details": {"content": content[:100]} + } + + def test_openai_tool_calling_stream(self) -> dict: + """测试 OpenAI 流式工具调用""" + payload = { + "model": "deepseek-chat", + "messages": [ + {"role": "user", "content": "Use get_time tool to check current time in Tokyo."} + ], + "tools": [{ + "type": "function", + "function": { + "name": "get_time", + "description": "Get current time for a timezone", + "parameters": { + "type": "object", + "properties": { + "timezone": {"type": "string"} + }, + "required": ["timezone"] + } + } + }], + "stream": True + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers=self.get_headers(), + json=payload, + stream=True, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + chunks = [] + tool_calls_found = False + finish_reason = None + + for line in resp.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + data_str = line_str[6:] + if data_str == "[DONE]": + break + try: + chunk = json.loads(data_str) + chunks.append(chunk) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + if "tool_calls" in delta: + tool_calls_found = True + fr = chunk.get("choices", [{}])[0].get("finish_reason") + if fr: + finish_reason = fr + except json.JSONDecodeError: + pass + + return { + "success": True, + "message": f"收到 {len(chunks)} 个数据块, 工具调用: {tool_calls_found}, finish: {finish_reason}", + "details": {"chunk_count": len(chunks), "tool_calls_found": tool_calls_found} + } + + def test_claude_tool_calling(self) -> dict: + """测试 Claude 工具调用""" + payload = { + "model": "claude-sonnet-4-20250514", + "max_tokens": 200, + "messages": [ + {"role": "user", "content": "Use the calculator tool to compute 15 * 23"} + ], + "tools": [{ + "name": "calculator", + "description": "Perform arithmetic calculations", + "input_schema": { + "type": "object", + "properties": { + "expression": {"type": "string", "description": "Math expression"} + }, + "required": ["expression"] + } + }], + "stream": False + } + + resp = requests.post( + f"{self.endpoint}/anthropic/v1/messages", + headers=self.get_headers(is_claude=True), + json=payload, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}", "details": {"response": resp.text}} + + data = resp.json() + if "error" in data: + return {"success": False, "message": str(data["error"])} + + content_blocks = data.get("content", []) + stop_reason = data.get("stop_reason", "") + + tool_use_blocks = [b for b in content_blocks if b.get("type") == "tool_use"] + text_blocks = [b for b in content_blocks if b.get("type") == "text"] + + if tool_use_blocks: + return { + "success": True, + "message": f"检测到 {len(tool_use_blocks)} 个工具调用, stop_reason={stop_reason}", + "details": {"tool_use": tool_use_blocks} + } + else: + text_content = "".join(b.get("text", "") for b in text_blocks) + return { + "success": True, + "message": f"AI 直接回复: {text_content[:50]}...", + "details": {"content": text_content[:100]} + } + + # ===================================================================== + # 搜索模式测试 + # ===================================================================== + + def test_openai_search_mode(self) -> dict: + """测试 OpenAI 搜索模式""" + payload = { + "model": "deepseek-chat-search", + "messages": [ + {"role": "user", "content": "今天的新闻有哪些?"} + ], + "stream": True + } + + resp = requests.post( + f"{self.endpoint}/v1/chat/completions", + headers=self.get_headers(), + json=payload, + stream=True, + timeout=TEST_TIMEOUT + ) + + if resp.status_code != 200: + return {"success": False, "message": f"状态码: {resp.status_code}"} + + content = "" + for line in resp.iter_lines(): + if line: + line_str = line.decode("utf-8") + if line_str.startswith("data: "): + data_str = line_str[6:] + if data_str == "[DONE]": + break + try: + chunk = json.loads(data_str) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + if "content" in delta: + content += delta["content"] + except json.JSONDecodeError: + pass + + if not content: + return {"success": False, "message": "搜索模式无响应内容"} + + return { + "success": True, + "message": f"搜索模式正常,收到 {len(content)} 字符", + "details": {"content_preview": content[:100]} + } + # ===================================================================== # 运行测试 # ===================================================================== @@ -672,6 +903,13 @@ class TestRunner: if not quick: self.run_test("多轮对话", self.test_multi_turn_conversation) self.run_test("长输入处理", self.test_long_input) + self.run_test("OpenAI 搜索模式", self.test_openai_search_mode) + + # 工具调用测试 + if not quick: + self.run_test("OpenAI 工具调用", self.test_openai_tool_calling) + self.run_test("OpenAI 流式工具调用", self.test_openai_tool_calling_stream) + self.run_test("Claude 工具调用", self.test_claude_tool_calling) # 管理 API 测试 self.run_test("管理配置 API", self.test_admin_config) diff --git a/tests/test_unit.py b/tests/test_unit.py index c8a6326..c9ed401 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -451,6 +451,112 @@ class TestStreamParsing(unittest.TestCase): self.assertTrue(check_response_started(response_fragment)) # RESPONSE 触发 +class TestToolCallParsing(unittest.TestCase): + """工具调用解析测试""" + + def test_parse_tool_calls_simple(self): + """测试简单工具调用解析""" + from core.sse_parser import parse_tool_calls + + response_text = '{"tool_calls": [{"name": "get_weather", "input": {"location": "Beijing"}}]}' + tools = [{"name": "get_weather"}] + + result = parse_tool_calls(response_text, tools) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["name"], "get_weather") + self.assertEqual(result[0]["input"]["location"], "Beijing") + + def test_parse_tool_calls_multiple(self): + """测试多工具调用解析""" + from core.sse_parser import parse_tool_calls + + response_text = '''{"tool_calls": [ + {"name": "get_weather", "input": {"location": "Beijing"}}, + {"name": "get_time", "input": {"timezone": "Asia/Shanghai"}} + ]}''' + tools = [{"name": "get_weather"}, {"name": "get_time"}] + + result = parse_tool_calls(response_text, tools) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["name"], "get_weather") + self.assertEqual(result[1]["name"], "get_time") + + def test_parse_tool_calls_no_match(self): + """测试无工具调用时返回空列表""" + from core.sse_parser import parse_tool_calls + + response_text = "这是一个普通的回复,没有工具调用。" + tools = [{"name": "get_weather"}] + + result = parse_tool_calls(response_text, tools) + + self.assertEqual(result, []) + + def test_parse_tool_calls_with_surrounding_text(self): + """测试带有周围文本的工具调用""" + from core.sse_parser import parse_tool_calls + + response_text = '''好的,我来帮你查询天气。 +{"tool_calls": [{"name": "get_weather", "input": {"location": "Shanghai"}}]}''' + tools = [{"name": "get_weather"}] + + result = parse_tool_calls(response_text, tools) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["name"], "get_weather") + + def test_parse_tool_calls_empty_input(self): + """测试空输入""" + from core.sse_parser import parse_tool_calls + + result = parse_tool_calls("", []) + self.assertEqual(result, []) + + result = parse_tool_calls("some text", []) + self.assertEqual(result, []) + + def test_parse_tool_calls_invalid_json(self): + """测试无效 JSON""" + from core.sse_parser import parse_tool_calls + + response_text = '{"tool_calls": [{"name": "get_weather", invalid json here}' + tools = [{"name": "get_weather"}] + + result = parse_tool_calls(response_text, tools) + + # 应该返回空列表而不是抛出异常 + self.assertEqual(result, []) + + +class TestTokenEstimation(unittest.TestCase): + """Token 估算测试""" + + def test_estimate_tokens_string(self): + """测试字符串 token 估算""" + from core.utils import estimate_tokens + + # 8个字符应该约等于2个token + result = estimate_tokens("12345678") + self.assertEqual(result, 2) + + # 空字符串应该返回1 + result = estimate_tokens("") + self.assertEqual(result, 1) + + def test_estimate_tokens_list(self): + """测试列表 token 估算""" + from core.utils import estimate_tokens + + content = [ + {"text": "Hello"}, + {"text": "World"} + ] + result = estimate_tokens(content) + self.assertGreater(result, 0) + + if __name__ == "__main__": # 设置环境变量避免配置警告 os.environ.setdefault("DS2API_CONFIG_PATH",