利用torch.fx提取PyTorch网络结构信息绘制网络结构图


torch.fx是一个用于转换(transform)PyTorch模型(即nn.Module)的工具包。从torch 1.10开始,工具包不再处于beta阶段,torch.fx成为了PyTorch的稳定功能。

最近我比较闲,按照文档随便试了一下torch.fx的功能,立马意识到,这玩意真的挺有用!torch.fx能将nn.Module转换为一个图结构,图的节点保存着当前网络节点前向时的输入,输出和参数,以及网络结构本身。这个图结构保存的信息足够多,api丰富。我一直苦于看不懂使用Tensorboard的add_graph方法得到的网络结构图,就尝试用graphviz可视化torch.fx得到的图,发现效果确实不错,比Tensorboard的结果清晰不少。

比如可视化ResNet18网络中的某个残差结构:

torch.fxgraphviz的帮助下,写一个绘制不同网络结构的通用程序只需要100多行代码。其中包含将nn.Module网络转换为图结构、提取图结构中每个节点和边的信息、可视化节点与边这三部分代码。

将网络转换为图结构

为了实现对nn.Module的编辑转换,torch.fx包含三个主要的类,分别负责符号追踪(symbolic tracer),中间表示(intermediate representation),生成Python代码(Python code generation)。其中符号追踪类torch.fx.Tracer就负责将nn.Module网络转换为中间表示,即torch.fx.Graph类实例。类torch.fx.Graph表示计算图,而torch.fx.Node表示计算图中的节点。

让我们首先定义一个包含多种操作的nn.Module,用于测试。

import torch
import torch.nn as nn
import torch.nn.functional as F

class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.bias = nn.Parameter(torch.randn(1))
        self.main = nn.Sequential(nn.Conv2d(3, 4, 1), nn.ReLU(True))
        self.skip = nn.Conv2d(2, 4, 3, stride=1, padding=1)

    def forward(self, x, y):
        x = self.main(x)
        y = (self.skip(y)+self.bias).clamp(0, 1)
        x_size = x.size()[-2:]
        y = F.interpolate(y, x_size, mode="bilinear", align_corners=False)
        return torch.sigmoid(x) + y

接着,利用 torch.fx.Tracer转换我们定义的网络:

tm = TestModel()
# 调用`torch.fx.symbolic_trace(m)`与调用
# `torch.fx.Tracer().trace(m)`等价
symbolic_traced : torch.fx.GraphModule = torch.fx.symbolic_trace(tm)

torch.fx.GraphModule是从torch.fx.Graph转换生成的nn.Module子类,可以当作一个正常的nn.Module来使用,除此以外,其.graph属性就是我们要获取的网络计算图(torch.fx.Graph类)。

总结一下,转换一个nn.Module为计算图torch.fx.Graph只需要调用函数 torch.fx.symbolic_trace(model).graph

提取计算图信息

torch.fx.Graphtorch.fx中间表示最关键的数据结构,存储着计算图中计算节点torch.fx.Node列表,每个节点结构保存着节点的计算细节以及节点之间的调用关系。

将我们得到的计算图print出来:

