IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 使用LibTorch训练模型(pytorch c++) -> 正文阅读

[人工智能]使用LibTorch训练模型(pytorch c++)

libtorch需要使用c++14的编译器

下载LibTorch,官网下载 https://pytorch.org/

调试的时候建议使用cpu+debug版本,到实际使用部署的时候再切换为cuda版本。

helloworld

#include "torch/library.h"
#include "torch/script.h"

int main()
{
	torch::Tensor output = torch::randn({ 3,2 });
	std::cout << output;
	
	return 0;
}

output

-1.9173 -0.5073
 1.5614 -0.0566
-0.0884  0.9237
[ CPUFloatType{3,2} ]

定义module

class LeNet5 : public torch::nn::Module
{
public:
	torch::nn::Conv2d C1;
	torch::nn::Conv2d C3;
	torch::nn::Linear F5;
	torch::nn::Linear F6;
	torch::nn::Linear OUTPUT;

public:
	LeNet5() 
		:
		// (inputchannel, outputchannel, kernel_size)
		C1(torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5).padding(2).padding_mode(torch::kCircular))),
		C3(torch::nn::Conv2d(6, 16, 5)),
		F5(torch::nn::Linear(16 * 5 * 5, 120)),
		F6(torch::nn::Linear(120, 84)),
		OUTPUT(torch::nn::Linear(84, 10))
	{
		register_module("C1", C1);
		register_module("C3", C3);
		register_module("F5", F5);
		register_module("F6", F6);
		register_module("OUTPUT", OUTPUT);
	}

	torch::Tensor forward(torch::Tensor input)
	{
		namespace F = torch::nn::functional;

		auto op_maxpool = F::MaxPool2dFuncOptions(2);

		// forward
		auto 
		x = F::max_pool2d(F::relu(C1(input)), op_maxpool);
		x = F::max_pool2d(F::relu(C3(x)), op_maxpool);
		x = x.view({ -1, num_flat_features(x) });
		x = F::relu(F5(x));
		x = F::relu(F6(x));
		x = OUTPUT(x);
		
		return x;
	}
	

private:
	long num_flat_features(torch::Tensor x) {
		// x.size()[1:]  # all dimensions except the batch dimension
		auto size = x.sizes(); // TODO: Revisit this
		auto num_features = 1;
		for (auto s : size) {
			num_features *= s;
		}
		return num_features;
	}
};

定义Dataset

class img_loader : public torch::data::datasets::Dataset<img_loader>
{
public:
    img_loader(std::string maping_filename, int channel, int height, int width);
    torch::data::Example<> get(size_t index);
    torch::optional<size_t> size() const;
    long get_num_classes();

    torch::Tensor img_to_tensor(cv::Mat img);

private:
    std::vector<std::string> images_list;
    std::vector<long> labels_list;
    long num_classes;
    int out_channels;
    int out_height;
    int out_width;
    int flag_open;
};
img_loader::img_loader(std::string maping_filename, int channel, int height, int width)
{
    out_channels = channel;
    out_height = height;
    out_width = width;

    assert(channel == 1 || channel == 3);

    flag_open = 0;
    if (channel > 1) flag_open = 1;

    std::ifstream input_file;
    try
    {
        input_file.open(maping_filename);
        if (input_file.is_open())
        {
            std::string image_path;
            long label;
            num_classes = 0;
            long current_class = -1;
            while (input_file >> image_path >> label)
            {
                if (current_class != label)
                {
                    num_classes++;
                    current_class = label;
                }

                images_list.push_back(image_path);
                labels_list.push_back(label);
            }
        }
        else
        {
            std::cerr << "Error: can't open file please make sure input_file path is correct\\n";
            exit(-2);
        }
    }
    catch (const std::exception& e)
    {
        std::cerr << e.what() << '\\n';
        exit(-3);
    }
}

torch::data::Example<> img_loader::get(size_t index)
{
    cv::Mat img = cv::imread(images_list[index], flag_open);
    auto tdata = img_to_tensor(img);
    auto tlabel = torch::from_blob(&labels_list[index], { 1 }, torch::kLong);

    return { tdata, tlabel };
}

torch::optional<size_t> img_loader::size() const
{
    return images_list.size();
}

long img_loader::get_num_classes()
{
    return num_classes;
}

torch::Tensor img_loader::img_to_tensor(cv::Mat img)
{
    assert(!img.empty());
    cv::resize(img, img, cv::Size(out_width, out_height));

    img.convertTo(img, CV_32FC1);
    img = (img - 127.5) / 128.0;

    if (img.channels() == 1)
    {
        auto channel = torch::from_blob(
            img.ptr(),
            { out_height, out_width },
            torch::kFloat);

        /*return torch::cat({ channel, channel, channel })
            .view({ 3, out_height, out_width })
            .to(torch::kFloat);*/

        return channel;
    }

    std::vector<cv::Mat> channels(3);
    cv::split(img, channels);

    auto R = torch::from_blob(
        channels[2].ptr(),
        { out_height, out_width },
        torch::kFloat);
    auto G = torch::from_blob(
        channels[1].ptr(),
        { out_height, out_width },
        torch::kFloat);
    auto B = torch::from_blob(
        channels[0].ptr(),
        { out_height, out_width },
        torch::kFloat);

    return torch::cat({ B, G, R })
        .view({ 3, out_height, out_width })
        .to(torch::kFloat);
}

训练流程

auto dataset2 = img_loader(R"(F:\\Workdatas\\mnist\\train.txt)", 3, 28, 28);
auto data_loader = torch::data::make_data_loader(std::move(dataset2), torch::data::DataLoaderOptions(1).workers(4));
auto criterion = torch::nn::CrossEntropyLoss();
auto optimizer = torch::optim::SGD(net.parameters(), torch::optim::SGDOptions(0.001).momentum(0.9));

训练循环

for(auto& batch : *data_loader)

optimizer.zero_grad();
auto outputs = net.forward(inputs);
auto loss = criterion(outputs, labels);

loss.backward();
optimizer.step();
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-04 17:31:26  更:2021-09-04 17:32:07 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 20:45:20-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码