IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【SegNet搭建-pytorch】反池化上采样indices的记录 -> 正文阅读

[人工智能]【SegNet搭建-pytorch】反池化上采样indices的记录

【SegNet搭建-pytorch】加载预训练权重+利用反池化上采样记录indices

说明:该文章为个人笔记,存在不完整,敬请谅解~~~
论文地址:https://arxiv.org/pdf/1511.00561.pdf

1 掌握

  • 1、加载预训练模型(迁移学习);
  • 2、采用反池化(max pooling)进行上采样,记录使用indices方法;

2 SegNet参考模型图

在这里插入图片描述

3、记录indices技巧(关键部分)

import torch
import torch.nn as nn
from torchvision.models import vgg16

#############1 加载预训练模型################
vgg16_pretrained = vgg16(pretrained=True) 
print(vgg16_pretrained)        #查看vgg16网络架构

#############2 定义需要返回maxpooling中indices索引值###########
for index in [4, 9, 16, 23, 30]:  # 定义返回maxpooling-indices的层
    vgg16_pretrained.features[index].return_indices = True

class SegNet(nn.Module):  # SegNet模型搭建
    def __init__(self):
        super(SegNet, self).__init__()
        # encode layers, maxpool_indices
        self.encode1 = vgg16_pretrained.features[:4]  # 左闭右开,第4层取不到。
        self.pool1 = vgg16_pretrained.features[4]     # 第四层 maxpooling层,单独写出来为了forward方面得到indices
		
		......
		
        self.encode5 = vgg16_pretrained.features[24:30]
        self.pool5 = vgg16_pretrained.features[30]

        self.decode5 = decoder(512, 512)
        self.uppool5 = nn.MaxUnpool2d(2,2)
		
		....
        
        self.decode1 = decoder(64, 12, 2)
        self.unpool1 = nn.MaxUnpool2d(2, 2)
        
    def forward(self, x):  # 前向传播
        encode1 = self.encode1(x)
        encode1_size = encode1.size()  # 3 此处定义size,为方面后面使用反池化的feature map大小
        pool1, indices1 = self.pool1(encode1)  ##### 4 获取encode阶段的indices#################
		...
        encode5 = self.encode4(pool4)
        encode5_size = encode5.size()
        pool5, indices5 = self.pool5(encode5)
        
		############ 5 反池化indices和输出的size大小#########################
        
        unpool5 = self.uppool5(input=pool5, indices=indices5, output_size=encode5_size)   
        decoder5 = self.decode5(input=unpool5)
		...
        unpool1 = self.unpool1(input=decoder2, indices=indices1, output_size=encode1_size)
        decoder1 = self.decoder1(unpool1)

        return decoder1

4、模型搭建

import torch
import torch.nn as nn
from torchvision.models import vgg16

#############1 加载预训练模型################
vgg16_pretrained = vgg16(pretrained=True) 
print(vgg16_pretrained)        #查看vgg16网络架构

#############2 定义需要返回maxpooling中indices索引值###########
for index in [4, 9, 16, 23, 30]:  # 定义返回maxpooling-indices的层
    vgg16_pretrained.features[index].return_indices = True

def decoder(input_channels, output_channels, num=3):  # 解码器
    if num == 3:
        block = nn.Sequential(
            nn.Conv2d(input_channels, input_channels, kernel_size=(3, 3), padding=1),
            nn.Conv2d(input_channels, input_channels, kernel_size=(3, 3), padding=1),
            nn.Conv2d(input_channels, output_channels, kernel_size=(3, 3), padding=1)
        )
    elif num == 2:
        block = nn.Sequential(
            nn.Conv2d(input_channels, input_channels, kernel_size=(3, 3), padding=1),
            nn.Conv2d(input_channels, output_channels, kernel_size=(3, 3), padding=1)
        )

    return block

class SegNet(nn.Module):  # SegNet模型搭建
    def __init__(self):
        super(SegNet, self).__init__()
        # encode layers, maxpool_indices
        self.encode1 = vgg16_pretrained.features[:4]  # 左闭右开,第4层取不到。
        self.pool1 = vgg16_pretrained.features[4]     # 第四层 maxpooling层,单独写出来为了forward方面得到indices

        self.encode2 = vgg16_pretrained.features[5:9]
        self.pool2 =vgg16_pretrained.featuresp[9]

        self.encode3 = vgg16_pretrained.features[10:16]
        self.pool3 = vgg16_pretrained.features[16]

        self.encode4 = vgg16_pretrained.features[17:23]
        self.pool4 = vgg16_pretrained.features[23]

        self.encode5 = vgg16_pretrained.features[24:30]
        self.pool5 = vgg16_pretrained.features[30]

        self.decode5 = decoder(512, 512)
        self.uppool5 = nn.MaxUnpool2d(2,2)

        self.decode4 = decoder(512,256)
        self.uppool4 = nn.MaxUnpool2d(2,2)

        self.decode3 =decoder(256, 128)
        self.uppool3 = nn.MaxUnpool2d(2,2)

        self.decode2 = decoder(128, 64, 2)
        self.unpool2 = nn.MaxUnpool2d(2, 2)

        self.decode1 = decoder(64, 12, 2)
        self.unpool1 = nn.MaxUnpool2d(2, 2)
        
    def forward(self, x):  # 前向传播
        encode1 = self.encode1(x);print('encode1:', encoder1.size())
        encode1_size = encode1.size()  # 此处定义size,为方面后面使用反池化的feature map大小
        pool1, indices1 = self.pool1(encode1)  ##### 获取encode阶段的indices

        encode2 = self.encode2(pool1)
        encode2_size = encode2.size()
        pool2, indices2 = self.pool2(encode2)

        encode3 = self.encode3(pool2)
        encode3_size = encode3.size()
        pool3, indices3 = self.pool3(encode3)

        encode4 = self.encode4(pool3)
        encode4_size = encode4.size()
        pool4, indices4 = self.pool4(encode4)

        encode5 = self.encode4(pool4)
        encode5_size = encode5.size()
        pool5, indices5 = self.pool5(encode5)

        unpool5 = self.uppool5(input=pool5, indices=indices5, output_size=encode5_size)   # 反池化注意indices和输出的size大小
        decoder5 = self.decode5(input=unpool5)

        unpool4 = self.uppool4(input=decoder5,indices = indices4, output_size = encode4_size)
        decoder4 = self.decode4(unpool4)

        unpool3 = self.uppool3(input=decoder4, indices=indices3, output_size=encode3_size)
        decoder3 = self.decode3(unpool3)

        unpool2 = self.unpool2(input=decoder3, indices=indices2, output_size=encode2_size)
        decoder2 = self.decoder2(unpool2)

        unpool1 = self.unpool1(input=decoder2, indices=indices1, output_size=encode1_size)
        decoder1 = self.decoder1(unpool1)

        return decoder1
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-10 11:32:56  更:2021-07-10 11:33:48 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 9:46:43-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码