读取pandas数据 官方文档:https://www.backtrader.com/docu/pandas-datafeed/pandas-datafeed/
DataFeed 开发 官方文档:https://www.backtrader.com/docu/datafeed-develop-general/datafeed-develop-general/
SQL读取类
官方暂时没有一键读取的功能类,因此需要自己写,一个简单的例子如下,运用请参考最下面的示例代码:
class SQLiteData(DataBase):
"""自定义的SQL lite数据格式"""
params = (
('dataname', None),
('name', ''),
('timeframe', TimeFrame.Days),
('fromdate', None),
('todate', None),
)
def __init__(self):
self.engine = create_engine('sqlite:///local_sql_lite.db')
self.tabel_name = "my_stock_code"
def start(self):
self.conn = self.engine.connect()
sql_query = "SELECT `date`,`open`,`high`,`low`,`close`,`volume`,`turnover` FROM `{}` ORDER BY `date` ASC" \
.format(self.tabel_name)
self.result = self.conn.execute(sql_query)
def stop(self):
self.engine.dispose()
def _load(self):
one_row = self.result.fetchone()
if one_row is None:
return False
self.lines.datetime[0] = date2num(dt.datetime.strptime(str(one_row[0]), '%Y-%m-%d %H:%M:%S'))
self.lines.open[0] = float(one_row[1])
self.lines.high[0] = float(one_row[2])
self.lines.low[0] = float(one_row[3])
self.lines.close[0] = float(one_row[4])
self.lines.volume[0] = int(one_row[5])
self.lines.turnover[0] = float(one_row[6])
self.lines.openinterest[0] = -1
return True
其中有几个比较重要的函数:
def start() :在加载数据前执行一次,常用于初始化参数def stop() :结束数据加载程序之后执行一次,常用于关闭数据库链接def _load() :在策略中的next()拿到的数据其实就是这里传过去的数据,会循环执行多次这个函数,直到取得全部数据或接收到False 或None 的返回值params :这是定义数据集本身的一些参数,重要参数已在示例代码中解释,更多参数请参考官网
注意:def _load() 函数中:
self.lines :表示第一个数据集的列,等同于 self.datas[0].lines ,是一种简写形式self.lines.open :指代数据中的开盘价那一列self.lines.open[0] :特指当天的开盘价,如果是self.lines.open[-1],就是昨天的,如果是self.lines.open[1] 就是明天的
示例代码
import backtrader
import efinance
import pandas as pd
from datetime import datetime
import sqlite3
import datetime as dt
from backtrader import TimeFrame
from backtrader.feed import DataBase
from backtrader import date2num
from sqlalchemy import create_engine
def get_k_data(stock_code, begin: datetime, end: datetime) -> pd.DataFrame:
"""根据efinance工具包获取股票数据
:param stock_code:股票代码
:param begin: 开始日期
:param end: 结束日期
"""
k_dataframe: pd.DataFrame = efinance.stock.get_quote_history(
stock_code, beg=begin.strftime("%Y%m%d"), end=end.strftime("%Y%m%d"))
k_dataframe = k_dataframe.iloc[:, :9]
k_dataframe.columns = ['name', 'code', 'date', 'open', 'close', 'high', 'low', 'volume', 'turnover']
k_dataframe.index = pd.to_datetime(k_dataframe.date)
k_dataframe.drop(['name', 'code', "date"], axis=1, inplace=True)
return k_dataframe
def write_sql_lite_from_pandas(stock_code, begin: datetime, end: datetime):
"""获取K线数据,并保存到SQL lite数据库"""
conn = sqlite3.connect('local_sql_lite.db')
dataframe = get_k_data(stock_code, begin=begin, end=end)
dataframe.to_sql("my_stock_code", conn, if_exists="replace")
class SQLiteData(DataBase):
"""自定义的SQL lite数据格式"""
params = (
('dataname', None),
('name', ''),
('timeframe', TimeFrame.Days),
('fromdate', None),
('todate', None),
('turnover', -1),
)
lines = ('turnover',)
def __init__(self):
self.engine = create_engine('sqlite:///local_sql_lite.db')
self._timeframe = self.p.timeframe
self._compression = self.p.compression
self._dataname = "my_stock_code"
def start(self):
self.conn = self.engine.connect()
sql_query = "SELECT `date`,`open`,`high`,`low`,`close`,`volume`,`turnover` FROM `{}` ORDER BY `date` ASC" \
.format(self._dataname)
self.result = self.conn.execute(sql_query)
def stop(self):
self.engine.dispose()
def _load(self):
one_row = self.result.fetchone()
if one_row is None:
return False
self.lines.datetime[0] = date2num(dt.datetime.strptime(str(one_row[0]), '%Y-%m-%d %H:%M:%S'))
self.lines.open[0] = float(one_row[1])
self.lines.high[0] = float(one_row[2])
self.lines.low[0] = float(one_row[3])
self.lines.close[0] = float(one_row[4])
self.lines.volume[0] = int(one_row[5])
self.lines.turnover[0] = float(one_row[6])
self.lines.openinterest[0] = -1
return True
class MyStrategy1(backtrader.Strategy):
def __init__(self):
self.close_price = self.datas[0].close
this_data = self.getdatabyname("stock_600519")
print("全部列名:", this_data.getlinealiases())
def next(self):
print('=======================')
print("今天是:", self.datetime.date())
print("当前的值:", dict(zip(self.datas[0].getlinealiases(), [i[0] for i in list(self.datas[0].lines)])))
def main():
start_time = datetime(2015, 1, 1)
end_time = datetime(2015, 1, 10)
write_sql_lite_from_pandas("600519", start_time, end_time)
data = SQLiteData()
cerebral_system = backtrader.Cerebro()
cerebral_system.adddata(data, name="stock_600519")
cerebral_system.addstrategy(MyStrategy1)
cerebral_system.run()
if __name__ == '__main__':
main()
|