IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> PYTORCH中CNN网络参数计算和模型文件大小预估 -> 正文阅读

[人工智能]PYTORCH中CNN网络参数计算和模型文件大小预估

网络参数计算

先定义好网络结构,然后统计网络参数。

网络定义

以LeNet-5为例,参考之前的博客《PyTorch构建网络示例:LeNet-5》,网络结构设计代码如下:

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import os

class lenet5(nn.Module):
    def __init__(self):
        super(lenet5,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)
        self.conv2 = nn.Conv2d(6,16,5,1)
        self.pool2 = nn.AvgPool2d(2)
        self.fc1 = nn.Linear(4*4*16,120)#注:按原始的minst数据集,输入为图示中的32*32时,此处应该是5*5*16.但是按torchvision.datasets中的输入大小则是28*28,此处为4*4*16。或者直接在Data.DataLoader时将输入transforms到32*32大小
        self.fc2 = nn.Linear(120,84)
        self.out = nn.Linear(84,10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)#flatten the output of pool2 to (batch_size, 16 * 4 * 4),x.shape[0]为batch_size,-1为自适应调整大小
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.out(x),dim=-1)
        return x

net = lenet5()

参数计算

定义参数统计函数并传入实例化的net:

def cnn_paras_count(net):
    """cnn参数量统计, 使用方式cnn_paras_count(net)"""
    # Find total parameters and trainable parameters
    total_params = sum(p.numel() for p in net.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')
    return total_params, total_trainable_params

cnn_paras_count(net)

输出结果:

模型文件大小预估

需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。
例如:

44426个参数*4 / 1024 ≈ 174KB

查看网络训练中实际存储的模型文件大小:

相差的部分是模型文件中除了需要存储实际参数外还需要存储一些网络信息。总体差异不大。
更进一步可以转化为MBGB等单位。

后记

CNN网络参数计算和模型文件大小预估对实际的网络设计很有帮助,可以提前了解所设计的网络的复杂度,并根据模型文件大小提前规划硬盘存储空间。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-12 16:27:40  更:2022-05-12 16:29:43 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/1 23:16:36-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码