Add files via upload

This commit is contained in:
AiShell 2025-05-26 15:59:20 +08:00 committed by GitHub
parent be0bd74e27
commit a33b34e720
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

159
Wargame/common/llmcaller.py Normal file
View File

@ -0,0 +1,159 @@
import os
from typing import Optional, Dict, Any
import requests
import json
class LLMCaller:
"""统一的LLM调用类支持多种模型服务"""
def __init__(self, provider: str, api_key: Optional[str] = None,
base_url: Optional[str] = None, model: Optional[str] = None):
"""
初始化LLM调用器
Args:
provider: 服务提供商 ('deepseek', 'claude', 'openai', 'qwen', 'ollama')
api_key: API密钥ollama不需要
base_url: API基础URL可选用于自定义端点
model: 模型名称可选使用默认值
"""
self.provider = provider.lower()
self.api_key = api_key or os.environ.get(f"{provider.upper()}_API_KEY")
# 默认配置
self.configs = {
'deepseek': {
'base_url': base_url or 'https://api.deepseek.com/v1',
'model': model or 'deepseek-chat',
'endpoint': '/chat/completions'
},
'claude': {
'base_url': base_url or 'https://api.anthropic.com/v1',
'model': model or 'claude-3-5-sonnet-20241022',
'endpoint': '/messages'
},
'openai': {
'base_url': base_url or 'https://api.openai.com/v1',
'model': model or 'gpt-4o-mini',
'endpoint': '/chat/completions'
},
'qwen': {
'base_url': base_url or 'https://dashscope.aliyuncs.com/compatible-mode/v1',
'model': model or 'qwen-plus',
'endpoint': '/chat/completions'
},
'ollama': {
'base_url': base_url or 'http://localhost:11434',
'model': model or 'llama3.2',
'endpoint': '/api/chat'
}
}
if self.provider not in self.configs:
raise ValueError(f"不支持的provider: {self.provider}")
self.config = self.configs[self.provider]
def call(self, system_prompt: str, user_prompt: str, model: Optional[str] = None, **kwargs) -> str:
"""
调用LLM
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
model: 模型名称可选不传则使用默认模型
**kwargs: 其他参数temperature, max_tokens等
Returns:
LLM的响应文本
"""
# 如果传入了model参数临时使用该模型
original_model = self.config['model']
if model:
self.config['model'] = model
try:
if self.provider == 'claude':
return self._call_claude(system_prompt, user_prompt, **kwargs)
elif self.provider == 'ollama':
return self._call_ollama(system_prompt, user_prompt, **kwargs)
else:
return self._call_openai_compatible(system_prompt, user_prompt, **kwargs)
finally:
# 恢复原始模型设置
self.config['model'] = original_model
def _call_openai_compatible(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""调用OpenAI兼容的APIOpenAI, DeepSeek, Qwen"""
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
data = {
'model': self.config['model'],
'messages': [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_prompt}
],
**kwargs
}
response = requests.post(
f"{self.config['base_url']}{self.config['endpoint']}",
headers=headers,
json=data
)
response.raise_for_status()
return response.json()['choices'][0]['message']['content']
def _call_claude(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""调用Claude API"""
headers = {
'Content-Type': 'application/json',
'X-API-Key': self.api_key,
'anthropic-version': '2023-06-01'
}
data = {
'model': self.config['model'],
'system': system_prompt,
'messages': [
{'role': 'user', 'content': user_prompt}
],
'max_tokens': kwargs.get('max_tokens', 4096),
**{k: v for k, v in kwargs.items() if k != 'max_tokens'}
}
response = requests.post(
f"{self.config['base_url']}{self.config['endpoint']}",
headers=headers,
json=data
)
response.raise_for_status()
return response.json()['content'][0]['text']
def _call_ollama(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""调用Ollama API"""
data = {
'model': self.config['model'],
'messages': [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_prompt}
],
'stream': False,
**kwargs
}
response = requests.post(
f"{self.config['base_url']}{self.config['endpoint']}",
json=data
)
response.raise_for_status()
return response.json()['message']['content']