IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【手撕LSTM】LSTM的numpy实现 -> 正文阅读

[人工智能]【手撕LSTM】LSTM的numpy实现

在这里插入图片描述

详细理论部分参考我博文(2020李宏毅)机器学习-Recurrent Neural Network

LSTM原理图

在这里插入图片描述

公式
F t = σ ( X t W x f + H t ? 1 W h f + b f ) \mathbf{F}_{t}=\sigma\left(\mathbf{X}_{t} \mathbf{W}_{x f}+\mathbf{H}_{t-1} \mathbf{W}_{h f}+\mathbf{b}_{f}\right) Ft?=σ(Xt?Wxf?+Ht?1?Whf?+bf?)
I t = σ ( X t W x i + H t ? 1 W h i + b i ) \mathbf{I}_{t}=\sigma\left(\mathbf{X}_{t} \mathbf{W}_{x i}+\mathbf{H}_{t-1} \mathbf{W}_{h i}+\mathbf{b}_{i}\right) It?=σ(Xt?Wxi?+Ht?1?Whi?+bi?)
C ~ t = tanh ? ( X t W x c + H t ? 1 W h c + b c ) \tilde{\mathbf{C}}_{t}=\tanh \left(\mathbf{X}_{t} \mathbf{W}_{x c}+\mathbf{H}_{t-1} \mathbf{W}_{h c}+\mathbf{b}_{c}\right) C~t?=tanh(Xt?Wxc?+Ht?1?Whc?+bc?)
C t = F t ⊙ C t ? 1 + I t ⊙ C ~ t \mathbf{C}_{t}=\mathbf{F}_{t} \odot \mathbf{C}_{t-1}+\mathbf{I}_{t} \odot \tilde{\mathbf{C}}_{t} Ct?=Ft?Ct?1?+It?C~t?
O t = σ ( X t W x o + H t ? 1 W h o + b o ) \mathbf{O}_{t}=\sigma\left(\mathbf{X}_{t} \mathbf{W}_{x o}+\mathbf{H}_{t-1} \mathbf{W}_{h o}+\mathbf{b}_{o}\right) Ot?=σ(Xt?Wxo?+Ht?1?Who?+bo?)
H t = O t ⊙ tanh ? ( C t ) \mathbf{H}_{t}=\mathbf{O}_{t} \odot \tanh \left(\mathbf{C}_{t}\right) Ht?=Ot?tanh(Ct?)

便于程序实现的公式(简化版公式)

F t = σ ( W f [ H t ? 1 , X t ] + b f ) \mathbf{F}_{t}=\sigma\left(\mathbf{W}_{f}[\mathbf{H}_{t-1},\mathbf{X}_{t}] + \mathbf{b}_{f}\right) Ft?=σ(Wf?[Ht?1?,Xt?]+bf?)
I t = σ ( W i [ H t ? 1 , X t ] + b i ) \mathbf{I}_{t}=\sigma\left(\mathbf{W}_{i}[\mathbf{H}_{t-1},\mathbf{X}_{t}]+\mathbf{b}_{i}\right) It?=σ(Wi?[Ht?1?,Xt?]+bi?)
C ~ t = tanh ? ( W c [ H t ? 1 , X t ] + b c ) \tilde{\mathbf{C}}_{t}=\tanh \left(\mathbf{W}_{c}[\mathbf{H}_{t-1},\mathbf{X}_{t}] +\mathbf{b}_{c}\right) C~t?=tanh(Wc?[Ht?1?,Xt?]+bc?)
C t = F t ⊙ C t ? 1 + I t ⊙ C ~ t \mathbf{C}_{t}=\mathbf{F}_{t} \odot \mathbf{C}_{t-1}+\mathbf{I}_{t} \odot \tilde{\mathbf{C}}_{t} Ct?=Ft?Ct?1?+It?C~t?
O t = σ ( W o [ H t ? 1 , X t ] + b o ) \mathbf{O}_{t}=\sigma\left(\mathbf{W}_{o}[\mathbf{H}_{t-1},\mathbf{X}_{t}]+\mathbf{b}_{o}\right) Ot?=σ(Wo?[Ht?1?,Xt?]+bo?)
H t = O t ⊙ tanh ? ( C t ) \mathbf{H}_{t}=\mathbf{O}_{t} \odot \tanh \left(\mathbf{C}_{t}\right) Ht?=Ot?tanh(Ct?)

关于“门”

遗忘门

