"""
Signal Engine - Multi-timeframe signal generation
Timeframes: 5s, 10s, 30s, 1m, 5m
Indicators: RSI, MACD, Bollinger Bands, EMA, Stochastic, CCI
"""
import asyncio
import random
from datetime import datetime
from typing import Dict, List, Optional, Any
from collections import deque
import math


class TechnicalIndicators:
    """Pure Python technical indicator calculations"""

    @staticmethod
    def sma(prices: List[float], period: int) -> Optional[float]:
        if len(prices) < period:
            return None
        return sum(prices[-period:]) / period

    @staticmethod
    def ema(prices: List[float], period: int) -> Optional[float]:
        if len(prices) < period:
            return None
        k = 2 / (period + 1)
        ema_val = sum(prices[:period]) / period
        for price in prices[period:]:
            ema_val = price * k + ema_val * (1 - k)
        return ema_val

    @staticmethod
    def rsi(prices: List[float], period: int = 14) -> Optional[float]:
        if len(prices) < period + 1:
            return None
        gains, losses = [], []
        for i in range(1, len(prices)):
            diff = prices[i] - prices[i - 1]
            gains.append(max(0, diff))
            losses.append(max(0, -diff))
        avg_gain = sum(gains[-period:]) / period
        avg_loss = sum(losses[-period:]) / period
        if avg_loss == 0:
            return 100.0
        rs = avg_gain / avg_loss
        return 100 - (100 / (1 + rs))

    @staticmethod
    def macd(prices: List[float], fast=12, slow=26, signal=9) -> Optional[Dict]:
        if len(prices) < slow + signal:
            return None
        ema_fast = TechnicalIndicators.ema(prices, fast)
        ema_slow = TechnicalIndicators.ema(prices, slow)
        if ema_fast is None or ema_slow is None:
            return None
        macd_line = ema_fast - ema_slow
        # Simplified signal line
        signal_line = macd_line * 0.95  # approximate
        histogram = macd_line - signal_line
        return {"macd": macd_line, "signal": signal_line, "histogram": histogram}

    @staticmethod
    def bollinger_bands(prices: List[float], period: int = 20, std_dev: float = 2.0) -> Optional[Dict]:
        if len(prices) < period:
            return None
        recent = prices[-period:]
        sma = sum(recent) / period
        variance = sum((p - sma) ** 2 for p in recent) / period
        std = math.sqrt(variance)
        return {
            "upper": sma + std_dev * std,
            "middle": sma,
            "lower": sma - std_dev * std,
        }

    @staticmethod
    def stochastic(highs: List[float], lows: List[float], closes: List[float], k_period=14) -> Optional[Dict]:
        if len(closes) < k_period:
            return None
        high_max = max(highs[-k_period:])
        low_min = min(lows[-k_period:])
        if high_max == low_min:
            return None
        k = ((closes[-1] - low_min) / (high_max - low_min)) * 100
        return {"k": k, "d": k * 0.95}  # simplified

    @staticmethod
    def cci(highs: List[float], lows: List[float], closes: List[float], period=20) -> Optional[float]:
        if len(closes) < period:
            return None
        typical_prices = [(h + l + c) / 3 for h, l, c in zip(highs[-period:], lows[-period:], closes[-period:])]
        mean = sum(typical_prices) / period
        mean_dev = sum(abs(p - mean) for p in typical_prices) / period
        if mean_dev == 0:
            return 0
        return (typical_prices[-1] - mean) / (0.015 * mean_dev)


