mirror of
https://github.com/SunZhimin2021/AIPentest.git
synced 2025-06-21 10:21:03 +00:00
Add files via upload
This commit is contained in:
parent
be0bd74e27
commit
a33b34e720
159
Wargame/common/llmcaller.py
Normal file
159
Wargame/common/llmcaller.py
Normal file
@ -0,0 +1,159 @@
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
import requests
|
||||
import json
|
||||
|
||||
|
||||
class LLMCaller:
|
||||
"""统一的LLM调用类,支持多种模型服务"""
|
||||
|
||||
def __init__(self, provider: str, api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None, model: Optional[str] = None):
|
||||
"""
|
||||
初始化LLM调用器
|
||||
|
||||
Args:
|
||||
provider: 服务提供商 ('deepseek', 'claude', 'openai', 'qwen', 'ollama')
|
||||
api_key: API密钥(ollama不需要)
|
||||
base_url: API基础URL(可选,用于自定义端点)
|
||||
model: 模型名称(可选,使用默认值)
|
||||
"""
|
||||
self.provider = provider.lower()
|
||||
self.api_key = api_key or os.environ.get(f"{provider.upper()}_API_KEY")
|
||||
|
||||
# 默认配置
|
||||
self.configs = {
|
||||
'deepseek': {
|
||||
'base_url': base_url or 'https://api.deepseek.com/v1',
|
||||
'model': model or 'deepseek-chat',
|
||||
'endpoint': '/chat/completions'
|
||||
},
|
||||
'claude': {
|
||||
'base_url': base_url or 'https://api.anthropic.com/v1',
|
||||
'model': model or 'claude-3-5-sonnet-20241022',
|
||||
'endpoint': '/messages'
|
||||
},
|
||||
'openai': {
|
||||
'base_url': base_url or 'https://api.openai.com/v1',
|
||||
'model': model or 'gpt-4o-mini',
|
||||
'endpoint': '/chat/completions'
|
||||
},
|
||||
'qwen': {
|
||||
'base_url': base_url or 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
'model': model or 'qwen-plus',
|
||||
'endpoint': '/chat/completions'
|
||||
},
|
||||
'ollama': {
|
||||
'base_url': base_url or 'http://localhost:11434',
|
||||
'model': model or 'llama3.2',
|
||||
'endpoint': '/api/chat'
|
||||
}
|
||||
}
|
||||
|
||||
if self.provider not in self.configs:
|
||||
raise ValueError(f"不支持的provider: {self.provider}")
|
||||
|
||||
self.config = self.configs[self.provider]
|
||||
|
||||
def call(self, system_prompt: str, user_prompt: str, model: Optional[str] = None, **kwargs) -> str:
|
||||
"""
|
||||
调用LLM
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
model: 模型名称(可选,不传则使用默认模型)
|
||||
**kwargs: 其他参数(temperature, max_tokens等)
|
||||
|
||||
Returns:
|
||||
LLM的响应文本
|
||||
"""
|
||||
# 如果传入了model参数,临时使用该模型
|
||||
original_model = self.config['model']
|
||||
if model:
|
||||
self.config['model'] = model
|
||||
|
||||
try:
|
||||
if self.provider == 'claude':
|
||||
return self._call_claude(system_prompt, user_prompt, **kwargs)
|
||||
elif self.provider == 'ollama':
|
||||
return self._call_ollama(system_prompt, user_prompt, **kwargs)
|
||||
else:
|
||||
return self._call_openai_compatible(system_prompt, user_prompt, **kwargs)
|
||||
finally:
|
||||
# 恢复原始模型设置
|
||||
self.config['model'] = original_model
|
||||
|
||||
def _call_openai_compatible(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""调用OpenAI兼容的API(OpenAI, DeepSeek, Qwen)"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
|
||||
data = {
|
||||
'model': self.config['model'],
|
||||
'messages': [
|
||||
{'role': 'system', 'content': system_prompt},
|
||||
{'role': 'user', 'content': user_prompt}
|
||||
],
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.config['base_url']}{self.config['endpoint']}",
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
|
||||
def _call_claude(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""调用Claude API"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'X-API-Key': self.api_key,
|
||||
'anthropic-version': '2023-06-01'
|
||||
}
|
||||
|
||||
data = {
|
||||
'model': self.config['model'],
|
||||
'system': system_prompt,
|
||||
'messages': [
|
||||
{'role': 'user', 'content': user_prompt}
|
||||
],
|
||||
'max_tokens': kwargs.get('max_tokens', 4096),
|
||||
**{k: v for k, v in kwargs.items() if k != 'max_tokens'}
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.config['base_url']}{self.config['endpoint']}",
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()['content'][0]['text']
|
||||
|
||||
def _call_ollama(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""调用Ollama API"""
|
||||
data = {
|
||||
'model': self.config['model'],
|
||||
'messages': [
|
||||
{'role': 'system', 'content': system_prompt},
|
||||
{'role': 'user', 'content': user_prompt}
|
||||
],
|
||||
'stream': False,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.config['base_url']}{self.config['endpoint']}",
|
||||
json=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()['message']['content']
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user