Chinaunix首页 | 论坛 | 博客
  • 博客访问: 3593121
  • 博文数量: 365
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 2522
  • 用 户 组: 普通用户
  • 注册时间: 2019-10-28 13:40
文章分类

全部博文(365)

文章存档

2023年(8)

2022年(130)

2021年(155)

2020年(50)

2019年(22)

我的朋友

分类: Python/Ruby

2021-03-05 17:07:11

常规操作, 没什么好解释的. 缺模块的同学自行pip -install.

import numpy as np

import time

from matplotlib import pyplot as plt

import json

import copy

import os

import torch

from torch import nn

from torch import optim

from torchvision import transforms, models, datasets

数据读取与预处理

数据预处理部分:

数据增强: torchvision transforms 模块自带功能, 用于扩充数据样本

数据预处理: torchvision transforms 也帮我们实现好了

数据分批: DataLoader 模块直接读取 batch 数据

# ----------------1. 数据读取与预处理------------------

# 路径

data_dir = './flower_data/'

train_dir = data_dir + '/train'

valid_dir = data_dir + '/valid'

# 制作数据源

data_transforms = {

    'train': transforms.Compose([transforms.RandomRotation(45),  #随机旋转,-4545度之间随机选

        transforms.CenterCrop(224),  #从中心开始裁剪

        transforms.RandomHorizontalFlip(p=0.5),  #随机水平翻转 选择一个概率概率

        transforms.RandomVerticalFlip(p=0.5),  #随机垂直翻转

        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  #参数1为亮度, 参数2为对比度,参数3为饱和度,参数4为色相

        transforms.RandomGrayscale(p=0.025),  #概率转换成灰度率, 3通道就是R=G=B

        transforms.ToTensor(),

        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  #均值, 标准差

    ]),

    'valid': transforms.Compose([transforms.Resize(256),

        transforms.CenterCrop(224),

        transforms.ToTensor(),

        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    ]),

}

batch_size = 8

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}

class_names = image_datasets['train'].classes

# 调试输出

print(image_datasets)

print(dataloaders)

print(dataset_sizes)

print(class_names)

# 读取标签对应的实际名字

with open('cat_to_name.json', 'r') as f:

    cat_to_name = json.load(f)

print(cat_to_name)

输出结果:

{'train': Dataset ImageFolder

    Number of datapoints: 6552

    Root location: ./flower_data/train

    StandardTransform

Transform: Compose(

               RandomRotation(degrees=(-45, 45), resample=False, expand=False)

               CenterCrop(size=(224, 224))

               RandomHorizontalFlip(p=0.5)

               RandomVerticalFlip(p=0.5)

               ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])

               RandomGrayscale(p=0.025)

               ToTensor()

               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

           ), 'valid': Dataset ImageFolder

    Number of datapoints: 818

    Root location: ./flower_data/valid

    StandardTransform

Transform: Compose(

               Resize(size=256, interpolation=PIL.Image.BILINEAR)

               CenterCrop(size=(224, 224))

               ToTensor()

               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

           )}

{'train': , 'valid': }

{'train': 6552, 'valid': 818}

['1', '10', '100', '101', '102', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']

{'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster',

阅读(24428) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~