import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torchvision
'''下载训练集 CIFAR-10 10分类训练集'''
train_dataset = datasets.CIFAR10('./lab_dir/data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
classes=('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def imshow(img):
img = img/2+0.5
npimg = img.numpy()
'''
img 格式: channels,imageSize,imageSize
imshow需要格式:imageSize,imageSize,channels
np.transpose 转换数组
'''
plt.imshow(np.transpose(npimg,(1,2,0)))
plt.show()
dataIter = iter(train_loader)
images,labels=dataIter.next()
#拼接图像:make_grid
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(64)))
# 单独查看某一张图片
pic = images[0,:,:,:].numpy()
pic.shape
# 这时pic的shape是3,224,224,plt.imshow的时候应该是224,224,3的shape
# 下面转换一下通道
pic = np.transpose(pic,(1,2,0))
type(pic)
plt.imshow(pic)
plt.show()
?
|