| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> Pytorch随笔--Tensor与Variable -> 正文阅读 |
|
[人工智能]Pytorch随笔--Tensor与Variable |
前言该文章是我回过头来看之前的pytorch笔记的一些记录,以自己的思路将一些相关的内容集合起来进行记录,算是对之前阶段学习的总结。 问题描述刚接触pytorch的时候,一个很常见的操作就是参考别人的代码或者观看别人的教学视频。但随着pytorch的更新,一些特性便会发生变化,这就很明显地会体现在代码上,比如之前的一些代码和现在的代码将会变得不一样,尽管它们实现的功能基本是一样的。尤其是对于训练模型这一部分来说,基本的流程会很相似,但几年前的代码和当前的代码却会有很多地方不同,这对于初学者来说会是一个令人疑惑的点,至少对我来说有过这个困扰。所以下面就是记录一些经过pytorch的更新,而导致不同时间段内的代码不一样的情况,同时也对比分析指出目前应该使用的形式。 Tensor和VariableTensor是pytorch中非常重要且常见的数据结构,相较于numpy数组,Tensor能加载到GPU中,从而有效地利用GPU进行加速计算。但是普通的Tensor对于构建神经网络还远远不够,我们需要能够构建计算图的 tensor,这就是 Variable。Variable 是对 tensor 的封装,操作和 tensor 是一样的,但是每个 Variabel都有三个属性,分别是data表示变量中的具体值, grad表示这个变量反向传播的梯度, grad_fn表示是通过什么操作得到这个变量,例如( 加减乘除、卷积、反置卷积)。所以在构建计算图时,需要将tensor封装成Variable变量,这样才能在后续阶段进行反向传播计算梯度。很常见的代码示例如下:
注意,以上是pytorch 0.4之前的代码风格。但是在pytorch 0.4版本后,torch.Tensor 和torch.autograd.Variable变成同一个类了。torch.Tensor 能够像之前的Variable一样追踪历史和反向传播,也即不用使用torch.autograd.Variable来封装tensor以使其能够进行反向传播。但是Variable仍能够正常工作,只是返回的依旧是Tensor。目前常用的代码形式如下所示:
也就是说使用Variable封装tensor是老版本的使用方式了,虽然不会报错,但属实是可以但没必要。 tensor.data和tensor.detach在使用pytorch训练网络的过程中,遇到要将GPU上的tensor变量转换成CPU上的numpy数组这种情况是很常见的,比较常见的作法就是使用a.data.cpu().numpy()(具体细节可以参考这里,这不是本文重点),但是这就牵扯出一个安全性问题,使用.data真的安全吗? 首先,我们需要先知道这三个方法的意义:
由上面的描述会感觉这两者很相似,也没什么不同。但其实这两者会有一个不同的机制,即当y=x.data属性来访问数据时,pytorch不会记录数据是否改变,此时改变了y的值,意味着也要改变x的值,而在自动求导时会使用更改后的值,这会导致错误求导结果。而使用y=x.detach()时,如果了y值,也意味着改变了x值,此时调用x.backword()会报错,这也从侧面表明.detach()方法会记录数据的变化状态。也就是说tensor.data和tensor.detach都是复制计算图中的一份数据,此数据与原数据共享内存,当复制下来的数据因为某些操作变化时,就会影响到计算图梯度的计算,此时tensor.detach会记录下数据的变化状态,当数据有变化时会直接报错已防止计算图梯度被影响,而tensor.data则存在这个隐患,所以使用tensor.data是不安全的,会比较推荐tensor.detach操作。 具体代码示例如下,使用tensor.data分离计算图时,如果复制的值变化了,则梯度的计算会出错,但程序并不会知晓,依旧将其视为正确结果。而使用tensor.detach时,复制的值发生变化就会报错,这样就能避免这种难以发现的错误。
loss.data[0]和loss.item()在pytorch 0.4之前,广泛的loss使用是: 注意,当不转换为Python number来计算loss的累加值时,程序将会有大量的内存使用,这是因为 上述loss积累的式子可以变为: 总结经过上面的讨论,可以知道经过pytorch的不断更新,代码的形式相应的也会有一些变化,总结来看,可以归纳为三点:
参考链接pytorch入门一(Tensor、Variable)_朴素.无恙的博客-CSDN博客 pytorch入门之第一章Variable的理解_YYLin-CSDN博客_pytorch variable |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 | -2024/11/27 12:29:38- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |