IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 使用DeepAR实现股价预测 -> 正文阅读

[人工智能]使用DeepAR实现股价预测

使用DeepAR实现股价预测


以往的RNN时间序列预测往往是强调一支股票的股价预测,当提取的一支其他股票的特征时,用于另一支股票预测时就显得捉襟见肘了;当需要对多只股票进行训练及预测时,通常的做法是将他们归类,再进行分别的预测及训练,更重要的是,以往的RNN神经网络(如LSTM等),给出的都是单点预测,在结果服从连续分布的情形下,单点预测的概率其实是0的,我们更希望知道结果的走向,或者框定一个结果走向的范围;

本实验采用DeepAR这个新兴的时间序列预测算法,对78支上市公司股票价格进行训练,训练好的结果可以应用在任意一支股票的预测上(但是本文未给出相关过程),测试集上的表现比较理想。本文仅供学习参考,不作为投资依据;完全原创,转载请注明出处

本文的数据采用了Tushare的大数据接口,感谢Tushare的开发者,为Quanters提供了持续精良的服务

本文的模型采用了mxnet的Deepar模型, deepar模型已经为我们封装好了大多数处理方法,这使得我们的分析过程更加简单快捷,在此一并感谢

import pandas as pd
import tushare as ts
import numpy as np
# 初始化pro接口(该tokens请在tushare个人主页获取)
pro = ts.pro_api('xxx')
np.random.seed(42)

获取股票列表


# 拉取数据
df = pro.stock_basic(**{
    "ts_code": "",
    "name": "",
    "exchange": "",
    "market": "",
    "is_hs": "",
    "list_status": "L",
    "limit": "",
    "offset": ""
}, fields=[
    "ts_code",
    "symbol",
    "name",
    "area",
    "industry",
    "market",
    "list_date"
])
df.to_csv('./Stock-data/股票代码.csv')
df.head()

从众多股票中采样100支

stock_code = pd.read_csv('./Stock-data/股票代码.csv')

name = []
ts_code = []
i = 0
while i < 100:
    sample = stock_code.sample()
    if sample['list_date'].values < 20150731 and 'ST' not in sample['name'].values[0]:
        ts_code.append(sample['ts_code'].values[0])
        name.append(sample['name'].values[0])
        i += 1
print(len(name),name)

日期处理函数

def deal_date(date):
        temp = [date[0:4],date[4:6],date[6:]]
        new_date = '-'.join(temp)
        return new_date

拉取等长度的股票,并保存

stock_list = []
for i,j in zip(name,ts_code):
    # 拉取数据
    df = pro.daily(**{
        "ts_code": f"{j}",
        "trade_date": "",
        "start_date": "20190731",
        "end_date": "20220404",
        "offset": "",
        "limit": ""
    }, fields=[
        "ts_code",
        "trade_date",
        "open",
        "high",
        "low",
        "close",
        "pre_close",
        "change",
        "pct_chg",
        "vol",
        "amount"
    ])
    if len(df) == 649:
        df['Date'] = df['trade_date'].apply(deal_date)
        stock_list.append(i)
        df['name'] = f'{i}'
        df.to_csv(f'./Stock-data/{i}.csv')
print(stock_list)
df.head()

各指标解释

  • open 开盘价
  • high 最高价
  • low 最低价
  • close 收盘价
  • pre_close 昨收价
  • change 涨跌额
  • pct_chg 涨跌幅
  • vol 成交量
  • amount 成交额

我将采用open,high,low,close,change,pct_chg,vol,amount及公司所属行业进行时间序列预测

%matplotlib inline
import mxnet as mx
from mxnet import gluon
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
from tqdm.autonotebook import tqdm
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

预测区间长度及上下文选取

