IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> [Pytorch 源码阅读] —— 谈谈 dispatcher -> 正文阅读

[人工智能][Pytorch 源码阅读] —— 谈谈 dispatcher

前言

这篇文章的内容主要还是基于 EdWard z. yang 的 Let’s talk about the PyTorch dispatcher 来梳理一下 Pytorch dispatcher 相关的内容学习以及源码阅读。其中会涉及到很多类和内容,也会出现很多源代码,所以文章篇幅会很长,主要还是记录一下我在阅读源码,包括梳理类之间关系的一个过程,源代码中我基本都标注了所在文件位置,方便有兴趣的读者可以沿着我的这个过程一起探索神秘的 dispatcher。

概念介绍

dispatcher 可以理解为分发器,可以根据关于 tensor 输入的一些信息来决定要调用哪一块的程序。其主要是通过分发表(dispatch table)的形式来实现的,如下图:

分发表中包含了相关的 dispatch key 和对应的函数指针,可以看到 dispatch key 不仅有硬件后端,例如 CPU,GPU 等,还有一些更抽象的概念, 例如 autograd 和 tracing。dispatcher 的工作就是根据输入的 tensor 还有一些其他因素(参数个数,返回值类型等)去计算得到一个 dispatch key ,然后根据 dispatch table 找到对应对函数指针,一般我们也称为 kernel,然后间接跳转调用它。

概念介绍起来可能上面的一小段话就可以把 dispatcher 介绍完毕了,但是深入细节,其 dispatch key 是何种形式表现的?如何去进行 dispatch key 计算的?包括如何将它融入到 Pytorch 整个系统等问题,还是有很多内容值得一说的。

diapatch key 的表示和计算

diapatch key 的计算是通过一个 dispatch key set 的结构来实现的, dispatch key set 可以理解为一个 64bit 的数组,每一个 bit 都代表了一个 dispatch key,然后从左到右有优先级关系,同样一个算子,可能有针对不同 dispatch key 的实现,针对这些散落在 Pytorch各处的注册实现,然后也可能会将某些 key 排除的情况,这个数组通过最终将关于同一个算子的所有可用实现都结合到一起,调用其中优先级最高的 dispatch key 对应的 kernel 实现。

dispatch table 注册

下一个问题就是这些函数指针是如何出现在 dispatch table 中的呢?这个主要是在算子注册的时候一步步添加到 dispatch table 里面的,主要是使用了 C++ 代码,注册算子的整个过程官方有对应的文档 。主要算子注册有一下 3 种算子注册交互方式:

首先需要 m.def 定义一个关于算子的 schema(后面会介绍),然后 m.impl 将带有 dispatch key 信息的算子实现注册到 dispatch table 中,最后还有一个名为 m.fallback 的方式,在为所有的算子都注册上同一个 dispatch key。想象一下 dispatch table 和不同的算子间的对应关系可以表示为网格形式,当我们使用 m.impl 注册一个算子实现时就会传入相关的 函数指针 以及 diapatch key,假设这里我们用 C++ 实现了一个 cpu kernel 的 aten::mul 算子,对应到下图就是:

Pytorch 中也可以使用 ”catch-all“ 的操作来将同一个 kernel 注册到所有 dispatch key 上:

或者可以通过 fallback 为所有算子添加一个 dispatch key:

对于上面不同的注册形式,会有一个优先级的关系, 特定的实现 > catch all > fallback。

boxing 和 unboxing

更进一步,谈到函数调用,需要介绍另一个 dispatcher 中另一个重要的概念,boxing 和 unboxing,我们可以从数据结构的角度来对这个概念进行理解。在 C++ 中,我们都知道数据有不同的数据类型,int,float,double 等,包括一些类对象,这些多种多样的数据类型我们就可以理解为是一种 Unboxing 的行为,在 Pytorch 中定义了一种叫 IValue 的数据结构,它可以用来表示很多种数据类型,对外的表现就都是 IValue 这一种类型,这样将很多种元素打包,对外表现成一种就可以理解为是 boxing 的行为。那么对于 unboxing 的输入和 boxing 的输入,Pytorch 自然就要根据输入来进行不同形式的调用/转换(从 boxing 转换成 unboxing,或者从 unboxing 转换成 boxing),这部分将在下面源码部分做详细介绍。

