import pandas as pd from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers import pipeline import os import time import json from datetime import datetime def load_checkpoint(checkpoint_file): """加载断点信息""" if os.path.exists(checkpoint_file): try: with open(checkpoint_file, 'r', encoding='utf-8') as f: checkpoint = json.load(f) print(f"发现断点文件,从第 {checkpoint['last_processed'] + 1} 条记录开始继续处理") return checkpoint except Exception as e: print(f"读取断点文件失败: {e}") return None return None def save_checkpoint(checkpoint_file, last_processed, total_records): """保存断点信息""" checkpoint = { 'last_processed': last_processed, 'total_records': total_records, 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') } try: with open(checkpoint_file, 'w', encoding='utf-8') as f: json.dump(checkpoint, f, ensure_ascii=False, indent=2) except Exception as e: print(f"保存断点文件失败: {e}") def save_batch_results(results, output_file, mode='w'): """保存批量结果到CSV""" try: results_df = pd.DataFrame(results) if mode == 'w': results_df.to_csv(output_file, index=False, encoding='utf-8') else: # append mode results_df.to_csv(output_file, mode='a', header=False, index=False, encoding='utf-8') return True except Exception as e: print(f"保存结果失败: {e}") return False def main(): # 设置文件路径,可使用ProtectAI,deepset,metaguard86M local_path = "./models/metaguard86M" input_file = "deepset_prompt_injection_data.csv" output_file = "./result/deepset_86M_results.csv" checkpoint_file = "86M_checkpoint.json" batch_size = 10 # 每批处理的记录数 # 检查模型路径是否存在 if not os.path.exists(local_path): print(f"错误: 模型路径 {local_path} 不存在") return # 加载断点信息 checkpoint = load_checkpoint(checkpoint_file) start_idx = 0 if checkpoint: start_idx = checkpoint['last_processed'] + 1 # 加载模型和分词器 print("正在加载模型和分词器...") try: tokenizer = AutoTokenizer.from_pretrained(local_path) model = AutoModelForSequenceClassification.from_pretrained(local_path) # 创建分类器 classifier = pipeline( "text-classification", model=model, tokenizer=tokenizer, truncation=True, max_length=512, ) print("模型加载成功!") except Exception as e: print(f"模型加载失败: {e}") return # 读取CSV文件 if not os.path.exists(input_file): print(f"错误: 输入文件 {input_file} 不存在") return try: print(f"正在读取 {input_file}...") df = pd.read_csv(input_file) print(f"成功读取 {len(df)} 行数据") # 检查必要的列是否存在 if 'text' not in df.columns: print("错误: CSV文件中缺少 'text' 列") return # 如果没有label列,创建一个空的 if 'label' not in df.columns: df['label'] = '' except Exception as e: print(f"读取CSV文件失败: {e}") return # 如果是断点续传,检查总记录数是否一致 if checkpoint and checkpoint['total_records'] != len(df): print(f"警告: 当前文件记录数({len(df)})与断点记录数({checkpoint['total_records']})不一致") response = input("是否要重新开始处理? (y/n): ") if response.lower() == 'y': start_idx = 0 if os.path.exists(output_file): os.remove(output_file) # 如果从头开始,清空之前的输出文件 if start_idx == 0 and os.path.exists(output_file): os.remove(output_file) print("已清空之前的输出文件") # 创建结果列表 batch_results = [] total_processed = start_idx # 记录分类开始时间 start_time = time.time() start_datetime = datetime.now() print(f"分类开始时间: {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}") if start_idx > 0: print(f"从第 {start_idx + 1} 条记录开始继续处理...") else: print("开始分类处理...") try: for idx in range(start_idx, len(df)): row = df.iloc[idx] text = str(row['text']) # 确保是字符串类型 source_label = row.get('label', '') # 获取原始标签,如果不存在则为空 try: # 进行分类 prediction = classifier(text) # 提取预测结果 predicted_label = prediction[0]['label'] predicted_score = prediction[0]['score'] # 添加到批量结果列表 batch_results.append({ 'text': text, 'source_label': source_label, 'classified_label': predicted_label, 'classified_score': predicted_score }) total_processed = idx # 每处理batch_size条记录或到达最后一条记录时保存 if len(batch_results) >= batch_size or idx == len(df) - 1: # 保存批量结果 mode = 'w' if idx < batch_size and start_idx == 0 else 'a' if save_batch_results(batch_results, output_file, mode): print(f"已处理并保存 {idx + 1}/{len(df)} 条数据") # 保存断点 save_checkpoint(checkpoint_file, idx, len(df)) # 清空批量结果列表 batch_results = [] else: print(f"保存第 {idx + 1} 批数据失败,停止处理") break # 打印进度(每10条显示一次) elif (idx + 1) % 10 == 0: print(f"已处理 {idx + 1}/{len(df)} 条数据") except Exception as e: print(f"处理第 {idx + 1} 行时出错: {e}") # 添加错误记录 batch_results.append({ 'text': text, 'source_label': source_label, 'classified_label': 'ERROR', 'classified_score': 0.0 }) total_processed = idx except KeyboardInterrupt: print(f"\n用户中断处理,已处理到第 {total_processed + 1} 条记录") # 保存剩余的批量结果 if batch_results: mode = 'a' if total_processed > 0 else 'w' save_batch_results(batch_results, output_file, mode) # 保存断点 save_checkpoint(checkpoint_file, total_processed, len(df)) return except Exception as e: print(f"\n处理过程中发生错误: {e}") # 保存剩余的批量结果 if batch_results: mode = 'a' if total_processed > 0 else 'w' save_batch_results(batch_results, output_file, mode) # 保存断点 save_checkpoint(checkpoint_file, total_processed, len(df)) return # 记录分类结束时间并计算性能统计 end_time = time.time() end_datetime = datetime.now() total_time = end_time - start_time processed_count = total_processed - start_idx + 1 print(f"\n分类结束时间: {end_datetime.strftime('%Y-%m-%d %H:%M:%S')}") print(f"本次处理耗时: {total_time:.2f} 秒") print(f"本次处理记录数: {processed_count}") if processed_count > 0: print(f"平均每条数据耗时: {total_time/processed_count:.3f} 秒") print(f"处理速度: {processed_count/total_time:.2f} 条/秒") # 显示最终统计信息 try: if os.path.exists(output_file): results_df = pd.read_csv(output_file) print(f"\n总共完成 {len(results_df)} 条记录的分类") # 打印统计信息 print("\n分类结果统计:") print(results_df['classified_label'].value_counts()) # 显示前几行结果作为示例 print(f"\n前5行结果预览:") print(results_df.head()) # 如果全部完成,删除断点文件 if len(results_df) == len(df): if os.path.exists(checkpoint_file): os.remove(checkpoint_file) print(f"\n所有记录处理完成,已删除断点文件") except Exception as e: print(f"读取最终结果失败: {e}") if __name__ == "__main__": main()