   -> 正文阅读


torch.fx中主要包含三个组件:符号追踪器(symbolic tracer),中间表示(intermediate representation),python代码生成(python code generation)。
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp

# Code generation - valid Python code
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp


  1. 有关图的例子
import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

# 打印graph的所有node


opcode         name           target                                                       args                kwargs
-------------  -------------  -----------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                            ()                  {}
get_attr       linear_weight  linear.weight                                                ()                  {}
call_function  add            <built-in function add>                                      (x, linear_weight)  {}
call_module    linear         linear                                                       (add,)              {}
call_method    relu           relu                                                         (linear,)           {}
call_function  sum_1          <built-in method sum of type object at 0x00007FF80165E360>   (relu,)             {'dim': -1}
call_function  topk           <built-in method topk of type object at 0x00007FF80165E360>  (sum_1, 3)          {}
output         output         output                                                       (topk,)             {}


  1. 图形操作
    torch.fx做的就是将一个Module转换为静态图,这和转换Module有什么关系呢?如果我们将一个Module追踪得到的Graph进行变换,加上Python代码生成 工具,是不是就可以达到变换一个Module的目的,整个流程就是symbolic tracing -> intermediate representation -> transforms -> Python code generation,这就实现了一个Module到另外一个Module的Python-to-Python的转换流程流程如下所示:
import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # 首先得到模块的graph
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)
    # 然后对graph做一些修改操作
    # Step 2: Modify this Graph or create a new one
    graph = ...
    # 最后用新得到的graph构建新的模块
    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

这里最终得到的torch.fx.GraphModule除了包含graph和code属性外就和正常的nn.Module一样,它的forward执行的就是graph的语义代码。这里来看一个修改Module的简单例子,这个例子中我们将模块中所有的torch.add()操作替换成 torch.mul() :

import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if == torch.add:
       = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

① 删除活添加节点,FX的API torch.relu()
② replace_pattern()用于编辑图形的查找替换API

  1. 图像操作的相关案例
    ① Replace one op
    ② Conv/Batch Norm fusion
    ③ replace_pattern: Basic usage
    ④ Quantization
    ⑤ Invert Transformation


import torch.fx as fx
from torch.fx.node import Argument, Target
from torch.nn.utils.fusion import fuse_conv_bn_eval
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fx.passes.shape_prop import ShapeProp
import copy
from collections import defaultdict
import torch.utils.mkldnn as th_mkldnn
import operator
import time
import logging
from enum import Enum

def _parent_name(target : str) -> Tuple[str, str]:
    Splits a qualname into parent path and last atom.
    For example, `` -> (``, `baz`)
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name

# Works for length 2 patterns with 2 modules
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
    if len(node.args) == 0:
        return False
    nodes: Tuple[Any, fx.Node] = (node.args[0], node)
    for expected_type, current_node in zip(pattern, nodes):
        if not isinstance(current_node, fx.Node):
            return False
        if current_node.op != 'call_module':
            return False
        if not isinstance(, str):
            return False
        if not in modules:
            return False
        if type(modules[]) is not expected_type:
            return False
    return True

