利用tensorboard显示模型示意图使用到的函数add_graph中的参数
add_graph(model, input_to_model=None, verbose=False, **kwargs)
参数
??? model (torch.nn.Module): 待可视化的网络模型 ??? input_to_model (torch.Tensor or list of torch.Tensor, optional): 待输入神经网络的变量或一组变量 ???verbose表示详细信息,verbose=FALSE,意思就是设置运行的时候不显示详细信息
1.AlexNet
import torch
from torch import nn
from torchstat import stat
class AlexNet(nn.Module):
def __init__(self,num_classes):
super(AlexNet,self).__init__()
self.features=nn.Sequential(
nn.Conv2d(3,64,kernel_size=11,stride=4,padding=2),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3,stride=2),
nn.Conv2d(64,192,kernel_size=5,padding=2),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3,stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1), # b, 384, 13, 13
nn.ReLU(True),
nn.Conv2d(384, 256, kernel_size=3, padding=1), # b, 256, 13, 13
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, padding=1), # b, 256, 13, 13
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2)) # b, 256, 6, 6
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256*6*6, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Linear(4096, num_classes))
def forward(self,x):
x=self.features(x)
print(x.size())
x=x.view(x.size(0),256*6*6)
x=self.classifier(x)
return x
model=AlexNet(10)
stat(model,(3,224,224))
?
?
?2.vgg net
# VGG-16模型
from torch import nn
from torchstat import stat
class VGG(nn.Module):
def __init__(self, num_classes):
super(VGG, self).__init__() # b, 3, 224, 224
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1), # b, 64, 224, 224
nn.ReLU(True),
nn.Conv2d(64, 64, kernel_size=3, padding=1), # b, 64, 224, 224
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # b, 64, 112, 112
nn.Conv2d(64, 128, kernel_size=3, padding=1), # b, 128, 112, 112
nn.ReLU(True),
nn.Conv2d(128, 128, kernel_size=3, padding=1), # b, 128, 112, 112
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # b, 128, 56, 56
nn.Conv2d(128, 256, kernel_size=3, padding=1), # b, 256, 56, 56
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, padding=1), # b, 256, 56, 56
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, padding=1), # b, 256, 56, 56
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # b, 256, 28, 28
nn.Conv2d(256, 512, kernel_size=3, padding=1), # b, 512, 28, 28
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 28, 28
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 28, 28
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2), # b, 512, 14, 14
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 14, 14
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 14, 14
nn.ReLU(True),
nn.Conv2d(512, 512, kernel_size=3, padding=1), # b, 512, 14, 14
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)) # b, 512, 7, 7
self.classifier = nn.Sequential(
nn.Linear(512*7*7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes))
def forward(self, x):
x = self.features(x)
print(x.size())
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = VGG(1000)
stat(model, (3, 224, 224))
torch.Size([1, 512, 7, 7])
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 features.0 3 224 224 64 224 224 1792.0 12.00 173,408,256.0 89,915,392.0 609280.0 12845056.0 6.76% 13454336.0
1 features.1 64 224 224 64 224 224 0.0 12.00 3,211,264.0 3,211,264.0 12845056.0 12845056.0 1.76% 25690112.0
2 features.2 64 224 224 64 224 224 36928.0 12.00 3,699,376,128.0 1,852,899,328.0 12992768.0 12845056.0 7.65% 25837824.0
3 features.3 64 224 224 64 224 224 0.0 12.00 3,211,264.0 3,211,264.0 12845056.0 12845056.0 1.18% 25690112.0
4 features.4 64 224 224 64 112 112 0.0 3.00 2,408,448.0 3,211,264.0 12845056.0 3211264.0 11.47% 16056320.0
5 features.5 64 112 112 128 112 112 73856.0 6.00 1,849,688,064.0 926,449,664.0 3506688.0 6422528.0 4.41% 9929216.0
6 features.6 128 112 112 128 112 112 0.0 6.00 1,605,632.0 1,605,632.0 6422528.0 6422528.0 0.59% 12845056.0
7 features.7 128 112 112 128 112 112 147584.0 6.00 3,699,376,128.0 1,851,293,696.0 7012864.0 6422528.0 7.06% 13435392.0
8 features.8 128 112 112 128 112 112 0.0 6.00 1,605,632.0 1,605,632.0 6422528.0 6422528.0 0.59% 12845056.0
9 features.9 128 112 112 128 56 56 0.0 1.00 1,204,224.0 1,605,632.0 6422528.0 1605632.0 3.82% 8028160.0
10 features.10 128 56 56 256 56 56 295168.0 3.00 1,849,688,064.0 925,646,848.0 2786304.0 3211264.0 3.82% 5997568.0
11 features.11 256 56 56 256 56 56 0.0 3.00 802,816.0 802,816.0 3211264.0 3211264.0 0.29% 6422528.0
12 features.12 256 56 56 256 56 56 590080.0 3.00 3,699,376,128.0 1,850,490,880.0 5571584.0 3211264.0 6.47% 8782848.0
13 features.13 256 56 56 256 56 56 0.0 3.00 802,816.0 802,816.0 3211264.0 3211264.0 0.29% 6422528.0
14 features.14 256 56 56 256 56 56 590080.0 3.00 3,699,376,128.0 1,850,490,880.0 5571584.0 3211264.0 6.47% 8782848.0
15 features.15 256 56 56 256 56 56 0.0 3.00 802,816.0 802,816.0 3211264.0 3211264.0 0.29% 6422528.0
16 features.16 256 56 56 256 28 28 0.0 0.00 602,112.0 802,816.0 3211264.0 802816.0 2.35% 4014080.0
17 features.17 256 28 28 512 28 28 1180160.0 1.00 1,849,688,064.0 925,245,440.0 5523456.0 1605632.0 3.82% 7129088.0
18 features.18 512 28 28 512 28 28 0.0 1.00 401,408.0 401,408.0 1605632.0 1605632.0 0.29% 3211264.0
19 features.19 512 28 28 512 28 28 2359808.0 1.00 3,699,376,128.0 1,850,089,472.0 11044864.0 1605632.0 6.47% 12650496.0
20 features.20 512 28 28 512 28 28 0.0 1.00 401,408.0 401,408.0 1605632.0 1605632.0 0.00% 3211264.0
21 features.21 512 28 28 512 28 28 2359808.0 1.00 3,699,376,128.0 1,850,089,472.0 11044864.0 1605632.0 7.35% 12650496.0
22 features.22 512 28 28 512 28 28 0.0 1.00 401,408.0 401,408.0 1605632.0 1605632.0 0.00% 3211264.0
23 features.23 512 28 28 512 14 14 0.0 0.00 301,056.0 401,408.0 1605632.0 401408.0 1.47% 2007040.0
24 features.24 512 14 14 512 14 14 2359808.0 0.00 924,844,032.0 462,522,368.0 9840640.0 401408.0 2.35% 10242048.0
25 features.25 512 14 14 512 14 14 0.0 0.00 100,352.0 100,352.0 401408.0 401408.0 0.29% 802816.0
26 features.26 512 14 14 512 14 14 2359808.0 0.00 924,844,032.0 462,522,368.0 9840640.0 401408.0 1.76% 10242048.0
27 features.27 512 14 14 512 14 14 0.0 0.00 100,352.0 100,352.0 401408.0 401408.0 0.00% 802816.0
28 features.28 512 14 14 512 14 14 2359808.0 0.00 924,844,032.0 462,522,368.0 9840640.0 401408.0 2.35% 10242048.0
29 features.29 512 14 14 512 14 14 0.0 0.00 100,352.0 100,352.0 401408.0 401408.0 0.00% 802816.0
30 features.30 512 14 14 512 7 7 0.0 0.00 75,264.0 100,352.0 401408.0 100352.0 0.00% 501760.0
31 classifier.0 25088 4096 102764544.0 0.00 205,516,800.0 102,760,448.0 411158528.0 16384.0 6.47% 411174912.0
32 classifier.1 4096 4096 0.0 0.00 4,096.0 4,096.0 16384.0 16384.0 0.00% 32768.0
33 classifier.2 4096 4096 0.0 0.00 0.0 0.0 0.0 0.0 0.00% 0.0
34 classifier.3 4096 4096 16781312.0 0.00 33,550,336.0 16,777,216.0 67141632.0 16384.0 1.47% 67158016.0
35 classifier.4 4096 4096 0.0 0.00 4,096.0 4,096.0 16384.0 16384.0 0.29% 32768.0
36 classifier.5 4096 4096 0.0 0.00 0.0 0.0 0.0 0.0 0.00% 0.0
37 classifier.6 4096 1000 4097000.0 0.00 8,191,000.0 4,096,000.0 16404384.0 4000.0 0.29% 16408384.0
total 138357544.0 100.00 30,958,666,264.0 15,503,489,024.0 16404384.0 4000.0 100.00% 783170624.0
============================================================================================================================================================
Total params: 138,357,544
------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 100.00MB
Total MAdd: 30.96GMAdd
Total Flops: 15.5GFlops
Total MemR+W: 746.89MB
?
?
3.GoogLeNet
from torch import nn
import torch
# 卷积 + BN
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv= nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Inception(nn.Module):
def __init__(self, in_channels):
super(Inception, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
self.branch3x3_1 = BasicConv2d(in_channels, 64, kernel_size=1)
self.branch3x3_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
self.branch3x3_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
self.branchpool_1 = nn.AvgPool2d(kernel_size=3, stride=1)
self.branch_pool = BasicConv2d(in_channels, 64, kernel_size=1, padding=1)
def forward(self, x):
branch1x1 = self.branch1x1(x)
print('branch1x1_size: ', branch1x1.size())
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
print('branch5x5_size: ', branch5x5.size())
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch3x3 = self.branch3x3_3(branch3x3)
print('branch3x3_size: ', branch3x3.size())
branch_pool = self.branchpool_1(x)
branch_pool = self.branch_pool(branch_pool)
print('branch_pool_size: ', branch_pool.size())
outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
return torch.cat(outputs, 1)
model = Inception(3)
from torchstat import stat
stat(model, (3, 224, 224))
branch1x1_size: torch.Size([1, 64, 224, 224])
branch5x5_size: torch.Size([1, 64, 224, 224])
branch3x3_size: torch.Size([1, 96, 224, 224])
branch_pool_size: torch.Size([1, 64, 224, 224])
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 branch1x1.conv 3 224 224 64 224 224 192.0 12.00 16,056,320.0 9,633,792.0 602880.0 12845056.0 34.88% 13447936.0
1 branch1x1.bn 64 224 224 64 224 224 128.0 12.00 12,845,056.0 6,422,528.0 12845568.0 12845056.0 17.03% 25690624.0
2 branch5x5_1.conv 3 224 224 48 224 224 144.0 9.00 12,042,240.0 7,225,344.0 602688.0 9633792.0 1.98% 10236480.0
3 branch5x5_1.bn 48 224 224 48 224 224 96.0 9.00 9,633,792.0 4,816,896.0 9634176.0 9633792.0 0.40% 19267968.0
4 branch5x5_2.conv 48 224 224 64 224 224 76800.0 12.00 7,703,822,336.0 3,853,516,800.0 9940992.0 12845056.0 3.64% 22786048.0
5 branch5x5_2.bn 64 224 224 64 224 224 128.0 12.00 12,845,056.0 6,422,528.0 12845568.0 12845056.0 0.48% 25690624.0
6 branch3x3_1.conv 3 224 224 64 224 224 192.0 12.00 16,056,320.0 9,633,792.0 602880.0 12845056.0 1.27% 13447936.0
7 branch3x3_1.bn 64 224 224 64 224 224 128.0 12.00 12,845,056.0 6,422,528.0 12845568.0 12845056.0 0.48% 25690624.0
8 branch3x3_2.conv 64 224 224 96 224 224 55296.0 18.00 5,544,247,296.0 2,774,532,096.0 13066240.0 19267584.0 3.40% 32333824.0
9 branch3x3_2.bn 96 224 224 96 224 224 192.0 18.00 19,267,584.0 9,633,792.0 19268352.0 19267584.0 0.71% 38535936.0
10 branch3x3_3.conv 96 224 224 96 224 224 82944.0 18.00 8,318,779,392.0 4,161,798,144.0 19599360.0 19267584.0 4.28% 38866944.0
11 branch3x3_3.bn 96 224 224 96 224 224 192.0 18.00 19,267,584.0 9,633,792.0 19268352.0 19267584.0 0.48% 38535936.0
12 branchpool_1 3 224 224 3 222 222 0.0 0.00 1,330,668.0 150,528.0 602112.0 591408.0 0.24% 1193520.0
13 branch_pool.conv 3 222 222 64 224 224 192.0 12.00 16,056,320.0 9,633,792.0 592176.0 12845056.0 30.28% 13437232.0
14 branch_pool.bn 64 224 224 64 224 224 128.0 12.00 12,845,056.0 6,422,528.0 12845568.0 12845056.0 0.48% 25690624.0
total 216752.0 186.00 21,727,940,076.0 10,875,898,880.0 12845568.0 12845056.0 100.00% 344852256.0
============================================================================================================================================================
Total params: 216,752
------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 186.00MB
Total MAdd: 21.73GMAdd
Total Flops: 10.88GFlops
Total MemR+W: 328.88MB
?
4.ResNet
from torch import nn
from torchstat import stat
def conv3x3(in_planes, out_planes, stride=1):
# 3x3 convolution with padding
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
# downsample对应着一个下采样函数
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
|