完整项目代码:https://github.com/SPECTRELWF/pytorch-cnn-study 个人主页:liuweifeng.top:8090
ResNet网络结构
ResNet是何恺明大神在CVPR2016的工作,也拿到了当年的最佳论文。是为了解决深层网络的梯度消失的问题,引入了残差块连接。
数据集描述
数据集使用的是来自格物钛的一个公开数据集,数据集下载地址:https://gas.graviti.cn/dataset/data-decorators/COVID_CT,里面包含715张图片,包含确诊和未确诊的,比例大概一比一,图像是处理过的CT图像。
网络结构
使用pytorch的torchvision里面提供的resnet50(),未使用预训练模型。在后面再加上一层全连接层:
import torchvision
import torch.nn as nn
class my_resnet50(nn.Module):
def __init__(self):
super(my_resnet50, self).__init__()
self.backbone = torchvision.models.resnet50(pretrained=False)
self.fc2 = nn.Linear(1000,512)
self.fc3 = nn.Linear(512,2)
def forward(self,x):
x = self.backbone(x)
x = self.fc2(x)
x = self.fc3(x)
return x
train:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.utils.data as data
from torch.utils.data import DataLoader
from dataload.COVID_Dataload import COVID
from resnet50 import my_resnet50
from torch import nn,optim
transforms = transforms.Compose([
transforms.Resize([224,224]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
batch_size = 32
train_set = COVID(transformer=transforms,train=True)
train_loader = DataLoader(train_set,
batch_size = batch_size,
shuffle = True,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 200
lr = 1e-4
net = my_resnet50().cuda(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=lr,momentum=0.9)
train_loss = []
for epoch in range(epochs):
sum_loss = 0
for batch_idx,(x,y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device)
pred = net(x)
optimizer.zero_grad()
loss = loss_func(pred, y)
loss.backward()
optimizer.step()
sum_loss += loss.item()
train_loss.append(loss.item())
print(["epoch:%d , batch:%d , loss:%.3f" % (epoch, batch_idx,loss.item())])
torch.save(net.state_dict(),'model/no_pretrain/epoch' + str(epoch+1) + '.pth')
from utils import plot_curve
plot_curve(train_loss)
test:
import torch
import torchvision
from dataload.COVID_Dataload import COVID
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torchvision.transforms as transforms
from resnet50 import my_resnet50
transform = transforms.Compose([
transforms.Resize([224,224]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_dataset = COVID(train=False,transformer=transform)
test_loader = DataLoader(test_dataset,
batch_size = 32,
shuffle = False,
)
def predict():
net = my_resnet50().to(device)
net.load_state_dict(torch.load('/home/lwf/code/pytorch学习/ResNet/resnet新冠病毒确诊的预测/model/no_pretrain/epoch200.pth'))
print(net)
total_correct = 0
for batch_idx, (x, y) in enumerate(test_loader):
x = x.to(device)
print(x.shape)
y = y.to(device)
print('y',y)
out = net(x)
pred = out.argmax(dim=1)
print('pred',pred)
correct = pred.eq(y).sum().float().item()
total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print("test acc:", acc)
predict()
predict
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from resnet50 import my_resnet50
transform = transforms.Compose([
transforms.Resize([224,224]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
file_name = input("输入要预测的文件名:")
img = Image.open(file_name).convert("RGB")
show_img = img
img = transform(img)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img = img.to(device)
img = img.unsqueeze(0)
net = my_resnet50().to(device)
net.load_state_dict(torch.load(r'model/no_pretrain/epoch200.pth'))
pred = net(img)
print(pred)
print(pred.argmax(dim = 1).cpu().numpy()[0])
res = ''
if pred.argmax(dim = 1) == 0:
res += 'pred:no_covid'
else:
res += 'pred:covid'
plt.figure("Predict")
plt.imshow(show_img)
plt.axis("off")
plt.title(res)
plt.show()
|