介绍
在使用Pytorch时,我们经常需要对一些tensor进行形状的改变以满足神经网络对输入数据的维度要求,我们最常用的两种方式就是.view() 以及 .reshape() , 除此之外,还有一种方法是.resize_() , 这种方法不仅可以改变数据的形状,同时还可以做到数据的部分截取。
在这篇博文中,我会将前两种方式,即.view() 和.reshape() ,作为主要的介绍对象。
1.view()
我们通过调用.view() 方法返回的结果是原始张量数据的另一个视图,即返回的张量与原始张量共享基础数据(但是不是共享内存地址!)。同时当我们使用.view() 对原始张量进行reshape, 一个必要的条件是原始的张量必须在内存空间中连续分布(contiguous)。比如当我们在调用.view() 方法前使用了诸如.permute() 或者.transpose() 等张量维度调换的操作,就会产生该张量在内存空间分布不连续的情况,从而当我们进行.view() 操作时,Pytorch会报如下错误
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
当我们遇到如上情况时,我们可以先将tensor进行tensor.contiguous() 变换后再进行.view() ,这样就不会报错了。或者也可以用.reshape() 方法来代替.view() 方法,同样也是可以的。下面我们就来看一看.reshape() 方法。
2.reshape()
从作用上来说,.reshape() 方法和.view() 方法是一致的,都是想要在不改变数据总量的情况下改变原有张量的形状。但是正如前文所说,使用.view() 方法进行形状重构需要保证原始tensor在内存空间中分布连续,不然无法进行重构。而.reshape() 方法却不需要这个条件,因为.reshape() 方法既可以返回原始tensor的视图(view)也可以返回原始tensor的一个copy, 具体哪种情况取决于这个原始tensor的内存空间分布是否连续。即如果发现这个原始tensor因为经过.permute() 或.transpose() 等操作变得在内存空间不连续,那么.reshape() 方法会返回copy,而如果是正常情况,使用.reshape() 也和.view() 一样返回view。
以上是我的个人理解,如果存在错误欢迎大家批评指正。
|