TurtleTrade/回测/TurtleOnTime_ai.py

645 lines
23 KiB
Python
Raw Normal View History

2025-04-02 23:06:11 +08:00
"""海龟实时
"""
import numpy as np
import math
import akshare as ak
import os
from datetime import datetime, timedelta
import pandas as pd
import mplfinance as mpf
import TurtleClassNew
# -----------------------------------
# 创建组合,先用一个测试
conbinations = []
# 我是否需要当前组合的信息:需要
# 什么东西 股票还是etf
# 风险系数risk_coef; atr头寸单位
# 系数是多少每1%波动多少钱 atr 买4份一共多少钱
# 组合总共会花掉多少钱
# 每个item应该具有的属性
# code
# ATR
# price
# risk_coef
# capital
# 每个月调整risk_coef和captial
# 初始化函数
# 初始化conbinations中的数据
# 监盘函数:
# 数据整理保存函数,收盘后开始
for item in conbinations:
# 创建Turtle实例ETF与股票获取数据代码不同
pass
# 获取数据每5分钟获取一次
# 计算唐奇安通道 每天收盘计算
#
# https://akshare.akfamily.xyz/data/stock/stock.html#id9
import akshare as ak
import pandas as pd
import numpy as np
import sqlite3
from datetime import datetime
import smtplib
from email.mime.text import MIMEText
from email.header import Header
import json
from typing import Dict, List, Tuple
import logging
from decimal import Decimal
class DatabaseManager:
def __init__(self, db_path: str = "turtle_trading.db"):
self.conn = sqlite3.connect(db_path)
self.create_tables()
def create_tables(self):
"""创建必要的数据表"""
# 交易信号记录表
self.conn.execute('''
CREATE TABLE IF NOT EXISTS signals (
id INTEGER PRIMARY KEY AUTOINCREMENT,
stock_code TEXT,
signal_type TEXT,
suggested_price REAL,
suggested_quantity INTEGER,
timestamp DATETIME,
status TEXT
)''')
# 实际交易记录表
self.conn.execute('''
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
signal_id INTEGER,
actual_price REAL,
actual_quantity INTEGER,
timestamp DATETIME,
FOREIGN KEY (signal_id) REFERENCES signals (id)
)''')
# 持仓状态表
self.conn.execute('''
CREATE TABLE IF NOT EXISTS positions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
stock_code TEXT,
quantity INTEGER,
avg_cost REAL,
entry_price REAL,
last_price REAL,
stop_loss REAL,
target_price REAL,
position_type TEXT,
timestamp DATETIME,
status TEXT
)''')
def get_position(self, stock_code: str) -> Dict:
"""获取股票当前持仓信息"""
cursor = self.conn.execute('''
SELECT * FROM positions
WHERE stock_code = ? AND status = 'ACTIVE'
''', (stock_code,))
position = cursor.fetchone()
if position:
return {
'id': position[0],
'stock_code': position[1],
'quantity': position[2],
'avg_cost': position[3],
'entry_price': position[4],
'last_price': position[5],
'stop_loss': position[6],
'target_price': position[7],
'position_type': position[8],
'timestamp': position[9],
'status': position[10]
}
return None
def update_position(self, stock_code: str, last_price: float):
"""更新持仓的最新价格"""
self.conn.execute('''
UPDATE positions
SET last_price = ?, timestamp = ?
WHERE stock_code = ? AND status = 'ACTIVE'
''', (last_price, datetime.now(), stock_code))
self.conn.commit()
def create_position(self, stock_code: str, quantity: int,
entry_price: float, position_type: str,
stop_loss: float, target_price: float):
"""创建新持仓"""
self.conn.execute('''
INSERT INTO positions (
stock_code, quantity, avg_cost, entry_price, last_price,
stop_loss, target_price, position_type, timestamp, status
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (stock_code, quantity, entry_price, entry_price, entry_price,
stop_loss, target_price, position_type, datetime.now(), 'ACTIVE'))
self.conn.commit()
def close_position(self, stock_code: str):
"""关闭持仓"""
self.conn.execute('''
UPDATE positions
SET status = 'CLOSED', timestamp = ?
WHERE stock_code = ? AND status = 'ACTIVE'
''', (datetime.now(), stock_code))
self.conn.commit()
class TurtleStrategy:
def __init__(self, lookback_days: int = 20):
self.lookback_days = lookback_days
def calculate_signals(self, df: pd.DataFrame,
current_position: Dict = None) -> Dict:
"""计算交易信号,考虑当前持仓状态"""
# 计算技术指标
df['high_20'] = df['high'].rolling(20).max()
df['low_10'] = df['low'].rolling(10).min()
df['atr'] = self._calculate_atr(df)
current = df.iloc[-1]
prev = df.iloc[-2]
signal = {
'type': None,
'price': None,
'quantity': 0,
'stop_loss': None,
'target_price': None
}
if current_position:
# 持仓状态下的信号计算
return self._calculate_position_signals(
current_position, current, prev, df
)
else:
# 无持仓状态下的信号计算
return self._calculate_entry_signals(current, prev, df)
def _calculate_position_signals(self, position: Dict,
current: pd.Series,
prev: pd.Series,
df: pd.DataFrame) -> Dict:
"""计算持仓状态下的信号"""
signal = {
'type': None,
'price': current['close'],
'quantity': 0,
'stop_loss': position['stop_loss'],
'target_price': position['target_price']
}
# 检查止损
if current['low'] <= position['stop_loss']:
signal['type'] = 'STOP_LOSS'
signal['quantity'] = position['quantity']
return signal
# 检查获利目标
if current['high'] >= position['target_price']:
signal['type'] = 'TAKE_PROFIT'
signal['quantity'] = position['quantity']
return signal
# 检查加仓条件
if position['position_type'] == 'LONG':
if current['close'] > position['entry_price'] * 1.05: # 5%盈利时考虑加仓
signal['type'] = 'ADD'
signal['quantity'] = self._calculate_position_size(
current['close'], df['atr'].iloc[-1]
)
# 更新止损为前低
signal['stop_loss'] = df['low'].rolling(5).min().iloc[-1]
# 检查减仓条件
elif position['position_type'] == 'SHORT':
if current['close'] < position['entry_price'] * 0.95: # 5%盈利时考虑加仓
signal['type'] = 'REDUCE'
signal['quantity'] = self._calculate_position_size(
current['close'], df['atr'].iloc[-1]
)
# 更新止损为前高
signal['stop_loss'] = df['high'].rolling(5).max().iloc[-1]
return signal
def _calculate_entry_signals(self, current: pd.Series,
prev: pd.Series,
df: pd.DataFrame) -> Dict:
"""计算入场信号"""
signal = {
'type': None,
'price': current['close'],
'quantity': 0,
'stop_loss': None,
'target_price': None
}
atr = df['atr'].iloc[-1]
# 多头入场
if current['close'] > prev['high_20']:
signal['type'] = 'BUY'
signal['quantity'] = self._calculate_position_size(
current['close'], atr
)
signal['stop_loss'] = current['close'] - 2 * atr
signal['target_price'] = current['close'] + 4 * atr
# 空头入场
elif current['close'] < prev['low_10']:
signal['type'] = 'SELL'
signal['quantity'] = self._calculate_position_size(
current['close'], atr
)
signal['stop_loss'] = current['close'] + 2 * atr
signal['target_price'] = current['close'] - 4 * atr
return signal
def _calculate_atr(self, df: pd.DataFrame) -> pd.Series:
"""计算ATR指标"""
df['tr'] = np.maximum(
df['high'] - df['low'],
np.maximum(
abs(df['high'] - df['close'].shift(1)),
abs(df['low'] - df['close'].shift(1))
)
)
return df['tr'].rolling(20).mean()
def _calculate_position_size(self, price: float, atr: float) -> int:
"""计算持仓规模"""
risk_per_trade = 100000 * 0.01 # 假设账户规模100000每次风险1%
return int(risk_per_trade / (atr * 100))
class EmailManager:
def __init__(self, config_path: str = "email_config.json"):
with open(config_path) as f:
self.config = json.load(f)
def send_signal(self, stock_code: str, signal_type: str,
suggested_price: float, suggested_quantity: int,
stop_loss: float = None, target_price: float = None) -> bool:
"""发送交易信号邮件"""
subject = f"交易信号: {stock_code} - {signal_type}"
content = f"""
股票代码: {stock_code}
信号类型: {signal_type}
建议价格: {suggested_price}
建议数量: {suggested_quantity}
止损价位: {stop_loss if stop_loss else ''}
目标价位: {target_price if target_price else ''}
请回复实际成交价格和数量, 格式:
价格,数量
例如: 10.5,100
"""
return self._send_email(subject, content)
def send_position_update(self, position: Dict,
current_price: float) -> bool:
"""发送持仓更新邮件"""
subject = f"持仓更新: {position['stock_code']}"
# 计算收益
profit = (current_price - position['avg_cost']) * position['quantity']
profit_pct = (current_price / position['avg_cost'] - 1) * 100
content = f"""
股票代码: {position['stock_code']}
当前价格: {current_price}
持仓数量: {position['quantity']}
平均成本: {position['avg_cost']}
止损价位: {position['stop_loss']}
目标价位: {position['target_price']}
当前收益: {profit:.2f} ({profit_pct:.2f}%)
持仓类型: {position['position_type']}
"""
return self._send_email(subject, content)
def _send_email(self, subject: str, content: str) -> bool:
"""发送邮件的具体实现"""
try:
msg = MIMEText(content, 'plain', 'utf-8')
msg['Subject'] = Header(subject, 'utf-8')
msg['From'] = self.config['sender']
msg['To'] = self.config['receiver']
with smtplib.SMTP_SSL(self.config['smtp_server'],
self.config['smtp_port']) as server:
server.login(self.config['username'], self.config['password'])
server.sendmail(self.config['sender'],
[self.config['receiver']],
msg.as_string())
return True
except Exception as e:
logging.error(f"发送邮件失败: {str(e)}")
return False
class TurtleTrader:
def __init__(self, config_path: str = "config.json"):
self.db = DatabaseManager()
self.strategy = TurtleStrategy()
self.email = EmailManager()
# 加载配置
with open(config_path) as f:
self.config = json.load(f)
def process_stock(self, stock_code: str):
"""处理单个股票"""
try:
# 获取当前持仓状态
position = self.db.get_position(stock_code)
# 获取股票数据
df = ak.stock_zh_a_hist(
symbol=stock_code,
period="daily",
start_date="20230101",
end_date=datetime.now().strftime("%Y%m%d"),
adjust="qfq"
)
# 计算信号
signal = self.strategy.calculate_signals(df, position)
current_price = df['close'].iloc[-1]
# 更新持仓的最新价格
if position:
self.db.update_position(stock_code, current_price)
# 定期发送持仓更新
self.email.send_position_update(position, current_price)
if signal['type']:
# 保存信号
signal_id = self.db.save_signal(
stock_code, signal['type'],
signal['price'], signal['quantity']
)
# 发送邮件
self.email.send_signal(
stock_code, signal['type'],
signal['price'], signal['quantity'],
signal['stop_loss'], signal['target_price']
)
except Exception as e:
logging.error(f"处理股票 {stock_code} 时发生错误: {str(e)}")
def process_feedback(self, signal_id: int, actual_price: float,
actual_quantity: int):
"""处理交易反馈并更新持仓状态"""
try:
# 获取原始信号
cursor = self.db.conn.execute('''
SELECT stock_code, signal_type, suggested_price
FROM signals WHERE id = ?
''', (signal_id,))
signal = cursor.fetchone()
if not signal:
raise ValueError(f"Signal ID {signal_id} not found")
stock_code, signal_type, suggested_price = signal
# 保存实际交易记录
self.db.save_trade(signal_id, actual_price, actual_quantity)
# 更新持仓状态
current_position = self.db.get_position(stock_code)
if signal_type in ['BUY', 'ADD']:
if current_position:
# 计算新的平均成本
total_cost = (current_position['avg_cost'] *
current_position['quantity'] +
actual_price * actual_quantity)
total_quantity = (current_position['quantity'] +
actual_quantity)
new_avg_cost = total_cost / total_quantity
# 更新持仓
self.db.conn.execute('''
UPDATE positions
SET quantity = ?, avg_cost = ?, last_price = ?,
timestamp = ?
WHERE id = ?
''', (total_quantity, new_avg_cost, actual_price,
datetime.now(), current_position['id']))
else:
# 创建新持仓
# 使用ATR计算止损和目标价位
df = self._get_stock_data(stock_code)
atr = self.strategy._calculate_atr(df).iloc[-1]
stop_loss = actual_price - 2 * atr
target_price = actual_price + 4 * atr
self.db.create_position(
stock_code, actual_quantity, actual_price,
'LONG', stop_loss, target_price
)
elif signal_type in ['SELL', 'REDUCE', 'STOP_LOSS', 'TAKE_PROFIT']:
if current_position:
remaining_quantity = (current_position['quantity'] -
actual_quantity)
if remaining_quantity > 0:
# 部分平仓
self.db.conn.execute('''
UPDATE positions
SET quantity = ?, last_price = ?, timestamp = ?
WHERE id = ?
''', (remaining_quantity, actual_price,
datetime.now(), current_position['id']))
else:
# 完全平仓
self.db.close_position(stock_code)
self.db.conn.commit()
except Exception as e:
logging.error(f"处理交易反馈时发生错误: {str(e)}")
raise
def _get_stock_data(self, stock_code: str, days: int = 30) -> pd.DataFrame:
"""获取股票历史数据"""
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
df = ak.stock_zh_a_hist(
symbol=stock_code,
period="daily",
start_date=start_date.strftime("%Y%m%d"),
end_date=end_date.strftime("%Y%m%d"),
adjust="qfq"
)
return df
class PerformanceAnalyzer:
def __init__(self, db_manager: DatabaseManager):
self.db = db_manager
def analyze(self) -> Dict:
"""分析交易和持仓表现"""
# 分析交易执行质量
trade_metrics = self._analyze_trade_execution()
# 分析持仓表现
position_metrics = self._analyze_positions()
return {
"trade_execution": trade_metrics,
"position_performance": position_metrics
}
def _analyze_trade_execution(self) -> Dict:
"""分析交易执行质量"""
cursor = self.db.conn.execute('''
SELECT s.stock_code, s.signal_type, s.suggested_price,
s.suggested_quantity, t.actual_price, t.actual_quantity
FROM signals s
JOIN trades t ON s.id = t.signal_id
WHERE s.status = 'EXECUTED'
''')
trades = cursor.fetchall()
if not trades:
return {"message": "没有足够的交易数据进行分析"}
# 计算关键指标
price_slippage = []
quantity_fill = []
execution_delay = []
for trade in trades:
price_diff = (trade[4] - trade[2]) / trade[2] * 100
quantity_diff = trade[5] / trade[3] * 100
price_slippage.append(price_diff)
quantity_fill.append(quantity_diff)
return {
"total_trades": len(trades),
"avg_price_slippage": np.mean(price_slippage),
"max_price_slippage": max(price_slippage),
"avg_quantity_fill": np.mean(quantity_fill),
"price_slippage_std": np.std(price_slippage)
}
def _analyze_positions(self) -> Dict:
"""分析持仓表现"""
cursor = self.db.conn.execute('''
SELECT stock_code, quantity, avg_cost, entry_price,
last_price, stop_loss, target_price, position_type,
timestamp, status
FROM positions
''')
positions = cursor.fetchall()
if not positions:
return {"message": "没有持仓数据进行分析"}
active_positions = []
closed_positions = []
total_profit = 0
win_count = 0
for pos in positions:
profit = (pos[4] - pos[2]) * pos[1] # (last_price - avg_cost) * quantity
profit_pct = (pos[4] / pos[2] - 1) * 100
if pos[9] == 'ACTIVE':
active_positions.append({
'stock_code': pos[0],
'profit': profit,
'profit_pct': profit_pct
})
else:
closed_positions.append({
'stock_code': pos[0],
'profit': profit,
'profit_pct': profit_pct
})
if profit > 0:
win_count += 1
total_profit += profit
return {
"active_positions": len(active_positions),
"closed_positions": len(closed_positions),
"total_profit": total_profit,
"win_rate": win_count / len(closed_positions) if closed_positions else 0,
"avg_profit_active": np.mean([p['profit'] for p in active_positions]) if active_positions else 0,
"avg_profit_closed": np.mean([p['profit'] for p in closed_positions]) if closed_positions else 0,
"best_position": max([p['profit_pct'] for p in active_positions + closed_positions]) if positions else 0,
"worst_position": min([p['profit_pct'] for p in active_positions + closed_positions]) if positions else 0
}
def main():
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
filename='turtle_trader.log'
)
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
try:
trader = TurtleTrader()
analyzer = PerformanceAnalyzer(trader.db)
while True:
# 处理所有配置的股票
for stock_code in trader.config['stock_codes']:
trader.process_stock(stock_code)
# 定期进行性能分析
if datetime.now().hour == 15: # 每天收盘后进行分析
analysis = analyzer.analyze()
logging.info(f"每日性能分析报告: {json.dumps(analysis, indent=2)}")
# 等待下一个检查周期
time.sleep(trader.config['check_interval'])
except KeyboardInterrupt:
logging.info("系统正常关闭")
except Exception as e:
logging.error(f"系统运行出错: {str(e)}")
raise
if __name__ == "__main__":
main()