IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch中tqdm进度条的使用 -> 正文阅读

[人工智能]Pytorch中tqdm进度条的使用

参考文章

  • https://zhuanlan.zhihu.com/p/378474516

代码示例

def train(args):
    # 获取训练数据
    datasets = pre_process()
    data = datasets[0].to(device)
    # 定义模型
    model = GAT(datasets.num_features, datasets.num_classes, [200, 100]).to(device)
    # 定义损失函数
    criterion = nn.CrossEntropyLoss()
    # 定义优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    for epoch in range(args.epochs):

        loop = tqdm(range(300), total=300)
        for iteration in loop:
            model.train()
            # 初始化梯度
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            # 计算loss
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            # 反向传播
            loss.backward()
            # 更新梯度
            optimizer.step()
            # 更新信息
            loop.set_description(f'Epoch [{epoch}/{args.epochs}]')
            loop.set_postfix(loss=loss.item())

案例1

# -*- coding: utf-8 -*-
# @Time    : 2022/5/5 16:35
# @Author  : 王天赐
# @Email   : 15565946702@163.com
# @File    : model.py
# @Software: PyCharm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from dataset import get_data
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
from tqdm import tqdm


class GraphCNN(nn.Module):

    def __init__(self, in_channels, out_channels, hidden_channels):
        super(GraphCNN, self).__init__()
        self.conv1 = pyg_nn.GCNConv(in_channels, hidden_channels)
        self.conv2 = pyg_nn.GCNConv(hidden_channels, out_channels)
        pass

    def forward(self, data):
        # data.x, data.edge_index
        x = data.x  # [N, F] N个节点,F个特征
        edge_index = data.edge_index  # [2, E] E个边
        hid = self.conv1(x=x, edge_index=edge_index)  # [N, hidden_channels]
        hid = F.relu(hid)
        out = self.conv2(x=hid, edge_index=edge_index)  # [N, out_channels]
        out = F.log_softmax(out, dim=1)  # softmax 在第一维进行归一化, 生成一个 [out_channels] 的向量, 其中每个元素的值表示该类别的概率
        return out


def train():
    cora_dataset = get_data()
    model = GraphCNN(in_channels=cora_dataset.num_features, out_channels=cora_dataset.num_classes, hidden_channels=16)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    data = cora_dataset[0].to(device)

    # 定义优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    # 定义损失函数
    loss_fn = nn.CrossEntropyLoss()  # 对于多分类问题, 可以使用 nn.CrossEntropyLoss()

    # 训练
    for item in range(10):
        # 存储 epoch 和 loss
        epochs = []
        losses = []
        accs = []

        # 进度条

        loop = tqdm(range(100), total=100, desc='train')
        for epoch in loop:
            model.train()
            optimizer.zero_grad()
            out = model(data)
            # 计算损失
            loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()

            epochs.append(epoch + 1)
            losses.append(loss.item())

            # 测试
            model.eval()
            _, pred = model(data).max(dim=1)  # 获取预测结果类别概率中的最大值
            acc = metrics.accuracy_score(y_true=data.y[data.test_mask].cpu(), y_pred=pred[data.test_mask].cpu())
            accs.append(acc)

            # 更新信息
            loop.set_description(f'Item [{item + 1}/{10}] Epoch [{epoch + 1}/{100}]')
            loop.set_postfix(loss=loss.item(), acc=acc)

        # 绘制 loss 曲线
        plt.plot(epochs, losses)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        # 绘制 acc 曲线
        plt.plot(epochs, accs)
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.show()


if __name__ == '__main__':
    train()

效果如下 :
在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-08 08:05:07  更:2022-05-08 08:09:09 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 6:27:14-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码