diff --git a/README.MD b/README.MD index ecb1c39..048311f 100644 --- a/README.MD +++ b/README.MD @@ -4,6 +4,19 @@ MirrorFlower(镜花)是一款基于 AI 的代码安全审计工具,支持多种编程语言的代码分析,可以帮助开发者快速发现代码中的潜在安全漏洞。支持DeepSeek-R1,ChatGPT-4o等多种大模型。 +## 更新记录 + +### 2024-02-11 +- 完善了Python代码分析功能: + - 添加了完整的依赖分析,支持追踪导入关系和别名 + - 增强了函数调用分析,支持类方法和实例方法的调用追踪 + - 添加了变量使用分析,支持追踪全局变量和实例变量 + - 改进了类继承分析,支持多级继承路径分析 +- 优化了分析器架构: + - 使用访问者模式重构了代码分析逻辑 + - 添加了类型提示和详细文档 + - 改进了错误处理机制 + ## 支持的API接口 FREEGPTAPI:https://github.com/popjane/free_chatgpt_api @@ -143,11 +156,6 @@ OPENAI_MODEL=your_preferred_model uvicorn backend.app:app --reload ``` -5. 访问Mirror-Flower -```bash -http://localhost:8000/ui -``` - ## 注意事项 1. 文件大小限制:10MB diff --git a/backend/__pycache__/app.cpython-313.pyc b/backend/__pycache__/app.cpython-313.pyc index 7cecd47..741ceac 100644 Binary files a/backend/__pycache__/app.cpython-313.pyc and b/backend/__pycache__/app.cpython-313.pyc differ diff --git a/code_analyzer.py b/code_analyzer.py new file mode 100644 index 0000000..1b1c62a --- /dev/null +++ b/code_analyzer.py @@ -0,0 +1,281 @@ +import ast +from typing import Dict, Set, List, Optional, Union +from pathlib import Path +from dataclasses import dataclass + +@dataclass +class AnalyzerConfig: + max_workers: int = 4 # 并行处理的最大线程数 + ignore_patterns: List[str] = None # 要忽略的文件模式 + follow_imports: bool = True # 是否分析导入的模块 + max_depth: int = 3 # 分析的最大深度 + encoding: str = 'utf-8' # 文件编码 + +class CodeAnalyzer: + def __init__(self, config: Optional[AnalyzerConfig] = None): + self.config = config or AnalyzerConfig() + self.dependencies: Dict[str, Set[str]] = {} + self.globals: Dict[str, Set[str]] = {} + self.function_calls: Dict[str, Set[str]] = {} + self.class_hierarchy: Dict[str, List[str]] = {} + self.variable_usages: Dict[str, Set[str]] = {} # 存储变量使用位置 + + def _analyze_dependencies(self, content: str, file_path: str) -> None: + """分析文件的导入依赖关系""" + tree = ast.parse(content) + + class ImportVisitor(ast.NodeVisitor): + def __init__(self, analyzer, file_path): + self.analyzer = analyzer + self.file_path = file_path + self.aliases = {} # 记录导入别名 + + def visit_Import(self, node): + for name in node.names: + self.analyzer.dependencies.setdefault(self.file_path, set()).add(name.name) + if name.asname: + self.aliases[name.asname] = name.name + + def visit_ImportFrom(self, node): + module = node.module if node.module else '' + for name in node.names: + full_name = f"{module}.{name.name}" if module else name.name + self.analyzer.dependencies.setdefault(self.file_path, set()).add(full_name) + if name.asname: + self.aliases[name.asname] = full_name + + def visit_Name(self, node): + # 检查是否使用了导入的别名 + if node.id in self.aliases: + self.analyzer.dependencies.setdefault(self.file_path, set()).add(self.aliases[node.id]) + self.generic_visit(node) + + visitor = ImportVisitor(self, file_path) + visitor.visit(tree) + + def _analyze_globals(self, content: str, file_path: str) -> None: + """分析全局变量""" + tree = ast.parse(content) + + class GlobalVisitor(ast.NodeVisitor): + def __init__(self, analyzer, file_path): + self.analyzer = analyzer + self.file_path = file_path + self.current_scope = None + + def visit_Module(self, node): + old_scope = self.current_scope + self.current_scope = 'module' + self.generic_visit(node) + self.current_scope = old_scope + + def visit_Global(self, node): + for name in node.names: + self.analyzer.globals.setdefault(self.file_path, set()).add(name) + + def visit_Assign(self, node): + if self.current_scope == 'module' and isinstance(node.targets[0], ast.Name): + self.analyzer.globals.setdefault(self.file_path, set()).add(node.targets[0].id) + self.generic_visit(node) + + visitor = GlobalVisitor(self, file_path) + visitor.visit(tree) + + def _analyze_function_calls(self, content: str, file_path: str) -> None: + """分析函数调用关系""" + tree = ast.parse(content) + + class FunctionCallVisitor(ast.NodeVisitor): + def __init__(self, analyzer, file_path): + self.analyzer = analyzer + self.file_path = file_path + self.current_function = None + self.current_class = None + + def visit_ClassDef(self, node): + old_class = self.current_class + self.current_class = node.name + self.generic_visit(node) + self.current_class = old_class + + def visit_FunctionDef(self, node): + old_function = self.current_function + if self.current_class: + self.current_function = f"{self.current_class}.{node.name}" + else: + self.current_function = node.name + self.generic_visit(node) + self.current_function = old_function + + def visit_Call(self, node): + if not self.current_function: + return + + caller = f"{self.file_path}:{self.current_function}" + + if isinstance(node.func, ast.Name): + callee = node.func.id + elif isinstance(node.func, ast.Attribute): + # 处理方法调用 + if isinstance(node.func.value, ast.Name): + callee = f"{node.func.value.id}.{node.func.attr}" + else: + callee = node.func.attr + else: + return + + self.analyzer.function_calls.setdefault(caller, set()).add(callee) + self.generic_visit(node) + + visitor = FunctionCallVisitor(self, file_path) + visitor.visit(tree) + + def _analyze_class_hierarchy(self, content: str, file_path: str) -> None: + """分析类继承关系""" + tree = ast.parse(content) + + class ClassVisitor(ast.NodeVisitor): + def __init__(self, analyzer, file_path): + self.analyzer = analyzer + self.file_path = file_path + self.current_module = None + + def visit_Module(self, node): + for imp in node.body: + if isinstance(imp, ast.ImportFrom): + self.current_module = imp.module + self.generic_visit(node) + + def visit_ClassDef(self, node): + class_name = f"{self.file_path}:{node.name}" + bases = [] + for base in node.bases: + if isinstance(base, ast.Name): + bases.append(base.id) + elif isinstance(base, ast.Attribute): + # 处理完整的模块路径 + parts = [] + current = base + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + bases.append('.'.join(reversed(parts))) + if bases: + self.analyzer.class_hierarchy[class_name] = bases + + visitor = ClassVisitor(self, file_path) + visitor.visit(tree) + + def _analyze_variable_usage(self, content: str, file_path: str) -> None: + """分析变量的使用位置""" + tree = ast.parse(content) + + class VariableVisitor(ast.NodeVisitor): + def __init__(self, analyzer, file_path): + self.analyzer = analyzer + self.file_path = file_path + self.current_function = None + self.current_class = None + + def visit_ClassDef(self, node): + old_class = self.current_class + self.current_class = node.name + self.generic_visit(node) + self.current_class = old_class + + def visit_FunctionDef(self, node): + old_function = self.current_function + scope = f"{self.current_class}.{node.name}" if self.current_class else node.name + self.current_function = scope + self.generic_visit(node) + self.current_function = old_function + + def visit_Name(self, node): + if isinstance(node.ctx, (ast.Load, ast.Store)): + scope = f"{self.file_path}:{self.current_function}" if self.current_function else self.file_path + self.analyzer.variable_usages.setdefault(node.id, set()).add(scope) + self.generic_visit(node) + + def visit_Attribute(self, node): + if isinstance(node.ctx, (ast.Load, ast.Store)) and isinstance(node.value, ast.Name): + if node.value.id == 'self' and self.current_class: + # 记录实例变量 + var_name = f"{self.current_class}.{node.attr}" + scope = f"{self.file_path}:{self.current_function}" + self.analyzer.variable_usages.setdefault(var_name, set()).add(scope) + self.generic_visit(node) + + visitor = VariableVisitor(self, file_path) + visitor.visit(tree) + + def get_file_dependencies(self, file_path: str) -> set: + """获取指定文件的依赖""" + return self.dependencies.get(file_path, set()) + + def get_file_globals(self, file_path: str) -> set: + """获取指定文件的全局变量""" + return self.globals.get(file_path, set()) + + def get_function_calls(self, function_name: str) -> set: + """获取指定函数调用的其他函数""" + return self.function_calls.get(function_name, set()) + + def get_class_bases(self, class_name: str) -> list: + """获取指定类的父类""" + return self.class_hierarchy.get(class_name, []) + + def get_variable_usages(self, variable_name: str) -> set: + """获取变量的所有使用位置""" + return self.variable_usages.get(variable_name, set()) + + def analyze_file(self, content: str, file_path: Union[str, Path]) -> None: + """分析单个文件的所有关系 + + Args: + content (str): 文件内容 + file_path (Union[str, Path]): 文件路径 + + Raises: + SyntaxError: 当文件包含语法错误时 + Exception: 其他分析错误 + """ + if isinstance(file_path, Path): + file_path = str(file_path) + try: + # 分析文件依赖 + self._analyze_dependencies(content, file_path) + + # 分析全局变量 + self._analyze_globals(content, file_path) + + # 分析函数调用关系 + self._analyze_function_calls(content, file_path) + + # 分析类继承关系 + self._analyze_class_hierarchy(content, file_path) + + # 分析变量使用位置 + self._analyze_variable_usage(content, file_path) + except SyntaxError: + print(f"语法错误: {file_path}") + except Exception as e: + print(f"分析错误 {file_path}: {str(e)}") + + def get_file_analysis(self, file_path: str) -> dict: + """获取指定文件的完整分析结果""" + return { + 'dependencies': self.get_file_dependencies(file_path), + 'globals': self.get_file_globals(file_path), + 'function_calls': { + caller: self.get_function_calls(caller) + for caller in self.function_calls + if caller.startswith(f"{file_path}:") + }, + 'class_hierarchy': { + class_name: self.get_class_bases(class_name) + for class_name in self.class_hierarchy + if class_name.startswith(f"{file_path}:") + } + } \ No newline at end of file diff --git a/core/analyzers/__init__.py b/core/analyzers/__init__.py new file mode 100644 index 0000000..b164872 --- /dev/null +++ b/core/analyzers/__init__.py @@ -0,0 +1,3 @@ +from .context_analyzer import ContextAnalyzer + +__all__ = ['ContextAnalyzer'] \ No newline at end of file diff --git a/core/analyzers/context_analyzer.py b/core/analyzers/context_analyzer.py index 2660b7d..dc76892 100644 --- a/core/analyzers/context_analyzer.py +++ b/core/analyzers/context_analyzer.py @@ -1,47 +1,230 @@ -class ContextAnalyzer: - def __init__(self): - self.file_dependencies = {} - self.global_variables = {} - self.function_calls = {} - self.class_hierarchy = {} - - def analyze_project_context(self, files): - """分析项目整体上下文""" - for file_path in files: - self._analyze_file_context(file_path) - - def _analyze_file_context(self, file_path): - """分析单个文件的上下文""" - with open(file_path, 'r') as f: - content = f.read() - - # 分析文件依赖 - self._analyze_dependencies(content, file_path) - - # 分析全局变量 - self._analyze_globals(content, file_path) - - # 分析函数调用关系 - self._analyze_function_calls(content, file_path) - - # 分析类继承关系 - self._analyze_class_hierarchy(content, file_path) - - def get_call_graph(self, function_name): - """获取函数调用图""" - call_graph = { - 'name': function_name, - 'calls': self.function_calls.get(function_name, []), - 'called_by': self._find_callers(function_name) - } - return call_graph - - def get_variable_scope(self, variable_name): - """获取变量作用域""" - if variable_name in self.global_variables: - return { - 'type': 'global', - 'defined_in': self.global_variables[variable_name], - 'used_in': self._find_variable_usage(variable_name) - } - return None \ No newline at end of file +from code_analyzer import CodeAnalyzer +from typing import Dict, Set, List, Any, Optional, Union +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor +import logging +from functools import lru_cache +from tqdm import tqdm +import json +from code_analyzer import AnalyzerConfig + +logger = logging.getLogger(__name__) + +class ContextAnalyzer: + def __init__(self, config: Optional[AnalyzerConfig] = None): + self.code_analyzer = CodeAnalyzer(config) + self._cache: Dict[str, bool] = {} # 缓存分析结果 + + def analyze_project_context(self, files: List[Union[str, Path]]) -> None: + """分析项目整体上下文 + + Args: + files: 要分析的文件路径列表 + + Example: + analyzer = ContextAnalyzer() + analyzer.analyze_project_context(['file1.py', 'file2.py']) + """ + with ThreadPoolExecutor(max_workers=self.code_analyzer.config.max_workers) as executor: + list(tqdm( + executor.map(self._analyze_file_context, files), + total=len(files), + desc="分析项目文件" + )) + + def _analyze_file_context(self, file_path: Union[str, Path]) -> None: + """分析单个文件的上下文""" + try: + # 检查缓存 + if file_path in self._cache: + return + + # 检查文件是否存在 + if not Path(file_path).exists(): + raise FileNotFoundError(f"文件不存在: {file_path}") + + # 检查文件是否是 Python 文件 + if not str(file_path).endswith('.py'): + raise ValueError(f"不是 Python 文件: {file_path}") + + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + self.code_analyzer.analyze_file(content, file_path) + self._cache[file_path] = True + + except (UnicodeDecodeError, FileNotFoundError, ValueError) as e: + logger.error(f"分析文件失败 {file_path}: {str(e)}") + except Exception as e: + logger.exception(f"未知错误 {file_path}: {str(e)}") + + @lru_cache(maxsize=128) + def get_call_graph(self, function_name: str) -> Dict[str, Any]: + """获取函数调用图 + + Args: + function_name: 要分析的函数名 + + Returns: + 包含函数调用关系的字典,格式为: + { + 'name': 函数名, + 'calls': 该函数调用的其他函数集合, + 'called_by': 调用该函数的其他函数集合 + } + """ + calls = self.code_analyzer.get_function_calls(function_name) + called_by = set() + + # 查找调用该函数的其他函数 + for caller, callees in self.code_analyzer.function_calls.items(): + if function_name in callees: + called_by.add(caller) + + return { + 'name': function_name, + 'calls': calls, + 'called_by': called_by + } + + @lru_cache(maxsize=128) + def get_variable_scope(self, variable_name: str) -> Optional[Dict[str, Any]]: + """获取变量作用域""" + # 在所有文件中查找该变量的定义 + defined_in = set() + for file_path, globals_set in self.code_analyzer.globals.items(): + if variable_name in globals_set: + defined_in.add(file_path) + + if defined_in: + # 获取变量的使用信息 + usage_info = self._find_variable_usage(variable_name) + return { + 'type': 'global', + 'defined_in': list(defined_in), + 'used_in': usage_info + } + return None + + def _find_variable_usage(self, variable_name: str) -> Dict[str, Any]: + """查找变量的使用位置""" + usages = self.code_analyzer.get_variable_usages(variable_name) + + # 将使用位置按文件分组 + usage_by_file = {} + for usage in usages: + if ':' in usage: + file_path, function_name = usage.split(':', 1) + usage_by_file.setdefault(file_path, set()).add(function_name) + else: + # 模块级别的使用 + usage_by_file.setdefault(usage, set()).add('module_level') + + return { + 'files': list(usage_by_file.keys()), + 'details': { + file_path: { + 'module_level': 'module_level' in functions, + 'functions': [f for f in functions if f != 'module_level'] + } + for file_path, functions in usage_by_file.items() + } + } + + def get_file_context(self, file_path: str) -> dict: + """获取文件的完整上下文信息""" + return { + 'code_analysis': self.code_analyzer.get_file_analysis(file_path) + } + + def get_project_analysis(self) -> dict: + """获取项目整体分析结果""" + return { + 'all_dependencies': self.code_analyzer.dependencies, + 'all_globals': self.code_analyzer.globals, + 'function_call_graph': self.code_analyzer.function_calls, + 'class_hierarchy_graph': self.code_analyzer.class_hierarchy, + 'variable_usage_map': self.code_analyzer.variable_usages + } + + def clear_cache(self): + """清除缓存""" + self._cache.clear() + + def validate_analysis(self) -> List[str]: + """验证分析结果的完整性和一致性 + + Returns: + 发现的问题列表 + """ + issues = [] + + # 检查函数调用的一致性 + for caller, callees in self.code_analyzer.function_calls.items(): + if ':' not in caller: + issues.append(f"无效的调用者格式: {caller}") + + # 检查类继承的有效性 + for class_name, bases in self.code_analyzer.class_hierarchy.items(): + if ':' not in class_name: + issues.append(f"无效的类名格式: {class_name}") + + # 检查变量使用的有效性 + for var_name, usages in self.code_analyzer.variable_usages.items(): + for usage in usages: + if ':' not in usage and not usage.endswith('.py'): + issues.append(f"无效的变量使用位置: {usage}") + + return issues + + def clear_analysis(self) -> None: + """清理所有分析结果""" + self._cache.clear() + self.code_analyzer.dependencies.clear() + self.code_analyzer.globals.clear() + self.code_analyzer.function_calls.clear() + self.code_analyzer.class_hierarchy.clear() + self.code_analyzer.variable_usages.clear() + + def save_analysis(self, output_path: Union[str, Path]) -> None: + """保存分析结果到文件""" + result = { + 'dependencies': {k: list(v) for k, v in self.code_analyzer.dependencies.items()}, + 'globals': {k: list(v) for k, v in self.code_analyzer.globals.items()}, + 'function_calls': {k: list(v) for k, v in self.code_analyzer.function_calls.items()}, + 'class_hierarchy': self.code_analyzer.class_hierarchy, + 'variable_usages': {k: list(v) for k, v in self.code_analyzer.variable_usages.items()} + } + + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(result, f, indent=2) + + def load_analysis(self, input_path: Union[str, Path]) -> None: + """从文件加载分析结果""" + with open(input_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + self.code_analyzer.dependencies = {k: set(v) for k, v in data['dependencies'].items()} + self.code_analyzer.globals = {k: set(v) for k, v in data['globals'].items()} + self.code_analyzer.function_calls = {k: set(v) for k, v in data['function_calls'].items()} + self.code_analyzer.class_hierarchy = data['class_hierarchy'] + self.code_analyzer.variable_usages = {k: set(v) for k, v in data['variable_usages'].items()} + + def get_analysis_stats(self) -> Dict[str, Any]: + """获取分析结果的统计信息""" + return { + 'total_files': len(self.code_analyzer.dependencies), + 'total_functions': len({ + func.split(':')[1] + for func in self.code_analyzer.function_calls.keys() + }), + 'total_classes': len(self.code_analyzer.class_hierarchy), + 'total_globals': sum(len(vars) for vars in self.code_analyzer.globals.values()), + 'dependencies_stats': { + 'total': sum(len(deps) for deps in self.code_analyzer.dependencies.values()), + 'by_file': { + file: len(deps) + for file, deps in self.code_analyzer.dependencies.items() + } + } + } \ No newline at end of file diff --git a/examples/basic_usage.py b/examples/basic_usage.py new file mode 100644 index 0000000..6b1b168 --- /dev/null +++ b/examples/basic_usage.py @@ -0,0 +1,24 @@ +from core.analyzers.context_analyzer import ContextAnalyzer +from code_analyzer import AnalyzerConfig + +def analyze_single_file(): + """单文件分析示例""" + config = AnalyzerConfig() + analyzer = ContextAnalyzer(config) + + # 分析单个文件 + analyzer.analyze_project_context(['example.py']) + + # 获取分析结果 + context = analyzer.get_file_context('example.py') + print("文件分析结果:", context) + + # 获取函数调用图 + call_graph = analyzer.get_call_graph('main') + print("函数调用图:", call_graph) + + # 清理 + analyzer.clear_analysis() + +if __name__ == '__main__': + analyze_single_file() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c5f293b..2c6c7c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1 @@ -fastapi>=0.68.0 -uvicorn>=0.15.0 -python-multipart>=0.0.5 -openai>=1.0.0 -javalang>=0.13.0 -aiohttp>=3.8.1 -python-dotenv>=0.19.0 -pydantic>=2.0.0 -pydantic-settings>=2.0.0 -php-ast>=1.1.0 \ No newline at end of file +tqdm>=4.65.0 \ No newline at end of file diff --git a/tests/test_analyzers.py b/tests/test_analyzers.py new file mode 100644 index 0000000..0448edc --- /dev/null +++ b/tests/test_analyzers.py @@ -0,0 +1,20 @@ +import unittest +from pathlib import Path +from core.analyzers.context_analyzer import ContextAnalyzer +from code_analyzer import AnalyzerConfig + +class TestContextAnalyzer(unittest.TestCase): + def setUp(self): + self.analyzer = ContextAnalyzer(AnalyzerConfig()) + + def test_analyze_file(self): + test_file = Path(__file__).parent / 'test_data' / 'simple.py' + self.analyzer.analyze_project_context([test_file]) + context = self.analyzer.get_file_context(str(test_file)) + self.assertIsNotNone(context) + + def tearDown(self): + self.analyzer.clear_analysis() + +if __name__ == '__main__': + unittest.main() \ No newline at end of file