IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch:ResNet50做新冠肺炎CT照片是否确诊分类 -> 正文阅读

[人工智能]pytorch:ResNet50做新冠肺炎CT照片是否确诊分类

完整项目代码:https://github.com/SPECTRELWF/pytorch-cnn-study
个人主页:liuweifeng.top:8090

ResNet网络结构

在这里插入图片描述

ResNet是何恺明大神在CVPR2016的工作,也拿到了当年的最佳论文。是为了解决深层网络的梯度消失的问题,引入了残差块连接。

数据集描述

数据集使用的是来自格物钛的一个公开数据集,数据集下载地址:https://gas.graviti.cn/dataset/data-decorators/COVID_CT,里面包含715张图片,包含确诊和未确诊的,比例大概一比一,图像是处理过的CT图像。
在这里插入图片描述

网络结构

使用pytorch的torchvision里面提供的resnet50(),未使用预训练模型。在后面再加上一层全连接层:

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/9 下午4:57

import torchvision
import torch.nn as nn



class my_resnet50(nn.Module):
    def __init__(self):
        super(my_resnet50, self).__init__()
        self.backbone = torchvision.models.resnet50(pretrained=False)
        self.fc2 = nn.Linear(1000,512)
        self.fc3 = nn.Linear(512,2)

    def forward(self,x):
        x = self.backbone(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

train:

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/9 下午4:48

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.utils.data as data
from torch.utils.data import DataLoader
from dataload.COVID_Dataload import COVID
from resnet50 import my_resnet50
from torch import nn,optim

transforms = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.RandomHorizontalFlip(),
    # transforms.RandomCrop(224),
    transforms.ToTensor(),

])

batch_size = 32
train_set = COVID(transformer=transforms,train=True)
train_loader = DataLoader(train_set,
                          batch_size = batch_size,
                          shuffle = True,
                          )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#设置超参数
epochs = 200
lr = 1e-4

net = my_resnet50().cuda(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=lr,momentum=0.9)
train_loss = []

for epoch in range(epochs):
    sum_loss = 0
    for batch_idx,(x,y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
        pred = net(x)

        optimizer.zero_grad()
        loss = loss_func(pred, y)
        loss.backward()
        optimizer.step()

        sum_loss += loss.item()
        train_loss.append(loss.item())

        print(["epoch:%d , batch:%d , loss:%.3f" % (epoch, batch_idx,loss.item())])
    torch.save(net.state_dict(),'model/no_pretrain/epoch' + str(epoch+1) + '.pth')
from utils import plot_curve
plot_curve(train_loss)

test:

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/4 下午1:29

import torch
import torchvision
from dataload.COVID_Dataload import COVID
# 定义使用GPU
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torchvision.transforms as transforms
from resnet50 import my_resnet50
transform = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]),
    ])

test_dataset = COVID(train=False,transformer=transform)
test_loader = DataLoader(test_dataset,
                         batch_size = 32,
                         shuffle = False,
                         )




def predict():
    net = my_resnet50().to(device)
    net.load_state_dict(torch.load('/home/lwf/code/pytorch学习/ResNet/resnet新冠病毒确诊的预测/model/no_pretrain/epoch200.pth'))
    print(net)
    total_correct = 0
    for batch_idx, (x, y) in enumerate(test_loader):
        # x = x.view(x.size(0),28*28)
        # x = x.view(256,28,28)
        x = x.to(device)
        print(x.shape)
        y = y.to(device)
        print('y',y)
        out = net(x)
        # print(out)
        pred = out.argmax(dim=1)
        print('pred',pred)
        correct = pred.eq(y).sum().float().item()
        total_correct += correct
    total_num = len(test_loader.dataset)

    acc = total_correct / total_num
    print("test acc:", acc)


predict()

在这里插入图片描述

predict

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/4 下午2:38

##读入文件,显示正确分类和预测分类
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from resnet50 import my_resnet50

transform = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]),
    ])


file_name = input("输入要预测的文件名:")
img = Image.open(file_name).convert("RGB")
show_img = img
img = transform(img)
#
# print(img)
# print(img.shape)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img = img.to(device)
img = img.unsqueeze(0)
net = my_resnet50().to(device)
net.load_state_dict(torch.load(r'model/no_pretrain/epoch200.pth'))

pred = net(img)
print(pred)
print(pred.argmax(dim = 1).cpu().numpy()[0])
res = ''
if pred.argmax(dim = 1) == 0:
    res += 'pred:no_covid'
else:
    res += 'pred:covid'

plt.figure("Predict")
plt.imshow(show_img)
plt.axis("off")
plt.title(res)
plt.show()

在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-10 12:23:57  更:2021-11-10 12:26:08 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 7:00:24-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码