Chinaunix首页 | 论坛 | 博客
  • 博客访问: 3665920
  • 博文数量: 365
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 2522
  • 用 户 组: 普通用户
  • 注册时间: 2019-10-28 13:40
文章分类

全部博文(365)

文章存档

2023年(8)

2022年(130)

2021年(155)

2020年(50)

2019年(22)

我的朋友

分类: Python/Ruby

2021-05-05 17:27:05

from pycocotools.coco import COCO

import numpy as np

import skimage.io as io

import matplotlib.pyplot as plt

import os

from PIL import Image

from PIL import ImageDraw

import csv

import shutil

def create_coco_maps(ann_handle):

    coco_name_maps = {}

    coco_id_maps = {}

    cat_ids = ann_handle.getCatIds()

    cat_infos = ann_handle.loadCats(cat_ids)

    for cat_info in cat_infos:

        cat_name = cat_info['name']

        cat_id = cat_info['id']

        if cat_name not in coco_name_maps.keys():

            coco_name_maps[cat_name] = cat_id

        if cat_id not in coco_id_maps.keys():

            coco_id_maps[cat_id] = cat_name

    return coco_name_maps, coco_id_maps

def get_need_cls_ids(need_cls_names, coco_name_maps):

    need_cls_ids = []

    for cls_name in coco_name_maps.keys():

        if cls_name in need_cls_names:

            need_cls_ids.append(coco_name_maps[cls_name])

    return need_cls_ids

def get_new_label_id(name, need_cls_names):

    for i,need_name in enumerate(need_cls_names):

        if name == need_name:

            return i

    return None

if __name__ == '__main__':

    # create coco ann handle

    need_cls_names = ['person','bicycle','car','motorcycle','bus','truck','traffic light']

    dst_img_dir = '/dataset/coco_traffic_yolov5/images/val/'

    dst_label_dir = '/dataset/coco_traffic_yolov5/labels/val/'

    min_side = 0.04464 # while 224*224, min side is 10. 0.04464=10/224

    dataDir='/dataset/COCO/'

    dataType='val2017'

    annFile = '{}/annotations/instances_{}.json'.format(dataDir,dataType)

    ann_handle=COCO(annFile)

    # create coco maps for id and name

    coco_name_maps, coco_id_maps = create_coco_maps(ann_handle)

    # get need_cls_ids

    need_cls_ids = get_need_cls_ids(need_cls_names, coco_name_maps)

    # get all imgids

    img_ids = ann_货币代码handle.getImgIds() # get all imgids

    for i,img_id in enumerate(img_ids):

        print('process img: %d/%d'%(i, len(img_ids)))

        new_info = ''

        img_info = ann_handle.loadImgs(img_id)[0]

        img_name = img_info['file_name']

        img_height = img_info['height']

        img_width = img_info['width']

        boj_infos = []

        ann_ids = ann_handle.getAnnIds(imgIds=img_id,iscrowd=None)

        for ann_id in ann_ids:

            anns = ann_handle.loadAnns(ann_id)[0]

            obj_cls = anns['category_id']

            obj_name = coco_id_maps[obj_cls]

            obj_box = anns['bbox']

            if obj_name in need_cls_names:

                new_label = get_new_label_id(obj_name, need_cls_names)

                x1 = obj_box[0]

                y1 = obj_box[1]

                w = obj_box[2]

                h = obj_box[3]

                #x_c_norm = (x1) / img_width

                #y_c_norm = (y1) / img_height

                x_c_norm = (x1 + w / 2.0) / img_width

                y_c_norm = (y1 + h / 2.0) / img_height

                w_norm = w / img_width

                h_norm = h / img_height

                if w_norm > min_side and  h_norm > min_side:

                    boj_infos.append('%d %.4f %.4f %.4f %.4f\n'%(new_label, x_c_norm, y_c_norm, w_norm, h_norm))

        if len(boj_infos) > 0:

            print('  this img has need cls')

            shutil.copy(dataDir + '/' + dataType + '/' + img_name, dst_img_dir + '/' + img_name)

            with open(dst_label_dir + '/' + img_name.replace('.jpg', '.txt'), 'w') as f:

                f.writelines(boj_infos)

        else:

            print('  this img has no need cls')

阅读(1069) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~