数据集
使用宝可梦精灵的图片数据集。数据集地址:
- 链接:https://pan.baidu.com/s/1zDERMsV1AvwfZudhuae6Ew
- 提取码:rs4h
数据集中的每一类别的图片放在一个文件夹中 数据集共包含5个类别的图片,我们取每个文件夹(类别):
- 前60%做训练集
- 60%~80%做验证集
- 80%~100%做测试集
数据集处理
'''
load图片数据集
'''
import torch
import os, glob
import random, csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
class Pokemon(Dataset):
def __init__(self, root, resize, mode):
'''
:param root: 数据集目录
:param resize: 图片的输出size
:param mode: train/val/test
'''
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.name2label = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label.keys())
'''读入图片数据集'''
self.images, self.labels = self.load_csv('images.csv')
'''划分train、val、test集'''
if mode=='train':
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif mode=='val':
self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else:
self.images = self.images[int(0.8*len(self.images)):]
self.labels = self.labels[int(0.8*len(self.labels)):]
def load_csv(self, filename):
'''
一次加载进所有图片可能会造成内存不够用,因此我们可以把图片保存到一个csv文件
:param filename:保存的文件名
:return:
'''
if not os.path.exists(os.path.join(self.root, filename)):
'''把所有的文件放到一个list中去。文件的class可以通过路径名来判定'''
images = []
for name in self.name2label.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images), images)
random.shuffle(images)
'''写入csv文件'''
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img, label])
print('writen into csv file:', filename)
'''read from csv file'''
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def __len__(self):
'''
返回总体样本数量
:return:
'''
return len(self.images)
def denormalize(self, x_hat):
'''
逆标准化处理
:param x_hat: 标准化的tensor
:return: 逆标准化的tensor
'''
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
print(mean.shape, std.shape)
x = x_hat * std + mean
return x
def __getitem__(self, idx):
'''
取得当前位置图片
:param idx: 图片索引
:return:
'''
img, label = self.images[idx], self.labels[idx]
'''数据增强之后将图片转换为tensor'''
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = tf(img)
label = torch.tensor(label)
return img, label
def main():
'''
可视化查看数据集
此处需要安装并开启visdom
安装:pip install visdom
开启:python -m visdom.server
'''
import visdom
import time
import torchvision
viz = visdom.Visdom()
db = Pokemon('pokemon', 64, 'train')
x,y = next(iter(db))
print('sample:', x.shape, y.shape, y)
viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
for x, y in loader:
viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
if __name__ == '__main__':
main()
迁移学习网络
原理
Pokemon和ImageNet都需要图片中提取特征,因此存在某些共性的knowledge。因此我们可以利用更加通用的ImageNet的模型,帮我们解决特定的图片分类任务。
我们采用torchvision.models中训练好的resnet18,使用它训练好的卷积部分提取图像特征,并训练新的分类器处理我们提取到的特征。
这样我们只需要训练分类器,而不用再训练特征提取器,因此可以减少所需训练量。
代码实现
辅助文件:utils.py
from matplotlib import pyplot as plt
import torch
from torch import nn
'''
定义一个神经网络层
第一个维度保持,其他维度打平成一个维度
'''
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
'''
把image打印在matplotlab上
'''
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
实现网络构建,网络训练与评估的文件:train_transfer.py
'''
利用迁移学习
torchvision提供了训练好的resnet18、resnet34、resnet50...
此处需要安装并开启visdom
安装:pip install visdom
开启:python -m visdom.server
'''
import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
from pokemon import Pokemon
from utils import Flatten
from torchvision.models import resnet18
batchsz = 32
lr = 1e-3
epochs = 10
device = torch.device('cuda')
torch.manual_seed(1234)
train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
viz = visdom.Visdom()
def evalute(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x,y in loader:
x,y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct / total
def main():
'''初始化网络'''
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
Flatten(),
nn.Linear(512, 5)
).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
'''记录实验结果参数'''
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
'''训练与评估'''
for epoch in range(epochs):
'''训练一次模型'''
for step, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
'''评估模型'''
if epoch % 1 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
torch.save(model.state_dict(), 'best.mdl')
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best acc:', best_acc, 'best epoch:', best_epoch)
'''加载最优模型'''
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
'''测试模型'''
test_acc = evalute(model, test_loader)
print('test acc:', test_acc)
if __name__ == '__main__':
main()
|