?#2021SC@SDUSC
这部分将分析rnn_cell.py和rnn_layers.py部分
RNN cells. py
RNNCell 表示 .NestedMap 中的循环状态。 zero_state(theta, batch_size) 返回初始状态,由每个子类定义。 从状态中,每个子类都定义了GetOutput() 来提取输出张量。 RNNCell.FProp 定义了前向函数:: (theta, state0, 输入) -> state1, extras 所有参数和返回值都是.NestedMap 。 每个子类都定义了这些 .NestedMap 应该具有的字段。 extras 是一个 .NestedMap ,其中包含一些 FProp 计算以促进反向传播的中间结果。 zero_state(theta, batch_size) 、state0 和 state1 都是兼容的 .NestedMap (参见 .NestedMap.IsCompatible )。 即,它们递归地具有相同的键。 此外,这些 .NestedMap 中相应的张量具有相同的形状和数据类型。
@classmethod
def Params(cls):
p = super().Params()
p.Define('inputs_arity', 1,
'number of tensors expected for the inputs.act to FProp.')
p.Define('num_input_nodes', 0, 'Number of input nodes.')
p.Define(
'num_output_nodes', 0,
'Number of output nodes. If num_hidden_nodes is 0, also used as '
'cell size.')
p.Define(
'reset_cell_state', False,
('Set True to support resetting cell state in scenarios where multiple '
'inputs are packed into a single training example. The RNN layer '
'should provide reset_mask inputs in addition to act and padding if '
'this flag is set.'))
p.Define(
'zero_state_init_params', py_utils.DefaultRNNCellStateInit(),
'Parameters that define how the initial state values are set '
'for each cell. Must be one of the static functions defined in '
'py_utils.RNNCellStateInit.')
return p
def FProp(self, theta, state0, inputs): 这里的默认实现假设 cell forward 函数由两个函数组成: _Gates(_Mix(theta,state0,inputs),theta,state0,inputs) _Mix 的结果存放在 extras 中以方便反向传播。 如果 reset_cell_state 为 True,则可选地应用 _ResetState 。除了其他输入之外,RNN 层还应提供“reset_mask”输入。 reset_mask 输入在运行 _Mix() 和 _Gates() 之前应该被重置为默认值(零)的时间步长为 0,否则为 1。这是为了支持诸如打包输入之类的用例,其中在单个输入示例序列中输入多个样本,并且需要相互屏蔽。例如,如果打包在一起的两个例子是 [‘good’, ‘day’] -> [‘guten-tag’] 和 [‘thanks’] -> [‘danke’] 产生 [‘good’, 'day ', ‘thanks’] -> [‘guten-tag’, ‘danke’],源 reset_mask 将为 [1, 1, 0],目标重置掩码将为 [1, 0]。这些 id 旨在为彼此不同的示例启用屏蔽计算。 参数: theta:一个.NestedMap 对象,包含该层及其子层的权重值。 state0:之前的循环状态。一个.NestedMap 。 输入:单元格的输入。一个.NestedMap 。 返回: 元组 (state1, extras)。 - state1:??下一个循环状态。一个.NestedMap 。 - 附加:中间结果以促进反向传播。一个.NestedMap 。
assert isinstance(inputs.act, list)
assert self.params.inputs_arity == len(inputs.act)
if self.params.reset_cell_state:
state0_modified = self._ResetState(state0.DeepCopy(), inputs)
else:
state0_modified = state0
xmw = self._Mix(theta, state0_modified, inputs)
state1 = self._Gates(xmw, theta, state0_modified, inputs)
return state1, py_utils.NestedMap()
def _GetBias(self, theta): 获取要添加的偏置向量。 包括forget_gate_bias 之类的调整。 直接使用 this 而不是 ‘b’ 变量,因为以这种方式包含调整允许 const-prop 在推理时消除调整。 参数: theta:一个.NestedMap 对象,包含该层及其子层的权重值。 返回: 偏置向量。
p = self.params
if p.enable_lstm_bias:
b = theta.b
else:
b = tf.zeros([self.num_gates * self.hidden_size], dtype=p.dtype)
if p.forget_gate_bias != 0.0:
# Apply the forget gate bias directly to the bias vector.
if not p.couple_input_forget_gates:
# Normal 4 gate bias (i_i, i_g, f_g, o_g).
adjustment = (
tf.ones([4, self.hidden_size], dtype=p.dtype) * tf.expand_dims(
tf.constant([0., 0., p.forget_gate_bias, 0.], dtype=p.dtype),
axis=1))
else:
# 3 gates with coupled input/forget (i_i, f_g, o_g).
adjustment = (
tf.ones([3, self.hidden_size], dtype=p.dtype) * tf.expand_dims(
tf.constant([0., p.forget_gate_bias, 0.], dtype=p.dtype),
axis=1))
adjustment = tf.reshape(adjustment, [self.num_gates * self.hidden_size])
b = b + adjustment
return b
rnn_layers.py
函数GeneratePackedInputResetMask从 segment_id 生成 RNN 单元的掩码输入。 参数: segment_id:形状为 [time, batch_size, 1] 的张量。 is_reverse:如果输入以相反的顺序馈送到 RNN,则为真。 返回: reset_mask - 形状为 [time, batch_size, 1] 的张量。 对于样本设置为 0 需要重置状态的地方(在示例边界处),否则为 1。
segment_id_left = segment_id[:-1]
segment_id_right = segment_id[1:]
# Mask is a [t-1, bs, 1] tensor.
reset_mask = tf.cast(
tf.equal(segment_id_left, segment_id_right), dtype=segment_id.dtype)
mask_padding_shape = tf.concat(
[tf.ones([1], dtype=tf.int32),
tf.shape(segment_id)[1:]], axis=0)
mask_padding = tf.ones(mask_padding_shape, dtype=segment_id.dtype)
if is_reverse:
reset_mask = tf.concat([reset_mask, mask_padding], axis=0)
else:
reset_mask = tf.concat([mask_padding, reset_mask], axis=0)
return reset_mask
Class RNN: 静态展开的RNN 形参:
def Params(cls):
p = super().Params()
p.Define('cell', rnn_cell.LSTMCellSimple.Params(),
'Configs for the RNN cell.')
p.Define(
'sequence_length', 0,
'Sequence length to unroll. If > 0, then will unroll to this fixed '
'size. If 0, then will unroll to accommodate the size of the inputs '
'for each call to FProp.')
p.Define('reverse', False,
'Whether or not to unroll the sequence in reversed order.')
p.Define('packed_input', False, 'To reset states for packed inputs.')
return p
初始化:
def __init__(self, params):
super().__init__(params)
p = self.params
assert not p.packed_input, ('Packed inputs are currently not supported by '
'Static RNN')
p.cell.reset_cell_state = p.packed_input
assert p.sequence_length >= 0
self.CreateChild('cell', p.cell)
函数FProp计算 RNN 前向传递。 参数: theta:一个.NestedMap 对象,包含该层及其子层的权重值。 输入:单个张量或基数等于的张量元组 rnn_cell.inputs_arity。 对于每个输入张量,假设第一维是时间、第二维批次和第三维深度。 填充:张量。 第一个暗淡是时间,第二个暗淡是批次,第三个暗淡预计为 1。 state0:如果不是 None,则为 .NestedMap 中的初始 rnn 状态。 默认为单元格的零状态。 返回: [时间、batch、dim]的张量。 最终的循环状态。
p = self.params
assert isinstance(self.cell, rnn_cell.RNNCell)
if p.sequence_length == 0:
if isinstance(inputs, (tuple, list)):
sequence_length = len(inputs)
else:
sequence_length = py_utils.GetShape(inputs)[0]
else:
sequence_length = p.sequence_length
assert sequence_length >= 1, ('Sequence length must be defined or inputs '
'must have fixed shapes.')
with tf.name_scope(p.name):
inputs_sequence = tf.unstack(inputs, num=sequence_length)
paddings_sequence = tf.unstack(paddings, num=sequence_length)
# We start from all 0 states.
if state0:
state = state0
else:
inputs0 = py_utils.NestedMap(
act=[inputs_sequence[0]], padding=paddings_sequence[0])
state = self.cell.zero_state(theta.cell, self.cell.batch_size(inputs0))
outputs = [None] * sequence_length
if p.reverse:
sequence = range(sequence_length - 1, -1, -1)
else:
sequence = range(0, sequence_length, 1)
for idx in sequence:
cur_input = py_utils.NestedMap(act=[inputs[idx]], padding=paddings[idx])
state, _ = self.cell.FProp(theta.cell, state, cur_input)
outputs[idx] = self.cell.GetOutput(state)
return tf.stack(outputs), state
|