Update clientdeepseek.py

add sse support
This commit is contained in:
AiShell 2025-04-16 14:10:45 +08:00 committed by GitHub
parent f7ede79533
commit 7bbc02c761
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,9 +4,11 @@ from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from openai import OpenAI
import os
import json
from dotenv import load_dotenv
load_dotenv() # load environment variables from .env
@ -24,6 +26,7 @@ class MCPClient:
api_key=api_key ,
base_url=base_url ,
)
self.deepseek_client = OpenAI(api_key=api_key,base_url=base_url)
async def connect_to_server(self, server_script_path: str):
"""Connect to an MCP server
@ -54,6 +57,34 @@ class MCPClient:
tools = response.tools
print("\nConnected to server with tools:", [tool.name for tool in tools])
async def connect_to_sse_server(self, server_url: str, timeout=10.0, retries=1):
"""Connect to an MCP server running with SSE transport"""
try:
# Store the context managers so they stay alive
self._streams_context = sse_client(url=server_url)
streams = await self._streams_context.__aenter__()
self._session_context = ClientSession(*streams)
self.session: ClientSession = await self._session_context.__aenter__()
# Initialize
await self.session.initialize()
# List available tools to verify connection
print("Initialized SSE client...")
print("Listing tools...")
response = await self.session.list_tools()
tools = response.tools
print("\nConnected to server with tools:", [tool.name for tool in tools])
return True
except Exception as e:
print(f"Connection error: {str(e)}")
if self.debug:
import traceback
traceback.print_exc()
return False
async def process_query(self, query: str) -> str:
"""Process a query using OpenAI and available tools"""
messages = [
@ -72,39 +103,55 @@ class MCPClient:
"parameters": tool.inputSchema
}
} for tool in response.tools]
print(f"available_tools: {available_tools}")
# Initial OpenAI API call
response = self.openai.chat.completions.create(
model="deepseek-chat", # Use appropriate OpenAI model
messages=messages,
tools=available_tools,
tool_choice="auto"
)
#print(response)
#response = self.openai.chat.completions.create(
# model="deepseek-chat", # Use appropriate OpenAI model
# messages=messages,
# tools=available_tools,
# tool_choice="auto"
#)
response = self.deepseek_client.chat.completions.create(
model=modelname, # deepseek官网用的是deepseek-chat
messages=messages,
tools=available_tools,
tool_choice="auto",
temperature=0
)
print(response)
# Process response and handle tool calls
final_text = []
assistant_message = response.choices[0].message
print(f"assistant_message: {assistant_message}")
final_text.append(assistant_message.content or "")
print (assistant_message)
print (assistant_message.tool_calls)
#print (f"assistant_message: {assistant_message}")
print (f"assistant_message.tool_calls: {assistant_message.tool_calls}")
# Handle tool calls if present
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
for tool_call in assistant_message.tool_calls:
tool_name = tool_call.function.name
tool_args = tool_call.function.arguments
print(tool_name)
print(tool_args)
# Convert string arguments to JSON if needed
import json
print(f"tool_name: {tool_name}")
print(f"tool_args: {tool_args}")
if isinstance(tool_args, str):
try:
# 尝试解析JSON字符串成字典
tool_args = json.loads(tool_args)
except:
pass
print (tool_args)
except json.JSONDecodeError:
# 如果无法解析为JSON创建一个包含原始字符串的字典
tool_args = {"input": tool_args}
# 再次确认类型,如果仍然不是字典,则创建一个包含它的字典
if not isinstance(tool_args, dict):
tool_args = {"value": str(tool_args)}
print(f"converted tool_args: {tool_args}")
# Convert string arguments to JSON if needed
# Execute tool call
result = await self.session.call_tool(tool_name, tool_args)
print(result)
print(f"result: {result}")
final_text.append(f"[Calling tool {tool_name} with args {tool_args}]")
result_content = result.content
if isinstance(result_content, list):
@ -121,13 +168,13 @@ class MCPClient:
"name": tool_name,
"content": result_content
})
print(messages)
print(f"messages: {messages}")
# Get next response from OpenAI
response = self.openai.chat.completions.create(
model="deepseek-chat", # Use appropriate OpenAI model
messages=messages,
)
print(response)
print(f"response: {response}")
final_text.append(response.choices[0].message.content or "")
return "\n".join(final_text)
@ -156,16 +203,18 @@ class MCPClient:
async def main():
if len(sys.argv) < 2:
print("Usage: python client.py <path_to_server_script> [api_key] [base_url]")
print("Usage: python dsclient.py <path_to_server_script> ")
sys.exit(1)
# Optional command line arguments for API key and base URL
api_key = sys.argv[2] if len(sys.argv) > 2 else None
base_url = sys.argv[3] if len(sys.argv) > 3 else None
client = MCPClient(api_key=api_key, base_url=base_url)
try:
await client.connect_to_server(sys.argv[1])
if sys.argv[1].endswith('.py'):
await client.connect_to_server(sys.argv[1])
else:
await client.connect_to_sse_server(sys.argv[1])
await client.chat_loop()
finally:
await client.cleanup()