AIPentest/mcp/clientdeepseek.py
AiShell 7bbc02c761
Update clientdeepseek.py
add sse support
2025-04-16 14:10:45 +08:00

225 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from typing import Optional
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from openai import OpenAI
import os
import json
from dotenv import load_dotenv
load_dotenv() # load environment variables from .env
api_key = os.getenv('api_key', '')
base_url= os.getenv('base_url', '')
modelname=os.getenv('modelname','')
class MCPClient:
def __init__(self, api_key=None, base_url=None):
# Initialize session and client objects
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
# Initialize OpenAI client with explicit parameters
self.openai = OpenAI(
api_key=api_key ,
base_url=base_url ,
)
self.deepseek_client = OpenAI(api_key=api_key,base_url=base_url)
async def connect_to_server(self, server_script_path: str):
"""Connect to an MCP server
Args:
server_script_path: Path to the server script (.py or .js)
"""
is_python = server_script_path.endswith('.py')
is_js = server_script_path.endswith('.js')
if not (is_python or is_js):
raise ValueError("Server script must be a .py or .js file")
command = "python" if is_python else "node"
server_params = StdioServerParameters(
command=command,
args=[server_script_path],
env=None
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
await self.session.initialize()
# List available tools
response = await self.session.list_tools()
tools = response.tools
print("\nConnected to server with tools:", [tool.name for tool in tools])
async def connect_to_sse_server(self, server_url: str, timeout=10.0, retries=1):
"""Connect to an MCP server running with SSE transport"""
try:
# Store the context managers so they stay alive
self._streams_context = sse_client(url=server_url)
streams = await self._streams_context.__aenter__()
self._session_context = ClientSession(*streams)
self.session: ClientSession = await self._session_context.__aenter__()
# Initialize
await self.session.initialize()
# List available tools to verify connection
print("Initialized SSE client...")
print("Listing tools...")
response = await self.session.list_tools()
tools = response.tools
print("\nConnected to server with tools:", [tool.name for tool in tools])
return True
except Exception as e:
print(f"Connection error: {str(e)}")
if self.debug:
import traceback
traceback.print_exc()
return False
async def process_query(self, query: str) -> str:
"""Process a query using OpenAI and available tools"""
messages = [
{
"role": "user",
"content": query
}
]
response = await self.session.list_tools()
available_tools = [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
}
} for tool in response.tools]
print(f"available_tools: {available_tools}")
# Initial OpenAI API call
#response = self.openai.chat.completions.create(
# model="deepseek-chat", # Use appropriate OpenAI model
# messages=messages,
# tools=available_tools,
# tool_choice="auto"
#)
response = self.deepseek_client.chat.completions.create(
model=modelname, # deepseek官网用的是deepseek-chat
messages=messages,
tools=available_tools,
tool_choice="auto",
temperature=0
)
print(response)
# Process response and handle tool calls
final_text = []
assistant_message = response.choices[0].message
print(f"assistant_message: {assistant_message}")
final_text.append(assistant_message.content or "")
#print (f"assistant_message: {assistant_message}")
print (f"assistant_message.tool_calls: {assistant_message.tool_calls}")
# Handle tool calls if present
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
for tool_call in assistant_message.tool_calls:
tool_name = tool_call.function.name
tool_args = tool_call.function.arguments
print(f"tool_name: {tool_name}")
print(f"tool_args: {tool_args}")
if isinstance(tool_args, str):
try:
# 尝试解析JSON字符串成字典
tool_args = json.loads(tool_args)
except json.JSONDecodeError:
# 如果无法解析为JSON创建一个包含原始字符串的字典
tool_args = {"input": tool_args}
# 再次确认类型,如果仍然不是字典,则创建一个包含它的字典
if not isinstance(tool_args, dict):
tool_args = {"value": str(tool_args)}
print(f"converted tool_args: {tool_args}")
# Convert string arguments to JSON if needed
# Execute tool call
result = await self.session.call_tool(tool_name, tool_args)
print(f"result: {result}")
final_text.append(f"[Calling tool {tool_name} with args {tool_args}]")
result_content = result.content
if isinstance(result_content, list):
# 如果是列表(可能包含TextContent对象),提取文本
result_content = result_content[0].text if hasattr(result_content[0], 'text') else str(result_content)
elif not isinstance(result_content, str):
# 如果不是字符串也不是列表,转换为字符串
result_content = str(result_content)
# Continue conversation with tool results
messages.append(assistant_message)
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_name,
"content": result_content
})
print(f"messages: {messages}")
# Get next response from OpenAI
response = self.openai.chat.completions.create(
model="deepseek-chat", # Use appropriate OpenAI model
messages=messages,
)
print(f"response: {response}")
final_text.append(response.choices[0].message.content or "")
return "\n".join(final_text)
async def chat_loop(self):
"""Run an interactive chat loop"""
print("\nMCP Client Started!")
print("Type your queries or 'quit' to exit.")
while True:
try:
query = input("\nQuery: ").strip()
if query.lower() == 'quit':
break
response = await self.process_query(query)
print("\n" + response)
except Exception as e:
print(f"\nError: {str(e)}")
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
async def main():
if len(sys.argv) < 2:
print("Usage: python dsclient.py <path_to_server_script> ")
sys.exit(1)
# Optional command line arguments for API key and base URL
client = MCPClient(api_key=api_key, base_url=base_url)
try:
if sys.argv[1].endswith('.py'):
await client.connect_to_server(sys.argv[1])
else:
await client.connect_to_sse_server(sys.argv[1])
await client.chat_loop()
finally:
await client.cleanup()
if __name__ == "__main__":
import sys
asyncio.run(main())