IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> LSTM实现股票预测 -> 正文阅读

[人工智能]LSTM实现股票预测

1、传统RNN的缺点

??RNN 面临的较大问题是无法解决长跨度依赖问题,即后面节点相对于跨度很大的前面时间节点的信息感知能力太弱。如下图中的两句话:左上角的句子中 sky 可以由较短跨度的词预测出来,而右下角句子中的 French 与较长跨度之前的 France 有关系,即长跨度依赖,比较难预测。
在这里插入图片描述

??图片来源:https://www.jianshu.com/p/9dc9f41f0b29

??长跨度依赖的根本问题在于,多阶段的反向传播后会导致梯度消失、梯度爆炸。可以使用梯度截断去解决梯度爆炸问题,但无法轻易解决梯度消失问题。

2、LSTM(长短时记忆网络)

2.1 原理

??为了解决长期依赖问题,长短记忆网络(Long Short Term Memory,LSTM)应运而生。之所以LSTM 能解决 RNN 的长期依赖问题,是因为 LSTM 使用门(gate)机制对信息的流通和损失进行控制
在这里插入图片描述
??如图上图所示,LSTM 引入了三个门限:输入门 i t i_t it?、遗忘门 f t f_t ft?、输出门 o t o_t ot?;引入了表征长期记忆的细胞态 C t C_t Ct?;引入了等待存入长期记忆的候选态 C t ~ \widetilde{C_t} Ct? ?;

??三个门限都是当前时刻的输入特征 x t x_t xt?和上个时刻的短期记忆 h t ? 1 h_{t-1} ht?1?的函数,分别表示为:

  • 输入门(门限) i t = σ ( W i . [ h t ? 1 , x t ] + b j ) i_t=\sigma (W_i.[h_{t-1},x_t]+b_j) it?=σ(Wi?.[ht?1?,xt?]+bj?),决定了多少比例的信息会被存入当前细胞态;
  • 遗忘门(门限) f t = σ ( W f . [ h t ? 1 , x t ] + b f ) f_t=\sigma (W_f.[h_{t-1},x_t]+b_f) ft?=σ(Wf?.[ht?1?,xt?]+bf?),将细胞态中的信息选择性的遗忘;
  • 输出门(门限): o t = σ ( W o . [ h t ? 1 , x t ] + b o ) o_t=\sigma (W_o.[h_{t-1},x_t]+b_o) ot?=σ(Wo?.[ht?1?,xt?]+bo?),将细胞态中的信息选择性的进行输出;

三个公式中 W i W_i Wi? W f W_f Wf? W o W_o Wo?是待训练参数矩阵, b i b_i bi? b f b_f bf? b o b_o bo?是待训练偏置项。 σ \sigma σsigmoid 激活函数,它可以使门限的范围在 0 到 1 之间。
??定义 h t h_t ht?为记忆体,它表征短期记忆,是当前细胞态经过输出门得到的:
记 忆 体 ( 短 期 记 忆 ) : h t = o t ? t a n h ( C t ) 记忆体(短期记忆): h_t=o_t*tanh(C_t) ):ht?=ot??tanh(Ct?)
??候选态表示归纳出的待存入细胞态的新知识,是当前时刻的输入特征 x t x_t xt?和上个时刻的短期记忆 h t ? 1 h_{t-1} ht?1?的函数:
候 选 态 ( 归 纳 出 的 新 知 识 ) : C t ~ = t a n h ( W c . [ h t ? 1 , x t ] + b c ) 候选态(归纳出的新知识):\widetilde{C_t}=tanh(W_c.[h_{t-1},x_t]+b_c) ()Ct? ?=tanh(Wc?.[ht?1?,xt?]+bc?)
??细胞态 C t C_t Ct?表示长期记忆,它等于上个时刻的长期记忆 C t ? 1 C_{t-1} Ct?1?通过遗忘门的值和当前时刻归纳出的新知识 C t ~ \widetilde{C_t} Ct? ?通过输入门的值之和:
细 胞 态 ( 长 期 记 忆 ) : C t = f t ? C t ? 1 + i t ? C t ~ 细胞态(长期记忆):C_t=f_t*C_{t-1}+i_t*\widetilde{C_t} ()Ct?=ft??Ct?1?+it??Ct? ?
在这里插入图片描述

2.2 举例

??当明确了这些概念,这里举一个简单的例子理解一下LSTM

??假设 LSTM 就是我们听老师讲课的过程,目前老师讲到了第 45 页 PPT。我们的脑袋里记住的内容,是 PPT 第 1 页到第 45 页的长期记忆 C t C_t Ct?。它由两部分组成:一部分是 PPT 第 1 页到第 44 页的内容,也就是上一时刻的长期记忆 C t ? 1 C_{t-1} Ct?1?。我们不可能一字不差的记住全部内容,会不自觉地忘记了一些,所以上个时刻的长期记忆 C t ? 1 C_{t-1} Ct?1?要乘以遗忘门,这个乘积项就表示留存在我们脑中的对过去的记忆;另一部分是当前我们归纳出的新知识 C t ~ \widetilde{C_t} Ct? ?,,它由老师正在讲的第 45 页 PPT(当前时刻的输入 x t x_t xt?)和第 44 页 PPT 的短期记忆留存(上一时刻的短期记忆 h t ? 1 h_{t-1} ht?1?)组成。将现在的记忆 C t ~ \widetilde{C_t} Ct? ?乘以输入门后与过去的记忆一同存储为当前的长期记忆 C t C_t Ct?。接下来,如果我们想把我们学到的知识(当前的长期记忆 C t C_t Ct?)复述给朋友,我们不可能一字不落的讲出来,所以 C t C_t Ct?需要经过输出门筛选后才成为了输出 h t h_t ht?
??当有多层循环网络时,第二层循环网络的输入 x t x_t xt?就是第一层循环网络的输出 h t h_t ht?,即输入第二层网络的是第一层网络提取出的精华。可以这么想,老师现在扮演的就是第一层循环网络,每一页 PPT 都是老师从一篇一篇论文中提取出的精华,输出给我们。作为第二层循环网络的我们,接收到的数据就是老师的长期记忆 C t C_t Ct?tanh 激活函数后乘以输出门提取出的短期记忆 h t h_t ht?

