IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> (刘二大人)PyTorch深度学习实践-卷积网络(Residual) -> 正文阅读

[人工智能](刘二大人)PyTorch深度学习实践-卷积网络(Residual)

1. Residual Block 的实现

?

1.1 代码展示

import torch
import torch.nn.functional as F

class Residual_Block(torch.nn.Module):
    def __init__(self,channels):
        super(Residual_Block, self).__init__()
        self.channels = channels #因为输入的x与输出的y要进行加法,需要保证他们的 C、W、H都一样

        self.conv_1 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
        self.conv_2 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)

    def forward(self,x):
        y = F.relu(self.conv_1(x))
        y = self.conv_2(y)
        return F.relu(y+x) #对x和y的和再做激活
    

?2. 使用Residual Model 对 Minist数据集进行训练

2.1 代码展示

import torch
from Inception import InceptinA
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets,transforms

#追踪日志
writer = SummaryWriter(log_dir='../LEDR')

#准备数据集
trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3801,))])
train_set = datasets.MNIST(root='E:\learn_pytorch\LE',train=True,transform=trans,download=True)
test_set = datasets.MNIST(root='E:\learn_pytorch\LE',train=False,transform=trans,download=True)

#下载数据集
train_data = DataLoader(dataset=train_set,batch_size=64,shuffle=True)
test_data = DataLoader(dataset=test_set,batch_size=64,shuffle=False)

#构建模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_1 = torch.nn.Conv2d(1,10,kernel_size=5)#输出变成 10x24x24
        self.conv_2 = torch.nn.Conv2d(88,20,kernel_size=5)# 输出变成 20x12x12
        self.mp = torch.nn.MaxPool2d(2)

        self.incept1 = InceptinA(channels=10)
        self.incept2 = InceptinA(channels=20)

        self.fc = torch.nn.Linear(1408,10)

    def forward(self,x):
        x = F.relu(self.mp(self.conv_1(x)))# 输出为 10x12x12
        x = self.incept1(x) #输出是88x12x12
        x = F.relu(self.mp(self.conv_2(x)))# 输出是 20x4x4
        x = self.incept2(x) #输出是 88x4x4
        x = x.view(-1,1408)
        x = self.fc(x)
        return x

#实例化模型
huihui = Net()

#定义损失函数和优化函数
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=huihui.parameters(),lr=0.01,momentum=0.5)

#开始训练
def train(epoch):
    run_loss = 0.0
    for batch_id , data in enumerate(train_data,0):
        inputs , targets = data
        outputs = huihui(inputs)
        loss = criterion(outputs, targets)

        #归零,反馈,更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        run_loss += loss.item()
        if batch_id % 300 == 299:
            print("[%d,%d] loss:%.3f" %(epoch+1,batch_id+1,run_loss/300))
            run_loss = 0.0

def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for data in test_data:
            inputs , labels = data
            outputs = huihui(inputs)
            _,predict = torch.max(outputs,dim=1)
            total += labels.size(0)
            correct += (predict==labels).sum().item()
        writer.add_scalar("The Accuracy1",correct/total,epoch)
        print('[Accuracy] %d %%' % (100*correct/total))

if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

writer.close()



2.2 结果展示(15轮,其实到10轮时就达到99%了)

D:\Anaconda3\envs\pytorch\python.exe E:/learn_pytorch/LE/Residul_Model.py
[1,300] loss: 0.568
[1,600] loss: 0.162
[1,900] loss: 0.110
[Accuracy] 96 %
[2,300] loss: 0.083
[2,600] loss: 0.081
[2,900] loss: 0.076
[Accuracy] 98 %
[3,300] loss: 0.062
[3,600] loss: 0.057
[3,900] loss: 0.056
[Accuracy] 98 %
[4,300] loss: 0.048
[4,600] loss: 0.047
[4,900] loss: 0.047
[Accuracy] 98 %
[5,300] loss: 0.040
[5,600] loss: 0.042
[5,900] loss: 0.040
[Accuracy] 98 %
[6,300] loss: 0.035
[6,600] loss: 0.035
[6,900] loss: 0.035
[Accuracy] 98 %
[7,300] loss: 0.030
[7,600] loss: 0.032
[7,900] loss: 0.032
[Accuracy] 99 %
[8,300] loss: 0.027
[8,600] loss: 0.028
[8,900] loss: 0.027
[Accuracy] 98 %
[9,300] loss: 0.025
[9,600] loss: 0.027
[9,900] loss: 0.024
[Accuracy] 98 %
[10,300] loss: 0.020
[10,600] loss: 0.023
[10,900] loss: 0.024
[Accuracy] 99 %
[11,300] loss: 0.020
[11,600] loss: 0.018
[11,900] loss: 0.022
[Accuracy] 99 %
[12,300] loss: 0.019
[12,600] loss: 0.017
[12,900] loss: 0.019
[Accuracy] 99 %
[13,300] loss: 0.017
[13,600] loss: 0.016
[13,900] loss: 0.017
[Accuracy] 99 %
[14,300] loss: 0.014
[14,600] loss: 0.017
[14,900] loss: 0.016
[Accuracy] 99 %
[15,300] loss: 0.014
[15,600] loss: 0.014
[15,900] loss: 0.015
[Accuracy] 99 %

Process finished with exit code 0

2.3 图像展示?

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-11-05 00:28:48  更:2022-11-05 00:29:39 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 2:53:20-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计