def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
    assert(isinstance(, str))
    parent_name, name = _parent_name(
    modules[] = new_module
    setattr(modules[parent_name], name, new_module)

def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
    Fuses convolution/BN layers for inference purposes. Will deepcopy your
    model by default, but can modify the model inplace as well.
    patterns = [(nn.Conv1d, nn.BatchNorm1d),
                (nn.Conv2d, nn.BatchNorm2d),
                (nn.Conv3d, nn.BatchNorm3d)]
    if not inplace:
        model = copy.deepcopy(model)
    fx_model = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    new_graph = copy.deepcopy(fx_model.graph)

    for pattern in patterns:
        for node in new_graph.nodes:
        # 找到目标Node:args是Conv,target是BN
            if matches_module_pattern(pattern, node, modules):
                if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
                conv = modules[node.args[0].target]
                bn = modules[]
                # 融合BN和Conv
                fused_conv = fuse_conv_bn_eval(conv, bn)
                # 替换Node的module,其实就是将融合后的module替换Conv Node的target,背后是模块替换
                replace_node_module(node.args[0], modules, fused_conv)
                 # 将所有用到BN Node的替换为Conv Node(已经融合后的Conv)
                # 删除BN Node
    return fx.GraphModule(fx_model, new_graph)

def remove_dropout(model: nn.Module) -> nn.Module:
    Removes all dropout layers from the module.
    fx_model = fx.symbolic_trace(model)

    class DropoutRemover(torch.fx.Transformer):
        def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
            if isinstance(self.submodules[target], nn.Dropout):
                assert len(args) == 1
                return args[0]
                return super().call_module(target, args, kwargs)
    return DropoutRemover(fx_model).transform()

def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
    Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
    new_graph = fx.Graph()
    env: Dict[fx.Node, fx.Node] = {}
    for input in inputs:
        new_node = new_graph.placeholder(
        env[input] = new_node
    for node in nodes:
        new_node = new_graph.node_copy(node, lambda x: env[x])
        env[node] = new_node
    new_graph.output([env[output] for output in outputs])
    return fx.GraphModule(orig_module, new_graph)

mkldnn_supported = [
    nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d,
    torch.relu, torch.transpose, torch.sigmoid,
    F.relu, F.avg_pool2d, F.adaptive_avg_pool2d
# These are operators that may not be convertible into MKLDNN ops (e.g. the
# args are scalar values). Thus, we only include them in the subgraph if their
# arguments are already in MKLDNN.
# TODO: Determine whether this can be removed after type inference.
mkldnn_supported_unknown = [operator.add, operator.mul]
mkldnn_map = {
    nn.Conv2d: th_mkldnn.MkldnnConv2d,
    nn.Linear: th_mkldnn.MkldnnLinear,
    nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a)

def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
    For each node, if it's a module that can be preconverted into MKLDNN,
    then we do so and create a mapping to allow us to convert from the MKLDNN
    version of the module to the original.
    old_modules: Dict[nn.Module, nn.Module] = {}
    for node in nodes:
        if node.op == 'call_module':
            assert(isinstance(, str))
            cur_module = modules[]
            if type(cur_module) in mkldnn_map:
                new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
                assert(isinstance(new_module, nn.Module))
                old_modules[new_module] = copy.deepcopy(cur_module)
                replace_node_module(node, modules, new_module)
    return old_modules

def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]):
    Maps each module that's been changed with `modules_to_mkldnn` back to its
    for node in nodes:
        if node.op == 'call_module':
            assert(isinstance(, str))
            cur_module = modules[]
            if cur_module in old_modules:
                replace_node_module(node, modules, old_modules[cur_module])

class MklSubgraph:
    def __init__(self, fx_graph: fx.Graph):
        self.fx_graph = fx_graph
        self.nodes: List[fx.Node] = []
        self.start_nodes: List[fx.Node] = []
        self.end_nodes: List[fx.Node] = []

def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
    This generates a heuristic that can be passed into `optimize_for_inference` that
    determines whether a subgraph should be run in MKL by running it with the example_inputs.
    Example usage:
        heuristic = gen_mkl_autotuner(example_inputs, iters=10)
        fast_model = optimization.optimize_for_inference(model, heuristic)
    fx_model = None
    old_modules = None

    def use_mkl_heuristic(graph: MklSubgraph) -> bool:
        nonlocal fx_model, old_modules
        input_nodes = graph.start_nodes
        if fx_model is None:
            fx_model = graph.fx_graph.owning_module
            old_modules = graph.fx_graph.old_modules  # type: ignore[attr-defined]
        sample_inputs = [torch.randn(node.shape) for node in input_nodes]  # type: ignore[attr-defined]
        output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes])
        submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)

        def benchmark(f):
            for _ in range(warmup):
            begin = time.time()
            for _ in range(iters):
                out = f()
            return time.time() - begin

        mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])

        reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules)
        no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
        return mkl_time < no_mkl_time
    return use_mkl_heuristic

def use_mkl_length(graph: MklSubgraph) -> bool:
    This is a heuristic that can be passed into `optimize_for_inference` that
    determines whether a subgraph should be run in MKL by checking if there
    are more than 2 nodes in it
    return len(graph.nodes) > 2

class UnionFind:
    def __init__(self, n):
        self.parent: List[Optional[int]] = [None] * n
        self.size: List[int] = [0] * n

    def make_set(self, v: int):
        self.parent[v] = v
        self.size[v] = 1

    def find(self, v: int) -> int:
        par = self.parent[v]
        if v == par:
            return v
        assert(par is not None)
        self.parent[v] = self.find(par)
        return cast(int, self.parent[v])

    def join(self, a: int, b: int):
        a, b = self.find(a), self.find(b)
        if a == b:
            return a
        if self.size[a] < self.size[b]:
            a, b = b, a
        self.parent[b] = a
        self.size[a] += self.size[b]

