mirror of
https://github.com/SunZhimin2021/AIPentest.git
synced 2025-05-05 10:06:57 +00:00
Update clientdeepseek.py
add sse support
This commit is contained in:
parent
f7ede79533
commit
7bbc02c761
@ -4,9 +4,11 @@ from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # load environment variables from .env
|
||||
@ -24,6 +26,7 @@ class MCPClient:
|
||||
api_key=api_key ,
|
||||
base_url=base_url ,
|
||||
)
|
||||
self.deepseek_client = OpenAI(api_key=api_key,base_url=base_url)
|
||||
|
||||
async def connect_to_server(self, server_script_path: str):
|
||||
"""Connect to an MCP server
|
||||
@ -54,6 +57,34 @@ class MCPClient:
|
||||
tools = response.tools
|
||||
print("\nConnected to server with tools:", [tool.name for tool in tools])
|
||||
|
||||
async def connect_to_sse_server(self, server_url: str, timeout=10.0, retries=1):
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
try:
|
||||
# Store the context managers so they stay alive
|
||||
self._streams_context = sse_client(url=server_url)
|
||||
streams = await self._streams_context.__aenter__()
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self.session: ClientSession = await self._session_context.__aenter__()
|
||||
|
||||
# Initialize
|
||||
await self.session.initialize()
|
||||
|
||||
# List available tools to verify connection
|
||||
print("Initialized SSE client...")
|
||||
print("Listing tools...")
|
||||
response = await self.session.list_tools()
|
||||
tools = response.tools
|
||||
print("\nConnected to server with tools:", [tool.name for tool in tools])
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Connection error: {str(e)}")
|
||||
if self.debug:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def process_query(self, query: str) -> str:
|
||||
"""Process a query using OpenAI and available tools"""
|
||||
messages = [
|
||||
@ -72,39 +103,55 @@ class MCPClient:
|
||||
"parameters": tool.inputSchema
|
||||
}
|
||||
} for tool in response.tools]
|
||||
|
||||
print(f"available_tools: {available_tools}")
|
||||
# Initial OpenAI API call
|
||||
response = self.openai.chat.completions.create(
|
||||
model="deepseek-chat", # Use appropriate OpenAI model
|
||||
messages=messages,
|
||||
tools=available_tools,
|
||||
tool_choice="auto"
|
||||
)
|
||||
#print(response)
|
||||
#response = self.openai.chat.completions.create(
|
||||
# model="deepseek-chat", # Use appropriate OpenAI model
|
||||
# messages=messages,
|
||||
# tools=available_tools,
|
||||
# tool_choice="auto"
|
||||
#)
|
||||
response = self.deepseek_client.chat.completions.create(
|
||||
model=modelname, # deepseek官网用的是deepseek-chat
|
||||
messages=messages,
|
||||
tools=available_tools,
|
||||
tool_choice="auto",
|
||||
temperature=0
|
||||
)
|
||||
print(response)
|
||||
# Process response and handle tool calls
|
||||
final_text = []
|
||||
assistant_message = response.choices[0].message
|
||||
print(f"assistant_message: {assistant_message}")
|
||||
final_text.append(assistant_message.content or "")
|
||||
print (assistant_message)
|
||||
print (assistant_message.tool_calls)
|
||||
#print (f"assistant_message: {assistant_message}")
|
||||
print (f"assistant_message.tool_calls: {assistant_message.tool_calls}")
|
||||
# Handle tool calls if present
|
||||
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
|
||||
for tool_call in assistant_message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = tool_call.function.arguments
|
||||
print(tool_name)
|
||||
print(tool_args)
|
||||
# Convert string arguments to JSON if needed
|
||||
import json
|
||||
print(f"tool_name: {tool_name}")
|
||||
print(f"tool_args: {tool_args}")
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
# 尝试解析JSON字符串成字典
|
||||
tool_args = json.loads(tool_args)
|
||||
except:
|
||||
pass
|
||||
print (tool_args)
|
||||
except json.JSONDecodeError:
|
||||
# 如果无法解析为JSON,创建一个包含原始字符串的字典
|
||||
tool_args = {"input": tool_args}
|
||||
|
||||
# 再次确认类型,如果仍然不是字典,则创建一个包含它的字典
|
||||
if not isinstance(tool_args, dict):
|
||||
tool_args = {"value": str(tool_args)}
|
||||
|
||||
print(f"converted tool_args: {tool_args}")
|
||||
# Convert string arguments to JSON if needed
|
||||
|
||||
|
||||
# Execute tool call
|
||||
result = await self.session.call_tool(tool_name, tool_args)
|
||||
print(result)
|
||||
print(f"result: {result}")
|
||||
final_text.append(f"[Calling tool {tool_name} with args {tool_args}]")
|
||||
result_content = result.content
|
||||
if isinstance(result_content, list):
|
||||
@ -121,13 +168,13 @@ class MCPClient:
|
||||
"name": tool_name,
|
||||
"content": result_content
|
||||
})
|
||||
print(messages)
|
||||
print(f"messages: {messages}")
|
||||
# Get next response from OpenAI
|
||||
response = self.openai.chat.completions.create(
|
||||
model="deepseek-chat", # Use appropriate OpenAI model
|
||||
messages=messages,
|
||||
)
|
||||
print(response)
|
||||
print(f"response: {response}")
|
||||
final_text.append(response.choices[0].message.content or "")
|
||||
|
||||
return "\n".join(final_text)
|
||||
@ -156,16 +203,18 @@ class MCPClient:
|
||||
|
||||
async def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python client.py <path_to_server_script> [api_key] [base_url]")
|
||||
print("Usage: python dsclient.py <path_to_server_script> ")
|
||||
sys.exit(1)
|
||||
|
||||
# Optional command line arguments for API key and base URL
|
||||
api_key = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
base_url = sys.argv[3] if len(sys.argv) > 3 else None
|
||||
|
||||
|
||||
client = MCPClient(api_key=api_key, base_url=base_url)
|
||||
try:
|
||||
await client.connect_to_server(sys.argv[1])
|
||||
if sys.argv[1].endswith('.py'):
|
||||
await client.connect_to_server(sys.argv[1])
|
||||
else:
|
||||
await client.connect_to_sse_server(sys.argv[1])
|
||||
await client.chat_loop()
|
||||
finally:
|
||||
await client.cleanup()
|
||||
|
Loading…
x
Reference in New Issue
Block a user