参考自哈工大车万翔等老师编写的《自然语言处理-基于预训练模型的方法》
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from tqdm import tqdm
import torch.optim as optim
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"
PAD_TOKEN = "<pad>"
def load_reuters():
from nltk.corpus import reuters
text = reuters.sents()
text = [[word.lower() for word in sentence] for sentence in text]
vocab = Vocab.build(text, reserved_tokens=[PAD_TOKEN, BOS_TOKEN, EOS_TOKEN])
corpus = [vocab.convert_tokens_to_ids(sentence) for sentence in text]
return corpus, vocab
class Vocab:
def __init__(self, tokens=None):
self.idx_to_token = list()
self.token_to_idx = dict()
if tokens is not None:
if "<unk>" not in tokens:
tokens = tokens + ["<unk>"]
for token in tokens:
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1
self.unk = self.token_to_idx["<unk>"]
@classmethod
def build(cls, text, min_freq=1, reserved_tokens=None):
token_freqs = defaultdict(int)
for sentence in text:
for token in sentence:
token_freqs[token] += 1
uniq_tokens = ["<unk>"] + (reserved_tokens if reserved_tokens else [])
uniq_tokens += [token for token, freq in token_freqs.items() if freq >= min_freq and token != "<unk>"]
return cls(uniq_tokens)
def __len__(self):
return len(self.idx_to_token)
def __getitem__(self, token):
return self.token_to_idx.get(token, self.unk)
def convert_tokens_to_ids(self, tokens):
return [self[token] for token in tokens]
def convert_ids_to_tokens(self, indices):
return [self.idx_to_token[index] for index in indices]
class SGNSDataset(Dataset):
def __init__(self, corpus, vocab, context_size=2, n_negatives=5, ns_dist=None):
self.data = []
self.bos = vocab[BOS_TOKEN]
self.eos = vocab[EOS_TOKEN]
self.pad = vocab[PAD_TOKEN]
for sentence in tqdm(corpus, desc="Dataset Construction"):
sentence = [self.bos] + sentence + [self.eos]
for i in range(1, len(sentence)-1):
w = sentence[i]
left_context_index = max(0, i-context_size)
right_context_index = min(len(sentence), i+context_size)
context = sentence[left_context_index:i] + sentence[i+1:right_context_index+1]
context += [self.pad] * (2*context_size - len(context_size))
self.data.append((w, context))
self.n_negatives = n_negatives
self.ns_dist = ns_dist if ns_dist else torch.ones(len(vocab))
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i]
def collate_fn(self, examples):
words = torch.tensor([ex[0] for ex in examples], dtype=torch.long)
contexts = torch.tensor([ex[1] for ex in examples], dtype=torch.long)
batch_size, context_size = contexts.shape
neg_contexts = []
for i in range(batch_size):
ns_dist = self.ns_dist.index_fill(0, contexts[i], .0)
neg_contexts.append(torch.multinomial(ns_dist, self.n_negatives * context_size, replacement=True))
neg_contexts = torch.stack(neg_contexts, dim=0)
return words, contexts, neg_contexts
class SGNSModel(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(SGNSModel, self).__init__()
self.w_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.c_embeddings = nn.Embedding(vocab_size, embedding_dim)
def forward_w(self, words):
w_embeds = self.w_embeddings(words)
return w_embeds
def forward_c(self, contexts):
c_embeds = self.c_embeddings(contexts)
return c_embeds
def get_unigram_distribution(corpus, vocab_size):
token_counts = torch.tensor([0]*vocab_size)
total_count = 0
for sentence in corpus:
total_count += len(sentence)
for token in sentence:
token_counts[token] += 1
unigram_dist = torch.div(token_counts.float(), total_count)
return unigram_dist
def save_pretrained(vocab, embeds, save_path):
with open(save_path, "w") as writer:
writer.write(f"{embeds.shape[0]} {embeds.shape[1]}\n")
for idx, token in enumerate(vocab.idx_to_token):
vec = " ".join([f"{x}" for x in embeds[idx]])
writer.write(f"{token} {vec}\n")
def main():
embedding_dim = 128
context_size = 3
batch_size = 1024
n_negatives = 5
num_epoch = 10
corpus, vocab = load_reuters()
unigram_dist = get_unigram_distribution(corpus, len(vocab))
negative_sampling_dist = unigram_dist ** 0.75
negative_sampling_dist /= negative_sampling_dist.sum()
dataset = SGNSDataset(corpus, vocab, context_size=context_size, n_negatives=n_negatives, ns_dist=negative_sampling_dist)
data_loader = DataLoader(dataset, batch_size)
model = SGNSModel(len(vocab), embedding_dim)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(num_epoch):
total_loss = 0
for batch in tqdm(data_loader, desc=f"Training Epoch {epoch}"):
words, contexts, neg_contexts = [x.to(device) for x in batch]
optimizer.zero_grad()
batch_size = words.shape[0]
word_embeds = model.forward_w(words).unsqueeze(dim=2)
context_embeds = model.forward_c(contexts)
neg_context_embeds = model.forward_c(neg_contexts)
context_loss = F.logsigmoid(torch.bmm(context_embeds, word_embeds).seqeeze(dim=2))
context_loss = context_loss.mean(dim=1)
neg_context_loss = F.logsigmoid(torch.bmm(neg_context_embeds, word_embeds).squeeze(dim=2).neg())
neg_context_loss = neg_context_loss.view(batch_size, -1, n_negatives).sum(2)
neg_context_loss = neg_context_loss.mean(dim=1)
loss = -(context_loss + neg_context_loss).mean()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Loss: {total_loss:.2f}")
combined_embeds = model.w_embeddings.weight + model.c_embeddings.weight
save_pretrained(vocab, combined_embeds.data, "sgns.vec")
if __name__ == "__main__":
main()
|