def optimize_for_inference(
    model: torch.nn.Module,
    pass_config: Optional[Dict[str, Any]] = None,
    tracer: Type[fx.Tracer] = fx.Tracer
) -> torch.nn.Module:
    Performs a set of optimization passes to optimize a model for the
    purposes of inference. Specifically, the passes that are run are:
    1. Conv/BN fusion
    2. Dropout removal
    3. MKL layout optimizations
    The third optimization takes a function `use_mkl_heuristic` that's used
    to determine whether a subgraph should be explicity run in MKL layout.
    Note: As FX does not currently handle aliasing, this pass currently
    assumes nothing aliases. If that isn't true, use at your own risk.
    default_pass_config = {
        "conv_bn_fuse": True,
        "remove_dropout": True,
        "mkldnn_layout_optimize": {'heuristic': use_mkl_length},
    if pass_config is None:
        pass_config = {}

    if default_pass_config["conv_bn_fuse"]:
        model = fuse(model)
    if default_pass_config["remove_dropout"]:
        model = remove_dropout(model)
    if default_pass_config["mkldnn_layout_optimize"] is False:
        return model
    if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
        raise RuntimeError("mkldnn_layout_optimize config is not a dict")
    if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
        raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
    use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]

    cur_tracer = tracer()
    fx_graph = cur_tracer.trace(copy.deepcopy(model))
    fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
    modules: Dict[str, nn.Module] = dict(model.named_modules())

    class MklSupport(Enum):
        NO = 1
        YES = 2
        UNKNOWN = 3

    # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
    # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
    # However, if it's in `mkldnn_supported_unknown`, then we only treat it as
    # a MKLDNN node if its inputs are MKLDNN nodes.
    for node in list(fx_graph.nodes):
        supports_mkldnn = MklSupport.NO
        if node.op == 'call_module':
            cur_module = modules[]
            if type(cur_module) in mkldnn_supported:
                supports_mkldnn = MklSupport.YES
                sample_parameter = next(cur_module.parameters(), None)
                if sample_parameter is not None:
                    assert(sample_parameter.dtype == torch.float), "this pass is only for torch.float modules"
                    assert(sample_parameter.device == torch.device('cpu')), "this pass is only for CPU modules"
        elif node.op == 'call_function':
            if in mkldnn_supported:
                supports_mkldnn = MklSupport.YES
            elif in mkldnn_supported_unknown:
                supports_mkldnn = MklSupport.UNKNOWN

        if supports_mkldnn != MklSupport.NO:
            if supports_mkldnn == MklSupport.UNKNOWN:
                if not any([ == 'to_dense' for arg in node.args]):
            with fx_graph.inserting_before(node):
                mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))

            node.args = cast(Tuple[fx.node.Argument], mkldnn_args)

            with fx_graph.inserting_after(node):
                dense_x = fx_graph.create_node('call_method', 'to_dense', (node,))
                dense_x.args = (node,)

    # Does pre-conversion of all modules into MKLDNN (when possible)
    old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
    fx_graph.old_modules = old_modules  # type: ignore[attr-defined]

    # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
    for node in fx_graph.nodes:
        if node.op == 'call_method' and == 'to_dense':
            prv_node = node.args[0]
            users = list(node.users)
            for user in users:
                if user.op == 'call_method' and == 'to_mkldnn':
            if len(node.users) == 0:

    num_nodes = len(fx_graph.nodes)
    uf = UnionFind(num_nodes)

    def get_color(n):
        if hasattr(n, 'color'):  # Current node is part of a MKL subgraph
            return uf.find(n.color)
        if hasattr(n, 'start_color'):  # Current node is input to MKL subgraph
            return uf.find(n.start_color)
        return None

    # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
    # of input nodes (which are only `to_mkldnn` calls), output nodes
    # (`to_dense` calls), and intermediate nodes, which are run entirely on
    # MKLDNN layout tensors.
    # Specifically, this code does a flood fill on a directed acyclic graph
    # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
    # If every node only had one input, this would be sufficient. However, in
    # the case that a node has multiple inputs coming from different start
    # nodes (i.e. colors), we need to join these 2 colors into 1. That's done
    # using a Disjoint Set Union.
    for cur_idx, node in enumerate(fx_graph.nodes):
        if node.op == 'call_method' and == 'to_mkldnn':
            node.start_color = cur_idx
        elif node.op == 'call_method' and == 'to_dense':
            assert(get_color(node.args[0]) is not None)
            node.end_color = get_color(node.args[0])
            cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]

            if len(cur_colors) == 0:
            assert(not any(i is None for i in cur_colors))
            cur_colors = sorted(cur_colors)
            node.color = cur_colors[0]
            for other_color in cur_colors[1:]:
                uf.join(cur_colors[0], other_color)

    mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
    for node in fx_graph.nodes:
        if hasattr(node, 'color'):
        if hasattr(node, 'start_color'):
        if hasattr(node, 'end_color'):

    # Now that we have all the subgraphs, we need to decide which MKLDNN
    # subgraphs we actually want to keep in MKLDNN.
    for graph in mkldnn_graphs.values():
        if not use_mkl_heuristic(graph):
            for node in graph.start_nodes + graph.end_nodes:
                prv = node.args[0]
            reset_modules(graph.nodes, modules, old_modules)

    mkldnn_conversions = 0
    for node in fx_graph.nodes:
        if == 'to_mkldnn' or == 'to_dense':
            mkldnn_conversions += 1"mkldnn conversions: {mkldnn_conversions}")
    result = fx.GraphModule(model, fx_graph)
    return result


加:2022-03-22 20:35:19  更:2022-03-22 20:39:36 
