IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch 梯度相关知识点 -> 正文阅读

[人工智能]pytorch 梯度相关知识点

1. requires_grad

如果需要为张量计算所需的梯度,那么我们就需要对张量设置requires_grad=True;张量创建的时候默认requires_grad=False

  • 如果不设置requires_grad=True,后续计算梯度的时候就会报错
    (1)requires_grad=False&默认设置
import torch
from torch import nn

# 创建一个输入x,默认设置
x = torch.ones(5)
# y = 2*x**2
y = 2*torch.dot(x,x)
# y 进行梯度返传
y.backward()
# 打印x的梯度,即x.grad
print(f"x.grad={x.grad}")
  • 结果
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

(2)requires_grad=False

import torch
from torch import nn

x_false = torch.ones(5, requires_grad=False)
y_false = 2 * torch.dot(x_false, x_false)
y_false.backward()
print(f"x_false.grad={x_false.grad}")
  • 结果
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

(3)requires_grad=True

import torch
from torch import nn


x_true = torch.ones(5, requires_grad=True)
y_true = 2 * torch.dot(x_true,x_true)
y_true.backward()
print(f"x_true.grad={x_true.grad}")
print(f"4*x_true={4*x_true}")
print(f"x_true.grad==4*x_true={x_true.grad==4*x_true}")
  • 结果:
x_true.grad=tensor([4., 4., 4., 4., 4.])
4*x_true=tensor([4., 4., 4., 4., 4.], grad_fn=<MulBackward0>)
x_true.grad==4*x_true=tensor([True, True, True, True, True])

2. grad_fn,grad

grad:表示当执行完y.backward()后,可以通过x.grad计算x变量的梯度
grad_fn是用来记录变量是怎么来的,记录图节点的方式,为了后续反向传播做准备
z = 2 ? x 2 + 6 z=2*x^2+6 z=2?x2+6
由上述公式可得:

  • x:最底层的生物,牛马如我;故x.grad_fn=None
  • y = 2 ? x 2 y=2*x^2 y=2?x2:来源于乘法,故y.grad_fn = MulBackward
  • z = y + 6 z=y+6 z=y+6:来源于加法,故z.grad_fn = AddBackward
x_true = torch.ones(5, requires_grad=True)
y_true = 2 * torch.dot(x_true, x_true)
z_true = y_true + 6
z_true.backward()
print(f"x_true.grad={x_true.grad}")
print(f"x_true.grad_fn={x_true.grad_fn}")
print(f"y_true.grad_fn={y_true.grad_fn}")
print(f"z_true.grad_fn={z_true.grad_fn}")

结果:

x_true.grad=tensor([4., 4., 4., 4., 4.])
x_true.grad_fn=None
y_true.grad_fn=<MulBackward0 object at 0x00000180E0DF3550>
z_true.grad_fn=<AddBackward0 object at 0x00000180E0DF3550>

3. with torch.no_grad()

torch.no_grad
禁用梯度计算的上下文管理器;当您确定不会调用Tensor.backward()时,禁用梯度计算对推断很有用。它将减少原本需要requires_grad=True的计算的内存消耗
有两种方式设置:

  • with torch.no_grad()
  • @torch.no_grad()
# 定义一个张量x
x = torch.tensor([1.0], requires_grad=True)

# 用with torch.no_grad()防止计算梯度
with torch.no_grad():
	y = x * 2

print(f"y.requires_grad={y.requires_grad}")
#输出结果: y.requires_grad=False
# 用@torch.no_grad()

@torch.no_grad()
def doubler(x):
	return x * 2


z = doubler(x)
print(f"z.requires_grad={z.requires_grad}")
# z.requires_grad=False

4. tensor.detach()

返回一个新的张量,从当前图分离出来。

import torch
from torch import nn

x = torch.ones(5)
# destination
y = torch.zeros(3)
w = torch.randn((5, 3), requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z,y)

print(f"z.grad_fn={z.grad_fn}")
print(f"loss.grad_fn={loss.grad_fn}")
loss.backward()
print(f"w.grad={w.grad}")
print(f"b.grad={b.grad}")
print(f"z.requires_grad={z.requires_grad}")
# with torch.no_grad():

# 新建一个张量z_dect
z_dect = z.detach()
print(f"after_detach:z.requires_grad={z_dect.requires_grad}")

5. 小结

有如下原因需要禁用梯度:
(1)将神经网络中的一些参数标记为冻结参数。这是对预先训练的网络进行微调的一个非常常见的场景;
(2)在只进行正向传递的情况下加快计算速度,因为在不跟踪梯度的张量上的计算将更加有效

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-15 22:31:54  更:2022-03-15 22:35:31 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 15:28:22-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码