IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> paddle.summary不显示网络的基础结构和参数数量都为0的问题 -> 正文阅读

[人工智能]paddle.summary不显示网络的基础结构和参数数量都为0的问题

因为自己在这个坑里折腾了很长时间,才找到原因,所以希望能提示帮到大家。

先写结论

  • 网络层必须定义在调用forward前定义完成(比如在__init__中定义),在forward中调用,才能打印出网络结构图和参数。
  • 虽然,不能打印出网络的基础结构和参数信息,但不影响使用。

方式1,如下代码,打印不出网络结构和参数,参数数量都是0,但模型代码是可以正常训练和使用的。
import paddle
from paddle.nn import Linear

class S4(paddle.nn.Layer):
    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        output_d = Linear(90,4)(inputs)
        output_d = Linear(4,1)(output_d)
        return output_d

model = S4()

params_info = paddle.summary(model, (2,90))
print(params_info)

运行耗时: 12毫秒
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
     S           [[2, 90]]              [2, 1]               0       
===========================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
---------------------------------------------------------------------------

{'total_params': 0, 'trainable_params': 0}
方式2,如下写的代码,可以打印出网络结构图和参数。在__init__中定义好各网络层。
import paddle
from paddle.nn import Linear

class S(paddle.nn.Layer):
    def __init__(self):
        super().__init__()
        self.l1= Linear(90,4)
        self.l2= Linear(4,1)

    def forward(self, inputs):
        output_d = self.l1(inputs)
        output_d = self.l2(output_d)
        return output_d

model = S()

params_info = paddle.summary(model, (2,90))
print(params_info)
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
  Linear-128         [[2, 90]]              [2, 4]              364      
  Linear-129          [[2, 4]]              [2, 1]               5       
===========================================================================
Total params: 369
Trainable params: 369
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
---------------------------------------------------------------------------

{'total_params': 369, 'trainable_params': 369}

深入summary的源码

  • 主要通过model(模型的实例)的sublayers()方法,返回网络层的列表。
model.sublayers()

[Linear(in_features=90, out_features=4, dtype=float32),
 Linear(in_features=4, out_features=1, dtype=float32)]
  • 遍历列表中的每个层,在源代码中主要看hook函数源码

  • 每个网络层的属性

vars(model.sublayers()[0])