源码分析

在源码部分,主要还是介绍类及各个类之间的关系为主,本文尝试由点及面的来进行源码阅读。所以还要从 Pytorch 统一的数据结构 IValue 类(interpreter value)说起。

IValue 类

前面我们已经提到,它是 Pytorch 中定义对数据的一个统一表达。从概念上,一个 16-byte IValue 类型由 3 个字段组成,一个 8-byte 的payload 类型,可以简单理解为指向相关数据的指针,4-byte 的 tag 则是表示 Ivalue 中包含的值是何种类型,最后一个是 1-byte 的 bool 类型,说明是否是 intrusive_ptr。

//  aten/src/ATen/core/ivalue.h
struct TORCH_API IValue final {
  IValue(const IValue& rhs)
      : IValue(rhs.payload, rhs.tag, rhs.is_intrusive_ptr) {
    if (is_intrusive_ptr && payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
      c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
    }
  }
  ...
}

其中 Payload 是一个嵌套联合体,为了使非 tensor 类拷贝起来更简单快捷,其定义为:

union Payload {
    union TriviallyCopyablePayload {
      TriviallyCopyablePayload() : as_int(0) {}
      int64_t as_int;
      double as_double;
      bool as_bool;
      c10::intrusive_ptr_target* as_intrusive_ptr;
      struct {
        DeviceType type;
        DeviceIndex index;
      } as_device;
    } u;
    at::Tensor as_tensor;
    Payload() : u() {}
    ~Payload() {}
  };

接下来是 Tag 则是对 IValue 可以包含数据类型的一个枚举:

 enum class Tag : uint32_t {
#define DEFINE_TAG(x) x,
    TORCH_FORALL_TAGS(DEFINE_TAG)
#undef DEFINE_TAG
  };

#define TORCH_FORALL_TAGS(_) \
  _(None)                    \
  _(Tensor)                  \
  _(Storage)                 \
  _(Double)                  \
  _(ComplexDouble)           \
  _(Int)                     \
  _(Bool)                    \
  _(Tuple)                   \
  _(String)                  \
  _(Blob)                    \
  _(GenericList)             \
  _(GenericDict)             \
  _(Future)                  \
  _(Device)                  \
  _(Stream)                  \
  _(Object)                  \
  _(PyObject)                \
  _(Uninitialized)           \
  _(Capsule)                 \
  _(RRef)                    \
  _(Quantizer)               \
  _(Generator)               \
  _(Enum)

可以看到 IValue 可以包含很多种不同的类型,在 IValue 的定义中,设置了对应类型来初始化或者获取其中真实类型的操作相关:

  // aten::Tensor 类型
  IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) {
    new (&payload.as_tensor) at::Tensor(std::move(t));
  }

  bool isTensor() const {
    return Tag::Tensor == tag;
  }
  at::Tensor toTensor() &&;
  at::Tensor& toTensor() &;
  const at::Tensor& toTensor() const&;

  // Double
  IValue(double d) : tag(Tag::Double), is_intrusive_ptr(false) {
    payload.u.as_double = d;
  }
  bool isDouble() const {
    return Tag::Double == tag;
  }
  double toDouble() const {
    AT_ASSERT(isDouble());
    return payload.u.as_double;
  }

在 aten/src/ATen/core/ivalue_inl.h 中有对应成员函数的具体实现,感兴趣的读者可以自行阅读。通过 IValue 的统一数据表示,引出了 pytorch 中 boxing 和 unboxing 的概念。我们将 IValue 就可以看做是 Boxing,字面理解就是它会把各种各样的数据类型都打包起来,对外看起来是一致的,而相对的各种各样的类型就是 Unboxing 了。因为 unboxing 是对很多不同类型的统称,所以一般的 unboxing 形式的函数都是采用了模板形式来实现的

