mirror of
https://github.com/SunZhimin2021/AIPentest.git
synced 2025-06-20 18:00:18 +00:00
161 lines
6.5 KiB
Python
161 lines
6.5 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
|
||
# 全局配置变量
|
||
SAFE_LABEL = "LABEL_0" # 表示安全/无害的标签
|
||
|
||
def analyze_classification_results(csv_file_path, safe_label=None):
|
||
"""
|
||
分析二分类结果,计算混淆矩阵和召回率
|
||
|
||
参数:
|
||
csv_file_path: CSV文件路径
|
||
safe_label: 表示安全/无害的标签,默认使用全局变量SAFE_LABEL
|
||
|
||
返回:
|
||
包含统计结果的字典
|
||
"""
|
||
|
||
# 如果没有指定safe_label,使用全局变量
|
||
if safe_label is None:
|
||
safe_label = SAFE_LABEL
|
||
|
||
# 读取CSV文件
|
||
try:
|
||
df = pd.read_csv(csv_file_path)
|
||
print(f"成功读取数据,共 {len(df)} 行记录")
|
||
print(f"列名: {list(df.columns)}")
|
||
except Exception as e:
|
||
print(f"读取文件出错: {e}")
|
||
return None
|
||
|
||
# 检查数据类型并进行相应处理
|
||
print(f"\nsource_label数据类型: {df['source_label'].dtype}")
|
||
print(f"source_label唯一值: {df['source_label'].unique()}")
|
||
print(f"classified_label数据类型: {df['classified_label'].dtype}")
|
||
print(f"classified_label唯一值: {df['classified_label'].unique()}")
|
||
|
||
# 数据预处理:转换标签为二进制
|
||
# 根据实际数据,source_label已经是0和1,直接使用
|
||
# 0表示benign(无害),1表示有害
|
||
df['true_label'] = df['source_label']
|
||
|
||
# classified_label: safe_label=0, 其他=1
|
||
df['pred_label'] = df['classified_label'].apply(lambda x: 0 if str(x).strip().upper() == safe_label.upper() else 1)
|
||
|
||
# 打印标签分布
|
||
print("\n=== 标签分布 ===")
|
||
print("真实标签分布:")
|
||
print(f" benign (0): {sum(df['true_label'] == 0)} 个")
|
||
print(f" 有害 (1): {sum(df['true_label'] == 1)} 个")
|
||
|
||
print("\n预测标签分布:")
|
||
print(f" {safe_label} (0): {sum(df['pred_label'] == 0)} 个")
|
||
print(f" 非{safe_label} (1): {sum(df['pred_label'] == 1)} 个")
|
||
|
||
# 计算混淆矩阵
|
||
# TP: 真实为1,预测为1 (正确识别为有害)
|
||
# TN: 真实为0,预测为0 (正确识别为无害)
|
||
# FP: 真实为0,预测为1 (误判为有害)
|
||
# FN: 真实为1,预测为0 (漏判为无害)
|
||
|
||
TP = sum((df['true_label'] == 1) & (df['pred_label'] == 1))
|
||
TN = sum((df['true_label'] == 0) & (df['pred_label'] == 0))
|
||
FP = sum((df['true_label'] == 0) & (df['pred_label'] == 1))
|
||
FN = sum((df['true_label'] == 1) & (df['pred_label'] == 0))
|
||
|
||
# 计算性能指标
|
||
recall = TP / (TP + FN) if (TP + FN) > 0 else 0 # 召回率/敏感度
|
||
precision = TP / (TP + FP) if (TP + FP) > 0 else 0 # 精确率
|
||
accuracy = (TP + TN) / (TP + TN + FP + FN) # 准确率
|
||
specificity = TN / (TN + FP) if (TN + FP) > 0 else 0 # 特异性
|
||
false_positive_rate = FP / (FP + TN) if (FP + TN) > 0 else 0 # 误报率
|
||
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||
|
||
# 输出结果
|
||
print("\n=== 混淆矩阵 ===")
|
||
print(f"True Positive (TP): {TP}")
|
||
print(f"True Negative (TN): {TN}")
|
||
print(f"False Positive (FP): {FP}")
|
||
print(f"False Negative (FN): {FN}")
|
||
|
||
print("\n=== 性能指标 ===")
|
||
print(f"召回率 (Recall): {recall:.4f}")
|
||
print(f"精确率 (Precision): {precision:.4f}")
|
||
print(f"准确率 (Accuracy): {accuracy:.4f}")
|
||
print(f"特异性 (Specificity): {specificity:.4f}")
|
||
print(f"误报率 (FPR): {false_positive_rate:.4f}")
|
||
print(f"F1分数: {f1_score:.4f}")
|
||
|
||
print("\n=== 混淆矩阵表格 ===")
|
||
print(" 预测")
|
||
print(f" {safe_label}(0) 非{safe_label}(1)")
|
||
print(f"真实 benign(0) {TN:4d} {FP:4d}")
|
||
print(f" 有害(1) {FN:4d} {TP:4d}")
|
||
|
||
# 详细分析
|
||
print("\n=== 详细分析 ===")
|
||
if FP > 0:
|
||
print(f"误报(FP): {FP} 个良性样本被错误分类为有害")
|
||
fp_samples = df[(df['true_label'] == 0) & (df['pred_label'] == 1)]
|
||
print("误报样本示例:")
|
||
for idx, row in fp_samples.head(3).iterrows():
|
||
print(f" - 文本: {row['text'][:100]}...")
|
||
print(f" 预测分数: {row['classified_score']:.4f}")
|
||
|
||
if FN > 0:
|
||
print(f"\n漏报(FN): {FN} 个有害样本被错误分类为无害")
|
||
fn_samples = df[(df['true_label'] == 1) & (df['pred_label'] == 0)]
|
||
print("漏报样本示例:")
|
||
for idx, row in fn_samples.head(3).iterrows():
|
||
print(f" - 文本: {row['text'][:100]}...")
|
||
print(f" 预测分数: {row['classified_score']:.4f}")
|
||
|
||
# 分析预测分数分布
|
||
print("\n=== 预测分数分析 ===")
|
||
print(f"{safe_label}预测的平均分数: {df[df['pred_label'] == 0]['classified_score'].mean():.4f}")
|
||
print(f"非{safe_label}预测的平均分数: {df[df['pred_label'] == 1]['classified_score'].mean():.4f}")
|
||
|
||
# 返回结果字典
|
||
results = {
|
||
'TP': TP, 'TN': TN, 'FP': FP, 'FN': FN,
|
||
'recall': recall, 'precision': precision,
|
||
'accuracy': accuracy, 'specificity': specificity,
|
||
'false_positive_rate': false_positive_rate, 'f1_score': f1_score,
|
||
'data_summary': {
|
||
'total_samples': len(df),
|
||
'benign_samples': sum(df['true_label'] == 0),
|
||
'harmful_samples': sum(df['true_label'] == 1),
|
||
'safe_predictions': sum(df['pred_label'] == 0),
|
||
'unsafe_predictions': sum(df['pred_label'] == 1)
|
||
}
|
||
}
|
||
|
||
return results
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 全局配置:可以根据不同模型修改安全标签
|
||
#SAFE_LABEL = "LABEL_0" # 用于llama prompt guard的模型
|
||
#SAFE_LABEL = "LEGIT" # 用于Deepset模型
|
||
SAFE_LABEL = "SAFE" # 用于ProtectAI模型
|
||
|
||
# 替换为您的CSV文件路径
|
||
csv_file_path = "result/deepset_protectai_results.csv" # 请修改为实际的文件路径
|
||
|
||
print("开始分析分类结果...")
|
||
print(f"当前安全标签配置: {SAFE_LABEL}")
|
||
|
||
# 可以通过参数覆盖全局配置
|
||
# results = analyze_classification_results(csv_file_path, safe_label="SAFE")
|
||
results = analyze_classification_results(csv_file_path)
|
||
|
||
if results:
|
||
print(f"\n分析完成!")
|
||
print(f"准确率: {results['accuracy']:.4f}")
|
||
print(f"召回率: {results['recall']:.4f}")
|
||
print(f"精确率: {results['precision']:.4f}")
|
||
print(f"误报率: {results['false_positive_rate']:.4f}")
|
||
print(f"F1分数: {results['f1_score']:.4f}")
|
||
else:
|
||
print("分析失败,请检查文件路径和格式。") |