背景
目前csdn上搜索到的keras的版本实现,排在前面的是: https://blog.csdn.net/xiaosongshine/article/details/86595847
但是,这个文章存在挺多问题。本身这个文章的实现其实是少了一部分的(缺少了LayerNorm+残差的部分),multi-head attention的实现也少了一个
W
o
W_o
Wo?再做一次全连接映射。加上其本身运用的参数跟原始论文也差很多,所以跟论文描述的encoder区块其实对应不太上,如果是想对着论文来看代码的话,这段代码可能会产生一定的误导。所以我从各个地方找了其他的缺少的部分实现,凑出一个基本能对应上论文的keras版本的transformer-encoder完整的实现;另一方面,也顺便结合原理和代码(会尽量把注释写清楚),将transformer的原理重新复习一遍。
keras的版本
为了兼容csdn上看到的代码,keras的版本采用的是2.2.4的keras版本(非tf.kreas)。如果需要其他更高阶版本或者tf.keras的版本,可能会需要有一定的改动,可以参考GitHub上的CyberZHG的代码进行改动即可。
主要参考链接
原理主要参考链接:
- https://zhuanlan.zhihu.com/p/44121378
- https://zhuanlan.zhihu.com/p/44731789
- https://blog.csdn.net/u012526436/article/details/86295971
- 原始论文 https://arxiv.org/pdf/1706.03762.pdf
代码主要参考链接:
- https://github.com/CyberZHG/keras-transformer
- https://blog.csdn.net/xiaosongshine/article/details/86595847
- https://blog.csdn.net/qq_40742298/article/details/115011147
模型整体结构
因为是用来做文本分类,所以这个图里面我们只谈左边的encoder部分。 encoder部分首先是input + embedding部分,其次是由N个block组成的编码部分,在原文中,这个N是6。每个block呢,又由multi-head attention、add & norm 、feedforward和残连接层组成,我们接下来还是一步一步的拆解。
Input层
原始的Input层,为词向量+position embedding,这个跟一般的文本输入一样,假设输入为(batch_size, seq_len, embedding_size),注意一点的是,这个embedding_size为了在后续可以接上残差连接层,其应该要在整个网络中保证一致,原文中,这个embedding_size和各种子层的维度要一致,原文都是512维,以
d
m
o
d
e
l
=
512
d_{model}=512
dmodel?=512表示。
Position embedding层
因为transformer与RNN不同,其没有了词位置顺序信息,因此为了保证位置信息,先将词过一个position embedding,然后再与词向量求和作为后续block的输入。注意一点的是,《Attention is all you need》原文提到了用sin和cos的方式以及训练词位置的embedding,经过实验发现二者没有区别,最后用的是sin和cos的方式。但是bert里面的position embedding是可训练的。 公式不赘述,大致表示如下: 具体的代码实现及注释见下:
from __future__ import print_function
from keras import backend as K
from keras.engine.topology import Layer
class Position_Embedding(Layer):
def __init__(self, size=None, mode='sum', **kwargs):
self.size = size
self.mode = mode
super(Position_Embedding, self).__init__(**kwargs)
def call(self, x):
if (self.size == None) or (self.mode == 'sum'):
self.size = int(x.shape[-1])
batch_size,seq_len = K.shape(x)[0],K.shape(x)[1]
position_j = 1. / K.pow(10000., 2 * K.arange(self.size / 2, dtype='float32' ) / self.size)
position_j = K.expand_dims(position_j, 0)
position_i = K.cumsum(K.ones_like(x[:,:,0]), 1)-1
position_i = K.expand_dims(position_i, 2)
position_ij = K.dot(position_i, position_j)
position_ij_2i = K.sin(position_ij)[...,tf.newaxis]
position_ij_2i_1 = K.cos(postition_ij)[...,tf.newaxis]
position_ij = K.concatenate([position_ij_2i,position_ij_2i_1])
position_ij = K.reshape(position_ij,(batch_size,seq_len,self.size))
if self.mode == 'sum':
return position_ij + x
elif self.mode == 'concat':
return K.concatenate([position_ij, x], 2)
def compute_output_shape(self, input_shape):
if self.mode == 'sum':
return input_shape
elif self.mode == 'concat':
return (input_shape[0], input_shape[1], input_shape[2]+self.size)
单个block的各自实现
multi-head attention
首先,我们需要先实现单个的attention,如果不想按单个单个的attention实现,可以参考https://blog.csdn.net/xiaosongshine/article/details/86595847的attention层快速实现多个attention,不过需要添加一个Wo才能和论文完全一致,这里为了保证跟论文一致且拆解更清晰,我们先实现单个attention。
scaled dot attention
看一下scaled dot attention的示意图及公式:
- 定义Wq,Wk,Wv三个矩阵
- 分别用三个矩阵相乘得到Q,K ,V
- Q,K dot得到分数,算softmax权重
- 权重 * V矩阵得到最后的加权后的V矩阵(H矩阵)
- 特别的是算softmax的时候要除以一个
D
k
\sqrt{D_{k}}
Dk?
?,具体原因见https://blog.csdn.net/qq_37430422/article/details/105042303
代码实现:
class ScaledDotProductAttention(Layer):
r"""The attention layer that takes three inputs representing queries, keys and values.
\text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}}) V
See: https://arxiv.org/pdf/1706.03762.pdf
"""
def __init__(self,
return_attention=False,
history_only=False,
**kwargs):
"""Initialize the layer.
:param return_attention: Whether to return attention weights.
:param history_only: Whether to only use history data.
:param kwargs: Arguments for parent class.
"""
super(ScaledDotProductAttention, self).__init__(**kwargs)
self.supports_masking = True
self.return_attention = return_attention
self.history_only = history_only
self.intensity = self.attention = None
def get_config(self):
config = {
'return_attention': self.return_attention,
'history_only': self.history_only,
}
base_config = super(ScaledDotProductAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
query_shape, key_shape, value_shape = input_shape
else:
query_shape = key_shape = value_shape = input_shape
output_shape = query_shape[:-1] + value_shape[-1:]
if self.return_attention:
attention_shape = query_shape[:2] + (key_shape[1],)
return [output_shape, attention_shape]
return output_shape
def compute_mask(self, inputs, mask=None):
if isinstance(mask, list):
mask = mask[0]
if self.return_attention:
return [mask, None]
return mask
def call(self, inputs, mask=None, **kwargs):
if isinstance(inputs, list):
query, key, value = inputs
else:
query = key = value = inputs
if isinstance(mask, list):
mask = mask[1]
feature_dim = K.shape(query)[-1]
e = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(feature_dim, dtype=K.floatx()))
if self.history_only:
query_len, key_len = K.shape(query)[1], K.shape(key)[1]
indices = K.expand_dims(K.arange(0, key_len), axis=0)
upper = K.expand_dims(K.arange(0, query_len), axis=-1)
e -= 10000.0 * K.expand_dims(K.cast(indices > upper, K.floatx()), axis=0)
if mask is not None:
e -= 10000.0 * (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx()))
self.intensity = e
e = K.exp(e - K.max(e, axis=-1, keepdims=True))
self.attention = e / K.sum(e, axis=-1, keepdims=True)
v = K.batch_dot(self.attention, value)
if self.return_attention:
return [v, self.attention]
return v
multi-head attention
这个实现其实就是比较简单的了,把Q,K,V先映射一遍,然后切成num_head个块之后,再分别通过前面实现的scaled dot attention最后合并,然后再做一个映射即可,用Q举例看一下示意图: (1)假设Q(bs=1,seq_len=10,dim=512)已经过了一个映射层,得到Q_的示意 (2)同理得到的K_,计算Q_和K_计算dot attention矩阵 (3)同理得到V_,加权求和Outputs (4)reshape回去 (5)最后,再过一次Wo 代码实现:
class MultiHeadAttention(Layer):
"""Multi-head attention layer.
See: https://arxiv.org/pdf/1706.03762.pdf
"""
def __init__(self,
head_num,
activation='relu',
use_bias=True,
kernel_initializer='glorot_normal',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
history_only=False,
**kwargs):
"""Initialize the layer.
:param head_num: Number of heads.
:param activation: Activations for linear mappings.
:param use_bias: Whether to use bias term.
:param kernel_initializer: Initializer for linear mappings.
:param bias_initializer: Initializer for linear mappings.
:param kernel_regularizer: Regularizer for linear mappings.
:param bias_regularizer: Regularizer for linear mappings.
:param kernel_constraint: Constraints for linear mappings.
:param bias_constraint: Constraints for linear mappings.
:param history_only: Whether to only use history in attention layer.
"""
self.supports_masking = True
self.head_num = head_num
self.activation = keras.activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.bias_initializer = keras.initializers.get(bias_initializer)
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
self.bias_regularizer = keras.regularizers.get(bias_regularizer)
self.kernel_constraint = keras.constraints.get(kernel_constraint)
self.bias_constraint = keras.constraints.get(bias_constraint)
self.history_only = history_only
self.Wq = self.Wk = self.Wv = self.Wo = None
self.bq = self.bk = self.bv = self.bo = None
self.intensity = self.attention = None
super(MultiHeadAttention, self).__init__(**kwargs)
def get_config(self):
config = {
'head_num': self.head_num,
'activation': keras.activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
'bias_initializer': keras.initializers.serialize(self.bias_initializer),
'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
'bias_constraint': keras.constraints.serialize(self.bias_constraint),
'history_only': self.history_only,
}
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
q, k, v = input_shape
return q[:-1] + (v[-1],)
return input_shape
def compute_mask(self, inputs, input_mask=None):
if isinstance(input_mask, list):
return input_mask[0]
return input_mask
def build(self, input_shape):
if isinstance(input_shape, list):
q, k, v = input_shape
else:
q = k = v = input_shape
feature_dim = int(v[-1])
if feature_dim % self.head_num != 0:
raise IndexError('Invalid head number %d with the given input dim %d' % (self.head_num, feature_dim))
self.Wq = self.add_weight(
shape=(int(q[-1]), feature_dim),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
name='%s_Wq' % self.name,
)
if self.use_bias:
self.bq = self.add_weight(
shape=(feature_dim,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
name='%s_bq' % self.name,
)
self.Wk = self.add_weight(
shape=(int(k[-1]), feature_dim),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
name='%s_Wk' % self.name,
)
if self.use_bias:
self.bk = self.add_weight(
shape=(feature_dim,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
name='%s_bk' % self.name,
)
self.Wv = self.add_weight(
shape=(int(v[-1]), feature_dim),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
name='%s_Wv' % self.name,
)
if self.use_bias:
self.bv = self.add_weight(
shape=(feature_dim,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
name='%s_bv' % self.name,
)
self.Wo = self.add_weight(
shape=(feature_dim, feature_dim),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
name='%s_Wo' % self.name,
)
if self.use_bias:
self.bo = self.add_weight(
shape=(feature_dim,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
name='%s_bo' % self.name,
)
super(MultiHeadAttention, self).build(input_shape)
@staticmethod
def _reshape_to_batches(x, head_num):
input_shape = K.shape(x)
batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
head_dim = feature_dim // head_num
x = K.reshape(x, (batch_size, seq_len, head_num, head_dim))
x = K.permute_dimensions(x, [0, 2, 1, 3])
return K.reshape(x, (batch_size * head_num, seq_len, head_dim))
@staticmethod
def _reshape_attention_from_batches(x, head_num):
input_shape = K.shape(x)
batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim))
return K.permute_dimensions(x, [0, 2, 1, 3])
@staticmethod
def _reshape_from_batches(x, head_num):
input_shape = K.shape(x)
batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2]
x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim))
x = K.permute_dimensions(x, [0, 2, 1, 3])
return K.reshape(x, (batch_size // head_num, seq_len, feature_dim * head_num))
@staticmethod
def _reshape_mask(mask, head_num):
if mask is None:
return mask
seq_len = K.shape(mask)[1]
mask = K.expand_dims(mask, axis=1)
mask = K.tile(mask, [1, head_num, 1])
return K.reshape(mask, (-1, seq_len))
def call(self, inputs, mask=None):
if isinstance(inputs, list):
q, k, v = inputs
else:
q = k = v = inputs
if isinstance(mask, list):
q_mask, k_mask, v_mask = mask
else:
q_mask = k_mask = v_mask = mask
q = K.dot(q, self.Wq)
k = K.dot(k, self.Wk)
v = K.dot(v, self.Wv)
if self.use_bias:
q += self.bq
k += self.bk
v += self.bv
if self.activation is not None:
q = self.activation(q)
k = self.activation(k)
v = self.activation(v)
scaled_dot_product_attention = ScaledDotProductAttention(
history_only=self.history_only,
name='%s-Attention' % self.name,
)
y = scaled_dot_product_attention(
inputs=[
self._reshape_to_batches(q, self.head_num),
self._reshape_to_batches(k, self.head_num),
self._reshape_to_batches(v, self.head_num),
],
mask=[
self._reshape_mask(q_mask, self.head_num),
self._reshape_mask(k_mask, self.head_num),
self._reshape_mask(v_mask, self.head_num),
],
)
y = self._reshape_from_batches(y, self.head_num)
y = K.dot(y, self.Wo)
if self.use_bias:
y += self.bo
if self.activation is not None:
y = self.activation(y)
input_shape = [K.int_shape(q), K.int_shape(k), K.int_shape(v)]
output_shape = self.compute_output_shape(input_shape)
if output_shape[1] is not None:
output_shape = (-1,) + output_shape[1:]
y = K.reshape(y, output_shape)
return y
LayerNorm
代码:
class LayerNorm(Layer):
def __init__(self,
center=True,
scale=False,
epsilon=None,
gamma_initializer='ones',
beta_initializer='zeros',
gamma_regularizer=None,
beta_regularizer=None,
gamma_constraint=None,
beta_constraint=None,
**kwargs
):
super(LayerNorm, self).__init__(**kwargs)
self.supports_masking = True
self.center = center
self.scale = scale
if epsilon is None:
epsilon = K.epsilon() * K.epsilon()
self.epsilon = epsilon
self.gamma_initializer = keras.initializers.get(gamma_initializer)
self.beta_initializer = keras.initializers.get(beta_initializer)
self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
self.beta_regularizer = keras.regularizers.get(beta_regularizer)
self.gamma_constraint = keras.constraints.get(gamma_constraint)
self.beta_constraint = keras.constraints.get(beta_constraint)
self.gamma, self.beta = 0., 0.
def call(self, inputs, **kwargs):
mean = K.mean(inputs, axis=-1, keepdims=True)
variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
std = K.sqrt(variance + self.epsilon)
outputs = (inputs - mean) / std
if self.scale:
outputs *= self.gamma
if self.center:
outputs += self.beta
return outputs
加上Add、FFN,形成一个完整的transformer block
def transformer_block(x,prefix):
O_seq = MultiHeadAttention(head_num=8,name=f'{prefix}_att1')(x)
O_seq_Add1 = Add(name=f'{prefix}_add1')([x,O_seq])
O_seq_LN1 = LayerNorm(name=f'{prefix}_LN1')(O_seq_Add1)
O_seq_fc1 = Dense(model_dim * 4,activation='relu',name=f'{prefix}_fc1')(O_seq_LN1)
O_seq_fc2 = Dense(model_dim,name=f'{prefix}_fc2')(O_seq_fc1)
O_seq_Add2 = Add(name=f'{prefix}_add2')([O_seq_LN1,O_seq_fc2])
O_seq_Add2 = add([O_seq_LN1,O_seq_fc2])
O_seq_LN2 = LayerNorm(name=f'{prefix}_LN2')(O_seq_Add2)
return O_seq_LN2
完整模型定义
MAX_LEN = 512
MODEL_DIM = 512
def load_word_embedding(filepath):
embeddings_index = {}
f = open(filepath, encoding='utf8')
for line in tqdm(f):
values = line.split()
word = ''.join(values[:-MODEL_DIM])
coefs = np.asarray(values[-MODEL_DIM:], dtype='float32')
embeddings_index[word] = coefs
f.close()
return embeddings_index
def build_matrix(word_index, path):
embedding_index = load_word_embedding(path)
embedding_matrix = np.zeros((len(word_index) + 1, MODEL_DIM))
for word, i in word_index.items():
if word in embedding_index:
embedding_matrix[i] = embedding_index[word]
return embedding_matrix
def transformer_block(x,prefix):
O_seq = MultiHeadAttention(head_num=8,name=f'{prefix}_att1')(x)
O_seq = Dropout(0.1,name=f'{prefix}_do1')(O_seq)
O_seq_Add1 = Add(name=f'{prefix}_add1')([x,O_seq])
O_seq_LN1 = LayerNorm(name=f'{prefix}_LN1')(O_seq_Add1)
O_seq_fc1 = Dense(MODEL_DIM * 4,activation='relu',name=f'{prefix}_fc1')(O_seq_LN1)
O_seq_fc2 = Dense(MODEL_DIM,name=f'{prefix}_fc2')(O_seq_fc1)
O_seq_fc2 = Dropout(0.1,name=f'{prefix}_do2')(O_seq_fc2)
O_seq_Add2 = Add(name=f'{prefix}_add2')([O_seq_LN1,O_seq_fc2])
O_seq_Add2 = add([O_seq_LN1,O_seq_fc2])
O_seq_LN2 = LayerNorm(name=f'{prefix}_LN2')(O_seq_Add2)
return O_seq_LN2
def build_model(embedding_matrix, num_class = 2):
words = Input(shape=(MAX_LEN,),name='inputs',dtype='int32')
embeddings = Embedding(*embedding_matrix.shape, weights=[embedding_matrix], trainable=True)(words)
embeddings = Position_Embedding()(embeddings)
embeddings = Dropout(0.1)(embeddings)
seq_len = K.shape(words)[1]
O_seq1 = transformer_block(embeddings,prefix='t1')
O_seq2 = transformer_block(O_seq1,prefix='t2')
O_seq3 = transformer_block(O_seq2,prefix='t3')
O_seq4 = transformer_block(O_seq3,prefix='t4')
O_seq5 = transformer_block(O_seq4,prefix='t5')
O_seq6 = transformer_block(O_seq5,prefix='t6')
O_seq = Add()([O_seq4,O_seq5,O_seq6])
O_seq = GlobalAveragePooling1D()(O_seq)
O_seq = Dropout(0.1)(O_seq)
result = Dense(num_class, activation='softmax', name='outputs')(O_seq)
model = Model(inputs=words, outputs=result)
opt=keras.optimizers.Adam(lr=5e-5)
model.compile(loss='categorical_crossentropy',optimizer=opt, metrics=['acc'])
model.summary()
return model
题外话
如果只用上面的这些代码来跑模型,你可能会发现模型收敛很困难,因为没有做learning rate的warm up,而这其实是很重要的,如果发现模型不收敛,可以尝试把LayerNorm放到attention和FFN之前,或者先尝试把Learning rate调小一点(5e-5及以下),还可以加上warmup策略。
参考:https://zhuanlan.zhihu.com/p/84614490 附上keras的warmup的实现,来源: https://gitee.com/yangyin2020/keras_classfication/blob/master/warmup_cosine_decay_scheduler.py
可以自己根据需要修改:
import numpy as np
from tensorflow import keras
from keras import backend as K
def cosine_decay_with_warmup(global_step,
learning_rate_base,
total_steps,
warmup_learning_rate=0.0,
warmup_steps=0,
hold_base_rate_steps=0):
"""Cosine decay schedule with warm up period.
Cosine annealing learning rate as described in:
Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
ICLR 2017. https://arxiv.org/abs/1608.03983
In this schedule, the learning rate grows linearly from warmup_learning_rate
to learning_rate_base for warmup_steps, then transitions to a cosine decay
schedule.
Arguments:
global_step {int} -- global step.
learning_rate_base {float} -- base learning rate.
total_steps {int} -- total number of training steps.
Keyword Arguments:
warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
warmup_steps {int} -- number of warmup steps. (default: {0})
hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
before decaying. (default: {0})
Returns:
a float representing learning rate.
Raises:
ValueError: if warmup_learning_rate is larger than learning_rate_base,
or if warmup_steps is larger than total_steps.
"""
if total_steps < warmup_steps:
raise ValueError('total_steps must be larger or equal to '
'warmup_steps.')
learning_rate = 0.5 * learning_rate_base * (1 + np.cos(
np.pi *
(global_step - warmup_steps - hold_base_rate_steps
) / float(total_steps - warmup_steps - hold_base_rate_steps)))
if hold_base_rate_steps > 0:
learning_rate = np.where(global_step > warmup_steps + hold_base_rate_steps,
learning_rate, learning_rate_base)
if warmup_steps > 0:
if learning_rate_base < warmup_learning_rate:
raise ValueError('learning_rate_base must be larger or equal to '
'warmup_learning_rate.')
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
warmup_rate = slope * global_step + warmup_learning_rate
learning_rate = np.where(global_step < warmup_steps, warmup_rate,
learning_rate)
return np.where(global_step > total_steps, 0.0, learning_rate)
class WarmUpCosineDecayScheduler(keras.callbacks.Callback):
"""Cosine decay with warmup learning rate scheduler
"""
def __init__(self,
learning_rate_base,
total_steps,
global_step_init=0,
warmup_learning_rate=0.0,
warmup_steps=0,
hold_base_rate_steps=0,
verbose=0):
"""Constructor for cosine decay with warmup learning rate scheduler.
Arguments:
learning_rate_base {float} -- base learning rate.
total_steps {int} -- total number of training steps.
Keyword Arguments:
global_step_init {int} -- initial global step, e.g. from previous checkpoint.
warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
warmup_steps {int} -- number of warmup steps. (default: {0})
hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
before decaying. (default: {0})
verbose {int} -- 0: quiet, 1: update messages. (default: {0})
"""
super(WarmUpCosineDecayScheduler, self).__init__()
self.learning_rate_base = learning_rate_base
self.total_steps = total_steps
self.global_step = global_step_init
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
self.hold_base_rate_steps = hold_base_rate_steps
self.verbose = verbose
self.learning_rates = []
def on_batch_end(self, batch, logs=None):
self.global_step = self.global_step + 1
lr = K.get_value(self.model.optimizer.lr)
self.learning_rates.append(lr)
def on_batch_begin(self, batch, logs=None):
lr = cosine_decay_with_warmup(global_step=self.global_step,
learning_rate_base=self.learning_rate_base,
total_steps=self.total_steps,
warmup_learning_rate=self.warmup_learning_rate,
warmup_steps=self.warmup_steps,
hold_base_rate_steps=self.hold_base_rate_steps)
K.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
print('\nBatch %05d: setting learning '
'rate to %s.' % (self.global_step + 1, lr))
if __name__ == '__main__':
from keras.models import Sequential
from keras.layers import Dense
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
sample_count = 12608
epochs = 50
warmup_epoch = 10
batch_size = 16
learning_rate_base = 0.0001
total_steps = int(epochs * sample_count / batch_size)
warmup_steps = int(warmup_epoch * sample_count / batch_size)
data = np.random.random((sample_count, 100))
labels = np.random.randint(10, size=(sample_count, 1))
one_hot_labels = keras.utils.to_categorical(labels, num_classes=10)
warmup_batches = warmup_epoch * sample_count / batch_size
warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
total_steps=total_steps,
warmup_learning_rate=4e-06,
warmup_steps=warmup_steps,
hold_base_rate_steps=5,
)
model.fit(data, one_hot_labels, epochs=epochs, batch_size=batch_size,
verbose=0, callbacks=[warm_up_lr])
import matplotlib.pyplot as plt
plt.plot(warm_up_lr.learning_rates)
plt.xlabel('Step', fontsize=20)
plt.ylabel('lr', fontsize=20)
plt.axis([0, total_steps, 0, learning_rate_base*1.1])
plt.xticks(np.arange(0, epochs, 1))
plt.grid()
plt.title('Cosine decay with warmup', fontsize=20)
plt.show()
|