1.RNN与mRNN
mRNN(Multiplicative Recurrent Neural Network)原论文见参考1,其改进在于:
introduce a new RNN variant that uses multiplicative (or “gated”) connections which allow the current input character to determine the transition matrix from one hidden state vector to the next
二者结构的差异如下:
2.LSTM与mLSTM
mLSTM原论文见参考3。
相对于传统LSTM,其主要改进在于:
This mLSTM architecture was motivated by its ability to have both controlled and flexible input-dependent transitions, to allow for fast changes to the distributed hidden representation without erasing information.
代码实现见参考4,5。tf主要内容如下:
def call(self, inputs, state):
m = tf.matmul(inputs, wmx) * tf.matmul(h_prev, wmh)
z = tf.matmul(inputs, wx) + tf.matmul(m, wh) + b
i, f, o, u = tf.split(z, 4, 1)
i = tf.nn.sigmoid(i)
f = tf.nn.sigmoid(f)
o = tf.nn.sigmoid(o)
u = tf.tanh(u)
c = f * c_prev + i * u
h = o * tf.tanh(c)
return h, (c, h)
二者之间的关系与RNN/mRNN基本一致。可谓短小强悍!
未完待续
参考文献
[1] Multiplicative RNN [2] Written Memories: Understanding, Deriving and Extending the LSTM [3] MULTIPLICATIVE LSTM FOR SEQUENCE MODELLING [4] mLSMT tf实现 [5] mLSMT jax实现
|