1. 什么是word embedding
(1)从word到num
我们的自然语言,不管是中文还是英文都不能直接在机器中表达,此时就要将自然语言映射为数字。要映射成数字就要有字典,所以一般会先构建词典,举例如下:
word_dict = {"我":0, "你":1, "他":2, "她":3, "是":4, "好":5,
"坏":6, "人":7, "天":8, "第":9, "气":10, "今":11,
"怎":12, "么":13, "样":14, "啊":15}
我们假设词典的大小为15,即voc_size=15,我们的sentences为"今天天气怎么样"和"他人怎么样",这样的话通过查表就可以得到如下的表示:
word2index = [[11, 8, 8, 10, 12, 13, 14],
[2, 7, 12, 13, 14]]
(2)pad
可以看到因为各个句子的长度不一样,所以生成的矩阵不整齐,这样也不利于进行矩阵计算,所以需要进行pad,将各个向量转为长度相同的向量。最简单的方法就是填0:
padded = [[11, 8, 8, 10, 12, 13, 14, 0, 0, 0],
[2, 7, 12, 13, 14, 0, 0, 0, 0, 0]]
如上所示就是将input_length设为10,如果原本长度小于10的补0,大于10的截断。
上面的matrix看起来并不是one_hot的形式,但实际上上式跟下面的表示是等价的:
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]
(3)embedding
由于one_hot编码的稀疏性,且这种编码无法描述两个元素之间的相关性,所以可以用embedding编码。比如在上面的one_hot编码中我们是用15维的特征来描述一个字的,上面的“我”和“你”两个向量点乘的话结果为0,完全是没有关系的,但如果用另一种方式编码的话就可以有关系了:
人称代词 名词 动词 形容词
我 0.9 0.5 0.2 0.2
你 0.8 0.6 0.3. 0.1
这样的话就可以将10维的特征描述转为4维的特征描述,且效果看起来会更好一些。
当然在进行embedding的时候不是人为设定的特征,而是人为设定好想要的特征维数之后通过语料训练得到的。
2. 实现
用的是tensorflow2.0
from tensorflow.keras.preprocessing.text import one_hot
sentences=['the glass of milk',
'the glass of juice',
'the cup of tea',
'I am a good boy']
voc_size=10000
onehot_repr=[one_hot(words, voc_size) for words in sentences]
print(onehot_repr)
from tensorflow.keras.layers import Embedding
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
import numpy as np
sent_length=8
embedded_docs=pad_sequences(onehot_repr, padding='pre', maxlen=sent_length)
print(embedded_docs)
dim = 10
model = Sequential()
model.add(Embedding(voc_size, dim, input_length=sent_length))
model.compile('adam', 'mse')
print(model.summary())
vector = model.predict(embedded_docs)
print(model.predict(embedded_docs))
print(vector.shape)
结果:
>>> print(onehot_repr)
[[8174, 3076, 416, 6851], [8174, 3076, 416, 6687], [8174, 9660, 416, 8721], [5223, 4222, 5952, 6180, 5440]]
>>> print(embedded_docs)
[[ 0 0 0 0 8174 3076 416 6851]
[ 0 0 0 0 8174 3076 416 6687]
[ 0 0 0 0 8174 9660 416 8721]
[ 0 0 0 5223 4222 5952 6180 5440]]
>>> print(model.summary())
_________________________________________________________________
Layer (type) Output Shape Param
=================================================================
embedding (Embedding) (None, 8, 10) 100000
=================================================================
Total params: 100,000
Trainable params: 100,000
Non-trainable params: 0
_________________________________________________________________
None
>>> print(model.predict(embedded_docs))
[[[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[-0.02706006 0.01063529 -0.01942388 0.02701591 -0.04124977
-0.00983888 0.01273515 0.03012211 0.04841721 -0.01894962]
[ 0.00510359 0.01853384 -0.02409974 -0.02285388 0.04018563
-0.04754727 0.02264073 -0.01251531 -0.04369598 0.03063634]
[-0.03827395 0.0083343 -0.03649645 0.00391301 -0.0283778
0.04224857 0.03885354 -0.01442292 0.01358733 -0.03044585]
[-0.02544751 -0.02753698 0.00250997 -0.01593918 0.04284723
0.03717153 0.01787357 0.01125566 -0.0267596 -0.0248112 ]]
[[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[-0.02706006 0.01063529 -0.01942388 0.02701591 -0.04124977
-0.00983888 0.01273515 0.03012211 0.04841721 -0.01894962]
[ 0.00510359 0.01853384 -0.02409974 -0.02285388 0.04018563
-0.04754727 0.02264073 -0.01251531 -0.04369598 0.03063634]
[-0.03827395 0.0083343 -0.03649645 0.00391301 -0.0283778
0.04224857 0.03885354 -0.01442292 0.01358733 -0.03044585]
[ 0.03599827 -0.00697263 0.01096133 0.01282989 0.04026625
-0.0409615 -0.03822895 0.03571489 -0.03869583 0.0247351 ]]
[[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[-0.02706006 0.01063529 -0.01942388 0.02701591 -0.04124977
-0.00983888 0.01273515 0.03012211 0.04841721 -0.01894962]
[-0.04196328 -0.0178645 0.01629119 0.00710867 0.03742753
0.04766042 -0.01144195 -0.00392986 0.04960826 0.01370332]
[-0.03827395 0.0083343 -0.03649645 0.00391301 -0.0283778
0.04224857 0.03885354 -0.01442292 0.01358733 -0.03044585]
[ 0.03060566 -0.01925355 -0.01740856 0.00497576 -0.04157882
0.01061495 0.04219753 -0.02456384 0.03463561 -0.01594185]]
[[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[ 0.04522029 -0.01226036 0.00501914 -0.04841211 -0.03839918
-0.01737207 0.01646307 -0.02959521 0.03347592 0.0029294 ]
[-0.01688796 0.02185148 0.01407048 0.01172693 -0.04144372
-0.02081727 0.02715001 0.01198126 0.00415362 0.02064079]
[ 0.01959873 0.03910967 -0.03127551 -0.04483137 -0.01185248
0.03648222 0.04708296 -0.00957827 -0.002679 0.03122015]
[ 0.00080504 0.00700544 0.02628921 0.0229356 0.04947283
0.01667294 0.03602554 -0.01248958 -0.00070317 -0.03361555]
[-0.0488179 -0.02457787 -0.03306667 -0.03750541 -0.03436396
-0.04636976 -0.03443474 0.00712519 -0.02974316 0.03063191]
[ 0.0434726 0.04021135 -0.03558815 0.04452255 0.04240603
-0.011404 0.00316377 -0.01917359 0.03822576 -0.01635139]]]
>>> print(vector.shape)
(4, 8, 10)
|