IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深入理解pytorch中计算图的inplace操作 -> 正文阅读

[人工智能]深入理解pytorch中计算图的inplace操作

a=1
print(id(a))
a=2
print(id(a))

并不是在1的空间删除填上2,而是新开辟了空间。

a=[1]
print(id(a[0]))
a[0]=1
print(id(a[0]))

这个是Inplace操作。

embedding=nn.Parameter(torch.rand(2,3))
d=nn.Parameter(torch.rand(3,3))
user_embeddings=embedding.clone()
user_embedding_input = user_embeddings[0] 
a=user_embedding_input*3#option1
print(a)
a=torch.matmul(d,user_embedding_input)#option2
print(a)
user_embeddings[0]=a
loss=a.sum()
loss.backward()#是否报错?

报错。

这里涉及一个概念,你直接[0]这样索引,这种属于selectbackward。不会创建新的内存空间,类似的还有slicebackward(例如b[:2,:1]),其也不会创建新的内存空间。然后在后面又进行了赋值,这样,在计算d的梯度的时候显然会报错。

embedding=nn.Parameter(torch.rand(2,3))
d=nn.Parameter(torch.rand(3,3))
user_embeddings=embedding.clone()
user_embedding_input = user_embeddings[[0],:] 
# a=user_embedding_input*3#option1
# print(a)
a=torch.matmul(user_embedding_input,d)#option2
print(a)
user_embeddings[[0],:]=a
loss=a.sum()
loss.backward()#是否报错?

不报错,上面的索引是indexbackward,这个相当于创建了一个新的变量,然后index操作,梯度回传即可。虽然后面user_embeddings改了,但是那个属于中间节点,把user_embedding_input的梯度传过来即可,然后再传给前面的embedding,可以发现,user_embeddings改不改都没有关系。这并不会导致什么错误,而且反向传播之后会清空中间节点的梯度。(补充:indexbackward取出的时候会创建新变量,并和原来脱离关系,但是如果是要更新vv,则三种索引都会改变vv。vv[select]=1,vv[slice]=1,vv[index]=1这三者都会改变vv。这可能是pytorch出于方便考虑的。总之index和前两者只有在取出来的时候会不一样。)

embedding=nn.Parameter(torch.rand(2,3))
d=nn.Parameter(torch.rand(3,3))
user_embeddings=embedding.clone()
user_embedding_input = user_embeddings[[0],:] 
a=torch.matmul(user_embedding_input,d)
print(a)
user_embeddings[[0],:]=a
user_embedding_input=3#question line
loss=a.sum()
loss.backward()#是否报错?

不报错。

这里有人有疑问了,为什么user_embedding_input改了还是不报错,这是因为计算梯度有缓存,而且这个改也不是Inplace的,Pytorch已经缓存了那个原来的空间,所以不报错。

user_embedding_input[0]=2

如果你这么操作,那么就会报错了。

另外一个知识点,中间节点的赋值会连带上之前的计算图。

a=nn.Parameter(torch.tensor([[2.]]))#叶子节点。
b=a.clone()#中间节点。
print(b)
d=nn.Parameter(torch.tensor(3.))
print(d)
e=b[[0],:]*d
b[0]=e#赋值,会带上e的历史。而不仅仅是一个数据。
print(e)
loss1=e.sum()
e=b[[0],:]*d
b[0]=e
print(e)
loss1+=e.sum()
loss1.backward()
d.grad#14
a.grad#12

上面这样其实有点类似于RNN了,这个你能否计算对呢?

在这里插入图片描述
补充,上面画得有点不对,那个b节点也应该分裂成两个。我们假设b分成b1,b2,然后下面那两个b1,b2改名叫做bb1,bb2。

这里的难点在于,b1被Inplace了,讲道理b1是从a那里复制过来的,所以是cloneback。然后bb1是从b1那里index过来的。那b1的grad是怎么记录的呢?因为b1已经不存在了,缓存也失效了,因为直接被Inplace了。既然无法记录,如果继续反向传播到a节点呢?

补充

  1. 一个tensor不可导,对其部分进行赋值inplace一个可导的,整个tensor都会变得可导,也就是说pytorch里面计算梯度是以对象为单位的。
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-12 17:30:27  更:2022-03-12 17:33:21 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 15:46:57-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码