IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 第五章:循环神经网络语言模型 -> 正文阅读

[人工智能]第五章:循环神经网络语言模型

第五章:循环神经网络语言模型(使用RNN实现静态词向量的预训练)

数据

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@Filename :5-2rnnlm.py
@Description :
@Datatime :2021/08/26 11:14:21
@Author :qtxu
@Version :v1.0
'''

from vocab import Vocab
from utils import BOS_TOKEN,EOS_TOKEN,BOW_TOKEN,EOW_TOKEN,PAD_TOKEN
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.nn.utils.rnn import pad_sequence
from utils import load_reuters, load_pretrained, save_pretrained, get_loader, init_weights


#  ----------------- 数据 --------------------------------
class RnnlmDataset(Dataset):
    def __init__(self, corpus, vocab):
        self.data = []
        self.bos = vocab[BOS_TOKEN]
        self.eos = vocab[EOS_TOKEN]
        self.pad = vocab[PAD_TOKEN]
        for sentence in tqdm(corpus, desc=f"Dataset Construction"):
            #模型输入序列:BOS_TOKEN,w_1,w_2,....,w_n
            input = [self.bos]+sentence
            #模型输出序列:w_1,w_2,....,w_n,EOS_TOKEN
            targets = sentence+[self.eos]
            self.data.append((input, targets))

    def __len__(self):
        return len(self.data)

    def __getitem__(self,i):
        return self.data[i]    

    def collate_fn(self, example):
        #从独立样本集合中构建批次输入输出
        inputs = [torch.tensor(ex[0]) for ex in example]
        targets = [torch.tensor(ex[1]) for ex in example]    
        #对批次内的样本进行长度补齐
        inputs = pad_sequence(inputs, batch_first= True, padding_value=self.pad)
        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad)
        return (inputs, targets)


#  ----------------- 模型 --------------------------------
class RNNLM(nn.Module):
    def __init__(self,vocab_size, embedding_dim, hidden_dim):
        super(RNNLM,self).__init__()
        #词向量层
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        #循环神经网络:这里使用LSTM
        self.rnn = nn.LSTM(embedding_dim, hidden_dim,batch_first=True)
        #输出层
        self.output = nn.Linear(hidden_dim,vocab_size)

    def forward(self,inputs):
        embeds = self.embeddings(inputs)
        #计算每一时刻的隐含层表示
        hidden, _ = self.rnn(embeds)
        output = self.output(hidden)
        log_probs = F.log_softmax(output,dim=2) #注意此处dim=2
        return log_probs
    

#  ----------------- 训练 --------------------------------

# 超参数设置
embedding_dim = 64
context_size = 2
hidden_dim = 128
batch_size = 16
num_epoch = 10

corpus, vocab = load_reuters()
dataset = RnnlmDataset(corpus, vocab)
dataloader = get_loader(dataset, batch_size)

nll_loss = nn.NLLLoss(ignore_index = dataset.pad)
#构建RNNLM,并加载至相应设备

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = RNNLM(len(vocab),embedding_dim,hidden_dim)
model.to(device)

#使用Adam优化器
optimizer = optim.Adam(model.parameters(), lr = 0.001)

model.train()
for epoch in range(num_epoch):
    total_loss = 0
    for batch in tqdm(dataloader,desc=f"Training Epoch {epoch}"):
        inputs, targets = [x.to(device) for x in batch]
        optimizer.zero_grad()
        log_probs = model(inputs)
        loss = nll_loss(log_probs.view(-1, log_probs.shape[-1]),targets.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"loss: {total_loss:.2f}")

#将词向量保存至rnnlm.vec
save_pretrained(vocab,model.embeddings.weight.data,"5-2rnnlm.vec")


vocab

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@Filename :vocab.py
@Description :
@Datatime :2021/08/24 17:02:51
@Author :qtxu
@Version :v1.0
'''

from collections import defaultdict,Counter

