本意是想大概跑一下CNN对掌纹进行识别分类的代码,了解一下流程和框架。基本内容参考基于CNN对掌纹图片进行分类。
1-1.RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x100820 and 408980x230)
1-2 RuntimeError: shape ‘[-1, 72000]’ is invalid for input of size 3226240
这两个问题都是因为没有正确理解和计算卷积层、池化层的输入和输出大小而直接套用别人的CNN网络导致的。 制错代码:
class net(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(1,10,5)#进行2d卷积,因为是灰度图像,所以初试通道是1。输入通道为1,输出通道是10,进行5*5的2d卷积
self.conv2=nn.Conv2d(10,20,3)
self.fc1=nn.Linear(20*60*60,500)
self.fc2=nn.Linear(500,386)#要分386类
def forward(self,x):
input_size=x.size(0) #batch_size)
x=self.conv1(x) # 输入:batch *1*128*128 输出 batch *10*124*124
x=F.relu(x) #size保持不变
x=F.max_pool2d(x,2,2) #输入 batch *10*124*124 输出batch*10*62*62,对图片进行压缩,减少运算。
x=self.conv2(x) #输入 batch*10*62*62 输出 batch*20*60*60
x=F.relu(x)#激活层,不改变图片的shape,每次卷积之后进行一次激活,输出一个非线性函数,增强神经元的表达能力
x=x.view(input_size,-1) #将图片转化为一维线性
x=self.fc1(x)#输入 batch*30*60*60 输出 batch*500
x=F.relu(x)
x=self.fc2(x)#输入 batch*500 输出 batch*386
out_put=F.log_softmax(x,dim=1)#计算损失函数,输出概率最大的类别
return out_put
我使用的图像数据大小为150*150,灰度图像是1通道。在卷积层的参数不变的情况下,由于输入大小不同,输出尺寸也不同,所以需要根据公式重新计算,参考卷积神经网络中各个卷积层的设置及输出大小计算的详细讲解,卷积神将网络的计算公式为:N=(W-F+2P)/S+1 其中N:输出大小;W:输入大小;F:卷积核大小;P:填充值的大小;S:步长大小。
卷积层和池化层会改变图像尺寸,ReLU()不会。将卷积层进一步打包,重新计算输入输出得到改正的CNN:
class net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 10, 5), # 输入:1*150*150 输出:10*146*146 (146=(150-5)/1+1)
nn.ReLU(),
nn.MaxPool2d(2),) # 输入:10*146*146 输出:10*73*73 (73=(146-2)/2+1)
self.conv2 = nn.Sequential(
nn.Conv2d(10, 20, 3), # 输入:10*73*73 输出:20*71*71 (143=(73-3)/1+1)
nn.ReLU(),)
self.fc1 = nn.Linear(20*71*71, 230) # 要分230类
#self.fc2 = nn.Linear(500, 230)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # 将图片转化为一维线性
x = self.fc1(x) # 输入 batch*20*71*71 输出 batch*230
out_put = F.log_softmax(x, dim=1) # 计算损失函数,输出概率最大的类别
return out_put
问题解决!
2. IndexError: Target 230 is out of bounds.
这个数据集的ROI来自230个人,所以多分类为230类:1-230。 报错是指标签230超出范围,linear层的输出分类也没有错误,参考这篇文章 ,将label改为0-229:
train_label.append(int(re.findall(r"\d+", filename)[0])-1)#将图片名的第一个数字减1(0-229)作为label
运行成功!
在未修改网络结构和参数的情况下,跑的结果:
|