IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 整理经典网络并计算其flops等 -> 正文阅读

[人工智能]整理经典网络并计算其flops等

利用tensorboard显示模型示意图使用到的函数add_graph中的参数

add_graph(model, input_to_model=None, verbose=False, **kwargs)

参数

??? model (torch.nn.Module): 待可视化的网络模型
??? input_to_model (torch.Tensor or list of torch.Tensor, optional): 待输入神经网络的变量或一组变量
???verbose表示详细信息,verbose=FALSE,意思就是设置运行的时候不显示详细信息

1.AlexNet

import torch
from torch import nn
from torchstat import stat
class AlexNet(nn.Module):
    def __init__(self,num_classes):
        super(AlexNet,self).__init__()
        self.features=nn.Sequential(
            nn.Conv2d(3,64,kernel_size=11,stride=4,padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=3,stride=2),
        
        
            nn.Conv2d(64,192,kernel_size=5,padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=3,stride=2),
        
        
            nn.Conv2d(192, 384, kernel_size=3, padding=1),   # b, 384, 13, 13
            nn.ReLU(True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),   # b, 256, 13, 13
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),   # b, 256, 13, 13
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=3, stride=2))    # b, 256, 6, 6
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Linear(4096, num_classes))
    def forward(self,x):
        x=self.features(x)
        print(x.size())
        x=x.view(x.size(0),256*6*6)
        x=self.classifier(x)
        return x

    
model=AlexNet(10)
stat(model,(3,224,224))

?

?

?2.vgg net

# VGG-16模型

from torch import nn
from torchstat import stat
class VGG(nn.Module):
    def __init__(self, num_classes):
        super(VGG, self).__init__()     # b, 3, 224, 224
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),   # b, 64, 224, 224
            nn.ReLU(True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),    # b, 64, 224, 224
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),     # b, 64, 112, 112

            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # b, 128, 112, 112
            nn.ReLU(True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),   # b, 128, 112, 112
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),   # b, 128, 56, 56

            nn.Conv2d(128, 256, kernel_size=3, padding=1),    # b, 256, 56, 56
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),    # b, 256, 56, 56
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # b, 256, 56, 56
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),    # b, 256, 28, 28

            nn.Conv2d(256, 512, kernel_size=3, padding=1),  # b, 512, 28, 28
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 28, 28
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 28, 28
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # b, 512, 14, 14

            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 14, 14
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 14, 14
            nn.ReLU(True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),  # b, 512, 14, 14
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2))  # b, 512, 7, 7
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes))

    def forward(self, x):
        x = self.features(x)
        print(x.size())
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
model = VGG(1000)
stat(model, (3, 224, 224))
torch.Size([1, 512, 7, 7])
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
        module name  input shape output shape       params memory(MB)              MAdd             Flops   MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0        features.0    3 224 224   64 224 224       1792.0      12.00     173,408,256.0      89,915,392.0     609280.0   12845056.0       6.76%   13454336.0
