"""
Created by: Matthew Todd

Script to project IMC WFS and MC2 Trans during various injections to the ISS
to estimate the ambient noise witnessed in the ISS OUTER RIN witness.

Portion of creating an IMC noise budget
"""
# %%
from gwpy.timeseries import TimeSeriesDict, TimeSeries
import numpy as np 
import matplotlib.pyplot as plt
from dttxml import dtt_read
import os, ipdb
import scipy.constants as scc

plt.rcParams.update(
    {'axes.labelsize': 24,
     'axes.titlesize': 24,
     'xtick.labelsize': 24,
     'ytick.labelsize': 24,
     'legend.fontsize': 14,
     'legend.markerscale': 1.0,
     'lines.linewidth': 2,
     'lines.markersize': 3}
)
# %%

###  Create dictionary with all the asds  ###

data_dict = {}

# load f, asd for RIN in the ISS during quiet time
# gpstime for quiet_data 1446823543
quiet_gpstime = 1446823543
f, asd = np.loadtxt(
    fname=os.path.abspath("dtt_injections/quiet_data_iss_outer_rin.txt"),
    comments="#",
    delimiter=" ",
    unpack=True
)
data_dict["iss_outer_rin"] = {
    "freq": f,
    "asd": asd,
    "label": "ISS_OUTER_RIN",
    "color": "C0",
    "linestyle": "-",
}

f, asd = np.loadtxt(
    fname=os.path.abspath("dtt_injections/quiet_data_iss_inner_rin.txt"),
    comments="#",
    delimiter=" ",
    unpack=True
)
data_dict["iss_inner_rin"] = {
    "freq": f,
    "asd": asd,
    "label": "ISS_INNER_RIN",
    "color": "C1",
    "linestyle": "-",
}

# load f, calibrated asd from WFS/MC2
dirname = os.path.abspath("./dtt_injections")
file_dict = {
    0: {
        "filename": "pzt_pit_injection_1_15_Hz.xml",
        "calibrated_channel": "H1:IMC-WFS_A_DC_PIT_OUT_DQ",
        "label": "IMC_PZT_PIT",
        "color": "C2",
        "linestyle": "None",
        "marker": "o",
        "Xref": 3,
        "Yref": 4
    },
    1: {
        "filename": "pzt_pit_injection_15_1000_Hz.xml",
        "calibrated_channel": "H1:IMC-WFS_A_DC_PIT_OUT_DQ",
        "label": "IMC_PZT_PIT",
        "color": "C3",
        "linestyle": "None",
        "marker": "o",
        "Xref": 3,
        "Yref": 4
    },
    2: {
        "filename": "pzt_yaw_injection_1_15_Hz.xml",
        "calibrated_channel": "H1:IMC-WFS_A_DC_PIT_OUT_DQ",
        "label": "IMC_PZT_YAW",
        "color": "C4",
        "linestyle": "None",
        "marker": "s",
        "Xref": 3,
        "Yref": 4
    },
    3: {
        "filename": "pzt_yaw_injection_15_1000_Hz.xml",
        "calibrated_channel": "H1:IMC-WFS_A_DC_PIT_OUT_DQ",
        "label": "IMC_PZT_YAW",
        "color": "C5",
        "linestyle": "None",
        "marker": "s",
        "Xref": 3,
        "Yref": 4
    },
    4: {
        "filename": "imc_dof_1_pit_injection_1_4_Hz.xml",
        "calibrated_channel": "H1:IMC-WFS_B_DC_PIT_OUT_DQ",
        "label": "IMC_DOF_1_P",
        "color": "C6",
        "linestyle": "None",
        "marker": "o",
        "Xref": 6,
        "Yref": 7,
        "bandlimits": (0.2, 5)
    },
    5: {
        "filename": "imc_dof_1_yaw_injection_1_4_Hz.xml",
        "calibrated_channel": "H1:IMC-WFS_B_DC_YAW_OUT_DQ",
        "label": "IMC_DOF_1_Y",
        "color": "C7",
        "linestyle": "None",
        "marker": "s",
        "Xref": 5,
        "Yref": 6,
        "bandlimits": (0.2, 4)
    },
    6: {
        "filename": "imc_dof_2_pit_injection_1_4_Hz.xml",
        "calibrated_channel": "H1:IMC-MC2_TRANS_PIT_OUT_DQ",
        "label": "IMC_DOF_2_P",
        "color": "C8",
        "linestyle": "None",
        "marker": "o",
        "Xref": 4,
        "Yref": 7,
        "bandlimits": (0.2, 4)
    },
    7: {
        "filename": "imc_dof_2_yaw_injection_1_4_Hz.xml",
        "calibrated_channel": "H1:IMC-WFS_A_DC_YAW_OUT_DQ",
        "label": "IMC_DOF_2_Y",
        "color": "C9",
        "linestyle": "None",
        "marker": "s",
        "Xref": 4,
        "Yref": 7,
        "bandlimits": (0.2, 4)
    },
}

for key, info in file_dict.items():
    filepath = os.path.join(dirname, info["filename"])
    data = dtt_read(filepath)

    if "bandlimits" in info.keys():
        bandlimits = info["bandlimits"]
    else:
        name, ext = info["filename"].split(".")
        bandlimits = name.split("_")[-3:-1]
        bandlimits = [float(bandlimit) for bandlimit in bandlimits]
    FHz =data.references[info["Xref"]].FHz
    freq_indices = np.searchsorted(FHz, bandlimits)
    freq_indices[0] -= 1
    freq_indices[1] += 1

    X_PSD = data.references[info["Xref"]].PSD[0]
    Y_PSD = data.references[info["Yref"]].PSD[0]
    
    CSDxy_index = data.results.CSD[info["calibrated_channel"]].channelB_inv['H1:PSL-ISS_SECONDLOOP_RIN_OUTER_OUT_DQ']
    CSDxx = data.results.PSD[info["calibrated_channel"]].PSD[0]**2
    CSDxy = data.results.CSD[info["calibrated_channel"]].CSD[CSDxy_index]
    TF = np.abs(CSDxy/CSDxx)
    
    calibrated_x_asd = TF*X_PSD

    if info["label"] in data_dict.keys():
        data_dict[info["label"]]["freq"] = np.concatenate((
            data_dict[info["label"]]["freq"],
            FHz[slice(*freq_indices)]
        ))
        data_dict[info["label"]]["asd"] = np.concatenate((
            data_dict[info["label"]]["asd"],
            calibrated_x_asd[slice(*freq_indices)]
        ))
    else:
        data_dict[info["label"]] = {
            "freq": FHz[slice(*freq_indices)],
            "asd": calibrated_x_asd[slice(*freq_indices)],
            "label": info["label"],
            "color": info.get("color", f"C{key+2}"),
            "linestyle": info.get("linestyle", f"-"),
            "marker": info.get("marker", None)
        }

# %%

### Plot all the curves ###
fig, axs = plt.subplots(figsize=(8, 8), tight_layout=True)


# Plot projections and RIN
for idx, (key, value) in enumerate(data_dict.items()):
    axs.loglog(
        value["freq"],
        value["asd"],
        label=value["label"],
        color=value["color"],
        ls=value["linestyle"],
        marker=value.get("marker", None),
    )

# Plot shot noise limit of outer
dc_power = TimeSeries.fetch(
    channel="H1:PSL-ISS_SECONDLOOP_PDSUMOUTER_OUT_DQ",
    start=quiet_gpstime,
    end=quiet_gpstime+360,
    ).median().value*1e-3
lambda0 = 1064e-9
shot_noise_limit = np.sqrt(2*scc.h*(scc.c/lambda0) / dc_power)
shot_noise_limit = np.ones_like(data_dict["iss_outer_rin"]['freq'])*shot_noise_limit
axs.loglog(
    data_dict["iss_outer_rin"]['freq'],
    shot_noise_limit,
    color="black",
    ls="--"
)
axs.text(
    x=0.3, 
    y=shot_noise_limit[0]*0.99, 
    s='Shot Noise Limit', 
    color='k', 
    verticalalignment='top', # Places the text just above the line
    horizontalalignment='center', # Centers the text at the x position
    fontsize=12,
)


axs.set_xlabel("Frequency [Hz]")
axs.set_ylabel("ASD [V/rtHz]")
axs.grid(True, 'major', 'both', alpha=0.5)
axs.grid(True, 'minor', 'both', alpha=0.2)
axs.set_xlim(1e-1, 5000)
axs.legend(loc="upper right")
axs.set_title("ISS Noise Budget for LHO O4")
fig.savefig("iss_noise_budget_LHO_O4.pdf")
fig.savefig("iss_noise_budget_LHO_O4.png")
# %%
