torch 和 C++互相调用
需要安装torch即可,在linux环境下实验通过。 torch.utils.cpp_extension 通过pybind11实现C++和python互相通信。 在ninja框架下,构建即时代码(JIT),只需要第一次编译C++
代码例子
展示了python如何和C++后端相互调用和传递列表和torchTensor
CPP.cpp
#include <torch/extension.h>
#include <iostream>
#include <string>
#include <iterator>
struct Pet {
Pet(const std::string &name) : name(name) { }
void setName(const std::string &name_) { name = name_; }
const std::string &getName() const { return name; }
std::string name;
};
using PetList = std::vector<Pet>;
PYBIND11_MAKE_OPAQUE(std::vector<Pet>)
void addAndprintPet()
{
PetList petlist;
petlist.push_back(Pet("CatCpp"));
py::object addPet=py::module::import("PY").attr("addPet");
addPet(&petlist);
for (auto pet:petlist)
{
std::cout<<"from CPP "<<pet.getName()<<std::endl;
}
}
void printList()
{
py::list a;
a.append(123);
py::object addNumer=py::module::import("PY").attr("addNumer");
py::list b = addNumer(a);
for (auto number:b)
{
std::cout<<"from CPP "<<number.cast<int>()<<std::endl;
}
}
torch::Tensor TensorAdd(const torch::Tensor &a,const torch::Tensor &b)
{
return a+b;
}
void mainFun()
{
addAndprintPet();
printList();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<PetList>(m, "PetList")
.def(py::init<>())
.def("pop_back", &PetList::pop_back)
.def("push_back", (void (PetList::*)(const Pet&)) &PetList::push_back)
.def("__len__", [](const PetList &v) { return v.size(); })
.def("__iter__", [](PetList &v) {
return py::make_iterator(v.begin(), v.end());
}, py::keep_alive<0, 1>());
py::class_<Pet>(m, "Pet")
.def(py::init<const std::string &>())
.def("setName", &Pet::setName)
.def("getName", &Pet::getName)
.def("__repr__", [](const Pet& u) { return u.getName(); });
m.def("mainFun", &mainFun, "mainFun");
m.def("TensorAdd", &TensorAdd, "TensorAdd");
}
PY.py
import os
import torch
from torch.utils.cpp_extension import load
dir = os.path.dirname(os.path.realpath(__file__))
CPP = load(
name="CPP",
sources=[os.path.join(dir, "CPP.cpp")],
verbose=False)
def addPet(petlist):
# petlist.pop_back()
for p in petlist:
print('from PY',p) # petlist 是PetList类型
petlist.push_back(CPP.Pet('CatPy'))
def addNumer(numlist):
print('from PY',numlist)
return numlist+[1234]
if __name__=='__main__':
# 调用CPP的函数addAndprintPet,printList
CPP.mainFun()
# 定义类
p = CPP.Pet("Molly")
print(p)
print(p.getName())
p.setName("Charly")
print(p.getName())
# Tensor加法
print(CPP.TensorAdd(torch.zeros((3,3)),torch.ones(3,3)))
运行结果:(verbose=True)
|