IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch) -> 正文阅读

[人工智能]【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

在这里插入图片描述

举例讲解transformer的输入输出细节

数据从输入到encoder到decoder输出这个过程中的流程(以机器翻译为例子):

encoder

对于机器翻译来说,一个样本是由原始句子和翻译后的句子组成的。比如原始句子是: “我爱机器学习”,那么翻译后是 ’i love machine learning‘。 则该一个样本就是由“我爱机器学习”和 “i love machine learning” 组成。

这个样本的原始句子的单词长度是length=4,即‘我’ ‘爱’ ‘机器’ ‘学习’。经过embedding后每个词的embedding向量是512。那么“我爱机器学习”这个句子的embedding后的维度是[4,512 ] (若是批量输入,则embedding后的维度是[batch, 4, 512])。

padding

假设样本中句子的最大长度是10,那么对于长度不足10的句子,需要补足到10个长度,shape就变为[10, 512], 补全的位置上的embedding数值自然就是0了

Padding Mask

对于输入序列一般要进行padding补齐,也就是说设定一个统一长度N,在较短的序列后面填充0到长度为N。对于那些补零的数据来说,attention机制不应该把注意力放在这些位置上,所以需要进行一些处理。具体的做法是,把这些位置的值加上一个非常大的负数(负无穷),这样经过softmax后,这些位置的权重就会接近0。Transformer的padding mask实际上是一个张量,每个值都是一个Boolean,值为false的地方就是要进行处理的地方。
在这里插入图片描述
在这里插入图片描述

Positional Embedding

得到补全后的句子embedding向量后,直接输入encoder的话,那么是没有考虑到句子中的位置顺序关系的。此时需要再加一个位置向量,位置向量在模型训练中有特定的方式,可以表示每个词的位置或者不同词之间的距离;总之,核心思想是在attention计算时提供有效的距离信息。

attention

参考我的博文(2021李宏毅)机器学习-Self-attention

FeedForward

略,很简单

add/Norm

经过add/norm后的隐藏输出的shape也是[10,512]。

encoder输入输出

从输入开始,再从头理一遍单个encoder这个过程:

  1. 输入x
  2. x 做一个层归一化: x1 = norm(x)
  3. 进入多头self-attention: x2 = self_attention(x1)
  4. 残差加成:x3 = x + x2
  5. 再做个层归一化:x4 = norm(x3)
  6. 经过前馈网络: x5 = feed_forward(x4)
  7. 残差加成: x6 = x3 + x5
  8. 输出x6
    在这里插入图片描述
    这就是Encoder所做的工作

decoder

在这里插入图片描述
注意encoder的输出并没直接作为decoder的直接输入。

训练的时候,1.初始decoder的time step为1时(也就是第一次接收输入),其输入为一个特殊的token,可能是目标序列开始的token(如),也可能是源序列结尾的token(如),也可能是其它视任务而定的输入等等,不同源码中可能有微小的差异,其目标则是预测翻译后的第1个单词(token)是什么;2.然后和预测出来的第1个单词一起,再次作为decoder的输入,得到第2个预测单词;3后续依此类推;

具体的例子如下:

样本:“我/爱/机器/学习”和 “i/ love /machine/ learning”
训练:

  1. 把“我/爱/机器/学习”embedding后输入到encoder里去,最后一层的encoder最终输出的outputs [10, 512](假设我们采用的embedding长度为512,而且batch size = 1),此outputs 乘以新的参数矩阵,可以作为decoder里每一层用到的K和V;

  2. 将作为decoder的初始输入,将decoder的最大概率输出词 A1和‘i’做cross entropy计算error。

  3. 将,“i” 作为decoder的输入,将decoder的最大概率输出词 A2 和‘love’做cross entropy计算error。

  4. 将,“i”,“love” 作为decoder的输入,将decoder的最大概率输出词A3和’machine’ 做cross entropy计算error。

  5. 将,“i”,"love ",“machine” 作为decoder的输入,将decoder最大概率输出词A4和‘learning’做cross entropy计算error。

  6. 将,“i”,"love ",“machine”,“learning” 作为decoder的输入,将decoder最大概率输出词A5和终止符做cross entropy计算error。

Sequence Mask

上述训练过程是挨个单词串行进行的,那么能不能并行进行呢,当然可以。可以看到上述单个句子训练时候,输入到 decoder的分别是

,“i”

,“i”,“love”

,“i”,"love ",“machine”

,“i”,"love ",“machine”,“learning”

那么为何不将这些输入组成矩阵,进行输入呢?这些输入组成矩阵形式如下:

,“i”

,“i”,“love”

,“i”,"love ",“machine”

,“i”,"love ",“machine”,“learning” 】

怎么操作得到这个矩阵呢?

将decoder在上述2-6步次的输入补全为一个完整的句子

【,“i”,"love ",“machine”,“learning”
,“i”,"love ",“machine”,“learning”
,“i”,"love ",“machine”,“learning”
,“i”,"love ",“machine”,“learning”
,“i”,"love ",“machine”,“learning”】

然后将上述矩阵矩阵乘以一个 mask矩阵

【1 0 0 0 0

1 1 0 0 0

1 1 1 0 0

1 1 1 1 0

1 1 1 1 1 】

这样是不是就得到了

,“i”

,“i”,“love”

,“i”,"love ",“machine”

,“i”,"love ",“machine”,“learning” 】

这样将这个矩阵输入到decoder(其实你可以想一下,此时这个矩阵是不是类似于批处理,矩阵的每行是一个样本,只是每行的样本长度不一样,每行输入后最终得到一个输出概率分布,作为矩阵输入的话一下可以得到5个输出概率分布)。
这样就可以进行并行计算进行训练了。

测试

训练好模型, 测试的时候,比如用 '机器学习很有趣’当作测试样本,得到其英语翻译。

这一句经过encoder后得到输出tensor,送入到decoder(并不是当作decoder的直接输入):

  1. 然后用起始符当作decoder的 输入,得到输出 machine

  2. 用 + machine 当作输入得到输出 learning

  3. 用 + machine + learning 当作输入得到is

  4. 用 + machine + learning + is 当作输入得到interesting

  5. 用 + machine + learning + is + interesting 当作输入得到 结束符号

得到了完整的翻译 ‘machine learning is interesting’

可以看到,在测试过程中,只能一个单词一个单词的进行输出,是串行进行的。

Transformer pytorch代码实现

数据预处理

import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
        # enc_input           dec_input         dec_output
        ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
        ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
src_vocab_size = len(src_vocab)

tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

src_len = 5 # enc_input max sequence length
tgt_len = 6 # dec_input(=dec_output) max sequence length

#定义超参数
# Transformer Parameters
d_model = 512  # Embedding Size 字嵌入 & 位置嵌入的维度,这俩值是相同的,因此用一个变量就行了
d_ff = 2048 # FeedForward dimension FeedForward 层隐藏神经元个数
d_k = d_v = 64  # dimension of K(=Q), V   Q、K、V 向量的维度,其中 Q 与 K 的维度必须相等,V 的维度没有限制,不过为了方便起见,我都设为 64
n_layers = 6  # number of Encoder of Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention

def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
      enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
      dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
      dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

      enc_inputs.extend(enc_input)
      dec_inputs.extend(dec_input)
      dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)



class MyDataSet(Data.Dataset):
    def __init__(self,enc_inputs,dec_inputs,dec_outputs):
        super(MyDataSet,self).__init__()
        self.enc_inputs = enc_inputs
        self.dec_inputs = dec_inputs
        self.dec_outputs = dec_outputs
        
    def __len__(self):
        return self.enc_inputs.shape[0]
    
    def __getitem__(self,idx):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)

Positional Encoding

下面是论文中的公式,总之这层就是解决使用attention丢失位置信息的解决办法
P E ( p o s , 2 i ) = sin ? ( p o s / 1000 0 2 i / d nodele? ) P E ( p o s , 2 i + 1 ) = cos ? ( p o s / 1000 0 2 i / d model? ) \begin{gathered} P E(p o s, 2 i)=\sin \left(p o s / 10000^{2 i / d_{\text {nodele }}}\right) \\ P E(p o s, 2 i+1)=\cos \left(p o s / 10000^{2 i / d_{\text {model }}}\right) \end{gathered} PE(pos,2i)=sin(pos/100002i/dnodele??)PE(pos,2i+1)=cos(pos/100002i/dmodel??)?
在这里插入图片描述

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        pe:[max_len, d_model] max_len 限定每个句子最长由多少个词构成
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

Pad Mask

这就是为了处理,句子不一样长,但是输入有需要定长,不够长的pad填充,但是计算又不需要这个pad,所以mask掉

这个函数最核心的一句代码是 seq_k.data.eq(0),这句的作用是返回一个大小和 seq_k 一样的 tensor,只不过里面的值只有 True 和 False。如果 seq_k 某个位置的值等于 0,那么对应位置就是 True,否则即为 False。举个例子,输入为 seq_data = [1, 2, 3, 4, 0],seq_data.data.eq(0) 就会返回 [False, False, False, False, True]

def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len] # 用于升维,为了做attention,计算score矩阵用的
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], True is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

Subsequence Mask

Subsequence Mask 只有 Decoder 会用到,主要作用是屏蔽未来时刻单词的信息。首先通过 np.ones() 生成一个全 1 的方阵,然后通过 np.triu() 生成一个上三角矩阵,下图是 np.triu() 用法
在这里插入图片描述

def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask # [batch_size, tgt_len, tgt_len]

ScaledDotProductAttention

缩放点积注意力
在这里插入图片描述

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
        
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context, attn

MultiHeadAttention

在这里插入图片描述

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context) # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model).cuda()(output + residual), attn

FeedForward Layer

做两次线性变换,残差连接后再跟一个 Layer Norm
在这里插入图片描述

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model).cuda()(output + residual) # [batch_size, seq_len, d_model]

Encoder Layer

可复用的block
在这里插入图片描述

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn

Encoder

在这里插入图片描述

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
        enc_self_attns = []
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return enc_outputs, enc_self_attns

Decoder Layer

在这里插入图片描述

class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention() # 
        self.dec_enc_attn = MultiHeadAttention() # cross attention
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn

Decoder

在这里插入图片描述

在这里插入图片描述

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        '''
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batch_size, src_len, d_model]
        '''
        dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda() # [batch_size, tgt_len, d_model]
        # Decoder输入序列的pad mask矩阵(这个例子中decoder是没有加pad的,实际应用中都是有pad填充的)
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
        # Masked Self_Attention:当前时刻是看不到未来的信息的
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
        # Decoder中把两种mask矩阵相加(既屏蔽了pad的信息,也屏蔽了未来时刻的信息)
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda() # [batch_size, tgt_len, tgt_len]

        # 这个mask主要用于encoder-decoder attention层
        # get_attn_pad_mask主要是enc_inputs的pad mask矩阵(因为enc是处理K,V的,求Attention时是用v1,v2,..vm去加权的,
        # 要把pad对应的v_i的相关系数设为0,这样注意力就不会关注pad向量)
        #                       dec_inputs只是提供expand的size的
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns

Transformer

class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder().cuda()
        self.decoder = Decoder().cuda()
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda()
    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''
        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        
        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
model = Transformer().cuda()
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

训练

for epoch in range(1000):
    for enc_inputs, dec_inputs, dec_outputs in loader:
      '''
      enc_inputs: [batch_size, src_len]
      dec_inputs: [batch_size, tgt_len]
      dec_outputs: [batch_size, tgt_len]
      '''
      enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
      # outputs: [batch_size * tgt_len, tgt_vocab_size]
      outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
      loss = criterion(outputs, dec_outputs.view(-1))
      print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
