mirror of
https://github.com/SunZhimin2021/AIPentest.git
synced 2025-06-21 18:30:41 +00:00
Add files via upload
for test and result analysis
This commit is contained in:
parent
662f1f9694
commit
75564e12e5
249
promptinjection/guardtest.py
Normal file
249
promptinjection/guardtest.py
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
|
from transformers import pipeline
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def load_checkpoint(checkpoint_file):
|
||||||
|
"""加载断点信息"""
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
try:
|
||||||
|
with open(checkpoint_file, 'r', encoding='utf-8') as f:
|
||||||
|
checkpoint = json.load(f)
|
||||||
|
print(f"发现断点文件,从第 {checkpoint['last_processed'] + 1} 条记录开始继续处理")
|
||||||
|
return checkpoint
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取断点文件失败: {e}")
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save_checkpoint(checkpoint_file, last_processed, total_records):
|
||||||
|
"""保存断点信息"""
|
||||||
|
checkpoint = {
|
||||||
|
'last_processed': last_processed,
|
||||||
|
'total_records': total_records,
|
||||||
|
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
with open(checkpoint_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(checkpoint, f, ensure_ascii=False, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存断点文件失败: {e}")
|
||||||
|
|
||||||
|
def save_batch_results(results, output_file, mode='w'):
|
||||||
|
"""保存批量结果到CSV"""
|
||||||
|
try:
|
||||||
|
results_df = pd.DataFrame(results)
|
||||||
|
if mode == 'w':
|
||||||
|
results_df.to_csv(output_file, index=False, encoding='utf-8')
|
||||||
|
else: # append mode
|
||||||
|
results_df.to_csv(output_file, mode='a', header=False, index=False, encoding='utf-8')
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存结果失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 设置文件路径,可使用ProtectAI,deepset,metaguard86M
|
||||||
|
local_path = "./models/metaguard86M"
|
||||||
|
input_file = "deepset_prompt_injection_data.csv"
|
||||||
|
output_file = "./result/deepset_86M_results.csv"
|
||||||
|
checkpoint_file = "86M_checkpoint.json"
|
||||||
|
batch_size = 10 # 每批处理的记录数
|
||||||
|
|
||||||
|
# 检查模型路径是否存在
|
||||||
|
if not os.path.exists(local_path):
|
||||||
|
print(f"错误: 模型路径 {local_path} 不存在")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 加载断点信息
|
||||||
|
checkpoint = load_checkpoint(checkpoint_file)
|
||||||
|
start_idx = 0
|
||||||
|
if checkpoint:
|
||||||
|
start_idx = checkpoint['last_processed'] + 1
|
||||||
|
|
||||||
|
# 加载模型和分词器
|
||||||
|
print("正在加载模型和分词器...")
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(local_path)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(local_path)
|
||||||
|
|
||||||
|
# 创建分类器
|
||||||
|
classifier = pipeline(
|
||||||
|
"text-classification",
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
truncation=True,
|
||||||
|
max_length=512,
|
||||||
|
)
|
||||||
|
print("模型加载成功!")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"模型加载失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 读取CSV文件
|
||||||
|
if not os.path.exists(input_file):
|
||||||
|
print(f"错误: 输入文件 {input_file} 不存在")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"正在读取 {input_file}...")
|
||||||
|
df = pd.read_csv(input_file)
|
||||||
|
print(f"成功读取 {len(df)} 行数据")
|
||||||
|
|
||||||
|
# 检查必要的列是否存在
|
||||||
|
if 'text' not in df.columns:
|
||||||
|
print("错误: CSV文件中缺少 'text' 列")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果没有label列,创建一个空的
|
||||||
|
if 'label' not in df.columns:
|
||||||
|
df['label'] = ''
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取CSV文件失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果是断点续传,检查总记录数是否一致
|
||||||
|
if checkpoint and checkpoint['total_records'] != len(df):
|
||||||
|
print(f"警告: 当前文件记录数({len(df)})与断点记录数({checkpoint['total_records']})不一致")
|
||||||
|
response = input("是否要重新开始处理? (y/n): ")
|
||||||
|
if response.lower() == 'y':
|
||||||
|
start_idx = 0
|
||||||
|
if os.path.exists(output_file):
|
||||||
|
os.remove(output_file)
|
||||||
|
|
||||||
|
# 如果从头开始,清空之前的输出文件
|
||||||
|
if start_idx == 0 and os.path.exists(output_file):
|
||||||
|
os.remove(output_file)
|
||||||
|
print("已清空之前的输出文件")
|
||||||
|
|
||||||
|
# 创建结果列表
|
||||||
|
batch_results = []
|
||||||
|
total_processed = start_idx
|
||||||
|
|
||||||
|
# 记录分类开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
start_datetime = datetime.now()
|
||||||
|
print(f"分类开始时间: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
if start_idx > 0:
|
||||||
|
print(f"从第 {start_idx + 1} 条记录开始继续处理...")
|
||||||
|
else:
|
||||||
|
print("开始分类处理...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
for idx in range(start_idx, len(df)):
|
||||||
|
row = df.iloc[idx]
|
||||||
|
text = str(row['text']) # 确保是字符串类型
|
||||||
|
source_label = row.get('label', '') # 获取原始标签,如果不存在则为空
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 进行分类
|
||||||
|
prediction = classifier(text)
|
||||||
|
|
||||||
|
# 提取预测结果
|
||||||
|
predicted_label = prediction[0]['label']
|
||||||
|
predicted_score = prediction[0]['score']
|
||||||
|
|
||||||
|
# 添加到批量结果列表
|
||||||
|
batch_results.append({
|
||||||
|
'text': text,
|
||||||
|
'source_label': source_label,
|
||||||
|
'classified_label': predicted_label,
|
||||||
|
'classified_score': predicted_score
|
||||||
|
})
|
||||||
|
|
||||||
|
total_processed = idx
|
||||||
|
|
||||||
|
# 每处理batch_size条记录或到达最后一条记录时保存
|
||||||
|
if len(batch_results) >= batch_size or idx == len(df) - 1:
|
||||||
|
# 保存批量结果
|
||||||
|
mode = 'w' if idx < batch_size and start_idx == 0 else 'a'
|
||||||
|
if save_batch_results(batch_results, output_file, mode):
|
||||||
|
print(f"已处理并保存 {idx + 1}/{len(df)} 条数据")
|
||||||
|
|
||||||
|
# 保存断点
|
||||||
|
save_checkpoint(checkpoint_file, idx, len(df))
|
||||||
|
|
||||||
|
# 清空批量结果列表
|
||||||
|
batch_results = []
|
||||||
|
else:
|
||||||
|
print(f"保存第 {idx + 1} 批数据失败,停止处理")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 打印进度(每10条显示一次)
|
||||||
|
elif (idx + 1) % 10 == 0:
|
||||||
|
print(f"已处理 {idx + 1}/{len(df)} 条数据")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理第 {idx + 1} 行时出错: {e}")
|
||||||
|
# 添加错误记录
|
||||||
|
batch_results.append({
|
||||||
|
'text': text,
|
||||||
|
'source_label': source_label,
|
||||||
|
'classified_label': 'ERROR',
|
||||||
|
'classified_score': 0.0
|
||||||
|
})
|
||||||
|
total_processed = idx
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print(f"\n用户中断处理,已处理到第 {total_processed + 1} 条记录")
|
||||||
|
# 保存剩余的批量结果
|
||||||
|
if batch_results:
|
||||||
|
mode = 'a' if total_processed > 0 else 'w'
|
||||||
|
save_batch_results(batch_results, output_file, mode)
|
||||||
|
# 保存断点
|
||||||
|
save_checkpoint(checkpoint_file, total_processed, len(df))
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n处理过程中发生错误: {e}")
|
||||||
|
# 保存剩余的批量结果
|
||||||
|
if batch_results:
|
||||||
|
mode = 'a' if total_processed > 0 else 'w'
|
||||||
|
save_batch_results(batch_results, output_file, mode)
|
||||||
|
# 保存断点
|
||||||
|
save_checkpoint(checkpoint_file, total_processed, len(df))
|
||||||
|
return
|
||||||
|
|
||||||
|
# 记录分类结束时间并计算性能统计
|
||||||
|
end_time = time.time()
|
||||||
|
end_datetime = datetime.now()
|
||||||
|
total_time = end_time - start_time
|
||||||
|
processed_count = total_processed - start_idx + 1
|
||||||
|
|
||||||
|
print(f"\n分类结束时间: {end_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
print(f"本次处理耗时: {total_time:.2f} 秒")
|
||||||
|
print(f"本次处理记录数: {processed_count}")
|
||||||
|
if processed_count > 0:
|
||||||
|
print(f"平均每条数据耗时: {total_time/processed_count:.3f} 秒")
|
||||||
|
print(f"处理速度: {processed_count/total_time:.2f} 条/秒")
|
||||||
|
|
||||||
|
# 显示最终统计信息
|
||||||
|
try:
|
||||||
|
if os.path.exists(output_file):
|
||||||
|
results_df = pd.read_csv(output_file)
|
||||||
|
print(f"\n总共完成 {len(results_df)} 条记录的分类")
|
||||||
|
|
||||||
|
# 打印统计信息
|
||||||
|
print("\n分类结果统计:")
|
||||||
|
print(results_df['classified_label'].value_counts())
|
||||||
|
|
||||||
|
# 显示前几行结果作为示例
|
||||||
|
print(f"\n前5行结果预览:")
|
||||||
|
print(results_df.head())
|
||||||
|
|
||||||
|
# 如果全部完成,删除断点文件
|
||||||
|
if len(results_df) == len(df):
|
||||||
|
if os.path.exists(checkpoint_file):
|
||||||
|
os.remove(checkpoint_file)
|
||||||
|
print(f"\n所有记录处理完成,已删除断点文件")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取最终结果失败: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
161
promptinjection/result_analysis.py
Normal file
161
promptinjection/result_analysis.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 全局配置变量
|
||||||
|
SAFE_LABEL = "LABEL_0" # 表示安全/无害的标签
|
||||||
|
|
||||||
|
def analyze_classification_results(csv_file_path, safe_label=None):
|
||||||
|
"""
|
||||||
|
分析二分类结果,计算混淆矩阵和召回率
|
||||||
|
|
||||||
|
参数:
|
||||||
|
csv_file_path: CSV文件路径
|
||||||
|
safe_label: 表示安全/无害的标签,默认使用全局变量SAFE_LABEL
|
||||||
|
|
||||||
|
返回:
|
||||||
|
包含统计结果的字典
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 如果没有指定safe_label,使用全局变量
|
||||||
|
if safe_label is None:
|
||||||
|
safe_label = SAFE_LABEL
|
||||||
|
|
||||||
|
# 读取CSV文件
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(csv_file_path)
|
||||||
|
print(f"成功读取数据,共 {len(df)} 行记录")
|
||||||
|
print(f"列名: {list(df.columns)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取文件出错: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查数据类型并进行相应处理
|
||||||
|
print(f"\nsource_label数据类型: {df['source_label'].dtype}")
|
||||||
|
print(f"source_label唯一值: {df['source_label'].unique()}")
|
||||||
|
print(f"classified_label数据类型: {df['classified_label'].dtype}")
|
||||||
|
print(f"classified_label唯一值: {df['classified_label'].unique()}")
|
||||||
|
|
||||||
|
# 数据预处理:转换标签为二进制
|
||||||
|
# 根据实际数据,source_label已经是0和1,直接使用
|
||||||
|
# 0表示benign(无害),1表示有害
|
||||||
|
df['true_label'] = df['source_label']
|
||||||
|
|
||||||
|
# classified_label: safe_label=0, 其他=1
|
||||||
|
df['pred_label'] = df['classified_label'].apply(lambda x: 0 if str(x).strip().upper() == safe_label.upper() else 1)
|
||||||
|
|
||||||
|
# 打印标签分布
|
||||||
|
print("\n=== 标签分布 ===")
|
||||||
|
print("真实标签分布:")
|
||||||
|
print(f" benign (0): {sum(df['true_label'] == 0)} 个")
|
||||||
|
print(f" 有害 (1): {sum(df['true_label'] == 1)} 个")
|
||||||
|
|
||||||
|
print("\n预测标签分布:")
|
||||||
|
print(f" {safe_label} (0): {sum(df['pred_label'] == 0)} 个")
|
||||||
|
print(f" 非{safe_label} (1): {sum(df['pred_label'] == 1)} 个")
|
||||||
|
|
||||||
|
# 计算混淆矩阵
|
||||||
|
# TP: 真实为1,预测为1 (正确识别为有害)
|
||||||
|
# TN: 真实为0,预测为0 (正确识别为无害)
|
||||||
|
# FP: 真实为0,预测为1 (误判为有害)
|
||||||
|
# FN: 真实为1,预测为0 (漏判为无害)
|
||||||
|
|
||||||
|
TP = sum((df['true_label'] == 1) & (df['pred_label'] == 1))
|
||||||
|
TN = sum((df['true_label'] == 0) & (df['pred_label'] == 0))
|
||||||
|
FP = sum((df['true_label'] == 0) & (df['pred_label'] == 1))
|
||||||
|
FN = sum((df['true_label'] == 1) & (df['pred_label'] == 0))
|
||||||
|
|
||||||
|
# 计算性能指标
|
||||||
|
recall = TP / (TP + FN) if (TP + FN) > 0 else 0 # 召回率/敏感度
|
||||||
|
precision = TP / (TP + FP) if (TP + FP) > 0 else 0 # 精确率
|
||||||
|
accuracy = (TP + TN) / (TP + TN + FP + FN) # 准确率
|
||||||
|
specificity = TN / (TN + FP) if (TN + FP) > 0 else 0 # 特异性
|
||||||
|
false_positive_rate = FP / (FP + TN) if (FP + TN) > 0 else 0 # 误报率
|
||||||
|
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||||
|
|
||||||
|
# 输出结果
|
||||||
|
print("\n=== 混淆矩阵 ===")
|
||||||
|
print(f"True Positive (TP): {TP}")
|
||||||
|
print(f"True Negative (TN): {TN}")
|
||||||
|
print(f"False Positive (FP): {FP}")
|
||||||
|
print(f"False Negative (FN): {FN}")
|
||||||
|
|
||||||
|
print("\n=== 性能指标 ===")
|
||||||
|
print(f"召回率 (Recall): {recall:.4f}")
|
||||||
|
print(f"精确率 (Precision): {precision:.4f}")
|
||||||
|
print(f"准确率 (Accuracy): {accuracy:.4f}")
|
||||||
|
print(f"特异性 (Specificity): {specificity:.4f}")
|
||||||
|
print(f"误报率 (FPR): {false_positive_rate:.4f}")
|
||||||
|
print(f"F1分数: {f1_score:.4f}")
|
||||||
|
|
||||||
|
print("\n=== 混淆矩阵表格 ===")
|
||||||
|
print(" 预测")
|
||||||
|
print(f" {safe_label}(0) 非{safe_label}(1)")
|
||||||
|
print(f"真实 benign(0) {TN:4d} {FP:4d}")
|
||||||
|
print(f" 有害(1) {FN:4d} {TP:4d}")
|
||||||
|
|
||||||
|
# 详细分析
|
||||||
|
print("\n=== 详细分析 ===")
|
||||||
|
if FP > 0:
|
||||||
|
print(f"误报(FP): {FP} 个良性样本被错误分类为有害")
|
||||||
|
fp_samples = df[(df['true_label'] == 0) & (df['pred_label'] == 1)]
|
||||||
|
print("误报样本示例:")
|
||||||
|
for idx, row in fp_samples.head(3).iterrows():
|
||||||
|
print(f" - 文本: {row['text'][:100]}...")
|
||||||
|
print(f" 预测分数: {row['classified_score']:.4f}")
|
||||||
|
|
||||||
|
if FN > 0:
|
||||||
|
print(f"\n漏报(FN): {FN} 个有害样本被错误分类为无害")
|
||||||
|
fn_samples = df[(df['true_label'] == 1) & (df['pred_label'] == 0)]
|
||||||
|
print("漏报样本示例:")
|
||||||
|
for idx, row in fn_samples.head(3).iterrows():
|
||||||
|
print(f" - 文本: {row['text'][:100]}...")
|
||||||
|
print(f" 预测分数: {row['classified_score']:.4f}")
|
||||||
|
|
||||||
|
# 分析预测分数分布
|
||||||
|
print("\n=== 预测分数分析 ===")
|
||||||
|
print(f"{safe_label}预测的平均分数: {df[df['pred_label'] == 0]['classified_score'].mean():.4f}")
|
||||||
|
print(f"非{safe_label}预测的平均分数: {df[df['pred_label'] == 1]['classified_score'].mean():.4f}")
|
||||||
|
|
||||||
|
# 返回结果字典
|
||||||
|
results = {
|
||||||
|
'TP': TP, 'TN': TN, 'FP': FP, 'FN': FN,
|
||||||
|
'recall': recall, 'precision': precision,
|
||||||
|
'accuracy': accuracy, 'specificity': specificity,
|
||||||
|
'false_positive_rate': false_positive_rate, 'f1_score': f1_score,
|
||||||
|
'data_summary': {
|
||||||
|
'total_samples': len(df),
|
||||||
|
'benign_samples': sum(df['true_label'] == 0),
|
||||||
|
'harmful_samples': sum(df['true_label'] == 1),
|
||||||
|
'safe_predictions': sum(df['pred_label'] == 0),
|
||||||
|
'unsafe_predictions': sum(df['pred_label'] == 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 全局配置:可以根据不同模型修改安全标签
|
||||||
|
#SAFE_LABEL = "LABEL_0" # 用于llama prompt guard的模型
|
||||||
|
#SAFE_LABEL = "LEGIT" # 用于Deepset模型
|
||||||
|
SAFE_LABEL = "SAFE" # 用于ProtectAI模型
|
||||||
|
|
||||||
|
# 替换为您的CSV文件路径
|
||||||
|
csv_file_path = "result/deepset_protectai_results.csv" # 请修改为实际的文件路径
|
||||||
|
|
||||||
|
print("开始分析分类结果...")
|
||||||
|
print(f"当前安全标签配置: {SAFE_LABEL}")
|
||||||
|
|
||||||
|
# 可以通过参数覆盖全局配置
|
||||||
|
# results = analyze_classification_results(csv_file_path, safe_label="SAFE")
|
||||||
|
results = analyze_classification_results(csv_file_path)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
print(f"\n分析完成!")
|
||||||
|
print(f"准确率: {results['accuracy']:.4f}")
|
||||||
|
print(f"召回率: {results['recall']:.4f}")
|
||||||
|
print(f"精确率: {results['precision']:.4f}")
|
||||||
|
print(f"误报率: {results['false_positive_rate']:.4f}")
|
||||||
|
print(f"F1分数: {results['f1_score']:.4f}")
|
||||||
|
else:
|
||||||
|
print("分析失败,请检查文件路径和格式。")
|
Loading…
x
Reference in New Issue
Block a user