前言
这篇文章的内容主要还是基于 EdWard z. yang 的 Let’s talk about the PyTorch dispatcher 来梳理一下 Pytorch dispatcher 相关的内容学习以及源码阅读。其中会涉及到很多类和内容,也会出现很多源代码,所以文章篇幅会很长,主要还是记录一下我在阅读源码,包括梳理类之间关系的一个过程,源代码中我基本都标注了所在文件位置,方便有兴趣的读者可以沿着我的这个过程一起探索神秘的 dispatcher。
概念介绍
dispatcher 可以理解为分发器,可以根据关于 tensor 输入的一些信息来决定要调用哪一块的程序。其主要是通过分发表(dispatch table)的形式来实现的,如下图:
分发表中包含了相关的 dispatch key 和对应的函数指针,可以看到 dispatch key 不仅有硬件后端,例如 CPU,GPU 等,还有一些更抽象的概念, 例如 autograd 和 tracing。dispatcher 的工作就是根据输入的 tensor 还有一些其他因素(参数个数,返回值类型等)去计算得到一个 dispatch key ,然后根据 dispatch table 找到对应对函数指针,一般我们也称为 kernel,然后间接跳转调用它。
概念介绍起来可能上面的一小段话就可以把 dispatcher 介绍完毕了,但是深入细节,其 dispatch key 是何种形式表现的?如何去进行 dispatch key 计算的?包括如何将它融入到 Pytorch 整个系统等问题,还是有很多内容值得一说的。
diapatch key 的表示和计算
diapatch key 的计算是通过一个 dispatch key set 的结构来实现的, dispatch key set 可以理解为一个 64bit 的数组,每一个 bit 都代表了一个 dispatch key,然后从左到右有优先级关系,同样一个算子,可能有针对不同 dispatch key 的实现,针对这些散落在 Pytorch各处的注册实现,然后也可能会将某些 key 排除的情况,这个数组通过最终将关于同一个算子的所有可用实现都结合到一起,调用其中优先级最高的 dispatch key 对应的 kernel 实现。
dispatch table 注册
下一个问题就是这些函数指针是如何出现在 dispatch table 中的呢?这个主要是在算子注册的时候一步步添加到 dispatch table 里面的,主要是使用了 C++ 代码,注册算子的整个过程官方有对应的文档 。主要算子注册有一下 3 种算子注册交互方式:
首先需要 m.def 定义一个关于算子的 schema(后面会介绍),然后 m.impl 将带有 dispatch key 信息的算子实现注册到 dispatch table 中,最后还有一个名为 m.fallback 的方式,在为所有的算子都注册上同一个 dispatch key。想象一下 dispatch table 和不同的算子间的对应关系可以表示为网格形式,当我们使用 m.impl 注册一个算子实现时就会传入相关的 函数指针 以及 diapatch key,假设这里我们用 C++ 实现了一个 cpu kernel 的 aten::mul 算子,对应到下图就是:
Pytorch 中也可以使用 ”catch-all“ 的操作来将同一个 kernel 注册到所有 dispatch key 上:
或者可以通过 fallback 为所有算子添加一个 dispatch key:
对于上面不同的注册形式,会有一个优先级的关系, 特定的实现 > catch all > fallback。
boxing 和 unboxing
更进一步,谈到函数调用,需要介绍另一个 dispatcher 中另一个重要的概念,boxing 和 unboxing,我们可以从数据结构的角度来对这个概念进行理解。在 C++ 中,我们都知道数据有不同的数据类型,int,float,double 等,包括一些类对象,这些多种多样的数据类型我们就可以理解为是一种 Unboxing 的行为,在 Pytorch 中定义了一种叫 IValue 的数据结构,它可以用来表示很多种数据类型,对外的表现就都是 IValue 这一种类型,这样将很多种元素打包,对外表现成一种就可以理解为是 boxing 的行为。那么对于 unboxing 的输入和 boxing 的输入,Pytorch 自然就要根据输入来进行不同形式的调用/转换(从 boxing 转换成 unboxing,或者从 unboxing 转换成 boxing),这部分将在下面源码部分做详细介绍。
源码分析
在源码部分,主要还是介绍类及各个类之间的关系为主,本文尝试由点及面的来进行源码阅读。所以还要从 Pytorch 统一的数据结构 IValue 类(interpreter value)说起。
IValue 类
前面我们已经提到,它是 Pytorch 中定义对数据的一个统一表达。从概念上,一个 16-byte IValue 类型由 3 个字段组成,一个 8-byte 的payload 类型,可以简单理解为指向相关数据的指针,4-byte 的 tag 则是表示 Ivalue 中包含的值是何种类型,最后一个是 1-byte 的 bool 类型,说明是否是 intrusive_ptr。
struct TORCH_API IValue final {
IValue(const IValue& rhs)
: IValue(rhs.payload, rhs.tag, rhs.is_intrusive_ptr) {
if (is_intrusive_ptr && payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
}
}
...
};
其中 Payload 是一个嵌套联合体,为了使非 tensor 类拷贝起来更简单快捷,其定义为:
union Payload {
union TriviallyCopyablePayload {
TriviallyCopyablePayload() : as_int(0) {}
int64_t as_int;
double as_double;
bool as_bool;
c10::intrusive_ptr_target* as_intrusive_ptr;
struct {
DeviceType type;
DeviceIndex index;
} as_device;
} u;
at::Tensor as_tensor;
Payload() : u() {}
~Payload() {}
};
接下来是 Tag 则是对 IValue 可以包含数据类型的一个枚举:
enum class Tag : uint32_t {
#define DEFINE_TAG(x) x,
TORCH_FORALL_TAGS(DEFINE_TAG)
#undef DEFINE_TAG
};
#define TORCH_FORALL_TAGS(_) \
_(None) \
_(Tensor) \
_(Storage) \
_(Double) \
_(ComplexDouble) \
_(Int) \
_(Bool) \
_(Tuple) \
_(String) \
_(Blob) \
_(GenericList) \
_(GenericDict) \
_(Future) \
_(Device) \
_(Stream) \
_(Object) \
_(PyObject) \
_(Uninitialized) \
_(Capsule) \
_(RRef) \
_(Quantizer) \
_(Generator) \
_(Enum)
可以看到 IValue 可以包含很多种不同的类型,在 IValue 的定义中,设置了对应类型来初始化或者获取其中真实类型的操作相关:
IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) {
new (&payload.as_tensor) at::Tensor(std::move(t));
}
bool isTensor() const {
return Tag::Tensor == tag;
}
at::Tensor toTensor() &&;
at::Tensor& toTensor() &;
const at::Tensor& toTensor() const&;
IValue(double d) : tag(Tag::Double), is_intrusive_ptr(false) {
payload.u.as_double = d;
}
bool isDouble() const {
return Tag::Double == tag;
}
double toDouble() const {
AT_ASSERT(isDouble());
return payload.u.as_double;
}
在 aten/src/ATen/core/ivalue_inl.h 中有对应成员函数的具体实现,感兴趣的读者可以自行阅读。通过 IValue 的统一数据表示,引出了 pytorch 中 boxing 和 unboxing 的概念。我们将 IValue 就可以看做是 Boxing,字面理解就是它会把各种各样的数据类型都打包起来,对外看起来是一致的,而相对的各种各样的类型就是 Unboxing 了。因为 unboxing 是对很多不同类型的统称,所以一般的 unboxing 形式的函数都是采用了模板形式来实现的。
schema
在 Pytorch 中一个算子都要有一个对应的 schema,基本上所有算子的schema 都定义在了 aten/src/ATen/native/native_functions.yaml 文件中,以字符串的形式呈现,下面以 torch.add 为例:
– func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
在 native_functions.yaml 中可以找到上述定义,最终通过脚本分析和代码生成,这个算子定义会被翻译成下面的形式:
namespace at {
TORCH_LIBRARY(aten, m) {
m.def("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
}
这里就是上面提到的算子注册到 dispatch table 的第一步?🏻。这里我们可以将 schema 看做是对一个算子的总体描述,这个描述包含了:算子名称,输入个数和类型,参数个数个类型,返回值类型等信息。继续深入,这里的字符串会在 m.def 里面被解析成 FunctionSchema 的类对象,让我们来看看相关源码:
template <typename Schema>
Library& def(Schema&& raw_schema) & {
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
return _def(std::move(s));
}
inline c10::FunctionSchema schema(const char* s) {
return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA);
}
inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) {
c10::FunctionSchema s = torch::jit::parseSchema(str);
s.setAliasAnalysis(k);
return s;
}
继续:
C10_EXPORT FunctionSchema parseSchema(const std::string& schema) {
auto parsed = parseSchemaOrName(schema);
return parsed.right();
}
C10_EXPORT either<OperatorName, FunctionSchema> parseSchemaOrName(
const std::string& schemaOrName) {
return SchemaParser(schemaOrName).parseDeclarations().at(0);
}
讲过上面层层调用,最后在 SchemaParser(const std::string& str) ,函数中对具体的 schema 字符串进行解析,最终返回的是 FunctionSchema 类型。下面看一下 FunctionSchema 类型的定义:
struct FunctionSchema {
FunctionSchema(
std::string name,
std::string overload_name,
std::vector<Argument> arguments,
std::vector<Argument> returns,
bool is_vararg = false,
bool is_varret = false)
: name_({std::move(name), std::move(overload_name)}),
arguments_(std::move(arguments)),
returns_(std::move(returns)),
is_vararg_(is_vararg),
is_varret_(is_varret) {
checkSchema();
}
...
};
从上面的初始化参数可以看出字符串根据其内容,被解析出了对算子的一个更具体的描述。这里简单为 name 和 overload_name 的区别做一个说明,上述 add 算子的 name 为 aten::name ,overload_name 为空。
-func: arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
上述是定义在 native_functions.yaml 中的 arange 算子的某一个 schema,这里解析出来的 name 就是 aten::arange ,overload_name 则是 start_out ,所以一个算子定义的全名为:name+"."+overload_name 。所以最终我们字符串定义的 schema 会变成 c10::FunctionSchema 这个类,在后面的 dispatch 中会起到很大的索引作用。
OperatorHandle
前面介绍数据类型和算子的定义,下面就是要怎么实现算子和使用算子了。这个 Handle 类主要是用来处理一些已经注册了 schema 的算子,其主要有接口可以查询到已经注册 op 的 operator_name,包括查询/返回它的 FunctionSchema,还有就是注册函数的调用,其主要通过 OperatorDef 类及相关 list 迭代器实现,对外的接口。
class TORCH_API OperatorHandle {
public:
const OperatorName& operator_name() const {
return operatorDef_->op.operator_name();
}
bool hasSchema() const {
return operatorDef_->op.hasSchema();
}
const FunctionSchema& schema() const {
return operatorDef_->op.schema();
}
...
void callBoxed(Stack* stack) const {
c10::Dispatcher::singleton().callBoxed(*this, stack);
}
template<class FuncType>
TypedOperatorHandle<FuncType> typed() const {
return TypedOperatorHandle<FuncType>(operatorIterator_);
}
void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
}
private:
Dispatcher::OperatorDef* operatorDef_;
std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
};
这里的 TypedOperatorHandle 则是与 OperatorHandle 同样的功能,只是把 op 的参数模板化了,并且可以用 unboxed 的方式调用相关实现函数:
template<class Return, class... Args>
class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle {
public:
C10_ALWAYS_INLINE Return call(Args... args) const {
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}
C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...);
}
private:
explicit TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: OperatorHandle(operatorIterator) {}
friend class OperatorHandle;
};
可以看到其就定义了两个 unbox 形式的调用函数。
KernelFunction
根据官方介绍这个类相当于 std::function,但是表示的是算子的 kernel 函数,可以从一个 Boxed/unboxed 的函数,仿函数,lambda 函数创建一个 kernekFunction,包括有需要它会自动适配到对应 boxed/unboxed。
class TORCH_API KernelFunction final {
public:
using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
using BoxedKernelFunction_withDispatchKeys = void(const OperatorHandle&, DispatchKeySet, Stack*);
void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const;
template<class Return, class... Args>
Return call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const;
template<BoxedKernelFunction* func>
static KernelFunction makeFromBoxedFunction();
template<class FuncPtr, bool AllowLegacyTypes = false>
static KernelFunction makeFromUnboxedFunction(FuncPtr);
template<bool AllowLegacyTypes = false, class KernelFunctor>
static KernelFunction makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor);
....
private:
OperatorKernel* getFunctor_() const;
std::shared_ptr<OperatorKernel> functor_;
InternalBoxedKernelFunction* boxed_kernel_func_;
void* unboxed_kernel_func_;
};
KernelFunction 更多的成员函数实现是在 aten/src/ATen/core/boxing/KernelFunction_impl.h 中,有兴趣的读者可以做额外展开阅读。
OperatorEntry
这个类是内部使用的,用户一般是不会直接访问到,是更底层记录算子的一些信息的,上层的相关类都是依赖 OperatorEntry 类实现的。
class TORCH_API OperatorEntry final {
public:
explicit OperatorEntry(OperatorName&& operator_name);
const FunctionSchema& schema() const {
return schema_->schema;
}
void registerSchema(FunctionSchema&&, std::string&& debug);
void deregisterSchema();
std::list<AnnotatedKernel>::iterator registerKernel(
const Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
);
const KernelFunction& lookup(DispatchKey k) const {
const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
if (!kernel.isValid()) {
reportError(k);
}
}
return kernel;
}
...
private:
OperatorName name_;
c10::optional<AnnotatedSchema> schema_;
std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;
ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
...
};
Dispatcher
这个是动态分发的主要类,但是不是用户直接可以用的。注册相关 op 函数kernel 可以使用 aten/src/ATen/core/op_registration/op_registration.h 的 RegisterOperators 类。
class TORCH_API Dispatcher final {
private:
friend class impl::OperatorEntry;
struct OperatorDef final {
explicit OperatorDef(OperatorName&& op_name)
: op(std::move(op_name)) {}
impl::OperatorEntry op;
size_t def_count = 0;
size_t def_and_impl_count = 0;
};
friend class OperatorHandle;
template<class> friend class TypedOperatorHandle;
public:
~Dispatcher();
static Dispatcher& realSingleton();
C10_ALWAYS_INLINE static Dispatcher& singleton() {
return realSingleton();
}
c10::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
c10::optional<OperatorHandle> findOp(const OperatorName& operator_name);
template<class Return, class... Args>
Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const;
template<class Return, class... Args>
static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
void callBoxed(const OperatorHandle& op, Stack* stack) const;
RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug);
RegistrationHandleRAII registerImpl(OperatorName op_name, c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, c10::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
...
};
内部调用逻辑
针对不同的算子 schem,它们都有一个最终的算子调用过程,内部实现一个很常规的方法是使用 dispatcher 类的方法,下面以 abs 算子为例,因为其输入是 unboxed 的 tensor 类型,所以需要以 Unboxed 的形式:
at::Tensor & Tensor::abs_() const {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::abs_", "")
.typed<at::Tensor & (at::Tensor &)>();
return op.call(const_cast<Tensor&>(*this));
}
下面我们就深入追踪一下这个 op.call 函数:
C10_ALWAYS_INLINE Return call(Args... args) const {
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}
diapatcher 类对 call 的定义为:
template<class Return, class... Args>
C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
detail::unused_arg_(args...);
auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(args...);
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
}
最后又调用了 KernelFunction 中的 call 函数:
template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
return callUnboxedKernelFunction<Return, Args...>(unboxed_kernel_func_, functor_.get(), dispatchKeySet, std::forward<Args>(args)...);
}
return impl::BoxedKernelWrapper<Return(Args...)>::call(
boxed_kernel_func_,
functor_.get(),
opHandle,
dispatchKeySet,
std::forward<Args>(args)...
);
}
如果是 unboxed 的行为:
template<class Return, class... Args>
inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKernel* functor, DispatchKeySet dispatchKeySet, Args&&... args) {
using ActualSignature = Return (OperatorKernel*, DispatchKeySet, Args...);
ActualSignature* func = reinterpret_cast<ActualSignature*>(unboxed_kernel_func);
return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
}
如果是 boxed 的行为:
static Result call(
KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func,
OperatorKernel* functor,
const OperatorHandle& opHandle,
DispatchKeySet dispatchKeySet,
Args... args
) {
torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
(*boxed_kernel_func)(functor, opHandle, dispatchKeySet, &stack);
......
);
}
所以整理上述过程,要调用一个算子底层实现的过程是:
- 通过 dispatcher 类 + op name 的形式来查找对应的算子 schema 。因为 schema 中定义了相关的算子输入、输出、参数的相关信息。
- 其中 FunctionSchema 类只是记录,想要具体访问还是要使用 OperatorHandle 类,所以上面返回的是 OperatorHandle 类对象。
- 因为算子的输入基本都是 scalar / tensor 这种 Unboxed 类型的参数,所以要进一步根据输入参数和返回类型来获取 TypedOperatorHandle 类,并调用相关的 call 函数
-
TypedOperatorHandle::call (已经获得了函数的返回类型和输入参数个数及类型) -
dispatcher::call (通过 dispatcher 中的 dispatchKetSet 等,找到当前最高优先级的 Key 并找到对应的 KernelFunction 类) -
KernelFunction::call (KernelFunction 中有 unboxed_kernel_func_ 和 boxed_kernel_func_两个成员变量,用来代表其记录的相关的函数指针,这里根据当前带有的是 unboxed 的还是 boxed 的 kernel function 来决定最后的调用方式)
kernel 是如何注册上的
前面提到过,kernel 的注册主要是使用了 m.impl 接口,这里就从源码的角度来看一下 m.impl 是如何将 kernel 塞进 dispatch table 的。这里以 conv 算子为例,重温一下注册 kernel 的语法:
TORCH_LIBRARY_IMPL(aten, CompositeImplicitAutograd, m) {
m.impl("conv2d", TORCH_FN(wrapper__conv2d));
}
这里主要是 TORCH_LIBRARY_IMPL 这个宏中定义的 Library 类:
template <typename Name, typename Func>
Library& impl(Name name, Func&& raw_f) & {
CppFunction f(std::forward<Func>(raw_f));
return _impl(name, std::move(f));
}
RegistrationHandleRAII Dispatcher::registerImpl(
OperatorName op_name,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<impl::CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
) {
std::lock_guard<std::mutex> lock(mutex_);
auto op = findOrRegisterName_(op_name);
auto handle = op.operatorDef_->op.registerKernel(
*this,
dispatch_key,
std::move(kernel),
std::move(cpp_signature),
std::move(inferred_function_schema),
std::move(debug)
);
++op.operatorDef_->def_and_impl_count;
return RegistrationHandleRAII([this, op, op_name, dispatch_key, handle] {
deregisterImpl_(op, op_name, dispatch_key, handle);
});
}
通过 TORCH_LIBRARY_IMPL 宏来访问 Library 类的 _impl 函数,在 _impl 函数中对 op name 及相关的 dispatch_key 进行 check,最后调用 Dispatcher 类的 registerImpl 接口,在内部调用 OperatorEntry 类的 registerKernel 接口将函数塞进 Dispatcher 类,并更新相关的 DispatchTable。
at::Tensor wrapper__conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) {
return at::native::conv2d(input, weight, bias, stride, padding, dilation, groups);
}
at::Tensor conv2d(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
int64_t groups) {
return at::_convolution_mode(
input, weight, bias, stride, std::move(padding), dilation, groups);
}
at::Tensor _convolution_mode(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation, int64_t groups) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::_convolution_mode", "")
.typed<at::Tensor (const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, at::IntArrayRef, c10::string_view, at::IntArrayRef, int64_t)>();
return op.call(input, weight, bias, stride, padding, dilation, groups);
}
上面就是通过相关接口 dispatch 函数的调用。
首先需要注册相关 schema 及 impl(kernel),调用的时候就是先实例化一个 dispatcher 对象,然后通过 op name 获取这个算子的 operatorHandle,并且根据 box/unbox 和 dispatchkey 来最后调用实现的对应算子 kernel。
根据数据类型再次分发
之前在介绍 Tensor C++ 相关实现中提到,一个算子在调用过程中可能还会根据具体的数值类型进行再次分发,当时是用下面的图表示的:
我们上面介绍的 dispacther 相关内容都是绿色部分,根据 device type,layout 等信息进行动态分发,下面简单对红色的部分进行一个源码定位。
Pytorch 中有很多带有 AT_DISPATCH_ALL_TYPES 关键字的宏来做相关的根据类型分发的内容,下面以 logical_xor 算子为例进行说明,首先来看一下这个算子的实现函数:
Tensor logical_xor(const Tensor& self, const Tensor& other) {
return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out));
}
template <typename OutImpl>
Tensor comparison_op(
const Tensor& self,
const Scalar& other,
OutImpl& out_impl) {
return comparison_op(
self, wrapped_scalar_tensor_and_check_convert(other, self), out_impl);
}
static Tensor wrapped_scalar_tensor_and_check_convert(
const Scalar& scalar,
Tensor tensor) {
check_convert(scalar, tensor.scalar_type());
return wrapped_scalar_tensor(scalar);
}
最后是在 check_convert 函数中调用了相关的分发的宏:
static void check_convert(const Scalar& scalar, ScalarType scalarType) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Bool,
at::ScalarType::BFloat16,
at::ScalarType::Half,
scalarType,
"check_convert",
[&] { scalar.to<scalar_t>(); });
}
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
\
at::ScalarType _st = ::detail::scalar_type(the_type); \
RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__) \
...
可以看到其是采用宏的方式来代替冗长的 switch…case 的操作。
|