schema

在 Pytorch 中一个算子都要有一个对应的 schema,基本上所有算子的schema 都定义在了 aten/src/ATen/native/native_functions.yaml 文件中,以字符串的形式呈现,下面以 torch.add 为例:

– func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor

在 native_functions.yaml 中可以找到上述定义,最终通过脚本分析和代码生成,这个算子定义会被翻译成下面的形式:

namespace at {
TORCH_LIBRARY(aten, m) {
  m.def("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor");
}

这里就是上面提到的算子注册到 dispatch table 的第一步?🏻。这里我们可以将 schema 看做是对一个算子的总体描述,这个描述包含了:算子名称,输入个数和类型,参数个数个类型,返回值类型等信息。继续深入,这里的字符串会在 m.def 里面被解析成 FunctionSchema 的类对象,让我们来看看相关源码:

// torch/include/torch/library.h  
template <typename Schema>
  Library& def(Schema&& raw_schema) & {
    // 完美转发,调用 schema() 函数
    c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
    return _def(std::move(s));
}

// schema() 函数的实现 
inline c10::FunctionSchema schema(const char* s) {
  return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA);
}

inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) {
  c10::FunctionSchema s = torch::jit::parseSchema(str);
  s.setAliasAnalysis(k);
  return s;
}

继续:

// torch/csrc/jit/frontend/function_schema_parser.cpp
C10_EXPORT FunctionSchema parseSchema(const std::string& schema) {
  auto parsed = parseSchemaOrName(schema);
  return parsed.right();
}

C10_EXPORT either<OperatorName, FunctionSchema> parseSchemaOrName(
    const std::string& schemaOrName) {
  return SchemaParser(schemaOrName).parseDeclarations().at(0);
}

讲过上面层层调用,最后在 SchemaParser(const std::string& str),函数中对具体的 schema 字符串进行解析,最终返回的是 FunctionSchema 类型。下面看一下 FunctionSchema 类型的定义:

struct FunctionSchema {
  FunctionSchema(
      std::string name,
      std::string overload_name,
      std::vector<Argument> arguments,
      std::vector<Argument> returns,
      bool is_vararg = false,
      bool is_varret = false)
      : name_({std::move(name), std::move(overload_name)}),
        arguments_(std::move(arguments)),
        returns_(std::move(returns)),
        is_vararg_(is_vararg),
        is_varret_(is_varret) {
    checkSchema();
  }
  ...
};

从上面的初始化参数可以看出字符串根据其内容,被解析出了对算子的一个更具体的描述。这里简单为 name 和 overload_name 的区别做一个说明,上述 add 算子的 name 为 aten::name,overload_name 为空。

-func: arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)

上述是定义在 native_functions.yaml 中的 arange 算子的某一个 schema,这里解析出来的 name 就是 aten::arange,overload_name 则是 start_out,所以一个算子定义的全名为:name+"."+overload_name 。所以最终我们字符串定义的 schema 会变成 c10::FunctionSchema 这个类,在后面的 dispatch 中会起到很大的索引作用。

OperatorHandle

前面介绍数据类型和算子的定义,下面就是要怎么实现算子和使用算子了。这个 Handle 类主要是用来处理一些已经注册了 schema 的算子,其主要有接口可以查询到已经注册 op 的 operator_name,包括查询/返回它的 FunctionSchema,还有就是注册函数的调用,其主要通过 OperatorDef 类及相关 list 迭代器实现,对外的接口。

// aten/src/ATen/core/dispatch/Dispatcher.h
class TORCH_API OperatorHandle {
public:
  const OperatorName& operator_name() const {
    return operatorDef_->op.operator_name();
  }

  bool hasSchema() const {
    return operatorDef_->op.hasSchema();
  }

  const FunctionSchema& schema() const {
    return operatorDef_->op.schema();
  }
  ...
  // 以 boxed 的方式调用函数
  void callBoxed(Stack* stack) const {
    c10::Dispatcher::singleton().callBoxed(*this, stack);
  }
  // 为了 unbox 的形式调用,这里还有一个 TypedOperatorHandle 类
  template<class FuncType>
  TypedOperatorHandle<FuncType> typed() const {
     return TypedOperatorHandle<FuncType>(operatorIterator_);
  }
  
  void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
    c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
  }
    
 private:
   Dispatcher::OperatorDef* operatorDef_;
   std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
};

这里的 TypedOperatorHandle 则是与 OperatorHandle 同样的功能,只是把 op 的参数模板化了,并且可以用 unboxed 的方式调用相关实现函数:

template<class Return, class... Args>
class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle {
public:
  // 以 unboxed 的方式调用函数
  C10_ALWAYS_INLINE Return call(Args... args) const {
    return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
  }
  
  C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
    return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...);
  }
 private:
  explicit TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
  : OperatorHandle(operatorIterator) {}
  friend class OperatorHandle; 
};

可以看到其就定义了两个 unbox 形式的调用函数。

KernelFunction

根据官方介绍这个类相当于 std::function,但是表示的是算子的 kernel 函数,可以从一个 Boxed/unboxed 的函数,仿函数,lambda 函数创建一个 kernekFunction,包括有需要它会自动适配到对应 boxed/unboxed。

// aten/src/ATen/core/boxing/KernelFunction.h
class TORCH_API KernelFunction final {
public:
  // 3 种不同的函数形式,boxed 所有输入的都是 Stack,即 vector<IValue>;
	using InternalBoxedKernelFunction = void(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*);
  using BoxedKernelFunction = void(const OperatorHandle&, Stack*);
  using BoxedKernelFunction_withDispatchKeys = void(const OperatorHandle&, DispatchKeySet, Stack*);
  // 以 boxed 的方式调用函数
  void callBoxed(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Stack* stack) const;
  
// 以 unboxed 的方式调用函数
template<class Return, class... Args>
  Return call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const;

 // 以 boxed/unboxed 的方式从 函数/仿函数/lambda 函数创建 KernelFunction
  template<BoxedKernelFunction* func>
  static KernelFunction makeFromBoxedFunction();
  
  template<class FuncPtr, bool AllowLegacyTypes = false>
  static KernelFunction makeFromUnboxedFunction(FuncPtr);
  
  template<bool AllowLegacyTypes = false, class KernelFunctor>
  static KernelFunction makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor);
  ....
private:
  OperatorKernel* getFunctor_() const;
  std::shared_ptr<OperatorKernel> functor_;
  InternalBoxedKernelFunction* boxed_kernel_func_; // 内部定义的 boxed 的函数指针
  void* unboxed_kernel_func_; // 常规的函数指针
};

KernelFunction 更多的成员函数实现是在 aten/src/ATen/core/boxing/KernelFunction_impl.h 中,有兴趣的读者可以做额外展开阅读。

OperatorEntry

这个类是内部使用的,用户一般是不会直接访问到,是更底层记录算子的一些信息的,上层的相关类都是依赖 OperatorEntry 类实现的。

// aten/src/ATen/core/dispatch/OperatorEntry.h
class TORCH_API OperatorEntry final {
public:
  // 使用 op name 即可初始化
   explicit OperatorEntry(OperatorName&& operator_name);
  
  // 获取相关算子的 schema
   const FunctionSchema& schema() const {
     return schema_->schema;
   }
  
  // 注册 schema
  void registerSchema(FunctionSchema&&, std::string&& debug);
  void deregisterSchema();
  
  // 注册实现
  std::list<AnnotatedKernel>::iterator registerKernel(
    const Dispatcher& dispatcher,
    c10::optional<DispatchKey> dispatch_key,
    KernelFunction kernel,
    c10::optional<CppSignature> cpp_signature,
    std::unique_ptr<FunctionSchema> inferred_function_schema,
    std::string debug
  );
  
  // 根据 dispatch key 查找相关 kernel
  const KernelFunction& lookup(DispatchKey k) const {
    const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
    if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
      if (!kernel.isValid()) {
        reportError(k);
      }
    }
    return kernel;
  }
  ...