在LSTM中,遗忘门可以实现操作:
F t = σ ( W f [ H t ? 1 , X t ] + b f ) \mathbf{F}_{t}=\sigma\left(\mathbf{W}_{f}[\mathbf{H}_{t-1},\mathbf{X}_{t}] + \mathbf{b}_{f}\right) Ft?=σ(Wf?[Ht?1?,Xt?]+bf?)
在这里, W f W_f Wf?是控制遗忘门行为的权重。将 [ H t ? 1 , X t ] [\mathbf{H}_{t-1},\mathbf{X}_{t}] [Ht?1?,Xt?]连接起来,然后乘以 W f W_f Wf?。上面的等式使得向量 F t \mathbf{F}_{t} Ft?的值介于0到1之间。该遗忘门向量将逐元素乘以先前的单元状态 C t ? 1 \mathbf{C}_{t-1} Ct?1?。因此,如果 F t \mathbf{F}_{t} Ft?的其中一个值为0(或接近于0),则表示LSTM应该移除 C t ? 1 \mathbf{C}_{t-1} Ct?1?中的一部分信息,如果其中一个值为1,则它将保留信息。

输入门

输入门的公式:
I t = σ ( W i [ H t ? 1 , X t ] + b i ) \mathbf{I}_{t}=\sigma\left(\mathbf{W}_{i}[\mathbf{H}_{t-1},\mathbf{X}_{t}]+\mathbf{b}_{i}\right) It?=σ(Wi?[Ht?1?,Xt?]+bi?)
类似于遗忘门,在这里 I t \mathbf{I}_{t} It?也是值为0到1之间的向量。这将与 C ~ t \tilde{\mathbf{C}}_{t} C~t?逐元素相乘以计算 C t \mathbf{C}_{t} Ct?

更新memory

新的输入向量:
C ~ t = tanh ? ( W c [ H t ? 1 , X t ] + b c ) \tilde{\mathbf{C}}_{t}=\tanh \left(\mathbf{W}_{c}[\mathbf{H}_{t-1},\mathbf{X}_{t}] +\mathbf{b}_{c}\right) C~t?=tanh(Wc?[Ht?1?,Xt?]+bc?)
最后,新的memory状态为:
C t = F t ⊙ C t ? 1 + I t ⊙ C ~ t \mathbf{C}_{t}=\mathbf{F}_{t} \odot \mathbf{C}_{t-1}+\mathbf{I}_{t} \odot \tilde{\mathbf{C}}_{t} Ct?=Ft?Ct?1?+It?C~t?

输出门

为了确定接下来将使用哪些输出,使用以下两个公式:

O t = σ ( W o [ H t ? 1 , X t ] + b o ) \mathbf{O}_{t}=\sigma\left(\mathbf{W}_{o}[\mathbf{H}_{t-1},\mathbf{X}_{t}]+\mathbf{b}_{o}\right) Ot?=σ(Wo?[Ht?1?,Xt?]+bo?)
H t = O t ⊙ tanh ? ( C t ) \mathbf{H}_{t}=\mathbf{O}_{t} \odot \tanh \left(\mathbf{C}_{t}\right) Ht?=Ot?tanh(Ct?)

LSTM单元

实现上图中描述的LSTM单元。

说明

  1. H t ? 1 \mathbf{H}_{t-1} Ht?1? X t \mathbf{X}_{t} Xt?连接在一个矩阵中: c o n c a t = [ H t ? 1 X t ] concat = \begin{bmatrix} \mathbf{H}_{t-1} \\ \mathbf{X}_{t}\end{bmatrix} concat=[Ht?1?Xt??]
  2. 计算以上公式,使用sigmoid()np.tanh()
  3. 计算预测 y ? t ? y^{\langle t \rangle} y?t?,使用softmax()
  4. 预测 y ^ \hat y y^?公式为 y ^ = s o f t m a x ( W y H t + b y ) \hat y=softmax(W_yH_t+b_y) y^?=softmax(Wy?Ht?+by?)
import numpy as np
def sigmoid(x):
    return 1/(1+np.exp(-x))

def softmax(x):
    e_x = np.exp(x-np.max(x))# 防溢出
    return e_x/e_x.sum(axis=0)
