IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch识别mnist手写数据集 -> 正文阅读

[人工智能]pytorch识别mnist手写数据集

对比多方教程,自己总结如何入门使用pytorch学习搭建基础网络模型进行训练和测试

总结:

1.准备数据

这部分将会用到相关dataset和Dataloader2

#.1准备数据,这些需要准备dataset.DataLoader
  #数据的准备
train_data =datasets.MNIST("./data/train",train=True,transform= torchvision.transforms.ToTensor(),
                                       download=True)

test_data = datasets.MNIST("./data/test",train=False,transform= torchvision.transforms.ToTensor()
                            ,download=True)

print("训练集的长度:{}".format(len(train_data)))
print("测试集的长度:{}".format(len(test_data)))

  #数据迭代器的准备
train_loader=  DataLoader(train_data,batch_size=4,shuffle=True)
test_loader = DataLoader(test_data,batch_size=4)

2.超参数的设定

?

#2.超参数的设定
epoch = 20
bactch_size =16
device = "cuda" if torch.cuda.is_available() else "cpu"
start_time = time.time()

3.创建网络?

#3.构建模型,这里可以使用torch构造一个深层的神经网络
class MnistNet(nn.Module):
    def __init__(self):
        super(MnistNet, self).__init__()
        # self.fc1 = nn.Linear(28 * 28 * 1, 28)  # 定义Linear的输入和输出的形状
        # self.fc2 = nn.Linear(28, 10)  # 定义Linear的输入和输出的形状

        self.flatten = nn.Flatten()  #从第一维到最后一维
        self.liner_relu_nn = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def  forward(self,x):
        # x = x.view(-1,28*28*1)  #拉直  展平相当于flatten
        x = self.flatten(x)
        Y = self.liner_relu_nn(x)
        return  Y



 #实例化模型
modelNet = MnistNet().to(device)
# print(model)

?4.优化器的定义

#4.优化器的定义
loss_fn = nn.CrossEntropyLoss().to(device)   #交叉熵损失函数
optimizer = torch.optim.SGD(modelNet.parameters() , lr = 1e-4)  #SGD优化器

?5.训练定义

#5. 模型训练的定义
def train_mnist(train_loader,model,loss_fn,optimizer):
    '''训练过程写成函数'''
    modelNet.train(True)  #nn.module.train()方法表示训练中会使用到Dropout,BN等方法

    total_train_step = 0  #记录训练轮次
    #idx次数,也就是img的索引
    for data in train_loader:
        img , target = data
        img,target = img.to(device),target.to(device)

        pred = model(img)    #将图片输入网络预测

        train_loss = loss_fn(pred,target)  #损失函数

        #反向传播
        optimizer.zero_grad()  #梯度清零
        train_loss.backward()
        optimizer.step()  #参数的优化更新

        total_train_step += 1
        if total_train_step %1000 == 0:
            print("训练的轮次:{},loss:{}".format(total_train_step,train_loss.item()))

6.测试定义


#6.模型的验证
def test_mnist(test_loader,model,device):
    #模型的验证
    model.eval()  #eval不会使用dropout,BN等方法,符合测试的要求
    correct = 0   #总和的正确率
    test_loss = 0 #test总的测试损失
    with torch.no_grad():   #不用也不会计算梯度,也不会进行反向传播
        for data , target in test_loader:
            data,target = data.to(device) , target.to(device)
            #得到测试结果
            pred = model(data)
            #计算测试损失
            test_loss  += loss_fn(pred,target).item()  #必须加item

            #找到概率最大的下标
            # max_pred = torch.max(pred,dim=1)
            # max_pred = pred.argmax(dim=1)
            #累计正确的值
            correct =correct +  (pred.argmax(dim=1) == target)   #土堆方法
            # correct += max_pred.eq(target.view_as(max_pred)).sum().item() #官方方法
            #整个测试集的正确率
        test_loss = test_loss /len(test_loader)
        test_correct =  correct / len(test_loader)
        print("Test---Average:{}   Accuracy:{}".format(test_loss,test_correct))

?7.测试模型

#7.调用train_mnist和test_mnist方法
for i in range(epoch):
    print("----------第{}轮训练开始-------".format(i + 1))
    train_mnist(train_loader, modelNet, loss_fn, optimizer)
    test_mnist(test_loader,modelNet,device)


#5.模型的保存
# torch.save(MnistNet,"Mnist.pth")

?

?

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-06 16:17:41  更:2022-04-06 16:18:12 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 4:20:52-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码