TurtleTrade/mysql_database.py

260 lines
9.9 KiB
Python
Raw Permalink Normal View History

2025-04-02 22:55:29 +08:00
#!/usr/local/bin/python
# -*- coding: utf-8 -*-
# apk add py-mysqldb or
import platform
from datetime import datetime
import time
import sys
import os
import MySQLdb
from sqlalchemy import create_engine
from sqlalchemy.types import NVARCHAR
from sqlalchemy import inspect
from sqlalchemy.sql import text
import pandas as pd
import traceback
import akshare as ak
# 使用环境变量获得数据库。兼容开发模式可docker模式。
MYSQL_HOST = os.environ.get('MYSQL_HOST') if (os.environ.get('MYSQL_HOST') != None) else "localhost"
MYSQL_USER = os.environ.get('MYSQL_USER') if (os.environ.get('MYSQL_USER') != None) else "root"
MYSQL_PWD = os.environ.get('MYSQL_PWD') if (os.environ.get('MYSQL_PWD') != None) else "1212"
MYSQL_DB = os.environ.get('MYSQL_DB') if (os.environ.get('MYSQL_DB') != None) else "stock_data"
MYSQL_PORT = os.environ.get('MYSQL_PORT') if (os.environ.get('MYSQL_PORT') != None) else "3307"
2025-04-02 22:55:29 +08:00
print("MYSQL_HOST :", MYSQL_HOST, ",MYSQL_USER :", MYSQL_USER, ",MYSQL_DB :", MYSQL_DB)
MYSQL_CONN_URL = "mysql+mysqldb://" + MYSQL_USER + ":" + MYSQL_PWD + "@" + MYSQL_HOST + ":" + MYSQL_PORT + "/" + MYSQL_DB + "?charset=utf8mb4"
print("MYSQL_CONN_URL :", MYSQL_CONN_URL)
__version__ = "2.0.0"
# 每次发布时候更新。
def engine():
engine = create_engine(MYSQL_CONN_URL)
#encoding='utf8', convert_unicode=True)
return engine
def engine_to_db(to_db):
MYSQL_CONN_URL_NEW = "mysql+mysqldb://" + MYSQL_USER + ":" + MYSQL_PWD + "@" + MYSQL_HOST + ":" + MYSQL_PORT + "/" + to_db + "?charset=utf8mb4"
engine = create_engine(MYSQL_CONN_URL_NEW)
#encoding='utf8', convert_unicode=True)
return engine
# 通过数据库链接 engine。
def conn():
try:
db = MySQLdb.connect(host=MYSQL_HOST,
user=MYSQL_USER,
passwd=MYSQL_PWD,
db=MYSQL_DB,
port=int(MYSQL_PORT), # 确保转换为整数
charset="utf8")
2025-04-02 22:55:29 +08:00
# db.autocommit = True
except Exception as e:
print("conn error :", e)
db.autocommit(on=True)
return db.cursor()
# 定义通用方法函数,插入数据库表,并创建数据库主键,保证重跑数据的时候索引唯一。
def insert_db(data, table_name, write_index, primary_keys):
# 插入默认的数据库。
insert_other_db(MYSQL_DB, data, table_name, write_index, primary_keys)
# 增加一个插入到其他数据库的方法。
def insert_other_db(to_db, data, table_name, write_index, primary_keys):
# 定义engine
engine_mysql = engine_to_db(to_db)
# 使用 http://docs.sqlalchemy.org/en/latest/core/reflection.html
# 使用检查检查数据库表是否有主键。
insp = inspect(engine_mysql)
col_name_list = data.columns.tolist()
# 如果有索引把索引增加到varchar上面。
if write_index:
# 插入到第一个位置:
col_name_list.insert(0, data.index.name)
print(col_name_list)
data.to_sql(name=table_name, con=engine_mysql, schema=to_db, if_exists='append',
dtype={col_name: NVARCHAR(length=255) for col_name in col_name_list}, index=write_index)
# print(insp.get_pk_constraint(table_name))
# print()
# print(type(insp))
# 判断是否存在主键
if insp.get_pk_constraint(table_name)['constrained_columns'] == []:
with engine_mysql.connect() as con:
try:
# 使用 text 包裹 SQL 语句
con.execute(text('ALTER TABLE `%s` ADD PRIMARY KEY (`%s`)' % (table_name, primary_keys)))
except Exception as e:
print("################## ADD PRIMARY KEY ERROR :", e)
def fetch_all_data(table_name):
engine_mysql = engine_to_db(MYSQL_DB)
data = pd.read_sql_table(table_name, engine_mysql)
return data
def check_db_table(table_name):
# 判断是否已存在此表。
engine_mysql = engine_to_db(MYSQL_DB)
insp = inspect(engine_mysql)
if insp.has_table(table_name):
return True
else:
return False
def check_db_table_last_date(table_name):
engine_mysql = engine_to_db(MYSQL_DB)
insp = inspect(engine_mysql)
if not insp.has_table(table_name):
return None
sql = f"SELECT `日期` FROM `{table_name}` ORDER BY `日期` DESC LIMIT 1"
# params = (table_name,)
result = select(sql)#, params)
if not result:
return None
date_str = result[0][0]
if not date_str:
return None
try:
extracted_date_str = date_str.split(' ')[0]
return datetime.strptime(extracted_date_str, "%Y-%m-%d").date()
except ValueError:
return None
# 插入数据。
def insert(sql, params=()):
with conn() as db:
print("insert sql:" + sql)
try:
db.execute(sql, params)
except Exception as e:
print("error :", e)
# 查询数据
# def select(sql, params=()):
# with conn() as db:
# print("select sql:" + sql)
# try:
# db.execute(sql, params)
# except Exception as e:
# print("error :", e)
# result = db.fetchall()
# return result
def select(sql, params=None):
# 假设conn()是一个有效的数据库连接函数
with conn() as db:
try:
db.execute(sql, params)
result = db.fetchall()
return result
except Exception as e:
print(f"Error executing SQL query: {e}")
return None
# 计算数量
def select_count(sql, params=()):
with conn() as db:
print("select sql:" + sql)
try:
db.execute(sql, params)
except Exception as e:
print("error :", e)
result = db.fetchall()
# 只有一个数组中的第一个数据
if len(result) == 1:
return int(result[0][0])
else:
return 0
def delete_table(table_name):
with conn() as db:
drop_table_query = f"DROP TABLE IF EXISTS `{table_name}`"
db.execute(drop_table_query)
print(f"{table_name} 已删除,准备重新获取数据")
# # 通用函数。获得日期参数。
# def run_with_args(run_fun):
# tmp_datetime_show = datetime.datetime.now() # 修改成默认是当日执行 + datetime.timedelta()
# tmp_hour_int = int(tmp_datetime_show.strftime("%H"))
# if tmp_hour_int < 12 :
# # 判断如果是每天 中午 12 点之前运行,跑昨天的数据。
# tmp_datetime_show = (tmp_datetime_show + datetime.timedelta(days=-1))
# tmp_datetime_str = tmp_datetime_show.strftime("%Y-%m-%d %H:%M:%S.%f")
# print("\n######################### hour_int %d " % tmp_hour_int)
# str_db = "MYSQL_HOST :" + MYSQL_HOST + ", MYSQL_USER :" + MYSQL_USER + ", MYSQL_DB :" + MYSQL_DB
# print("\n######################### " + str_db + " ######################### ")
# print("\n######################### begin run %s %s #########################" % (run_fun, tmp_datetime_str))
# start = time.time()
# # 要支持数据重跑机制,将日期传入。循环次数
# if len(sys.argv) == 3:
# # python xxx.py 2017-07-01 10
# tmp_year, tmp_month, tmp_day = sys.argv[1].split("-")
# loop = int(sys.argv[2])
# tmp_datetime = datetime.datetime(int(tmp_year), int(tmp_month), int(tmp_day))
# for i in range(0, loop):
# # 循环插入多次数据,重复跑历史数据使用。
# # time.sleep(5)
# tmp_datetime_new = tmp_datetime + datetime.timedelta(days=i)
# try:
# run_fun(tmp_datetime_new)
# except Exception as e:
# print("error :", e)
# traceback.print_exc()
# elif len(sys.argv) == 2:
# # python xxx.py 2017-07-01
# tmp_year, tmp_month, tmp_day = sys.argv[1].split("-")
# tmp_datetime = datetime.datetime(int(tmp_year), int(tmp_month), int(tmp_day))
# try:
# run_fun(tmp_datetime)
# except Exception as e:
# print("error :", e)
# traceback.print_exc()
# else:
# # tmp_datetime = datetime.datetime.now() + datetime.timedelta(days=-1)
# try:
# run_fun(tmp_datetime_show) # 使用当前时间
# except Exception as e:
# print("error :", e)
# traceback.print_exc()
# print("######################### finish %s , use time: %s #########################" % (
# tmp_datetime_str, time.time() - start))
# # 设置基础目录,每次加载使用。
# bash_stock_tmp = "/data/cache/hist_data_cache/%s/%s/"
# if not os.path.exists(bash_stock_tmp):
# os.makedirs(bash_stock_tmp) # 创建多个文件夹结构。
# print("######################### init tmp dir #########################")
# # 增加读取股票缓存方法。加快处理速度。
# def get_hist_data_cache(code, date_start, date_end):
# cache_dir = bash_stock_tmp % (date_end[0:7], date_end)
# # 如果没有文件夹创建一个。月文件夹和日文件夹。方便删除。
# # print("cache_dir:", cache_dir)
# if not os.path.exists(cache_dir):
# os.makedirs(cache_dir)
# cache_file = cache_dir + "%s^%s.gzip.pickle" % (date_end, code)
# # 如果缓存存在就直接返回缓存数据。压缩方式。
# if os.path.isfile(cache_file):
# print("######### read from cache #########", cache_file)
# return pd.read_pickle(cache_file, compression="gzip")
# else:
# print("######### get data, write cache #########", code, date_start, date_end)
# stock = ak.stock_zh_a_hist(symbol= code, start_date=date_start, end_date=date_end, adjust="")
# stock.columns = ['date', 'open', 'close', 'high', 'low', 'volume', 'amount', 'amplitude', 'quote_change',
# 'ups_downs', 'turnover']
# if stock is None:
# return None
# stock = stock.sort_index(0) # 将数据按照日期排序下。
# print(stock)
# stock.to_pickle(cache_file, compression="gzip")
# return stock