tensor与vector的转换,是通过数据的指针来完成的。下面以ATen为例讲解,其他的如torch等只是命名空间不一样,其他的是一样的。
tensor转vector
#include<ATen/ATen.h>//引入头文件
#include<iostream>
using namespace std;
int main(){
at::Tensor t=at::ones({2,2},at::kInt);//建立一个2X2的tensor
vector<int> v(t.data_ptr<int>(),t.data_ptr<int>()+t.numel());//将tensor转换为vector
//输出转换后的结果
for(auto val:v){
cout<<val<<" ";
}
cout<<endl;
}
t是一个类型为at::kInt的tensor,其中kInt可以用其他数据类型替换如kFloat等,t.data_ptr<int>()返回int类型的指针,返回的地址是数据存储的起始位置。t.numel()返回t中的元素个数。
vector转tensor
#include<ATen/ATen.h>
#include<iostream>
using namespace std;
int main(){
vector<int> v={1,2,3,4};
at::TensorOptions opts=at::TensorOptions().dtype(at::kInt);
c10::IntArrayRef s={2,2};//设置返回的tensor的大小
at::Tensor t=at::from_blob(v.data(),s,opts).clone();
cout<<t<<endl;
}
opts用来对返回的tensor作一些额外的解释,例如类型。s用来指定返回的tensor的维度。clone是为了深复制,让后面对t的操作不会收到v的影响。v.data()返回vector中的数据的指针。
|