torch.fx
是一个用于转换(transform)PyTorch模型(即nn.Module
)的工具包。从torch 1.10
开始,工具包不再处于beta阶段,torch.fx
成为了PyTorch的稳定功能。
最近我比较闲,按照文档随便试了一下torch.fx
的功能,立马意识到,这玩意真的挺有用!torch.fx
能将nn.Module
转换为一个图结构,图的节点保存着当前网络节点前向时的输入,输出和参数,以及网络结构本身。这个图结构保存的信息足够多,api丰富。我一直苦于看不懂使用Tensorboard的add_graph
方法得到的网络结构图,就尝试用graphviz可视化torch.fx
得到的图,发现效果确实不错,比Tensorboard的结果清晰不少。
比如可视化ResNet18网络中的某个残差结构:
在torch.fx
和graphviz
的帮助下,写一个绘制不同网络结构的通用程序只需要100多行代码。其中包含将nn.Module
网络转换为图结构、提取图结构中每个节点和边的信息、可视化节点与边这三部分代码。
将网络转换为图结构#
为了实现对nn.Module
的编辑转换,torch.fx
包含三个主要的类,分别负责符号追踪(symbolic tracer),中间表示(intermediate representation),生成Python代码(Python code generation)。其中符号追踪类torch.fx.Tracer
就负责将nn.Module
网络转换为中间表示,即torch.fx.Graph
类实例。类torch.fx.Graph
表示计算图,而torch.fx.Node
表示计算图中的节点。
让我们首先定义一个包含多种操作的nn.Module
,用于测试。
import torch
import torch.nn as nn
import torch.nn.functional as F
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.bias = nn.Parameter(torch.randn(1))
self.main = nn.Sequential(nn.Conv2d(3, 4, 1), nn.ReLU(True))
self.skip = nn.Conv2d(2, 4, 3, stride=1, padding=1)
def forward(self, x, y):
x = self.main(x)
y = (self.skip(y)+self.bias).clamp(0, 1)
x_size = x.size()[-2:]
y = F.interpolate(y, x_size, mode="bilinear", align_corners=False)
return torch.sigmoid(x) + y
接着,利用
torch.fx.Tracer
转换我们定义的网络:
tm = TestModel()
# 调用`torch.fx.symbolic_trace(m)`与调用
# `torch.fx.Tracer().trace(m)`等价
symbolic_traced : torch.fx.GraphModule = torch.fx.symbolic_trace(tm)
torch.fx.GraphModule
是从torch.fx.Graph
转换生成的nn.Module
子类,可以当作一个正常的nn.Module
来使用,除此以外,其.graph
属性就是我们要获取的网络计算图(torch.fx.Graph
类)。
总结一下,转换一个nn.Module
为计算图torch.fx.Graph
只需要调用函数
torch.fx.symbolic_trace(model).graph
。
提取计算图信息#
torch.fx.Graph
是torch.fx
中间表示最关键的数据结构,存储着计算图中计算节点torch.fx.Node
列表,每个节点结构保存着节点的计算细节以及节点之间的调用关系。
将我们得到的计算图print出来:
print(symbolic_traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%y : [#users=1] = placeholder[target=y]
%main_0 : [#users=1] = call_module[target=main.0](args = (%x,), kwargs = {})
%main_1 : [#users=2] = call_module[target=main.1](args = (%main_0,), kwargs = {})
%skip : [#users=1] = call_module[target=skip](args = (%y,), kwargs = {})
%bias : [#users=1] = get_attr[target=bias]
%add : [#users=1] = call_function[target=operator.add](args = (%skip, %bias), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%add, 0, 1), kwargs = {})
%size : [#users=1] = call_method[target=size](args = (%main_1,), kwargs = {})
%getitem : [#users=1] = call_function[target=operator.getitem](args = (%size, slice(-2, None, None)), kwargs = {})
%interpolate : [#users=1] = call_function[target=torch.nn.functional.interpolate](args = (%clamp,), kwargs = {size: %getitem, scale_factor: None, mode: bilinear, align_corners: False, recompute_scale_factor: None})
%sigmoid : [#users=1] = call_function[target=torch.sigmoid](args = (%main_1,), kwargs = {})
%add_1 : [#users=1] = call_function[target=operator.add](args = (%sigmoid, %interpolate), kwargs = {})
return add_1
"""
除了第一行的graph()
和最后一行的return
,中间每行都是一个单独的计算图节点,每个以%
开头的变量都是计算图节点。每行节点名后:
展示的就是节点的主要信息。在args
和kwargs
经常会出现以%
开头的变量名,这些计算节点就是当前计算节点的参数,反应了节点与节点之间的依赖关系,即图中的边。
可以发现,torch.fx.Graph
结构中最关键的就是torch.fx.Node
列表,所有信息都在这些计算节点中保存。
torch.fx.Node#
torch.fx.Node
保存着计算图中单个计算节点的所有信息,说明当前要完成什么类型的计算(.op
),计算是调用什么完成的(.target
),以及完成计算时的参数是什么(.args
和.kwargs
)。每一个Node有一个.name
属性,唯一地标识当前节点,即不同节点的.name
属性一定是不同的。
下表展示了不同类型(.op
)的计算节点的不同属性的值代表什么。
.op |
描述 | .name |
.target |
.args |
.kwargs |
---|---|---|---|---|---|
placeholder |
网络输入 | 输入变量名 | 输入变量名 | 主网络前向时的参数数组 | |
get_attr |
取网络参数 | 参数名称 | 参数在主网络拓扑中的完整名 | ||
call_function |
调用函数 | 函数名称 | 要调用的函数本身 | 调用函数时的参数数组 | 调用函数时的参数字典 |
call_module |
子网络前向 | 子网络在主网络拓扑中的完整名 | 子网络在主网络拓扑中的完整名 | 网络前向时的参数数组 | 网络前向时的参数字典 |
call_method |
调用类方法 | 类方法名 | 类方法名 | 调用类方法时的参数数组 | 调用类方法时的参数字典 |
output |
网络输出 | args[0] 就是主网络输出 |
将由自定义的TestModel
类得到的计算图的不同节点的关键属性值打印出来:
symbolic_traced.graph.print_tabular()
"""
opcode name target args kwargs
------------- ----------- -------------------------------------------------------------- ----------------------------- -------------------------------------------------------------------------------------------------------------------
placeholder x x () {}
placeholder y y () {}
call_module main_0 main.0 (x,) {}
call_module main_1 main.1 (main_0,) {}
call_module skip skip (y,) {}
get_attr bias bias () {}
call_function add <built-in function add> (skip, bias) {}
call_method clamp clamp (add, 0, 1) {}
call_method size size (main_1,) {}
call_function getitem <built-in function getitem> (size, slice(-2, None, None)) {}
call_function interpolate <function interpolate at 0x000002A12F1E18B0> (clamp,) {'size': getitem, 'scale_factor': None, 'mode': 'bilinear', 'align_corners': False, 'recompute_scale_factor': None}
call_function sigmoid <built-in method sigmoid of type object at 0x00007FFF9EC8F530> (main_1,) {}
call_function add_1 <built-in function add> (sigmoid, interpolate) {}
output output output (add_1,) {}
"""
每个计算图节点的.name
都和.target
对应,为了避免.name
一致,后出现的节点的.name
会再增加一个出现次数项_1
。
显然,可以用.name
来作为绘制网络结构图时每个节点的唯一标识符。
但用什么来表示当前的节点进行的计算呢(用作绘制网络节点时的label)?call_function
、call_method
和get_attr
以及输入输出节点似乎只需要把.target
整理一下就可以作为label,但call_module
类型的节点,除了当前节点子网络在主网络中完整名称,还需要知道网络本身的信息。幸运的是,torch.nn.Module
类提供一个
.get_submodule
方法,可以根据名称获取子网络。比如A
网络有子网络net_b
,而net_b
又有子网络net_c
,net_c
是一个简单的nn.Conv2d
,要想获取到这个nn.Conv2d
这个类实例本身,可以用A.get_submodule(".net_b.net_c")
,这里的".net_b.net_c"
就是我们一直再说的“子网络在主网络拓扑中的完整名称”。获取到网络本身后,将其转换为字符串(str(net)
)就能得到网络本身的描述了,可以用作label。
另外call_function
类型节点的.target
是可调用的函数本身,为了更用户友好地print这个函数,我研究了半天,最后发现打印torch.nn.Graph
的结果里,[target=<...>]
中的结果就挺好,看了一下torch.fx
的源码,发现torch.fx.Node
的_pretty_print_target
方法就是做这个事情的。直接调用node._pretty_print_target()
,就可以得到一个节点很用户友好的label啦。
除了节点的计算本身的信息,节点的输入和输出也很关键。节点的输入节点按照我们刚才的分析比较好找:迭代节点的.args
和.kwargs
属性中的存储的调用参数,参数中的节点就是当前节点的输入节点。torch.fx.Node
类干脆专门添加了一个属性.all_input_nodes
,在内部执行了迭代过程,返回所有作为输入的计算节点。显然,每个输入节点和当前节点之间就构成了有向图中的一条边。
而网络的输出却很难获取,尤其是输出的Tensor的形状等等信息,这对我们分析网络很有帮助,但torch.fx.Node
本身却没有存储相关信息。说来也确实是,转换网络为计算图的过程中,根本没有指定输入Tensor的这一步啊?那每个计算图节点的输出的Tensor尺寸啥的信息怎么获取呢?这里就要引入torch.fx.Interpreter
类了。
torch.fx.Interpreter
#
torch.fx.Interpreter
负责按照torch.fx.Graph
描述的计算图,一个节点一个节点地执行计算。具体而言,torch.fx.Interpreter
在使用一个torch.fx.GraphModule
初始化后,当执行其.run()
方法时,按照计算图,分别用.run_node()
执行一个又一个torch.fx.Node
节点。执行计算图节点时,按照不同节点类型,分别调用.placeholder()
、.get_attr()
…等方法执行节点。这一过程中的映射执行关系可以可视化为:
run()
└─run_node()
├─placeholder()
├─get_attr()
├─call_function()
├─call_method()
├─call_module()
└─output()
因此,可以覆盖torch.fx.Interpreter
的.run_node()
方法,记录调用不同op对应的方法后的结果:
class ResultProbe(torch.fx.Interpreter):
def run_node(self, n: torch.fx.Node) -> Any:
try:
# 执行计算图节点,并保存结果
result = super().run_node(n)
except Exception:
traceback.print_exc()
raise RuntimeError(
f"ShapeProp error for: "
f"node={n.format_node()} with " f"meta={n.meta}"
)
find_tensor_in_result = False
def extract_tensor_meta(obj):
# 找到结果中的Tensor,记录其主要信息
if isinstance(obj, torch.Tensor):
nonlocal find_tensor_in_result
find_tensor_in_result = True
# _extract_tensor_metadata函数的细节省略,
# 可以参考文章最后的完整代码链接。
return _extract_tensor_metadata(obj)
else:
return obj
# torch.fx.node.map_aggregate会迭代result,
# 对result中的每个元素执行给定的函数
# 比如result是list时会对其每个元素调用指定的函数,
# result是dict是会对其每个value执行给定的函数。
n.meta["result"] = torch.fx.node.map_aggregate(
result,
extract_tensor_meta
)
n.meta["find_tensor_in_result"] = find_tensor_in_result
return result
自定义类ResultProbe
覆盖了torch.fx.Interpreter
的run_node
方法,然后将结果存储到torch.fx.Node
的.meta
属性中。另外,还保存了一个Flag:find_tensor_in_result
,以此指示当前计算节点的输出中是否包含Tensor,因为我们可能更关注那些真正处理Tensor的计算节点,而对如同求Tensor尺寸以供后续操作这样的计算节点兴趣乏乏。
接着,用ResultProbe
执行网络前向:
args = (torch.randn(1, 3, 16, 16), torch.randn(1, 2, 8, 8))
kwargs = dict()
# run的参数就是TestModel的参数
ResultProbe(symbolic_traced).run(*args, **kwargs)
执行完毕后,每个Node节点的.meta
属性就包含了"result"
和"find_tensor_in_result"
这两个信息,后续直接将其打印出来即可。
可视化节点与边#
我们利用Graphviz来实现图的可视化。 Graphviz是一个开源的图可视化软件,它定义了一种叫做DOT的图描述语言,能将用dot语言描述的图转换为svg、pdf、png等多种格式,将图可视化。可以按照 Graphviz Download/中的介绍安装Graphviz。Python库 graphviz则提供了Graphviz的API接口,能通过调用Python命令生成dot文件,接着利用Graphviz转换为图片。我们就利用graphviz包实现网络结构图的可视化。
在Ubuntu系统上,可以用下面的命令安装所有依赖:
# 安装Graphviz二进制文件
sudo apt install graphviz
# 安装Python库graphviz
pip install graphviz
GraphvizOnline则可以在线实现dot语言与图片的相互转换,可以在这个网页上大概浏览下dot语言是多么简洁。
Python库graphviz的使用方法也很简单:
import graphviz
dot = graphviz.Digraph()
# 只需要不断调用`.node`和`.edge`就能增加节点和边
dot.node('A', 'King Arthur')
dot.node('B', 'Sir Bedevere the Wise')
dot.node('L', 'Sir Lancelot the Brave')
dot.edge('B', 'L', constraint='false')
dot.edges(['AB', 'AL'])
# 生成图片
dot.render('test/round-table.gv', view=False)
dot语言接受HTML格式的label,因此可以生成HTML表格作为图中每个节点的label,以此表示更多信息。
借助graphviz可视化我们先前得到的计算图就很简单了:
# 可视化单个节点的核心代码:
def single_node(model: torch.nn.Module, graph: graphviz.Digraph, node: torch.fx.Node):
node_label = node_label_html(model, node) # 生成当前节点的label
node_kwargs = dict(shape="plaintext")
graph.node(node.name, node_label, **node_kwargs) # 在Graphviz图中添加当前节点
# 遍历当前节点的所有输入节点,添加Graphviz图中的边
for in_node in node.all_input_nodes:
edge_kwargs = dict()
if (
not node.meta["find_tensor_in_result"]
or not in_node.meta["find_tensor_in_result"]
):
# 如果当前节点的输入和输出中都没有Tensor,就把当前边置为浅灰色虚线,弱化显示
edge_kwargs.update(dict(style="dashed", color="lightgrey"))
# 添加当前边
graph.edge(in_node.name, node.name, **edge_kwargs)
其中node_label_html
函数就是一个把前文得到的Node的信息拼接在一个HTML表格中,然后返回HTML代码的函数。就不详细讲解如何拼接HTML字符串了。
最后,生成网络结构图的总体代码就很简单了:
def model_graph(model: torch.nn.Module, *args, **kwargs) -> graphviz.Digraph:
# 将nn.Module转换为torch.fx.GraphModule,获取计算图
symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(model)
# 执行一下网络,以此获取每个节点输入输出的具体信息
ResultProbe(symbolic_traced).run(*args, **kwargs)
# 定义一个Graphviz网络
graph = graphviz.Digraph("model", format="svg", node_attr={"shape": "plaintext"})
for node in symbolic_traced.graph.nodes: # 遍历所有节点
single_node(model, graph, node)
return graph
model = TestModel()
graph = model_graph(model, torch.randn(1, 3, 16, 16), torch.randn(1, 2, 8, 8))
graph.render(directory="test", view=False)
通过上述代码得到的TestModel
的可视化结果如图所示:
图中虚线部分表示x
在送入Conv2d
和ReLU
后,被调用.size
方法,获取特征图尺寸,用于后续缩放y
经过一系列变换后的结果。
而其它实线部分就是我们想要得到的PyTorch网络计算图了。
总结#
到目前为止,我们仅仅使用了torch.fx
最基础的功能。torch.fx
更为关键的nn.Module
编程转换功能还都没有涉及到。
我们所完成的绘制网络结构图的小工具也有很多还可以改进的地方,
比如通过覆盖torch.fx.Tracer
的is_leaf_module
方法实现对网络可视化层级的控制,
避免可视化一些如ResidualBlock,Transformer等一些很基础的网络组件,以防生成的计算图非常庞大,难以辨识。
可视化网络结构的完整代码可以在我的 gist上找到。最后再用我们实现的小工具可视化一些经典的网络吧: UNet, ResNet18, Mobilenet v2。
最后,因为本文这个小工具会可视化出所有节点,类似于DenseNet这样的网络,连接过于稠密,画出来的计算图也基本无法肉眼辨识。等后续用重写is_leaf_module
的方法改进小工具后再说。现在先来欣赏下DenseNet121的结构吧:
DenseNet121,注意,要按住Ctrl
键后狂滑鼠标滚轮缩小网站才能看到网络全貌。