import torch
import torchvision
from torch import nn
import MyFunction as MF
import sys
import datetime
import time
from tqdm import tqdm
num_workers = 0 if sys.platform.startswith("win") else 4
def load_cifar10(is_train, augs, batch_size):
dataset = torchvision.datasets.CIFAR10(root="./data", train=is_train,
transform=augs, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=is_train,
num_workers=num_workers)
return dataloader
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
train_augs = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(224),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
normalize])
test_augs = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
normalize])
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 10)
nn.init.xavier_uniform_(finetune_net.fc.weight);
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train_batch_ch13(net, X, y, loss, trainer, device):
if isinstance(X, list):
X = [x.to(device) for x in X]
else:
X = X.to(device)
y = y.to(device)
net.train()
trainer.zero_grad()
pred = net(X)
l = loss(pred, y)
l.sum().backward()
trainer.step()
train_loss_sum = l.sum()
train_acc_sum = MF.accuracy(pred, y)
return train_loss_sum, train_acc_sum
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,device=device):
print("Start Training...")
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("==========" * 8 + "%s" % nowtime)
print(f'training on {device}:{torch.cuda.get_device_name()}')
net.to(device)
timer = MF.Timer()
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
start = time.time()
with tqdm(train_iter) as t:
for features, labels in t:
timer.start()
l, acc = train_batch_ch13(net, features, labels, loss, trainer,
device=device)
n += labels.shape[0]
train_l_sum += l
train_acc_sum += acc
train_l = train_l_sum / n
train_acc = train_acc_sum / n
timer.stop()
t.set_description(f"epoch:{epoch+1}/{num_epochs}")
t.set_postfix(loss="%.3f" % train_l, train_acc="%.3f" % train_acc, epoch_time="%.3f sec" % (time.time() - start))
torch.save(net.state_dict(), "C:/Users/52xj/Desktop/pytorch/data/finetuning88-%d.pth" %(epoch))
test_acc = MF.evaluate_accuracy_ch13(net, test_iter)
print(f'epoch:{epoch+1},loss {train_l:.3f}, train_acc {train_acc:.3f}, test_acc {test_acc:.3f}')
print(f'{n * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("==========" * 8 + "%s" % nowtime)
print('Finished Training...')
def train_fine_tuning(net, learning_rate, num_epochs=20,param_group=True):
batch_size = 256
train_iter = load_cifar10(True, train_augs, batch_size)
test_iter = load_cifar10(False, test_augs, batch_size)
loss = nn.CrossEntropyLoss(reduction="none")
if param_group:
params_1x = [param for name, param in net.named_parameters()
if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.SGD([{'params': params_1x},
{'params': net.fc.parameters(),'lr': learning_rate * 10}],
lr=learning_rate,
weight_decay=0.001)
else:
trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
weight_decay=0.001)
train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, device=device)
train_fine_tuning(finetune_net, 5e-5)
Start Training...
================================================================================2021-08-08 19:49:58
training on cuda:GeForce RTX 2060 SUPER
epoch:1/20: 100%|██████████| 196/196 [02:40<00:00, 1.22it/s, epoch_time=160.859 sec, loss=0.310, train_acc=0.891]
epoch:2/20: 100%|██████████| 196/196 [02:35<00:00, 1.26it/s, epoch_time=155.550 sec, loss=0.314, train_acc=0.890]
epoch:3/20: 100%|██████████| 196/196 [02:36<00:00, 1.25it/s, epoch_time=156.974 sec, loss=0.307, train_acc=0.892]
epoch:4/20: 100%|██████████| 196/196 [02:37<00:00, 1.25it/s, epoch_time=157.125 sec, loss=0.306, train_acc=0.891]
epoch:5/20: 100%|██████████| 196/196 [02:37<00:00, 1.24it/s, epoch_time=157.536 sec, loss=0.307, train_acc=0.893]
epoch:6/20: 100%|██████████| 196/196 [02:37<00:00, 1.24it/s, epoch_time=157.464 sec, loss=0.299, train_acc=0.894]
epoch:7/20: 100%|██████████| 196/196 [02:41<00:00, 1.22it/s, epoch_time=161.027 sec, loss=0.296, train_acc=0.896]
epoch:8/20: 100%|██████████| 196/196 [02:41<00:00, 1.21it/s, epoch_time=161.367 sec, loss=0.290, train_acc=0.898]
epoch:9/20: 100%|██████████| 196/196 [02:40<00:00, 1.22it/s, epoch_time=160.566 sec, loss=0.293, train_acc=0.898]
epoch:10/20: 100%|██████████| 196/196 [02:40<00:00, 1.22it/s, epoch_time=160.857 sec, loss=0.288, train_acc=0.899]
epoch:11/20: 100%|██████████| 196/196 [02:40<00:00, 1.22it/s, epoch_time=160.416 sec, loss=0.291, train_acc=0.898]
epoch:12/20: 100%|██████████| 196/196 [02:38<00:00, 1.24it/s, epoch_time=158.641 sec, loss=0.290, train_acc=0.900]
epoch:13/20: 100%|██████████| 196/196 [02:37<00:00, 1.24it/s, epoch_time=157.849 sec, loss=0.281, train_acc=0.901]
epoch:14/20: 100%|██████████| 196/196 [02:37<00:00, 1.24it/s, epoch_time=157.642 sec, loss=0.287, train_acc=0.900]
epoch:15/20: 100%|██████████| 196/196 [02:37<00:00, 1.25it/s, epoch_time=157.286 sec, loss=0.273, train_acc=0.903]
epoch:16/20: 100%|██████████| 196/196 [02:39<00:00, 1.23it/s, epoch_time=159.261 sec, loss=0.282, train_acc=0.901]
epoch:17/20: 100%|██████████| 196/196 [02:38<00:00, 1.23it/s, epoch_time=158.784 sec, loss=0.281, train_acc=0.902]
epoch:18/20: 100%|██████████| 196/196 [02:37<00:00, 1.24it/s, epoch_time=157.865 sec, loss=0.275, train_acc=0.903]
epoch:19/20: 100%|██████████| 196/196 [02:38<00:00, 1.24it/s, epoch_time=158.309 sec, loss=0.271, train_acc=0.905]
epoch:20/20: 100%|██████████| 196/196 [02:37<00:00, 1.24it/s, epoch_time=157.908 sec, loss=0.271, train_acc=0.905]
epoch:20,loss 0.271, train_acc 0.905, test_acc 0.956
482.4 examples/sec on cuda
================================================================================2021-08-08 20:48:34
Finished Training...
import torch
from torch import nn
import torchvision
from PIL import Image
import torchvision.transforms as transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 10)
nn.init.xavier_uniform_(finetune_net.fc.weight);
def predict_(img):
data_transform = transforms.Compose([transforms.Resize(32),
transforms.ToTensor()])
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
model = finetune_net
model.load_state_dict(torch.load("./data/finetuning88-19.pth"))
model.eval()
classes = {'0': '飞机', '1': '汽车', '2': '鸟', '3': '猫', '4': '鹿', '5': '狗', '6': '青蛙', '7': '马', '8': '船', '9': '卡车'}
with torch.no_grad():
output = model(img)
print("output",output)
pred = output.argmax(axis=1)
print(pred)
return classes[str(pred.item())]
img = Image.open("./img/test3.jpg")
net = predict_(img)
print(net)
由于电脑显卡比较low,只训练了20个周期,发现对猫和鸟的识别效果比较好。
|