【SegNet搭建-pytorch】加载预训练权重+利用反池化上采样记录indices
说明:该文章为个人笔记,存在不完整,敬请谅解~~~ 论文地址:https://arxiv.org/pdf/1511.00561.pdf
1 掌握
- 1、加载预训练模型(迁移学习);
- 2、采用反池化(max pooling)进行上采样,记录使用indices方法;
2 SegNet参考模型图
3、记录indices技巧(关键部分)
import torch
import torch.nn as nn
from torchvision.models import vgg16
#############1 加载预训练模型################
vgg16_pretrained = vgg16(pretrained=True)
print(vgg16_pretrained) #查看vgg16网络架构
#############2 定义需要返回maxpooling中indices索引值###########
for index in [4, 9, 16, 23, 30]: # 定义返回maxpooling-indices的层
vgg16_pretrained.features[index].return_indices = True
class SegNet(nn.Module): # SegNet模型搭建
def __init__(self):
super(SegNet, self).__init__()
# encode layers, maxpool_indices
self.encode1 = vgg16_pretrained.features[:4] # 左闭右开,第4层取不到。
self.pool1 = vgg16_pretrained.features[4] # 第四层 maxpooling层,单独写出来为了forward方面得到indices
......
self.encode5 = vgg16_pretrained.features[24:30]
self.pool5 = vgg16_pretrained.features[30]
self.decode5 = decoder(512, 512)
self.uppool5 = nn.MaxUnpool2d(2,2)
....
self.decode1 = decoder(64, 12, 2)
self.unpool1 = nn.MaxUnpool2d(2, 2)
def forward(self, x): # 前向传播
encode1 = self.encode1(x)
encode1_size = encode1.size() # 3 此处定义size,为方面后面使用反池化的feature map大小
pool1, indices1 = self.pool1(encode1) ##### 4 获取encode阶段的indices#################
...
encode5 = self.encode4(pool4)
encode5_size = encode5.size()
pool5, indices5 = self.pool5(encode5)
############ 5 反池化indices和输出的size大小#########################
unpool5 = self.uppool5(input=pool5, indices=indices5, output_size=encode5_size)
decoder5 = self.decode5(input=unpool5)
...
unpool1 = self.unpool1(input=decoder2, indices=indices1, output_size=encode1_size)
decoder1 = self.decoder1(unpool1)
return decoder1
4、模型搭建
import torch
import torch.nn as nn
from torchvision.models import vgg16
#############1 加载预训练模型################
vgg16_pretrained = vgg16(pretrained=True)
print(vgg16_pretrained) #查看vgg16网络架构
#############2 定义需要返回maxpooling中indices索引值###########
for index in [4, 9, 16, 23, 30]: # 定义返回maxpooling-indices的层
vgg16_pretrained.features[index].return_indices = True
def decoder(input_channels, output_channels, num=3): # 解码器
if num == 3:
block = nn.Sequential(
nn.Conv2d(input_channels, input_channels, kernel_size=(3, 3), padding=1),
nn.Conv2d(input_channels, input_channels, kernel_size=(3, 3), padding=1),
nn.Conv2d(input_channels, output_channels, kernel_size=(3, 3), padding=1)
)
elif num == 2:
block = nn.Sequential(
nn.Conv2d(input_channels, input_channels, kernel_size=(3, 3), padding=1),
nn.Conv2d(input_channels, output_channels, kernel_size=(3, 3), padding=1)
)
return block
class SegNet(nn.Module): # SegNet模型搭建
def __init__(self):
super(SegNet, self).__init__()
# encode layers, maxpool_indices
self.encode1 = vgg16_pretrained.features[:4] # 左闭右开,第4层取不到。
self.pool1 = vgg16_pretrained.features[4] # 第四层 maxpooling层,单独写出来为了forward方面得到indices
self.encode2 = vgg16_pretrained.features[5:9]
self.pool2 =vgg16_pretrained.featuresp[9]
self.encode3 = vgg16_pretrained.features[10:16]
self.pool3 = vgg16_pretrained.features[16]
self.encode4 = vgg16_pretrained.features[17:23]
self.pool4 = vgg16_pretrained.features[23]
self.encode5 = vgg16_pretrained.features[24:30]
self.pool5 = vgg16_pretrained.features[30]
self.decode5 = decoder(512, 512)
self.uppool5 = nn.MaxUnpool2d(2,2)
self.decode4 = decoder(512,256)
self.uppool4 = nn.MaxUnpool2d(2,2)
self.decode3 =decoder(256, 128)
self.uppool3 = nn.MaxUnpool2d(2,2)
self.decode2 = decoder(128, 64, 2)
self.unpool2 = nn.MaxUnpool2d(2, 2)
self.decode1 = decoder(64, 12, 2)
self.unpool1 = nn.MaxUnpool2d(2, 2)
def forward(self, x): # 前向传播
encode1 = self.encode1(x);print('encode1:', encoder1.size())
encode1_size = encode1.size() # 此处定义size,为方面后面使用反池化的feature map大小
pool1, indices1 = self.pool1(encode1) ##### 获取encode阶段的indices
encode2 = self.encode2(pool1)
encode2_size = encode2.size()
pool2, indices2 = self.pool2(encode2)
encode3 = self.encode3(pool2)
encode3_size = encode3.size()
pool3, indices3 = self.pool3(encode3)
encode4 = self.encode4(pool3)
encode4_size = encode4.size()
pool4, indices4 = self.pool4(encode4)
encode5 = self.encode4(pool4)
encode5_size = encode5.size()
pool5, indices5 = self.pool5(encode5)
unpool5 = self.uppool5(input=pool5, indices=indices5, output_size=encode5_size) # 反池化注意indices和输出的size大小
decoder5 = self.decode5(input=unpool5)
unpool4 = self.uppool4(input=decoder5,indices = indices4, output_size = encode4_size)
decoder4 = self.decode4(unpool4)
unpool3 = self.uppool3(input=decoder4, indices=indices3, output_size=encode3_size)
decoder3 = self.decode3(unpool3)
unpool2 = self.unpool2(input=decoder3, indices=indices2, output_size=encode2_size)
decoder2 = self.decoder2(unpool2)
unpool1 = self.unpool1(input=decoder2, indices=indices1, output_size=encode1_size)
decoder1 = self.decoder1(unpool1)
return decoder1
|