IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch随笔--Tensor与Variable -> 正文阅读

[人工智能]Pytorch随笔--Tensor与Variable

前言

该文章是我回过头来看之前的pytorch笔记的一些记录,以自己的思路将一些相关的内容集合起来进行记录,算是对之前阶段学习的总结。

问题描述

刚接触pytorch的时候,一个很常见的操作就是参考别人的代码或者观看别人的教学视频。但随着pytorch的更新,一些特性便会发生变化,这就很明显地会体现在代码上,比如之前的一些代码和现在的代码将会变得不一样,尽管它们实现的功能基本是一样的。尤其是对于训练模型这一部分来说,基本的流程会很相似,但几年前的代码和当前的代码却会有很多地方不同,这对于初学者来说会是一个令人疑惑的点,至少对我来说有过这个困扰。所以下面就是记录一些经过pytorch的更新,而导致不同时间段内的代码不一样的情况,同时也对比分析指出目前应该使用的形式。

Tensor和Variable

Tensor是pytorch中非常重要且常见的数据结构,相较于numpy数组,Tensor能加载到GPU中,从而有效地利用GPU进行加速计算。但是普通的Tensor对于构建神经网络还远远不够,我们需要能够构建计算图的 tensor,这就是 Variable。Variable 是对 tensor 的封装,操作和 tensor 是一样的,但是每个 Variabel都有三个属性,分别是data表示变量中的具体值, grad表示这个变量反向传播的梯度, grad_fn表示是通过什么操作得到这个变量,例如( 加减乘除、卷积、反置卷积)。所以在构建计算图时,需要将tensor封装成Variable变量,这样才能在后续阶段进行反向传播计算梯度。很常见的代码示例如下:

import torch
from torch.autograd import Variable

x = Variable(torch.randn(10, 20), requires_grad=True)
y = Variable(torch.randn(10, 5), requires_grad=True)
w = Variable(torch.randn(20, 5), requires_grad=True)

out = torch.mean(y - torch.matmul(x, w))
out.backward()

注意,以上是pytorch 0.4之前的代码风格。但是在pytorch 0.4版本后,torch.Tensor 和torch.autograd.Variable变成同一个类了。torch.Tensor 能够像之前的Variable一样追踪历史和反向传播,也即不用使用torch.autograd.Variable来封装tensor以使其能够进行反向传播。但是Variable仍能够正常工作,只是返回的依旧是Tensor。目前常用的代码形式如下所示:

import torch

x =torch.randn((10, 20), requires_grad=True)
y = torch.randn((10, 5), requires_grad=True)
w = torch.randn((20, 5), requires_grad=True)

out = torch.mean(y - torch.matmul(x, w))
out.backward()

也就是说使用Variable封装tensor是老版本的使用方式了,虽然不会报错,但属实是可以但没必要。

tensor.data和tensor.detach

在使用pytorch训练网络的过程中,遇到要将GPU上的tensor变量转换成CPU上的numpy数组这种情况是很常见的,比较常见的作法就是使用a.data.cpu().numpy()(具体细节可以参考这里,这不是本文重点),但是这就牵扯出一个安全性问题,使用.data真的安全吗?

首先,我们需要先知道这三个方法的意义:

  • tensor .data?返回和 x 的相同数据 tensor,而且这个新的tensor和原来的tensor是共用数据的,一者改变,另一者也会跟着改变,而且新分离得到的tensor的require s_grad = False, 即不可求导的。
  • tensor .detach() 返回和 x 的相同数据 tensor,而且这个新的tensor和原来的tensor是共用数据的,一者改变,另一者也会跟着改变,而且新分离得到的tensor的require s_grad = False, 即不可求导的。(这一点其实和?tensor.data 是一样的)

由上面的描述会感觉这两者很相似,也没什么不同。但其实这两者会有一个不同的机制,即当y=x.data属性来访问数据时,pytorch不会记录数据是否改变,此时改变了y的值,意味着也要改变x的值,而在自动求导时会使用更改后的值,这会导致错误求导结果。而使用y=x.detach()时,如果了y值,也意味着改变了x值,此时调用x.backword()会报错,这也从侧面表明.detach()方法会记录数据的变化状态。也就是说tensor.data和tensor.detach都是复制计算图中的一份数据,此数据与原数据共享内存,当复制下来的数据因为某些操作变化时,就会影响到计算图梯度的计算,此时tensor.detach会记录下数据的变化状态,当数据有变化时会直接报错已防止计算图梯度被影响,而tensor.data则存在这个隐患,所以使用tensor.data是不安全的,会比较推荐tensor.detach操作。

具体代码示例如下,使用tensor.data分离计算图时,如果复制的值变化了,则梯度的计算会出错,但程序并不会知晓,依旧将其视为正确结果。而使用tensor.detach时,复制的值发生变化就会报错,这样就能避免这种难以发现的错误。

a = torch.tensor([1,2,3.], requires_grad = True)
out = a.sigmoid()

c = out.data
c.zero_()
out.sum().backward()
a.grad  # The result is very, very wrong because `out` changed!
# out: tensor([ 0.,  0.,  0.])

c = out.detach()
c.zero_()
out.sum().backward()  # error
# out: RuntimeError: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

loss.data[0]和loss.item()

在pytorch 0.4之前,广泛的loss使用是:total_loss += loss.data[0], 这是因为loss是variable,其中size为(1,),而在0.4之后,loss是一个0维的向量(标量)。此时就可以可以使用loss.item()从标量中获得Python number,也即获取的是python的基本数据类型。

注意,当不转换为Python number来计算loss的累加值时,程序将会有大量的内存使用,这是因为total_loss += loss.data[0]式子的右边是0维tensor。这样,总损失就会包含loss和梯度历史,这些梯度历史会在大的autograd graph中保存更长时间,带来大量的内存消耗。换句话说,pytorch使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。如果这里使用loss.data[0]相加,系统会认为这也是计算图的一部分,也就是说网络会一直延伸变大,那么消耗的显存也就越来越多。而使用loss.item()则不会遇到这种问题,所以推荐使用loss.item()来获取每个step的loss。(注意,loss.item()只能针对只有一个元素的tensor使用,如何结果是一个tensor列表,可以使用loss.tolist(),这样获取的依然是Python number)

上述loss积累的式子可以变为:
total_loss += loss.item()

总结

经过上面的讨论,可以知道经过pytorch的不断更新,代码的形式相应的也会有一些变化,总结来看,可以归纳为三点:

  • 不再需要将tensor包装进Variable中,torch.Tensor 和torch.autograd.Variable是同一个类了
  • 推荐使用tensor.detach(),因为其比tensor.data更安全
  • 推荐使用loss.item()获取loss,应为使用loss.data[0]会有显存消耗越来越大的隐患

参考链接

pytorch入门一(Tensor、Variable)_朴素.无恙的博客-CSDN博客

pytorch入门之第一章Variable的理解_YYLin-CSDN博客_pytorch variable

pytorch中的.detach和.data深入详解_MIss-Y的博客-CSDN博客

pytorch学习:loss为什么要加item()_github_38148039的博客-CSDN博客

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-23 11:28:02  更:2021-09-23 11:29:34 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 12:29:38-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码