图像识别实战(一)----数据集的预处理
1.模块的导入
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image
2.数据集的读取
data_dir = './flower_data'
train_dir = data_dir+ '/train'
valid_dir = data_dir+ '/valid'
3.数据集的预处理
data_transforms = {
'train': transforms.Compose([transforms.RandomRotation(45),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.1,saturation=0.1, hue=0.1),
transforms.RandomGrayscale(p=0.025),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
'valid': transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]),
transforms.Compose() 这个类的主要功能是串联图片的变换操作,类似于一个列表。
4.数据集的组织与加载
batch_size = 8
image_datasets = {x:datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x]) for x in ['train','valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
dataset=torchvision.datasets.ImageFolder(
root,
transform=None,
target_transform=None,
loader=<function default_loader>,
is_valid_file=None)
我们打印出 image_datasets
{'train': Dataset ImageFolder
Number of datapoints: 3614
Root location: F:/flower_data/train
StandardTransform
Transform: Compose(
RandomRotation(degrees=(-45, 45), resample=False, expand=False)
CenterCrop(size=(224, 224))
RandomHorizontalFlip(p=0.5)
RandomVerticalFlip(p=0.5)
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
RandomGrayscale(p=0.025)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
),
'valid': Dataset ImageFolder
Number of datapoints: 56
Root location: F:/flower_data/valid
StandardTransform
Transform: Compose(
Resize(size=256, interpolation=PIL.Image.BILINEAR)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)}
我们打印出 dataset_sizes 帮助理解{}中的逻辑
{'train': 3614, 'valid': 56}
5.数据集图像展示
def im_convert(tensor):
"""展示数据"""
image = tensor.to('cpu').clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image*np.array((0.229, 0.224, 0.225))+np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
fig = plt.figure(figsize=(20, 12))
columns = 4
rows = 2
dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()
for idx in range (columns*rows):
ax = fig.add_subplot(rows,columns, idx+1,xticks=[], yticks=[])
ax.set_title(class_names[classes[idx].item()])
plt.imshow(im_convert(inputs[idx]))
plt.show()
|