本文使用pytorch框架来搭建一个多分支的神经网络,编程时借鉴了Inception的编程思想。 准备实现的网络结构如下图所示
(图画的有点丑)
每个分支的输入图片大小设置为64*64,卷积层和池化层的参数设置如下表所示
层 | 参数 |
---|
CONV1 | 3,16,kernel_size=3, stride=1, padding=1 | Pooling1 | kernel_size=2, stride=2 | CONV2 | 16,32,kernel_size=3, stride=1, padding=1 | Pooling2 | kernel_size=2, stride=2 | CONV3 | 32,64,kernel_size=3, stride=1, padding=1 | Pooling3 | kernel_size=2, stride=2 | CONV4 | 64,128,kernel_size=3, stride=1, padding=1 | Pooling4 | kernel_size=2, stride=2 |
首先,导入需要的库
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim
接下来,定义网络模型
class ThreeInputsNet(nn.Module):
def __init__(self):
super(ThreeInputsNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 3)
self.outlayer2 = nn.Linear(128 * 3, 256)
self.outlayer3 = nn.Linear(256, 3)
def forward(self, input1, input2, input3):
out1 = self.pooling1(self.conv1(input1))
out1 = self.pooling1(self.conv2(out1))
out1 = self.pooling1(self.conv3(out1))
out1 = self.pooling1(self.conv4(out1))
out2 = self.pooling1(self.conv1(input2))
out2 = self.pooling1(self.conv2(out2))
out2 = self.pooling1(self.conv3(out2))
out2 = self.pooling1(self.conv4(out2))
out3 = self.pooling1(self.conv1(input3))
out3 = self.pooling1(self.conv2(out3))
out3 = self.pooling1(self.conv3(out3))
out3 = self.pooling1(self.conv4(out3))
out = torch.cat((out1, out2, out3), dim=1)
out = out.view(out.size(0), -1)
out = self.outlayer1(out)
out = self.outlayer2(out)
out = self.outlayer3(out)
return out
输入一些数据测试一下网络能否跑通
if __name__ == '__main__':
input1 = torch.ones(8, 3, 64, 64)
input2 = torch.ones(8, 3, 64, 64)
input3 = torch.ones(8, 3, 64, 64)
net = ThreeInputsNet()
output = net(input1, input2, input3)
print("out.shape:{}".format(output.shape))
完整代码
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim
class ThreeInputsNet(nn.Module):
def __init__(self):
super(ThreeInputsNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pooling1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 5)
self.outlayer2 = nn.Linear(128 * 5, 256)
self.outlayer3 = nn.Linear(256, 3)
def forward(self, input1, input2, input3):
out1 = self.pooling1(self.conv1(input1))
out1 = self.pooling1(self.conv2(out1))
out1 = self.pooling1(self.conv3(out1))
out1 = self.pooling1(self.conv4(out1))
out2 = self.pooling1(self.conv1(input2))
out2 = self.pooling1(self.conv2(out2))
out2 = self.pooling1(self.conv3(out2))
out2 = self.pooling1(self.conv4(out2))
out3 = self.pooling1(self.conv1(input3))
out3 = self.pooling1(self.conv2(out3))
out3 = self.pooling1(self.conv3(out3))
out3 = self.pooling1(self.conv4(out3))
out = torch.cat((out1, out2, out3), dim=1)
out = out.view(out.size(0), -1)
out = self.outlayer1(out)
out = self.outlayer2(out)
out = self.outlayer3(out)
return out
if __name__ == '__main__':
input1 = torch.ones(8,3,64,64)
input2 = torch.ones(8, 3, 64, 64)
input3 = torch.ones(8, 3, 64, 64)
net = ThreeInputsNet()
output = net(input1, input2, input3)
print("out.shape:{}".format(output.shape))
|