prediction_length = 10
context_length = 20
stock = ['荣盛石化', '新联电子', '栖霞建设', '星徽股份', '中国武夷', 
'飞凯材料', '河钢资源', '平高电气', '新北洋', '亚太科技', '杭州高新', 
'海立股份', '燕塘乳业', '杰瑞股份', '广电电气', '赛摩智能', '成都路桥', 
'恒邦股份', '石化油服', '金隅集团', '青龙管业', '同德化工', '科华数据', 
'中国东航', '澳柯玛', '亚太药业', '冠城大通', '白云机场', '华东医药', 
'全筑股份', '菲利华', '和而泰', '潮宏基', '岭南股份', '索菲亚', '长江证券', 
'炬华科技', '嘉事堂', '西藏珠峰', '聆达股份', '北大荒', '七匹狼', '先河环保', 
'中国汽研', '鸿博股份', '合金投资', '华银电力', '世纪瑞尔', '东方日升', '新开普', 
'亚光科技', '电科院', '粤电力A', '东方雨虹', '普莱柯', '上海机电', '天利科技', 
'奥维通信', '华邦健康', '春秋航空', '杰瑞股份', '海峡股份', '京蓝科技', '中海油服', 
'温州宏丰', '御银股份', '芒果超媒', '太平洋', '泰豪科技', '申达股份', '众合科技', 
'华帝股份', '财信发展', '大金重工', '协鑫集成', '保利联合', '平高电气', '黄河旋风', 
'凌云股份',]
stock = list(set(stock))
data = pd.read_csv('./Stock-data/上海梅林.csv', index_col=False,usecols=['Date','high','low','open','close','change','pct_chg','vol','amount','name'])
print(len(data))
for i in stock:
    temp = pd.read_csv(f'./Stock-data/{i}.csv', index_col=False,usecols=['Date','high','low','open','close','change','pct_chg','vol','amount','name'])
    data = pd.concat([data,temp],ignore_index=True)
print(len(data))
data.head(10)

给这78支股票所在行业进行归类

total = data.copy()
stock_list = sorted(list(set(total["name"])))
date_list = sorted(list(set(total["Date"])))
data_dic = {"name": stock_list}
industry = {}
for i in stock_list:
    temp = stock_code[stock_code['name'] == i]['industry'].values[0]
    if temp not in industry:
        industry[temp] = [i]
    else:
        industry[temp].append(i)
industry
stat_cat_features = []
company_ind = {}
for i,key in enumerate(industry):
    for com in industry[key]:
        company_ind[com] = i
cat_cardinality = [i+1]
print(company_ind)
for i in stock_list:
    stat_cat_features.append([company_ind[i]])
print(stat_cat_features)

目标变量处理

stock_list = sorted(list(set(total["name"])))
date_list = sorted(list(set(total["Date"])))
data_dic = {"name": stock_list}
for date in date_list:
    tmp = total[total["Date"]==date][["name", "Date", "close"]]
    tmp = tmp.pivot(index="name", columns="Date", values="close")
    tmp_values = tmp[date].values
    data_dic[date] = tmp_values
new_df = pd.DataFrame(data_dic)
new_df.head()

协变量处理

def deal_cov_variables(date_list,var_name):
    feature_dict = {}
    for date in date_list:
        tmp = total[total["Date"]==date][["name", "Date", var_name]]
        tmp = tmp.pivot(index="name", columns="Date", values=var_name)
        tmp_values = tmp[date].values
        feature_dict[date] = tmp_values
    feature_df = pd.DataFrame(feature_dict)
    return feature_df
cov_variables = ['high','low','open','close','change','pct_chg','vol','amount']
feature_df_list = []

for i in cov_variables:
    feature_df_list.append(deal_cov_variables(date_list,i))

feature_df_list[0].head()

协变量归一化操作

def min_max_scale(lst):
    '''
    # 基于日期级别的归一化
    :input shape (bank_num,days)
    :output shape (bank_num,days)
    '''
    new = []
    for i in range(len(lst[0])):
        minimum = min(lst[:,i])
        maximum = max(lst[:,i])
        new.append((lst[:,i] - minimum) / (maximum - minimum))
    return np.array(new).T
dynamic_feats = []
for i in range(len(feature_df_list)):
    one_feature = min_max_scale(np.array(feature_df_list[i]))
    dynamic_feats.append(one_feature)
print(one_feature.shape)
dynamic_feats = np.array(dynamic_feats).reshape(-1,len(feature_df_list),len(date_list))
print(dynamic_feats.shape) # (stock_num, feature_num, date_num)

训练、测试数据划分

from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName

# test_target_values是649天的实际结果y
train_df = new_df.drop(["name"], axis=1).values
train_df.reshape(-1,len(date_list))
test_target_values = train_df.copy()
print(len(train_df[0]))
# train_target_values是639天的实际结果y,不能让模型训练到后10天,这样才能看出效果 (将649天shift10天)
train_target_values = [ts[:-prediction_length] for ts in train_df]
print(len(train_target_values[0]))
start_date = [pd.Timestamp("2019-07-31", freq='B') for _ in range(len(new_df))]
train_ds = ListDataset([
    {
        FieldName.TARGET: target,
        FieldName.START: start,
        FieldName.FEAT_DYNAMIC_REAL: dynamic_feat[:,:-prediction_length],
        FieldName.FEAT_STATIC_CAT:cat_feature,
    }
    for (target, start,dynamic_feat,cat_feature) in zip(train_target_values,
                                        start_date,
                                        dynamic_feats,
                                        stat_cat_features)
], freq="1B")

