mirror of
https://github.com/SunZhimin2021/AIPentest.git
synced 2025-05-05 18:17:02 +00:00
Create clientdeepseek.py
init
This commit is contained in:
parent
f6bd8d46fe
commit
549ed3517f
175
mcp/clientdeepseek.py
Normal file
175
mcp/clientdeepseek.py
Normal file
@ -0,0 +1,175 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # load environment variables from .env
|
||||
api_key = os.getenv('api_key', '')
|
||||
base_url= os.getenv('base_url', '')
|
||||
modelname=os.getenv('modelname','')
|
||||
class MCPClient:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
# Initialize OpenAI client with explicit parameters
|
||||
self.openai = OpenAI(
|
||||
api_key=api_key ,
|
||||
base_url=base_url ,
|
||||
)
|
||||
|
||||
async def connect_to_server(self, server_script_path: str):
|
||||
"""Connect to an MCP server
|
||||
|
||||
Args:
|
||||
server_script_path: Path to the server script (.py or .js)
|
||||
"""
|
||||
is_python = server_script_path.endswith('.py')
|
||||
is_js = server_script_path.endswith('.js')
|
||||
if not (is_python or is_js):
|
||||
raise ValueError("Server script must be a .py or .js file")
|
||||
|
||||
command = "python" if is_python else "node"
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=[server_script_path],
|
||||
env=None
|
||||
)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
# List available tools
|
||||
response = await self.session.list_tools()
|
||||
tools = response.tools
|
||||
print("\nConnected to server with tools:", [tool.name for tool in tools])
|
||||
|
||||
async def process_query(self, query: str) -> str:
|
||||
"""Process a query using OpenAI and available tools"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": query
|
||||
}
|
||||
]
|
||||
|
||||
response = await self.session.list_tools()
|
||||
available_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.inputSchema
|
||||
}
|
||||
} for tool in response.tools]
|
||||
|
||||
# Initial OpenAI API call
|
||||
response = self.openai.chat.completions.create(
|
||||
model="deepseek-chat", # Use appropriate OpenAI model
|
||||
messages=messages,
|
||||
tools=available_tools,
|
||||
tool_choice="auto"
|
||||
)
|
||||
#print(response)
|
||||
# Process response and handle tool calls
|
||||
final_text = []
|
||||
assistant_message = response.choices[0].message
|
||||
final_text.append(assistant_message.content or "")
|
||||
print (assistant_message)
|
||||
print (assistant_message.tool_calls)
|
||||
# Handle tool calls if present
|
||||
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
|
||||
for tool_call in assistant_message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = tool_call.function.arguments
|
||||
print(tool_name)
|
||||
print(tool_args)
|
||||
# Convert string arguments to JSON if needed
|
||||
import json
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args = json.loads(tool_args)
|
||||
except:
|
||||
pass
|
||||
print (tool_args)
|
||||
# Execute tool call
|
||||
result = await self.session.call_tool(tool_name, tool_args)
|
||||
print(result)
|
||||
final_text.append(f"[Calling tool {tool_name} with args {tool_args}]")
|
||||
result_content = result.content
|
||||
if isinstance(result_content, list):
|
||||
# 如果是列表(可能包含TextContent对象),提取文本
|
||||
result_content = result_content[0].text if hasattr(result_content[0], 'text') else str(result_content)
|
||||
elif not isinstance(result_content, str):
|
||||
# 如果不是字符串也不是列表,转换为字符串
|
||||
result_content = str(result_content)
|
||||
# Continue conversation with tool results
|
||||
messages.append(assistant_message)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_name,
|
||||
"content": result_content
|
||||
})
|
||||
print(messages)
|
||||
# Get next response from OpenAI
|
||||
response = self.openai.chat.completions.create(
|
||||
model="deepseek-chat", # Use appropriate OpenAI model
|
||||
messages=messages,
|
||||
)
|
||||
print(response)
|
||||
final_text.append(response.choices[0].message.content or "")
|
||||
|
||||
return "\n".join(final_text)
|
||||
|
||||
async def chat_loop(self):
|
||||
"""Run an interactive chat loop"""
|
||||
print("\nMCP Client Started!")
|
||||
print("Type your queries or 'quit' to exit.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nQuery: ").strip()
|
||||
|
||||
if query.lower() == 'quit':
|
||||
break
|
||||
|
||||
response = await self.process_query(query)
|
||||
print("\n" + response)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError: {str(e)}")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
async def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python client.py <path_to_server_script> [api_key] [base_url]")
|
||||
sys.exit(1)
|
||||
|
||||
# Optional command line arguments for API key and base URL
|
||||
api_key = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
base_url = sys.argv[3] if len(sys.argv) > 3 else None
|
||||
|
||||
client = MCPClient(api_key=api_key, base_url=base_url)
|
||||
try:
|
||||
await client.connect_to_server(sys.argv[1])
|
||||
await client.chat_loop()
|
||||
finally:
|
||||
await client.cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
asyncio.run(main())
|
Loading…
x
Reference in New Issue
Block a user