libtorch需要使用c++14的编译器
下载LibTorch,官网下载 https://pytorch.org/
调试的时候建议使用cpu+debug版本,到实际使用部署的时候再切换为cuda版本。
helloworld
#include "torch/library.h"
#include "torch/script.h"
int main()
{
torch::Tensor output = torch::randn({ 3,2 });
std::cout << output;
return 0;
}
output
-1.9173 -0.5073
1.5614 -0.0566
-0.0884 0.9237
[ CPUFloatType{3,2} ]
定义module
class LeNet5 : public torch::nn::Module
{
public:
torch::nn::Conv2d C1;
torch::nn::Conv2d C3;
torch::nn::Linear F5;
torch::nn::Linear F6;
torch::nn::Linear OUTPUT;
public:
LeNet5()
:
// (inputchannel, outputchannel, kernel_size)
C1(torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5).padding(2).padding_mode(torch::kCircular))),
C3(torch::nn::Conv2d(6, 16, 5)),
F5(torch::nn::Linear(16 * 5 * 5, 120)),
F6(torch::nn::Linear(120, 84)),
OUTPUT(torch::nn::Linear(84, 10))
{
register_module("C1", C1);
register_module("C3", C3);
register_module("F5", F5);
register_module("F6", F6);
register_module("OUTPUT", OUTPUT);
}
torch::Tensor forward(torch::Tensor input)
{
namespace F = torch::nn::functional;
auto op_maxpool = F::MaxPool2dFuncOptions(2);
// forward
auto
x = F::max_pool2d(F::relu(C1(input)), op_maxpool);
x = F::max_pool2d(F::relu(C3(x)), op_maxpool);
x = x.view({ -1, num_flat_features(x) });
x = F::relu(F5(x));
x = F::relu(F6(x));
x = OUTPUT(x);
return x;
}
private:
long num_flat_features(torch::Tensor x) {
// x.size()[1:] # all dimensions except the batch dimension
auto size = x.sizes(); // TODO: Revisit this
auto num_features = 1;
for (auto s : size) {
num_features *= s;
}
return num_features;
}
};
定义Dataset
class img_loader : public torch::data::datasets::Dataset<img_loader>
{
public:
img_loader(std::string maping_filename, int channel, int height, int width);
torch::data::Example<> get(size_t index);
torch::optional<size_t> size() const;
long get_num_classes();
torch::Tensor img_to_tensor(cv::Mat img);
private:
std::vector<std::string> images_list;
std::vector<long> labels_list;
long num_classes;
int out_channels;
int out_height;
int out_width;
int flag_open;
};
img_loader::img_loader(std::string maping_filename, int channel, int height, int width)
{
out_channels = channel;
out_height = height;
out_width = width;
assert(channel == 1 || channel == 3);
flag_open = 0;
if (channel > 1) flag_open = 1;
std::ifstream input_file;
try
{
input_file.open(maping_filename);
if (input_file.is_open())
{
std::string image_path;
long label;
num_classes = 0;
long current_class = -1;
while (input_file >> image_path >> label)
{
if (current_class != label)
{
num_classes++;
current_class = label;
}
images_list.push_back(image_path);
labels_list.push_back(label);
}
}
else
{
std::cerr << "Error: can't open file please make sure input_file path is correct\\n";
exit(-2);
}
}
catch (const std::exception& e)
{
std::cerr << e.what() << '\\n';
exit(-3);
}
}
torch::data::Example<> img_loader::get(size_t index)
{
cv::Mat img = cv::imread(images_list[index], flag_open);
auto tdata = img_to_tensor(img);
auto tlabel = torch::from_blob(&labels_list[index], { 1 }, torch::kLong);
return { tdata, tlabel };
}
torch::optional<size_t> img_loader::size() const
{
return images_list.size();
}
long img_loader::get_num_classes()
{
return num_classes;
}
torch::Tensor img_loader::img_to_tensor(cv::Mat img)
{
assert(!img.empty());
cv::resize(img, img, cv::Size(out_width, out_height));
img.convertTo(img, CV_32FC1);
img = (img - 127.5) / 128.0;
if (img.channels() == 1)
{
auto channel = torch::from_blob(
img.ptr(),
{ out_height, out_width },
torch::kFloat);
/*return torch::cat({ channel, channel, channel })
.view({ 3, out_height, out_width })
.to(torch::kFloat);*/
return channel;
}
std::vector<cv::Mat> channels(3);
cv::split(img, channels);
auto R = torch::from_blob(
channels[2].ptr(),
{ out_height, out_width },
torch::kFloat);
auto G = torch::from_blob(
channels[1].ptr(),
{ out_height, out_width },
torch::kFloat);
auto B = torch::from_blob(
channels[0].ptr(),
{ out_height, out_width },
torch::kFloat);
return torch::cat({ B, G, R })
.view({ 3, out_height, out_width })
.to(torch::kFloat);
}
训练流程
auto dataset2 = img_loader(R"(F:\\Workdatas\\mnist\\train.txt)", 3, 28, 28);
auto data_loader = torch::data::make_data_loader(std::move(dataset2), torch::data::DataLoaderOptions(1).workers(4));
auto criterion = torch::nn::CrossEntropyLoss();
auto optimizer = torch::optim::SGD(net.parameters(), torch::optim::SGDOptions(0.001).momentum(0.9));
训练循环
for(auto& batch : *data_loader)
optimizer.zero_grad();
auto outputs = net.forward(inputs);
auto loss = criterion(outputs, labels);
loss.backward();
optimizer.step();
|