private:
    OperatorName name_;
    c10::optional<AnnotatedSchema> schema_;

    std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> dispatchTable_;
    DispatchKeyExtractor dispatchKeyExtractor_;
    ska::flat_hash_map<DispatchKey, std::list<AnnotatedKernel>> kernels_;
    ...
};

Dispatcher

这个是动态分发的主要类,但是不是用户直接可以用的。注册相关 op 函数kernel 可以使用 aten/src/ATen/core/op_registration/op_registration.h 的 RegisterOperators 类。

// aten/src/ATen/core/dispatch/Dispatcher.h
class TORCH_API Dispatcher final {
private:
  // For direct access to backend fallback information
  friend class impl::OperatorEntry;

  struct OperatorDef final {
    explicit OperatorDef(OperatorName&& op_name)
    : op(std::move(op_name)) {}
    impl::OperatorEntry op;
    size_t def_count = 0;
    size_t def_and_impl_count = 0;
   };
   friend class OperatorHandle;
   template<class> friend class TypedOperatorHandle;
  public:
   ~Dispatcher();
   static Dispatcher& realSingleton();
  // 全局只需要一个 diapatch table
  C10_ALWAYS_INLINE static Dispatcher& singleton() {
    return realSingleton();
  }
  // 通过 schema 查找来访问 operator
  c10::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
  OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
  c10::optional<OperatorHandle> findOp(const OperatorName& operator_name);
  
    // 调用算子
  template<class Return, class... Args>
  Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const;
  
  template<class Return, class... Args>
  static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, bool pre_sampled, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
  
  void callBoxed(const OperatorHandle& op, Stack* stack) const;
    // 注册一个新的算子 schema
  RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug);
    // 注册一个算子 kernel 到 dispatch table 上
  RegistrationHandleRAII registerImpl(OperatorName op_name, c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, c10::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
 ...
}

内部调用逻辑

针对不同的算子 schem,它们都有一个最终的算子调用过程,内部实现一个很常规的方法是使用 dispatcher 类的方法,下面以 abs 算子为例,因为其输入是 unboxed 的 tensor 类型,所以需要以 Unboxed 的形式:

at::Tensor & Tensor::abs_() const {
    static auto op = c10::Dispatcher::singleton() // 返回一个 static 的 dispatcher 对象
        .findSchemaOrThrow("aten::abs_", "") // 返回 OperatorHandle 类
        .typed<at::Tensor & (at::Tensor &)>(); // 返回 TypedOperatorHandle 类
    return op.call(const_cast<Tensor&>(*this)); // 最后调用 TypedOperatorHandle 的 call 函数
}

下面我们就深入追踪一下这个 op.call 函数:

 // aten/src/ATen/core/dispatch/Dispatcher.h
 // 返回来调用 dispatcher 类中定义的 call 函数
  C10_ALWAYS_INLINE Return call(Args... args) const {
    return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
  }

diapatcher 类对 call 的定义为:

template<class Return, class... Args>
C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
  detail::unused_arg_(args...); 
  // 获取当前 dispatchKeySet
  auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
    .template getDispatchKeySetUnboxed<Args...>(args...);
 // 根据 最好优先级 key 来找到当前应该派发的 KernelFunction
  const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
  // 最后调用 KernelFunction 类的 call 函数
  return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
}

最后又调用了 KernelFunction 中的 call 函数:

template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
   // 如果是 unboxed 形式的函数
    if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
        return callUnboxedKernelFunction<Return, Args...>(unboxed_kernel_func_, functor_.get(), dispatchKeySet, std::forward<Args>(args)...);
    }

    // 如果是 boxed 形式的函数
    return impl::BoxedKernelWrapper<Return(Args...)>::call(
        boxed_kernel_func_,
        functor_.get(),
        opHandle,
        dispatchKeySet,
        std::forward<Args>(args)...
    );
}

如果是 unboxed 的行为:

// aten/src/ATen/core/boxing/KernelFunction_impl.h
template<class Return, class... Args>
inline Return callUnboxedKernelFunction(void* unboxed_kernel_func, OperatorKernel* functor, DispatchKeySet dispatchKeySet, Args&&... args) {
    using ActualSignature = Return (OperatorKernel*, DispatchKeySet, Args...);
    ActualSignature* func = reinterpret_cast<ActualSignature*>(unboxed_kernel_func);
    // 将相关参数传入,unboxed_kernel_func 运行
    return (*func)(functor, dispatchKeySet, std::forward<Args>(args)...);
}

如果是 boxed 的行为:

// aten/src/ATen/core/boxing/KernelFunction_impl.h
 static Result call(
    KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func,
    OperatorKernel* functor,
    const OperatorHandle& opHandle,
    DispatchKeySet dispatchKeySet,
    Args... args
  ) {
    torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
    // 将相关参数传入,boxed_kernel_func 运行
    (*boxed_kernel_func)(functor, opHandle, dispatchKeySet, &stack);
    ......
    );
  }

所以整理上述过程,要调用一个算子底层实现的过程是:

  1. 通过 dispatcher 类 + op name 的形式来查找对应的算子 schema 。因为 schema 中定义了相关的算子输入、输出、参数的相关信息。
  • 其中 FunctionSchema 类只是记录,想要具体访问还是要使用 OperatorHandle 类,所以上面返回的是 OperatorHandle 类对象。
  1. 因为算子的输入基本都是 scalar / tensor 这种 Unboxed 类型的参数,所以要进一步根据输入参数和返回类型来获取 TypedOperatorHandle 类,并调用相关的 call 函数
  • TypedOperatorHandle::call (已经获得了函数的返回类型和输入参数个数及类型)

  • dispatcher::call (通过 dispatcher 中的 dispatchKetSet 等,找到当前最高优先级的 Key 并找到对应的 KernelFunction 类)

  • KernelFunction::call (KernelFunction 中有 unboxed_kernel_func_ 和 boxed_kernel_func_两个成员变量,用来代表其记录的相关的函数指针,这里根据当前带有的是 unboxed 的还是 boxed 的 kernel function 来决定最后的调用方式)

kernel 是如何注册上的

前面提到过,kernel 的注册主要是使用了 m.impl 接口,这里就从源码的角度来看一下 m.impl 是如何将 kernel 塞进 dispatch table 的。这里以 conv 算子为例,重温一下注册 kernel 的语法:

TORCH_LIBRARY_IMPL(aten, CompositeImplicitAutograd, m)  {
    m.impl("conv2d", TORCH_FN(wrapper__conv2d));
  }

这里主要是 TORCH_LIBRARY_IMPL 这个宏中定义的 Library 类:

// torch/library.h
template <typename Name, typename Func>
  Library& impl(Name name, Func&& raw_f) & {
    CppFunction f(std::forward<Func>(raw_f));
      return _impl(name, std::move(f));
  }

// 注册实现
RegistrationHandleRAII Dispatcher::registerImpl(
  OperatorName op_name,
  c10::optional<DispatchKey> dispatch_key,
  KernelFunction kernel,
  c10::optional<impl::CppSignature> cpp_signature,
  std::unique_ptr<FunctionSchema> inferred_function_schema,
  std::string debug
) {
  std::lock_guard<std::mutex> lock(mutex_);

  auto op = findOrRegisterName_(op_name);

  auto handle = op.operatorDef_->op.registerKernel(
    *this,
    dispatch_key,
    std::move(kernel),
    // NOLINTNEXTLINE(performance-move-const-arg)
    std::move(cpp_signature),
    std::move(inferred_function_schema),
    std::move(debug)
  );

  ++op.operatorDef_->def_and_impl_count;

  return RegistrationHandleRAII([this, op, op_name, dispatch_key, handle] {
    deregisterImpl_(op, op_name, dispatch_key, handle);
  });
}

