import numpy as np
import matplotlib.pyplot as plt
import os, termcolor, matplotlib, configparser, codecs
from pydarm.darm import DARMModel
from gwpy.timeseries import *
from scipy import signal

matplotlib.rcParams.update({'font.size': 14})
matplotlib.rcParams.update({'figure.figsize': (14,9)})

def get_cheta_rin():
    files = os.listdir('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN/')
    freqs = {}
    psd = {}
    for f in files:
        data = codecs.open('/ligo/home/matthewrichard.todd/Projects//cheta_rin/darm_to_esd/20240815_CHETA_RIN/' + f, encoding='cp1252')
        x = np.loadtxt(data, skiprows=23)
        kw = f.split('_')[0]
        if kw in freqs.keys():
            #freqs[kw].append(x[:,0])
            psd[kw].append(x[:,1])
        else:
            freqs[kw] = x[:,0]
            psd[kw] = [x[:,1],]
            
    for kw in freqs.keys():
        psd[kw] = np.sqrt((np.array(psd[kw])**2).mean(axis=0))

    dc = 2.4

    # concatenate some of the measurements
    kws = ['6.25Hz', '25Hz', '100Hz', '400Hz', '3.2kHz', '25.6kHz']

    fr = []
    sp = []
    for k in kws:
        if len(fr) == 0:
            fr = freqs[k]
            sp = psd[k]
        else:
            idx = freqs[k] > fr.max()
            fr = np.concatenate([fr, freqs[k][idx]])
            sp = np.concatenate([sp, psd[k][idx]])

    # read old measurement

    files = [os.path.join('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN','SCRN0003.TXT'), 
            os.path.join('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN','SCRN0004.TXT'), 
            os.path.join('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN','SCRN0005.TXT')]
    f = []
    p = []
    for ff in files:
        x = np.loadtxt(ff)
        f.append(x[:,0])
        p.append(x[:,1])
        

    # concatenate old measurements

    fr_old = []
    sp_old = []
    for ff,pp in zip(f[::-1], p[::-1]):
        if len(fr_old) == 0:
            fr_old = ff
            sp_old = pp
        else:
            idx = ff > fr_old.max()
            fr_old = np.concatenate([fr_old, ff[idx]])
            sp_old = np.concatenate([sp_old, pp[idx]])



    ## Load new RIN measurements
    R = np.loadtxt(os.path.join('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN','CHETA_RIN_PWM.txt'))
    Ropen  = np.interp(fr, R[:,0], R[:,1]/0.5)

        
    rin = 1e-7*np.ones_like(fr)
    rin[fr<10] = 1e-7 * (fr[fr<10]/10)**-1
    rin[fr>1000] = 1e-7 * (fr[fr>1000]/1000)**1.4

    gps = 1379395092
    dt = 300

    channels = ['L1:GDS-CALIB_STRAIN_CLEAN', 'H1:GDS-CALIB_STRAIN_CLEAN']
    data = TimeSeriesDict.get(channels, gps, gps+dt)

    fs = 16384
    Np = 5*fs

    sp_ = {}
    for c in channels:
        fr_, sp_[c] = signal.welch(data[c].value, nperseg=Np, noverlap=Np//2, fs=fs, average='median')
        sp_[c] = np.sqrt(sp_[c])



    P = 0.5    # CO2 power
    TRo = np.sqrt(2) * 2 * 1e-20 * (150/fr) * (P/25e-3)*(Ropen/4e-6)
    freq = fr

    data = {
        'Frequencies': freq,
        'CHETA_RIN_OPEN_raw': Ropen,
        'CHETA_RIN_OPEN_calibrated': TRo,
        'NLN_DARM': np.array([fr_, 4000*sp_['H1:GDS-CALIB_STRAIN_CLEAN']]),

    }
    return data

def asd_to_rms(f, asd, reverse=False):
    rms_curve = np.zeros_like(asd)
    if not reverse:
        for i in range(len(f)-1, -1, -1):
            rms = np.sqrt(np.sum(asd[i:]**2))
            rms_curve[i] = rms
    else:
        for i in range(0, len(f)):
            rms = np.sqrt(np.sum(asd[:i]**2))
            rms_curve[i] = rms
    print()
    return rms_curve

def read_tf_txt(filename, comment_char='#',
                delimiter = " ", freq_sep = " ",
                ):
    
    file = open(os.path.join(os.getcwd(), 'Projects/darm_to_esd',
                 os.path.basename(filename)), 'r')
    lines = file.readlines()
    
    nlines = sum([0 if line[0] == comment_char else 1 for line in lines])
    print(termcolor.colored(f"\nReading from {file.name} ...", color='green'))
    print(termcolor.colored(f"{nlines} data points found.", color='white'))
    
    comments = []
    i = 0
    while True:
        line = lines[i]
        if line[0] != comment_char:
            break
        else:
            comments.append(line[1:-1].strip())
            i += 1

    main_body = lines[i:]
    
    print('Comments:')
    for each in comments[:-1]:
        print('\t', each)

    subdata = [sub.strip() for sub in line.split(freq_sep)]
    to_skip = []
    for j in range(len(subdata)):
        if subdata[j] == '':
            to_skip.append(j)
    
    ntfs = (len(subdata)-1 - len(to_skip))//2
    freq = np.ndarray((nlines, 1))
    data = np.ndarray((nlines, ntfs*2))
    for j in range(len(main_body)):
        subdata = [sub.strip() for sub in main_body[j].split(freq_sep)]
        to_skip = []
        for k in range(len(subdata)):
            if subdata[k] == '':
                to_skip.append(k)
        subdata = [subdata[k] for k in range(len(subdata)) if k not in to_skip]
        freq[j] = float(subdata[0])
        data[j] = [float(subdata[k+1]) for k in range(len(subdata[1:]))]

    file.close()

    return freq.T.flatten(), data

def calibrate(freq_data, data, freq_cal, data_cal):

    calibrated_data = np.zeros_like(data)
    for i in range(len(freq_data)):

        freq_cal_index = np.searchsorted(freq_cal, freq_data[i])

        if freq_cal_index == 0:
            f0, f1 = freq_cal[0], freq_cal[1]
            cal0, cal1 = data_cal[0], data_cal[1]

            slope = (cal1-cal0)/(f1-f0)
        elif freq_cal_index == len(freq_cal)-1:
            f0, f1 = freq_cal[-2], freq_cal[-1]
            cal0, cal1 = data_cal[-2], data_cal[-1]

            slope = (cal1-cal0)/(f1-f0)
        elif freq_cal_index == len(freq_cal):
            f0, f1 = freq_cal[-2], freq_cal[-1]
            cal0, cal1 = data_cal[-2], data_cal[-1]

            slope = (cal1-cal0)/(f1-f0)
        else:
            f0, f1 = freq_cal[freq_cal_index], freq_cal[freq_cal_index+1]
            cal0, cal1 = data_cal[freq_cal_index], data_cal[freq_cal_index+1]

            slope = (cal1-cal0)/(f1-f0)

        calibrated_data[i] = (slope*(freq_data[i]-f0) + cal0)*data[i]
        # print(i, freq_cal_index, freq_cal[freq_cal_index], len(freq_data), data[i], (slope*(freq_data[i]-f0) + cal0), calibrated_data[i])
        

    return calibrated_data

if __name__ == "__main__":

    # reading RIN data for CHETA (calibrated to DELTAL already)
    data = get_cheta_rin()
    freq_deltaL, rin_deltaL = data['Frequencies'], data['CHETA_RIN_OPEN_calibrated']
    NLN_fr, NLN_deltaL = data['NLN_DARM']
    freq_deltaL = freq_deltaL[:855]
    rin_deltaL = rin_deltaL[:855]

    # darm coupled cavity optical gain from pydarm report
    darm = DARMModel(os.path.join('/ligo/groups/cal/H1/reports/20250123T211118Z/', 'pydarm_H1.ini'))
    G = darm.compute_darm_olg(freq_deltaL)
    C = darm.sensing.compute_sensing(freq_deltaL)
    D = darm.digital.compute_response(freq_deltaL)

    # Mapping H1:CAL-DELTAL_CTRL to  H1:SUS-ETMX_L3_ESDOUTF_UL
    L1DA, L2DA, L3DA = darm.actuation.xarm.sus_digital_filters_response(freq_deltaL)

    L3_DAC = rin_deltaL * np.abs( 1 / (1 - G)) * np.abs(C) * np.abs(D) *  np.abs(L3DA)
    L3_DAC_rms = asd_to_rms(freq_deltaL, L3_DAC)

    # Mapping H1:CAL-DELTAL_CTRL to  H1:SUS-ETMX_L2_COILOUTF_UL
    L2_DAC = rin_deltaL * np.abs( 1 / (1 - G)) * np.abs(C) * np.abs(D) *  np.abs(L2DA)
    L2_DAC_rms = asd_to_rms(freq_deltaL, L2_DAC)

    # Mapping H1:CAL-DELTAL_CTRL to  H1:SUS-ETMX_L1_COILOUTF_UL
    L1_DAC = rin_deltaL * np.abs( 1 / (1 - G)) * np.abs(C) * np.abs(D) *  np.abs(L1DA)
    L1_DAC_rms = asd_to_rms(freq_deltaL, L1_DAC)

    # calibrating DARM_NLN to ESD, L2, L1
    Cnln = darm.sensing.compute_sensing(NLN_fr)
    Dnln = darm.digital.compute_response(NLN_fr)
    L1DAnln, L2DAnln, L3DAnln = darm.actuation.xarm.sus_digital_filters_response(NLN_fr)

    L3_DAC_from_nln = NLN_deltaL * np.abs(Cnln) * np.abs(Dnln) *  np.abs(L3DAnln)
    L3_DAC_from_nln_rms = asd_to_rms(NLN_fr, L3_DAC_from_nln)

    L2_DAC_from_nln = NLN_deltaL * np.abs(Cnln) * np.abs(Dnln) *  np.abs(L2DAnln)
    L2_DAC_from_nln_rms = asd_to_rms(NLN_fr, L2_DAC_from_nln)

    L1_DAC_from_nln = NLN_deltaL * np.abs(Cnln) * np.abs(Dnln) *  np.abs(L1DAnln)
    L1_DAC_from_nln_rms = asd_to_rms(NLN_fr, L1_DAC_from_nln)


    ###################################################################################

    # plot olg tf
    fig, axs = plt.subplots(2, 1, sharex=True,figsize=(12,9), tight_layout=True)
    axs[0].loglog(freq_deltaL, np.abs(G[:]), label = 'DARM OLG')
    axs[0].set_ylabel('Magnitude')
    axs[0].grid(True, 'major', alpha=.5)
    axs[0].grid(True, 'minor', alpha=.2)
    axs[0].legend()

    axs[1].semilogx(freq_deltaL, np.angle(G[:], deg=True),)
    axs[1].set_xlabel('Frequency [Hz]')
    axs[1].set_ylabel('Phase [deg]')
    fig.suptitle("OLG Measurement 2025-01-23 T21:11:18")
    axs[1].grid(True, 'both')
    axs[1].set_xlim(.1, max(freq_deltaL))

    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/darm_olg.pdf'))

    # plot Sensing Function tf
    fig, axs = plt.subplots(2, 1, sharex=True,figsize=(12,9), tight_layout=True)
    axs[0].loglog(freq_deltaL, np.abs(C[:]), label = 'DARM Sensing Function')
    axs[0].set_ylabel('Magnitude [cts/m]')
    axs[0].grid(True, 'major', alpha=.5)
    axs[0].grid(True, 'minor', alpha=.2)
    axs[0].legend()

    axs[1].semilogx(freq_deltaL, np.angle(C[:], deg=True),)
    axs[1].set_xlabel('Frequency [Hz]')
    axs[1].set_ylabel('Phase [deg]')
    fig.suptitle("DARM Sensing Function Measurement 2025-01-23 T21:11:18")
    axs[1].grid(True, 'both')
    axs[1].set_xlim(.1, max(freq_deltaL))

    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/darm_sensing.pdf'))

    # plot Digitals tf
    fig, axs = plt.subplots(2, 1, sharex=True,figsize=(12,9), tight_layout=True)
    axs[0].loglog(freq_deltaL, np.abs(D[:]), label = 'DARM Digitals')
    axs[0].set_ylabel('Magnitude [cts/cts]')
    axs[0].grid(True, 'major', alpha=.5)
    axs[0].grid(True, 'minor', alpha=.2)
    axs[0].legend()

    axs[1].semilogx(freq_deltaL, np.angle(D[:], deg=True),)
    axs[1].set_xlabel('Frequency [Hz]')
    axs[1].set_ylabel('Phase [deg]')
    fig.suptitle("DARM Digitals Measurement 2025-01-23 T21:11:18")
    axs[1].grid(True, 'both')
    axs[1].set_xlim(.1, max(freq_deltaL))

    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/darm_digitals.pdf'))

    # plot L3DA tf
    fig, axs = plt.subplots(2, 1, sharex=True,figsize=(12,9), tight_layout=True)
    axs[0].loglog(freq_deltaL, np.abs(L3DA[:]), label = 'L3 DA')
    axs[0].set_ylabel('Magnitude [cts/cts]')
    axs[0].grid(True, 'major', alpha=.5)
    axs[0].grid(True, 'minor', alpha=.2)
    axs[0].legend()

    axs[1].semilogx(freq_deltaL, np.angle(L3DA[:], deg=True),)
    axs[1].set_xlabel('Frequency [Hz]')
    axs[1].set_ylabel('Phase [deg]')
    fig.suptitle("DARM-ctrl to L3 counts Measurement 2025-01-23 T21:11:18")
    axs[1].grid(True, 'both')
    axs[1].set_xlim(.1, max(freq_deltaL))

    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/darm_l3da.pdf'))


    # plot CHETA data in ESD counts
    ratio = 2**-19*L3_DAC_rms[1]*100
    fig = plt.figure(figsize=(14,9), tight_layout=True)
    plt.hlines(2**19, freq_deltaL[0], freq_deltaL[-1], color='C3', ls='--',
            label="Saturation Level")
    plt.loglog(freq_deltaL, L3_DAC, color='C0', label='CHETA noise in ESD cts/rtHz',
               linewidth = .8)
    plt.loglog(freq_deltaL, L3_DAC_rms, color='C0', ls='--', 
            label=f'CHETA noise in ESD cts RMS\n% of Saturation Limit: {ratio:.3f}%',
               linewidth = .8)
    
        ## add DARM at NLN plot ##
    plt.loglog(NLN_fr, L3_DAC_from_nln, color='C1', label='DARM at NLN in ESD cts/rtHz',
               linewidth = .8)
    plt.loglog(NLN_fr, L3_DAC_from_nln_rms, color='C1', ls='--',
               linewidth = .8)

    plt.xlabel("Frequency [Hz]")
    plt.ylabel("CHETA in ESDs [cts/rtHz]")
    plt.title(r"Calibrating $\delta l_{CHETA}$ to counts before L3_DAC")
    plt.grid(True, 'major', 'both', alpha=.5)
    plt.grid(True, 'minor', 'both', alpha=.2)
    plt.xlim(10, max(freq_deltaL))
    print(f"\nPlotting Frequencies: {10} to {max(freq_deltaL)}")
    plt.ylim(1e-11, 2**20)
    plt.legend(loc='lower left', )
    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/cheta_in_esds.pdf'))

    # plot L2DA tf
    fig, axs = plt.subplots(2, 1, sharex=True,figsize=(12,9), tight_layout=True)
    axs[0].loglog(freq_deltaL, np.abs(L2DA[:]), label = 'L3 DA')
    axs[0].set_ylabel('Magnitude [cts/cts]')
    axs[0].grid(True, 'major', alpha=.5)
    axs[0].grid(True, 'minor', alpha=.2)
    axs[0].legend()

    axs[1].semilogx(freq_deltaL, np.angle(L2DA[:], deg=True),)
    axs[1].set_xlabel('Frequency [Hz]')
    axs[1].set_ylabel('Phase [deg]')
    fig.suptitle("DARM-ctrl to L2 counts Measurement 2025-01-23 T21:11:18")
    axs[1].grid(True, 'both')
    axs[1].set_xlim(.1, max(freq_deltaL))

    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/darm_l2da.pdf'))

    # plot CHETA data in COILsL2 counts
    ratio = 2**-19*L2_DAC_rms[1]*100
    fig = plt.figure(figsize=(14,9), tight_layout=True)
    plt.hlines(2**19, freq_deltaL[0], freq_deltaL[-1], color='C3', ls='--',
            label="Saturation Level")
    plt.loglog(freq_deltaL, L2_DAC, color='C0', label='CHETA noise in COILsL2 cts/rtHz',
               linewidth = .8)
    plt.loglog(freq_deltaL, L2_DAC_rms, color='C0', ls='--', 
            label=f'CHETA noise in COILs cts RMS\n% of Saturation Limit: {ratio:.3f}%',
               linewidth = .8)
    
    ## add DARM at NLN plot ##
    plt.loglog(NLN_fr, L2_DAC_from_nln, color='C1', label='DARM at NLN in L2coils cts/rtHz',
               linewidth = .8)
    plt.loglog(NLN_fr, L2_DAC_from_nln_rms, color='C1', ls='--',
               linewidth = .8)

    plt.xlabel("Frequency [Hz]")
    plt.ylabel("CHETA in COILs for L2 [cts/rtHz]")
    plt.title(r"Calibrating $\delta l_{CHETA}$ to counts before L2_DAC")
    plt.grid(True, 'major', 'both', alpha=.5)
    plt.grid(True, 'minor', 'both', alpha=.2)
    plt.xlim(10, max(freq_deltaL))
    print(f"\nPlotting Frequencies: {10} to {max(freq_deltaL)}")
    plt.ylim(1e-11, 2**20)
    plt.legend(loc='lower left', )
    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/cheta_in_coilsL2.pdf'))

    # plot L1DA tf
    fig, axs = plt.subplots(2, 1, sharex=True,figsize=(12,9), tight_layout=True)
    axs[0].loglog(freq_deltaL, np.abs(L3DA[:]), label = 'L1 DA')
    axs[0].set_ylabel('Magnitude [cts/cts]')
    axs[0].grid(True, 'major', alpha=.5)
    axs[0].grid(True, 'minor', alpha=.2)
    axs[0].legend()

    axs[1].semilogx(freq_deltaL, np.angle(L3DA[:], deg=True),)
    axs[1].set_xlabel('Frequency [Hz]')
    axs[1].set_ylabel('Phase [deg]')
    fig.suptitle("DARM-ctrl to L1 counts Measurement 2025-01-23 T21:11:18")
    axs[1].grid(True, 'both')
    axs[1].set_xlim(.1, max(freq_deltaL))

    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/darm_l1da.pdf'))

    # plot CHETA data in COILsL1 counts
    ratio = 2**-19*L1_DAC_rms[1]*100
    fig = plt.figure(figsize=(14,9), tight_layout=True)
    plt.hlines(2**19, freq_deltaL[0], freq_deltaL[-1], color='C3', ls='--',
            label="Saturation Level")
    plt.loglog(freq_deltaL, L1_DAC, color='C0', label='CHETA noise in COILsL1 cts/rtHz',
               linewidth = .8)
    plt.loglog(freq_deltaL, L1_DAC_rms, color='C0', ls='--', 
            label=f'CHETA noise in COILs cts RMS\n% of Saturation Limit: {ratio:.3f}%',
               linewidth = .8)
    
    ## add DARM at NLN plot ##
    plt.loglog(NLN_fr, L1_DAC_from_nln, color='C1', label='DARM at NLN in L1coils cts/rtHz',
               linewidth = .8)
    plt.loglog(NLN_fr, L1_DAC_from_nln_rms, color='C1', ls='--',
               linewidth = .8)

    plt.xlabel("Frequency [Hz]")
    plt.ylabel("CHETA in COILs for L1 [cts/rtHz]")
    plt.title(r"Calibrating $\delta l_{CHETA}$ to counts before L1_DAC")
    plt.grid(True, 'major', 'both', alpha=.5)
    plt.grid(True, 'minor', 'both', alpha=.2)
    plt.xlim(10, max(freq_deltaL))
    print(f"\nPlotting Frequencies: {10} to {max(freq_deltaL)}")
    plt.ylim(1e-11, 2**20)
    plt.legend(loc='lower left', )
    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/cheta_in_coilsL1.pdf'))

    # plot CHETA RIN raw
    fig = plt.figure(figsize=(14,9), tight_layout=True)

    plt.loglog(freq_deltaL, data['CHETA_RIN_OPEN_raw'][:855], color='C0', label='CHETA RIN open')
    plt.loglog(freq_deltaL, asd_to_rms(freq_deltaL, data['CHETA_RIN_OPEN_raw'][:855]), 
               color='C0', ls='--')
    

    plt.xlabel("Frequency [Hz]")
    plt.ylabel("CHETA RIN raw [1/rtHz]")
    plt.title(r"Relative Intensity Noise of CHETA CO2 Laser")
    plt.grid(True, 'major', 'both', alpha=.5)
    plt.grid(True, 'minor', 'both', alpha=.2)
    plt.xlim(10, max(freq_deltaL))
    print(f"\nPlotting Frequencies: {10} to {max(freq_deltaL)}")
    # plt.ylim(1e-11, 2**20)
    plt.legend(loc='lower left', )
    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/cheta_rin_raw.pdf'))
