AIPentest/promptinjection/guardtest.py
AiShell 75564e12e5
Add files via upload
for test and result analysis
2025-06-07 16:47:17 +08:00

249 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import os
import time
import json
from datetime import datetime
def load_checkpoint(checkpoint_file):
"""加载断点信息"""
if os.path.exists(checkpoint_file):
try:
with open(checkpoint_file, 'r', encoding='utf-8') as f:
checkpoint = json.load(f)
print(f"发现断点文件,从第 {checkpoint['last_processed'] + 1} 条记录开始继续处理")
return checkpoint
except Exception as e:
print(f"读取断点文件失败: {e}")
return None
return None
def save_checkpoint(checkpoint_file, last_processed, total_records):
"""保存断点信息"""
checkpoint = {
'last_processed': last_processed,
'total_records': total_records,
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
try:
with open(checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存断点文件失败: {e}")
def save_batch_results(results, output_file, mode='w'):
"""保存批量结果到CSV"""
try:
results_df = pd.DataFrame(results)
if mode == 'w':
results_df.to_csv(output_file, index=False, encoding='utf-8')
else: # append mode
results_df.to_csv(output_file, mode='a', header=False, index=False, encoding='utf-8')
return True
except Exception as e:
print(f"保存结果失败: {e}")
return False
def main():
# 设置文件路径,可使用ProtectAI,deepset,metaguard86M
local_path = "./models/metaguard86M"
input_file = "deepset_prompt_injection_data.csv"
output_file = "./result/deepset_86M_results.csv"
checkpoint_file = "86M_checkpoint.json"
batch_size = 10 # 每批处理的记录数
# 检查模型路径是否存在
if not os.path.exists(local_path):
print(f"错误: 模型路径 {local_path} 不存在")
return
# 加载断点信息
checkpoint = load_checkpoint(checkpoint_file)
start_idx = 0
if checkpoint:
start_idx = checkpoint['last_processed'] + 1
# 加载模型和分词器
print("正在加载模型和分词器...")
try:
tokenizer = AutoTokenizer.from_pretrained(local_path)
model = AutoModelForSequenceClassification.from_pretrained(local_path)
# 创建分类器
classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=512,
)
print("模型加载成功!")
except Exception as e:
print(f"模型加载失败: {e}")
return
# 读取CSV文件
if not os.path.exists(input_file):
print(f"错误: 输入文件 {input_file} 不存在")
return
try:
print(f"正在读取 {input_file}...")
df = pd.read_csv(input_file)
print(f"成功读取 {len(df)} 行数据")
# 检查必要的列是否存在
if 'text' not in df.columns:
print("错误: CSV文件中缺少 'text'")
return
# 如果没有label列创建一个空的
if 'label' not in df.columns:
df['label'] = ''
except Exception as e:
print(f"读取CSV文件失败: {e}")
return
# 如果是断点续传,检查总记录数是否一致
if checkpoint and checkpoint['total_records'] != len(df):
print(f"警告: 当前文件记录数({len(df)})与断点记录数({checkpoint['total_records']})不一致")
response = input("是否要重新开始处理? (y/n): ")
if response.lower() == 'y':
start_idx = 0
if os.path.exists(output_file):
os.remove(output_file)
# 如果从头开始,清空之前的输出文件
if start_idx == 0 and os.path.exists(output_file):
os.remove(output_file)
print("已清空之前的输出文件")
# 创建结果列表
batch_results = []
total_processed = start_idx
# 记录分类开始时间
start_time = time.time()
start_datetime = datetime.now()
print(f"分类开始时间: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
if start_idx > 0:
print(f"从第 {start_idx + 1} 条记录开始继续处理...")
else:
print("开始分类处理...")
try:
for idx in range(start_idx, len(df)):
row = df.iloc[idx]
text = str(row['text']) # 确保是字符串类型
source_label = row.get('label', '') # 获取原始标签,如果不存在则为空
try:
# 进行分类
prediction = classifier(text)
# 提取预测结果
predicted_label = prediction[0]['label']
predicted_score = prediction[0]['score']
# 添加到批量结果列表
batch_results.append({
'text': text,
'source_label': source_label,
'classified_label': predicted_label,
'classified_score': predicted_score
})
total_processed = idx
# 每处理batch_size条记录或到达最后一条记录时保存
if len(batch_results) >= batch_size or idx == len(df) - 1:
# 保存批量结果
mode = 'w' if idx < batch_size and start_idx == 0 else 'a'
if save_batch_results(batch_results, output_file, mode):
print(f"已处理并保存 {idx + 1}/{len(df)} 条数据")
# 保存断点
save_checkpoint(checkpoint_file, idx, len(df))
# 清空批量结果列表
batch_results = []
else:
print(f"保存第 {idx + 1} 批数据失败,停止处理")
break
# 打印进度每10条显示一次
elif (idx + 1) % 10 == 0:
print(f"已处理 {idx + 1}/{len(df)} 条数据")
except Exception as e:
print(f"处理第 {idx + 1} 行时出错: {e}")
# 添加错误记录
batch_results.append({
'text': text,
'source_label': source_label,
'classified_label': 'ERROR',
'classified_score': 0.0
})
total_processed = idx
except KeyboardInterrupt:
print(f"\n用户中断处理,已处理到第 {total_processed + 1} 条记录")
# 保存剩余的批量结果
if batch_results:
mode = 'a' if total_processed > 0 else 'w'
save_batch_results(batch_results, output_file, mode)
# 保存断点
save_checkpoint(checkpoint_file, total_processed, len(df))
return
except Exception as e:
print(f"\n处理过程中发生错误: {e}")
# 保存剩余的批量结果
if batch_results:
mode = 'a' if total_processed > 0 else 'w'
save_batch_results(batch_results, output_file, mode)
# 保存断点
save_checkpoint(checkpoint_file, total_processed, len(df))
return
# 记录分类结束时间并计算性能统计
end_time = time.time()
end_datetime = datetime.now()
total_time = end_time - start_time
processed_count = total_processed - start_idx + 1
print(f"\n分类结束时间: {end_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"本次处理耗时: {total_time:.2f}")
print(f"本次处理记录数: {processed_count}")
if processed_count > 0:
print(f"平均每条数据耗时: {total_time/processed_count:.3f}")
print(f"处理速度: {processed_count/total_time:.2f} 条/秒")
# 显示最终统计信息
try:
if os.path.exists(output_file):
results_df = pd.read_csv(output_file)
print(f"\n总共完成 {len(results_df)} 条记录的分类")
# 打印统计信息
print("\n分类结果统计:")
print(results_df['classified_label'].value_counts())
# 显示前几行结果作为示例
print(f"\n前5行结果预览:")
print(results_df.head())
# 如果全部完成,删除断点文件
if len(results_df) == len(df):
if os.path.exists(checkpoint_file):
os.remove(checkpoint_file)
print(f"\n所有记录处理完成,已删除断点文件")
except Exception as e:
print(f"读取最终结果失败: {e}")
if __name__ == "__main__":
main()