IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 1014长短期记忆网络(LSTM) -> 正文阅读

[人工智能]1014长短期记忆网络(LSTM)

长短期记忆网络(LSTM)

  • 长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题,解决这个问题最早的方法之一就是 LSTM
  • 发明于90年代
  • 使用的效果和 GRU 相差不大,但是使用的东西更加复杂

  • 长短期记忆网络的设计灵感来自于计算机的逻辑门
  • 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)(有些文献认为记忆元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息)
  • 长短期记忆网络有三个门:忘记门(重置单元的内容,通过专用机制决定什么时候记忆或忽略隐状态中的输入)、输入门(决定何时将数据读入单元)、输出门(从单元中输出条目),门的计算和 GRU 中相同,但是命名不同
  • 忘记门(forget gate):将值朝 0 减少
  • 输入门(input gate):决定是否忽略掉输入数据
  • 输出门(output gate):决定是否使用隐状态

  • 类似于门控循环单元,当前时间步的输入前一个时间步的隐状态作为数据送入长短期记忆网络的门中,由三个具有 sigmoid 激活函数的全连接层处理,以计算输入门、遗忘门和输出门的值(这三个门的值都在 0~1 的范围内)

候选记忆单元(candidate memory cell)

  • 候选记忆元的计算与输入门、遗忘门、输出门的计算类似,但是使用了 tanh 函数作为激活函数,函数的值在 -1~1 之间

记忆单元

  • 在长短期记忆网络中,通过输入门遗忘门来控制输入和遗忘(或跳过):输入门 It 控制采用多少来自 Ct tilde 的新数据,而遗忘门 Ft 控制保留多少过去的记忆元 C(t-1) 的内容
  • 如果遗忘门始终为 1 且输入门始终为 0 ,则过去的记忆元 C(t-1) 将随时间被保存并传递到当前时间步(引入这种设计是为了缓解梯度消失的问题,并更好地捕获序列中的长距离依赖关系)
  • 上一时刻的记忆单元会作为状态输入到模型中
  • LSTM 和 RNN/GRU 的不同之处在于: LSTM 中的状态有两个, C 和 H

隐状态

  • 在长短期记忆网络中,隐状态 Ht 仅仅是记忆元 Ct 的 tanh 的门控版本,因此确保了 Ht 的值始终在 -1~1 之间
  • tanh 的作用:将 Ct 的值限制在 -1 和 1 之间
  • Ot 控制是否输出, Ot 接近 1 ,则能有效地将所有记忆信息传递给预测部分; Ot 接近 0 ,表示丢弃当前的 Xt 和过去所有的信息,只保留记忆元内的所有信息,而不需要更新隐状态

总结

1 LSTM 和 GRU 所想要实现的效果是差不多的,但是结构更加复杂

  • C :一个数值可能比较大的辅助记忆单元
  • C 中包含两项: 当前的 Xt过去的状态(在 GRU 中只能二选一,这里可以实现两个都选)

2 长短期记忆网络包含三种类型的门:输入门遗忘门输出门

3 长短期记忆网络的隐藏层输出包括“隐状态”“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息

4 长短期记忆网络可以缓解梯度消失和梯度爆炸

5 长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。多年来已经提出了其他许多变体,例如,多层、残差连接、不同类型的正则化。但是由于序列的长距离依赖性,训练长短期记忆网络和其他序列模型(如门控循环单元)的成本较高


代码:

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size
    
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    def three():
        return (normal(
            (num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))
    
    W_xi, W_hi, b_i = three()
    W_xf, W_hf, b_f = three()
    W_xo, W_ho, b_o = three()
    W_xc, W_hc, b_c = three()
    W_hq = normal((num_hiddens, num_outputs))
    b_q  = torch.zeros(num_outputs, device=device)
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]    
    for param in params:
        param.requires_grad_(True)
    return params

# 初始化函数
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
           torch.zeros((batch_size, num_hiddens),device=device))

# 实际模型
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

# 训练
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)       
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

# 简洁实现
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
mode = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

?

# 简洁实现
num_inputs = vocab_size
lstm_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
mode = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

?

# 简洁实现
num_inputs = vocab_size
lstm_layer = nn.RNN(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
mode = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-10-17 12:33:42  更:2022-10-17 12:35:47 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 2:37:06-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计