分类: Python/Ruby
2021-05-05 17:27:05
from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import os
from PIL import Image
from PIL import ImageDraw
import csv
import shutil
def create_coco_maps(ann_handle):
coco_name_maps = {}
coco_id_maps = {}
cat_ids = ann_handle.getCatIds()
cat_infos = ann_handle.loadCats(cat_ids)
for cat_info in cat_infos:
cat_name = cat_info['name']
cat_id = cat_info['id']
if cat_name not in coco_name_maps.keys():
coco_name_maps[cat_name] = cat_id
if cat_id not in coco_id_maps.keys():
coco_id_maps[cat_id] = cat_name
return coco_name_maps, coco_id_maps
def get_need_cls_ids(need_cls_names, coco_name_maps):
need_cls_ids = []
for cls_name in coco_name_maps.keys():
if cls_name in need_cls_names:
need_cls_ids.append(coco_name_maps[cls_name])
return need_cls_ids
def get_new_label_id(name, need_cls_names):
for i,need_name in enumerate(need_cls_names):
if name == need_name:
return i
return None
if __name__ == '__main__':
# create coco ann handle
need_cls_names = ['person','bicycle','car','motorcycle','bus','truck','traffic light']
dst_img_dir = '/dataset/coco_traffic_yolov5/images/val/'
dst_label_dir = '/dataset/coco_traffic_yolov5/labels/val/'
min_side = 0.04464 # while 224*224, min side is 10. 0.04464=10/224
dataDir='/dataset/COCO/'
dataType='val2017'
annFile = '{}/annotations/instances_{}.json'.format(dataDir,dataType)
ann_handle=COCO(annFile)
# create coco maps for id and name
coco_name_maps, coco_id_maps = create_coco_maps(ann_handle)
# get need_cls_ids
need_cls_ids = get_need_cls_ids(need_cls_names, coco_name_maps)
# get all imgids
img_ids = ann_货币代码handle.getImgIds() # get all imgids
for i,img_id in enumerate(img_ids):
print('process img: %d/%d'%(i, len(img_ids)))
new_info = ''
img_info = ann_handle.loadImgs(img_id)[0]
img_name = img_info['file_name']
img_height = img_info['height']
img_width = img_info['width']
boj_infos = []
ann_ids = ann_handle.getAnnIds(imgIds=img_id,iscrowd=None)
for ann_id in ann_ids:
anns = ann_handle.loadAnns(ann_id)[0]
obj_cls = anns['category_id']
obj_name = coco_id_maps[obj_cls]
obj_box = anns['bbox']
if obj_name in need_cls_names:
new_label = get_new_label_id(obj_name, need_cls_names)
x1 = obj_box[0]
y1 = obj_box[1]
w = obj_box[2]
h = obj_box[3]
#x_c_norm = (x1) / img_width
#y_c_norm = (y1) / img_height
x_c_norm = (x1 + w / 2.0) / img_width
y_c_norm = (y1 + h / 2.0) / img_height
w_norm = w / img_width
h_norm = h / img_height
if w_norm > min_side and h_norm > min_side:
boj_infos.append('%d %.4f %.4f %.4f %.4f\n'%(new_label, x_c_norm, y_c_norm, w_norm, h_norm))
if len(boj_infos) > 0:
print(' this img has need cls')
shutil.copy(dataDir + '/' + dataType + '/' + img_name, dst_img_dir + '/' + img_name)
with open(dst_label_dir + '/' + img_name.replace('.jpg', '.txt'), 'w') as f:
f.writelines(boj_infos)
else:
print(' this img has no need cls')