IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Libtorch之一元线性拟合(C++/Qt) -> 正文阅读

[人工智能]Libtorch之一元线性拟合(C++/Qt)

?简介

?????? 一元线性回归是机器学习算法中最简单、直观的算法,对于学习深度学习,虽然简单但必学,是学习深度学习的 “hello word”,本节通过Libtorch来实现这个算法,所参考的Pytorch例子为书上的案例。至于一些基本的理论知识在博客中不会提到。一是因为网上有许多大佬的帖子或博客写的很好。二是本人目的想通过C++实现深度学习(目前主要是复现pytorch案例),兴趣于此。所以基本上都是直接给代码。希望能给大家带来一点参考的价值。

1.神经网络的创建:

class LinearRegression : public torch::nn::Module
{
public:
    //线性回归网络层,即网络构造函数
    LinearRegression() {
        linear = register_module("linear", torch::nn::Linear(torch::nn::LinearOptions(1, 1)));
    }

    // 向前传播,即平差中的误差方程的构建
    torch::Tensor forward(torch::Tensor x) {
        x = linear(x);
        return x;
    }
    
private:
    torch::nn::Linear linear{ nullptr };
};

2.训练数据的创建

Y与X符合下面的公式,b带有偶然误差。

y = 5*x+b

//创建数据
    torch::Tensor x = torch::unsqueeze(torch::linspace(0, 19, 20), 1);
    auto y = 5 * x + torch::randint(1, 20,x.sizes());

3.训练,模型的生成

// 新建模型、误差函数、优化器
    auto model = std::make_shared<LinearRegression>();
    torch::optim::SGD optimizer(model->parameters(), 0.001);
    torch::nn::MSELoss criterion;
    
    //开始训练 
    int num_epoches = 10;
    for (int i = 0; i < num_epoches; i++)
    {
        auto out = model->forward(x);
        auto loss = criterion(out, y);
        optimizer.zero_grad();
        loss.backward();
        optimizer.step();

        std::cout << "Epche: " << i + 1 << "/" << num_epoches << "\tloss: " << loss.item<float>() << std::endl;
        if ((i + 1) % 2 == 0)
        {
            QVector<torch::Tensor > Ys;
            Ys.append(out);
            Ys.append(y);
            QString winTile = "Epche: " + QString::number(i + 1) + "\tloss: " + QString::number(loss.item<float>());
            drawChart(x, Ys,winTile);
        }
        
    }

4.拟合结果(打印的loss与拟合线)

?总结: 上面只给出了模型的创建与训练生成,至于模型的保存与使用方法在上一篇博客已经给出。另外画拟合线的工具仍然用的Qt的QCustomplot库,上文中我自己写的drawChart()函数如果有需要的可以通过下面链接自己下载,虽然是一个画图函数,但里面也包括了tensor与其他类型变量的转换。下一篇写Libtorch之多项式拟合.......。

《使用QCustomplot画拟合线代码》

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-26 08:51:52  更:2021-11-26 08:52:55 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 4:20:25-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码