IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 小黑公式与代码力量积蓄:WordLSTMCell -> 正文阅读

[人工智能]小黑公式与代码力量积蓄:WordLSTMCell

1.WordCell原理图

在这里插入图片描述

2.WordCell门控单元原始公式:

i b , e w = σ ( W i , e x b , e w + W i , b h b , c + b i ) i_{b,e}^{w} = \sigma(W_{i,e}x_{b,e}^{w} + W_{i,b}h_{b,c} + b_{i}) ib,ew?=σ(Wi,e?xb,ew?+Wi,b?hb,c?+bi?)
f b , e w = σ ( W f , e x b , e w + W f , b h b , c + b f ) f_{b,e}^{w} = \sigma(W_{f,e}x_{b,e}^{w} + W_{f,b}h_{b,c} + b_{f}) fb,ew?=σ(Wf,e?xb,ew?+Wf,b?hb,c?+bf?)
c ~ b , e w = t a n h ( W f , e x b , e w + W f , b h b , c + b f ) \widetilde{c}_{b,e}^{w} = tanh(W_{f,e}x_{b,e}^{w} + W_{f,b}h_{b,c} + b_{f}) c b,ew?=tanh(Wf,e?xb,ew?+Wf,b?hb,c?+bf?)

维度分析:
W x x W_{xx} Wxx?.shape:[hidden_size,hidden_size]
x b , e w x_{b,e}^{w} xb,ew?.shape:[hidden_size,1](已经经过了变化,从input_size变成hidden_size)

3.WordCell门控单元简化公式:

在这里插入图片描述
维度分析:
W W T W^{W^{T}} WWT与原始公式 [ W i , e , W i , b ; W f , e , W f , b ; W f , e , W f , b ] [W_{i,e}, W_{i,b};W_{f,e},W_{f,b};W_{f,e},W_{f,b}] [Wi,e?,Wi,b?;Wf,e?,Wf,b?;Wf,e?,Wf,b?]
x b , e w x_{b,e}^{w} xb,ew?.shape:[hidden_size,1]
h b c h_{b}^{c} hbc?.shape:[hidden_size,1]
W W T W^{W^{T}} WWT.shape:[3 x hidden_size,2 x hidden_size]

4.代码实现

import torch
from torch import nn
import torch.autograd as autograd
from torch.autograd import Variable
from torch.nn import functional,init
import numpy as np

class WordLSTMCell(nn.Module):
    
    def __init__(self,input_size,hidden_size,use_bias = True):
        super(WordLSTMCell,self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias
        self.weight_ih = nn.Parameter(
            torch.FloatTensor(input_size,3 * hidden_size)
        )
        self.weight_hh = nn.Parameter(
            torch.FloatTensor(hidden_size,3 * hidden_size)
        )
        if use_bias:
            self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
        else:
            self.register_parameter('bias',None)
        self.reset_parameters()
    def reset_parameters(self):
        # 正交化参数
        init.orthogonal_(self.weight_ih.data)
        weight_hh_data = torch.eye(self.hidden_size)    # [hidden_size,hidden_size]
        weight_hh_data = weight_hh_data.repeat(1,3)   # [hidden_size,3 * hidden_size]
        with torch.no_grad():
            self.weight_hh.set_(weight_hh_data)
        if self.use_bias:
            init.constant_(self.bias.data,val = 0)
    
    def forward(self,input_,hx):
        # input_:[num_words,word_emb_dim]
        # h_0,c_0:[1,hidden_size]
        h_0,c_0 = hx
        batch_size = h_0.size(0)
        # bias_batch:[1,3 * hidden_size]
        bias_batch = self.bias.unsqueeze(0).expand(batch_size,*self.bias.size())
        # wh_b:[1,3 * hidden_size]
        wh_b = torch.addmm(bias_batch,h_0,self.weight_hh)
        # wi:[num_words,3 * hidden_size]
        wi = torch.mm(input_,self.weight_ih)
        # f,i,g:[num_words,hidden_size]
        f,i,g = torch.split(wh_b + wi,split_size_or_sections = self.hidden_size,dim = 1)
        # c_l:[num_words,hidden_size]
        c_l = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g)
        return c_l

num_words = 5
word_emb_dim = 100
hidden_size = 100
input_ = torch.randn([num_words,word_emb_dim])
h_0 = torch.randn([1,hidden_size])
c_0 = torch.randn([1,hidden_size])
hx = (h_0,c_0)
word_lstm_cell = WordLSTMCell(word_emb_dim,hidden_size)
print(word_lstm_cell(input_,hx).shape)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-15 22:31:54  更:2022-03-15 22:32:31 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 16:00:07-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码