IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 机器学习笔记4-1:手写数字分类实战 -> 正文阅读

[人工智能]机器学习笔记4-1:手写数字分类实战


本章利用前述PyTorch的基本使用方法,来完成一个对MNIST数据集的手写数字图片分类任务。

1 模型构建

1.1 模型定义

首先我们继承torch.nn.Module类,创建一个自定义的网络模型,继承时至少需要重写两个方法:__init__()forward(x),前者用于模型初始化,定义模型结构;后者定义前向传播的计算步骤,即输入x如何得到输出y。我们的实现如下:

class Model(nn.Module):
    def __init__(self):
        # 调用父类的方法
        super(Model, self).__init__()
        # 定义模型结构
        self.layers = nn.Sequential(
            nn.Flatten(),  # 将训练样本输入转化为一维向量
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)

上述模型结构的隐藏层的层数和神经元数量、激活函数等,都可以自由设置,经过实验选择分类准确率最高的配置。

1.2 训练步骤

下面定义每个epoch所需要做的工作。一共分为4步:前向传播得到输出、根据输出计算损失、反向传播计算导数、根据导数更新模型参数。在实际的实现中,为了消除上一个epoch所计算的导数对当前epoch计算倒数的影响,需要清空之前的epoch累计的梯度。实现如下:

def train_loop(dataloader: DataLoader, model: nn.Module, loss_fn, optimizer: torch.optim.Optimizer):
    size = len(dataloader.dataset)
    for batch, (x, y) in enumerate(dataloader):
        # 前向传播
        y1 = model(x)
        # 计算损失
        loss = loss_fn(y1, y)
        # 清空累计梯度
        optimizer.zero_grad()
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 输出结果
        if batch % 200 == 0:
            loss, current = loss.item(), batch * len(x)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

1.3 测试步骤

模型测试时的步骤与训练相似,只是不用计算梯度、更新参数,于是在这一步可以禁用自动求导,使得运算更加迅速。实现如下:

def test_loop(dataloader: DataLoader, model: nn.Module, loss_fn, optimizer: torch.optim.Optimizer):
    # 初始化损失函数值、分类准确的样本数
    test_loss, correct = 0, 0
    # 获取dataloader的batch总数
    num_batches = len(dataloader)
    # 获取测试数据样本总数
    size = len(dataloader.dataset)
    # 禁用梯度,提高计算速度
    with torch.no_grad():
        # 遍历dataloader中的所有数据,注意x,y依然是一组batch而不是单个样本
        for x, y in dataloader:
            # 前向传播
            y1 = model(x)
            # 计算损失
            test_loss += loss_fn(y1, y).item()
            # 计算预测正确的数量,对一组batch的分类正确数量求和
            correct += (y1.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

1.4 数据加载

加载训练集和测试集的数据,注意将PIL转换为Tensor:

# 读取训练数据
train_data = datasets.MNIST(
    root="./datas/",
    download=False,
    train=True,
    transform=transforms.ToTensor()
)
# 读取测试数据
test_data = datasets.MNIST(
    root="./datas/",
    download=False,
    train=False,
    transform=transforms.ToTensor()
)
# 装载训练和测试数据,并设置batch_size,每次读取乱序
train_loader = DataLoader(train_data, shuffle=True, batch_size=64)
test_loader = DataLoader(test_data, shuffle=True, batch_size=64)

1.5 模型初始化

初始化模型实例,并设置运行的设备,如果计算机上安装了gpu版本的PyTorch,将会使用GPU加速运算:

# 设置要运行的设备,若安装了cuda则使用gpu,否则使用cpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 初始化模型,并设置运行的设备
model = Model().to(device)

1.6 优化器

初始化优化器对象,这里使用SGD,除此之外还有Adam、Adagrad等可以使用。

# 初始化优化器,使用SGD优化器
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

1.7 损失函数

损失函数使用交叉熵损失函数,注意使用交叉熵损失函数时,数据集中的label不必映射为one-hot编码,保持原样使用一个整型值表示第几类即可。

# 设置损失函数,使用交叉熵
loss_fn = nn.CrossEntropyLoss()

1.8 启动训练

训练模型只需要使用一个for循环,按设定好的迭代次数执行即可(下面仅显示最后训练完成输出的结果):

# 开始训练
for i in range(epoch):
    print(f"Epoch {i+1}\n-------------------------------")
    train_loop(train_loader, model, loss_fn, optimizer)
    test_loop(test_loader, model, loss_fn, optimizer)
...
Epoch 5
-------------------------------
 Accuracy: 97.2%, Avg loss: 0.089015 

完整代码请参见MNIST分类-CSDN完整代码资源

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-23 10:47:03  更:2021-07-23 10:50:02 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 11:25:05-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码