IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于libtorch的迁移学习的C++实现 -> 正文阅读

[人工智能]基于libtorch的迁移学习的C++实现

基于libtorch的迁移学习的C++实现

本文探索了使用C++语言进行迁移学习的方法, 基于Resnet18实现的猫狗大战.

准备数据集

首先将数据集划分为训练集与测试集两部分, 训练集共4554张图片, 测试集共446张图片, 如下所示:

.
├── test
│   ├── cat
│   │   ├── cat.1000.jpg
│   │   ├── ...
│   │   └── cat.99.jpg
│   └── dog
│       ├── dog.1005.jpg
│       ├── ...
│       └── dog.988.jpg
└── train
    ├── cat
    │   ├── cat.0.jpg
    │   ├── ...
    │   └── cat.9.jpg
    └── dog
        ├── dog.0.jpg
        ├── ...
        └── dog.9.jpg

数据集加载

数据集是torch::data::Dataset的一个子类, 需要重写get和size函数, 完整类如下:

class CatDogDataset :  public torch::data::Dataset<CatDogDataset>
{
private:
	std::vector<torch::Tensor> images, labels;
public:
	explicit CatDogDataset(const std::string& path);
	torch::data::Example<> get(size_t index) override;
	[[nodiscard]] torch::optional<size_t> size() const override;

private:
	static std::vector<torch::Tensor> ProcessImages(const std::vector<std::string>& imageList);
	static std::vector<torch::Tensor> ProcessLabels(const std::vector<int>& labelList);
	void LoadDataFromFolder(const std::vector<std::string>& foldersPaths);
};

检测数据集目录下的全部图片, 并将图片转化为{224, 244, 3}的torch::kF32张量, 将标签转换为{1}的torch::kLong张量, 代码如下:

CatDogDataset::CatDogDataset(const string& path)
{
	string newPath;
	if(*(path.end()-1) != '/')
	{
		newPath = path + '/';
	}
	vector<string> paths{newPath+"cat", newPath+"dog"};
	LoadDataFromFolder(paths);
}
void CatDogDataset::LoadDataFromFolder(const std::vector<std::string> &foldersPaths)
{
	vector<string> imagesList;
	vector<int> labelsList;
	int label = 0;

	for(const auto& pathStr : foldersPaths)
	{
		filesystem::path path(pathStr);
		cout << "正在读取目录: " << pathStr << '/' << endl;
		if(!filesystem::exists(path))
		{
			cout << "该路径不存在: " << pathStr << '/' <<  endl;
			continue;
		}
		filesystem::directory_entry dir(path);
		if(dir.status().type() != filesystem::file_type::directory)
		{
			cout << "该路径对应的不是一个目录: " << pathStr << '/' <<  endl;
			continue;
		}
		filesystem::directory_iterator fileList(path);
		for(auto& item : fileList)
		{
			string fileName = item.path().filename();
			if(fileName.length() > 4 && fileName.substr(fileName.length() - 3) == "jpg")
			{
				imagesList.push_back(item.path().string());
				labelsList.push_back(label);
			}
		}
		label += 1;
	}
	images = ProcessImages(imagesList);
	labels = ProcessLabels(labelsList);
	cout << "数据集已建立, 包含" << labels.size() << "张图片." << endl;
}

vector<torch::Tensor> CatDogDataset::ProcessImages(const vector<string>& imageList)
{
	vector<torch::Tensor> states;
	for(auto& image : imageList)
	{
		Mat img = cv::imread(image, IMREAD_UNCHANGED);
		resize(img, img, cv::Size(224, 224), cv::INTER_CUBIC);
		torch::Tensor imgTensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte);
		imgTensor = imgTensor.toType(torch::kF16);
		imgTensor = imgTensor.permute({2, 0, 1});
		states.push_back(imgTensor.clone());
	}
	return states;
}

vector<torch::Tensor> CatDogDataset::ProcessLabels(const vector<int>& labelList)
{
	vector<torch::Tensor> labels;
	for(auto label : labelList)
	{
		labels.push_back(torch::full({1}, label, torch::kLong));
	}
	return labels;
}

完成get与size函数:

torch::data::Example<> CatDogDataset::get(size_t index)
{
	torch::Tensor sampleImg = images.at(index);
	torch::Tensor sampleLabel = labels.at(index);
	return {sampleImg.clone(), sampleLabel.clone()};
}

torch::optional<size_t> CatDogDataset::size() const
{
	return labels.size();
}

准备训练的网络

准备去掉最后一层的Resnet18网络

使用python与pytorch加载Resnet18网络, 去掉最后一层并导出, 代码如下:

import torch
from torchvision import models

model = models.resnet18(pretrained=True)

for param in model.parameters():
	param.requires_grad = False

resnet18 = torch.nn.Sequential(*list(model.children())[:-1])

example_input = torch.rand(1, 3, 224, 224)
script_module = torch.jit.trace(resnet18, example_input)
script_module.save('resnet18_without_last_layer.pt')

加载网络并增加一个新的全连接层