class Vocab:
    def __init__(self, tokens=None):
        self.idx_to_token = list()
        self.token_to_idx = dict()

        if tokens is not None:
            if "<unk>" not in tokens:
                tokens = tokens +["<unk>"]

            for token in tokens:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
            self.unk = self.token_to_idx['<unk>']        



    @classmethod
    def build(cls, text, min_freq=1, reserved_tokens=None):
        token_freqs = defaultdict(int)
        for sentence in text:
            for token in sentence:
                token_freqs[token] += 1
        uniq_tokens = ["<unk>"] + (reserved_tokens if reserved_tokens else [])
        uniq_tokens += [token for token, freq in token_freqs.items() \
                        if freq >= min_freq and token != "<unk>"]
        return cls(uniq_tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, token):
        return self.token_to_idx.get(token, self.unk)

    def convert_tokens_to_ids(self, tokens):
        return [self[token] for token in tokens]

    def convert_ids_to_tokens(self, indices):
        return [self.idx_to_token[index] for index in indices]


def save_vocab(vocab, path):
    with open(path, 'w') as writer:
        writer.write("\n".join(vocab.idx_to_token))


def read_vocab(path):
    with open(path, 'r') as f:
        tokens = f.read().split('\n')
    return Vocab(tokens)

utils

# -*- coding: utf-8 -*-
# @Time    : 2021-08-24 16:39
# @Author  : XAT
# @FileName: utils.py
# @Software: PyCharm

import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset
from vocab import Vocab
from nltk.corpus import reuters  # 从nltk中导入Reuters数据处理模块


# Constants
BOS_TOKEN = "<bos>"  # 句首标记
EOS_TOKEN = "<eos>"  # 句尾标记
PAD_TOKEN = "<pad>"  # 补齐序列长度的标记
BOW_TOKEN = "<bow>"
EOW_TOKEN = "<eow>"

WEIGHT_INIT_RANGE = 0.1


def load_reuters():
    text = reuters.sents()  # 获取Reuters数据中的所有句子(已完成标记解析)
    text = [[word.lower() for word in sentence]
            for sentence in text]  # 将语料中的词转换为小写(可选)
    vocab = Vocab.build(text, reserved_tokens=[
                        PAD_TOKEN, BOS_TOKEN, EOS_TOKEN])  # 构建词表,并传入预留标记
    corpus = [vocab.convert_tokens_to_ids(
        sentence) for sentence in text]  # 利用词表将文本数据转换为id表示

    return corpus, vocab


def save_pretrained(vocab, embeds, save_path):
    with open(save_path, "w") as writer:
        #记录词表大小
        writer.write(f"{embeds.shape[0]} {embeds.shape[1]}\n")
        for idx, token in enumerate(vocab.idx_to_token):
            vec = " ".join(["{:.4f}".format(x) for x in embeds[idx]])
            #每一行对应一个单词以及由空格分隔的词向量
            writer.write(f"{token} {vec}\n")
    print(f"Pretrained embeddings saved to:{save_path}")


def load_pretrained(load_path):
    with open(load_path, "r") as fin:
        n, d = map(int, fin.readline().split())
        tokens = []
        embeds = []
        for line in fin:
            line = line.rstrip().split(' ')
            token, embeds = line[0], list(map(float, line[1:]))
            tokens.append(token)
            embeds.append(embeds)
        vocab = Vocab(tokens)
        embeds = torch.tensor(embeds, dtype=torch.float)
    return vocab, embeds


def get_loader(dataset, batch_size, shuffle=True):
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=dataset.collate_fn,
        shuffle=shuffle
    )
    return data_loader


def init_weights(model):
    for name, param in model.named_parameters():
        # print("------------------------------------------------------")
        # print("model.named_parameters()",model.named_parameters())
        if "embedding" not in name:
            torch.nn.init.uniform_(
                param, a=-WEIGHT_INIT_RANGE, b=WEIGHT_INIT_RANGE)

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-27 11:51:08  更:2021-08-27 11:53:06 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 22:50:23-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码