IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 利用RNN构建语言模型 -> 正文阅读

[人工智能]利用RNN构建语言模型

这里省略了文本数据和数据集构建。具体可以查看这里。(这个是之前我按照李沐老师的课写的)

RNN输入输出

输入为当前向量 x x x(词),输出为预测向量 y y y。隐藏状态设为 h h h。其和上一个输入和上一个隐藏状态相关。RNN具体输入输出公式为:
h t = ? ( x t W x h + h t ? 1 W h h + b h ) y t = ? ( h t W h y + b y ) h_t = \phi( x_tW_{xh}+h_{t-1}W_{hh}+b_h)\\ y_t = \phi(h_tW_{hy}+b_y) ht?=?(xt?Wxh?+ht?1?Whh?+bh?)yt?=?(ht?Why?+by?)
在第一个 x x x输入时隐藏状态这里设为torch.zero((1, batch_size, num_hidden))

定义参数

import math
import torch
from torch import nn
from torch.nn import functional as F
# load_text函数参考链接
batch_size, num_steps = 32, 35
train_iter, vocab = load_text(tokens, batch_size, num_steps)
print(vocab.idx_to_token)
# 隐藏层大小
num_hidden = 256
rnn_layer = nn.RNN(len(vocab), num_hidden)
# 定义隐状态 (隐层数,批量大小, 隐藏单元数)
state = torch.zeros((1, batch_size, num_hidden))

定义RNN网络

class RNNmodel(nn.Module):
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNmodel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.liner = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.liner = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        output = self.liner(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # `nn.GRU` 以张量作为隐藏状态
            return torch.zeros((self.num_directions * self.rnn.num_layers,
                                batch_size, self.num_hiddens),
                               device=device)
        else:
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

定义预测函数

def predict(prefix, num_preds, net, vocab, device):
    state = net.begin_state(batch_size=1, device=device)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))
    for y in prefix[1:]:
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join(vocab.to_tokens(outputs))
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
net = RNNmodel(rnn_layer, len(vocab))
net = net.to(device)
predict('time traveller', 10, net, vocab, device)

输出:

'time travellerbbfbxbbbbf'

定义训练函数

def train_epoch(net, train_iter, loss, updater, device, use_random_iter):
    state = None
    L = torch.tensor([])
    for X, Y in train_iter:
        if state is None or use_random_iter:
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            if isinstance(net, nn.Module) and not isinstance(state, tuple):
                # `state` is a tensor for `nn.GRU`
                state.detach_()
            else:
                # `state` is a tuple of tensors for `nn.LSTM` and
                # for our custom scratch implementation
                for s in state:
                    s.detach_()
        y = Y.T.reshape(-1)
        X, y = X.to(device), y.to(device)
        y_hat, state = net(X, state)
        l = loss(y_hat, y.long()).mean()
        L = torch.cat((L, torch.tensor([l])), 0)
        updater.zero_grad()
        l.backward()
        if isinstance(net, nn.Module):
            params = [p for p in net.parameters() if p.requires_grad]
        else:
            params = net.params
        norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
        if norm > 1:  # 这里1可以自定义
            for param in params:
                param.grad[:] *= 1 / norm
        updater.step()
    return torch.mean(L)


def train(net, train_iter, vocab, lr, num_epochs, device, use_radom=False):
    loss = nn.CrossEntropyLoss()
    if isinstance(net, nn.Module):
        updater = torch.optim.SGD(net.parameters(), lr)
    else:
        return 'wrong'
    pre = lambda prefix: predict(prefix, 50, net, vocab, device)
    # 训练
    for epoch in range(num_epochs):
        l = train_epoch(net, train_iter, loss, updater, device, use_radom)
        if (epoch + 1) % 10 == 0:
            print(pre('time traveller'))
            print("[{}|{}]:loss={}".format(epoch+1, num_epochs, l))
    print(pre('time traveller'), '\n', pre('time machine'))
num_epochs, lr = 200, 2
train(net, train_iter, vocab, lr, num_epochs, device)

截取部分输出
在这里插入图片描述
在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-03 16:13:20  更:2022-03-03 16:17:50 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 18:25:41-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码