a=1
print(id(a))
a=2
print(id(a))
并不是在1的空间删除填上2,而是新开辟了空间。
a=[1]
print(id(a[0]))
a[0]=1
print(id(a[0]))
这个是Inplace操作。
embedding=nn.Parameter(torch.rand(2,3))
d=nn.Parameter(torch.rand(3,3))
user_embeddings=embedding.clone()
user_embedding_input = user_embeddings[0]
a=user_embedding_input*3
print(a)
a=torch.matmul(d,user_embedding_input)
print(a)
user_embeddings[0]=a
loss=a.sum()
loss.backward()
报错。
这里涉及一个概念,你直接[0]这样索引,这种属于selectbackward 。不会创建新的内存空间,类似的还有slicebackward (例如b[:2,:1]),其也不会创建新的内存空间。然后在后面又进行了赋值,这样,在计算d的梯度的时候显然会报错。
embedding=nn.Parameter(torch.rand(2,3))
d=nn.Parameter(torch.rand(3,3))
user_embeddings=embedding.clone()
user_embedding_input = user_embeddings[[0],:]
a=torch.matmul(user_embedding_input,d)
print(a)
user_embeddings[[0],:]=a
loss=a.sum()
loss.backward()
不报错,上面的索引是indexbackward ,这个相当于创建了一个新的变量,然后index操作,梯度回传即可。虽然后面user_embeddings改了,但是那个属于中间节点,把user_embedding_input的梯度传过来即可,然后再传给前面的embedding,可以发现,user_embeddings改不改都没有关系。这并不会导致什么错误,而且反向传播之后会清空中间节点的梯度。(补充:indexbackward 取出的时候会创建新变量,并和原来脱离关系,但是如果是要更新vv,则三种索引都会改变vv。vv[select]=1,vv[slice]=1,vv[index]=1这三者都会改变vv。这可能是pytorch出于方便考虑的。总之index和前两者只有在取出来的时候会不一样。)
embedding=nn.Parameter(torch.rand(2,3))
d=nn.Parameter(torch.rand(3,3))
user_embeddings=embedding.clone()
user_embedding_input = user_embeddings[[0],:]
a=torch.matmul(user_embedding_input,d)
print(a)
user_embeddings[[0],:]=a
user_embedding_input=3#question line
loss=a.sum()
loss.backward()#是否报错?
不报错。
这里有人有疑问了,为什么user_embedding_input改了还是不报错,这是因为计算梯度有缓存,而且这个改也不是Inplace的,Pytorch已经缓存了那个原来的空间,所以不报错。
user_embedding_input[0]=2
如果你这么操作,那么就会报错了。
另外一个知识点,中间节点的赋值会连带上之前的计算图。
a=nn.Parameter(torch.tensor([[2.]]))#叶子节点。
b=a.clone()#中间节点。
print(b)
d=nn.Parameter(torch.tensor(3.))
print(d)
e=b[[0],:]*d
b[0]=e#赋值,会带上e的历史。而不仅仅是一个数据。
print(e)
loss1=e.sum()
e=b[[0],:]*d
b[0]=e
print(e)
loss1+=e.sum()
loss1.backward()
d.grad#14
a.grad#12
上面这样其实有点类似于RNN了,这个你能否计算对呢?
补充,上面画得有点不对,那个b节点也应该分裂成两个。我们假设b分成b1,b2,然后下面那两个b1,b2改名叫做bb1,bb2。
这里的难点在于,b1被Inplace了,讲道理b1是从a那里复制过来的,所以是cloneback。然后bb1是从b1那里index过来的。那b1的grad是怎么记录的呢?因为b1已经不存在了,缓存也失效了,因为直接被Inplace了。既然无法记录,如果继续反向传播到a节点呢?
补充
- 一个tensor不可导,对其部分进行赋值inplace一个可导的,整个tensor都会变得可导,也就是说pytorch里面计算梯度是以对象为单位的。
|