mirror of
https://github.com/SunZhimin2021/AIPentest.git
synced 2025-11-05 19:04:12 +00:00
225 lines
8.5 KiB
Python
225 lines
8.5 KiB
Python
import asyncio
|
||
from typing import Optional
|
||
from contextlib import AsyncExitStack
|
||
|
||
from mcp import ClientSession, StdioServerParameters
|
||
from mcp.client.stdio import stdio_client
|
||
from mcp.client.sse import sse_client
|
||
|
||
from openai import OpenAI
|
||
import os
|
||
import json
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv() # load environment variables from .env
|
||
api_key = os.getenv('api_key', '')
|
||
base_url= os.getenv('base_url', '')
|
||
modelname=os.getenv('modelname','')
|
||
class MCPClient:
|
||
def __init__(self, api_key=None, base_url=None):
|
||
# Initialize session and client objects
|
||
self.session: Optional[ClientSession] = None
|
||
self.exit_stack = AsyncExitStack()
|
||
|
||
# Initialize OpenAI client with explicit parameters
|
||
self.openai = OpenAI(
|
||
api_key=api_key ,
|
||
base_url=base_url ,
|
||
)
|
||
self.deepseek_client = OpenAI(api_key=api_key,base_url=base_url)
|
||
|
||
async def connect_to_server(self, server_script_path: str):
|
||
"""Connect to an MCP server
|
||
|
||
Args:
|
||
server_script_path: Path to the server script (.py or .js)
|
||
"""
|
||
is_python = server_script_path.endswith('.py')
|
||
is_js = server_script_path.endswith('.js')
|
||
if not (is_python or is_js):
|
||
raise ValueError("Server script must be a .py or .js file")
|
||
|
||
command = "python" if is_python else "node"
|
||
server_params = StdioServerParameters(
|
||
command=command,
|
||
args=[server_script_path],
|
||
env=None
|
||
)
|
||
|
||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||
self.stdio, self.write = stdio_transport
|
||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||
|
||
await self.session.initialize()
|
||
|
||
# List available tools
|
||
response = await self.session.list_tools()
|
||
tools = response.tools
|
||
print("\nConnected to server with tools:", [tool.name for tool in tools])
|
||
|
||
async def connect_to_sse_server(self, server_url: str, timeout=10.0, retries=1):
|
||
"""Connect to an MCP server running with SSE transport"""
|
||
try:
|
||
# Store the context managers so they stay alive
|
||
self._streams_context = sse_client(url=server_url)
|
||
streams = await self._streams_context.__aenter__()
|
||
|
||
self._session_context = ClientSession(*streams)
|
||
self.session: ClientSession = await self._session_context.__aenter__()
|
||
|
||
# Initialize
|
||
await self.session.initialize()
|
||
|
||
# List available tools to verify connection
|
||
print("Initialized SSE client...")
|
||
print("Listing tools...")
|
||
response = await self.session.list_tools()
|
||
tools = response.tools
|
||
print("\nConnected to server with tools:", [tool.name for tool in tools])
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"Connection error: {str(e)}")
|
||
if self.debug:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
async def process_query(self, query: str) -> str:
|
||
"""Process a query using OpenAI and available tools"""
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": query
|
||
}
|
||
]
|
||
|
||
response = await self.session.list_tools()
|
||
available_tools = [{
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool.name,
|
||
"description": tool.description,
|
||
"parameters": tool.inputSchema
|
||
}
|
||
} for tool in response.tools]
|
||
print(f"available_tools: {available_tools}")
|
||
# Initial OpenAI API call
|
||
#response = self.openai.chat.completions.create(
|
||
# model="deepseek-chat", # Use appropriate OpenAI model
|
||
# messages=messages,
|
||
# tools=available_tools,
|
||
# tool_choice="auto"
|
||
#)
|
||
response = self.deepseek_client.chat.completions.create(
|
||
model=modelname, # deepseek官网用的是deepseek-chat
|
||
messages=messages,
|
||
tools=available_tools,
|
||
tool_choice="auto",
|
||
temperature=0
|
||
)
|
||
print(response)
|
||
# Process response and handle tool calls
|
||
final_text = []
|
||
assistant_message = response.choices[0].message
|
||
print(f"assistant_message: {assistant_message}")
|
||
final_text.append(assistant_message.content or "")
|
||
#print (f"assistant_message: {assistant_message}")
|
||
print (f"assistant_message.tool_calls: {assistant_message.tool_calls}")
|
||
# Handle tool calls if present
|
||
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
|
||
for tool_call in assistant_message.tool_calls:
|
||
tool_name = tool_call.function.name
|
||
tool_args = tool_call.function.arguments
|
||
print(f"tool_name: {tool_name}")
|
||
print(f"tool_args: {tool_args}")
|
||
if isinstance(tool_args, str):
|
||
try:
|
||
# 尝试解析JSON字符串成字典
|
||
tool_args = json.loads(tool_args)
|
||
except json.JSONDecodeError:
|
||
# 如果无法解析为JSON,创建一个包含原始字符串的字典
|
||
tool_args = {"input": tool_args}
|
||
|
||
# 再次确认类型,如果仍然不是字典,则创建一个包含它的字典
|
||
if not isinstance(tool_args, dict):
|
||
tool_args = {"value": str(tool_args)}
|
||
|
||
print(f"converted tool_args: {tool_args}")
|
||
# Convert string arguments to JSON if needed
|
||
|
||
|
||
# Execute tool call
|
||
result = await self.session.call_tool(tool_name, tool_args)
|
||
print(f"result: {result}")
|
||
final_text.append(f"[Calling tool {tool_name} with args {tool_args}]")
|
||
result_content = result.content
|
||
if isinstance(result_content, list):
|
||
# 如果是列表(可能包含TextContent对象),提取文本
|
||
result_content = result_content[0].text if hasattr(result_content[0], 'text') else str(result_content)
|
||
elif not isinstance(result_content, str):
|
||
# 如果不是字符串也不是列表,转换为字符串
|
||
result_content = str(result_content)
|
||
# Continue conversation with tool results
|
||
messages.append(assistant_message)
|
||
messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"name": tool_name,
|
||
"content": result_content
|
||
})
|
||
print(f"messages: {messages}")
|
||
# Get next response from OpenAI
|
||
response = self.openai.chat.completions.create(
|
||
model="deepseek-chat", # Use appropriate OpenAI model
|
||
messages=messages,
|
||
)
|
||
print(f"response: {response}")
|
||
final_text.append(response.choices[0].message.content or "")
|
||
|
||
return "\n".join(final_text)
|
||
|
||
async def chat_loop(self):
|
||
"""Run an interactive chat loop"""
|
||
print("\nMCP Client Started!")
|
||
print("Type your queries or 'quit' to exit.")
|
||
|
||
while True:
|
||
try:
|
||
query = input("\nQuery: ").strip()
|
||
|
||
if query.lower() == 'quit':
|
||
break
|
||
|
||
response = await self.process_query(query)
|
||
print("\n" + response)
|
||
|
||
except Exception as e:
|
||
print(f"\nError: {str(e)}")
|
||
|
||
async def cleanup(self):
|
||
"""Clean up resources"""
|
||
await self.exit_stack.aclose()
|
||
|
||
async def main():
|
||
if len(sys.argv) < 2:
|
||
print("Usage: python dsclient.py <path_to_server_script> ")
|
||
sys.exit(1)
|
||
|
||
# Optional command line arguments for API key and base URL
|
||
|
||
|
||
client = MCPClient(api_key=api_key, base_url=base_url)
|
||
try:
|
||
if sys.argv[1].endswith('.py'):
|
||
await client.connect_to_server(sys.argv[1])
|
||
else:
|
||
await client.connect_to_sse_server(sys.argv[1])
|
||
await client.chat_loop()
|
||
finally:
|
||
await client.cleanup()
|
||
|
||
if __name__ == "__main__":
|
||
import sys
|
||
asyncio.run(main())
|