2.3 Tensorflow2描述LSTM层

tf.keras.layers.LSTM(
记忆体个数,
return_sequences=是否返回输出)

??return_sequences=True 各时间步输出ht
??return_sequences=False 仅最后时间步输出ht(默认)
??例如:

model = tf.keras.Sequential([
    LSTM(80, return_sequences=True),
    Dropout(0.2),
    LSTM(100),
    Dropout(0.2),
    Dense(1)
])

3、LSTM实现股票预测

3.1 数据源

??SH600519.csv 是用 tushare 模块下载的 SH600519 贵州茅台的日 k 线数据,本次例子中只用它的 C 列数据(如图 所示):
用连续 60 天的开盘价,预测第 61 天的开盘价。
在这里插入图片描述
在这里插入图片描述

3.2 代码实现

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dropout, Dense, LSTM
import matplotlib.pyplot as plt
import os
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import math

maotai = pd.read_csv('./SH600519.csv')  # 读取股票文件

training_set = maotai.iloc[0:2426 - 300, 2:3].values  # 前(2426-300=2126)天的开盘价作为训练集,表格从0开始计数,2:3 是提取[2:3)列,前闭后开,故提取出C列开盘价
test_set = maotai.iloc[2426 - 300:, 2:3].values  # 后300天的开盘价作为测试集

# 归一化
sc = MinMaxScaler(feature_range=(0, 1))  # 定义归一化:归一化到(0,1)之间
training_set_scaled = sc.fit_transform(training_set)  # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化
test_set = sc.transform(test_set)  # 利用训练集的属性对测试集进行归一化

x_train = []
y_train = []

x_test = []
y_test = []

# 测试集:csv表格中前2426-300=2126天数据
# 利用for循环,遍历整个训练集,提取训练集中连续60天的开盘价作为输入特征x_train,第61天的数据作为标签,for循环共构建2426-300-60=2066组数据。
for i in range(60, len(training_set_scaled)):
    x_train.append(training_set_scaled[i - 60:i, 0])
    y_train.append(training_set_scaled[i, 0])
# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
# 将训练集由list格式变为array格式
x_train, y_train = np.array(x_train), np.array(y_train)

# 使x_train符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。
# 此处整个数据集送入,送入样本数为x_train.shape[0]即2066组数据;输入60个开盘价,预测出第61天的开盘价,循环核时间展开步数为60; 每个时间步送入的特征是某一天的开盘价,只有1个数据,故每个时间步输入特征个数为1
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
# 测试集:csv表格中后300天数据
# 利用for循环,遍历整个测试集,提取测试集中连续60天的开盘价作为输入特征x_train,第61天的数据作为标签,for循环共构建300-60=240组数据。
for i in range(60, len(test_set)):
    x_test.append(test_set[i - 60:i, 0])
    y_test.append(test_set[i, 0])
# 测试集变array并reshape为符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
x_test, y_test = np.array(x_test), np.array(y_test)
x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))

model = tf.keras.Sequential([
    LSTM(80, return_sequences=True),
    Dropout(0.2),
    LSTM(100),
    Dropout(0.2),
    Dense(1)
])

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss='mean_squared_error')  # 损失函数用均方误差
# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值

checkpoint_save_path = "./checkpoint/LSTM_stock.ckpt"

if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True,
                                                 monitor='val_loss')

history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])

model.summary()

file = open('./weights.txt', 'w')  # 参数提取
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

################## predict ######################
# 测试集输入模型进行预测
predicted_stock_price = model.predict(x_test)
# 对预测数据还原---从(0,1)反归一化到原始范围
predicted_stock_price = sc.inverse_transform(predicted_stock_price)
# 对真实数据还原---从(0,1)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])
# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='MaoTai Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted MaoTai Stock Price')
plt.title('MaoTai Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('MaoTai Stock Price')
plt.legend()
plt.show()

##########evaluate##############
# calculate MSE 均方误差 ---> E[(预测值-真实值)^2] (预测值减真实值求平方后求均值)
mse = mean_squared_error(predicted_stock_price, real_stock_price)
# calculate RMSE 均方根误差--->sqrt[MSE]    (对均方误差开方)
rmse = math.sqrt(mean_squared_error(predicted_stock_price, real_stock_price))
# calculate MAE 平均绝对误差----->E[|预测值-真实值|](预测值减真实值求绝对值后求均值)
mae = mean_absolute_error(predicted_stock_price, real_stock_price)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)

??LSTM 股票预测 loss 曲线:
在这里插入图片描述
??LSTM 股票预测曲线:
在这里插入图片描述
??LSTM股票预测评价指标
在这里插入图片描述
??模型摘要:
在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-24 18:10:12  更:2022-05-24 18:12:10 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 5:25:25-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码