2025-02-06 21:10:34 +08:00

1119 lines
42 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, validator
from openai import AsyncOpenAI
import os
from typing import List, Optional, Dict, Set, Tuple
import aiohttp
import logging
from pydantic_settings import BaseSettings
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from enum import Enum
from urllib.parse import urlparse
import asyncio
import zipfile
from pathlib import Path
import tempfile
import json
# 配置类
class ModelType(str, Enum):
GPT35 = "gpt-3.5-turbo"
GPT4 = "gpt-4"
CLAUDE = "claude-2"
CLAUDE3 = "claude-3"
class Settings(BaseSettings):
OPENAI_API_KEY: str = "sk-LshpT3NTpaT3HDw317634b552f22499c9aE9E75e003b9fA3"
OPENAI_API_BASE: str = "https://free.v36.cm/v1"
OPENAI_MODEL: str = "gpt-4o-mini-2024-07-18"
# 更新可用模型配置,按域名和类型分组
AVAILABLE_MODELS: Dict[str, Dict[str, List[str]]] = {
"guji.ai": {
"Chat": [
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-32k",
"claude-2",
"claude-3",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"THUDM/glm-4-9b-chat",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat"
],
"Embedding": [
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"BAAI/bge-m3",
"netease-youdao/bce-embedding-base_v1"
],
"Image": [
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-3-medium",
"stabilityai/stable-diffusion-3-5-large",
"stabilityai/stable-diffusion-3-5-large-turbo"
],
"Audio": [
"FunAudioLLM/SenseVoiceSmall",
"fishaudio/fish-speech-1.4",
"fishaudio/fish-speech-1.5",
"FunAudioLLM/CosyVoice2-0.5B"
],
"Pro": [
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/THUDM/glm-4-9b-chat",
"Pro/BAAI/bge-m3",
"Pro/OpenGVLab/InternVL2-8B"
]
},
"360.com": {
"Chat": [
"360GPT_S2_V9",
"360GPT_S2_V9.4",
"360GPT_S2_V9.4-4K",
"360GPT_S2_V9.4-8K"
]
},
"v36.cm": {
"Chat": [
"360GPT_S2_V9",
"360GPT_S2_V9.4",
"360GPT_S2_V9.4-4K",
"360GPT_S2_V9.4-8K"
]
},
"api.siliconflow.cn": {
"Chat": [
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0125-preview",
"gpt-4-32k",
"claude-2",
"claude-3",
"360GPT_S2_V9",
"360GPT_S2_V9.4",
"360GPT_S2_V9.4-4K",
"360GPT_S2_V9.4-8K",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"THUDM/glm-4-9b-chat",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat"
]
}
}
@validator('OPENAI_API_BASE')
def validate_api_base(cls, v):
v = v.rstrip('/')
if not v.endswith('/v1'):
v = v + '/v1'
return v
async def fetch_models_from_api(self, api_base: str, api_key: str) -> Dict[str, List[str]]:
"""从API获取可用模型列表"""
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
async with session.get(f"{api_base}/models", headers=headers) as response:
if response.status == 200:
data = await response.json()
models = data.get("data", [])
# 按类型分组模型
grouped_models = {
"Chat": [],
"Embedding": [],
"Image": [],
"Audio": [],
"Pro": []
}
for model in models:
model_id = model.get("id")
if not model_id:
continue
# 根据模型ID分类
if model_id.startswith("360GPT"):
grouped_models["Chat"].append(model_id)
elif "embedding" in model_id.lower():
grouped_models["Embedding"].append(model_id)
elif any(img_key in model_id.lower() for img_key in ["diffusion", "dall-e"]):
grouped_models["Image"].append(model_id)
elif any(audio_key in model_id.lower() for audio_key in ["speech", "voice", "audio"]):
grouped_models["Audio"].append(model_id)
elif model_id.startswith("Pro/"):
grouped_models["Pro"].append(model_id)
else:
grouped_models["Chat"].append(model_id)
# 移除空类别
return {k: v for k, v in grouped_models.items() if v}
else:
error_data = await response.text()
logger.error(f"获取模型列表失败: {error_data}")
return self.get_default_models(api_base)
except Exception as e:
logger.error(f"获取模型列表时出错: {str(e)}")
return self.get_default_models(api_base)
def get_default_models(self, api_base: str) -> Dict[str, List[str]]:
"""根据API地址获取默认模型列表"""
try:
parsed_url = urlparse(api_base)
domain = parsed_url.netloc.lower()
# 提取基本域名
base_domain = '.'.join(domain.split('.')[-2:])
# 尝试直接匹配域名
if domain in self.AVAILABLE_MODELS:
logger.info(f"直接匹配到域名: {domain}")
return self.AVAILABLE_MODELS[domain]
# 尝试匹配基本域名
if base_domain in self.AVAILABLE_MODELS:
logger.info(f"匹配到基本域名: {base_domain}")
return self.AVAILABLE_MODELS[base_domain]
# 尝试部分匹配
for configured_domain in self.AVAILABLE_MODELS.keys():
if (configured_domain in domain or
domain in configured_domain or
configured_domain in base_domain or
base_domain in configured_domain):
logger.info(f"部分匹配到域名: {configured_domain}")
return self.AVAILABLE_MODELS[configured_domain]
# 如果是特定域名,返回对应配置
if "v36.cm" in domain:
logger.info("匹配到 v36.cm 域名")
return self.AVAILABLE_MODELS["v36.cm"]
if "360.com" in domain:
logger.info("匹配到 360.com 域名")
return self.AVAILABLE_MODELS["360.com"]
if "siliconflow.cn" in domain:
logger.info("匹配到 siliconflow.cn 域名")
return self.AVAILABLE_MODELS["api.siliconflow.cn"]
logger.warning(f"未找到匹配的域名配置: {domain}")
return {"Chat": ["gpt-3.5-turbo"]}
except Exception as e:
logger.error(f"获取默认模型列表时出错: {str(e)}")
return {"Chat": ["gpt-3.5-turbo"]}
async def get_models_for_api(self, api_base: str, api_key: str = None) -> Dict[str, List[str]]:
"""获取指定API地址支持的模型列表"""
try:
# 尝试从API获取模型列表
if api_key:
try:
models = await self.fetch_models_from_api(api_base, api_key)
if models:
logger.info(f"从API获取到模型列表: {models}")
return models
except Exception as e:
logger.warning(f"从API获取模型列表失败: {str(e)}")
# 如果API获取失败使用预配置的模型列表
return self.get_default_models(api_base)
except Exception as e:
logger.error(f"获取模型列表时出错: {str(e)}")
return {"Chat": ["gpt-3.5-turbo"]}
class Config:
env_file = ".env"
case_sensitive = True
settings = Settings()
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI(title="代码审计工具API")
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 挂载静态文件目录如果前端文件在static目录下
static_dir = os.path.join(os.path.dirname(__file__), "static")
if os.path.exists(static_dir):
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# 添加新的数据模型
class ProjectAuditResult(BaseModel):
file_path: str
language: str
vulnerabilities: List[dict]
related_files: List[str]
context_analysis: str
class ProjectAnalysis:
def __init__(self):
self.file_dependencies: Dict[str, Set[str]] = {} # 文件依赖关系
self.shared_variables: Dict[str, Set[str]] = {} # 共享变量
self.function_calls: Dict[str, Set[str]] = {} # 函数调用关系
self.vulnerability_context: Dict[str, List[dict]] = {} # 漏洞上下文
def add_dependency(self, file: str, depends_on: str):
if file not in self.file_dependencies:
self.file_dependencies[file] = set()
self.file_dependencies[file].add(depends_on)
def add_shared_variable(self, file: str, variable: str):
if file not in self.shared_variables:
self.shared_variables[file] = set()
self.shared_variables[file].add(variable)
def add_function_call(self, source_file: str, target_file: str):
if source_file not in self.function_calls:
self.function_calls[source_file] = set()
self.function_calls[source_file].add(target_file)
def get_related_files(self, file: str) -> Set[str]:
"""获取与指定文件相关的所有文件"""
related = set()
if file in self.file_dependencies:
related.update(self.file_dependencies[file])
if file in self.function_calls:
related.update(self.function_calls[file])
return related
class CodeAuditService:
def __init__(self):
"""初始化服务"""
self.client = None
self.openai_api_key = None
self.api_base = None
self.model = None
# 初始化时设置基本配置,但不进行异步操作
self._init_config(settings.OPENAI_API_KEY, settings.OPENAI_API_BASE)
self.project_analysis = ProjectAnalysis()
self.supported_languages = {
'.php': 'php',
'.java': 'java',
'.py': 'python',
'.js': 'javascript'
}
def _init_config(self, api_key: str, api_base: str):
"""初始化基本配置"""
self.openai_api_key = api_key
self.api_base = api_base.rstrip('/') + '/v1' if not api_base.endswith('/v1') else api_base
self.model = "gpt-3.5-turbo" # 设置默认模型
async def ensure_initialized(self):
"""确保服务完全初始化"""
if not self.client:
await self.configure_openai(self.openai_api_key, self.api_base)
async def configure_openai(self, api_key: str = None, api_base: str = None, model: str = None):
"""配置OpenAI API设置"""
if not api_key and not self.openai_api_key:
raise ValueError("未设置OPENAI_API_KEY")
# 更新配置
if api_key:
self.openai_api_key = api_key
if api_base:
api_base = api_base.rstrip('/')
if not api_base.endswith('/v1'):
api_base = api_base + '/v1'
self.api_base = api_base
elif not self.api_base:
self.api_base = settings.OPENAI_API_BASE
# 获取当前API地址支持的模型列表
available_models = await settings.get_models_for_api(self.api_base, self.openai_api_key)
# 设置模型
if model:
model_found = False
for category_models in available_models.values():
if model in category_models:
self.model = model
model_found = True
break
if not model_found:
raise ValueError(f"该API地址不支持模型: {model}")
elif not self.model:
if "Chat" in available_models and available_models["Chat"]:
self.model = available_models["Chat"][0]
else:
first_category = next(iter(available_models))
if available_models[first_category]:
self.model = available_models[first_category][0]
else:
self.model = "gpt-3.5-turbo"
try:
self.client = AsyncOpenAI(
api_key=self.openai_api_key,
base_url=self.api_base,
timeout=120.0, # 增加超时时间到120秒
max_retries=5 # 增加重试次数
)
logger.info(f"OpenAI API已配置: {self.api_base}, 使用模型: {self.model}")
except Exception as e:
logger.error(f"OpenAI客户端配置失败: {str(e)}")
raise ValueError(f"API配置失败: {str(e)}")
async def analyze_code(self, code: str, language: str, api_key: str = None, api_base: str = None) -> dict:
"""分析代码支持自定义API设置"""
try:
# 如果提供了新的API设置重新配置
if api_key or api_base:
self.configure_openai(api_key, api_base)
# 验证语言支持
if language not in self.supported_languages.values():
raise HTTPException(
status_code=400,
detail=f"不支持的语言类型。支持的语言包括: {', '.join(self.supported_languages.values())}"
)
# 第一轮AI分析
logger.info(f"开始第一轮{language}代码分析")
first_response = await self._get_openai_response(
self._generate_first_prompt(code, language)
)
# 第二轮AI验证
logger.info("开始第二轮验证分析")
second_response = await self._get_openai_response(
self._generate_second_prompt(code, first_response)
)
return {
"first_analysis": first_response,
"second_analysis": second_response
}
except Exception as e:
logger.error(f"代码分析过程中发生错误: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
def _generate_first_prompt(self, code: str, language: str) -> str:
"""生成第一轮分析的提示"""
language_specific_prompts = {
'php': """
请特别注意以下PHP特有的安全问题
1. 文件包含漏洞include/require
2. SQL注入mysql_*函数)
3. 命令注入system, exec, shell_exec
4. 文件上传漏洞
5. 反序列化漏洞unserialize
6. XSSecho, print
7. SSRF漏洞
8. 目录遍历
9. 会话管理问题
10. 配置文件泄露
""",
'java': """
请特别注意以下Java特有的安全问题
1. SQL注入PreparedStatement相关
2. 命令注入Runtime.exec, ProcessBuilder
3. XXE漏洞XML解析器配置
4. 反序列化漏洞readObject
5. 不安全的文件操作
6. CSRF/XSS防护
7. 权限控制缺陷
8. 线程安全问题
9. 密码学实现缺陷
10. 日志信息泄露
""",
'python': """
请特别注意以下Python特有的安全问题
1. 不安全的反序列化如pickle.loads
2. 命令注入如os.system, eval
3. 不安全的模块导入
4. SQL注入特别是字符串格式化
5. 路径遍历
""",
'javascript': """
请特别注意以下JavaScript特有的安全问题
1. XSSDOM型和反射型
2. 原型污染
3. 不安全的第三方依赖
4. 客户端存储安全
5. 不安全的随机数生成
"""
}
base_prompt = f"""请分析以下{language}代码中的安全漏洞:
{code}
请详细说明每个潜在的安全问题,包括:
1. 漏洞类型
2. 漏洞位置
3. 可能的影响
4. 修复建议
"""
if language in language_specific_prompts:
base_prompt += f"\n{language_specific_prompts[language]}"
return base_prompt
def _check_file_type(self, filename: str) -> str:
"""检查文件类型是否支持"""
ext = os.path.splitext(filename)[1].lower()
if ext not in self.supported_languages:
supported_exts = ', '.join(self.supported_languages.keys())
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型。支持的文件类型包括: {supported_exts}"
)
return self.supported_languages[ext]
def _generate_second_prompt(self, code: str, first_response: str) -> str:
return f"""请验证以下代码审计结果的准确性并提供可能的payload
{first_response}
代码:
{code}"""
async def _get_openai_response(self, prompt: str) -> str:
if not self.client:
raise ValueError("OpenAI客户端未初始化")
try:
logger.info(f"正在发送请求到: {self.api_base}, 使用模型: {self.model}")
# 添加重试逻辑
max_retries = 3
retry_count = 0
retry_delay = 1 # 初始延迟1秒
while retry_count < max_retries:
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "你是一个专业的代码安全审计专家。"},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=2000
)
if hasattr(response, 'choices') and len(response.choices) > 0:
content = response.choices[0].message.content
logger.debug(f"收到响应: {content[:100]}...")
return content
else:
raise ValueError("API响应格式错误")
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
raise
logger.warning(f"请求失败,正在进行第{retry_count}次重试: {str(e)}")
await asyncio.sleep(retry_delay)
retry_delay *= 2 # 指数退避
except Exception as e:
error_msg = str(e)
logger.error(f"OpenAI API调用失败: {error_msg}")
if "401" in error_msg:
raise HTTPException(status_code=401, detail="API密钥无效或未授权")
elif "timeout" in error_msg.lower():
raise HTTPException(status_code=504, detail="API请求超时请稍后重试")
elif "404" in error_msg:
raise HTTPException(status_code=404, detail="API端点不存在请检查API基础URL")
else:
raise HTTPException(status_code=500, detail=f"AI分析服务错误: {error_msg}")
async def analyze_project(self, zip_file: UploadFile) -> Dict[str, ProjectAuditResult]:
"""分析整个项目代码"""
results = {}
# 创建临时目录
with tempfile.TemporaryDirectory() as temp_dir:
# 保存并解压ZIP文件
zip_path = Path(temp_dir) / "project.zip"
with open(zip_path, "wb") as f:
content = await zip_file.read()
f.write(content)
# 解压文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# 首先进行项目结构分析
await self._analyze_project_structure(temp_dir)
# 分析每个文件
for file_path in Path(temp_dir).rglob('*'):
if file_path.suffix in self.supported_languages:
rel_path = str(file_path.relative_to(temp_dir))
try:
result = await self._analyze_file_with_context(
file_path,
self.project_analysis.get_related_files(rel_path)
)
results[rel_path] = result
except Exception as e:
logger.error(f"分析文件 {rel_path} 时出错: {str(e)}")
# 进行交叉验证和上下文关联分析
await self._cross_validate_results(results)
return results
async def _analyze_project_structure(self, project_dir: str):
"""分析项目结构,建立依赖关系"""
for file_path in Path(project_dir).rglob('*'):
if file_path.suffix not in self.supported_languages:
continue
rel_path = str(file_path.relative_to(project_dir))
content = file_path.read_text(errors='ignore')
# 分析文件依赖
await self._analyze_dependencies(rel_path, content)
# 分析共享变量
await self._analyze_shared_variables(rel_path, content)
# 分析函数调用
await self._analyze_function_calls(rel_path, content)
async def _analyze_file_with_context(
self,
file_path: Path,
related_files: Set[str]
) -> ProjectAuditResult:
"""分析单个文件,考虑上下文"""
content = file_path.read_text(errors='ignore')
language = file_path.suffix[1:] # 移除点号
# 构建包含上下文的提示
context_prompt = f"""请分析以下{language}代码重点关注安全漏洞。请提供详细的JSON格式分析结果
代码内容:
{content}
相关文件:
{', '.join(related_files)}
请提供以下格式的分析结果:
{{
"vulnerabilities": [
{{
"type": "漏洞类型如SQL注入、XSS等",
"location": "具体代码行号和代码片段",
"severity": "严重程度(高/中/低)",
"description": "详细的漏洞描述",
"impact": "潜在影响",
"fix": "修复建议",
"related_context": "相关的上下文信息"
}}
],
"context_analysis": "整体代码安全性分析",
"related_files": {{
"dependencies": ["相关的依赖文件"],
"includes": ["包含的文件"],
"functions": ["调用的函数"],
"affected_by": ["受影响的文件"],
"affects": ["可能影响的文件"]
}}
}}
请特别注意:
1. 详细分析每个可能的漏洞
2. 提供具体的代码位置
3. 给出可行的修复建议
4. 分析代码与其他文件的关联
5. 考虑整体的安全影响
"""
# 获取AI分析结果
analysis_result = await self._get_openai_response(context_prompt)
# 解析结果
try:
result_dict = json.loads(analysis_result)
# 确保结果包含所有必要字段
if 'vulnerabilities' not in result_dict:
result_dict['vulnerabilities'] = []
if 'context_analysis' not in result_dict:
result_dict['context_analysis'] = "未提供分析结果"
if 'related_files' not in result_dict:
result_dict['related_files'] = {
"dependencies": [],
"includes": [],
"functions": [],
"affected_by": [],
"affects": []
}
# 转换相关文件格式
related_files_list = []
for category, files in result_dict['related_files'].items():
if files:
related_files_list.extend([f"{category}: {file}" for file in files])
return ProjectAuditResult(
file_path=str(file_path),
language=language,
vulnerabilities=result_dict['vulnerabilities'],
related_files=related_files_list,
context_analysis=result_dict['context_analysis']
)
except json.JSONDecodeError:
# 如果结果不是JSON格式返回基本结构
return ProjectAuditResult(
file_path=str(file_path),
language=language,
vulnerabilities=[],
related_files=[],
context_analysis=analysis_result
)
async def _cross_validate_results(self, results: Dict[str, ProjectAuditResult]):
"""交叉验证分析结果"""
# 收集所有漏洞
all_vulnerabilities = []
for result in results.values():
all_vulnerabilities.extend(result.vulnerabilities)
# 生成交叉验证提示
validation_prompt = f"""请验证以下项目漏洞分析结果的准确性和完整性:
发现的漏洞:
{json.dumps(all_vulnerabilities, indent=2, ensure_ascii=False)}
请考虑:
1. 漏洞之间的关联性
2. 漏洞的优先级
3. 误报可能性
4. 修复建议的可行性
"""
# 获取验证结果
validation_result = await self._get_openai_response(validation_prompt)
# 更新结果
for result in results.values():
result.context_analysis += f"\n\n交叉验证结果:\n{validation_result}"
async def _analyze_dependencies(self, file_path: str, content: str):
"""分析文件依赖关系"""
try:
language = Path(file_path).suffix[1:]
# 根据不同语言分析依赖
if language == 'php':
await self._analyze_php_dependencies(file_path, content)
elif language == 'java':
await self._analyze_java_dependencies(file_path, content)
elif language == 'py':
await self._analyze_python_dependencies(file_path, content)
elif language == 'js':
await self._analyze_js_dependencies(file_path, content)
except Exception as e:
logger.error(f"分析文件依赖时出错 {file_path}: {str(e)}")
async def _analyze_php_dependencies(self, file_path: str, content: str):
"""分析PHP文件依赖"""
import re
patterns = [
r'(?:include|require|include_once|require_once)\s*[\'"]([^\'"]+)[\'"]',
r'use\s+([^;]+)',
r'namespace\s+([^;{\s]+)'
]
for pattern in patterns:
matches = re.finditer(pattern, content)
for match in matches:
dependency = match.group(1)
self.project_analysis.add_dependency(file_path, dependency)
async def _analyze_java_dependencies(self, file_path: str, content: str):
"""分析Java文件依赖"""
import re
patterns = [
r'import\s+([^;]+)',
r'extends\s+([^\s{]+)',
r'implements\s+([^{]+)'
]
for pattern in patterns:
matches = re.finditer(pattern, content)
for match in matches:
dependency = match.group(1)
self.project_analysis.add_dependency(file_path, dependency)
async def _analyze_python_dependencies(self, file_path: str, content: str):
"""分析Python文件依赖"""
import re
patterns = [
r'(?:from|import)\s+([^\s]+)',
r'__import__\([\'"]([^\'"]+)[\'"]\)'
]
for pattern in patterns:
matches = re.finditer(pattern, content)
for match in matches:
dependency = match.group(1)
self.project_analysis.add_dependency(file_path, dependency)
async def _analyze_js_dependencies(self, file_path: str, content: str):
"""分析JavaScript文件依赖"""
import re
patterns = [
r'(?:import|require)\s*\([\'"]([^\'"]+)[\'"]\)',
r'import\s+.*\s+from\s+[\'"]([^\'"]+)[\'"]',
r'import\s+[\'"]([^\'"]+)[\'"]'
]
for pattern in patterns:
matches = re.finditer(pattern, content)
for match in matches:
dependency = match.group(1)
self.project_analysis.add_dependency(file_path, dependency)
async def _analyze_shared_variables(self, file_path: str, content: str):
"""分析共享变量"""
try:
language = Path(file_path).suffix[1:]
# 根据不同语言分析共享变量
if language == 'php':
await self._analyze_php_shared_vars(file_path, content)
elif language == 'java':
await self._analyze_java_shared_vars(file_path, content)
elif language == 'py':
await self._analyze_python_shared_vars(file_path, content)
elif language == 'js':
await self._analyze_js_shared_vars(file_path, content)
except Exception as e:
logger.error(f"分析共享变量时出错 {file_path}: {str(e)}")
async def _analyze_php_shared_vars(self, file_path: str, content: str):
"""分析PHP共享变量"""
import re
patterns = [
r'\$GLOBALS\[[\'"](\w+)[\'"]\]',
r'\$_(?:GET|POST|REQUEST|SESSION|COOKIE)\[[\'"](\w+)[\'"]\]',
r'global\s+\$(\w+)'
]
for pattern in patterns:
matches = re.finditer(pattern, content)
for match in matches:
var_name = match.group(1)
self.project_analysis.add_shared_variable(file_path, var_name)
async def _analyze_java_shared_vars(self, file_path: str, content: str):
"""分析Java共享变量"""
import re
patterns = [
r'static\s+(?:final\s+)?(?:\w+)\s+(\w+)',
r'public\s+(?:static\s+)?(?:\w+)\s+(\w+)'
]
for pattern in patterns:
matches = re.finditer(pattern, content)
for match in matches:
var_name = match.group(1)
self.project_analysis.add_shared_variable(file_path, var_name)
async def _analyze_python_shared_vars(self, file_path: str, content: str):
"""分析Python共享变量"""
import re
patterns = [
r'global\s+(\w+)',
r'(\w+)\s*=\s*[^=]' # 简单的全局变量定义
]
for pattern in patterns:
matches = re.finditer(pattern, content)
for match in matches:
var_name = match.group(1)
self.project_analysis.add_shared_variable(file_path, var_name)
async def _analyze_js_shared_vars(self, file_path: str, content: str):
"""分析JavaScript共享变量"""
import re
patterns = [
r'(?:var|let|const)\s+(\w+)\s*=',
r'window\.(\w+)\s*=',
r'global\.(\w+)\s*='
]
for pattern in patterns:
matches = re.finditer(pattern, content)
for match in matches:
var_name = match.group(1)
self.project_analysis.add_shared_variable(file_path, var_name)
async def _analyze_function_calls(self, file_path: str, content: str):
"""分析函数调用关系"""
try:
language = Path(file_path).suffix[1:]
# 根据不同语言分析函数调用
if language == 'php':
await self._analyze_php_function_calls(file_path, content)
elif language == 'java':
await self._analyze_java_function_calls(file_path, content)
elif language == 'py':
await self._analyze_python_function_calls(file_path, content)
elif language == 'js':
await self._analyze_js_function_calls(file_path, content)
except Exception as e:
logger.error(f"分析函数调用时出错 {file_path}: {str(e)}")
async def _analyze_php_function_calls(self, file_path: str, content: str):
"""分析PHP函数调用"""
import re
pattern = r'(?:function\s+(\w+)|(\w+)\s*\()'
matches = re.finditer(pattern, content)
for match in matches:
func_name = match.group(1) or match.group(2)
# 在项目中查找调用此函数的文件
await self._find_function_callers(file_path, func_name)
async def _analyze_java_function_calls(self, file_path: str, content: str):
"""分析Java函数调用"""
import re
pattern = r'(?:public|private|protected)?\s*(?:static)?\s*\w+\s+(\w+)\s*\([^)]*\)'
matches = re.finditer(pattern, content)
for match in matches:
func_name = match.group(1)
await self._find_function_callers(file_path, func_name)
async def _analyze_python_function_calls(self, file_path: str, content: str):
"""分析Python函数调用"""
import re
pattern = r'def\s+(\w+)\s*\('
matches = re.finditer(pattern, content)
for match in matches:
func_name = match.group(1)
await self._find_function_callers(file_path, func_name)
async def _analyze_js_function_calls(self, file_path: str, content: str):
"""分析JavaScript函数调用"""
import re
pattern = r'(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?function)'
matches = re.finditer(pattern, content)
for match in matches:
func_name = match.group(1) or match.group(2)
await self._find_function_callers(file_path, func_name)
async def _find_function_callers(self, source_file: str, function_name: str):
"""查找调用指定函数的文件"""
# 这里可以实现更复杂的函数调用分析
# 当前简单记录函数定义所在文件
self.project_analysis.add_function_call(source_file, function_name)
code_audit_service = CodeAuditService()
# 添加请求体模型
class ConfigureRequest(BaseModel):
api_key: str
api_base: Optional[str] = None
model: Optional[str] = None # 添加模型选择
# 添加获取可用模型的API
@app.get("/api/models")
async def get_available_models():
"""获取当前API地址支持的模型列表"""
try:
# 确保服务已初始化
await code_audit_service.ensure_initialized()
api_base = code_audit_service.api_base or settings.OPENAI_API_BASE
available_models = await settings.get_models_for_api(
api_base,
code_audit_service.openai_api_key or settings.OPENAI_API_KEY
)
current_model = code_audit_service.model or settings.OPENAI_MODEL
logger.info(f"当前API地址: {api_base}")
logger.info(f"可用模型: {available_models}")
logger.info(f"当前使用的模型: {current_model}")
return {
"models": available_models,
"current_model": current_model
}
except Exception as e:
logger.error(f"获取模型列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 修改配置路由
@app.post("/api/configure")
async def configure_api(config: ConfigureRequest):
"""配置OpenAI API设置"""
try:
await code_audit_service.configure_openai(
config.api_key,
config.api_base,
config.model
)
return {
"status": "success",
"message": "API配置已更新",
"model": code_audit_service.model
}
except Exception as e:
logger.error(f"配置更新失败: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
@app.post("/api/audit")
async def audit_code(
file: UploadFile = File(...),
api_key: str = None,
api_base: str = None
):
"""审计代码支持自定义API设置"""
try:
# 确保服务已初始化
await code_audit_service.ensure_initialized()
content = await file.read()
code = content.decode()
# 获取文件扩展名并检查支持的类型
file_extension = os.path.splitext(file.filename)[1].lower()
language = code_audit_service._check_file_type(file.filename)
logger.info(f"开始分析{file.filename}")
result = await code_audit_service.analyze_code(code, language, api_key, api_base)
return result
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="文件编码错误")
except Exception as e:
logger.error(f"处理文件时发生错误: {str(e)}")
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/audit/project")
async def audit_project(
project: UploadFile = File(...),
api_key: str = None,
api_base: str = None
):
"""审计整个项目代码"""
try:
# 确保服务已初始化
await code_audit_service.ensure_initialized()
# 验证文件类型
if not project.filename.endswith('.zip'):
raise HTTPException(status_code=400, detail="请上传ZIP格式的项目文件")
# 分析项目
results = await code_audit_service.analyze_project(project)
return {
"status": "success",
"results": results
}
except Exception as e:
logger.error(f"项目审计过程中发生错误: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/")
async def root():
"""
根路径处理程序返回API基本信息
"""
return {
"name": "代码审计工具API",
"version": "1.0.0",
"status": "running",
"endpoints": {
"audit": "/api/audit",
"configure": "/api/configure",
"health": "/health",
"docs": "/docs"
}
}
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""
全局异常处理
"""
error_msg = str(exc)
logger.error(f"发生错误: {error_msg}")
return JSONResponse(
status_code=500,
content={
"error": "内部服务器错误",
"detail": error_msg
}
)
@app.middleware("http")
async def log_requests(request, call_next):
"""
请求日志中间件
"""
logger.info(f"收到请求: {request.method} {request.url}")
response = await call_next(request)
logger.info(f"响应状态码: {response.status_code}")
return response
@app.get("/ui")
async def serve_spa():
"""
服务前端单页应用
"""
return FileResponse(os.path.join(static_dir, "index.html"))