网络参数计算
先定义好网络结构,然后统计网络参数。
网络定义
以LeNet-5为例,参考之前的博客《PyTorch构建网络示例:LeNet-5》,网络结构设计代码如下:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import os
class lenet5(nn.Module):
def __init__(self):
super(lenet5,self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)
self.conv2 = nn.Conv2d(6,16,5,1)
self.pool2 = nn.AvgPool2d(2)
self.fc1 = nn.Linear(4*4*16,120)#注:按原始的minst数据集,输入为图示中的32*32时,此处应该是5*5*16.但是按torchvision.datasets中的输入大小则是28*28,此处为4*4*16。或者直接在Data.DataLoader时将输入transforms到32*32大小
self.fc2 = nn.Linear(120,84)
self.out = nn.Linear(84,10)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(x.shape[0], -1)#flatten the output of pool2 to (batch_size, 16 * 4 * 4),x.shape[0]为batch_size,-1为自适应调整大小
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.out(x),dim=-1)
return x
net = lenet5()
参数计算
定义参数统计函数并传入实例化的net:
def cnn_paras_count(net):
"""cnn参数量统计, 使用方式cnn_paras_count(net)"""
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in net.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
return total_params, total_trainable_params
cnn_paras_count(net)
输出结果:
模型文件大小预估
需要注意的是,一般模型中参数是以float32 保存的,也就是一个参数由4个bytes 表示,那么就可以将参数量转化为存储大小。 例如:
44426个参数*4 / 1024 ≈ 174KB
查看网络训练中实际存储的模型文件大小:
相差的部分是模型文件中除了需要存储实际参数外还需要存储一些网络信息。总体差异不大。 更进一步可以转化为MB 、GB 等单位。
后记
CNN网络参数计算和模型文件大小预估对实际的网络设计很有帮助,可以提前了解所设计的网络的复杂度,并根据模型文件大小提前规划硬盘存储空间。
|