Trad / ml_engine /strategies.py
Riy777's picture
Update ml_engine/strategies.py
4cdcb03 verified
# ml_engine/strategies.py (Updated to use LearningHub for weights)
import asyncio
# (Import from internal modules)
from .patterns import ChartPatternAnalyzer
class PatternEnhancedStrategyEngine:
# πŸ”΄ --- START OF CHANGE --- πŸ”΄
def __init__(self, data_manager, learning_hub): # (Changed from learning_engine)
self.data_manager = data_manager
self.learning_hub = learning_hub # (Changed from learning_engine)
self.pattern_analyzer = ChartPatternAnalyzer()
# πŸ”΄ --- END OF CHANGE --- πŸ”΄
async def enhance_strategy_with_patterns(self, strategy_scores, pattern_analysis, symbol):
"""(Unchanged logic)"""
if not pattern_analysis or pattern_analysis.get('pattern_detected') in ['no_clear_pattern', 'insufficient_data']:
return strategy_scores
pattern_confidence = pattern_analysis.get('pattern_confidence', 0)
pattern_name = pattern_analysis.get('pattern_detected', '')
predicted_direction = pattern_analysis.get('predicted_direction', '')
if pattern_confidence >= 0.6:
enhancement_factor = self._calculate_pattern_enhancement(pattern_confidence, pattern_name)
enhanced_strategies = self._get_pattern_appropriate_strategies(pattern_name, predicted_direction)
# (Omitted print statements for brevity)
for strategy in enhanced_strategies:
if strategy in strategy_scores:
original_score = strategy_scores[strategy]
strategy_scores[strategy] = min(original_score * enhancement_factor, 1.0)
return strategy_scores
def _calculate_pattern_enhancement(self, pattern_confidence, pattern_name):
"""(Unchanged logic)"""
base_enhancement = 1.0 + (pattern_confidence * 0.3)
high_reliability_patterns = ['Double Top', 'Double Bottom', 'Head & Shoulders', 'Cup and Handle']
if pattern_name in high_reliability_patterns:
base_enhancement *= 1.1
return min(base_enhancement, 1.5)
def _get_pattern_appropriate_strategies(self, pattern_name, direction):
"""(Unchanged logic)"""
reversal_patterns = ['Double Top', 'Double Bottom', 'Head & Shoulders', 'Triple Top', 'Triple Bottom']
continuation_patterns = ['Flags', 'Pennants', 'Triangles', 'Rectangles']
if pattern_name in reversal_patterns:
if direction == 'down':
return ['breakout_momentum', 'trend_following']
else:
return ['mean_reversion', 'breakout_momentum']
elif pattern_name in continuation_patterns:
return ['trend_following', 'breakout_momentum']
else:
return ['breakout_momentum', 'hybrid_ai']
class MultiStrategyEngine:
# πŸ”΄ --- START OF CHANGE --- πŸ”΄
def __init__(self, data_manager, learning_hub): # (Changed from learning_engine)
self.data_manager = data_manager
self.learning_hub = learning_hub # (Changed from learning_engine)
# (Pass the hub to the enhancer)
self.pattern_enhancer = PatternEnhancedStrategyEngine(data_manager, learning_hub)
# πŸ”΄ --- END OF CHANGE --- πŸ”΄
self.strategies = {
'trend_following': self._trend_following_strategy,
'mean_reversion': self._mean_reversion_strategy,
'breakout_momentum': self._breakout_momentum_strategy,
'volume_spike': self._volume_spike_strategy,
'whale_tracking': self._whale_tracking_strategy,
'pattern_recognition': self._pattern_recognition_strategy,
'hybrid_ai': self._hybrid_ai_strategy
}
async def evaluate_all_strategies(self, symbol_data, market_context):
"""Evaluate all trading strategies"""
try:
# πŸ”΄ --- START OF CHANGE --- πŸ”΄
# (Get weights from the new Learning Hub)
if self.learning_hub and self.learning_hub.initialized:
try:
market_condition = market_context.get('market_trend', 'sideways_market')
# (Call the new hub function)
optimized_weights = await self.learning_hub.get_optimized_weights(market_condition)
except Exception as e:
print(f"⚠️ Error getting optimized weights from hub: {e}. Using defaults.")
optimized_weights = await self.get_default_weights()
else:
optimized_weights = await self.get_default_weights()
# πŸ”΄ --- END OF CHANGE --- πŸ”΄
strategy_scores = {}
base_scores = {}
primary_strategies = [s for s in self.strategies.keys() if s != 'hybrid_ai']
for strategy_name in primary_strategies:
strategy_function = self.strategies[strategy_name]
try:
base_score = await strategy_function(symbol_data, market_context)
if base_score is None:
continue
base_scores[strategy_name] = base_score
weight = optimized_weights.get(strategy_name, 0.1)
weighted_score = base_score * weight
strategy_scores[strategy_name] = min(weighted_score, 1.0)
except Exception as error:
print(f"❌ Error evaluating strategy {strategy_name}: {error}")
continue
try:
hybrid_score = await self._hybrid_ai_strategy(symbol_data, market_context, base_scores)
if hybrid_score is not None:
base_scores['hybrid_ai'] = hybrid_score
weight = optimized_weights.get('hybrid_ai', 0.1)
strategy_scores['hybrid_ai'] = min(hybrid_score * weight, 1.0)
except Exception as e:
print(f"❌ Error in hybrid_ai strategy: {e}")
# Pattern enhancement (Unchanged)
pattern_analysis = symbol_data.get('pattern_analysis')
if pattern_analysis:
strategy_scores = await self.pattern_enhancer.enhance_strategy_with_patterns(
strategy_scores, pattern_analysis, symbol_data.get('symbol')
)
if base_scores:
best_strategy = max(base_scores.items(), key=lambda x: x[1])
best_strategy_name = best_strategy[0]
best_strategy_score = best_strategy[1]
symbol_data['recommended_strategy'] = best_strategy_name
symbol_data['strategy_confidence'] = best_strategy_score
return strategy_scores, base_scores
except Exception as error:
print(f"❌ Error in evaluate_all_strategies: {error}")
return {}, {}
async def get_default_weights(self):
"""(Unchanged) Default weights"""
return {
'trend_following': 0.15,
'mean_reversion': 0.12,
'breakout_momentum': 0.20,
'volume_spike': 0.13,
'whale_tracking': 0.20,
'pattern_recognition': 0.10,
'hybrid_ai': 0.10
}
#
# (All individual strategy functions remain unchanged)
# (_trend_following_strategy, _mean_reversion_strategy, etc.)
# (Omitted for brevity)
#
async def _trend_following_strategy(self, symbol_data, market_context):
try:
score = 0.0
indicators = symbol_data.get('advanced_indicators', {})
for timeframe in ['1h', '15m', '5m']:
if timeframe in indicators:
tf_indicators = indicators[timeframe]
ema_21 = tf_indicators.get('ema_21')
ema_50 = tf_indicators.get('ema_50')
adx = tf_indicators.get('adx', 0)
if ema_21 is not None and ema_50 is not None:
if ema_21 > ema_50:
score += 0.2
if adx > 20:
score += 0.1
if symbol_data['current_price'] > ema_21:
score += 0.05
return min(score, 1.0)
except Exception: return None
def _check_ema_alignment(self, indicators):
required_emas = ['ema_9', 'ema_21', 'ema_50']
if all(ema in indicators for ema in required_emas):
return (indicators['ema_9'] > indicators['ema_21'] > indicators['ema_50'])
return False
async def _mean_reversion_strategy(self, symbol_data, market_context):
try:
score = 0.0
current_price = symbol_data['current_price']
indicators = symbol_data.get('advanced_indicators', {})
for timeframe in ['1h', '15m']:
if timeframe in indicators:
tf_indicators = indicators[timeframe]
rsi_value = tf_indicators.get('rsi', 50)
bb_lower = tf_indicators.get('bb_lower')
bb_upper = tf_indicators.get('bb_upper')
if bb_lower is None or bb_upper is None: continue
position_in_band = 0.5
if (bb_upper - bb_lower) > 0:
position_in_band = (current_price - bb_lower) / (bb_upper - bb_lower)
is_rsi_oversold = rsi_value < 25
is_bb_oversold = position_in_band < 0.1
if is_rsi_oversold or is_bb_oversold:
score += 0.4
if is_rsi_oversold and is_bb_oversold:
score += 0.2
return min(score, 1.0)
except Exception: return None
async def _breakout_momentum_strategy(self, symbol_data, market_context):
try:
score = 0.0
current_price = symbol_data['current_price']
indicators = symbol_data.get('advanced_indicators', {})
for timeframe in ['1h', '15m', '5m']:
if timeframe in indicators:
tf_indicators = indicators[timeframe]
volume_ratio = tf_indicators.get('volume_ratio', 0)
if volume_ratio < 1.5: continue
score += 0.2
macd_hist = tf_indicators.get('macd_hist', 0)
if macd_hist > 0:
score += 0.1
atr_percent = tf_indicators.get('atr_percent', 0)
if atr_percent > 1.5:
score += 0.1
vwap = tf_indicators.get('vwap')
if vwap and current_price > vwap:
score += 0.05
return min(score, 1.0)
except Exception: return None
async def _volume_spike_strategy(self, symbol_data, market_context):
try:
score = 0.0
indicators = symbol_data.get('advanced_indicators', {})
for timeframe in ['1h', '15m', '5m']:
if timeframe in indicators:
volume_ratio = indicators[timeframe].get('volume_ratio', 0)
if volume_ratio > 3.0: score += 0.45
elif volume_ratio > 2.0: score += 0.25
elif volume_ratio > 1.5: score += 0.15
return min(score, 1.0)
except Exception: return None
async def _whale_tracking_strategy(self, symbol_data, market_context):
try:
whale_data = symbol_data.get('whale_data', {})
if not whale_data.get('data_available', False):
return None
whale_signal = await self.data_manager.get_whale_trading_signal(
symbol_data['symbol'], whale_data, market_context
)
if whale_signal and whale_signal.get('action') != 'HOLD':
confidence = whale_signal.get('confidence', 0)
if whale_signal.get('action') in ['STRONG_BUY', 'BUY']:
return min(confidence * 1.2, 1.0)
return None
except Exception: return None
async def _pattern_recognition_strategy(self, symbol_data, market_context):
try:
score = 0.0
pattern_analysis = symbol_data.get('pattern_analysis')
if pattern_analysis and pattern_analysis.get('pattern_confidence', 0) > 0.6:
if pattern_analysis.get('predicted_direction') == 'up':
score += pattern_analysis.get('pattern_confidence', 0) * 0.8
else:
indicators = symbol_data.get('advanced_indicators', {})
if '1h' in indicators:
tf_indicators = indicators['1h']
if (tf_indicators.get('rsi', 50) > 60 and
tf_indicators.get('macd_hist', 0) > 0):
score += 0.3
return min(score, 1.0)
except Exception: return None
async def _hybrid_ai_strategy(self, symbol_data, market_context, base_scores):
try:
score = 0.0
monte_carlo_prob = symbol_data.get('monte_carlo_probability')
if monte_carlo_prob is not None:
score += monte_carlo_prob * 0.4
breakout_score = base_scores.get('breakout_momentum', 0)
volume_score = base_scores.get('volume_spike', 0)
whale_score = base_scores.get('whale_tracking', 0)
pattern_score = base_scores.get('pattern_recognition', 0)
if breakout_score > 0.7 and volume_score > 0.6: score += 0.3
if breakout_score > 0.6 and whale_score > 0.7: score += 0.4
if pattern_score > 0.7 and volume_score > 0.5: score += 0.2
if breakout_score > 0.7 and whale_score > 0.7 and volume_score > 0.7:
score = 1.0
return max(0.0, min(score, 1.0))
except Exception: return None
print("βœ… ML Module: Strategy Engine loaded (V3 - Integrated LearningHub for weights)")