参考:https://pytorch.org/docs/stable/fx.html
Intro
? FX 是针对 torch.nn.module 而开发的工具,其能动态地获取 model 前向传播的执行过程,以便动态地增加、删除、改动、检查运算操作。其由三个主要组件组成:符号追踪器(Symbolic Tracer)、中间表示(Intermediate Representation, IR)和 Python 代码生成。这三个组件常常同时出现,如下面的例子:
import torch
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
print(symbolic_traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
- 符号追踪器会对代码执行“符号”。其喂入(自己生成的)假数据(Proxy),来执行代码。由 Proxy 经过、执行的代码会被记录下来。
- 中间表示是在 Trace 期间记录各种操作的“容器”。其由一个节点(Node)列表组成,这些节点表示了函数的输入、名字和返回值。
- Python 代码生成是一种代码生成工具,可以根据当前 IR 图的内容生成正确、可执行的 Python 代码。这代码是可以复制出来黏贴使用的,可以用于进一步配置模型的(
forward )定义。
? 总的来说,FX 的使用流程为:符号跟踪->中间表示->转换->Python代码生成。这是一种 Python-to-Python 的方法。FX 的精髓在于“Dynamic Transformation”,即当你需要对模型进行额外改动设计(如插入量化节点、算子 Fusion)时,不需要繁琐地针对模型的每一个部分来修改代码,只需要按照 FX 的流程来高效自动化地实现。
FX 定义的类对象
- GraphModule:是由
fx.Graph 生成而来的 nn.Module ,其有对应的 graph 、code 成员变量。当 graph 成员变量被重新赋值过,code 变量和 forward() 函数回自动重新生成。如果你编辑过 graph 的内容却没有重新赋值过,那你必须调用 recompile() 函数来更新信息。torch.fx.symbolic_trace() 函数作用完后 return 的就是 GraphModule 。 - Graph:是 FX 的 IR 图的主要数据结构,由一系列有序的
Node 组成。这一一系列的 Node 就构成了执行逻辑。torch.fx.Tracer.trace() 函数作用完后 return 的就是 Graph 。 - Node:是
graph 中操作的单位数据结构。大多数情况下,Node 代表了各种实体的调用方式,如输入(Input)、输出(Output)、算子(Operator)、已执行的成员函数(Method)和子模型(Module)。每个 Node 都有一个 op 属性,具体分类如下:
placeholder :表示整个模型的输入。get_attr :表示从模型层次结构中检索参数。call_function :表示将自由函数应用于某些值。call_module :表示将模型层次结构的 forward() 成员函数中的子模块应用于给定参数。call_method :表示对某值调用成员函数。output :这与打印 graph 输出中的 return 语句内容相对应。 - Proxy:在符号追踪期间会用到。其本质上是一个
Node Wrapper ,用于流经程序的执行过程并记录下所有的操作(被调用的 torch function、method 和 operator)。若没有主动设置的话,Pytorch 会生成默认的 Proxy 用于符号追踪 。
Example for Transformation
? 对模型的图进行额外改动的方法有很多,如直接获取图并修改图(Direct Graph Manipulation),或通过在 GraphModule 模型上间接获取图来修改图(GraphModule Modification)。
Direct Graph Manipulation
简单替换 Node (利用 Pattern)
- 遍历
GraphModule 的 Graph 中的所有 Node 。 - 判断当前
Node 是否满足替换要求(可以用 target 属性作为判断条件)。 - 创建一个新的
Node 并插入到 Graph 中。 - 使用 FX 内置的 replace_all_uses_with 函数来将要被替换
Node 的输入输出流(flow)重新定向到新 Node 身上。 - 从
Graph 中删除旧 Node。 - 调用
recompile() 函数来更新 GraphModule 。
? 下面一个例子展示 FX 如何将任何加法操作替换成二进制与(AND)运算:
import torch
from torch.fx import symbolic_trace
import operator
class M(torch.nn.Module):
def forward(self, x, y):
return x + y, torch.add(x, y), x.add(y)
traced = symbolic_trace(M())
patterns = set([operator.add, torch.add, "add"])
for n in traced.graph.nodes:
if any(n.target == pattern for pattern in patterns):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
traced.recompile()
复杂替换 Node(利用 Proxy)
? 另一个修改 Graph 的方式是利用 Proxy ,再在一次主动 Trace 的过程中复制 Node 、构建新 Node 来组成新的 Graph 。
import torch
import torch.fx as fx
import torch.nn.functional as F
class M(torch.nn.Module):
def forward(self, x, y):
o = F.relu(x) + F.relu(y)
return o
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose_relu(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
mapping_table = {}
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
proxy_args = []
for x in node.args:
if isinstance(x, fx.Node):
proxy_args.append(fx.Proxy(mapping_table[x.name]))
else:
proxy_args.append(x)
output_proxy = decomposition_rules[node.target](*proxy_args)
new_node = output_proxy.node
mapping_table[node.name] = new_node
else:
def node_mapping(x):
return mapping_table[x.name]
new_node = new_graph.node_copy(node, node_mapping)
mapping_table[node.name] = new_node
return fx.GraphModule(model, new_graph)
decompose_relu(M())
? Proxy 可以想象为一个“穿线器”:绑定 Node 后,在经过新的 Node 时能自动“串”好连接关系并加入到原 Graph 中。能记录此时的“线头”,即记录访问到的 Node 。
GraphModule Modification
下面一个例子展示 FX 是如何通过 GraphModule 间接替换 torch.add() 为 torch.mul() 的:
import torch
import torch.fx as fx
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module) -> torch.nn.Module:
gm : fx.GraphModule = fx.symbolic_trace(m)
for node in gm.graph.nodes:
if node.op == 'call_function':
if node.target == torch.add:
node.target = torch.mul
gm.recompile()
gm.graph.lint()
return gm
transform(M())
符号追踪的局限性(注意事项)
控制流(Control Flow)
? PyTorch 官方将 if 语句、循环语句等具有选择/判断性质的语句称为控制流。在 FX语境中,控制流又可以分为动态控制流(Dynamic Control Flow)和静态控制流(Static Control Flow)。
? FX 无法 trace 动态控制流,但可以 trace 判断条件明确的静态控制流。
动态控制流
? 若控制流的判断条件含有运算变量(Input Tensor)参与,那么该控制流就称为动态控制流,如:
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
? 此时对该函数使用 trace 功能就会报错:
"""
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
静态控制流
? 类推可知,若控制流的判断条件无运算变量参与,也即判断条件的变量不参与流(Flow)计算,那么该控制流就称为静态控制流,如:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
if self.do_activation:
x = torch.relu(x)
return x
? 若想 trace 静态控制流,就需要明确判断条件,即给判断变量显式赋值:
without_activation = MyModule(do_activation=False)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
非 torch 函数
? 有些函数没有__torch_function__ 属性,例如 Python 自带的函数或 math 库中的函数,无法被 trace 追踪。例如,当你的模型里调用了 len() 函数,那么进行 trace 时会报错:
"""
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want ")
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
? 那么需要使用 wrap() API 来将普通函数包装成 torch 性质的函数:
torch.fx.wrap('len')
traced = torch.fx.symbolic_trace(normalize)
查看 Graph 内容
通过 print() 函数
如:
print(traced_model.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
通过 print_tabular() 函数
? 通过调用 print_tabular() 函数就可以以 tabular 的格式输出 IR 图:
traced_model.graph.print_tabular()
"""
opcode name target args kwargs
------------- -------- ----------------------- ----------- ------------------------
placeholder x x () {}
get_attr param param () {}
call_function add_1 <built-in function add> (x, param) {}
call_module linear_1 linear (add_1,) {}
call_method clamp_1 clamp (linear_1,) {'min': 0.0, 'max': 1.0}
output output output (clamp_1,) {}
"""
总结
? 目前,FX 没有提供任何方式来保证/验证运算符在语法上是有效的。也就是说,任何新(定义)加入的运算符都必须由用户自己来保证其正确性。
? 最后官网建议的一点是,你在对 Graph 做变换时,应该让整个程序的输入 torch.nn.Module ,然后获取对应的 Graph ,做出修改,最后再返回一个 torch.nn.Module 。这样更方便后续工作,比如又传入下一段 FX 代码中。
? 以上总结如有谬误,还请包涵、指正。
|