libtorch中网络是torch::nn::Module的子类, 增加新的一层全连接层本身即为torch::nn::Module的子类, 不在单独完成一个torch::nn::Module的子类.

代码实现如下:

auto resnet18WithoutLastLayer = torch::jit::load("model/resnet18_without_last_layer.pt");
torch::nn::Linear lastLayer(512, 2);
torch::optim::Adam opt(lastLayer->parameters(), torch::optim::AdamOptions(1e-3));

完成训练函数

将网络与数据转至torch::kCUDA设备与torch::kF32格式, 从data_loader中加载数据进行进行正向传递, 对误差进行反向传递, 训练五轮, 当模型准确度提高时, 保存模型.

代码实现如下:

template<typename Dataloader>
void train(torch::jit::script::Module net, torch::nn::Linear lin, Dataloader& data_loader, torch::optim::Optimizer& optimizer, size_t dataset_size, torch::Device device)
{
	float best_accuracy = 0.0;
	int batch_index = 0;

	net.to(device);
	lin->to(device);
	net.to(torch::kF32);
	lin->to(torch::kF32);

	for(int i=0; i < 5; i++)
	{
		float mse = 0;
		float Acc = 0.0;

		for(auto& batch: *data_loader)
		{
			auto data = batch.data;
			auto target = batch.target.squeeze();

			data = data.to(torch::kF32).to(device);
			target = target.to(torch::kLong).to(device);

			std::vector<torch::jit::IValue> input;
			input.push_back(data);
			optimizer.zero_grad();

			auto output = net.forward(input).toTensor();
			output = output.view({output.size(0), -1});
			output = lin(output);

			auto loss = torch::nll_loss(torch::log_softmax(output, 1), target);

			loss.backward();
			optimizer.step();

			auto acc = output.argmax(1).eq(target).sum();

			Acc += acc.template item<float>();
			mse += loss.template item<float>();

			batch_index += 1;
		}

		mse = mse/float(batch_index);
		std::cout << "Epoch: " << i  << ", " << "Accuracy: " << Acc/dataset_size << ", " << "MSE: " << mse << std::endl;

		if(Acc/dataset_size > best_accuracy)
		{
			best_accuracy = Acc/dataset_size;
			std::cout << "Saving model" << std::endl;
			net.save("model/model.pt");
			torch::save(lin, "model/model_linear.pt");
		}
	}
}

完成测试函数

从测试集中加载数据, 转至torch::kCUDA设备与torch::kF32格式, 进行正向传递, 对输出求最大参数, 与标签进行对比, 将正确的数量与测试集大小进行对比, 得到准确率.

代码实现如下:

template<typename Dataloader>
void test(torch::jit::script::Module network, torch::nn::Linear lin, Dataloader& loader, size_t data_size, torch::Device device)
{
	network.eval();
	network.to(device);
	lin->to(device);
	network.to(torch::kF32);
	lin->to(torch::kF32);

	float Acc = 0;

	for (const auto& batch : *loader)
	{
		auto data = batch.data;
		auto targets = batch.target.squeeze();

		data = data.to(torch::kF32).to(device);
		targets = targets.to(torch::kLong).to(device);

		std::vector<torch::jit::IValue> input;
		input.push_back(data);

		auto output = network.forward(input).toTensor();
		output = output.view({output.size(0), -1});
		output = lin(output);

		auto acc = output.argmax(1).eq(targets).sum();
		Acc += acc.template item<float>();
	}

	std::cout << "Accuracy:" << Acc/data_size << std::endl;
}

进行训练

书写主函数, 加载数据集, 进行训练, 代码实现如下:

int main(int argc, char** argv)
{
	auto trainDataset = CatDogDataset("image/train").map(torch::data::transforms::Stack<>());
	int trainDatasetSize = trainDataset.size().value();

	torch::manual_seed(1);

	auto resnet18WithoutLastLayer = torch::jit::load("model/resnet18_without_last_layer.pt");
	torch::nn::Linear lastLayer(512, 2);
	torch::optim::Adam opt(lastLayer->parameters(), torch::optim::AdamOptions(1e-3));

	auto trainLoader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(std::move(trainDataset), 256);

	train(resnet18WithoutLastLayer, lastLayer, trainLoader, opt, trainDatasetSize, torch::kCUDA);
	return 0;
}

进行测试

书写主函数, 加载数据集, 进行测试.

int main(int argc, char** argv)
{
	auto testDataset = CatDogDataset("image/test").map(torch::data::transforms::Stack<>());
	int testDatasetSize = testDataset.size().value();

	torch::jit::script::Module model;
	model = torch::jit::load("model/model.pt");

	torch::nn::Linear model_linear(512, 2);
	torch::load(model_linear, "model/model_linear.pt");

	auto testLoader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(std::move(testDataset), 256);

	test(model, model_linear, testLoader, testDatasetSize, torch::kCUDA);

	return 0;
}

附录

完整工程文件见gitee仓库

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-19 01:09:21  更:2022-02-19 01:09:50 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 11:07:28-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码