IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Lingvo分析(十) -> 正文阅读

[人工智能]Lingvo分析(十)

?在这里插入图片描述#2021SC@SDUSC

这部分将分析rnn_cell.py和rnn_layers.py部分

RNN cells. py

RNNCell 表示 .NestedMap 中的循环状态。 zero_state(theta, batch_size) 返回初始状态,由每个子类定义。 从状态中,每个子类都定义了GetOutput() 来提取输出张量。 RNNCell.FProp 定义了前向函数:: (theta, state0, 输入) -> state1, extras 所有参数和返回值都是.NestedMap。 每个子类都定义了这些 .NestedMap 应该具有的字段。 extras 是一个 .NestedMap,其中包含一些 FProp 计算以促进反向传播的中间结果。 zero_state(theta, batch_size)state0state1 都是兼容的 .NestedMap(参见 .NestedMap.IsCompatible)。 即,它们递归地具有相同的键。 此外,这些 .NestedMap 中相应的张量具有相同的形状和数据类型。

@classmethod
  def Params(cls):
    p = super().Params()
    p.Define('inputs_arity', 1,
             'number of tensors expected for the inputs.act to FProp.')
    p.Define('num_input_nodes', 0, 'Number of input nodes.')
    p.Define(
        'num_output_nodes', 0,
        'Number of output nodes. If num_hidden_nodes is 0, also used as '
        'cell size.')
    p.Define(
        'reset_cell_state', False,
        ('Set True to support resetting cell state in scenarios where multiple '
         'inputs are packed into a single training example. The RNN layer '
         'should provide reset_mask inputs in addition to act and padding if '
         'this flag is set.'))
    p.Define(
        'zero_state_init_params', py_utils.DefaultRNNCellStateInit(),
        'Parameters that define how the initial state values are set '
        'for each cell. Must be one of the static functions defined in '
        'py_utils.RNNCellStateInit.')
    return p

def FProp(self, theta, state0, inputs):
这里的默认实现假设 cell forward 函数由两个函数组成: _Gates(_Mix(theta,state0,inputs),theta,state0,inputs) _Mix 的结果存放在 extras 中以方便反向传播。 如果 reset_cell_state 为 True,则可选地应用 _ResetState。除了其他输入之外,RNN 层还应提供“reset_mask”输入。 reset_mask 输入在运行 _Mix()_Gates() 之前应该被重置为默认值(零)的时间步长为 0,否则为 1。这是为了支持诸如打包输入之类的用例,其中在单个输入示例序列中输入多个样本,并且需要相互屏蔽。例如,如果打包在一起的两个例子是 [‘good’, ‘day’] -> [‘guten-tag’] 和 [‘thanks’] -> [‘danke’] 产生 [‘good’, 'day ', ‘thanks’] -> [‘guten-tag’, ‘danke’],源 reset_mask 将为 [1, 1, 0],目标重置掩码将为 [1, 0]。这些 id 旨在为彼此不同的示例启用屏蔽计算。 参数: theta:一个.NestedMap 对象,包含该层及其子层的权重值。 state0:之前的循环状态。一个.NestedMap。 输入:单元格的输入。一个.NestedMap。 返回: 元组 (state1, extras)。 - state1:??下一个循环状态。一个.NestedMap。 - 附加:中间结果以促进反向传播。一个.NestedMap

assert isinstance(inputs.act, list)
    assert self.params.inputs_arity == len(inputs.act)
    if self.params.reset_cell_state:
      state0_modified = self._ResetState(state0.DeepCopy(), inputs)
    else:
      state0_modified = state0
    xmw = self._Mix(theta, state0_modified, inputs)
    state1 = self._Gates(xmw, theta, state0_modified, inputs)
    return state1, py_utils.NestedMap()

def _GetBias(self, theta): 获取要添加的偏置向量。
包括forget_gate_bias 之类的调整。 直接使用 this 而不是 ‘b’ 变量,因为以这种方式包含调整允许 const-prop 在推理时消除调整。 参数: theta:一个.NestedMap 对象,包含该层及其子层的权重值。 返回: 偏置向量。

p = self.params
    if p.enable_lstm_bias:
      b = theta.b
    else:
      b = tf.zeros([self.num_gates * self.hidden_size], dtype=p.dtype)
    if p.forget_gate_bias != 0.0:
      # Apply the forget gate bias directly to the bias vector.
      if not p.couple_input_forget_gates:
        # Normal 4 gate bias (i_i, i_g, f_g, o_g).
        adjustment = (
            tf.ones([4, self.hidden_size], dtype=p.dtype) * tf.expand_dims(
                tf.constant([0., 0., p.forget_gate_bias, 0.], dtype=p.dtype),
                axis=1))
      else:
        # 3 gates with coupled input/forget (i_i, f_g, o_g).
        adjustment = (
            tf.ones([3, self.hidden_size], dtype=p.dtype) * tf.expand_dims(
                tf.constant([0., p.forget_gate_bias, 0.], dtype=p.dtype),
                axis=1))
      adjustment = tf.reshape(adjustment, [self.num_gates * self.hidden_size])
      b = b + adjustment

    return b