1        features.1   64 224 224   64 224 224          0.0      12.00       3,211,264.0       3,211,264.0   12845056.0   12845056.0       1.76%   25690112.0
2        features.2   64 224 224   64 224 224      36928.0      12.00   3,699,376,128.0   1,852,899,328.0   12992768.0   12845056.0       7.65%   25837824.0
3        features.3   64 224 224   64 224 224          0.0      12.00       3,211,264.0       3,211,264.0   12845056.0   12845056.0       1.18%   25690112.0
4        features.4   64 224 224   64 112 112          0.0       3.00       2,408,448.0       3,211,264.0   12845056.0    3211264.0      11.47%   16056320.0
5        features.5   64 112 112  128 112 112      73856.0       6.00   1,849,688,064.0     926,449,664.0    3506688.0    6422528.0       4.41%    9929216.0
6        features.6  128 112 112  128 112 112          0.0       6.00       1,605,632.0       1,605,632.0    6422528.0    6422528.0       0.59%   12845056.0
7        features.7  128 112 112  128 112 112     147584.0       6.00   3,699,376,128.0   1,851,293,696.0    7012864.0    6422528.0       7.06%   13435392.0
8        features.8  128 112 112  128 112 112          0.0       6.00       1,605,632.0       1,605,632.0    6422528.0    6422528.0       0.59%   12845056.0
9        features.9  128 112 112  128  56  56          0.0       1.00       1,204,224.0       1,605,632.0    6422528.0    1605632.0       3.82%    8028160.0
10      features.10  128  56  56  256  56  56     295168.0       3.00   1,849,688,064.0     925,646,848.0    2786304.0    3211264.0       3.82%    5997568.0
11      features.11  256  56  56  256  56  56          0.0       3.00         802,816.0         802,816.0    3211264.0    3211264.0       0.29%    6422528.0
12      features.12  256  56  56  256  56  56     590080.0       3.00   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       6.47%    8782848.0
13      features.13  256  56  56  256  56  56          0.0       3.00         802,816.0         802,816.0    3211264.0    3211264.0       0.29%    6422528.0
14      features.14  256  56  56  256  56  56     590080.0       3.00   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       6.47%    8782848.0
15      features.15  256  56  56  256  56  56          0.0       3.00         802,816.0         802,816.0    3211264.0    3211264.0       0.29%    6422528.0
16      features.16  256  56  56  256  28  28          0.0       0.00         602,112.0         802,816.0    3211264.0     802816.0       2.35%    4014080.0
17      features.17  256  28  28  512  28  28    1180160.0       1.00   1,849,688,064.0     925,245,440.0    5523456.0    1605632.0       3.82%    7129088.0
18      features.18  512  28  28  512  28  28          0.0       1.00         401,408.0         401,408.0    1605632.0    1605632.0       0.29%    3211264.0
19      features.19  512  28  28  512  28  28    2359808.0       1.00   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       6.47%   12650496.0
20      features.20  512  28  28  512  28  28          0.0       1.00         401,408.0         401,408.0    1605632.0    1605632.0       0.00%    3211264.0
21      features.21  512  28  28  512  28  28    2359808.0       1.00   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       7.35%   12650496.0
22      features.22  512  28  28  512  28  28          0.0       1.00         401,408.0         401,408.0    1605632.0    1605632.0       0.00%    3211264.0
23      features.23  512  28  28  512  14  14          0.0       0.00         301,056.0         401,408.0    1605632.0     401408.0       1.47%    2007040.0
24      features.24  512  14  14  512  14  14    2359808.0       0.00     924,844,032.0     462,522,368.0    9840640.0     401408.0       2.35%   10242048.0
25      features.25  512  14  14  512  14  14          0.0       0.00         100,352.0         100,352.0     401408.0     401408.0       0.29%     802816.0
26      features.26  512  14  14  512  14  14    2359808.0       0.00     924,844,032.0     462,522,368.0    9840640.0     401408.0       1.76%   10242048.0
27      features.27  512  14  14  512  14  14          0.0       0.00         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
28      features.28  512  14  14  512  14  14    2359808.0       0.00     924,844,032.0     462,522,368.0    9840640.0     401408.0       2.35%   10242048.0
29      features.29  512  14  14  512  14  14          0.0       0.00         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
30      features.30  512  14  14  512   7   7          0.0       0.00          75,264.0         100,352.0     401408.0     100352.0       0.00%     501760.0
31     classifier.0        25088         4096  102764544.0       0.00     205,516,800.0     102,760,448.0  411158528.0      16384.0       6.47%  411174912.0
32     classifier.1         4096         4096          0.0       0.00           4,096.0           4,096.0      16384.0      16384.0       0.00%      32768.0
33     classifier.2         4096         4096          0.0       0.00               0.0               0.0          0.0          0.0       0.00%          0.0
34     classifier.3         4096         4096   16781312.0       0.00      33,550,336.0      16,777,216.0   67141632.0      16384.0       1.47%   67158016.0
35     classifier.4         4096         4096          0.0       0.00           4,096.0           4,096.0      16384.0      16384.0       0.29%      32768.0
36     classifier.5         4096         4096          0.0       0.00               0.0               0.0          0.0          0.0       0.00%          0.0
37     classifier.6         4096         1000    4097000.0       0.00       8,191,000.0       4,096,000.0   16404384.0       4000.0       0.29%   16408384.0
total                                          138357544.0     100.00  30,958,666,264.0  15,503,489,024.0   16404384.0       4000.0     100.00%  783170624.0
============================================================================================================================================================
Total params: 138,357,544
------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 100.00MB
Total MAdd: 30.96GMAdd
Total Flops: 15.5GFlops
Total MemR+W: 746.89MB

?

?

3.GoogLeNet