def LSTM_CELL_Forward(xt,h_prev,C_prev,parameters):
    """
    Arguments:
    xt:时间步“t”处输入的数据 shape(n_x,m)
    h_prev:时间步“t-1”的隐藏状态 shape(n_h,m)
    C_prev:时间步“t-1”的memory状态 shape(n_h,m)
    parameters
        Wf 遗忘门的权重矩阵 shape(n_h,n_h+n_x)
        bf 遗忘门的偏置 shape(n_h,1)
        Wi 输入门的权重矩阵 shape(n_h,n_h+n_x)
        bi 输入门的偏置 shape(n_h,1)
        Wc 第一个“tanh”的权重矩阵 shape(n_h,n_h+n_x)
        bc 第一个“tanh”的偏差 shape(n_h,1)
        Wo 输出门的权重矩阵 shape(n_h,n_h+n_x)
        bo 输出门的偏置 shape(n_h,1)
        Wy 将隐藏状态与输出关联的权重矩阵 shape(n_y,n_h)
        by 隐藏状态与输出相关的偏置 shape(n_y,1)
    Returns:
    h_next -- 下一个隐藏状态 shape(n_h,m)
    c_next -- 下一个memory状态 shape(n_h,m)
    yt_pred -- 时间步长“t”的预测 shape(n_y,m)
    """
    # 获取参数字典中各个参数
    Wf = parameters["Wf"]
    bf = parameters["bf"]
    Wi = parameters["Wi"]
    bi = parameters["bi"]
    Wc = parameters["Wc"]
    bc = parameters["bc"]
    Wo = parameters["Wo"]
    bo = parameters["bo"]
    Wy = parameters["Wy"]
    by = parameters["by"]
    
    # 获取 xt 和 Wy 的维度参数
    n_x, m = xt.shape
    n_y, n_h = Wy.shape
    
    #拼接 h_prev 和 xt
    concat = np.zeros((n_x+n_h,m))
    concat[: n_h, :] = h_prev
    concat[n_h :, :] = xt
    
    # 计算遗忘门、输入门、记忆细胞候选值、下一时间步的记忆细胞、输出门和下一时间步的隐状态值
    ft = sigmoid(np.dot(Wf,concat)+bf)
    it = sigmoid(np.dot(Wi,concat)+bi)
    cct = np.tanh(np.dot(Wc,concat)+bc)
    c_next = ft*c_prev + it*cct
    ot = sigmoid(np.dot(Wo,concat)+bo)
    h_next = ot*np.tanh(c_next)
    
    # LSTM单元的计算预测
    yt_pred = softmax(np.dot(Wy, h_next) + by)
    
    return h_next,c_next,yt_pred
np.random.seed(1)
xt = np.random.randn(3,10)
h_prev = np.random.randn(5,10)
c_prev = np.random.randn(5,10)
Wf = np.random.randn(5, 5+3)
bf = np.random.randn(5,1)
Wi = np.random.randn(5, 5+3)
bi = np.random.randn(5,1)
Wo = np.random.randn(5, 5+3)
bo = np.random.randn(5,1)
Wc = np.random.randn(5, 5+3)
bc = np.random.randn(5,1)
Wy = np.random.randn(2,5)
by = np.random.randn(2,1)

parameters = {"Wf": Wf, "Wi": Wi, "Wo": Wo, "Wc": Wc, "Wy": Wy, "bf": bf, "bi": bi, "bo": bo, "bc": bc, "by": by}

h_next, c_next, yt = LSTM_CELL_Forward(xt, h_prev, c_prev, parameters)
print("a_next[4] = ", h_next[4])
print("a_next.shape = ", c_next.shape)
print("c_next[2] = ", c_next[2])
print("c_next.shape = ", c_next.shape)
print("yt[1] =", yt[1])
print("yt.shape = ", yt.shape)
a_next[4] =  [-0.66408471  0.0036921   0.02088357  0.22834167 -0.85575339  0.00138482
  0.76566531  0.34631421 -0.00215674  0.43827275]
a_next.shape =  (5, 10)
c_next[2] =  [ 0.63267805  1.00570849  0.35504474  0.20690913 -1.64566718  0.11832942
  0.76449811 -0.0981561  -0.74348425 -0.26810932]
c_next.shape =  (5, 10)
yt[1] = [0.79913913 0.15986619 0.22412122 0.15606108 0.97057211 0.31146381
 0.00943007 0.12666353 0.39380172 0.07828381]
yt.shape =  (2, 10)

预期输出:
a_next[4] = [-0.66408471 0.0036921 0.02088357 0.22834167 -0.85575339 0.00138482
0.76566531 0.34631421 -0.00215674 0.43827275]

a_next.shape = (5, 10)

c_next[2] = [ 0.63267805 1.00570849 0.35504474 0.20690913 -1.64566718 0.11832942
0.76449811 -0.0981561 -0.74348425 -0.26810932]

c_next.shape = (5, 10)

yt[1] = [0.79913913 0.15986619 0.22412122 0.15606108 0.97057211 0.31146381
0.00943007 0.12666353 0.39380172 0.07828381]

yt.shape = (2, 10)

参考
https://zh-v2.d2l.ai/chapter_recurrent-modern/lstm.html
https://www.heywhale.com/mw/project/6174b96ef7e7c300175739cc

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-25 12:32:33  更:2021-10-25 12:35:08 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 10:05:27-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码