rnn_layers.py

函数GeneratePackedInputResetMask从 segment_id 生成 RNN 单元的掩码输入。
参数:
segment_id:形状为 [time, batch_size, 1] 的张量。
is_reverse:如果输入以相反的顺序馈送到 RNN,则为真。
返回:
reset_mask - 形状为 [time, batch_size, 1] 的张量。 对于样本设置为 0
需要重置状态的地方(在示例边界处),否则为 1。

segment_id_left = segment_id[:-1]
  segment_id_right = segment_id[1:]

  # Mask is a [t-1, bs, 1] tensor.
  reset_mask = tf.cast(
      tf.equal(segment_id_left, segment_id_right), dtype=segment_id.dtype)
  mask_padding_shape = tf.concat(
      [tf.ones([1], dtype=tf.int32),
       tf.shape(segment_id)[1:]], axis=0)
  mask_padding = tf.ones(mask_padding_shape, dtype=segment_id.dtype)
  if is_reverse:
    reset_mask = tf.concat([reset_mask, mask_padding], axis=0)
  else:
    reset_mask = tf.concat([mask_padding, reset_mask], axis=0)
  return reset_mask

Class RNN:
静态展开的RNN
形参:

def Params(cls):
    p = super().Params()
    p.Define('cell', rnn_cell.LSTMCellSimple.Params(),
             'Configs for the RNN cell.')
    p.Define(
        'sequence_length', 0,
        'Sequence length to unroll. If > 0, then will unroll to this fixed '
        'size. If 0, then will unroll to accommodate the size of the inputs '
        'for each call to FProp.')
    p.Define('reverse', False,
             'Whether or not to unroll the sequence in reversed order.')
    p.Define('packed_input', False, 'To reset states for packed inputs.')
    return p

初始化:

def __init__(self, params):
    super().__init__(params)
    p = self.params
    assert not p.packed_input, ('Packed inputs are currently not supported by '
                                'Static RNN')
    p.cell.reset_cell_state = p.packed_input
    assert p.sequence_length >= 0
    self.CreateChild('cell', p.cell)

函数FProp计算 RNN 前向传递。
参数: theta:一个.NestedMap 对象,包含该层及其子层的权重值。 输入:单个张量或基数等于的张量元组 rnn_cell.inputs_arity。 对于每个输入张量,假设第一维是时间、第二维批次和第三维深度。 填充:张量。 第一个暗淡是时间,第二个暗淡是批次,第三个暗淡预计为 1。 state0:如果不是 None,则为 .NestedMap 中的初始 rnn 状态。 默认为单元格的零状态。 返回: [时间、batch、dim]的张量。 最终的循环状态。

 p = self.params
    assert isinstance(self.cell, rnn_cell.RNNCell)
    if p.sequence_length == 0:
      if isinstance(inputs, (tuple, list)):
        sequence_length = len(inputs)
      else:
        sequence_length = py_utils.GetShape(inputs)[0]
    else:
      sequence_length = p.sequence_length
    assert sequence_length >= 1, ('Sequence length must be defined or inputs '
                                  'must have fixed shapes.')
    with tf.name_scope(p.name):
      inputs_sequence = tf.unstack(inputs, num=sequence_length)
      paddings_sequence = tf.unstack(paddings, num=sequence_length)
      # We start from all 0 states.
      if state0:
        state = state0
      else:
        inputs0 = py_utils.NestedMap(
            act=[inputs_sequence[0]], padding=paddings_sequence[0])
        state = self.cell.zero_state(theta.cell, self.cell.batch_size(inputs0))
      outputs = [None] * sequence_length
      if p.reverse:
        sequence = range(sequence_length - 1, -1, -1)
      else:
        sequence = range(0, sequence_length, 1)
      for idx in sequence:
        cur_input = py_utils.NestedMap(act=[inputs[idx]], padding=paddings[idx])
        state, _ = self.cell.FProp(theta.cell, state, cur_input)
        outputs[idx] = self.cell.GetOutput(state)
      return tf.stack(outputs), state
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-28 22:55:34  更:2021-12-28 22:55:54 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 20:52:52-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码