IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> Pytorch——循环神经网络预测正弦函数 -> 正文阅读

[人工智能]Pytorch——循环神经网络预测正弦函数

循环神经网络

在这里插入图片描述
RNN单元会自我更新之后再输出,其前向传播函数如下:
在这里插入图片描述
在这里插入图片描述

Pytorch实现RNN单元的两种方式

nn.RNN()

初始化:
在这里插入图片描述
使用:
在这里插入图片描述

  • x给的是x不是x(t),就是把x一次全部喂进去
    X.shape:[seq_len, batch_size, input_size]
    input_size就是word_vec
  • h0可以给也可以不给,不给的话就默认给0
  • ht:最后一个时刻的所有层cnn的状态
    h.shape:[num_layers, batch_size, h_dim]
    h_dim就是hidden_len
  • out:所有时刻的最后一层cnn的状态
    out_size:[seq_len, batch_size, h_dim]

在这里插入图片描述在这里插入图片描述

单层的RNN网络:

import torch
import torch.nn as nn

rnn = nn.RNN(input_size=100, hidden_size=20, num_layers=1)
print(rnn)  # RNN(100, 20)

x = torch.randn(10, 3, 100)  # [seq_length, b, input_size]
out, h = rnn(x, torch.zeros(1, 3, 20))  # [x, h0],不给h0参数也行
print(out.shape)  # torch.Size([10, 3, 20]) 所有时间的状态
print(h.shape)  # torch.Size([1, 3, 20]) 最后一个时刻的状态

多层RNN网络:

import torch
import torch.nn as nn

rnn = nn.RNN(input_size=100, hidden_size=20, num_layers=4)
print(rnn)  # RNN(100, 20, num_layers=4)

x = torch.randn(10, 3, 100)  # [seq_length, b, input_size]
out, h = rnn(x)  # [x, h0],不给h0参数也行
print(out.shape)  # torch.Size([10, 3, 20]) 所有时间戳上最后一层rnn的状态
print(h.shape)  # torch.Size([4, 3, 20]) 最后时间戳上所有rnn层的状态

nn.RNNCell()

在这里插入图片描述
在这里插入图片描述

  • x给的是x(t)。
    x(t):[batch_size, input_size]
  • ht:t时刻某一层的状态
    [batch_size, hidden_size]

单层RNN:

x = torch.randn(10, 3, 100)  # [句子长度, b, 每个单词用几维向量表示]
cell1 = nn.RNNCell(100,20)  # input_size, hidden_size, num_layer = 1
h1 = torch.zeros(3, 20)  # batch_size, hidden_size
for xt in x:
    h1 = cell1(xt,h1)
print(h1.shape)  # torch.Size([3, 20])

双层RNN:

x = torch.randn(10, 3, 100)  # [句子长度, b, 每个单词用几维向量表示]
cell1 = nn.RNNCell(100,30)  # [input_size, hidden_size, num_layer = 1]
cell2 = nn.RNNCell(30,20)
h1 = torch.zeros(3, 30)  # [batch_size, hidden_size]
h2 = torch.zeros(3, 20)
for xt in x:
    h1 = cell1(xt,h1)
    h2 = cell2(h1,h2)
print(h2.shape)  # torch.Size([3, 20])

RNN预测正弦函数

在这里插入图片描述

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2021/7/31 21:18
# @Author  : Liu Lihao
# @File    : test.py

import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from matplotlib import pyplot as plt


'''超参数'''
num_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
num_layers = 1
lr=0.01


'''定义网络结构'''
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True  # [batch, seq, feature]
        )

        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden_prev):  # x, h0
        # x: [batch_size, seq_len, input_size]
        # hidden_prev: [num_layers, batch_size, h_dim]  h_dim就是hidden_size
        out, hidden_prev = self.rnn(x, hidden_prev)
        # out: [batch_size, seq_len, h_dim]
        # hidden_prev: [num_layers, batch_size, h_dim]

        # [batch, seq, hidden_size] => [batch * seq, hidden_size]
        out = out.view(-1, hidden_size)

        # [batch * seq_len, hidden_size] => [batch * seq_len, output_size]
        out = self.linear(out)

        # [batch * seq_len, output_size] => [1, batch * seq_len, output_size]  这里batch=1
        out = out.unsqueeze(dim=0)

        return out, hidden_prev


'''声明网络。loss,优化器'''
model = Net()
loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)

hidden_prev = torch.zeros(num_layers, 1, hidden_size)


'''训练'''
for iter in range(6000):
    start = np.random.randint(3, size=1)[0]  # 开始的时刻,会在0-3之随机初始化。
    time_steps = np.linspace(start, start + 10, num_time_steps)  # 训练的数据:从start时刻到start+10时刻
    data = np.sin(time_steps)
    data = data.reshape(num_time_steps, 1)

    # 只往后预测一个点
    x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)  # 去掉最后一个元素,作为输入
    # x.shape:  torch.Size([1, 49, 1]) [batch_size, seq_len, input_size]

    y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)  # 去掉第一个元素,作为label
    # y.shape:  torch.Size([1, 49, 1])

    # 神经网络前传
    output, hidden_prev = model(x, hidden_prev)

    hidden_prev = hidden_prev.detach()  # 将variable参数从网络中隔离开,不参与参数更新。

    loss = loss_function(output, y)
    model.zero_grad()
    loss.backward()
    optimizer.step()

    if iter % 100 == 0:
        print("Iteration: {} loss {}".format(iter, loss.item()))


'''测试'''
start = np.random.randint(3, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)

predictions = []
input = x[:, 0, :]
for _ in range(x.shape[1]):
  input = input.view(1, 1, 1)
  pred, hidden_prev = model(input, hidden_prev)
  input = pred
  predictions.append(pred.detach().numpy().ravel()[0])


'''绘图'''
x = x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())

plt.scatter(time_steps[1:], predictions)
plt.show()
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-08 11:20:38  更:2021-08-08 11:21:24 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 4:00:10-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码