????????SQLite是一个软件库,实现了自给自足、无服务的、零配置的、事务性的SQL数据库引擎。在股票分析系统中适合使用这种轻量型的数据库。
一. SQLite安装与基本使用命令
1.1 Windows上安装SQLite
????????SQLite下载地址为: SQLite Download Page,我们需要下载sqlite-dll和sqlite-tools两个压缩包。
?????????将zip包解压,放到自己的安装目录,如C:\sqlite3,得到sqlite3.def, sqlite3.dll和sqlite3.exe文件。
????????将sqlite添加到环境变量,然后在CMD中使用sqlite3命令,结果将显示:
1.2 SQLite基本操作
语法 | 功能 | sqlite3 DatabaseName.db | 创建数据库 | CREATE TABLE database_name.table_name( column1 datatype PRIMARY KEY(one or more columns), column2 datatype, column3 datatype, ..... columnN datatype, ); | 创建表 | DROP TABLE database_name.table_name; | 删除表 | INSERT INTO TABLE_NAME [(column1, column2, column3, ... columnN)] VALUES (value1, value2, value3,...valueN); | 插入新的数据行 | SELECT column1, column2, columnN FROM table_name; | 查找数据 | SELECT column1, column2, columnN FROM table_name WHERE [condition] | 条件查找 |
1.3 Python SQLite接口
????????python3内置了sqlite3数据库接口,使用API操作数据库主要有以下几步:
- 连接数据库:db = sqlite3.connect(path_to_dabase), 如果不存在这个数据库则创建一个。
- 获取游标:cursor = db.cursor()
- 执行SQL语句:cursor.execute('sql statement')
- 关闭数据库:db.close()
二.使用SQLite搭建股票数据库
2.1 创建数据表
????????创建一个名为stocks的数据表,表中存储每天所有A股的开盘价、最高价、最低价、收盘价,交易量等数据。
cursor.execute('''CREATE TABLE IF NOT EXISTS stocks (
ts_code varchar(10),
trade_date varchar(50),
open REAL,
high REAL,
low REAL,
close REAL,
pre_close REAL,
change REAL,
pct_chg REAL,
vol REAL,
amount REAL,
primary key(ts_code, trade_date)
)''')
????????注意,由于一个数据表要存储很多天的股票,如果只用ts_code作为主键将造成无法插入数据的问题,因此这个使用联合主键的方式将ts_code和trade_date共同构成主键:primary key(ts_code, trade_date)
2.2 获取每日股票数据
????????获取每日股票数据可以使用tushare或者akshare这样的python包,本文使用的是tushare。将tushare获取到的每日数据插入到数据库:
pd_data = get_all_stocks(date=day_str)
if len(pd_data) == 0:
continue
data = pd_data.values.tolist()
data = [tuple(x) for x in data]
try:
self.cursor.executemany('''INSERT INTO stocks (ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount)
VALUES(?,?,?,?,?,?,?,?,?,?,?);''', data)
except sqlite3.Error as e:
# print(e)
pass
# Save the changes
self.db.commit()
????????为了跟新所有历史数据,将上述代码放在日期的循环中即可:
start_date = self._get_last_updated_date()
print("从{}开始更新...".format(str(start_date)))
end_date = datetime.date.today()
for i in tqdm(range((end_date-start_date).days + 1)):
day = start_date + datetime.timedelta(days=i)
day_str = str(day).replace('-', '')
...
2.3 查询数据
????????更新完数据库后,使用SQL的SELECT语句即可查询到股票数据,例如,要查询某一只股票的历史股价信息,可以使用如下命令:
@database
def get_stock_data(self, code):
try:
self.cursor.execute("SELECT * FROM stocks WHERE ts_code='%s'" % code)
except sqlite3.Error as e:
pass
pd_data = pandas.DataFrame(self.cursor.fetchall(), columns=self.colums)
return pd_data
????????上述代码中使用了@database装饰器,因为需要实现的方法大多都需要连接数据库,操作完后再关闭,因此实现该装饰器自动补全这些操作:
def database(func):
def wrapper(self, *args, **kwargs):
self.connect()
val = func(self, *args, **kwargs)
self._close()
return val
return wrapper
附录
????????基本的实现SQLite3存储并自动更新数据库的python代码如下:
import pandas
import os
import sqlite3
import sys
import datetime
from tqdm import tqdm
import numpy as np
cur_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(cur_path + "/../")
from shares import stock_index
class DataSource:
def __init__(self) -> None:
self.db = None #数据库
self.cursor = None #数据库游标
self.stocks = {} #股票数据
self.indexs = {} #指数池
self.colums = ['ts_code','trade_date','open','high','low','close','pre_close','change','pct_chg','vol','amount']
self.database_name = cur_path + "/../../test/stock_database_2021.db"
self.stock_index_ = stock_index.StockIndex()
# create database
self._create_database()
def _close(self):
self.cursor.close()
self.db.close()
def _create_database(self):
self.connect()
# ts_code trade_date open high low close pre_close change pct_chg vol amount
self.cursor.execute('''CREATE TABLE IF NOT EXISTS stocks (
ts_code varchar(10),
trade_date varchar(50),
open REAL,
high REAL,
low REAL,
close REAL,
pre_close REAL,
change REAL,
pct_chg REAL,
vol REAL,
amount REAL,
primary key(ts_code, trade_date)
)''')
self._close()
def connect(self):
self.db = sqlite3.connect(self.database_name)
self.cursor = self.db.cursor()
if self.db == None or self.cursor == None:
raise ValueError("Data base connect error")
def database(func):
def wrapper(self, *args, **kwargs):
self.connect()
val = func(self, *args, **kwargs)
self._close()
return val
return wrapper
def _get_last_updated_date(self):
if (self.cursor == None):
raise ValueError("Open Database firstly!")
self.cursor.execute("SELECT trade_date FROM stocks")
trade_date = self.cursor.fetchall()
trade_date = list(set(trade_date))
max_date = 20210101 # save start 2010.01.01
if len(trade_date) >= 1:
trade_date = [int(x[0]) for x in trade_date]
trade_date_array = np.array(trade_date)
max_date = np.max(trade_date_array)
max_date_str = str(max_date)
year = int(max_date_str[0:4])
month = int(max_date_str[4:6].lstrip('0'))
day = int(max_date_str[6:8].lstrip('0'))
max_date_t = datetime.date(year, month, day)
return max_date_t
def update(self):
self.connect()
# Insert a row of data
print("更新同步股票数据:")
start_date = self._get_last_updated_date()
print("从{}开始更新...".format(str(start_date)))
end_date = datetime.date.today()
for i in tqdm(range((end_date-start_date).days + 1)):
day = start_date + datetime.timedelta(days=i)
day_str = str(day).replace('-', '')
pd_data = self.stock_index_.get_all_stocks(date=day_str)
if len(pd_data) == 0:
continue
data = pd_data.values.tolist()
data = [tuple(x) for x in data]
try:
self.cursor.executemany('''INSERT INTO stocks (ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount)
VALUES(?,?,?,?,?,?,?,?,?,?,?);''', data)
except sqlite3.Error as e:
# print(e)
pass
# Save the changes
self.db.commit()
self._close()
@database
def dump_all_data(self):
self.cursor.execute("SELECT * FROM stocks")
print(self.cursor.fetchall())
@database
def get_stock_data(self, code):
try:
self.cursor.execute("SELECT * FROM stocks WHERE ts_code='%s'" % code)
except sqlite3.Error as e:
pass
pd_data = pandas.DataFrame(self.cursor.fetchall(), columns=self.colums)
return pd_data
# get database last date stock data
@database
def get_all_index(self, date=""):
start_date = self._get_last_updated_date()
today_str = str(start_date)
today_str = today_str.replace('-', '')
if date!="":
today_str = date
try:
self.cursor.execute("SELECT * FROM stocks WHERE trade_date='%s'" % today_str)
pd_data = pandas.DataFrame(self.cursor.fetchall(), columns=self.colums)
stoc_basic = self.stock_index_.get_stock_info()
all_data = pandas.merge(pd_data, stoc_basic, how='left', on='ts_code')
all_data = all_data.sort_values('pct_chg')
pd_data = all_data.dropna(axis=0, how='any')
except sqlite3.Error as e:
raise ValueError("can't get today stock data")
return pd_data
def get_amain_index(self):
all_data = self.get_all_index()
a_data = all_data[all_data.ts_code.str.contains('^000|^001|^002')]
return a_data
def get_gem_index(self):
data = self.get_all_index()
kc_data = data[data.ts_code.str.contains('^30|^68')]
return kc_data
if __name__ == "__main__":
data_source = DataSource()
data_source.update()
print("get stock data:")
data = data_source.get_stock_data(code='688607.SH')
print(data)
|