test_ds = ListDataset([
    {
        FieldName.TARGET: target,
        FieldName.START: start,
        FieldName.FEAT_DYNAMIC_REAL: dynamic_feat,
        FieldName.FEAT_STATIC_CAT:cat_feature,
    }
    for (target, start,dynamic_feat,cat_feature) in zip(test_target_values,
                                        start_date,
                                        dynamic_feats,
                                        stat_cat_features)
], freq="1B")
sample_trian = next(iter(train_ds))

训练模型

from gluonts.model.deepar import DeepAREstimator
from gluonts.mx.distribution.gaussian import GaussianOutput
from gluonts.mx.trainer import Trainer

n = 100
estimator = DeepAREstimator(
    prediction_length=prediction_length,
    context_length=context_length,
    freq="1B",
    distr_output = GaussianOutput(),
    use_feat_dynamic_real=True,
    dropout_rate=0.1,
    use_feat_static_cat=True,
    cardinality=cat_cardinality,
    trainer=Trainer(
        learning_rate=1e-3,
        epochs=n,
        num_batches_per_epoch=50,
        batch_size=32
    )
)
predictor = estimator.train(train_ds)

预测过程

from gluonts.evaluation.backtest import make_evaluation_predictions

forecast_it, ts_it = make_evaluation_predictions(
    dataset=test_ds,
    predictor=predictor,
    num_samples=100
)

print("Obtaining time series conditioning values ...")
tss = list(tqdm(ts_it, total=len(test_ds)))
print("Obtaining time series predictions ...")
forecasts = list(tqdm(forecast_it, total=len(test_ds)))

模型评估

from gluonts.evaluation import Evaluator


class CustomEvaluator(Evaluator):

    def get_metrics_per_ts(self, time_series, forecast):
        successive_diff = np.diff(time_series.values.reshape(len(time_series)))
        successive_diff = successive_diff ** 2
        successive_diff = successive_diff[:-prediction_length]
        denom = np.mean(successive_diff)
        pred_values = forecast.samples.mean(axis=0)
        true_values = time_series.values.reshape(len(time_series))[-prediction_length:]
        num = np.mean((pred_values - true_values) ** 2)
        rmsse = num / denom
        metrics = super().get_metrics_per_ts(time_series, forecast)
        metrics["RMSSE"] = rmsse
        return metrics

    def get_aggregate_metrics(self, metric_per_ts):
        wrmsse = metric_per_ts["RMSSE"].mean()
        agg_metric, _ = super().get_aggregate_metrics(metric_per_ts)
        agg_metric["MRMSSE"] = wrmsse
        return agg_metric, metric_per_ts


evaluator = CustomEvaluator(quantiles=[0.5, 0.67, 0.95, 0.99])
agg_metrics, item_metrics = evaluator(iter(tss), iter(forecasts), num_series=len(test_ds))
print(json.dumps(agg_metrics, indent=4))

结果查看

a = forecasts[0]
print(a.mean)
print(a.quantile(0.95))
import warnings
warnings.filterwarnings("ignore")
plot_log_path = "./plots/"
directory = os.path.dirname(plot_log_path)
if not os.path.exists(directory):
    os.makedirs(directory)
    

def plot_prob_forecasts(ts_entry, forecast_entry, path, sample_id, name, inline=True):
    plot_length = 150
    prediction_intervals = (50, 67, 95, 99)
    legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]

    _, ax = plt.subplots(1, 1, figsize=(10, 7))
    ts_entry[-plot_length:].plot(ax=ax)
    forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')
    ax.axvline(ts_entry.index[-prediction_length], color='r')
    plt.legend(legend, loc="upper left")
    plt.title(f'{name} Price series and predict results')
    if inline:
        plt.show()
        plt.clf()
    else:
        plt.savefig('{}forecast_{}.pdf'.format(path, sample_id))
        plt.close()

print("Plotting time series predictions ...")
for i in tqdm(range(20,30)):
    ts_entry = tss[i]
    forecast_entry = forecasts[i]
    name = stock_list[i]
    plot_prob_forecasts(ts_entry, forecast_entry, plot_log_path, i, name)

绘图结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

欢迎交流,实验不易,转载请注明出处!!!

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-06 23:10:15  更:2022-04-06 23:13:42 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 4:48:26-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码