之前在对conv和bn算子融合的时候偶然得知在pytorch1.10中是可以进行部分操作的。故写下此学习记录。torch中的fx主要功能是实现对nn.Module实例的变换,或者说用来操作模型。
torch.fx中主要包含三个组件:符号追踪器(symbolic tracer),中间表示(intermediate representation),python代码生成(python code generation)。
import torch
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
print(symbolic_traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符号追踪器对模块的forward代码进行符号执行,送入的是假的输入,叫Proxies,代码中所有的操作都会被记录下来。 这个追踪最终可以得到代码计算图的中间表示:torch.fx.Graph。Graph中记录了所有的操作具体的,一个Graph包括一系列的torch.fx.Node,Node是Graph的基本单元,它对应的是一个操作,Node.op记录的具体的操作类型,主要包括以下几种类型:placeholder,get_attr,call_function,call_module,call_method,output。
- 有关图的例子
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
输出:
opcode name target args kwargs
------------- ------------- ----------------------------------------------------------- ------------------ -----------
placeholder x x () {}
get_attr linear_weight linear.weight () {}
call_function add <built-in function add> (x, linear_weight) {}
call_module linear linear (add,) {}
call_method relu relu (linear,) {}
call_function sum_1 <built-in method sum of type object at 0x00007FF80165E360> (relu,) {'dim': -1}
call_function topk <built-in method topk of type object at 0x00007FF80165E360> (sum_1, 3) {}
output output output (topk,) {}
定义一个模块MyModule,实例化并追踪,然后调用graph.print_tabular方法打印出来,显示这个图的节点。在打印的信息中可以看到每个Node除了op之外,还有name,traget,args和kwargs,对于不同的op其中含义有点区别。placeholder其实就是Graph的输入,而output是Graph的输出,它们的target和name一样。get_attr就是获取module的参数,call_function是调用函数,它的target指明了具体的函数,call_module是调用子module,target就是子module名,call_method就是调用torch的函数。args和kwargs就是op对用的tuple和dict参数,可以看到很多ops的args其实就是其它Node的name,所以这样各个Node就是建立了联系,从而构成了Graph。
- 图形操作
最后一个组件,就是用于Python代码生成,就是根据Graph的语义自动生成相应的执行代码。 torch.fx做的就是将一个Module转换为静态图,这和转换Module有什么关系呢?如果我们将一个Module追踪得到的Graph进行变换,加上Python代码生成 工具,是不是就可以达到变换一个Module的目的,整个流程就是symbolic tracing -> intermediate representation -> transforms -> Python code generation,这就实现了一个Module到另外一个Module的Python-to-Python的转换流程流程如下所示:
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
graph : torch.fx.Graph = tracer_class().trace(m)
graph = ...
return torch.fx.GraphModule(m, graph)
这里最终得到的torch.fx.GraphModule除了包含graph和code属性外就和正常的nn.Module一样,它的forward执行的就是graph的语义代码。这里来看一个修改Module的简单例子,这个例子中我们将模块中所有的torch.add()操作替换成 torch.mul() :
import torch
import torch.fx
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
for node in graph.nodes:
if node.op == 'call_function':
if node.target == torch.add:
node.target = torch.mul
graph.lint()
return fx.GraphModule(m, graph)
① 删除活添加节点,FX的API torch.relu() ② replace_pattern()用于编辑图形的查找替换API
- 图像操作的相关案例
① Replace one op ② Conv/Batch Norm fusion ③ replace_pattern: Basic usage ④ Quantization ⑤ Invert Transformation
结合之前做的conv和bn融合的案例,因为在推理阶段将BN融合到Conv里合成一个操作可以加速推理速度,当时就查到了torch.fx是可以很快解决的,具体代码如下:
import torch.fx as fx
from torch.fx.node import Argument, Target
from torch.nn.utils.fusion import fuse_conv_bn_eval
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fx.passes.shape_prop import ShapeProp
import copy
from collections import defaultdict
import torch.utils.mkldnn as th_mkldnn
import operator
import time
import logging
from enum import Enum
def _parent_name(target : str) -> Tuple[str, str]:
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
if len(node.args) == 0:
return False
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
for expected_type, current_node in zip(pattern, nodes):
if not isinstance(current_node, fx.Node):
return False
if current_node.op != 'call_module':
return False
if not isinstance(current_node.target, str):
return False
if current_node.target not in modules:
return False
if type(modules[current_node.target]) is not expected_type:
return False
return True
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
assert(isinstance(node.target, str))
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
"""
Fuses convolution/BN layers for inference purposes. Will deepcopy your
model by default, but can modify the model inplace as well.
"""
patterns = [(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d)]
if not inplace:
model = copy.deepcopy(model)
fx_model = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
new_graph = copy.deepcopy(fx_model.graph)
for pattern in patterns:
for node in new_graph.nodes:
if matches_module_pattern(pattern, node, modules):
if len(node.args[0].users) > 1:
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
fused_conv = fuse_conv_bn_eval(conv, bn)
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])
new_graph.erase_node(node)
return fx.GraphModule(fx_model, new_graph)
def remove_dropout(model: nn.Module) -> nn.Module:
"""
Removes all dropout layers from the module.
"""
fx_model = fx.symbolic_trace(model)
class DropoutRemover(torch.fx.Transformer):
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if isinstance(self.submodules[target], nn.Dropout):
assert len(args) == 1
return args[0]
else:
return super().call_module(target, args, kwargs)
return DropoutRemover(fx_model).transform()
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
"""
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
"""
new_graph = fx.Graph()
env: Dict[fx.Node, fx.Node] = {}
for input in inputs:
new_node = new_graph.placeholder(input.name)
env[input] = new_node
for node in nodes:
new_node = new_graph.node_copy(node, lambda x: env[x])
env[node] = new_node
new_graph.output([env[output] for output in outputs])
new_graph.lint()
return fx.GraphModule(orig_module, new_graph)
mkldnn_supported = [
nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d,
torch.relu, torch.transpose, torch.sigmoid,
F.relu, F.avg_pool2d, F.adaptive_avg_pool2d
]
mkldnn_supported_unknown = [operator.add, operator.mul]
mkldnn_map = {
nn.Conv2d: th_mkldnn.MkldnnConv2d,
nn.Linear: th_mkldnn.MkldnnLinear,
nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a)
}
def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
"""
For each node, if it's a module that can be preconverted into MKLDNN,
then we do so and create a mapping to allow us to convert from the MKLDNN
version of the module to the original.
"""
old_modules: Dict[nn.Module, nn.Module] = {}
for node in nodes:
if node.op == 'call_module':
assert(isinstance(node.target, str))
cur_module = modules[node.target]
if type(cur_module) in mkldnn_map:
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
assert(isinstance(new_module, nn.Module))
old_modules[new_module] = copy.deepcopy(cur_module)
replace_node_module(node, modules, new_module)
return old_modules
def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]):
"""
Maps each module that's been changed with `modules_to_mkldnn` back to its
original.
"""
for node in nodes:
if node.op == 'call_module':
assert(isinstance(node.target, str))
cur_module = modules[node.target]
if cur_module in old_modules:
replace_node_module(node, modules, old_modules[cur_module])
class MklSubgraph:
def __init__(self, fx_graph: fx.Graph):
self.fx_graph = fx_graph
self.nodes: List[fx.Node] = []
self.start_nodes: List[fx.Node] = []
self.end_nodes: List[fx.Node] = []
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
"""
This generates a heuristic that can be passed into `optimize_for_inference` that
determines whether a subgraph should be run in MKL by running it with the example_inputs.
Example usage:
heuristic = gen_mkl_autotuner(example_inputs, iters=10)
fast_model = optimization.optimize_for_inference(model, heuristic)
"""
fx_model = None
old_modules = None
def use_mkl_heuristic(graph: MklSubgraph) -> bool:
nonlocal fx_model, old_modules
input_nodes = graph.start_nodes
if fx_model is None:
fx_model = graph.fx_graph.owning_module
old_modules = graph.fx_graph.old_modules
ShapeProp(fx_model).propagate(example_inputs)
sample_inputs = [torch.randn(node.shape) for node in input_nodes]
output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes])
submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
def benchmark(f):
for _ in range(warmup):
f()
begin = time.time()
for _ in range(iters):
out = f()
return time.time() - begin
mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules)
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
return mkl_time < no_mkl_time
return use_mkl_heuristic
def use_mkl_length(graph: MklSubgraph) -> bool:
"""
This is a heuristic that can be passed into `optimize_for_inference` that
determines whether a subgraph should be run in MKL by checking if there
are more than 2 nodes in it
"""
return len(graph.nodes) > 2
class UnionFind:
def __init__(self, n):
self.parent: List[Optional[int]] = [None] * n
self.size: List[int] = [0] * n
def make_set(self, v: int):
self.parent[v] = v
self.size[v] = 1
def find(self, v: int) -> int:
par = self.parent[v]
if v == par:
return v
assert(par is not None)
self.parent[v] = self.find(par)
return cast(int, self.parent[v])
def join(self, a: int, b: int):
a, b = self.find(a), self.find(b)
if a == b:
return a
if self.size[a] < self.size[b]:
a, b = b, a
self.parent[b] = a
self.size[a] += self.size[b]
def optimize_for_inference(
model: torch.nn.Module,
pass_config: Optional[Dict[str, Any]] = None,
tracer: Type[fx.Tracer] = fx.Tracer
) -> torch.nn.Module:
"""
Performs a set of optimization passes to optimize a model for the
purposes of inference. Specifically, the passes that are run are:
1. Conv/BN fusion
2. Dropout removal
3. MKL layout optimizations
The third optimization takes a function `use_mkl_heuristic` that's used
to determine whether a subgraph should be explicity run in MKL layout.
Note: As FX does not currently handle aliasing, this pass currently
assumes nothing aliases. If that isn't true, use at your own risk.
"""
default_pass_config = {
"conv_bn_fuse": True,
"remove_dropout": True,
"mkldnn_layout_optimize": {'heuristic': use_mkl_length},
}
if pass_config is None:
pass_config = {}
default_pass_config.update(pass_config)
if default_pass_config["conv_bn_fuse"]:
model = fuse(model)
if default_pass_config["remove_dropout"]:
model = remove_dropout(model)
if default_pass_config["mkldnn_layout_optimize"] is False:
return model
if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
raise RuntimeError("mkldnn_layout_optimize config is not a dict")
if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
cur_tracer = tracer()
fx_graph = cur_tracer.trace(copy.deepcopy(model))
fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
modules: Dict[str, nn.Module] = dict(model.named_modules())
class MklSupport(Enum):
NO = 1
YES = 2
UNKNOWN = 3
for node in list(fx_graph.nodes):
supports_mkldnn = MklSupport.NO
if node.op == 'call_module':
cur_module = modules[node.target]
if type(cur_module) in mkldnn_supported:
supports_mkldnn = MklSupport.YES
sample_parameter = next(cur_module.parameters(), None)
if sample_parameter is not None:
assert(sample_parameter.dtype == torch.float), "this pass is only for torch.float modules"
assert(sample_parameter.device == torch.device('cpu')), "this pass is only for CPU modules"
elif node.op == 'call_function':
if node.target in mkldnn_supported:
supports_mkldnn = MklSupport.YES
elif node.target in mkldnn_supported_unknown:
supports_mkldnn = MklSupport.UNKNOWN
if supports_mkldnn != MklSupport.NO:
if supports_mkldnn == MklSupport.UNKNOWN:
if not any([arg.target == 'to_dense' for arg in node.args]):
continue
with fx_graph.inserting_before(node):
mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))
node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
with fx_graph.inserting_after(node):
dense_x = fx_graph.create_node('call_method', 'to_dense', (node,))
node.replace_all_uses_with(dense_x)
dense_x.args = (node,)
old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
fx_graph.old_modules = old_modules
for node in fx_graph.nodes:
if node.op == 'call_method' and node.target == 'to_dense':
prv_node = node.args[0]
users = list(node.users)
for user in users:
if user.op == 'call_method' and user.target == 'to_mkldnn':
user.replace_all_uses_with(prv_node)
fx_graph.erase_node(user)
if len(node.users) == 0:
fx_graph.erase_node(node)
num_nodes = len(fx_graph.nodes)
uf = UnionFind(num_nodes)
def get_color(n):
if hasattr(n, 'color'):
return uf.find(n.color)
if hasattr(n, 'start_color'):
return uf.find(n.start_color)
return None
for cur_idx, node in enumerate(fx_graph.nodes):
if node.op == 'call_method' and node.target == 'to_mkldnn':
node.start_color = cur_idx
uf.make_set(cur_idx)
elif node.op == 'call_method' and node.target == 'to_dense':
assert(get_color(node.args[0]) is not None)
node.end_color = get_color(node.args[0])
else:
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
if len(cur_colors) == 0:
continue
assert(not any(i is None for i in cur_colors))
cur_colors = sorted(cur_colors)
node.color = cur_colors[0]
for other_color in cur_colors[1:]:
uf.join(cur_colors[0], other_color)
mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
for node in fx_graph.nodes:
if hasattr(node, 'color'):
mkldnn_graphs[uf.find(node.color)].nodes.append(node)
if hasattr(node, 'start_color'):
mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
if hasattr(node, 'end_color'):
mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
for graph in mkldnn_graphs.values():
if not use_mkl_heuristic(graph):
for node in graph.start_nodes + graph.end_nodes:
prv = node.args[0]
node.replace_all_uses_with(prv)
fx_graph.erase_node(node)
reset_modules(graph.nodes, modules, old_modules)
mkldnn_conversions = 0
for node in fx_graph.nodes:
if node.target == 'to_mkldnn' or node.target == 'to_dense':
mkldnn_conversions += 1
logging.info(f"mkldnn conversions: {mkldnn_conversions}")
fx_graph.lint()
result = fx.GraphModule(model, fx_graph)
return result
注: 参考文章: https://zhuanlan.zhihu.com/p/428735136 文档地址:https://pytorch.org/docs/stable/fx.html
|