"""Loads and processes raw data from spectrometer and power meter and returns and saves scaled spectral distributions used for calibration of Pysilsub."""
# modules
import os
import pathlib as pl
import re
from datetime import date
from warnings import warn

import numpy as np
import pandas as pd

# functions
def get_led_files(in_dir: str | pl.Path):
    """Get files from directory"""
    in_dir = pl.Path(in_dir)
    spectra_files = list(in_dir.glob("*.csv"))
    filename_pattern = r".*[RL]\d.*\.csv$"

    file_list = [f for f in spectra_files if re.match(filename_pattern, f.name)]
    return file_list

def load_thorlabs_power(filepath: pl.Path | str) -> pd.DataFrame:
    """Load file of thorlabs power measurements of one LED at multiple
    intensities, measured in W.

        filepath (pl.Path | str): filepath to CSV

        pd.DataFrame: returns raw data of power measurements at mulitpl
    df = pd.read_csv(filepath, sep=",", skiprows=14)
    return df

def load_jeti_spectra(filepath: pl.Path | str) -> pd.DataFrame:
    """Load spectra measured bei Jeti from one CSV file; spectral radiance
    is measured in W/(sr*sqm*nm)

        filepath (str): filepath to CSV file with spectra of multiple measurements

        pd.DataFrame: raw data without meta data
    column_names = pd.read_csv(filepath, sep=";", skiprows=9).columns
    data = pd.read_csv(filepath, sep=";", skiprows=18, names=column_names)

    # change dtype of spectra columns
    data = data.set_index("Name")
    data = data.apply(lambda x: x.str.replace(",", "."))
    data = data.apply(pd.to_numeric)
    return data

def calc_spectra_peaks(data: pd.DataFrame) -> pd.DataFrame:
    """Identify peak wavelength of multiple spectra in a dataframe, per column."""
    return data.apply(lambda col: col.idxmax(), axis=0)

def normalize_spectra(data: pd.DataFrame) -> pd.DataFrame:
    """Normalize at sum of column, with each column being a spectrum

        df (pd.DataFrame): Dataset with columns of spectra at different light ratios

        pd.DataFrame: Dataset with normalized spectra
    column_sums = data.sum()
    return data.div(column_sums)

def get_dict_of_spectra(spectra_dir: pl.Path | str) -> dict:
    spectra_dict = {}
    file_list = get_led_files(pl.Path(spectra_dir))
    for filepath in sorted(file_list):
        # process data
        df = load_jeti_spectra(filepath)
        # TODO evaluate peaks
        # peaks = get_spectra_peaks(df)
        normdf = normalize_spectra(df)
        light_ratios = [float(re.findall("(\d{3})$", col)[0]) for col in normdf.columns]
        normdf.columns = light_ratios
        # collect in dict
        # led_name = re.search("([LR]\d{1})", filepath.name)[0]
        key = filepath.stem
        spectra_dict[key] = normdf
    return spectra_dict

def get_dict_of_power_measurements(power_dir: pl.Path | str) -> dict:
    power_dict = {}
    # input light ratios
    light_ratios = np.insert(np.arange(5, 101, 5), 0, [0, 1, 2])
    file_list = get_led_files(pl.Path(power_dir))
    for filepath in sorted(file_list):
        df = load_thorlabs_power(filepath)
        # remove every other row (due to measurement logic)
            "Extracting every other row from power measurement file. Starting with second measurement (index 1)"
        df = df.iloc[1::2].reset_index()
        df["LR"] = light_ratios.astype(float)

        # collect in dict
        # led_name = re.search("([LR]\d{1})", filepath.name)[0]
        key = filepath.stem
        power_dict[key] = df[["LR", "Power (W)"]]
    return power_dict

def scale_spectra(spectra_dir: pl.Path | str, power_dir: pl.Path | str) -> dict:
    """Scale spectra of Jeti according to power from Thorlabs powermeter.
    Power measurements and spectra are matched based on filename

        spectra_dir (pl.Path): directory of spectra of different LEDs, each
        file contains spectra at different intensities
        power_dir (pl.Path): directory of power measurements of different LEDs,
        each file has a power measurements of different intensities

        dict: dictionary

    # get jeti spectra and normalize
    spectra_dict = get_dict_of_spectra(pl.Path(spectra_dir))

    # powermeter measurements
    power_dict = get_dict_of_power_measurements(pl.Path(power_dir))

    scaled_dict = {}
    for key in list(power_dict.keys()):
        power_df = power_dict[key]
        spectra_df = spectra_dict[key]
        # TODO check if all LR in both power and spectrum
        scaled_df = spectra_df.mul(power_df["Power (W)"].values).T
        scaled_df.insert(loc=0, column="Setting", value=scaled_df.index)
        scaled_dict[key] = scaled_df
    return scaled_dict

def depr_reorganize_spectra_per_light_source(spectra_dict: dict) -> dict:
    """Merge spectra of multiple LEDs (e.g. left vs. right, or A vs. B vs. C),
    into one dataframe for each light source (LEDs) and return dictionary of
    dataframes for light sources.
    unique_sources = {key[0] for key in spectra_dict.keys()}
    # initialize data frames for different sources
    sourc_dict = dict.fromkeys(unique_sources, pd.DataFrame())
    for sourc in list(sourc_dict.keys()):
        # concat primaries spectra of this source
        primaries = [prim for prim in spectra_dict.keys() if sourc in prim]
        for prim in primaries:
            df = spectra_dict[prim]
            sourc_dict[sourc] = pd.concat([sourc_dict[sourc], df])
    return sourc_dict

def get_calibration_data(
    spectra_dir: pl.Path, power_dir: pl.Path, out_dir: pl.Path
) -> dict:
    """Provide data from Jeti and Thorlabs powermeters to get one data frame
    per light source with scaled spectra based on power. Returns dict.
    # scaled spectra for all LEDs from different light sources
    scaled_spectra_dict = scale_spectra(spectra_dir, power_dir)

    # save file with saved spectra per light source
    for key in list(scaled_spectra_dict.keys()):
        out_path = out_dir / f"{key}.csv"
        scaled_spectrum = scaled_spectra_dict[key]
        scaled_spectrum.to_csv(out_path, index=False)

    return scaled_spectra_dict

def scaled_spectra_to_calibration_csv(
    out_dir: pl.Path, scaled_dir: pl.Path, prefix_id: list = ["L", "R"]
    """Reorganize scaled spectra into calibration csv for pysilsub.

        out_dir (pl.Path): to save
        scaled_dir (pl.Path): input directory with csv with each one spectrum
        prefix_id (list, optional): _description_. Defaults to ["L", "R"].
    for _, source_flag in enumerate(prefix_id):
        prim_pathlist = [
            scaled_dir / prim
            for prim in os.listdir(scaled_dir)
            if prim.startswith(source_flag)

        big_df = pd.DataFrame()
        for pidx, prim_path in enumerate(prim_pathlist):
            primdf = pd.read_csv(prim_path)
            primdf.insert(loc=0, column="Primary", value=pidx)
            big_df = pd.concat([big_df, primdf])

        big_df = big_df.reset_index(drop=True)
        out_path = out_dir / f"calibration_{source_flag}.csv"
        big_df.to_csv(out_path, index=False)
        print(f"Saved to {out_path}")

def plot_spectra(big_df):
    from matplotlib import pyplot as plt

    df = big_df.T

    # Create a line plot
    plt.figure(figsize=(10, 6))  # Adjust the figure size as needed

    # Plot each column as a separate line
    for column in df.columns:
        plt.plot(df.T.iloc[2:, :].index, df.T.iloc[2:, 1], label=f"{column}")

    # Customize the plot
    plt.ylabel("Scaled spectrum")
    # Adjust x-axis tick placement
    x_ticks = np.arange(len(df.iloc[2:, :].index))
    x_tick_labels = df.iloc[2:, :].index

    # Show every tenth tick
    x_ticks = x_ticks[::10]
    x_tick_labels = x_tick_labels[::10]

    plt.xticks(x_ticks, x_tick_labels, rotation=45)

    # Show the plot

if __name__ == "__main__":
    # input
    data_dir = pl.Path("data")
    spectra_dir = data_dir / "calibration_spectra_jeti_0.0OD" / "20230725"
    power_dir = data_dir / "calibration_power_thorlabs_0.0OD" / "20230802"
    # output
    today_str = date.today().strftime("%Y%m%d")
    scaled_dir = data_dir / "calibration_scaled_spectra_0.0OD" / today_str
    out_dir = data_dir / "pysilsub_calibration" / today_str
    os.makedirs(scaled_dir, exist_ok=True)
    os.makedirs(out_dir, exist_ok=True)

    # scaled spectra for all LEDs from different light sources
    scaled_spectra_dict = scale_spectra(spectra_dir, power_dir)

    # save file with saved spectra per light source
    for key in list(scaled_spectra_dict.keys()):
        scaled_path = scaled_dir / f"{key}.csv"
        scaled_spectrum = scaled_spectra_dict[key]
        scaled_spectrum.to_csv(scaled_path, index=False)

    # reorganize scaled spectra into calibration csv for pysilsub
        out_dir=out_dir, scaled_dir=scaled_dir, prefix_id=["L", "R"]