diff --git a/Wargame/common/llmcaller.py b/Wargame/common/llmcaller.py new file mode 100644 index 0000000..5756114 --- /dev/null +++ b/Wargame/common/llmcaller.py @@ -0,0 +1,159 @@ +import os +from typing import Optional, Dict, Any +import requests +import json + + +class LLMCaller: + """统一的LLM调用类,支持多种模型服务""" + + def __init__(self, provider: str, api_key: Optional[str] = None, + base_url: Optional[str] = None, model: Optional[str] = None): + """ + 初始化LLM调用器 + + Args: + provider: 服务提供商 ('deepseek', 'claude', 'openai', 'qwen', 'ollama') + api_key: API密钥(ollama不需要) + base_url: API基础URL(可选,用于自定义端点) + model: 模型名称(可选,使用默认值) + """ + self.provider = provider.lower() + self.api_key = api_key or os.environ.get(f"{provider.upper()}_API_KEY") + + # 默认配置 + self.configs = { + 'deepseek': { + 'base_url': base_url or 'https://api.deepseek.com/v1', + 'model': model or 'deepseek-chat', + 'endpoint': '/chat/completions' + }, + 'claude': { + 'base_url': base_url or 'https://api.anthropic.com/v1', + 'model': model or 'claude-3-5-sonnet-20241022', + 'endpoint': '/messages' + }, + 'openai': { + 'base_url': base_url or 'https://api.openai.com/v1', + 'model': model or 'gpt-4o-mini', + 'endpoint': '/chat/completions' + }, + 'qwen': { + 'base_url': base_url or 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'model': model or 'qwen-plus', + 'endpoint': '/chat/completions' + }, + 'ollama': { + 'base_url': base_url or 'http://localhost:11434', + 'model': model or 'llama3.2', + 'endpoint': '/api/chat' + } + } + + if self.provider not in self.configs: + raise ValueError(f"不支持的provider: {self.provider}") + + self.config = self.configs[self.provider] + + def call(self, system_prompt: str, user_prompt: str, model: Optional[str] = None, **kwargs) -> str: + """ + 调用LLM + + Args: + system_prompt: 系统提示词 + user_prompt: 用户提示词 + model: 模型名称(可选,不传则使用默认模型) + **kwargs: 其他参数(temperature, max_tokens等) + + Returns: + LLM的响应文本 + """ + # 如果传入了model参数,临时使用该模型 + original_model = self.config['model'] + if model: + self.config['model'] = model + + try: + if self.provider == 'claude': + return self._call_claude(system_prompt, user_prompt, **kwargs) + elif self.provider == 'ollama': + return self._call_ollama(system_prompt, user_prompt, **kwargs) + else: + return self._call_openai_compatible(system_prompt, user_prompt, **kwargs) + finally: + # 恢复原始模型设置 + self.config['model'] = original_model + + def _call_openai_compatible(self, system_prompt: str, user_prompt: str, **kwargs) -> str: + """调用OpenAI兼容的API(OpenAI, DeepSeek, Qwen)""" + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}' + } + + data = { + 'model': self.config['model'], + 'messages': [ + {'role': 'system', 'content': system_prompt}, + {'role': 'user', 'content': user_prompt} + ], + **kwargs + } + + response = requests.post( + f"{self.config['base_url']}{self.config['endpoint']}", + headers=headers, + json=data + ) + response.raise_for_status() + + return response.json()['choices'][0]['message']['content'] + + def _call_claude(self, system_prompt: str, user_prompt: str, **kwargs) -> str: + """调用Claude API""" + headers = { + 'Content-Type': 'application/json', + 'X-API-Key': self.api_key, + 'anthropic-version': '2023-06-01' + } + + data = { + 'model': self.config['model'], + 'system': system_prompt, + 'messages': [ + {'role': 'user', 'content': user_prompt} + ], + 'max_tokens': kwargs.get('max_tokens', 4096), + **{k: v for k, v in kwargs.items() if k != 'max_tokens'} + } + + response = requests.post( + f"{self.config['base_url']}{self.config['endpoint']}", + headers=headers, + json=data + ) + response.raise_for_status() + + return response.json()['content'][0]['text'] + + def _call_ollama(self, system_prompt: str, user_prompt: str, **kwargs) -> str: + """调用Ollama API""" + data = { + 'model': self.config['model'], + 'messages': [ + {'role': 'system', 'content': system_prompt}, + {'role': 'user', 'content': user_prompt} + ], + 'stream': False, + **kwargs + } + + response = requests.post( + f"{self.config['base_url']}{self.config['endpoint']}", + json=data + ) + response.raise_for_status() + + return response.json()['message']['content'] + +