分类: Python/Ruby
2021-03-05 17:07:11
常规操作, 没什么好解释的. 缺模块的同学自行pip -install.
import numpy as np
import time
from matplotlib import pyplot as plt
import json
import copy
import os
import torch
from torch import nn
from torch import optim
from torchvision import transforms, models, datasets
数据读取与预处理
数据预处理部分:
数据增强: torchvision 中 transforms 模块自带功能, 用于扩充数据样本
数据预处理: torchvision 中 transforms 也帮我们实现好了
数据分批: DataLoader 模块直接读取 batch 数据
# ----------------1. 数据读取与预处理------------------
# 路径
data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
# 制作数据源
data_transforms = {
'train': transforms.Compose([transforms.RandomRotation(45), #随机旋转,-45到45度之间随机选
transforms.CenterCrop(224), #从中心开始裁剪
transforms.RandomHorizontalFlip(p=0.5), #随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5), #随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1), #参数1为亮度, 参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.025), #概率转换成灰度率, 3通道就是R=G=B
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #均值, 标准差
]),
'valid': transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
batch_size = 8
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
# 调试输出
print(image_datasets)
print(dataloaders)
print(dataset_sizes)
print(class_names)
# 读取标签对应的实际名字
with open('cat_to_name.json', 'r') as f:
cat_to_name = json.load(f)
print(cat_to_name)
输出结果:
{'train': Dataset ImageFolder
Number of datapoints: 6552
Root location: ./flower_data/train
StandardTransform
Transform: Compose(
RandomRotation(degrees=(-45, 45), resample=False, expand=False)
CenterCrop(size=(224, 224))
RandomHorizontalFlip(p=0.5)
RandomVerticalFlip(p=0.5)
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
RandomGrayscale(p=0.025)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
), 'valid': Dataset ImageFolder
Number of datapoints: 818
Root location: ./flower_data/valid
StandardTransform
Transform: Compose(
Resize(size=256, interpolation=PIL.Image.BILINEAR)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)}
{'train':
{'train': 6552, 'valid': 818}
['1', '10', '100', '101', '102', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']
{'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster',