import logging
import pandas as pd
import numpy as np
from numba import cuda

# -------------------------------
# 1) Configure Logging
# -------------------------------
logger = logging.getLogger("GPUBacktest")
ch = logging.StreamHandler()

# -------------------------------
# 2) Hardcoded JSON Conditions
#    (Simplified from your example)
# -------------------------------
BULLISH_ENTRY_PCR = 0.7         # "Bullish": PCR > 0.7 => entry: PE
HEAVILY_BULLISH_PCR = 1.2375    # "Heavily Bullish": PCR > 1.2375 => exit CE
BEARISH_ENTRY_PCR = 1.1875      # "Bearish": PCR < 1.1875 => entry: CE
HEAVILY_BEARISH_PCR = 0.65      # "Heavily Bearish": PCR < 0.65 => exit PE

UNIVERSAL_STOP_LOSS = -4.0      # universal SL
UNIVERSAL_TARGET = 10.0         # universal target

# We map "PE" to position_side = +1, "CE" to position_side = -1
# So if we are in a PE, that means we want to exit if "Heavily Bearish".
# If we are in a CE, that means we want to exit if "Heavily Bullish".

# -------------------------------
# 3) GPU Kernel
#    One thread = one parameter variant
#    (Here, we only demonstrate one expiry for simplicity.)
# -------------------------------
def backtest_single_expiry_kernel(pcr_array, option_close_array,
                                  # JSON threshold constants
                                  bull_entry, heavy_bull_exit,
                                  bear_entry, heavy_bear_exit,
                                #   universal_sl, universal_target,
                                  # output
    pcr_array, option_close_array: float32 arrays representing a single expiry's time series.
    We assume both have the same length (N bars).
    Each thread handles one set of parameters (if we had multiple variants).
    Here, we show a minimal example: we can pass the thresholds as scalars,
    or we can store them in an array if we have multiple variants.
    For demonstration, each thread will do the SAME thresholds. 
    If you have multiple variants, you typically store each variant's 
    thresholds in arrays and index them with thread_id.
    The kernel loops over the time dimension sequentially to handle 
    open/close logic. We allow only one position at a time.
    idx = cuda.grid(1)
    # For demonstration, assume we only have 1 thread or a small number of threads
    # If idx >= something, return. But let's keep it simple:
    if idx == 0:
        n = pcr_array.size  # number of bars

        in_position = False
        position_side = 0.0  # +1 = PE, -1 = CE
        entry_price = 0.0
        total_pnl = 0.0

        # Loop over each bar/time in this single expiry
        for i in range(n):
            pcr_val = pcr_array[i]
            close_px = option_close_array[i]

            # If in position, check universal SL/target
            if in_position:
                current_pnl = (close_px - entry_price) * position_side
                # universal SL or target
                # if current_pnl <= universal_sl or current_pnl >= universal_target:
                #     total_pnl += current_pnl
                #     in_position = False

            # If still in position, check exit signals
            #   - If in PE (+1), exit if "Heavily Bearish" => PCR < heavy_bear_exit
            #   - If in CE (-1), exit if "Heavily Bullish" => PCR > heavy_bull_exit
            if in_position:
                # PE => exit if pcr_val < heavy_bear_exit
                if position_side > 0 and pcr_val < heavy_bear_exit:
                    current_pnl = (close_px - entry_price) * position_side
                    total_pnl += current_pnl
                    in_position = False
                # CE => exit if pcr_val > heavy_bull_exit
                elif position_side < 0 and pcr_val > heavy_bull_exit:
                    current_pnl = (close_px - entry_price) * position_side
                    total_pnl += current_pnl
                    in_position = False

            # If not in position, check entry signals
            #   - Bullish => PCR > bull_entry => open PE
            #   - Bearish => PCR < bear_entry => open CE
            if not in_position:
                if pcr_val > bull_entry:
                    # open PE
                    in_position = True
                    position_side = 1.0
                    entry_price = close_px
                elif pcr_val < bear_entry:
                    # open CE
                    in_position = True
                    position_side = -1.0
                    entry_price = close_px

        # At the end, if still in position, close at the last bar
        if in_position:
            last_close = option_close_array[n - 1]
            final_pnl = (last_close - entry_price) * position_side
            total_pnl += final_pnl

        pnl_array[0] = total_pnl
    # If you had multiple threads (variants), you'd do:
    #   thread_id = cuda.grid(1)
    #   load thresholds from arrays[thread_id]
    #   store to pnl_array[thread_id]

def main():
    # --------------------------------------------------------
    # A) Read the pickled data from disk
    # --------------------------------------------------------
    logger.info("Reading pickle files...")

    pcr_df = pd.read_pickle("data\MAIN_NIFTY50_PCR.pkl")         # (Date, Time, PCR)
    expiry_df = pd.read_pickle("data\MAIN_NIFTY50_EXPIRIES.pkl")   # (Expiries, Status)
    options_df = pd.read_pickle("data\MAIN_NIFTY50_OPTIONS_2020_NIFTY50_2020_01_09_OptionChain.pkl") # (Date, Time, Open, ..., ExpiryDate)

    logger.info(f"PCR DF shape: {pcr_df.shape}, columns: {pcr_df.columns}")
    logger.info(f"Expiry DF shape: {expiry_df.shape}, columns: {expiry_df.columns}")
    logger.info(f"Options DF shape: {options_df.shape}, columns: {options_df.columns}")

    # --------------------------------------------------------
    # B) Select the first expiry for demonstration
    #    (once working, you can loop over all expiries)
    # --------------------------------------------------------
    first_expiry = expiry_df.iloc[0]['Expiries']  # e.g. "2025-12-25" or similar
    logger.info(f"Selected expiry for test: {first_expiry}")

    # Filter options data for that expiry
    mask = (options_df['ExpiryDate'] == first_expiry)
    df_expiry = options_df[mask].copy()

    # Sort by Date, Time (important for time-series)
    df_expiry.sort_values(['Date','Time'], inplace=True)

    logger.info(f"Filtered option data shape: {df_expiry.shape}")

    # --------------------------------------------------------
    # C) Merge with PCR Data on (Date, Time)
    #    This ensures we have a PCR value for each row
    # --------------------------------------------------------
    df_merged = pd.merge(df_expiry, pcr_df, on=['Date','Time'], how='left')
    df_merged.sort_values(['Date','Time'], inplace=True)
    # Now df_merged has columns:
    #  [Date, Time, Open, High, Low, Close, Volume, OI, OptionType,
    #   StrikePrice, Ticker, Delta, Close_Index, ExpiryDate, PCR]
    logger.info(f"Merged shape: {df_merged.shape}")

    # For the GPU backtest, we'll just need:
    #   - pcr array
    #   - close prices array
    # In a real strategy, you might need Delta or other columns.

    # Drop rows with missing PCR if it happens
    df_merged = df_merged.dropna(subset=['PCR']).reset_index(drop=True)
    n_rows = len(df_merged)

    # --------------------------------------------------------
    # D) Prepare NumPy arrays for GPU
    # --------------------------------------------------------
    pcr_np = df_merged['PCR'].values.astype(np.float32)
    close_np = df_merged['Close'].values.astype(np.float32)

    # Copy to GPU
    pcr_gpu = cuda.to_device(pcr_np)
    close_gpu = cuda.to_device(close_np)

    # We'll store the result in a small array of length 1 (for one variant)
    # If you had multiple variants, you'd create a larger array: length = #variants
    pnl_gpu = cuda.device_array(1, dtype=np.float32)

    # --------------------------------------------------------
    # E) Launch the GPU kernel
    # --------------------------------------------------------
    # We use 1 block, 1 thread for demonstration (since we only do 1 variant)
    threads_per_block = 1
    blocks = 1

    logger.info("Launching GPU kernel for single-expiry demonstration...")

    backtest_single_expiry_kernel[blocks, threads_per_block](
        # JSON thresholds
        # Output

    # Synchronize

    # --------------------------------------------------------
    # F) Retrieve the result
    # --------------------------------------------------------
    pnl_result = pnl_gpu.copy_to_host()[0]
    logger.info(f"Backtest completed. Final PnL: {pnl_result:.2f}")

    # --------------------------------------------------------
    # G) Next Steps
    # --------------------------------------------------------
    # - If this works for the first expiry, you can loop over all expiries in expiry_df.
    # - If you have multiple variants (param sets), you can create bigger arrays for
    #   your thresholds, launch more threads, and do "thread_id = cuda.grid(1)"
    #   indexing inside the kernel.
    # - Ensure your memory usage and performance are tested for large data.

if __name__ == "__main__":