通过 TORCH_LIBRARY_IMPL 宏来访问 Library 类的 _impl 函数,在 _impl 函数中对 op name 及相关的 dispatch_key 进行 check,最后调用 Dispatcher 类的 registerImpl 接口,在内部调用 OperatorEntry 类的 registerKernel 接口将函数塞进 Dispatcher 类,并更新相关的 DispatchTable。

at::Tensor wrapper__conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) {
  return at::native::conv2d(input, weight, bias, stride, padding, dilation, groups);
}

// 层层调用
// aten/src/ATen/native/Convolution.cpp
at::Tensor conv2d(
    const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
    IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
    int64_t groups) {
  return at::_convolution_mode(
      input, weight, bias, stride, std::move(padding), dilation, groups);
}

// build/aten/src/ATen/Functions.cpp
at::Tensor _convolution_mode(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation, int64_t groups) {
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::_convolution_mode", "")
        .typed<at::Tensor (const at::Tensor &, const at::Tensor &, const c10::optional<at::Tensor> &, at::IntArrayRef, c10::string_view, at::IntArrayRef, int64_t)>();
    return op.call(input, weight, bias, stride, padding, dilation, groups);
}

上面就是通过相关接口 dispatch 函数的调用。

首先需要注册相关 schema 及 impl(kernel),调用的时候就是先实例化一个 dispatcher 对象,然后通过 op name 获取这个算子的 operatorHandle,并且根据 box/unbox 和 dispatchkey 来最后调用实现的对应算子 kernel。

根据数据类型再次分发

之前在介绍 Tensor C++ 相关实现中提到,一个算子在调用过程中可能还会根据具体的数值类型进行再次分发,当时是用下面的图表示的:

我们上面介绍的 dispacther 相关内容都是绿色部分,根据 device type,layout 等信息进行动态分发,下面简单对红色的部分进行一个源码定位。

Pytorch 中有很多带有 AT_DISPATCH_ALL_TYPES 关键字的宏来做相关的根据类型分发的内容,下面以 logical_xor 算子为例进行说明,首先来看一下这个算子的实现函数:

// aten/src/ATen/native/BinaryOps.cpp
Tensor logical_xor(const Tensor& self, const Tensor& other) {
  return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out));
}
template <typename OutImpl>
Tensor comparison_op(
    const Tensor& self,
    const Scalar& other,
    OutImpl& out_impl) {
  return comparison_op(
      self, wrapped_scalar_tensor_and_check_convert(other, self), out_impl);
}
static Tensor wrapped_scalar_tensor_and_check_convert(
    const Scalar& scalar,
    Tensor tensor) {
  check_convert(scalar, tensor.scalar_type());
  return wrapped_scalar_tensor(scalar);
}

最后是在 check_convert 函数中调用了相关的分发的宏:

static void check_convert(const Scalar& scalar, ScalarType scalarType) {
  // Validate that is possible to convert scalar to tensor dtype without
  // overflow
  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
      at::ScalarType::Bool,
      at::ScalarType::BFloat16,
      at::ScalarType::Half,
      scalarType,
      "check_convert",
      [&] { scalar.to<scalar_t>(); });
}
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(                             \
    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...)                 \
  [&] {                                                                     \
    const auto& the_type = TYPE;                                            \
    /* don't use TYPE again in case it is an expensive or side-effect op*/  \
    at::ScalarType _st = ::detail::scalar_type(the_type);                   \
    RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st);                                \
    switch (_st) {                                                          \
      AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Byte, uint8_t, __VA_ARGS__)      \
      AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Char, int8_t, __VA_ARGS__)       \
      AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Double, double, __VA_ARGS__)     \
      AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Float, float, __VA_ARGS__)       \
      AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Int, int32_t, __VA_ARGS__)       \
      AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Long, int64_t, __VA_ARGS__)      \
      AT_PRIVATE_CASE_TYPE(NAME, at::ScalarType::Short, int16_t, __VA_ARGS__)     \ 
     ...

可以看到其是采用宏的方式来代替冗长的 switch…case 的操作。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-08 11:20:38  更:2021-08-08 11:23:58 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 1:38:31-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码