IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> CBOW模型,源码实现 -> 正文阅读

[Python知识库]CBOW模型,源码实现

作者:recommend-item-box type_blog clearfix

在已经写好vocab.py和utils.py两个函数的前提下(可参阅另一篇博客),只需要构建下面的函数即可。

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@Filename :5-3Cbow.py
@Description :
@Datatime :2021/08/28 15:41:21
@Author :qtxu
@Version :v1.0
'''


# import torch
# import torch.nn.functional as F
# import torch.nn as nn
# from utils import BOS_TOKEN,BOW_TOKEN,EOS_TOKEN,EOW_TOKEN,PAD_TOKEN
# from utils import init_weights
# from vocab import read_vocab, save_vocab, 
# from tqdm.auto import tqdm
# from torch.optim as optimizer
# from torch.utils.data import Dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm
from utils import BOS_TOKEN, EOS_TOKEN, PAD_TOKEN
from utils import load_reuters, save_pretrained, get_loader, init_weights



class CbowDataset(Dataset):
    def __init__(self, corpus, vocab, context_size=2):
        self.data = []
        self.bos = vocab[BOS_TOKEN]
        self.eos = vocab[EOS_TOKEN]
        for sentence in tqdm(corpus, desc="Dataset Construction"):
            sentence = [self.bos] + sentence+ [self.eos]
            if len(sentence) < context_size * 2 + 1:
                continue
            for i in range(context_size, len(sentence) - context_size):
                # 模型输入:左右分别取context_size长度的上下文
                context = sentence[i-context_size:i] + sentence[i+1:i+context_size+1]
                # 模型输出:当前词
                target = sentence[i]
                self.data.append((context, target))


    def __len__(self):
        return len(self.data)

    def __getitem__(self,i):
        return self.data[i]

    def collate_fn(self,examples):
        inputs = torch.tensor([ex[0] for ex in examples])
        targets = torch.tensor([ex[1] for ex in examples])
        return (inputs, targets)


class CbowModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CbowModel, self).__init__()
        # 词嵌入层
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        # 线性变换:隐含层->输出层
        self.output = nn.Linear(embedding_dim, vocab_size)
        init_weights(self)

    def forward(self, inputs):
        embeds = self.embeddings(inputs)
        # 计算隐含层:对上下文词向量求平均
        hidden = embeds.mean(dim=1)
        output = self.output(hidden)
        log_probs = F.log_softmax(output, dim=1)
        return log_probs
    
embedding_dim =64
context_size = 2
hidden_dim = 128
batch_size = 32
num_epoch = 10

#读取文本数据,构建CBOW模型,训练数据集
corpus, vocab = load_reuters()
dataset = CbowDataset(corpus, vocab, context_size=2)
data_loader = get_loader(dataset,batch_size)


nll_loss = nn.NLLLoss()
#构建CBOW模型,并加载至Device
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
model = CbowModel(len(vocab),embedding_dim)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)


model.train()
for epoch in range(num_epoch):
    total_loss = 0
    for batch in tqdm(data_loader,desc=f"Training Epoch {epoch}"):  #如果不加f,则输出结果时,epoch是固定的,不能动态更改
        inputs,targets = [x.to(device) for x in batch]
        optimizer.zero_grad()
        log_probs = model(inputs)
        loss = nll_loss(log_probs,targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Loss:{total_loss:.2f}")

#保存词向量(model.embeddings)
save_pretrained(vocab,model.embeddings.weight.data,"5-3Cbow.vec")     
  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2021-08-31 15:24:46  更:2021-08-31 15:27:01 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/26 22:52:40-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计