matplotlib绘图imshow()函数报错“TypeError: Invalid dimensions for image data”
错误代码
plt.imshow((img[6, :, :, :].moveaxis(0, 2)))
改为
plt.imshow((img[6, :, :, :]))
报错 TypeError: Invalid dimensions for image data”
修改为:
plt.imshow((img[6, :, :, :].squeeze().numpy().transpose(1,2,0)))
参考 解决这个问题的关键就是理解了imshow函数的参数。 matplotlib.pyplot.imshow()函数的输入需要是二维的numpy或者是第三维度是3或4的numpy,
- 当第3维深度是1时,使用np.squeeze()函数压缩数据成为二维数组。
- 因为我在pytorch环境下使用,得到结果的输出是(batch_size,channel,width,height)的tensor,因此我首先需要detach()函数切断反向传播。
- 需要指出的是,imshow不支持显示tensor,因此,我需要使用.cpu()函数转移到cpu上来。
- 正如前面说到的,imshow函数的输入需要是二维的numpy或者第三维度是3或4的numpy,
- 因为我的使用情况比较特殊,还多了一个batch_size维度,不过还好,我设置batch_size仅为1,这时候可以使用.squeeze()函数把1给去掉,得到了是一个(channel,widht,height)的numpy,这显然与imshow的输入要求不符。因此,我们需要使用transpose函数把channel(=3)移动到最后,这也是为什么才有了.transpose(1,2,0)这种用法。当然,如果待显示的图像本身就是channel=1,那么完全可以使用squeeze()函数把其搞掉,直接输入给imshow函数一个二维的numpy.
|