IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 梯度剪裁: torch.nn.utils.clip_grad_norm_() -> 正文阅读

[人工智能]梯度剪裁: torch.nn.utils.clip_grad_norm_()


前言

当神经网络深度逐渐增加,网络参数量增多的时候,反向传播过程中链式法则里的梯度连乘项数便会增多,更易引起梯度消失和梯度爆炸。对于梯度爆炸问题,解决方法之一便是进行梯度剪裁,即设置一个梯度大小的上限。本文介绍了pytorch中梯度剪裁方法的原理和使用方法。


一、原理

注:为了防止混淆,本文对神经网络中的参数称为“网络参数”,其他程序相关参数成为“参数”。

pytorch中梯度剪裁方法为 torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2)1。三个参数:

parameters:希望实施梯度裁剪的可迭代网络参数
max_norm:该组网络参数梯度的范数上限
norm_type:范数类型

官方对该方法的描述为:

"Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place."

“对一组可迭代(网络)参数的梯度范数进行裁剪。效果如同将所有参数连接成单个向量来计算范数。梯度原位修改。”

我们来逐段分析其实现代码:

def clip_grad_norm_(parameters, max_norm, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)

该部分处理了传入的三个参数。首先将parameters中的非空网络参数存入一个列表,然后将max_normnorm_type类型强制为浮点数。

    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in parameters)

该句对无穷范数进行了单独计算,即取所有网络参数梯度范数中的最大值,定义为total_norm
t o t a l _ n o r m ∞ = max ? p i ∈ P ∣ g r a d ( p i ) ∣ {total\_norm}^{\infty}=\max_{pi\in {P}}|grad(p_i)| total_norm=piPmax?grad(pi?)

    else:
        total_norm = 0
        for p in parameters:
            param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm.item() ** norm_type
        total_norm = total_norm ** (1. / norm_type)

对于其他范数,我们计算所有网络参数梯度范数之和,再归一化,即等价于把所有网络参数放入一个向量,再对向量计算范数。将结果定义为total_norm
t o t a l _ n o r m n o r m _ t y p e = { ∑ p i ∈ P [ g r a d ( p i ) ] n o r m _ t y p e } 1 n o r m _ t y p e {total\_norm}^{norm\_type}=\{\sum_{pi\in {P}}[grad(p_i)]^{norm\_type}\}^{\frac{1}{norm\_type}} total_normnorm_type={piP?[grad(pi?)]norm_type}norm_type1?

    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            p.grad.data.mul_(clip_coef)
    return total_norm

最后定义了一个“裁剪系数”变量clip_coef,为传入参数max_normtotal_norm的比值(+1e-6防止分母为0的情况)。如果max_norm > total_norm,即没有溢出预设上限,则不对梯度进行修改。反之则以clip_coef为系数对全部梯度进行惩罚,使最后的全部梯度范数归一化至max_norm的值。注意该方法返回了一个 total_norm,实际应用时可以通过该方法得到网络参数梯度的范数,以便确定合理的max_norm值。
t o t a l _ n o r m ′ = { ∑ p i ∈ P [ g r a d ( p i ) c l i p _ c o e f ] n o r m _ t y p e } 1 n o r m _ t y p e = { ∑ p i ∈ P [ g r a d ( p i ) m a x _ n o r m t o t a l _ n o r m ] n o r m _ t y p e } 1 n o r m _ t y p e = { ∑ p i ∈ P [ g r a d ( p i ) ] n o r m _ t y p e } 1 n o r m _ t y p e t o t a l _ n o r m ? m a x _ n o r m = m a x _ n o r m \begin{aligned} {total\_norm}'&=\{\sum_{pi\in {P}}[\frac{grad(p_i)}{clip\_coef}]^{norm\_type}\}^{\frac{1}{norm\_type}} \\&=\{\sum_{pi\in {P}}[\frac{grad(p_i)}{\frac{max\_norm}{total\_norm}}]^{norm\_type}\}^{\frac{1}{norm\_type}} \\&=\frac{\{\sum_{pi\in {P}}[grad(p_i)]^{norm\_type}\}^{\frac{1}{norm\_type}}}{total\_norm} \cdot max\_norm \\&=max\_norm \end{aligned} total_norm?={piP?[clip_coefgrad(pi?)?]norm_type}norm_type1?={piP?[total_normmax_norm?grad(pi?)?]norm_type}norm_type1?=total_norm{piP?[grad(pi?)]norm_type}norm_type1???max_norm=max_norm?

二、使用方法

每一次迭代中,梯度处理的过程应该是:

计算梯度
裁剪梯度
更新网络参数

因此 torch.nn.utils.clip_grad_norm_() 的使用应该在loss.backward() 之后,**optimizer.step()**之前:

...
loss = crit(...)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=10, norm_type=2)
optimizer.step()
...

总结

本文从实现代码角度分析了pytorch中梯度裁剪方法 torch.nn.utils.clip_grad_norm_() 的原理和使用方法。


  1. 旧版为torch.nn.utils.clip_grad_norm() ??

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-09 10:14:30  更:2021-08-09 10:14:52 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 1:54:04-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码