{'training': True,
 '_full_name': 'linear_140',
 '_helper': <paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper at 0x7f65a26d0150>,
 '_built': True,
 '_dtype': 'float32',
 '_init_in_dynamic_mode': True,
 '_parameters': OrderedDict([('weight', Parameter containing:
               Tensor(shape=[90, 4], dtype=float32, place=CPUPlace, stop_gradient=False,
                      [[-0.20621595,  0.08401936, -0.14570992,  0.11183006],
                       [-0.01728971, -0.13487765, -0.23390889,  0.24098989],
                       [-0.13978893, -0.20424819, -0.04855666, -0.16199186],
                       [-0.19478002, -0.13951692,  0.14115235, -0.14658877],
                       [ 0.15643370,  0.02453366,  0.19980761, -0.02117334],
                       [-0.04742610, -0.01464723, -0.15753657,  0.02658525],
                       [ 0.12675306,  0.21732125,  0.18672404, -0.06408432],
                       [ 0.16140866,  0.22746661, -0.16478351, -0.00010374],
                       [-0.22843692, -0.11098161,  0.03810865, -0.24758199],
                       [-0.11715272,  0.08546367,  0.18643376,  0.00941786],
                       [ 0.04385796, -0.13718715, -0.13291588, -0.12744799],
                       [ 0.20507479,  0.08363530,  0.08122411,  0.14772725],
                       [ 0.12475023,  0.10399523,  0.22499707,  0.04691744],
                       [ 0.16640592,  0.05567488,  0.25142434,  0.05884814],
                       [ 0.12723908, -0.21167704, -0.05146568, -0.07546510],
                       [-0.18740901, -0.13138834,  0.09158534,  0.11389682],
                       [-0.23955993, -0.04748055,  0.01112020, -0.07551017],
                       [-0.19948956,  0.07887903, -0.11123639, -0.24801619],
                       [ 0.17985007, -0.13626945,  0.03615150, -0.09616728],
                       [ 0.11204135, -0.13762784, -0.10694183, -0.24507296],
                       [ 0.14161596, -0.18264475,  0.21488816, -0.02595139],
                       [ 0.01281962, -0.23533244, -0.05832650, -0.18845016],
                       [ 0.01685080,  0.21017367,  0.04473603, -0.07200867],
                       [-0.16427593,  0.10090092, -0.08228512,  0.22937572],
                       [ 0.02535132,  0.16946691,  0.01132673,  0.22891155],
                       [-0.17698936,  0.14869818,  0.24249542,  0.19769028],
                       [ 0.01254764,  0.19422016, -0.07282487,  0.09596258],
                       [ 0.16854528, -0.06393975,  0.20858878, -0.06705859],
                       [ 0.18959373, -0.08939679,  0.07643029, -0.05605973],
                       [-0.12243937, -0.14132866,  0.19494379, -0.06451888],
                       [ 0.19235328,  0.10110179, -0.22526026, -0.21696852],
                       [-0.07262844, -0.01888850,  0.16245791,  0.22750211],
                       [-0.22919488, -0.18954048,  0.03602901,  0.21270525],
                       [-0.13936907,  0.12847206,  0.12586927, -0.16047193],
                       [-0.23191750, -0.13488717,  0.24885699,  0.08321038],
                       [-0.13570510, -0.07265873, -0.10631232,  0.09162298],
                       [-0.12932971,  0.18517035, -0.15347271, -0.02739608],
                       [ 0.22921789, -0.15096828,  0.23355103,  0.21406570],
                       [ 0.08929682,  0.15368345, -0.06598282, -0.13409325],
                       [-0.04817520, -0.10624386, -0.01299433,  0.16373715],
                       [ 0.06702435, -0.01464282,  0.15239137, -0.06154861],
                       [-0.08851637, -0.15216920,  0.14518383,  0.03232276],
                       [-0.00448877,  0.10262579, -0.24696153,  0.06273824],
                       [ 0.19433418, -0.01608880, -0.06958365,  0.21542296],
                       [ 0.17663121,  0.24140605, -0.16102412, -0.13844144],
                       [-0.25249654, -0.11238430, -0.10917042,  0.22930911],
                       [ 0.18360525, -0.03554909, -0.20884866,  0.10487282],
                       [-0.24110393,  0.12529787, -0.00518186,  0.09543869],
                       [ 0.19364417,  0.18657467, -0.05774385,  0.06694761],
                       [ 0.06235594,  0.18447849, -0.23734707,  0.02031776],
                       [-0.07527824,  0.06170449,  0.02800348, -0.12043267],
                       [-0.06529415,  0.02134055, -0.18275535, -0.13008605],
                       [-0.20935142,  0.05829439, -0.04963452,  0.24142167],
                       [-0.00459233, -0.17473610, -0.12816842, -0.05939011],
                       [ 0.25015548,  0.10697207,  0.24913797,  0.23499483],
                       [-0.09882846, -0.07380529, -0.00138485, -0.05095109],
                       [ 0.24390355,  0.14776850, -0.23272485, -0.14770921],
                       [ 0.10985476,  0.04561046,  0.12682661, -0.18175307],
                       [-0.07764107, -0.23298098,  0.11743194,  0.03062549],
                       [-0.16420993,  0.06812567, -0.22719657, -0.13849448],
                       [-0.20762315,  0.15431166, -0.24009264, -0.04427037],
                       [-0.15257362, -0.11478184,  0.21630144,  0.01226106],
                       [ 0.20274323,  0.01106474,  0.17747423,  0.19533655],
                       [-0.09216593,  0.21642286,  0.16395739,  0.18711695],
                       [-0.16654512, -0.01774178, -0.21099040, -0.15609533],
                       [ 0.16119608,  0.16110554,  0.22074458, -0.16093451],
                       [ 0.13975376, -0.13630739, -0.03922784,  0.11074531],
                       [ 0.17509770,  0.00062934, -0.06153570,  0.14687642],
                       [ 0.22287810,  0.18240044,  0.14542899, -0.20445868],
                       [-0.21342079,  0.20837983, -0.05815056, -0.03875078],
                       [-0.24014582, -0.15421730, -0.12463945, -0.11398430],
                       [ 0.09004262, -0.02769084, -0.19755839, -0.10233013],
                       [ 0.19232154,  0.04774907,  0.18493724,  0.25028470],
                       [-0.12067989, -0.15001965, -0.02605258,  0.22234175],
                       [-0.00589231, -0.15065691,  0.07578629, -0.23713773],
                       [ 0.24967501, -0.00311911, -0.11508700, -0.11503294],
                       [-0.05981220,  0.11477354, -0.03898008, -0.11416386],
                       [ 0.12355697,  0.00971353,  0.03545964,  0.03065020],
                       [-0.22309546, -0.07022245,  0.04698494,  0.06996807],
                       [ 0.00341684, -0.04152401, -0.04922824, -0.19472414],
                       [-0.07824557, -0.24881583, -0.02940276, -0.09540446],
                       [-0.07309537, -0.19851863,  0.18788418, -0.14266992],
                       [ 0.19145533,  0.18813542, -0.13308662,  0.23938963],
                       [ 0.19193968,  0.15658599,  0.11875460, -0.12196723],
                       [-0.06095690,  0.14294901,  0.10374212, -0.08170357],
                       [ 0.15833840, -0.15739962, -0.09807962,  0.09703702],
                       [-0.15394495,  0.05613059,  0.21851090,  0.02185619],
                       [ 0.17367855,  0.13805395,  0.12828121, -0.12902705],
                       [-0.05945474,  0.09741652,  0.21674320,  0.06598496],
                       [-0.05226740, -0.21138683,  0.21320057,  0.14421508]])),
              ('bias', Parameter containing:
               Tensor(shape=[4], dtype=float32, place=CPUPlace, stop_gradient=False,
                      [0., 0., 0., 0.]))]),
 '_buffers': OrderedDict(),
 '_non_persistable_buffer_names_set': set(),
 '_sub_layers': OrderedDict(),
 '_loaddict_holder': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_post_hooks': OrderedDict(),
 '_weight_attr': None,
 '_bias_attr': None,
 'name': None}