class SignalEngine:
    TIMEFRAMES = {
        "5s": 5,
        "10s": 10,
        "30s": 30,
        "1m": 60,
        "5m": 300,
    }

    ASSETS = ["EURUSD", "GBPUSD", "USDJPY", "BTCUSD", "ETHUSD", "GOLD", "AAPL"]

    def __init__(self, ws_manager):
        self.ws_manager = ws_manager
        self.quotex_client = None
        self.running = False
        self._price_history: Dict[str, deque] = {
            asset: deque(maxlen=200) for asset in self.ASSETS
        }
        self._latest_signals: Dict[str, Dict] = {}
        self._signal_history: List[Dict] = []
        self._indicators = TechnicalIndicators()

    def set_quotex_client(self, client):
        self.quotex_client = client

    async def run_signal_loop(self):
        """Main signal generation loop"""
        self.running = True
        print("📊 Signal engine loop started")

        # Initialize price history
        await self._seed_price_history()

        tick = 0
        while self.running:
            try:
                await self._update_prices()
                await self._generate_signals(tick)
                tick += 1
                await asyncio.sleep(5)  # Base tick: 5 seconds
            except Exception as e:
                print(f"Signal loop error: {e}")
                await asyncio.sleep(5)

    async def _seed_price_history(self):
        """Seed initial price history"""
        base_prices = {
            "EURUSD": 1.0850, "GBPUSD": 1.2650, "USDJPY": 149.50,
            "BTCUSD": 43500.0, "ETHUSD": 2650.0, "GOLD": 2050.0, "AAPL": 192.0,
        }
        for asset in self.ASSETS:
            base = base_prices.get(asset, 1.0)
            for i in range(100):
                self._price_history[asset].append({
                    "close": base * (1 + random.uniform(-0.01, 0.01)),
                    "high": base * (1 + random.uniform(0.001, 0.005)),
                    "low": base * (1 - random.uniform(0.001, 0.005)),
                    "volume": random.uniform(1000, 50000),
                    "timestamp": datetime.utcnow().isoformat(),
                })

    async def _update_prices(self):
        """Fetch/simulate latest prices"""
        for asset in self.ASSETS:
            if self.quotex_client and self.quotex_client.is_connected:
                price = await self.quotex_client.get_realtime_price(asset)
            else:
                # Simulate realistic price movement
                history = list(self._price_history[asset])
                last_price = history[-1]["close"] if history else 1.0
                change = random.gauss(0, 0.001)
                price = last_price * (1 + change)

            if price:
                last = list(self._price_history[asset])[-1] if self._price_history[asset] else {"close": price}
                self._price_history[asset].append({
                    "close": price,
                    "high": price * (1 + random.uniform(0, 0.001)),
                    "low": price * (1 - random.uniform(0, 0.001)),
                    "volume": random.uniform(1000, 50000),
                    "timestamp": datetime.utcnow().isoformat(),
                })

    async def _generate_signals(self, tick: int):
        """Generate signals for all assets and timeframes"""
        for asset in self.ASSETS:
            history = list(self._price_history[asset])
            if len(history) < 30:
                continue

            closes = [h["close"] for h in history]
            highs = [h["high"] for h in history]
            lows = [h["low"] for h in history]

            # Compute indicators
            rsi = self._indicators.rsi(closes, 14)
            macd = self._indicators.macd(closes)
            bb = self._indicators.bollinger_bands(closes, 20)
            stoch = self._indicators.stochastic(highs, lows, closes, 14)
            cci = self._indicators.cci(highs, lows, closes, 20)
            ema9 = self._indicators.ema(closes, 9)
            ema21 = self._indicators.ema(closes, 21)

            # Determine which timeframes trigger on this tick
            for tf_name, tf_seconds in self.TIMEFRAMES.items():
                if tick % (tf_seconds // 5) != 0:
                    continue

                signal = self._analyze_signal(
                    asset=asset,
                    closes=closes,
                    rsi=rsi,
                    macd=macd,
                    bb=bb,
                    stoch=stoch,
                    cci=cci,
                    ema9=ema9,
                    ema21=ema21,
                    timeframe=tf_name,
                )

                if signal:
                    self._latest_signals[f"{asset}_{tf_name}"] = signal
                    self._signal_history.insert(0, signal)
                    if len(self._signal_history) > 500:
                        self._signal_history = self._signal_history[:500]

                    # Broadcast to WebSocket clients
                    await self.ws_manager.broadcast_signal(signal)

                    # Save to DB (non-blocking)
                    asyncio.create_task(self._save_signal(signal))

    def _analyze_signal(
        self, asset, closes, rsi, macd, bb, stoch, cci, ema9, ema21, timeframe
    ) -> Optional[Dict]:
        """Multi-indicator signal analysis"""
        call_votes = 0
        put_votes = 0
        total_weight = 0

        current_price = closes[-1] if closes else 0

        # RSI
        if rsi is not None:
            total_weight += 2
            if rsi < 30:
                call_votes += 2  # Oversold → CALL
            elif rsi > 70:
                put_votes += 2   # Overbought → PUT
            elif rsi < 45:
                call_votes += 1
            elif rsi > 55:
                put_votes += 1

        # MACD
        if macd:
            total_weight += 2
            if macd["histogram"] > 0 and macd["macd"] > macd["signal"]:
                call_votes += 2
            elif macd["histogram"] < 0 and macd["macd"] < macd["signal"]:
                put_votes += 2

        # Bollinger Bands
        if bb and current_price:
            total_weight += 2
            if current_price < bb["lower"]:
                call_votes += 2  # Price below lower → bounce up
            elif current_price > bb["upper"]:
                put_votes += 2   # Price above upper → bounce down
            elif current_price < bb["middle"]:
                call_votes += 1
            else:
                put_votes += 1

        # Stochastic
        if stoch:
            total_weight += 1
            if stoch["k"] < 20:
                call_votes += 1
            elif stoch["k"] > 80:
                put_votes += 1

        # CCI
        if cci is not None:
            total_weight += 1
            if cci < -100:
                call_votes += 1
            elif cci > 100:
                put_votes += 1

        # EMA Cross
        if ema9 and ema21:
            total_weight += 2
            if ema9 > ema21:
                call_votes += 2
            else:
                put_votes += 2

        if total_weight == 0:
            return None

        call_pct = call_votes / total_weight
        put_pct = put_votes / total_weight

        if call_pct > 0.60:
            direction = "CALL"
            strength = call_pct
        elif put_pct > 0.60:
            direction = "PUT"
            strength = put_pct
        else:
            return None  # No clear signal

        return {
            "asset": asset,
            "platform": "quotex",
            "direction": direction,
            "timeframe": timeframe,
            "strength": round(strength * 100, 1),
            "current_price": round(current_price, 5),
            "indicators": {
                "rsi": round(rsi, 2) if rsi else None,
                "macd": round(macd["macd"], 6) if macd else None,
                "macd_histogram": round(macd["histogram"], 6) if macd else None,
                "bb_upper": round(bb["upper"], 5) if bb else None,
                "bb_lower": round(bb["lower"], 5) if bb else None,
                "stoch_k": round(stoch["k"], 2) if stoch else None,
                "cci": round(cci, 2) if cci else None,
                "ema9": round(ema9, 5) if ema9 else None,
                "ema21": round(ema21, 5) if ema21 else None,
            },
            "timestamp": datetime.utcnow().isoformat(),
        }

    def get_latest_signals(self, asset=None, timeframe=None) -> List[Dict]:
        signals = list(self._latest_signals.values())
        if asset:
            signals = [s for s in signals if s["asset"] == asset]
        if timeframe:
            signals = [s for s in signals if s["timeframe"] == timeframe]
        return signals

    def get_signal_history(self, limit=50, asset=None) -> List[Dict]:
        history = self._signal_history
        if asset:
            history = [s for s in history if s["asset"] == asset]
        return history[:limit]

    async def _save_signal(self, signal: Dict):
        try:
            from app.database import save_signal
            await save_signal(signal)
        except Exception as e:
            pass  # Non-critical
