IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 网络模型中层的修改,删除与添加 -> 正文阅读

[人工智能]网络模型中层的修改,删除与添加

1. 权重参数保存的两种方式

第一种: 将网络模型和对应的参数保存在一起;

第二种: 模型和参数分离, 单独的保存模型的权重参数,

方式二官方推荐, 方便于网络模型修改后, 提取出对应层的参数权重;

1.1 模型和参数保存在一起

  1. 保存时:
 net123 = module.CustomModel()
# CustomModel 是自己定义的模型类, 放在 module 的文件中;

torch.save( net123,"./weights/All_in.pth")

  1. 加载模型时:
 net123 = torch.load("./weights/All_in.pth")

1.1 单独保存模型的参数-state_dict()

  1. 保存时:
net123 = module.CustomModel()
# CustomModel 是自己定义的模型类, 放在 module 的文件中;

torch.save(net123.state_dict(),'./weights/epoch_weight.pth')

  1. 加载模型时:
net123.load_state_dict(torch.load('epoch_weight.pth'))

2. 移除网络中的层

2.1 思路

具体的逻辑思路:

  1. 先删除网络所对应的层;
  2. 然后拷贝一份原始网络权重参数, 去除 步骤1中删除层所对应的权重参数;

2.2 删除网络层

具体的个人实现方式是:

  1. 拷贝原始网络模型的类, 将拷贝过来的类重命名为自己的类 class CustomModel(nn.Module)

  2. 在这个重命名类的基础上,进行修改,删除网络层;

其中重点修改该类中两个函数, 1. 初始化函数,2. 前向传播函数:

  1. 初始化函数部分, 注释掉自己不需要的属性;
  2. 前向传播函数部分, 注释掉,自己想要删除的那个网络层;
class CustomModle(nn.Module)

1. 初始化函数,`def __init__(self)`: 
2. 前向传播函数, `def   forward():`

举例讲来,假设这里需要移除 self.patch_embed 层,
那么在前向传播函数中, 便注释掉该层的 前向传播;

   def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        # x = self.patch_embed(x)  # [B, 196, 768]
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

2.3 删除网络层所对应的权重参数

  1. 加载原始的网络权重文件;
  2. 写出,需要删除网络层所对应的参数名称,
    (这一步,可以通过调试到 self.net = your_model()时,查看网络中每一层所对应的名字)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block
  1. 通过字典索引的方式,删除层所对应的权重参数;
    key: 删除网络层的名称, value: 该层所对应的权重;
       if self.args.weights is not None:
                assert os.path.exists(self.args.weights), "weights file: '{}' not exist.".format(self.args.weights)
                weights_dict = torch.load(self.args.weights, map_location='cuda:0')
                # 删除不需要的权重
                del_keys_patch_embed = ['patch_embed.proj.bias', 'patch_embed.proj.weight']
                for k in del_keys_patch_embed:
                    del weights_dict[k]


                del_keys = ['head.weight', 'head.bias'] if self.net.has_logits \
                    else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']
                for k in del_keys:
                    del weights_dict[k]
                print(self.net.load_state_dict(weights_dict, strict=False))

2.4 另外一种移除权重参数的方法

import torch
from collections import OrderedDict
import os
import torch.nn as nn
import torch.nn.init as init
from xxx import new_VGG
 
def init_weight(modules):
    for m in modules:
        if isinstance(m, nn.Conv2d):
            init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal(0,0.01)
            m.bias.data.zero_()
 
def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith('module'):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k,v in state_dict.items():
        name = ','.join(k.split('.')[state_idx:])
        new_state_dict[name] = v
    return new_state_dict
 
#加载pretrain model
state_dict = torch.load('/users/xxx/xxx.pth')
 
new_dict = copyStateDict(state_dict)
keys = []
for k,v in new_dict.items():
  #将‘conv_cls’开头的key过滤掉,这里是要去除的层的key
    if k.startswith('conv_cls'):  
        continue
    keys.append(k)
 
#去除指定层后的模型
new_dict = {k:new_dict[k] for k in keys}


#自己定义的模型,但要保证前面保存的层和自定义的模型中的层一致
net = new_VGG()  
 
#加载pretrain model中的参数到新的模型中,
#此时自定义的层中是没有参数的,在使用的时候需要init_weight一下
net.state_dict().update(new_dict)
 
#保存去除指定层后的模型
torch.save(net.state_dict(), '/users/xxx/xxx.pth')

3. 修改特定层

def Net(nn.Module):
	def __init__(self, input_ch, num_class,pretrained=True):
		super(Net,self).__init__()
		self.model = models.resnet50(pretrained=pretrained)
		conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
		self.model.conv1 = conv1 #替换原来的conv1
	def forward(self,x):
		x = self.model(x)
		return x	

按照上方的操作,则conv1的预训练权重无法被利用。为了能够利用到conv1的预训练权重,我们沿着dim=1取平局,拓展平均后的权重至与新conv1权重维度一致。

def Net(nn.Module):
	def __init__(self, input_ch, num_class,pretrained=True):
		super(Net,self).__init__()
		self.model = models.resnet50(pretrained=pretrained)
		conv1_weight = torch.mean(self.model.conv1.weight,dim=1,keepdim=True).repeat(1,input_ch,1,1)#取出从conv1权重并进行平均和拓展
		conv1 = nn.Conv2d(input_ch, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
		model_dict = self.model.state_dict()#获取整个网络的预训练权重
		self.model.conv1 = conv1 #替换原来的conv1
		model_dict['conv1.weight'] = conv1_weight #将conv1权重替换为新conv1权重
		model_dict.update(model_dict)#更新整个网络的预训练权重
		self.model.load_state_dict(model_dict)#载入新预训练权重
		
	def forward(self,x):
		x = self.model(x)
		return x	

4. 使用网络层所对应的序号

#为true时,网络参数为在数据集上训练好的
vgg16=tv.models.vgg16(True)
#在features的后面增加一层,classifier为一属性
vgg16.features.add_module('extra',nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
#修改指定层多的信息:可用序号修改指定的层:
vgg16.features[31]=nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#删除掉features的最后一层:
vgg16.features=nn.Sequential(*list(vgg16.features.children())[:-1])
print(vgg16)
#冻结指定层的预训练参数:
vgg16.features[26].weight.requires_grad = False
#optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.1)

reference:

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-11 16:27:03  更:2022-05-11 16:30:13 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 6:45:06-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码