from torch import nn
import torch
# 卷积 + BN
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv= nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class Inception(nn.Module):
    def __init__(self, in_channels):
        super(Inception, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)

        self.branch3x3_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)

        self.branchpool_1 = nn.AvgPool2d(kernel_size=3, stride=1)
        self.branch_pool = BasicConv2d(in_channels, 64, kernel_size=1, padding=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        print('branch1x1_size: ', branch1x1.size())
        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)
        print('branch5x5_size: ', branch5x5.size())
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)
        branch3x3 = self.branch3x3_3(branch3x3)
        print('branch3x3_size: ', branch3x3.size())
        branch_pool = self.branchpool_1(x)
        branch_pool = self.branch_pool(branch_pool)
        print('branch_pool_size: ', branch_pool.size())
        outputs = [branch1x1, branch5x5, branch3x3, branch_pool]

        return torch.cat(outputs, 1)

model = Inception(3)
from torchstat import stat
stat(model, (3, 224, 224))
branch1x1_size:  torch.Size([1, 64, 224, 224])
branch5x5_size:  torch.Size([1, 64, 224, 224])
branch3x3_size:  torch.Size([1, 96, 224, 224])
branch_pool_size:  torch.Size([1, 64, 224, 224])
            module name  input shape output shape    params memory(MB)              MAdd             Flops  MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0        branch1x1.conv    3 224 224   64 224 224     192.0      12.00      16,056,320.0       9,633,792.0    602880.0   12845056.0      34.88%   13447936.0
1          branch1x1.bn   64 224 224   64 224 224     128.0      12.00      12,845,056.0       6,422,528.0  12845568.0   12845056.0      17.03%   25690624.0
2      branch5x5_1.conv    3 224 224   48 224 224     144.0       9.00      12,042,240.0       7,225,344.0    602688.0    9633792.0       1.98%   10236480.0
3        branch5x5_1.bn   48 224 224   48 224 224      96.0       9.00       9,633,792.0       4,816,896.0   9634176.0    9633792.0       0.40%   19267968.0
4      branch5x5_2.conv   48 224 224   64 224 224   76800.0      12.00   7,703,822,336.0   3,853,516,800.0   9940992.0   12845056.0       3.64%   22786048.0
5        branch5x5_2.bn   64 224 224   64 224 224     128.0      12.00      12,845,056.0       6,422,528.0  12845568.0   12845056.0       0.48%   25690624.0
6      branch3x3_1.conv    3 224 224   64 224 224     192.0      12.00      16,056,320.0       9,633,792.0    602880.0   12845056.0       1.27%   13447936.0
7        branch3x3_1.bn   64 224 224   64 224 224     128.0      12.00      12,845,056.0       6,422,528.0  12845568.0   12845056.0       0.48%   25690624.0
8      branch3x3_2.conv   64 224 224   96 224 224   55296.0      18.00   5,544,247,296.0   2,774,532,096.0  13066240.0   19267584.0       3.40%   32333824.0
9        branch3x3_2.bn   96 224 224   96 224 224     192.0      18.00      19,267,584.0       9,633,792.0  19268352.0   19267584.0       0.71%   38535936.0
10     branch3x3_3.conv   96 224 224   96 224 224   82944.0      18.00   8,318,779,392.0   4,161,798,144.0  19599360.0   19267584.0       4.28%   38866944.0
11       branch3x3_3.bn   96 224 224   96 224 224     192.0      18.00      19,267,584.0       9,633,792.0  19268352.0   19267584.0       0.48%   38535936.0
12         branchpool_1    3 224 224    3 222 222       0.0       0.00       1,330,668.0         150,528.0    602112.0     591408.0       0.24%    1193520.0
13     branch_pool.conv    3 222 222   64 224 224     192.0      12.00      16,056,320.0       9,633,792.0    592176.0   12845056.0      30.28%   13437232.0
14       branch_pool.bn   64 224 224   64 224 224     128.0      12.00      12,845,056.0       6,422,528.0  12845568.0   12845056.0       0.48%   25690624.0
total                                              216752.0     186.00  21,727,940,076.0  10,875,898,880.0  12845568.0   12845056.0     100.00%  344852256.0
============================================================================================================================================================
Total params: 216,752
------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 186.00MB
Total MAdd: 21.73GMAdd
Total Flops: 10.88GFlops
Total MemR+W: 328.88MB

?

4.ResNet

from torch import nn
from torchstat import stat

def conv3x3(in_planes, out_planes, stride=1):
    # 3x3 convolution with padding
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        # downsample对应着一个下采样函数
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-07 10:48:58  更:2021-09-07 10:50:30 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 15:30:01-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码