'''
Description: nlp之TF-IDF学习
Autor: 365JHWZGo
Date: 2021-11-16 14:48:45
LastEditors: 365JHWZGo
LastEditTime: 2021-11-16 18:27:02
'''
import numpy as np
from collections import Counter
import itertools
from visual import show_tfidf
docs = [
"Born on Oct. 21st, 1992, in Shijiazhuang, Hebei province, Allen Deng is one of the most promising actors of the new generation. ",
"Brought up by his grandparents, Deng was greatly influenced by his grandpa, who used to be a professor in college.",
"It was his beloved grandpa who shaped his character and made him what he is today. Good family education nurtured him into a person who knows fairly well how to respond to various situations.",
"In the second year of his college, as a sophomore, he was chosen by Qiong Yao, a writer in Taiwan who is famous for her romantic novels, to play the role Xu Hao in a TV series Flowers in Fog, marking the beginning of his acting career.",
"Having graduated from ShanghaiTheatreAcademy, Deng went to Beijing with great passion for acting. He played in some TV series but didn’t get known by many.",
"it is a good day, I like to stay here",
"I am happy to be here",
"I am bob",
"it is sunny today",
"I have a party today",
"it is a dog and that is a cat",
"there are dog and cat on the tree",
"I study hard this morning",
"today is a good day"
]
docs_words = [d.replace(",", "").split(" ") for d in docs]
vocab = set(itertools.chain(*docs_words))
v2i = {v: i for i, v in enumerate(vocab)}
i2v = {i: v for v, i in v2i.items()}
def safe_log(x):
mask = x != 0
x[mask] = np.log(x[mask])
return x
idf_methods = {
"log": lambda x: 1+np.log(len(docs)/(x+1)),
"prob": lambda x: np.maximum(0, np.log((len(docs)-x)/(x+1))),
"len_norm":lambda x:x/(np.sum(np.square(x))+1),
}
def get_idf(method="log"):
df = np.zeros((len(i2v),1))
for i in range(len(i2v)):
d_count = 0
for d in docs_words:
d_count+=1 if i2v[i] in d else 0
df[i,0]=d_count
idf_fn = idf_methods.get(method,None)
if idf_fn is None:
raise ValueError
return idf_fn(df)
tf_methods = {
"log":lambda x:np.log(1+x),
"augmented":lambda x:0.5+0.5*x/np.max(x,axis=1,keepdims=True),
"boolean":lambda x:np.minimum(x,1),
"log_avg":lambda x:(1+safe_log(x))/(1+safe_log(np.mean(x,axis=1,keepdims=True)))
}
def get_tf(method="log"):
_tf = np.zeros((len(vocab),len(docs)),dtype=np.float64)
for i,d in enumerate(docs_words):
counter = Counter(d)
for v in counter.keys():
_tf[v2i[v],i]=counter[v]/counter.most_common(1)[0][1]
weighted_tf = tf_methods.get(method,None)
if weighted_tf is None:
raise ValueError
return weighted_tf(_tf)
def cosine_similarity(q,_tf_idf):
unit_q = q/np.sqrt(np.sum(np.square(q),axis=0,keepdims=True))
unit_ds = _tf_idf/np.sqrt(np.sum(np.square(_tf_idf),axis=0,keepdims=True))
similarity = unit_ds.T.dot(unit_q).ravel()
return similarity
def docs_score(q,len_norm=False):
q_words = q.replace(",","").split(" ")
unkown_v = 0
for v in set(q_words):
if v not in v2i:
v2i[v] = len(v2i)
i2v[len(v2i)-1]=v
unkown_v+=1
if unkown_v>0:
_idf = np.concatenate((idf,np.zeros((unkown_v,1),dtype=np.float64)),axis=0)
_tf_idf = np.concatenate((tf_idf,np.zeros((unkown_v,tf_idf.shape[1]),dtype=np.float64)),axis=0)
else:
_idf,_tf_idf = idf,tf_idf
counter = Counter(q_words)
q_tf = np.zeros((len(_idf),1),dtype=np.float)
for v in counter.keys():
q_tf[v2i[v],0]= counter[v]
q_vec = q_tf*_idf
q_scores = cosine_similarity(q_vec,_tf_idf)
if len_norm:
len_docs = [len(d) for d in docs_words]
q_scores = q_scores/np.array(len_docs)
return q_scores
def get_keywords(n=2):
for c in range(3):
col = tf_idf[:,c]
idx = np.argsort(col)[-n:]
print("doc{},top{} keywords {}".format(c,n,[i2v[i] for i in idx]))
if __name__ == '__main__':
tf = get_tf()
idf = get_idf()
tf_idf = tf*idf
print("tf shape(vecb in each docs): ", tf.shape)
print("\ntf samples:\n", tf[:2])
print("\nidf shape(vecb in all docs): ", idf.shape)
print("\nidf samples:\n", idf[:2])
print("\ntf_idf shape: ", tf_idf.shape)
print("\ntf_idf sample:\n", tf_idf[:2])
get_keywords()
q = 'who is Allen Deng'
scores = docs_score(q)
d_ids = scores.argsort()[-3:][::-1]
print('\ntop 3 docs for "{}":\n{}'.format(q,[docs[i] for i in d_ids]))
show_tfidf(tf_idf.T,[i2v[i] for i in range(tf_idf.shape[0])],"tfidf_matrix")
nlp目录结构 下载的包
conda install requests conda install pandas
单词分布结果: 代码运行结果:
utils.py
'''
Description: utils.py 依赖
Autor: 365JHWZGo
Date: 2021-11-16 16:56:48
LastEditors: 365JHWZGo
LastEditTime: 2021-11-16 17:08:37
'''
import numpy as np
import datetime
import os
import requests
import pandas as pd
import re
import itertools
PAD_ID = 0
class DateData:
def __init__(self, n):
np.random.seed(1)
self.date_cn = []
self.date_en = []
for timestamp in np.random.randint(143835585, 2043835585, n):
date = datetime.datetime.fromtimestamp(timestamp)
self.date_cn.append(date.strftime("%y-%m-%d"))
self.date_en.append(date.strftime("%d/%b/%Y"))
self.vocab = set(
[str(i) for i in range(0, 10)] + ["-", "/", "<GO>", "<EOS>"] + [
i.split("/")[1] for i in self.date_en])
self.v2i = {v: i for i, v in enumerate(sorted(list(self.vocab)), start=1)}
self.v2i["<PAD>"] = PAD_ID
self.vocab.add("<PAD>")
self.i2v = {i: v for v, i in self.v2i.items()}
self.x, self.y = [], []
for cn, en in zip(self.date_cn, self.date_en):
self.x.append([self.v2i[v] for v in cn])
self.y.append(
[self.v2i["<GO>"], ] + [self.v2i[v] for v in en[:3]] + [
self.v2i[en[3:6]], ] + [self.v2i[v] for v in en[6:]] + [
self.v2i["<EOS>"], ])
self.x, self.y = np.array(self.x), np.array(self.y)
self.start_token = self.v2i["<GO>"]
self.end_token = self.v2i["<EOS>"]
def sample(self, n=64):
bi = np.random.randint(0, len(self.x), size=n)
bx, by = self.x[bi], self.y[bi]
decoder_len = np.full((len(bx),), by.shape[1] - 1, dtype=np.int32)
return bx, by, decoder_len
def idx2str(self, idx):
x = []
for i in idx:
x.append(self.i2v[i])
if i == self.end_token:
break
return "".join(x)
@property
def num_word(self):
return len(self.vocab)
def pad_zero(seqs, max_len):
padded = np.full((len(seqs), max_len), fill_value=PAD_ID, dtype=np.long)
for i, seq in enumerate(seqs):
padded[i, :len(seq)] = seq
return padded
def maybe_download_mrpc(save_dir="./MRPC/", proxy=None):
train_url = 'https://mofanpy.com/static/files/MRPC/msr_paraphrase_train.txt'
test_url = 'https://mofanpy.com/static/files/MRPC/msr_paraphrase_test.txt'
os.makedirs(save_dir, exist_ok=True)
proxies = {"http": proxy, "https": proxy}
for url in [train_url, test_url]:
raw_path = os.path.join(save_dir, url.split("/")[-1])
if not os.path.isfile(raw_path):
print("downloading from %s" % url)
r = requests.get(url, proxies=proxies)
with open(raw_path, "w", encoding="utf-8") as f:
f.write(r.text.replace('"', "<QUOTE>"))
print("completed")
def _text_standardize(text):
text = re.sub(r'—', '-', text)
text = re.sub(r'–', '-', text)
text = re.sub(r'―', '-', text)
text = re.sub(r" \d+(,\d+)?(\.\d+)? ", " <NUM> ", text)
text = re.sub(r" \d+-+?\d*", " <NUM>-", text)
return text.strip()
def _process_mrpc(dir="./MRPC", rows=None):
data = {"train": None, "test": None}
files = os.listdir(dir)
for f in files:
df = pd.read_csv(os.path.join(dir, f), sep='\t', nrows=rows)
k = "train" if "train" in f else "test"
data[k] = {"is_same": df.iloc[:, 0].values, "s1": df["#1 String"].values, "s2": df["#2 String"].values}
vocab = set()
for n in ["train", "test"]:
for m in ["s1", "s2"]:
for i in range(len(data[n][m])):
data[n][m][i] = _text_standardize(data[n][m][i].lower())
cs = data[n][m][i].split(" ")
vocab.update(set(cs))
v2i = {v: i for i, v in enumerate(sorted(vocab), start=1)}
v2i["<PAD>"] = PAD_ID
v2i["<MASK>"] = len(v2i)
v2i["<SEP>"] = len(v2i)
v2i["<GO>"] = len(v2i)
i2v = {i: v for v, i in v2i.items()}
for n in ["train", "test"]:
for m in ["s1", "s2"]:
data[n][m+"id"] = [[v2i[v] for v in c.split(" ")] for c in data[n][m]]
return data, v2i, i2v
class MRPCData:
num_seg = 3
pad_id = PAD_ID
def __init__(self, data_dir="./MRPC/", rows=None, proxy=None):
maybe_download_mrpc(save_dir=data_dir, proxy=proxy)
data, self.v2i, self.i2v = _process_mrpc(data_dir, rows)
self.max_len = max(
[len(s1) + len(s2) + 3 for s1, s2 in zip(
data["train"]["s1id"] + data["test"]["s1id"], data["train"]["s2id"] + data["test"]["s2id"])])
self.xlen = np.array([
[
len(data["train"]["s1id"][i]), len(data["train"]["s2id"][i])
] for i in range(len(data["train"]["s1id"]))], dtype=int)
x = [
[self.v2i["<GO>"]] + data["train"]["s1id"][i] + [self.v2i["<SEP>"]] + data["train"]["s2id"][i] + [self.v2i["<SEP>"]]
for i in range(len(self.xlen))
]
self.x = pad_zero(x, max_len=self.max_len)
self.nsp_y = data["train"]["is_same"][:, None]
self.seg = np.full(self.x.shape, self.num_seg-1, np.int32)
for i in range(len(x)):
si = self.xlen[i][0] + 2
self.seg[i, :si] = 0
si_ = si + self.xlen[i][1] + 1
self.seg[i, si:si_] = 1
self.word_ids = np.array(list(set(self.i2v.keys()).difference(
[self.v2i[v] for v in ["<PAD>", "<MASK>", "<SEP>"]])))
def sample(self, n):
bi = np.random.randint(0, self.x.shape[0], size=n)
bx, bs, bl, by = self.x[bi], self.seg[bi], self.xlen[bi], self.nsp_y[bi]
return bx, bs, bl, by
@property
def num_word(self):
return len(self.v2i)
@property
def mask_id(self):
return self.v2i["<MASK>"]
class MRPCSingle:
pad_id = PAD_ID
def __init__(self, data_dir="./MRPC/", rows=None, proxy=None):
maybe_download_mrpc(save_dir=data_dir, proxy=proxy)
data, self.v2i, self.i2v = _process_mrpc(data_dir, rows)
self.max_len = max([len(s) + 2 for s in data["train"]["s1id"] + data["train"]["s2id"]])
x = [
[self.v2i["<GO>"]] + data["train"]["s1id"][i] + [self.v2i["<SEP>"]]
for i in range(len(data["train"]["s1id"]))
]
x += [
[self.v2i["<GO>"]] + data["train"]["s2id"][i] + [self.v2i["<SEP>"]]
for i in range(len(data["train"]["s2id"]))
]
self.x = pad_zero(x, max_len=self.max_len)
self.word_ids = np.array(list(set(self.i2v.keys()).difference([self.v2i["<PAD>"]])))
def sample(self, n):
bi = np.random.randint(0, self.x.shape[0], size=n)
bx = self.x[bi]
return bx
@property
def num_word(self):
return len(self.v2i)
class Dataset:
def __init__(self, x, y, v2i, i2v):
self.x, self.y = x, y
self.v2i, self.i2v = v2i, i2v
self.vocab = v2i.keys()
def sample(self, n):
b_idx = np.random.randint(0, len(self.x), n)
bx, by = self.x[b_idx], self.y[b_idx]
return bx, by
@property
def num_word(self):
return len(self.v2i)
def process_w2v_data(corpus, skip_window=2, method="skip_gram"):
all_words = [sentence.split(" ") for sentence in corpus]
all_words = np.array(list(itertools.chain(*all_words)))
vocab, v_count = np.unique(all_words, return_counts=True)
vocab = vocab[np.argsort(v_count)[::-1]]
print("all vocabularies sorted from more frequent to less frequent:\n", vocab)
v2i = {v: i for i, v in enumerate(vocab)}
i2v = {i: v for v, i in v2i.items()}
pairs = []
js = [i for i in range(-skip_window, skip_window + 1) if i != 0]
for c in corpus:
words = c.split(" ")
w_idx = [v2i[w] for w in words]
if method == "skip_gram":
for i in range(len(w_idx)):
for j in js:
if i + j < 0 or i + j >= len(w_idx):
continue
pairs.append((w_idx[i], w_idx[i + j]))
elif method.lower() == "cbow":
for i in range(skip_window, len(w_idx) - skip_window):
context = []
for j in js:
context.append(w_idx[i + j])
pairs.append(context + [w_idx[i]])
else:
raise ValueError
pairs = np.array(pairs)
print("5 example pairs:\n", pairs[:5])
if method.lower() == "skip_gram":
x, y = pairs[:, 0], pairs[:, 1]
elif method.lower() == "cbow":
x, y = pairs[:, :-1], pairs[:, -1]
else:
raise ValueError
return Dataset(x, y, v2i, i2v)
def set_soft_gpu(soft_gpu):
import tensorflow as tf
if soft_gpu:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
visual.py
'''
Description: visual.py依赖
Autor: 365JHWZGo
Date: 2021-11-16 16:58:26
LastEditors: 365JHWZGo
LastEditTime: 2021-11-16 17:15:05
'''
import matplotlib.pyplot as plt
import numpy as np
import pickle
from matplotlib.pyplot import cm
import os
import utils
def show_tfidf(tfidf, vocab, filename):
plt.imshow(tfidf, cmap="YlGn", vmin=tfidf.min(), vmax=tfidf.max())
plt.xticks(np.arange(tfidf.shape[1]), vocab, fontsize=6, rotation=90)
plt.yticks(np.arange(tfidf.shape[0]), np.arange(1, tfidf.shape[0]+1), fontsize=6)
plt.tight_layout()
plt.savefig("./nlp/images/%s.png" % filename, format="png", dpi=500)
plt.show()
def show_w2v_word_embedding(model, data: utils.Dataset, path):
word_emb = model.embeddings.get_weights()[0]
for i in range(data.num_word):
c = "blue"
try:
int(data.i2v[i])
except ValueError:
c = "red"
plt.text(word_emb[i, 0], word_emb[i, 1], s=data.i2v[i], color=c, weight="bold")
plt.xlim(word_emb[:, 0].min() - .5, word_emb[:, 0].max() + .5)
plt.ylim(word_emb[:, 1].min() - .5, word_emb[:, 1].max() + .5)
plt.xticks(())
plt.yticks(())
plt.xlabel("embedding dim1")
plt.ylabel("embedding dim2")
plt.savefig(path, dpi=300, format="png")
plt.show()
def seq2seq_attention():
with open("./visual/tmp/attention_align.pkl", "rb") as f:
data = pickle.load(f)
i2v, x, y, align = data["i2v"], data["x"], data["y"], data["align"]
plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
for i in range(6):
plt.subplot(2, 3, i + 1)
x_vocab = [i2v[j] for j in np.ravel(x[i])]
y_vocab = [i2v[j] for j in y[i, 1:]]
plt.imshow(align[i], cmap="YlGn", vmin=0., vmax=1.)
plt.yticks([j for j in range(len(y_vocab))], y_vocab)
plt.xticks([j for j in range(len(x_vocab))], x_vocab)
if i == 0 or i == 3:
plt.ylabel("Output")
if i >= 3:
plt.xlabel("Input")
plt.tight_layout()
plt.savefig("./visual/results/seq2seq_attention.png", format="png", dpi=200)
plt.show()
def all_mask_kinds():
seqs = ["I love you", "My name is M", "This is a very long seq", "Short one"]
vocabs = set((" ".join(seqs)).split(" "))
i2v = {i: v for i, v in enumerate(vocabs, start=1)}
i2v["<PAD>"] = 0
v2i = {v: i for i, v in i2v.items()}
id_seqs = [[v2i[v] for v in seq.split(" ")] for seq in seqs]
padded_id_seqs = np.array([l + [0] * (6 - len(l)) for l in id_seqs])
pmask = np.where(padded_id_seqs == 0, np.ones_like(padded_id_seqs), np.zeros_like(padded_id_seqs))
pmask = np.repeat(pmask[:, None, :], pmask.shape[-1], axis=1)
plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
for i in range(1, 5):
plt.subplot(2, 2, i)
plt.imshow(pmask[i-1], vmax=1, vmin=0, cmap="YlGn")
plt.xticks(range(6), seqs[i - 1].split(" "), rotation=45)
plt.yticks(range(6), seqs[i - 1].split(" "),)
plt.grid(which="minor", c="w", lw=0.5, linestyle="-")
plt.tight_layout()
plt.savefig("./visual/results/transformer_pad_mask.png", dpi=200)
plt.show()
max_len = pmask.shape[-1]
omask = ~np.triu(np.ones((max_len, max_len), dtype=np.bool), 1)
omask = np.tile(np.expand_dims(omask, axis=0), [np.shape(seqs)[0], 1, 1])
omask = np.where(omask, pmask, 1)
plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
for i in range(1, 5):
plt.subplot(2, 2, i)
plt.imshow(omask[i - 1], vmax=1, vmin=0, cmap="YlGn")
plt.xticks(range(6), seqs[i - 1].split(" "), rotation=45)
plt.yticks(range(6), seqs[i - 1].split(" "), )
plt.grid(which="minor", c="w", lw=0.5, linestyle="-")
plt.tight_layout()
plt.savefig("./visual/results/transformer_look_ahead_mask.png", dpi=200)
plt.show()
def position_embedding():
max_len = 500
model_dim = 512
pos = np.arange(max_len)[:, None]
pe = pos / np.power(10000, 2. * np.arange(model_dim)[None, :] / model_dim)
pe[:, 0::2] = np.sin(pe[:, 0::2])
pe[:, 1::2] = np.cos(pe[:, 1::2])
plt.imshow(pe, vmax=1, vmin=-1, cmap="rainbow")
plt.ylabel("word position")
plt.xlabel("embedding dim")
plt.savefig("./visual/results/transformer_position_embedding.png", dpi=200)
plt.show()
def transformer_attention_matrix(case=0):
with open("./visual/tmp/transformer_attention_matrix.pkl", "rb") as f:
data = pickle.load(f)
src = data["src"][case]
tgt = data["tgt"][case]
attentions = data["attentions"]
encoder_atten = attentions["encoder"]
decoder_tgt_atten = attentions["decoder"]["mh1"]
decoder_src_atten = attentions["decoder"]["mh2"]
plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
plt.figure(0, (7, 7))
plt.suptitle("Encoder self-attention")
for i in range(3):
for j in range(4):
plt.subplot(3, 4, i * 4 + j + 1)
plt.imshow(encoder_atten[i][case, j][:len(src), :len(src)], vmax=1, vmin=0, cmap="rainbow")
plt.xticks(range(len(src)), src)
plt.yticks(range(len(src)), src)
if j == 0:
plt.ylabel("layer %i" % (i+1))
if i == 2:
plt.xlabel("head %i" % (j+1))
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.savefig("./visual/results/transformer%d_encoder_self_attention.png" % case, dpi=200)
plt.show()
plt.figure(1, (7, 7))
plt.suptitle("Decoder self-attention")
for i in range(3):
for j in range(4):
plt.subplot(3, 4, i * 4 + j + 1)
plt.imshow(decoder_tgt_atten[i][case, j][:len(tgt), :len(tgt)], vmax=1, vmin=0, cmap="rainbow")
plt.xticks(range(len(tgt)), tgt, rotation=90, fontsize=7)
plt.yticks(range(len(tgt)), tgt, fontsize=7)
if j == 0:
plt.ylabel("layer %i" % (i+1))
if i == 2:
plt.xlabel("head %i" % (j+1))
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.savefig("./visual/results/transformer%d_decoder_self_attention.png" % case, dpi=200)
plt.show()
plt.figure(2, (7, 8))
plt.suptitle("Decoder-Encoder attention")
for i in range(3):
for j in range(4):
plt.subplot(3, 4, i*4+j+1)
plt.imshow(decoder_src_atten[i][case, j][:len(tgt), :len(src)], vmax=1, vmin=0, cmap="rainbow")
plt.xticks(range(len(src)), src, fontsize=7)
plt.yticks(range(len(tgt)), tgt, fontsize=7)
if j == 0:
plt.ylabel("layer %i" % (i+1))
if i == 2:
plt.xlabel("head %i" % (j+1))
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.savefig("./visual/results/transformer%d_decoder_encoder_attention.png" % case, dpi=200)
plt.show()
def transformer_attention_line(case=0):
with open("./visual/tmp/transformer_attention_matrix.pkl", "rb") as f:
data = pickle.load(f)
src = data["src"][case]
tgt = data["tgt"][case]
attentions = data["attentions"]
decoder_src_atten = attentions["decoder"]["mh2"]
tgt_label = tgt[1:11][::-1]
src_label = ["" for _ in range(2)] + src[::-1]
fig, ax = plt.subplots(nrows=2, ncols=2, sharex=True, figsize=(7, 14))
for i in range(2):
for j in range(2):
ax[i, j].set_yticks(np.arange(len(src_label)))
ax[i, j].set_yticklabels(src_label, fontsize=9)
ax[i, j].set_ylim(0, len(src_label)-1)
ax_ = ax[i, j].twinx()
ax_.set_yticks(np.linspace(ax_.get_yticks()[0], ax_.get_yticks()[-1], len(ax[i, j].get_yticks())))
ax_.set_yticklabels(tgt_label, fontsize=9)
img = decoder_src_atten[-1][case, i + j][:10, :8]
color = cm.rainbow(np.linspace(0, 1, img.shape[0]))
left_top, right_top = img.shape[1], img.shape[0]
for ri, c in zip(range(right_top), color):
for li in range(left_top):
alpha = (img[ri, li] / img[ri].max()) ** 8
ax[i, j].plot([0, 1], [left_top - li + 1, right_top - 1 - ri], alpha=alpha, c=c)
ax[i, j].set_xticks(())
ax[i, j].set_xlabel("head %i" % (j + 1 + i * 2))
ax[i, j].set_xlim(0, 1)
plt.subplots_adjust(top=0.9)
plt.tight_layout()
plt.savefig("./visual/results/transformer%d_encoder_decoder_attention_line.png" % case, dpi=100)
def self_attention_matrix(bert_or_gpt="bert", case=0):
with open("./visual/tmp/"+bert_or_gpt+"_attention_matrix.pkl", "rb") as f:
data = pickle.load(f)
src = data["src"]
attentions = data["attentions"]
encoder_atten = attentions["encoder"]
plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
s_len = 0
for s in src[case]:
if s == "<SEP>":
break
s_len += 1
plt.figure(0, (7, 28))
for j in range(4):
plt.subplot(4, 1, j + 1)
img = encoder_atten[-1][case, j][:s_len-1, :s_len-1]
plt.imshow(img, vmax=img.max(), vmin=0, cmap="rainbow")
plt.xticks(range(s_len-1), src[case][:s_len-1], rotation=90, fontsize=9)
plt.yticks(range(s_len-1), src[case][1:s_len], fontsize=9)
plt.xlabel("head %i" % (j+1))
plt.subplots_adjust(top=0.9)
plt.tight_layout()
plt.savefig("./visual/results/"+bert_or_gpt+"%d_self_attention.png" % case, dpi=500)
def self_attention_line(bert_or_gpt="bert", case=0):
with open("./visual/tmp/"+bert_or_gpt+"_attention_matrix.pkl", "rb") as f:
data = pickle.load(f)
src = data["src"][case]
attentions = data["attentions"]
encoder_atten = attentions["encoder"]
s_len = 0
print(" ".join(src))
for s in src:
if s == "<SEP>":
break
s_len += 1
y_label = src[:s_len][::-1]
fig, ax = plt.subplots(nrows=2, ncols=2, sharex=True, figsize=(7, 14))
for i in range(2):
for j in range(2):
ax[i, j].set_yticks(np.arange(len(y_label)))
ax[i, j].tick_params(labelright=True)
ax[i, j].set_yticklabels(y_label, fontsize=9)
img = encoder_atten[-1][case, i+j][:s_len - 1, :s_len - 1]
color = cm.rainbow(np.linspace(0, 1, img.shape[0]))
for row, c in zip(range(img.shape[0]), color):
for col in range(img.shape[1]):
alpha = (img[row, col] / img[row].max()) ** 5
ax[i, j].plot([0, 1], [img.shape[1]-col, img.shape[0]-row-1], alpha=alpha, c=c)
ax[i, j].set_xticks(())
ax[i, j].set_xlabel("head %i" % (j+1+i*2))
ax[i, j].set_xlim(0, 1)
plt.subplots_adjust(top=0.9)
plt.tight_layout()
plt.savefig("./visual/results/"+bert_or_gpt+"%d_self_attention_line.png" % case, dpi=100)
if __name__ == "__main__":
os.makedirs("./visual/results", exist_ok=True)
transformer_attention_matrix(case=0)
transformer_attention_line(case=0)
|