1 TensorRT C++ API支持的模型输入维度
在TensorRT 7.0及以上版本,我们通常使用以下语句指定输入维度:
const std::string input_name = "input";
const std::string output_name = "output";
const int inputIndex = m_TensorRT_Engine->getBindingIndex(input_name.c_str());
const int outputIndex = m_TensorRT_Engine->getBindingIndex(output_name.c_str());
m_TensorRT_Context->setBindingDimensions(inputIndex, Dims3(3, 100, 20));
其中Dims3代表该深度学习模型的输入Tensor的维度为三维tensor,shape为(3,100,20)
一般的深度学习模型,一般的输入维度为(C,H,W),这种输入的维度数据为三维tensor。
另外TensorRT C++ API最高支持Dims4,用于支持4维tensor数据的模型输入。但是随着深度学习框架目前发展的越来越复杂,更多的深度的学习模型需要5维,6维甚至更高维度的tensor作为网络输入,那么如何在现有的TensorRT API去扩展更高维度的输入tensor以满足我们自己的需要呢?
2 扩展TensorRT C++ API 模型输入维度
在TensorRT C++ API的include目录下的NvInferRuntimeCommon.h文件定义了类Class Dims32,
class Dims32
{
public:
static constexpr int32_t MAX_DIMS{8};
int32_t nbDims;
int32_t d[MAX_DIMS];
};
该类用于定义tensor的输入维度,从类定义上看,该类支持的最大维度为8。
在TensorRT C++ API的include目录下的NvInferLegacyDims.h定义了目前TensorRT所指的输入维度:
#ifndef NV_INFER_LEGACY_DIMS_H
#define NV_INFER_LEGACY_DIMS_H
#include "NvInferRuntimeCommon.h"
namespace nvinfer1
{
class Dims2 : public Dims
{
public:
Dims2()
: Dims{2, {}}
{
}
Dims2(int32_t d0, int32_t d1)
: Dims{2, {d0, d1}}
{
}
};
class DimsHW : public Dims2
{
public:
DimsHW()
: Dims2()
{
}
DimsHW(int32_t height, int32_t width)
: Dims2(height, width)
{
}
int32_t& h()
{
return d[0];
}
int32_t h() const
{
return d[0];
}
int32_t& w()
{
return d[1];
}
int32_t w() const
{
return d[1];
}
};
class Dims3 : public Dims
{
public:
Dims3()
: Dims{3, {}}
{
}
Dims3(int32_t d0, int32_t d1, int32_t d2)
: Dims{3, {d0, d1, d2}}
{
}
};
class Dims4 : public Dims
{
public:
Dims4()
: Dims{4, {}}
{
}
Dims4(int32_t d0, int32_t d1, int32_t d2, int32_t d3)
: Dims{4, {d0, d1, d2, d3}}
{
}
};
}
#endif
从上述文件的代码看,构建输入维度只需要继承类Dims,然后按定义进行初始化即可。所以为了TensortRT可以支持Dims5,Dims6,Dims7,Dims8等高输入维度,那么需要自定义扩展以上维度,扩展后的NvInferLegacyDims.h文件内容如下所示:
#ifndef NV_INFER_LEGACY_DIMS_H
#define NV_INFER_LEGACY_DIMS_H
#include "NvInferRuntimeCommon.h"
namespace nvinfer1
{
class Dims2 : public Dims
{
public:
Dims2()
: Dims{2, {}}
{
}
Dims2(int32_t d0, int32_t d1)
: Dims{2, {d0, d1}}
{
}
};
class DimsHW : public Dims2
{
public:
DimsHW()
: Dims2()
{
}
DimsHW(int32_t height, int32_t width)
: Dims2(height, width)
{
}
int32_t& h()
{
return d[0];
}
int32_t h() const
{
return d[0];
}
int32_t& w()
{
return d[1];
}
int32_t w() const
{
return d[1];
}
};
class Dims3 : public Dims
{
public:
Dims3()
: Dims{3, {}}
{
}
Dims3(int32_t d0, int32_t d1, int32_t d2)
: Dims{3, {d0, d1, d2}}
{
}
};
class Dims4 : public Dims
{
public:
Dims4()
: Dims{4, {}}
{
}
Dims4(int32_t d0, int32_t d1, int32_t d2, int32_t d3)
: Dims{4, {d0, d1, d2, d3}}
{
}
};
class Dims5 : public Dims
{
public:
Dims5()
{
nbDims = 5;
for (int32_t i = 0; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
Dims5(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4)
{
nbDims = 5;
d[0] = d0;
d[1] = d1;
d[2] = d2;
d[3] = d3;
d[4] = d4;
for (int32_t i = nbDims; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
};
class Dims6 : public Dims
{
public:
Dims6()
{
nbDims = 6;
for (int32_t i = 0; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
Dims6(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5)
{
nbDims = 6;
d[0] = d0;
d[1] = d1;
d[2] = d2;
d[3] = d3;
d[4] = d4;
d[5] = d5;
for (int32_t i = nbDims; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
};
class Dims7 : public Dims
{
public:
Dims7()
{
nbDims = 7;
for (int32_t i = 0; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
Dims7(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5, int32_t d6)
{
nbDims = 7;
d[0] = d0;
d[1] = d1;
d[2] = d2;
d[3] = d3;
d[4] = d4;
d[5] = d5;
d[6] = d6;
for (int32_t i = nbDims; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
};
class Dims8 : public Dims
{
public:
Dims8()
{
nbDims = 8;
for (int32_t i = 0; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
Dims8(int32_t d0, int32_t d1, int32_t d2, int32_t d3, int32_t d4, int32_t d5, int32_t d6, int32_t d7)
{
nbDims = 8;
d[0] = d0;
d[1] = d1;
d[2] = d2;
d[3] = d3;
d[4] = d4;
d[5] = d5;
d[6] = d6;
d[7] = d7;
for (int32_t i = nbDims; i < MAX_DIMS; ++i)
{
d[i] = 0;
}
}
};
}
#endif
将NvInferLegacyDims.h修改之后,重新编译即可使用所扩展的Dims5、Dims6、Dims7、Dims8的5维,6维,7维,8维网络输入维度。
如果有兴趣,可以访问我的个站:https://www.stubbornhuang.com/,更多干货!
|