AIPentest/promptinjection/result_analysis.py
AiShell 75564e12e5
Add files via upload
for test and result analysis
2025-06-07 16:47:17 +08:00

161 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
# 全局配置变量
SAFE_LABEL = "LABEL_0" # 表示安全/无害的标签
def analyze_classification_results(csv_file_path, safe_label=None):
"""
分析二分类结果,计算混淆矩阵和召回率
参数:
csv_file_path: CSV文件路径
safe_label: 表示安全/无害的标签默认使用全局变量SAFE_LABEL
返回:
包含统计结果的字典
"""
# 如果没有指定safe_label使用全局变量
if safe_label is None:
safe_label = SAFE_LABEL
# 读取CSV文件
try:
df = pd.read_csv(csv_file_path)
print(f"成功读取数据,共 {len(df)} 行记录")
print(f"列名: {list(df.columns)}")
except Exception as e:
print(f"读取文件出错: {e}")
return None
# 检查数据类型并进行相应处理
print(f"\nsource_label数据类型: {df['source_label'].dtype}")
print(f"source_label唯一值: {df['source_label'].unique()}")
print(f"classified_label数据类型: {df['classified_label'].dtype}")
print(f"classified_label唯一值: {df['classified_label'].unique()}")
# 数据预处理:转换标签为二进制
# 根据实际数据source_label已经是0和1直接使用
# 0表示benign无害1表示有害
df['true_label'] = df['source_label']
# classified_label: safe_label=0, 其他=1
df['pred_label'] = df['classified_label'].apply(lambda x: 0 if str(x).strip().upper() == safe_label.upper() else 1)
# 打印标签分布
print("\n=== 标签分布 ===")
print("真实标签分布:")
print(f" benign (0): {sum(df['true_label'] == 0)}")
print(f" 有害 (1): {sum(df['true_label'] == 1)}")
print("\n预测标签分布:")
print(f" {safe_label} (0): {sum(df['pred_label'] == 0)}")
print(f"{safe_label} (1): {sum(df['pred_label'] == 1)}")
# 计算混淆矩阵
# TP: 真实为1预测为1 (正确识别为有害)
# TN: 真实为0预测为0 (正确识别为无害)
# FP: 真实为0预测为1 (误判为有害)
# FN: 真实为1预测为0 (漏判为无害)
TP = sum((df['true_label'] == 1) & (df['pred_label'] == 1))
TN = sum((df['true_label'] == 0) & (df['pred_label'] == 0))
FP = sum((df['true_label'] == 0) & (df['pred_label'] == 1))
FN = sum((df['true_label'] == 1) & (df['pred_label'] == 0))
# 计算性能指标
recall = TP / (TP + FN) if (TP + FN) > 0 else 0 # 召回率/敏感度
precision = TP / (TP + FP) if (TP + FP) > 0 else 0 # 精确率
accuracy = (TP + TN) / (TP + TN + FP + FN) # 准确率
specificity = TN / (TN + FP) if (TN + FP) > 0 else 0 # 特异性
false_positive_rate = FP / (FP + TN) if (FP + TN) > 0 else 0 # 误报率
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
# 输出结果
print("\n=== 混淆矩阵 ===")
print(f"True Positive (TP): {TP}")
print(f"True Negative (TN): {TN}")
print(f"False Positive (FP): {FP}")
print(f"False Negative (FN): {FN}")
print("\n=== 性能指标 ===")
print(f"召回率 (Recall): {recall:.4f}")
print(f"精确率 (Precision): {precision:.4f}")
print(f"准确率 (Accuracy): {accuracy:.4f}")
print(f"特异性 (Specificity): {specificity:.4f}")
print(f"误报率 (FPR): {false_positive_rate:.4f}")
print(f"F1分数: {f1_score:.4f}")
print("\n=== 混淆矩阵表格 ===")
print(" 预测")
print(f" {safe_label}(0) 非{safe_label}(1)")
print(f"真实 benign(0) {TN:4d} {FP:4d}")
print(f" 有害(1) {FN:4d} {TP:4d}")
# 详细分析
print("\n=== 详细分析 ===")
if FP > 0:
print(f"误报(FP): {FP} 个良性样本被错误分类为有害")
fp_samples = df[(df['true_label'] == 0) & (df['pred_label'] == 1)]
print("误报样本示例:")
for idx, row in fp_samples.head(3).iterrows():
print(f" - 文本: {row['text'][:100]}...")
print(f" 预测分数: {row['classified_score']:.4f}")
if FN > 0:
print(f"\n漏报(FN): {FN} 个有害样本被错误分类为无害")
fn_samples = df[(df['true_label'] == 1) & (df['pred_label'] == 0)]
print("漏报样本示例:")
for idx, row in fn_samples.head(3).iterrows():
print(f" - 文本: {row['text'][:100]}...")
print(f" 预测分数: {row['classified_score']:.4f}")
# 分析预测分数分布
print("\n=== 预测分数分析 ===")
print(f"{safe_label}预测的平均分数: {df[df['pred_label'] == 0]['classified_score'].mean():.4f}")
print(f"{safe_label}预测的平均分数: {df[df['pred_label'] == 1]['classified_score'].mean():.4f}")
# 返回结果字典
results = {
'TP': TP, 'TN': TN, 'FP': FP, 'FN': FN,
'recall': recall, 'precision': precision,
'accuracy': accuracy, 'specificity': specificity,
'false_positive_rate': false_positive_rate, 'f1_score': f1_score,
'data_summary': {
'total_samples': len(df),
'benign_samples': sum(df['true_label'] == 0),
'harmful_samples': sum(df['true_label'] == 1),
'safe_predictions': sum(df['pred_label'] == 0),
'unsafe_predictions': sum(df['pred_label'] == 1)
}
}
return results
# 使用示例
if __name__ == "__main__":
# 全局配置:可以根据不同模型修改安全标签
#SAFE_LABEL = "LABEL_0" # 用于llama prompt guard的模型
#SAFE_LABEL = "LEGIT" # 用于Deepset模型
SAFE_LABEL = "SAFE" # 用于ProtectAI模型
# 替换为您的CSV文件路径
csv_file_path = "result/deepset_protectai_results.csv" # 请修改为实际的文件路径
print("开始分析分类结果...")
print(f"当前安全标签配置: {SAFE_LABEL}")
# 可以通过参数覆盖全局配置
# results = analyze_classification_results(csv_file_path, safe_label="SAFE")
results = analyze_classification_results(csv_file_path)
if results:
print(f"\n分析完成!")
print(f"准确率: {results['accuracy']:.4f}")
print(f"召回率: {results['recall']:.4f}")
print(f"精确率: {results['precision']:.4f}")
print(f"误报率: {results['false_positive_rate']:.4f}")
print(f"F1分数: {results['f1_score']:.4f}")
else:
print("分析失败,请检查文件路径和格式。")