IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Resnet18卷积神经网络实现图片分类算法(代码全注释) -> 正文阅读

[人工智能]Resnet18卷积神经网络实现图片分类算法(代码全注释)

1.类的定义

import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1#是否可以调用

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:#残差结构实虚判断
            identity = self.downsample(x)#虚函数

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,#残差结构
                 blocks_num,#残差结构数目
                 num_classes=1000,#训练集分类数目
                 include_top=True,#复杂可选
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64#卷积核个数

        self.groups = groups#不用
        self.width_per_group = width_per_group#不用

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,#输入通道数 卷积核个数
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)#通道个数和卷核个数一样
        self.relu = nn.ReLU(inplace=True)#激活
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)#池化
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:#默认为Ture
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)#平均池化
            self.fc = nn.Linear(512 * block.expansion, num_classes)#全连接层 512*1 分类数

        for m in self.modules():#卷积层初始化!!!
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):#通过上面定义好的残差结构造层 残差结构 卷积核个数 残差结构个数 步长默认为一
        downsample = None#默认实
        if stride != 1 or self.in_channel != channel * block.expansion:#用不到
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,#第一层虚线残差结构
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):#后面全是实线残差结构
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)#*列表去括号

    def forward(self, x):#正向传播
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)#压平 变成一维矩阵
            x = self.fc(x)#全连接

        return x


def resnet18(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)

2.训练分类模型

?

import os
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm

from sun import resnet18


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#优先使用Gpu0如果有 没有则cpu
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),#随机裁剪224*224
                                     transforms.RandomHorizontalFlip(),#随机翻转
                                     transforms.ToTensor(),#转化成矩阵吧 !!
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#标准化处理
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    #data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  #os.getcwd()获取当前文件目录 ..返回上一级目录../..返回上两级目录 abspath()返回绝对路径
    #data_root=r"D:\artificial intelligence\Cat and dog recognition\project_data"
    #image_path = os.path.join(data_root, "train1")  # flower data set path
    image_trpath_folder=r"D:\artificial intelligence\Cat and dog recognition\project_data\train"
    #assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=image_trpath_folder,#图片文件夹加载
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx#获取图片类型的字典
    cla_dict = dict((val, key) for key, val in flower_list.items())#将字典反过来
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)#编码成json
    with open('class_indices.json', 'w') as json_file:#写入json文件
        json_file.write(json_str)

    batch_size = 16#批处理数量处理
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,#加载图片
                                               batch_size=batch_size, shuffle=True,#批处理数量
                                               # 16
                                               num_workers=nw)#!!!

    image_vapath_folder=r"D:\artificial intelligence\Cat and dog recognition\project_data\test"#测试文件夹加载
    validate_dataset = datasets.ImageFolder(root=image_vapath_folder,#预处理
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,#图片加载
                                                  batch_size=batch_size, shuffle=False,#不洗牌
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    net = resnet18()#resnet网络对象
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet18_pre.pth"
    #assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    #net.load_state_dict(torch.load(model_weight_path, map_location=device))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 2)#全连接层
    net.to(device)

    # 损失函数
    loss_function = nn.CrossEntropyLoss()

    #
    #params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(net.parameters(), lr=0.0001)#学习rate优化器

    epochs = 20#训练次数
    best_acc = 0.0#准确rate初始化
    save_path = './resNet18.pth'#权重保存
    train_steps = len(train_loader)
    for epoch in range(epochs):#
        # train
        net.train()#训练
        running_loss = 0.0
        train_bar = tqdm(train_loader)#添加训练进度条 返回迭代器
        """ 
        enumerate()
        names = ["Alice","Bob","Carl"]
        for index,value in enumerate(names):
            print(f'{index}: {value}')
        

        """
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()#清空之前的梯度信息
            logits = net(images.to(device))#这参数
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        net.eval()# 启用dropout方法
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

3.利用训练好的模型进行图片的识别和分类?

import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import json
from sun import resnet18
data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
preimage_path=r"D:\artificial intelligence\Cat and dog recognition\project_data\test\test1\182.jpg"
image=Image.open(preimage_path)
plt.imshow(image)#显示图片格式
image=data_transform(image)
image=torch.unsqueeze(image,dim=0)#扩充维度batch 第0个维度一张图片3通道 224*224
try:
    json_file=open('./class_indices.json','r')
    class_indict=json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

model=resnet18(num_classes=2)
model_weight_path="./resNet18.pth"
model.load_state_dict(torch.load(model_weight_path))#载入模型
model.eval()#关闭dropout

with torch.no_grad():#默认进行反向传播 这个不进行
    output=torch.squeeze(model(image))#模型载入图片计算并压缩
    pre_list=torch.softmax(output,dim=0)#在列上进行概率计算
    pre=torch.argmax(pre_list).numpy()#取最大
print(class_indict[str(pre)],pre_list[pre].item())
plt.show()

总结:本人第一次写博文,希望大家多多支持,有不懂得地方可以找我交流或者评论区留言。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-03 11:10:51  更:2021-08-03 11:13:56 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 14:47:48-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码