方式1的问题,在于sublayers()返回的列表是空列表。

  • sublayers()的源码核心是遍历named_sublayers
  • named_sublayers的源码是用类内部属性_sub_layers获取所有的子网络的列表。
    例如:
model._sub_layers
OrderedDict([('l1', Linear(in_features=90, out_features=4, dtype=float32)),
             ('l2', Linear(in_features=4, out_features=1, dtype=float32))])
    Examplex:
        .. code-block:: python
            import paddle
            import numpy as np
            from collections import OrderedDict
            sublayers = OrderedDict([
                ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
                ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
                ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
            ])
            layers_dict = paddle.nn.LayerDict(sublayers=sublayers)
            l = layers_dict['conv1d']
            for k in layers_dict:
                l = layers_dict[k]
            len(layers_dict)
            #3
            del layers_dict['conv2d']
            len(layers_dict)
            #2
            conv1d = layers_dict.pop('conv1d')
            len(layers_dict)
            #1
            layers_dict.clear()
            len(layers_dict)
            #0

解决方法(只是可行,不一定是优化)

  • 在forward调用前把网络定义完毕。
  1. 生成网络顺序列表。
  2. 用paddle.nn.LayerList(把1打包成paddle的层列表。
  3. 在forward中顺序调用。
    其中1,2的步骤类似于paddle.nn.Sequential的作用。只不过,可以更加灵活,比如在网络层之间进行特殊的张量操作,而Sequential不行。

例如:

import paddle
from paddle.nn import Linear, BatchNorm, ReLU

class S(paddle.nn.Layer):
    def __init__(self):
        super().__init__()
        self.layer_list  = paddle.nn.LayerList(self.make_list())

    def linear(self,input_dim, output_dim):
        return Linear(input_dim, output_dim)
    
    ##可以根据网络参数不同,重复调用。返回网络列表。
    def linear_g(self, input_dim, output_dim):
        l_list = []
        l = self.linear(input_dim, output_dim)
        l_list.append(l)
        l = BatchNorm(l.weight.shape[1])
        l_list.append(l)
        l_list.append(ReLU())
        return l_list

    def make_list(self):
        l_list = []
        l_list.extend(self.linear_g(90,4))
        return l_list

    def forward(self, x):
        for i, l in enumerate(self.layer_list):
            x = self.layer_list[i](x)
        return x

model = S()

params_info = paddle.summary(model, (1,90))
print(params_info)
print(list(model.named_sublayers()))
print(model.named_children())

---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
  Linear-231         [[1, 90]]              [1, 4]              364      
 BatchNorm-22         [[1, 4]]              [1, 4]              16       
    ReLU-41           [[1, 4]]              [1, 4]               0       
===========================================================================
Total params: 380
Trainable params: 364
Non-trainable params: 16
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
---------------------------------------------------------------------------

{'total_params': 380, 'trainable_params': 364}
[('layer_list', LayerList(
  (0): Linear(in_features=90, out_features=4, dtype=float32)
  (1): BatchNorm()
  (2): ReLU()
)), ('layer_list.0', Linear(in_features=90, out_features=4, dtype=float32)), ('layer_list.1', BatchNorm()), ('layer_list.2', ReLU())]
<generator object Layer.named_children at 0x7f65a427dc50>

定义网络,如果能知道前一层的输出形状?

可以查看各自预定层的源码,例如:paddle.nn.Linear的源码中,

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)

可以看出,本层的输出性质是在self.weight.shape[1]中。
其他的层类似,查看extra_repr函数中的内容。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-04 19:55:33  更:2021-07-04 19:56:02 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年4日历 -2024/4/25 15:29:11-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码