Epoch: 0001 loss = 2.224822
Epoch: 0002 loss = 2.030541
Epoch: 0003 loss = 1.781882
Epoch: 0004 loss = 1.530569
Epoch: 0005 loss = 1.350049
Epoch: 0006 loss = 1.171564
Epoch: 0007 loss = 0.969550
Epoch: 0008 loss = 0.834517
Epoch: 0009 loss = 0.642394
Epoch: 0010 loss = 0.483570
Epoch: 0011 loss = 0.327683
Epoch: 0012 loss = 0.269178
Epoch: 0013 loss = 0.211673
Epoch: 0014 loss = 0.160442
Epoch: 0015 loss = 0.172483
Epoch: 0016 loss = 0.134756
Epoch: 0017 loss = 0.137094
Epoch: 0018 loss = 0.100166
Epoch: 0019 loss = 0.111751
Epoch: 0020 loss = 0.077123
Epoch: 0021 loss = 0.057618
Epoch: 0022 loss = 0.041987
Epoch: 0023 loss = 0.032597
Epoch: 0024 loss = 0.043513
Epoch: 0025 loss = 0.040776
Epoch: 0026 loss = 0.058874
Epoch: 0027 loss = 0.045291
Epoch: 0028 loss = 0.050369
Epoch: 0029 loss = 0.034177
Epoch: 0030 loss = 0.029485
Epoch: 0031 loss = 0.019343
Epoch: 0032 loss = 0.016446
Epoch: 0033 loss = 0.008740
Epoch: 0034 loss = 0.013603
Epoch: 0035 loss = 0.009225
Epoch: 0036 loss = 0.010144
Epoch: 0037 loss = 0.023320
Epoch: 0038 loss = 0.013741
Epoch: 0039 loss = 0.018951
Epoch: 0040 loss = 0.011962
Epoch: 0041 loss = 0.019953
Epoch: 0042 loss = 0.013491
Epoch: 0043 loss = 0.009314
Epoch: 0044 loss = 0.008740
Epoch: 0045 loss = 0.011457
Epoch: 0046 loss = 0.005985
Epoch: 0047 loss = 0.006249
Epoch: 0048 loss = 0.004624
Epoch: 0049 loss = 0.002393
Epoch: 0050 loss = 0.002759
Epoch: 0051 loss = 0.002254
Epoch: 0052 loss = 0.001777
Epoch: 0053 loss = 0.001847
Epoch: 0054 loss = 0.001632
Epoch: 0055 loss = 0.002175
Epoch: 0056 loss = 0.001579
Epoch: 0057 loss = 0.002228
Epoch: 0058 loss = 0.002176
Epoch: 0059 loss = 0.001884
Epoch: 0060 loss = 0.002646
Epoch: 0061 loss = 0.002424
Epoch: 0062 loss = 0.003926
Epoch: 0063 loss = 0.003388
Epoch: 0064 loss = 0.001827
Epoch: 0065 loss = 0.003487
Epoch: 0066 loss = 0.003174
Epoch: 0067 loss = 0.003894
Epoch: 0068 loss = 0.004140
Epoch: 0069 loss = 0.005471
Epoch: 0070 loss = 0.001934
Epoch: 0071 loss = 0.002330
Epoch: 0072 loss = 0.003117
Epoch: 0073 loss = 0.007589
Epoch: 0074 loss = 0.004772
Epoch: 0075 loss = 0.003417
Epoch: 0076 loss = 0.002385
Epoch: 0077 loss = 0.002436
Epoch: 0078 loss = 0.002571
Epoch: 0079 loss = 0.001841
Epoch: 0080 loss = 0.002677
Epoch: 0081 loss = 0.001102
Epoch: 0082 loss = 0.000963
Epoch: 0083 loss = 0.000974
Epoch: 0084 loss = 0.001100
Epoch: 0085 loss = 0.001042
Epoch: 0086 loss = 0.000707
Epoch: 0087 loss = 0.000699
Epoch: 0088 loss = 0.001113
Epoch: 0089 loss = 0.000647
Epoch: 0090 loss = 0.000762
Epoch: 0091 loss = 0.000446
Epoch: 0092 loss = 0.000515
Epoch: 0093 loss = 0.000477
Epoch: 0094 loss = 0.000579
Epoch: 0095 loss = 0.000418
Epoch: 0096 loss = 0.000406
Epoch: 0097 loss = 0.000368
Epoch: 0098 loss = 0.000349
Epoch: 0099 loss = 0.000188
Epoch: 0100 loss = 0.000249
Epoch: 0101 loss = 0.000259
Epoch: 0102 loss = 0.000257
Epoch: 0103 loss = 0.000219
Epoch: 0104 loss = 0.000236
Epoch: 0105 loss = 0.000220
Epoch: 0106 loss = 0.000192
Epoch: 0107 loss = 0.000183
Epoch: 0108 loss = 0.000125
Epoch: 0109 loss = 0.000160
Epoch: 0110 loss = 0.000182
Epoch: 0111 loss = 0.000209
Epoch: 0112 loss = 0.000161
Epoch: 0113 loss = 0.000108
Epoch: 0114 loss = 0.000154
Epoch: 0115 loss = 0.000121
Epoch: 0116 loss = 0.000149
Epoch: 0117 loss = 0.000144
Epoch: 0118 loss = 0.000164
Epoch: 0119 loss = 0.000098
Epoch: 0120 loss = 0.000093
Epoch: 0121 loss = 0.000105
Epoch: 0122 loss = 0.000121
Epoch: 0123 loss = 0.000124
Epoch: 0124 loss = 0.000109
Epoch: 0125 loss = 0.000143
Epoch: 0126 loss = 0.000154
Epoch: 0127 loss = 0.000145
Epoch: 0128 loss = 0.000116
Epoch: 0129 loss = 0.000083
Epoch: 0130 loss = 0.000131
Epoch: 0131 loss = 0.000114
Epoch: 0132 loss = 0.000146
Epoch: 0133 loss = 0.000184
Epoch: 0134 loss = 0.000127
Epoch: 0135 loss = 0.000120
Epoch: 0136 loss = 0.000144
Epoch: 0137 loss = 0.000113
Epoch: 0138 loss = 0.000167
Epoch: 0139 loss = 0.000110
Epoch: 0140 loss = 0.000170
Epoch: 0141 loss = 0.000133
Epoch: 0142 loss = 0.000142
Epoch: 0143 loss = 0.000173
Epoch: 0144 loss = 0.000103
Epoch: 0145 loss = 0.000107
Epoch: 0146 loss = 0.000151
Epoch: 0147 loss = 0.000201
Epoch: 0148 loss = 0.000169
Epoch: 0149 loss = 0.000106
Epoch: 0150 loss = 0.000159
Epoch: 0151 loss = 0.000234
Epoch: 0152 loss = 0.000223
Epoch: 0153 loss = 0.000241
Epoch: 0154 loss = 0.000248
Epoch: 0155 loss = 0.000239
Epoch: 0156 loss = 0.000193
Epoch: 0157 loss = 0.000255
Epoch: 0158 loss = 0.000146
Epoch: 0159 loss = 0.000257
Epoch: 0160 loss = 0.000320
Epoch: 0161 loss = 0.000345
Epoch: 0162 loss = 0.000246
Epoch: 0163 loss = 0.000277
Epoch: 0164 loss = 0.000298
Epoch: 0165 loss = 0.000321
Epoch: 0166 loss = 0.000326
Epoch: 0167 loss = 0.000260
Epoch: 0168 loss = 0.000271
Epoch: 0169 loss = 0.000250
Epoch: 0170 loss = 0.000382
Epoch: 0171 loss = 0.000621
Epoch: 0172 loss = 0.000630
Epoch: 0173 loss = 0.000532
Epoch: 0174 loss = 0.000563
Epoch: 0175 loss = 0.000660
Epoch: 0176 loss = 0.000237
Epoch: 0177 loss = 0.000606
Epoch: 0178 loss = 0.000761
Epoch: 0179 loss = 0.000233
Epoch: 0180 loss = 0.000590
Epoch: 0181 loss = 0.000484
Epoch: 0182 loss = 0.000533
Epoch: 0183 loss = 0.000734
Epoch: 0184 loss = 0.000308
Epoch: 0185 loss = 0.000789
Epoch: 0186 loss = 0.000449
Epoch: 0187 loss = 0.000585
Epoch: 0188 loss = 0.000568
Epoch: 0189 loss = 0.000344
Epoch: 0190 loss = 0.000458
Epoch: 0191 loss = 0.000492
Epoch: 0192 loss = 0.000347
Epoch: 0193 loss = 0.000430
Epoch: 0194 loss = 0.000340
Epoch: 0195 loss = 0.000354
Epoch: 0196 loss = 0.000266
Epoch: 0197 loss = 0.000359
Epoch: 0198 loss = 0.000423
Epoch: 0199 loss = 0.000473
Epoch: 0200 loss = 0.000407
Epoch: 0201 loss = 0.000290
Epoch: 0202 loss = 0.000357
Epoch: 0203 loss = 0.000336
Epoch: 0204 loss = 0.000287
Epoch: 0205 loss = 0.000299
Epoch: 0206 loss = 0.000307
Epoch: 0207 loss = 0.000362
Epoch: 0208 loss = 0.000327
Epoch: 0209 loss = 0.000232
Epoch: 0210 loss = 0.000182
Epoch: 0211 loss = 0.000193
Epoch: 0212 loss = 0.000253
Epoch: 0213 loss = 0.000261
Epoch: 0214 loss = 0.000288
Epoch: 0215 loss = 0.000204
Epoch: 0216 loss = 0.000340
Epoch: 0217 loss = 0.000233
Epoch: 0218 loss = 0.000201
Epoch: 0219 loss = 0.000204
Epoch: 0220 loss = 0.000182
Epoch: 0221 loss = 0.000181
Epoch: 0222 loss = 0.000137
Epoch: 0223 loss = 0.000159
Epoch: 0224 loss = 0.000186
Epoch: 0225 loss = 0.000233
Epoch: 0226 loss = 0.000171
Epoch: 0227 loss = 0.000153
Epoch: 0228 loss = 0.000167
Epoch: 0229 loss = 0.000138
Epoch: 0230 loss = 0.000167
Epoch: 0231 loss = 0.000128
Epoch: 0232 loss = 0.000116
Epoch: 0233 loss = 0.000194
Epoch: 0234 loss = 0.000149
Epoch: 0235 loss = 0.000130
Epoch: 0236 loss = 0.000115
Epoch: 0237 loss = 0.000170
Epoch: 0238 loss = 0.000186
Epoch: 0239 loss = 0.000168
Epoch: 0240 loss = 0.000096
Epoch: 0241 loss = 0.000121
Epoch: 0242 loss = 0.000119
Epoch: 0243 loss = 0.000115
Epoch: 0244 loss = 0.000126
Epoch: 0245 loss = 0.000073
Epoch: 0246 loss = 0.000093
Epoch: 0247 loss = 0.000144
Epoch: 0248 loss = 0.000072
Epoch: 0249 loss = 0.000096
Epoch: 0250 loss = 0.000100
Epoch: 0251 loss = 0.000087
Epoch: 0252 loss = 0.000110
Epoch: 0253 loss = 0.000077
Epoch: 0254 loss = 0.000084
Epoch: 0255 loss = 0.000076
Epoch: 0256 loss = 0.000066
Epoch: 0257 loss = 0.000077
Epoch: 0258 loss = 0.000104
Epoch: 0259 loss = 0.000079
Epoch: 0260 loss = 0.000078
Epoch: 0261 loss = 0.000065
Epoch: 0262 loss = 0.000070
Epoch: 0263 loss = 0.000077
Epoch: 0264 loss = 0.000088
Epoch: 0265 loss = 0.000078
Epoch: 0266 loss = 0.000069
Epoch: 0267 loss = 0.000059
Epoch: 0268 loss = 0.000049
Epoch: 0269 loss = 0.000055
Epoch: 0270 loss = 0.000055
Epoch: 0271 loss = 0.000092
Epoch: 0272 loss = 0.000072
Epoch: 0273 loss = 0.000058
Epoch: 0274 loss = 0.000063
Epoch: 0275 loss = 0.000065
Epoch: 0276 loss = 0.000061
Epoch: 0277 loss = 0.000048
Epoch: 0278 loss = 0.000058
Epoch: 0279 loss = 0.000045
Epoch: 0280 loss = 0.000045
Epoch: 0281 loss = 0.000053
Epoch: 0282 loss = 0.000041
Epoch: 0283 loss = 0.000056
Epoch: 0284 loss = 0.000065
Epoch: 0285 loss = 0.000039
Epoch: 0286 loss = 0.000042
Epoch: 0287 loss = 0.000049
Epoch: 0288 loss = 0.000041
Epoch: 0289 loss = 0.000049
Epoch: 0290 loss = 0.000043
Epoch: 0291 loss = 0.000045
Epoch: 0292 loss = 0.000031
Epoch: 0293 loss = 0.000073
Epoch: 0294 loss = 0.000053
Epoch: 0295 loss = 0.000033
Epoch: 0296 loss = 0.000043
Epoch: 0297 loss = 0.000032
Epoch: 0298 loss = 0.000042
Epoch: 0299 loss = 0.000048
Epoch: 0300 loss = 0.000045
Epoch: 0301 loss = 0.000057
Epoch: 0302 loss = 0.000039
Epoch: 0303 loss = 0.000041
Epoch: 0304 loss = 0.000037
Epoch: 0305 loss = 0.000048
Epoch: 0306 loss = 0.000055
Epoch: 0307 loss = 0.000039
Epoch: 0308 loss = 0.000040
Epoch: 0309 loss = 0.000045
Epoch: 0310 loss = 0.000048
Epoch: 0311 loss = 0.000036
Epoch: 0312 loss = 0.000030
Epoch: 0313 loss = 0.000040
Epoch: 0314 loss = 0.000039
Epoch: 0315 loss = 0.000055
Epoch: 0316 loss = 0.000067
Epoch: 0317 loss = 0.000031
Epoch: 0318 loss = 0.000055
Epoch: 0319 loss = 0.000034
Epoch: 0320 loss = 0.000036
Epoch: 0321 loss = 0.000026
Epoch: 0322 loss = 0.000032
Epoch: 0323 loss = 0.000032
Epoch: 0324 loss = 0.000064
Epoch: 0325 loss = 0.000033
Epoch: 0326 loss = 0.000030
Epoch: 0327 loss = 0.000040
Epoch: 0328 loss = 0.000027
Epoch: 0329 loss = 0.000029
Epoch: 0330 loss = 0.000028
Epoch: 0331 loss = 0.000032
Epoch: 0332 loss = 0.000037
Epoch: 0333 loss = 0.000028
Epoch: 0334 loss = 0.000028
Epoch: 0335 loss = 0.000027
Epoch: 0336 loss = 0.000040
Epoch: 0337 loss = 0.000027
Epoch: 0338 loss = 0.000029
Epoch: 0339 loss = 0.000024
Epoch: 0340 loss = 0.000041
Epoch: 0341 loss = 0.000025
Epoch: 0342 loss = 0.000023
Epoch: 0343 loss = 0.000038
Epoch: 0344 loss = 0.000022
Epoch: 0345 loss = 0.000029
Epoch: 0346 loss = 0.000026
Epoch: 0347 loss = 0.000032
Epoch: 0348 loss = 0.000029
Epoch: 0349 loss = 0.000036
Epoch: 0350 loss = 0.000034
Epoch: 0351 loss = 0.000021
Epoch: 0352 loss = 0.000024
Epoch: 0353 loss = 0.000019
Epoch: 0354 loss = 0.000022
Epoch: 0355 loss = 0.000033
Epoch: 0356 loss = 0.000026
Epoch: 0357 loss = 0.000032
Epoch: 0358 loss = 0.000023
Epoch: 0359 loss = 0.000022
Epoch: 0360 loss = 0.000023
Epoch: 0361 loss = 0.000029
Epoch: 0362 loss = 0.000023
Epoch: 0363 loss = 0.000037
Epoch: 0364 loss = 0.000023
Epoch: 0365 loss = 0.000023
Epoch: 0366 loss = 0.000027
Epoch: 0367 loss = 0.000021
Epoch: 0368 loss = 0.000025
Epoch: 0369 loss = 0.000022
Epoch: 0370 loss = 0.000017
Epoch: 0371 loss = 0.000031
Epoch: 0372 loss = 0.000020
Epoch: 0373 loss = 0.000024
Epoch: 0374 loss = 0.000019
Epoch: 0375 loss = 0.000021
Epoch: 0376 loss = 0.000023
Epoch: 0377 loss = 0.000028
Epoch: 0378 loss = 0.000018
Epoch: 0379 loss = 0.000027
Epoch: 0380 loss = 0.000018
Epoch: 0381 loss = 0.000025
Epoch: 0382 loss = 0.000027
Epoch: 0383 loss = 0.000018
Epoch: 0384 loss = 0.000020
Epoch: 0385 loss = 0.000019
Epoch: 0386 loss = 0.000021
Epoch: 0387 loss = 0.000019
Epoch: 0388 loss = 0.000017
Epoch: 0389 loss = 0.000022
Epoch: 0390 loss = 0.000027
Epoch: 0391 loss = 0.000023
Epoch: 0392 loss = 0.000022
Epoch: 0393 loss = 0.000020
Epoch: 0394 loss = 0.000019
Epoch: 0395 loss = 0.000016
Epoch: 0396 loss = 0.000018
Epoch: 0397 loss = 0.000015
Epoch: 0398 loss = 0.000021
Epoch: 0399 loss = 0.000013
Epoch: 0400 loss = 0.000016
Epoch: 0401 loss = 0.000017
Epoch: 0402 loss = 0.000024
Epoch: 0403 loss = 0.000018
Epoch: 0404 loss = 0.000017
Epoch: 0405 loss = 0.000018
Epoch: 0406 loss = 0.000015
Epoch: 0407 loss = 0.000017
Epoch: 0408 loss = 0.000022
Epoch: 0409 loss = 0.000017
Epoch: 0410 loss = 0.000020
Epoch: 0411 loss = 0.000023
Epoch: 0412 loss = 0.000020
Epoch: 0413 loss = 0.000017
Epoch: 0414 loss = 0.000018
Epoch: 0415 loss = 0.000015
Epoch: 0416 loss = 0.000014
Epoch: 0417 loss = 0.000018
Epoch: 0418 loss = 0.000015
Epoch: 0419 loss = 0.000012
Epoch: 0420 loss = 0.000017
Epoch: 0421 loss = 0.000014
Epoch: 0422 loss = 0.000016
Epoch: 0423 loss = 0.000014
Epoch: 0424 loss = 0.000018
Epoch: 0425 loss = 0.000023
Epoch: 0426 loss = 0.000017
Epoch: 0427 loss = 0.000016
Epoch: 0428 loss = 0.000014
Epoch: 0429 loss = 0.000014
Epoch: 0430 loss = 0.000017
Epoch: 0431 loss = 0.000013
Epoch: 0432 loss = 0.000014
Epoch: 0433 loss = 0.000014
Epoch: 0434 loss = 0.000020
Epoch: 0435 loss = 0.000016
Epoch: 0436 loss = 0.000013
Epoch: 0437 loss = 0.000015
Epoch: 0438 loss = 0.000014
Epoch: 0439 loss = 0.000015
Epoch: 0440 loss = 0.000021
Epoch: 0441 loss = 0.000016
Epoch: 0442 loss = 0.000014
Epoch: 0443 loss = 0.000014
Epoch: 0444 loss = 0.000013
Epoch: 0445 loss = 0.000017
Epoch: 0446 loss = 0.000016
Epoch: 0447 loss = 0.000019
Epoch: 0448 loss = 0.000018
Epoch: 0449 loss = 0.000020
Epoch: 0450 loss = 0.000017
Epoch: 0451 loss = 0.000015
Epoch: 0452 loss = 0.000014
Epoch: 0453 loss = 0.000016
Epoch: 0454 loss = 0.000012
Epoch: 0455 loss = 0.000014
Epoch: 0456 loss = 0.000011
Epoch: 0457 loss = 0.000017
Epoch: 0458 loss = 0.000016
Epoch: 0459 loss = 0.000015
Epoch: 0460 loss = 0.000013
Epoch: 0461 loss = 0.000012
Epoch: 0462 loss = 0.000011
Epoch: 0463 loss = 0.000022
Epoch: 0464 loss = 0.000018
Epoch: 0465 loss = 0.000015
Epoch: 0466 loss = 0.000016
Epoch: 0467 loss = 0.000010
Epoch: 0468 loss = 0.000013
Epoch: 0469 loss = 0.000012
Epoch: 0470 loss = 0.000014
Epoch: 0471 loss = 0.000015
Epoch: 0472 loss = 0.000010
Epoch: 0473 loss = 0.000011
Epoch: 0474 loss = 0.000014
Epoch: 0475 loss = 0.000011
Epoch: 0476 loss = 0.000012
Epoch: 0477 loss = 0.000015
Epoch: 0478 loss = 0.000012
Epoch: 0479 loss = 0.000014
Epoch: 0480 loss = 0.000011
Epoch: 0481 loss = 0.000010
Epoch: 0482 loss = 0.000014
Epoch: 0483 loss = 0.000013
Epoch: 0484 loss = 0.000014
Epoch: 0485 loss = 0.000014
Epoch: 0486 loss = 0.000014
Epoch: 0487 loss = 0.000014
Epoch: 0488 loss = 0.000012
Epoch: 0489 loss = 0.000014
Epoch: 0490 loss = 0.000012
Epoch: 0491 loss = 0.000010
Epoch: 0492 loss = 0.000012
Epoch: 0493 loss = 0.000014
Epoch: 0494 loss = 0.000014
Epoch: 0495 loss = 0.000009
Epoch: 0496 loss = 0.000011
Epoch: 0497 loss = 0.000015
Epoch: 0498 loss = 0.000016
Epoch: 0499 loss = 0.000011
Epoch: 0500 loss = 0.000013
Epoch: 0501 loss = 0.000011
Epoch: 0502 loss = 0.000010
Epoch: 0503 loss = 0.000011
Epoch: 0504 loss = 0.000010
Epoch: 0505 loss = 0.000012
Epoch: 0506 loss = 0.000012
Epoch: 0507 loss = 0.000010
Epoch: 0508 loss = 0.000012
Epoch: 0509 loss = 0.000011
Epoch: 0510 loss = 0.000008
Epoch: 0511 loss = 0.000012
Epoch: 0512 loss = 0.000014
Epoch: 0513 loss = 0.000008
Epoch: 0514 loss = 0.000008
Epoch: 0515 loss = 0.000012
Epoch: 0516 loss = 0.000009
Epoch: 0517 loss = 0.000012
Epoch: 0518 loss = 0.000013
Epoch: 0519 loss = 0.000012
Epoch: 0520 loss = 0.000007
Epoch: 0521 loss = 0.000012
Epoch: 0522 loss = 0.000012
Epoch: 0523 loss = 0.000009
Epoch: 0524 loss = 0.000008
Epoch: 0525 loss = 0.000011
Epoch: 0526 loss = 0.000014
Epoch: 0527 loss = 0.000014
Epoch: 0528 loss = 0.000013
Epoch: 0529 loss = 0.000011
Epoch: 0530 loss = 0.000012
Epoch: 0531 loss = 0.000012
Epoch: 0532 loss = 0.000011
Epoch: 0533 loss = 0.000009
Epoch: 0534 loss = 0.000012
Epoch: 0535 loss = 0.000011
Epoch: 0536 loss = 0.000011
Epoch: 0537 loss = 0.000012
Epoch: 0538 loss = 0.000009
Epoch: 0539 loss = 0.000009
Epoch: 0540 loss = 0.000013
Epoch: 0541 loss = 0.000009
Epoch: 0542 loss = 0.000011
Epoch: 0543 loss = 0.000009
Epoch: 0544 loss = 0.000011
Epoch: 0545 loss = 0.000010
Epoch: 0546 loss = 0.000009
Epoch: 0547 loss = 0.000014
Epoch: 0548 loss = 0.000010
Epoch: 0549 loss = 0.000009
Epoch: 0550 loss = 0.000011
Epoch: 0551 loss = 0.000015
Epoch: 0552 loss = 0.000012
Epoch: 0553 loss = 0.000011
Epoch: 0554 loss = 0.000010
Epoch: 0555 loss = 0.000010
Epoch: 0556 loss = 0.000008
Epoch: 0557 loss = 0.000009
Epoch: 0558 loss = 0.000011
Epoch: 0559 loss = 0.000007
Epoch: 0560 loss = 0.000009
Epoch: 0561 loss = 0.000009
Epoch: 0562 loss = 0.000010
Epoch: 0563 loss = 0.000012
Epoch: 0564 loss = 0.000008
Epoch: 0565 loss = 0.000011
Epoch: 0566 loss = 0.000009
Epoch: 0567 loss = 0.000007
Epoch: 0568 loss = 0.000009
Epoch: 0569 loss = 0.000009
Epoch: 0570 loss = 0.000008
Epoch: 0571 loss = 0.000009
Epoch: 0572 loss = 0.000007
Epoch: 0573 loss = 0.000010
Epoch: 0574 loss = 0.000007
Epoch: 0575 loss = 0.000007
Epoch: 0576 loss = 0.000009
Epoch: 0577 loss = 0.000008
Epoch: 0578 loss = 0.000009
Epoch: 0579 loss = 0.000007
Epoch: 0580 loss = 0.000008
Epoch: 0581 loss = 0.000010
Epoch: 0582 loss = 0.000007
Epoch: 0583 loss = 0.000007
Epoch: 0584 loss = 0.000007
Epoch: 0585 loss = 0.000008
Epoch: 0586 loss = 0.000012
Epoch: 0587 loss = 0.000013
Epoch: 0588 loss = 0.000010
Epoch: 0589 loss = 0.000007
Epoch: 0590 loss = 0.000006
Epoch: 0591 loss = 0.000009
Epoch: 0592 loss = 0.000010
Epoch: 0593 loss = 0.000008
Epoch: 0594 loss = 0.000010
Epoch: 0595 loss = 0.000006
Epoch: 0596 loss = 0.000008
Epoch: 0597 loss = 0.000007
Epoch: 0598 loss = 0.000009
Epoch: 0599 loss = 0.000008
Epoch: 0600 loss = 0.000008
Epoch: 0601 loss = 0.000008
Epoch: 0602 loss = 0.000009
Epoch: 0603 loss = 0.000009
Epoch: 0604 loss = 0.000008
Epoch: 0605 loss = 0.000007
Epoch: 0606 loss = 0.000008
Epoch: 0607 loss = 0.000009
Epoch: 0608 loss = 0.000008
Epoch: 0609 loss = 0.000010
Epoch: 0610 loss = 0.000009
Epoch: 0611 loss = 0.000011
Epoch: 0612 loss = 0.000008
Epoch: 0613 loss = 0.000010
Epoch: 0614 loss = 0.000009
Epoch: 0615 loss = 0.000008
Epoch: 0616 loss = 0.000007
Epoch: 0617 loss = 0.000007
Epoch: 0618 loss = 0.000008
Epoch: 0619 loss = 0.000008
Epoch: 0620 loss = 0.000009
Epoch: 0621 loss = 0.000008
Epoch: 0622 loss = 0.000009
Epoch: 0623 loss = 0.000006
Epoch: 0624 loss = 0.000008
Epoch: 0625 loss = 0.000009
Epoch: 0626 loss = 0.000008
Epoch: 0627 loss = 0.000009
Epoch: 0628 loss = 0.000010
Epoch: 0629 loss = 0.000008
Epoch: 0630 loss = 0.000010
Epoch: 0631 loss = 0.000007
Epoch: 0632 loss = 0.000008
Epoch: 0633 loss = 0.000007
Epoch: 0634 loss = 0.000007
Epoch: 0635 loss = 0.000008
Epoch: 0636 loss = 0.000008
Epoch: 0637 loss = 0.000008
Epoch: 0638 loss = 0.000011
Epoch: 0639 loss = 0.000009
Epoch: 0640 loss = 0.000007
Epoch: 0641 loss = 0.000008
Epoch: 0642 loss = 0.000006
Epoch: 0643 loss = 0.000006
Epoch: 0644 loss = 0.000006
Epoch: 0645 loss = 0.000005
Epoch: 0646 loss = 0.000007
Epoch: 0647 loss = 0.000006
Epoch: 0648 loss = 0.000007
Epoch: 0649 loss = 0.000007
Epoch: 0650 loss = 0.000007
Epoch: 0651 loss = 0.000006
Epoch: 0652 loss = 0.000009
Epoch: 0653 loss = 0.000007
Epoch: 0654 loss = 0.000008
Epoch: 0655 loss = 0.000008
Epoch: 0656 loss = 0.000008
Epoch: 0657 loss = 0.000006
Epoch: 0658 loss = 0.000005
Epoch: 0659 loss = 0.000008
Epoch: 0660 loss = 0.000009
Epoch: 0661 loss = 0.000009
Epoch: 0662 loss = 0.000009
Epoch: 0663 loss = 0.000006
Epoch: 0664 loss = 0.000006
Epoch: 0665 loss = 0.000008
Epoch: 0666 loss = 0.000006
Epoch: 0667 loss = 0.000009
Epoch: 0668 loss = 0.000007
Epoch: 0669 loss = 0.000007
Epoch: 0670 loss = 0.000007
Epoch: 0671 loss = 0.000008
Epoch: 0672 loss = 0.000007
Epoch: 0673 loss = 0.000009
Epoch: 0674 loss = 0.000005
Epoch: 0675 loss = 0.000006
Epoch: 0676 loss = 0.000010
Epoch: 0677 loss = 0.000008
Epoch: 0678 loss = 0.000008
Epoch: 0679 loss = 0.000006
Epoch: 0680 loss = 0.000006
Epoch: 0681 loss = 0.000006
Epoch: 0682 loss = 0.000008
Epoch: 0683 loss = 0.000012
Epoch: 0684 loss = 0.000007
Epoch: 0685 loss = 0.000007
Epoch: 0686 loss = 0.000008
Epoch: 0687 loss = 0.000006
Epoch: 0688 loss = 0.000007
Epoch: 0689 loss = 0.000006
Epoch: 0690 loss = 0.000009
Epoch: 0691 loss = 0.000008
Epoch: 0692 loss = 0.000006
Epoch: 0693 loss = 0.000006
Epoch: 0694 loss = 0.000005
Epoch: 0695 loss = 0.000006
Epoch: 0696 loss = 0.000007
Epoch: 0697 loss = 0.000007
Epoch: 0698 loss = 0.000006
Epoch: 0699 loss = 0.000006
Epoch: 0700 loss = 0.000006
Epoch: 0701 loss = 0.000004
Epoch: 0702 loss = 0.000008
Epoch: 0703 loss = 0.000007
Epoch: 0704 loss = 0.000007
Epoch: 0705 loss = 0.000006
Epoch: 0706 loss = 0.000006
Epoch: 0707 loss = 0.000006
Epoch: 0708 loss = 0.000008
Epoch: 0709 loss = 0.000005
Epoch: 0710 loss = 0.000009
Epoch: 0711 loss = 0.000006
Epoch: 0712 loss = 0.000005
Epoch: 0713 loss = 0.000008
Epoch: 0714 loss = 0.000009
Epoch: 0715 loss = 0.000007
Epoch: 0716 loss = 0.000009
Epoch: 0717 loss = 0.000007
Epoch: 0718 loss = 0.000008
Epoch: 0719 loss = 0.000007
Epoch: 0720 loss = 0.000005
Epoch: 0721 loss = 0.000007
Epoch: 0722 loss = 0.000005
Epoch: 0723 loss = 0.000006
Epoch: 0724 loss = 0.000009
Epoch: 0725 loss = 0.000006
Epoch: 0726 loss = 0.000007
Epoch: 0727 loss = 0.000010
Epoch: 0728 loss = 0.000006
Epoch: 0729 loss = 0.000006
Epoch: 0730 loss = 0.000007
Epoch: 0731 loss = 0.000007
Epoch: 0732 loss = 0.000006
Epoch: 0733 loss = 0.000004
Epoch: 0734 loss = 0.000006
Epoch: 0735 loss = 0.000006
Epoch: 0736 loss = 0.000006
Epoch: 0737 loss = 0.000009
Epoch: 0738 loss = 0.000007
Epoch: 0739 loss = 0.000006
Epoch: 0740 loss = 0.000006
Epoch: 0741 loss = 0.000007
Epoch: 0742 loss = 0.000005
Epoch: 0743 loss = 0.000006
Epoch: 0744 loss = 0.000005
Epoch: 0745 loss = 0.000005
Epoch: 0746 loss = 0.000004
Epoch: 0747 loss = 0.000007
Epoch: 0748 loss = 0.000006
Epoch: 0749 loss = 0.000006
Epoch: 0750 loss = 0.000007
Epoch: 0751 loss = 0.000005
Epoch: 0752 loss = 0.000006
Epoch: 0753 loss = 0.000006
Epoch: 0754 loss = 0.000004
Epoch: 0755 loss = 0.000007
Epoch: 0756 loss = 0.000006
Epoch: 0757 loss = 0.000008
Epoch: 0758 loss = 0.000005
Epoch: 0759 loss = 0.000007
Epoch: 0760 loss = 0.000005
Epoch: 0761 loss = 0.000006
Epoch: 0762 loss = 0.000006
Epoch: 0763 loss = 0.000005
Epoch: 0764 loss = 0.000006
Epoch: 0765 loss = 0.000005
Epoch: 0766 loss = 0.000005
Epoch: 0767 loss = 0.000004
Epoch: 0768 loss = 0.000008
Epoch: 0769 loss = 0.000006
Epoch: 0770 loss = 0.000008
Epoch: 0771 loss = 0.000004
Epoch: 0772 loss = 0.000007
Epoch: 0773 loss = 0.000008
Epoch: 0774 loss = 0.000008
Epoch: 0775 loss = 0.000006
Epoch: 0776 loss = 0.000006
Epoch: 0777 loss = 0.000008
Epoch: 0778 loss = 0.000006
Epoch: 0779 loss = 0.000006
Epoch: 0780 loss = 0.000005
Epoch: 0781 loss = 0.000005
Epoch: 0782 loss = 0.000005
Epoch: 0783 loss = 0.000006
Epoch: 0784 loss = 0.000006
Epoch: 0785 loss = 0.000005
Epoch: 0786 loss = 0.000005
Epoch: 0787 loss = 0.000005
Epoch: 0788 loss = 0.000005
Epoch: 0789 loss = 0.000005
Epoch: 0790 loss = 0.000007
Epoch: 0791 loss = 0.000008
Epoch: 0792 loss = 0.000006
Epoch: 0793 loss = 0.000004
Epoch: 0794 loss = 0.000008
Epoch: 0795 loss = 0.000005
Epoch: 0796 loss = 0.000005
Epoch: 0797 loss = 0.000004
Epoch: 0798 loss = 0.000005
Epoch: 0799 loss = 0.000005
Epoch: 0800 loss = 0.000006
Epoch: 0801 loss = 0.000006
Epoch: 0802 loss = 0.000004
Epoch: 0803 loss = 0.000006
Epoch: 0804 loss = 0.000007
Epoch: 0805 loss = 0.000004
Epoch: 0806 loss = 0.000005
Epoch: 0807 loss = 0.000007
Epoch: 0808 loss = 0.000006
Epoch: 0809 loss = 0.000006
Epoch: 0810 loss = 0.000006
Epoch: 0811 loss = 0.000007
Epoch: 0812 loss = 0.000005
Epoch: 0813 loss = 0.000005
Epoch: 0814 loss = 0.000006
Epoch: 0815 loss = 0.000005
Epoch: 0816 loss = 0.000005
Epoch: 0817 loss = 0.000007
Epoch: 0818 loss = 0.000005
Epoch: 0819 loss = 0.000004
Epoch: 0820 loss = 0.000007
Epoch: 0821 loss = 0.000006
Epoch: 0822 loss = 0.000005
Epoch: 0823 loss = 0.000005
Epoch: 0824 loss = 0.000006
Epoch: 0825 loss = 0.000005
Epoch: 0826 loss = 0.000006
Epoch: 0827 loss = 0.000006
Epoch: 0828 loss = 0.000005
Epoch: 0829 loss = 0.000005
Epoch: 0830 loss = 0.000005
Epoch: 0831 loss = 0.000006
Epoch: 0832 loss = 0.000007
Epoch: 0833 loss = 0.000005
Epoch: 0834 loss = 0.000005
Epoch: 0835 loss = 0.000005
Epoch: 0836 loss = 0.000005
Epoch: 0837 loss = 0.000006
Epoch: 0838 loss = 0.000006
Epoch: 0839 loss = 0.000004
Epoch: 0840 loss = 0.000005
Epoch: 0841 loss = 0.000005
Epoch: 0842 loss = 0.000005
Epoch: 0843 loss = 0.000005
Epoch: 0844 loss = 0.000004
Epoch: 0845 loss = 0.000006
Epoch: 0846 loss = 0.000006
Epoch: 0847 loss = 0.000006
Epoch: 0848 loss = 0.000005
Epoch: 0849 loss = 0.000005
Epoch: 0850 loss = 0.000006
Epoch: 0851 loss = 0.000007
Epoch: 0852 loss = 0.000004
Epoch: 0853 loss = 0.000005
Epoch: 0854 loss = 0.000005
Epoch: 0855 loss = 0.000005
Epoch: 0856 loss = 0.000006
Epoch: 0857 loss = 0.000005
Epoch: 0858 loss = 0.000006
Epoch: 0859 loss = 0.000005
Epoch: 0860 loss = 0.000005
Epoch: 0861 loss = 0.000004
Epoch: 0862 loss = 0.000004
Epoch: 0863 loss = 0.000004
Epoch: 0864 loss = 0.000004
Epoch: 0865 loss = 0.000006
Epoch: 0866 loss = 0.000004
Epoch: 0867 loss = 0.000004
Epoch: 0868 loss = 0.000004
Epoch: 0869 loss = 0.000005
Epoch: 0870 loss = 0.000006
Epoch: 0871 loss = 0.000005
Epoch: 0872 loss = 0.000004
Epoch: 0873 loss = 0.000004
Epoch: 0874 loss = 0.000005
Epoch: 0875 loss = 0.000004
Epoch: 0876 loss = 0.000004
Epoch: 0877 loss = 0.000005
Epoch: 0878 loss = 0.000005
Epoch: 0879 loss = 0.000006
Epoch: 0880 loss = 0.000005
Epoch: 0881 loss = 0.000006
Epoch: 0882 loss = 0.000005
Epoch: 0883 loss = 0.000006
Epoch: 0884 loss = 0.000006
Epoch: 0885 loss = 0.000005
Epoch: 0886 loss = 0.000007
Epoch: 0887 loss = 0.000005
Epoch: 0888 loss = 0.000004
Epoch: 0889 loss = 0.000006
Epoch: 0890 loss = 0.000005
Epoch: 0891 loss = 0.000006
Epoch: 0892 loss = 0.000006
Epoch: 0893 loss = 0.000006
Epoch: 0894 loss = 0.000005
Epoch: 0895 loss = 0.000006
Epoch: 0896 loss = 0.000006
Epoch: 0897 loss = 0.000004
Epoch: 0898 loss = 0.000006
Epoch: 0899 loss = 0.000004
Epoch: 0900 loss = 0.000005
Epoch: 0901 loss = 0.000004
Epoch: 0902 loss = 0.000006
Epoch: 0903 loss = 0.000004
Epoch: 0904 loss = 0.000006
Epoch: 0905 loss = 0.000006
Epoch: 0906 loss = 0.000005
Epoch: 0907 loss = 0.000007
Epoch: 0908 loss = 0.000004
Epoch: 0909 loss = 0.000006
Epoch: 0910 loss = 0.000005
Epoch: 0911 loss = 0.000006
Epoch: 0912 loss = 0.000005
Epoch: 0913 loss = 0.000006
Epoch: 0914 loss = 0.000005
Epoch: 0915 loss = 0.000007
Epoch: 0916 loss = 0.000005
Epoch: 0917 loss = 0.000005
Epoch: 0918 loss = 0.000003
Epoch: 0919 loss = 0.000005
Epoch: 0920 loss = 0.000006
Epoch: 0921 loss = 0.000006
Epoch: 0922 loss = 0.000006
Epoch: 0923 loss = 0.000006
Epoch: 0924 loss = 0.000004
Epoch: 0925 loss = 0.000003
Epoch: 0926 loss = 0.000003
Epoch: 0927 loss = 0.000004
Epoch: 0928 loss = 0.000005
Epoch: 0929 loss = 0.000004
Epoch: 0930 loss = 0.000004
Epoch: 0931 loss = 0.000005
Epoch: 0932 loss = 0.000004
Epoch: 0933 loss = 0.000005
Epoch: 0934 loss = 0.000004
Epoch: 0935 loss = 0.000005
Epoch: 0936 loss = 0.000006
Epoch: 0937 loss = 0.000006
Epoch: 0938 loss = 0.000006
Epoch: 0939 loss = 0.000005
Epoch: 0940 loss = 0.000004
Epoch: 0941 loss = 0.000006
Epoch: 0942 loss = 0.000004
Epoch: 0943 loss = 0.000004
Epoch: 0944 loss = 0.000005
Epoch: 0945 loss = 0.000004
Epoch: 0946 loss = 0.000004
Epoch: 0947 loss = 0.000006
Epoch: 0948 loss = 0.000005
Epoch: 0949 loss = 0.000006
Epoch: 0950 loss = 0.000005
Epoch: 0951 loss = 0.000004
Epoch: 0952 loss = 0.000004
Epoch: 0953 loss = 0.000004
Epoch: 0954 loss = 0.000004
Epoch: 0955 loss = 0.000005
Epoch: 0956 loss = 0.000005
Epoch: 0957 loss = 0.000006
Epoch: 0958 loss = 0.000004
Epoch: 0959 loss = 0.000005
Epoch: 0960 loss = 0.000005
Epoch: 0961 loss = 0.000004
Epoch: 0962 loss = 0.000007
Epoch: 0963 loss = 0.000007
Epoch: 0964 loss = 0.000004
Epoch: 0965 loss = 0.000005
Epoch: 0966 loss = 0.000004
Epoch: 0967 loss = 0.000006
Epoch: 0968 loss = 0.000005
Epoch: 0969 loss = 0.000004
Epoch: 0970 loss = 0.000004
Epoch: 0971 loss = 0.000004
Epoch: 0972 loss = 0.000004
Epoch: 0973 loss = 0.000005
Epoch: 0974 loss = 0.000004
Epoch: 0975 loss = 0.000004
Epoch: 0976 loss = 0.000006
Epoch: 0977 loss = 0.000005
Epoch: 0978 loss = 0.000005
Epoch: 0979 loss = 0.000004
Epoch: 0980 loss = 0.000004
Epoch: 0981 loss = 0.000004
Epoch: 0982 loss = 0.000004
Epoch: 0983 loss = 0.000004
Epoch: 0984 loss = 0.000004
Epoch: 0985 loss = 0.000005
Epoch: 0986 loss = 0.000006
Epoch: 0987 loss = 0.000005
Epoch: 0988 loss = 0.000005
Epoch: 0989 loss = 0.000005
Epoch: 0990 loss = 0.000004
Epoch: 0991 loss = 0.000004
Epoch: 0992 loss = 0.000005
Epoch: 0993 loss = 0.000004
Epoch: 0994 loss = 0.000005
Epoch: 0995 loss = 0.000004
Epoch: 0996 loss = 0.000004
Epoch: 0997 loss = 0.000005
Epoch: 0998 loss = 0.000006
Epoch: 0999 loss = 0.000006
Epoch: 1000 loss = 0.000005

测试

待实现……

参考

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-30 12:33:56  更:2021-10-30 12:34:48 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 8:46:54-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码