IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 使用RNN模型构建字符串批量转换功能seq2seq -> 正文阅读

[人工智能]使用RNN模型构建字符串批量转换功能seq2seq

使用RNN Module构建的一个字符串转换功能:

import torch
import torch.optim as optim

class Model(torch.nn.Module):
    """
    RNN
    """
    def __init__(self, input_size, hidden_size, batch_size,num_layers):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.num_layers = num_layers
        #反复使用rnncell, 权重共享
        self.rnn = torch.nn.RNN(
                input_size=self.input_size,
                hidden_size=self.hidden_size)
        

    def forward(self,input, **args):
        if 'batch_size' in args:
            self.batch_size = args['batch_size']
        hidden = torch.zeros(
                self.num_layers, 
                self.batch_size,
                self.hidden_size)
        out, _= self.rnn(input, hidden)
        return out.view(-1, self.hidden_size)



if __name__ == "__main__":

    num_layers = 1 # RNN层数

    #idx2char = ['e','h','l','o','n','a','b','c'] #构建词典
    idx2char = [chr(x) for x in range(ord('A'),ord('Z')+1)] + [chr(x) for x in range(ord('0'),ord('9')+1)] + ['+', '-', '*', '/', '=', ' ']

    input_size = len(idx2char) #输入序列每一元素的特征维度
    hidden_size = len(idx2char) #隐藏状态维度

    print(idx2char)

    #输入与标签数据
    #x_data = [1,0,5,2,2,3,2,2,4,5] #hellollnnaa
    #y_data = [3,1,4,2,3,2,3,3,5,4] #ohlolooaann

    x_data = ["xihuanliaojiexuexilehuatuan"]
    y_data = ["hifuanliaogaihaxolaofuatuen"]

    batch_size = len(x_data) #批次大小
    seq_len = len(x_data[0]) #每一批量的序列长度


    x_data = [idx2char.index(x) for item in x_data for x in item.upper() ]
    y_data = [idx2char.index(x)  for item in y_data for x in item.upper()]
    print(x_data, y_data)

    #词典转换为one-hot对照表
    one_hot_lookup = torch.diag(torch.ones(input_size,dtype=torch.int32))
    """
    one_hot_lookup = [
            [1,0,0,0,0,0],
            [0,1,0,0,0,0],
            [0,0,1,0,0,0],
            [0,0,0,1,0,0],
            [0,0,0,0,1,0],
            [0,0,0,0,0,1],
            ]
    #x_one_hot = [one_hot_lookup[x] for x in x_data]
    """

    x_one_hot = one_hot_lookup[x_data]
    print(x_one_hot)

    inputs = (x_one_hot.float()).view(seq_len, batch_size, input_size)
    labels = torch.LongTensor(y_data)


    model = Model(input_size, hidden_size, batch_size, num_layers)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.01)

    #测试过程
    for epoch in range(100):
        optimizer.zero_grad()#梯度数据重置
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        #反馈
        loss.backward()#反向传播
        #更新
        optimizer.step()#更新参数
        
        _,idx =outputs.max(dim=1)
        print("EPOCH: ", epoch+1, loss.item(), end=" ")
        print("Predicted String: ", end=" ")
        print("".join([idx2char[x] for x in idx]))


    batch_size = 1
    myinput = input("请输入你要转换的序列:")
    test_x_data = [idx2char.index(x) for x in myinput.upper()]

    #新数据
    with torch.no_grad(): #无需计算梯度
        x_one_hot = one_hot_lookup[test_x_data]
        inp = (x_one_hot.float()).view(len(test_x_data), batch_size, input_size)
        outputs = model(inp, **{"batch_size":batch_size})
        _,idx = outputs.max(dim=1)
        print("".join([idx2char[x] for x in idx]),end="")

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-29 10:15:46  更:2021-09-29 10:17:40 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 12:41:47-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码