print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %y : [#users=1] = placeholder[target=y]
    %main_0 : [#users=1] = call_module[target=main.0](args = (%x,), kwargs = {})
    %main_1 : [#users=2] = call_module[target=main.1](args = (%main_0,), kwargs = {})
    %skip : [#users=1] = call_module[target=skip](args = (%y,), kwargs = {})
    %bias : [#users=1] = get_attr[target=bias]
    %add : [#users=1] = call_function[target=operator.add](args = (%skip, %bias), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%add, 0, 1), kwargs = {})
    %size : [#users=1] = call_method[target=size](args = (%main_1,), kwargs = {})
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%size, slice(-2, None, None)), kwargs = {})
    %interpolate : [#users=1] = call_function[target=torch.nn.functional.interpolate](args = (%clamp,), kwargs = {size: %getitem, scale_factor: None, mode: bilinear, align_corners: False, recompute_scale_factor: None})
    %sigmoid : [#users=1] = call_function[target=torch.sigmoid](args = (%main_1,), kwargs = {})
    %add_1 : [#users=1] = call_function[target=operator.add](args = (%sigmoid, %interpolate), kwargs = {})
    return add_1
"""

除了第一行的graph()和最后一行的return,中间每行都是一个单独的计算图节点,每个以%开头的变量都是计算图节点。每行节点名后:展示的就是节点的主要信息。在argskwargs经常会出现以%开头的变量名,这些计算节点就是当前计算节点的参数,反应了节点与节点之间的依赖关系,即图中的边。

可以发现,torch.fx.Graph结构中最关键的就是torch.fx.Node列表,所有信息都在这些计算节点中保存。

torch.fx.Node

torch.fx.Node保存着计算图中单个计算节点的所有信息,说明当前要完成什么类型的计算(.op),计算是调用什么完成的(.target),以及完成计算时的参数是什么(.args.kwargs)。每一个Node有一个.name属性,唯一地标识当前节点,即不同节点的.name属性一定是不同的。

下表展示了不同类型(.op)的计算节点的不同属性的值代表什么。

.op 描述 .name .target .args .kwargs
placeholder 网络输入 输入变量名 输入变量名 主网络前向时的参数数组
get_attr 取网络参数 参数名称 参数在主网络拓扑中的完整名
call_function 调用函数 函数名称 要调用的函数本身 调用函数时的参数数组 调用函数时的参数字典
call_module 子网络前向 子网络在主网络拓扑中的完整名 子网络在主网络拓扑中的完整名 网络前向时的参数数组 网络前向时的参数字典
call_method 调用类方法 类方法名 类方法名 调用类方法时的参数数组 调用类方法时的参数字典
output 网络输出 args[0]就是主网络输出

将由自定义的TestModel类得到的计算图的不同节点的关键属性值打印出来:

symbolic_traced.graph.print_tabular()
"""
opcode         name         target                                                          args                           kwargs
-------------  -----------  --------------------------------------------------------------  -----------------------------  -------------------------------------------------------------------------------------------------------------------
placeholder    x            x                                                               ()                             {}
placeholder    y            y                                                               ()                             {}
call_module    main_0       main.0                                                          (x,)                           {}
call_module    main_1       main.1                                                          (main_0,)                      {}
call_module    skip         skip                                                            (y,)                           {}
get_attr       bias         bias                                                            ()                             {}
call_function  add          <built-in function add>                                         (skip, bias)                   {}
call_method    clamp        clamp                                                           (add, 0, 1)                    {}
call_method    size         size                                                            (main_1,)                      {}
call_function  getitem      <built-in function getitem>                                     (size, slice(-2, None, None))  {}
call_function  interpolate  <function interpolate at 0x000002A12F1E18B0>                    (clamp,)                       {'size': getitem, 'scale_factor': None, 'mode': 'bilinear', 'align_corners': False, 'recompute_scale_factor': None}
call_function  sigmoid      <built-in method sigmoid of type object at 0x00007FFF9EC8F530>  (main_1,)                      {}
call_function  add_1        <built-in function add>                                         (sigmoid, interpolate)         {}
output         output       output                                                          (add_1,)                       {}
"""

每个计算图节点的.name都和.target对应,为了避免.name一致,后出现的节点的.name会再增加一个出现次数项_1。 显然,可以用.name来作为绘制网络结构图时每个节点的唯一标识符。

但用什么来表示当前的节点进行的计算呢(用作绘制网络节点时的label)?call_functioncall_methodget_attr以及输入输出节点似乎只需要把.target整理一下就可以作为label,但call_module类型的节点,除了当前节点子网络在主网络中完整名称,还需要知道网络本身的信息。幸运的是,torch.nn.Module类提供一个 .get_submodule方法,可以根据名称获取子网络。比如A网络有子网络net_b,而net_b又有子网络net_cnet_c是一个简单的nn.Conv2d,要想获取到这个nn.Conv2d这个类实例本身,可以用A.get_submodule(".net_b.net_c"),这里的".net_b.net_c"就是我们一直再说的“子网络在主网络拓扑中的完整名称”。获取到网络本身后,将其转换为字符串(str(net))就能得到网络本身的描述了,可以用作label。

另外call_function类型节点的.target是可调用的函数本身,为了更用户友好地print这个函数,我研究了半天,最后发现打印torch.nn.Graph的结果里,[target=<...>]中的结果就挺好,看了一下torch.fx的源码,发现torch.fx.Node_pretty_print_target方法就是做这个事情的。直接调用node._pretty_print_target(),就可以得到一个节点很用户友好的label啦。

除了节点的计算本身的信息,节点的输入和输出也很关键。节点的输入节点按照我们刚才的分析比较好找:迭代节点的.args.kwargs属性中的存储的调用参数,参数中的节点就是当前节点的输入节点。torch.fx.Node类干脆专门添加了一个属性.all_input_nodes,在内部执行了迭代过程,返回所有作为输入的计算节点。显然,每个输入节点和当前节点之间就构成了有向图中的一条边。

而网络的输出却很难获取,尤其是输出的Tensor的形状等等信息,这对我们分析网络很有帮助,但torch.fx.Node本身却没有存储相关信息。说来也确实是,转换网络为计算图的过程中,根本没有指定输入Tensor的这一步啊?那每个计算图节点的输出的Tensor尺寸啥的信息怎么获取呢?这里就要引入torch.fx.Interpreter类了。

torch.fx.Interpreter

torch.fx.Interpreter负责按照torch.fx.Graph描述的计算图,一个节点一个节点地执行计算。具体而言,torch.fx.Interpreter在使用一个torch.fx.GraphModule初始化后,当执行其.run()方法时,按照计算图,分别用.run_node()执行一个又一个torch.fx.Node节点。执行计算图节点时,按照不同节点类型,分别调用.placeholder().get_attr()…等方法执行节点。这一过程中的映射执行关系可以可视化为:

run()
 └─run_node()
    ├─placeholder()
    ├─get_attr()
    ├─call_function()
    ├─call_method()
    ├─call_module()
    └─output()

因此,可以覆盖torch.fx.Interpreter.run_node()方法,记录调用不同op对应的方法后的结果:

class ResultProbe(torch.fx.Interpreter):
    def run_node(self, n: torch.fx.Node) -> Any:
        try:
            # 执行计算图节点,并保存结果
            result = super().run_node(n)
        except Exception:
            traceback.print_exc()
            raise RuntimeError(
                f"ShapeProp error for: "
                f"node={n.format_node()} with " f"meta={n.meta}"
            )
        
        find_tensor_in_result = False

        def extract_tensor_meta(obj):
            # 找到结果中的Tensor,记录其主要信息
            if isinstance(obj, torch.Tensor):
                nonlocal find_tensor_in_result
                find_tensor_in_result = True
                # _extract_tensor_metadata函数的细节省略,
                # 可以参考文章最后的完整代码链接。
                return _extract_tensor_metadata(obj)
            else:
                return obj

        # torch.fx.node.map_aggregate会迭代result,
        # 对result中的每个元素执行给定的函数
        # 比如result是list时会对其每个元素调用指定的函数,
        # result是dict是会对其每个value执行给定的函数。
        n.meta["result"] = torch.fx.node.map_aggregate(
            result, 
            extract_tensor_meta
        )
        n.meta["find_tensor_in_result"] = find_tensor_in_result
        return result

自定义类ResultProbe覆盖了torch.fx.Interpreterrun_node方法,然后将结果存储到torch.fx.Node.meta属性中。另外,还保存了一个Flag:find_tensor_in_result,以此指示当前计算节点的输出中是否包含Tensor,因为我们可能更关注那些真正处理Tensor的计算节点,而对如同求Tensor尺寸以供后续操作这样的计算节点兴趣乏乏。

接着,用ResultProbe执行网络前向:

args = (torch.randn(1, 3, 16, 16), torch.randn(1, 2, 8, 8))
kwargs = dict()
# run的参数就是TestModel的参数
ResultProbe(symbolic_traced).run(*args, **kwargs)

执行完毕后,每个Node节点的.meta属性就包含了"result""find_tensor_in_result"这两个信息,后续直接将其打印出来即可。

可视化节点与边

我们利用Graphviz来实现图的可视化。 Graphviz是一个开源的图可视化软件,它定义了一种叫做DOT的图描述语言,能将用dot语言描述的图转换为svg、pdf、png等多种格式,将图可视化。可以按照 Graphviz Download/中的介绍安装Graphviz。Python库 graphviz则提供了Graphviz的API接口,能通过调用Python命令生成dot文件,接着利用Graphviz转换为图片。我们就利用graphviz包实现网络结构图的可视化。

在Ubuntu系统上,可以用下面的命令安装所有依赖:

# 安装Graphviz二进制文件
sudo apt install graphviz
# 安装Python库graphviz
pip install graphviz

GraphvizOnline则可以在线实现dot语言与图片的相互转换,可以在这个网页上大概浏览下dot语言是多么简洁。

Python库graphviz的使用方法也很简单:

import graphviz
dot = graphviz.Digraph()
# 只需要不断调用`.node`和`.edge`就能增加节点和边
dot.node('A', 'King Arthur')
dot.node('B', 'Sir Bedevere the Wise')
dot.node('L', 'Sir Lancelot the Brave')
dot.edge('B', 'L', constraint='false')
dot.edges(['AB', 'AL'])
# 生成图片
dot.render('test/round-table.gv', view=False)

dot语言接受HTML格式的label,因此可以生成HTML表格作为图中每个节点的label,以此表示更多信息。

借助graphviz可视化我们先前得到的计算图就很简单了:

# 可视化单个节点的核心代码:
def single_node(model: torch.nn.Module, graph: graphviz.Digraph, node: torch.fx.Node):
    node_label = node_label_html(model, node) # 生成当前节点的label
    node_kwargs = dict(shape="plaintext")
    graph.node(node.name, node_label, **node_kwargs) # 在Graphviz图中添加当前节点
    
    # 遍历当前节点的所有输入节点,添加Graphviz图中的边
    for in_node in node.all_input_nodes:
        edge_kwargs = dict()
        if (
            not node.meta["find_tensor_in_result"]
            or not in_node.meta["find_tensor_in_result"]
        ):
            # 如果当前节点的输入和输出中都没有Tensor,就把当前边置为浅灰色虚线,弱化显示
            edge_kwargs.update(dict(style="dashed", color="lightgrey"))
        # 添加当前边
        graph.edge(in_node.name, node.name, **edge_kwargs)

其中node_label_html函数就是一个把前文得到的Node的信息拼接在一个HTML表格中,然后返回HTML代码的函数。就不详细讲解如何拼接HTML字符串了。

最后,生成网络结构图的总体代码就很简单了:

def model_graph(model: torch.nn.Module, *args, **kwargs) -> graphviz.Digraph:
    # 将nn.Module转换为torch.fx.GraphModule,获取计算图
    symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(model)
    # 执行一下网络,以此获取每个节点输入输出的具体信息
    ResultProbe(symbolic_traced).run(*args, **kwargs)
    # 定义一个Graphviz网络
    graph = graphviz.Digraph("model", format="svg", node_attr={"shape": "plaintext"})
    for node in symbolic_traced.graph.nodes: # 遍历所有节点
        single_node(model, graph, node)
    return graph

model = TestModel()

graph = model_graph(model, torch.randn(1, 3, 16, 16), torch.randn(1, 2, 8, 8))
graph.render(directory="test", view=False)

通过上述代码得到的TestModel的可视化结果如图所示:

图中虚线部分表示x在送入Conv2dReLU后,被调用.size方法,获取特征图尺寸,用于后续缩放y经过一系列变换后的结果。 而其它实线部分就是我们想要得到的PyTorch网络计算图了。

总结

到目前为止,我们仅仅使用了torch.fx最基础的功能。torch.fx更为关键的nn.Module编程转换功能还都没有涉及到。 我们所完成的绘制网络结构图的小工具也有很多还可以改进的地方, 比如通过覆盖torch.fx.Traceris_leaf_module方法实现对网络可视化层级的控制, 避免可视化一些如ResidualBlock,Transformer等一些很基础的网络组件,以防生成的计算图非常庞大,难以辨识。

可视化网络结构的完整代码可以在我的 gist上找到。最后再用我们实现的小工具可视化一些经典的网络吧: UNetResNet18Mobilenet v2

最后,因为本文这个小工具会可视化出所有节点,类似于DenseNet这样的网络,连接过于稠密,画出来的计算图也基本无法肉眼辨识。等后续用重写is_leaf_module的方法改进小工具后再说。现在先来欣赏下DenseNet121的结构吧: DenseNet121,注意,要按住Ctrl键后狂滑鼠标滚轮缩小网站才能看到网络全貌。