1.问题描述
最近在看语义分割的源码,但是预测阶段有一行代码看的我头大,不知道在索引些什么东西。后来经过查阅资料和实验发现这是numpy的高级索引,故写这篇博文记录一下自己是如何理解的。这行代码长这样。意思是把前向传播得到的预测值取最大值,并把最大值的索引值赋值给pre_label((352,480)的矩阵)。这个索引值代表了一张图经过模型的预测后,每一个像素值分别属于哪一个类别。通过这张全图都是索引值的图,去cm(color map)中找属于这个索引值的相应颜色(每个索引值对应了RGB3个通道的颜色),其中最关键的一部分就是pre = cm[pre_label]这一行代码。
for i, sample in enumerate(test_data):
valImg = sample['img'].to(device)
valLabel = sample['label'].long().to(device)
out = net(valImg)
out = F.log_softmax(out, dim=1)
pre_label = out.max(1)[1].squeeze().cpu().data.numpy()
pre = cm[pre_label]
pre1 = Image.fromarray(pre)
pre1.save(dir + str(i) + '.png')
print('Done')
那么问题来了!(12,3)的矩阵的索引是一个(352,480)的矩阵,索引出来是个什么玩意?
2.查阅资料
资料:根据用户手册,numpy数组支持数组索引。返回的数组与索引数组具有相同的形状,与原数组元素具有相同的类型和值(被索引位置)。针对你的问题,也就是你理解的:返回的还是一个二维数组,返回数组的值是以二维数组每个元素作为一维数组索引在一维数组中的值。
资料看半天,也不像是人话。所以自己动手做实验。
3.实验
首先定义数组,label模拟上面的cm(color map),index模拟上面的pre_label((352,480)的矩阵),但是我这边设置这么大的矩阵不利于查看结果,所以我只设置了(4,5)的矩阵,方便后面查看结果。
先看一下经过这种特殊索引后的shape 从这个shape可以看出,由于label[index]索引的时候并没有逗号,所以是对第一个维度的数据进行操作。再结合第二部分的资料就很好理解了:返回的数组与索引数组具有相同的形状,与原数组元素具有相同的类型和值(被索引位置)。资料中是针对的索引是一维数组的情况下。那么拓展一下,把一维(数组)拓展到二维(矩阵)。那么这条也是成立的。即把label原本(12,3)的shape的第一个维度拿出来,也就是12。被(4,5)的矩阵索引后。12这个维度就变成了(4,5)。其他维度不变。所以shape是(4,5,3)。那么这个(4, 5,3)的矩阵里面究竟是什么呢?通过之前学到numpy高级索引很容易就得到结果了。就是将index里的每一个元素的值(我这里设置的是1)当做label第一个维度(也就是行)进行索引。所以组成的(4,5,3)这个矩阵里全部由label中第二行(0对应第一行,1对应是第二行)组成。
|