import logging
import pandas as pd
import numpy as np
import glob
import re
from datetime import datetime
from numba import cuda, int32, float32

# -------------------------------
# 1) Configure Logging
# -------------------------------
logger = logging.getLogger("GPUBacktest")
ch = logging.StreamHandler()

# -------------------------------
# 2) Hardcoded JSON-Like Conditions
# -------------------------------
BULLISH_ENTRY_PCR = 0.7         # "Bullish": PCR > 0.7 => entry: PE (position_side = +1)
HEAVILY_BULLISH_PCR = 1.2375    # "Heavily Bullish": PCR > 1.2375 => exit CE
BEARISH_ENTRY_PCR = 1.1875      # "Bearish": PCR < 1.1875 => entry: CE (position_side = -1)
HEAVILY_BEARISH_PCR = 0.65      # "Heavily Bearish": PCR < 0.65 => exit PE

# -------------------------------
# 3) GPU Kernel: One Thread Per Expiry
# -------------------------------
def backtest_multiple_expiries_kernel(
    expiry_offsets,  # shape (num_expiries, 2): [ [start_index, length], ... ]
    bull_entry, heavy_bull_exit,
    bear_entry, heavy_bear_exit,
    result_pnl  # shape (num_expiries,)
    Each thread (thread_id = e) processes one expiry's slice of data.
    pcr_array, close_array: float32 arrays combining all expiries.
    expiry_offsets[e, 0] = start index for expiry e
    expiry_offsets[e, 1] = length (number of bars) for expiry e
    We do a time-sequential loop for that expiry (PCR/Close data).
    The final PnL is stored in result_pnl[e].

    e = cuda.grid(1)  # expiry_id
    if e >= result_pnl.size:

    start_idx = expiry_offsets[e, 0]
    length    = expiry_offsets[e, 1]

    in_position = False
    position_side = 0.0  # +1 => PE, -1 => CE
    entry_price = 0.0
    total_pnl = 0.0

    # Time-sequential loop over this expiry's slice
    for i in range(length):
        idx = start_idx + i
        pcr_val   = pcr_array[idx]
        close_val = close_array[idx]

        # If in position, check exit conditions
        if in_position:
            # PE => exit if pcr_val < heavy_bear_exit
            if position_side > 0 and pcr_val < heavy_bear_exit:
                current_pnl = (close_val - entry_price) * position_side
                total_pnl += current_pnl
                in_position = False

            # CE => exit if pcr_val > heavy_bull_exit
            elif position_side < 0 and pcr_val > heavy_bull_exit:
                current_pnl = (close_val - entry_price) * position_side
                total_pnl += current_pnl
                in_position = False

        # If not in position, check entry signals
        if not in_position:
            # Bullish => PCR > bull_entry => open PE
            if pcr_val > bull_entry:
                in_position   = True
                position_side = 1.0
                entry_price   = close_val
            # Bearish => PCR < bear_entry => open CE
            elif pcr_val < bear_entry:
                in_position   = True
                position_side = -1.0
                entry_price   = close_val

    # End of loop, close if still in position
    if in_position and length > 0:
        last_close = close_array[start_idx + length - 1]
        final_pnl = (last_close - entry_price) * position_side
        total_pnl += final_pnl

    result_pnl[e] = total_pnl

def main():
    # ---------------------------------------------------
    # A) Read master data: PCR
    # ---------------------------------------------------
    logger.info("Reading PCR data...")
    pcr_df = pd.read_pickle("data/MAIN_NIFTY50_PCR.pkl")  # columns: [Date, Time, PCR]

    # For demonstration, we won't use expiry_df here. We'll rely on file-naming for expiries.
    # But if you do have an expiry_df, you can cross-check or store statuses, etc.

    # ---------------------------------------------------
    # B) Gather OptionChain files for multiple expiries
    # ---------------------------------------------------
    pattern = "data/MAIN_NIFTY50_OPTIONS_*_OptionChain.pkl"
    files = glob.glob(pattern)
    logger.info(f"Found {len(files)} option chain files: {files}")

    # We want to parse the date from each filename: "YYYY_MM_DD"
    # We assume it ends with "..._YYYY_MM_DD_OptionChain.pkl"
    date_pattern = re.compile(r"NIFTY50_(\d{4})_(\d{2})_(\d{2})_OptionChain\.pkl$")

    # We'll collect merged data for all expiries
    # combined_pcr_list, combined_close_list = big arrays to store data for all
    # but we don't know total length up front, so we store partial results in lists, then np.concatenate
    pcr_segments = []
    close_segments = []
    expiry_offsets = []
    expiry_labels = []  # track the string date for reporting

    current_start_idx = 0  # cumulative offset

    # Process each file
    for file_path in sorted(files):
        match = date_pattern.search(file_path)
        if not match:
            logger.warning(f"Skipping file (date parse failed): {file_path}")

        yyyy, mm, dd = match.groups()
        expiry_str = f"{yyyy}-{mm}-{dd}"  # e.g. "2021-01-07"
        logger.info(f"Processing {file_path}, parsed expiry date: {expiry_str}")

        # 1) Load option data
        options_df = pd.read_pickle(file_path)

        # 2) Convert the expiry date to "DD-MM-YYYY" if your data uses that format
        dt = datetime.strptime(expiry_str, "%Y-%m-%d")
        expiry_ddmmyyyy = dt.strftime("%d-%m-%Y")

        # 3) Filter DataFrame for just that expiry
        df_exp = options_df[options_df['ExpiryDate'] == expiry_ddmmyyyy].copy()
        df_exp.sort_values(['Date','Time'], inplace=True)

        if df_exp.empty:
            logger.warning(f"No matching rows for expiry {expiry_ddmmyyyy}, skipping.")

        # 4) Merge with PCR on [Date, Time]
        df_merged = pd.merge(df_exp, pcr_df, on=['Date','Time'], how='left')
        df_merged.sort_values(['Date','Time'], inplace=True)
        df_merged.dropna(subset=['PCR'], inplace=True)

        if df_merged.empty:
            logger.warning(f"After merging with PCR, no data remains for {expiry_str}, skipping.")

        pcr_vals = df_merged['PCR'].values.astype(np.float32)
        close_vals = df_merged['Close'].values.astype(np.float32)
        segment_len = len(pcr_vals)

        # 5) Store segment arrays

        # Build offset info
        expiry_offsets.append((current_start_idx, segment_len))

        current_start_idx += segment_len

    # If no valid data found at all, bail out
    num_expiries = len(expiry_offsets)
    if num_expiries == 0:
        logger.info("No valid expiry data found. Exiting.")

    # ---------------------------------------------------
    # C) Combine all data into big arrays
    # ---------------------------------------------------
    logger.info("Combining data for all expiries into single arrays...")

    combined_pcr = np.concatenate(pcr_segments)     # shape (sum_of_lengths,)
    combined_close = np.concatenate(close_segments) # same shape
    offsets_np = np.array(expiry_offsets, dtype=np.int32)  # shape (num_expiries, 2)

    # ---------------------------------------------------
    # D) Copy to GPU
    # ---------------------------------------------------
    logger.info(f"Total bars across all {num_expiries} expiries: {combined_pcr.shape[0]}")
    pcr_gpu = cuda.to_device(combined_pcr)
    close_gpu = cuda.to_device(combined_close)
    offsets_gpu = cuda.to_device(offsets_np)

    # We'll have 1 thread per expiry => result array of length num_expiries
    result_gpu = cuda.device_array(num_expiries, dtype=np.float32)

    # ---------------------------------------------------
    # E) Launch the kernel: one thread per expiry
    # ---------------------------------------------------
    threads_per_block = 128
    blocks = (num_expiries + threads_per_block - 1) // threads_per_block

    logger.info(f"Launching kernel with {blocks} blocks, {threads_per_block} threads/block, for {num_expiries} expiries.")

    backtest_multiple_expiries_kernel[blocks, threads_per_block](
        pcr_gpu, close_gpu, offsets_gpu,

    # ---------------------------------------------------
    # F) Retrieve results
    # ---------------------------------------------------
    result_cpu = result_gpu.copy_to_host()  # shape (num_expiries,)

    # ---------------------------------------------------
    # G) Reporting
    # ---------------------------------------------------
    logger.info("Backtest results for each expiry:")
    for i, exp_date in enumerate(expiry_labels):
        pnl = result_cpu[i]
        logger.info(f"  {exp_date} => PnL: {pnl:.2f}")

if __name__ == "__main__":
