IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 四十五.门控循环单元(GRU)简介和keras实现 -> 正文阅读

[人工智能]四十五.门控循环单元(GRU)简介和keras实现

1.网络结构

U是LSTM的一种变体,可以说是简化版本的LSTM,但是预测效果也很不错,因此更常用。
GRU使记忆体 h t h^{t} ht融合了长期记忆和短期记忆。

(1)记忆体 h t h^{t} ht

h t = z t ⊙ h t ? 1 + ( 1 ? z t ) ⊙ h t ^ h^{t}=z^{t}\odot h^{t-1}+(1-z^{t})\odot \widehat{h^{t}} ht=ztht?1+(1?zt)ht
其中, z t z^{t} zt为更新门,控制当前状态需要从历史状态中保留多少信息,以及需要从候选状态中接受多少新信息。

(2)候选状态 h t ^ \widehat{h^{t}} ht

h t ^ = tanh ? ( W h x t + U ( r t ⊙ h t ? 1 ) + b h ) \widehat{h^{t}}=\tanh(\mathbf{W}_{h}x^{t}+\mathbf{U}(r^{t}\odot h^{t-1})+\mathbf{b}_{h}) ht =tanh(Wh?xt+U(rtht?1)+bh?)
其中, r t r^{t} rt t t t时刻的重置门。

(3)重置门和更新门

重置门用来控制候选状态 h t ^ \widehat{h^{t}} ht 的计算是否依赖上一时刻的记忆体:
r t = s i g m o i d ( W r x t + U r h t ? 1 + b r ) r^{t}=sigmoid(\mathbf{W}_{r}x^{t}+\mathbf{U}_{r}h^{t-1}+\mathbf{b}_{r}) rt=sigmoid(Wr?xt+Ur?ht?1+br?)
从计算公式可知, r t ∈ [ 0 , 1 ] r^{t}\in [0,1] rt[0,1]。当 r t = 0 r^{t}=0 rt=0时,候选状态和历史状态无关;当 r t = 1 r^{t}=1 rt=1时,候选状态和简单循环网络一致。
更新门的作用和求解方式与重置门相同,计算过程如下:
z t = s i g m o i d ( W z x t + U z h t ? 1 + b z ) z^{t}=sigmoid(\mathbf{W}_{z}x^{t}+\mathbf{U}_{z}h^{t-1}+\mathbf{b}_{z}) zt=sigmoid(Wz?xt+Uz?ht?1+bz?)

2.前向传播过程

(1)计算更新门和重置门:
r t = s i g m o i d ( W r x t + U r h t ? 1 + b r ) z t = s i g m o i d ( W z x t + U z h t ? 1 + b z ) r^{t}=sigmoid(\mathbf{W}_{r}x^{t}+\mathbf{U}_{r}h^{t-1}+\mathbf{b}_{r})\\ z^{t}=sigmoid(\mathbf{W}_{z}x^{t}+\mathbf{U}_{z}h^{t-1}+\mathbf{b}_{z}) rt=sigmoid(Wr?xt+Ur?ht?1+br?)zt=sigmoid(Wz?xt+Uz?ht?1+bz?)
(2)通过重置门和上一时刻记忆体,更新候选状态:
h t ^ = tanh ? ( W h x t + U ( r t ⊙ h t ? 1 ) + b h ) \widehat{h^{t}}=\tanh(\mathbf{W}_{h}x^{t}+\mathbf{U}(r^{t}\odot h^{t-1})+\mathbf{b}_{h}) ht =tanh(Wh?xt+U(rtht?1)+bh?)
(3)计算当前时刻记忆体:
h t = z t ⊙ h t ? 1 + ( 1 ? z t ) ⊙ h t ^ h^{t}=z^{t}\odot h^{t-1}+(1-z^{t})\odot \widehat{h^{t}} ht=ztht?1+(1?zt)ht

3.keras+GPU茅台股票预测

#导入工具包
import pandas as pd
maotai=pd.read_csv('./SH600519.csv')
training_set = maotai.iloc[0:2126,2:3].values
test_set = maotai.iloc[2126:,2:3].values
print(training_set.shape,test_set.shape)
#数据归一化
from sklearn.preprocessing import MinMaxScaler
print(training_set.max(),training_set.min())
sc=MinMaxScaler(feature_range=(0,1))
training_set=sc.fit_transform(training_set)
test_set=sc.fit_transform(test_set)
print(training_set.max(),training_set.min())
#划分数据集测试集
#调整数据维度
import numpy as np
import tensorflow as tf
x_train,y_train,x_test,y_test=[],[],[],[]
for i in range(60,len(training_set)):
    x_train.append(training_set[i-60:i,0])
    y_train.append(training_set[i,0])
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
x_train,y_train = np.array(x_train),np.array(y_train)
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
for i in range(60, len(test_set)):
    x_test.append(test_set[i - 60:i, 0])
    y_test.append(test_set[i, 0])
x_test, y_test = np.array(x_test), np.array(y_test)
x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))
#搭建GRU网络
from tensorflow.keras.layers import GRU,Dropout,Dense
model = tf.keras.Sequential([
    GRU(80,return_sequences=True),
    Dropout(0.2),
    GRU(100),
    Dropout(0.2),
    Dense(1)
])
#配置网络
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss='mean_squared_error')
#训练网络
history = model.fit(x_train, y_train, batch_size=64, epochs=50, 
                    validation_data=(x_test, y_test), validation_freq=1)
#绘制Loss曲线
loss =  history.history['loss']
val_loss = history.history['val_loss']
plt.plot(loss,label='Training Loss')
plt.plot(val_loss,label='Validation Loss')
plt.legend()
plt.title('Loss')
plt.show()
#预测test并且和真实标签对比
predict_price = model.predict(x_test)
predict_price = sc.inverse_transform(predict_price)
real_price = sc.inverse_transform(test_set[60:])
plt.plot(real_price, color='red', label='MaoTai Stock Price')
plt.plot(predict_price, color='blue', label='Predicted MaoTai Stock Price')
plt.title('MaoTai Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('MaoTai Stock Price')
plt.legend()
plt.show()
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-08 10:43:36  更:2021-09-08 10:44:10 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 19:55:02-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码