Chinaunix首页 | 论坛 | 博客
  • 博客访问: 1014373
  • 博文数量: 157
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 1388
  • 用 户 组: 普通用户
  • 注册时间: 2015-04-09 15:37
文章分类

全部博文(157)

文章存档

2023年(9)

2022年(2)

2021年(18)

2020年(7)

2017年(13)

2016年(53)

2015年(55)

我的朋友

分类: Python/Ruby

2022-12-28 17:57:03

1. 读取+遍历处理+写入

点击(此处)折叠或打开

  1.     all_rate_data_list = list()
  2.     better_rate_list = list()
  3.     rate_info_headers = ["threshold", "real_recall", "fake_recall"]
  4.     while threshold <= 1:
  5.         real_judge_real = 0
  6.         real_total = 0
  7.         fake_judge_fake = 0
  8.         fake_total = 0
  9.         per_rate_data_dict = dict()
  10.         for root, dirs, files in os.walk(csv_dir): # root: 当前目录路径 ; dirs :当前路径下所有子目录 ; files : 文件夹下所有文件名
  11.             for csv_file in files:
  12.                 file_str = os.path.join(root, csv_file)
  13.                 with open(file_str, "rU", encoding='utf-8') as myFile:
  14.                     reader = csv.reader(myFile)
  15.                     # sum_num = sum(1 for row in myFile) - 1
  16.                     myFile.seek(0)
  17.                     next(reader) # 从第二行开始获取数据
  18.                     for row in reader:
  19.                         row = list(map(str.strip, row))
  20.                         branch = row[1]
  21.                         file_path = row[0]
  22.                         xxx业务逻辑xxx

  23.         real_recall= round(real_judge_real / real_total, 8) if real_total else 0
  24.         fake_recall= round(fake_judge_fake / fake_total,8) if fake_total else 0
  25.    
  26.         per_rate_data_dict["threshold"] = round(threshold, 8)
  27.         per_rate_data_dict["real_recall"] = real_recall
  28.         per_rate_data_dict["fake_recall"] = fake_fomm_recall
  29.        
  30.         if per_rate_data_dict["real_recall"] >=target_real_recall:
  31.             better_rate_list.append(per_rate_data_dict)

  32.         all_rate_data_list.append(per_rate_data_dict)
  33.         threshold += step

  34.     if all_rate_data_list:
  35.         with open(os.path.join(csv_path, '{0}.csv'.format(model_name + "_" + model_type + "_logic1不同阈值下的准确率")),
  36.                   'w', newline='', encoding='utf-8') as f:
  37.             f_csv = csv.DictWriter(f, rate_info_headers)
  38.             f_csv.writeheader()
  39.             f_csv.writerows(all_rate_data_list)
  40.     if better_rate_list:
  41.         with open(os.path.join(csv_path, '{0}.csv'.format(model_name + "_" + model_type + "_logic1{BANNED}最佳优准确率下的阈值串")),
  42.                   'w', newline='', encoding='utf-8') as f:
  43.             f_csv = csv.DictWriter(f, rate_info_headers)
  44.             f_csv.writeheader()
  45.             f_csv.writerows(better_rate_list)
    


阅读(165) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~