import numpy as np
import matplotlib.pyplot as plt
import os, termcolor, matplotlib, configparser
from pydarm.darm import DARMModel
import scipy.signal as signal
import foton, ipdb
import codecs
from gwpy.timeseries import *

matplotlib.rcParams.update({'font.size': 14})
matplotlib.rcParams.update({'figure.figsize': (14,9)})

def get_cheta_rin():
    files = os.listdir('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN/')
    freqs = {}
    psd = {}
    for f in files:
        data = codecs.open('/ligo/home/matthewrichard.todd/Projects//cheta_rin/darm_to_esd/20240815_CHETA_RIN/' + f, encoding='cp1252')
        x = np.loadtxt(data, skiprows=23)
        kw = f.split('_')[0]
        if kw in freqs.keys():
            #freqs[kw].append(x[:,0])
            psd[kw].append(x[:,1])
        else:
            freqs[kw] = x[:,0]
            psd[kw] = [x[:,1],]
            
    for kw in freqs.keys():
        psd[kw] = np.sqrt((np.array(psd[kw])**2).mean(axis=0))

    dc = 2.4

    # concatenate some of the measurements
    kws = ['6.25Hz', '25Hz', '100Hz', '400Hz', '3.2kHz', '25.6kHz']

    fr = []
    sp = []
    for k in kws:
        if len(fr) == 0:
            fr = freqs[k]
            sp = psd[k]
        else:
            idx = freqs[k] > fr.max()
            fr = np.concatenate([fr, freqs[k][idx]])
            sp = np.concatenate([sp, psd[k][idx]])

    # read old measurement

    files = [os.path.join('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN','SCRN0003.TXT'), 
            os.path.join('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN','SCRN0004.TXT'), 
            os.path.join('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN','SCRN0005.TXT')]
    f = []
    p = []
    for ff in files:
        x = np.loadtxt(ff)
        f.append(x[:,0])
        p.append(x[:,1])
        

    # concatenate old measurements

    fr_old = []
    sp_old = []
    for ff,pp in zip(f[::-1], p[::-1]):
        if len(fr_old) == 0:
            fr_old = ff
            sp_old = pp
        else:
            idx = ff > fr_old.max()
            fr_old = np.concatenate([fr_old, ff[idx]])
            sp_old = np.concatenate([sp_old, pp[idx]])



    ## Load new RIN measurements
    R = np.loadtxt(os.path.join('/ligo/home/matthewrichard.todd/Projects/cheta_rin/darm_to_esd/20240815_CHETA_RIN','CHETA_RIN_PWM.txt'))
    Ropen  = np.interp(fr, R[:,0], R[:,1]/0.5)

        
    rin = 1e-7*np.ones_like(fr)
    rin[fr<10] = 1e-7 * (fr[fr<10]/10)**-1
    rin[fr>1000] = 1e-7 * (fr[fr>1000]/1000)**1.4

    gps = 1379395092
    dt = 300

    channels = ['L1:GDS-CALIB_STRAIN_CLEAN', 'H1:GDS-CALIB_STRAIN_CLEAN']
    data = TimeSeriesDict.get(channels, gps, gps+dt)

    fs = 16384
    Np = 5*fs

    sp_ = {}
    for c in channels:
        fr_, sp_[c] = signal.welch(data[c].value, nperseg=Np, noverlap=Np//2, fs=fs, average='median')
        sp_[c] = np.sqrt(sp_[c])



    P = 0.5    # CO2 power
    TRo = np.sqrt(2) * 2 * 1e-20 * (150/fr) * (P/25e-3)*(Ropen/4e-6)
    freq = fr

    data = {
        'Frequencies': freq,
        'CHETA_RIN_OPEN_raw': Ropen,
        'CHETA_RIN_OPEN_calibrated': TRo,
        'NLN_DARM': np.array([fr_, 4000*sp_['H1:GDS-CALIB_STRAIN_CLEAN']]),

    }
    return data

def asd_to_rms(f, asd, reverse=False):
    rms_curve = np.zeros_like(asd)
    if not reverse:
        for i in range(len(f)-1, -1, -1):
            rms = np.sqrt(np.sum(asd[i:]**2))
            rms_curve[i] = rms
    else:
        for i in range(0, len(f)):
            rms = np.sqrt(np.sum(asd[:i]**2))
            rms_curve[i] = rms
    print()
    return rms_curve

def read_tf_txt(filename, comment_char='#',
                delimiter = " ", freq_sep = " ",
                ):
    
    file = open(filename, 'r')
    lines = file.readlines()
    
    nlines = sum([0 if line[0] == comment_char else 1 for line in lines])
    print(termcolor.colored(f"\nReading from {file.name} ...", color='green'))
    print(termcolor.colored(f"{nlines} data points found.", color='white'))
    
    comments = []
    i = 0
    while True:
        line = lines[i]
        if line[0] != comment_char:
            break
        else:
            comments.append(line[1:-1].strip())
            i += 1

    main_body = lines[i:]
    
    print('Comments:')
    for each in comments[:-1]:
        print('\t', each)

    subdata = [sub.strip() for sub in line.split(freq_sep)]
    to_skip = []
    for j in range(len(subdata)):
        if subdata[j] == '':
            to_skip.append(j)
    
    ntfs = (len(subdata)-1 - len(to_skip))//2
    freq = np.ndarray((nlines, 1))
    data = np.ndarray((nlines, ntfs*2))
    for j in range(len(main_body)):
        subdata = [sub.strip() for sub in main_body[j].split(freq_sep)]
        to_skip = []
        for k in range(len(subdata)):
            if subdata[k] == '':
                to_skip.append(k)
        subdata = [subdata[k] for k in range(len(subdata)) if k not in to_skip]
        freq[j] = float(subdata[0])
        data[j] = [float(subdata[k+1]) for k in range(len(subdata[1:]))]

    file.close()

    return freq.T.flatten(), data

def calibrate(freq_data, data, freq_cal, data_cal):
    """
    Essentially an interpolation function for multiplying transfer functions that do not have the exact same frequency
    arrays.
    """
    calibrated_data = np.zeros_like(data)
    for i in range(len(freq_data)):

        freq_cal_index = np.searchsorted(freq_cal, freq_data[i])

        if freq_cal_index == 0:
            f0, f1 = freq_cal[0], freq_cal[1]
            cal0, cal1 = data_cal[0], data_cal[1]

            slope = (cal1-cal0)/(f1-f0)
        elif freq_cal_index == len(freq_cal)-1:
            f0, f1 = freq_cal[-2], freq_cal[-1]
            cal0, cal1 = data_cal[-2], data_cal[-1]

            slope = (cal1-cal0)/(f1-f0)
        elif freq_cal_index == len(freq_cal):
            f0, f1 = freq_cal[-2], freq_cal[-1]
            cal0, cal1 = data_cal[-2], data_cal[-1]

            slope = (cal1-cal0)/(f1-f0)
        else:
            f0, f1 = freq_cal[freq_cal_index], freq_cal[freq_cal_index+1]
            cal0, cal1 = data_cal[freq_cal_index], data_cal[freq_cal_index+1]

            slope = (cal1-cal0)/(f1-f0)

        calibrated_data[i] = (slope*(freq_data[i]-f0) + cal0)*data[i]
        # print(i, freq_cal_index, freq_cal[freq_cal_index], len(freq_data), data[i], (slope*(freq_data[i]-f0) + cal0), calibrated_data[i])
        

    return calibrated_data

def get_foton_filters(freq, file, module, fms = None):

    
    ff = foton.FilterFile(file , read_only=True)
    fmname = module

    fm = ff[fmname]
    tf = np.ones_like(freq, dtype=np.complex128)
    
    if fms != None:
        if isinstance(fms, int) :
            tf *= fm[fms].freqresp(freq)
        elif isinstance(fms, slice) :
            for section in fms[fms]:
                tf *= section.freqresp(freq)
        elif isinstance(fms, list):
            for section in fms:
                tf *= fm[section].freqresp(freq)
    
    else:
        for section in fm:
            tf *= section.freqresp(freq)
    
    return tf

if __name__ == "__main__":

    # reading RIN data for CHETA (calibrated to DELTAL already)
    data = get_cheta_rin()
    freq_deltaL, rin_deltaL = data['Frequencies'], data['CHETA_RIN_OPEN_calibrated']
    NLN_fr, NLN_deltaL = data['NLN_DARM']
    freq_deltaL = freq_deltaL[:855]
    rin_deltaL = rin_deltaL[:855]

    # darm coupled cavity optical gain from pydarm report
    darm = DARMModel(os.path.join('/ligo/groups/cal/H1/reports/20250123T211118Z/', 
                                  'pydarm_H1.ini'))
    G = darm.compute_darm_olg(freq_deltaL)
    
    C = darm.sensing.compute_sensing(freq_deltaL)
    freqM1, M1 = read_tf_txt(os.path.join(os.path.dirname(__file__), 
                                          'omc_dcpd_a_to_darm_err_tf.txt'))
    M2 = get_foton_filters(freq_deltaL, '/opt/rtcds/lho/h1/chans/H1IOPOMC0.txt', 
                           'OMC_DCPD_A0')

    AW = get_foton_filters(freq_deltaL, '/opt/rtcds/lho/h1/chans/H1IOPOMC0.txt', 
                           'OMC_DCPD_A0', 1)

    M = calibrate(freq_deltaL, M2, freqM1, M1[:,0]*np.exp(1j*np.deg2rad(M1[:,1])))
    tf_to_ADC = C / M
    tf_to_ADC_unwhitened = tf_to_ADC * AW
    deltaLctrl_to_DCPDcts = tf_to_ADC_unwhitened

    
    DCPDcts_cheta = rin_deltaL * np.abs( 1 / (1 - G) ) * np.abs( deltaLctrl_to_DCPDcts )
    DCPDcts_cheta_rms = asd_to_rms(freq_deltaL, DCPDcts_cheta)

    # calibrating DARM_NLN to ESD, L2, L1
    Cnln = darm.sensing.compute_sensing(NLN_fr)
    M2nln = get_foton_filters(NLN_fr, '/opt/rtcds/lho/h1/chans/H1IOPOMC0.txt', 
                           'OMC_DCPD_A0')
    AWnln = get_foton_filters(NLN_fr, '/opt/rtcds/lho/h1/chans/H1IOPOMC0.txt', 
                           'OMC_DCPD_A0', 1)
    Mnln = calibrate(NLN_fr, M2nln, freqM1, M1[:,0]*np.exp(1j*np.deg2rad(M1[:,1])))
    tf_to_ADC_nln = Cnln / Mnln
    tf_to_ADC_unwhitened_nln = tf_to_ADC_nln * AWnln
    deltaLctrl_to_DCPDcts_nln = tf_to_ADC_unwhitened_nln

    DCPDcts_nln = NLN_deltaL * np.abs( deltaLctrl_to_DCPDcts_nln )
    DCPDcts_nln_rms = asd_to_rms(NLN_fr, DCPDcts_nln)

    ########################   PLOTTING     ##################################
    # plot olg tf
    fig, axs = plt.subplots(2, 1, sharex=True,figsize=(12,9), tight_layout=True)
    axs[0].loglog(freq_deltaL, np.abs(G), label = 'DARM OLG')
    axs[0].set_ylabel('Magnitude')
    axs[0].grid(True, 'major', alpha=.5)
    axs[0].grid(True, 'minor', alpha=.2)
    axs[0].legend()

    axs[1].semilogx(freq_deltaL, np.angle(G[:], deg=True),)
    axs[1].set_xlabel('Frequency [Hz]')
    axs[1].set_ylabel('Phase [deg]')
    fig.suptitle("OLG Measurement 2025-01-23 T21:11:18")
    axs[1].grid(True, 'both')
    axs[1].set_xlim(.1, max(freq_deltaL))

    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/darm_olg.pdf'))

    # plot part Sensing Function tf
    fig, axs = plt.subplots(2, 1, sharex=True,figsize=(12,9), tight_layout=True)
    axs[0].loglog(freq_deltaL, np.abs(deltaLctrl_to_DCPDcts), 
                  label = 'DARM TO DCPD cts Sensing Function')
    axs[0].set_ylabel('Magnitude [cts/m]')
    axs[0].grid(True, 'major', alpha=.5)
    axs[0].grid(True, 'minor', alpha=.2)
    axs[0].legend()

    axs[1].semilogx(freq_deltaL, np.angle(deltaLctrl_to_DCPDcts, deg=True),)
    axs[1].set_xlabel('Frequency [Hz]')
    axs[1].set_ylabel('Phase [deg]')
    fig.suptitle("DARM to DCPD cts Trans. Func. Measurement 2025-01-23 T21:11:18")
    axs[1].grid(True, 'both')
    axs[1].set_xlim(.1, max(freq_deltaL))

    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/darm_deltaL_to_dcpd.pdf'))


    # plot CHETA data in DCPD counts
    ratio = 2**-15*DCPDcts_cheta_rms[1]*100
    fig = plt.figure(figsize=(14,9), tight_layout=True)
    plt.hlines(2**15, freq_deltaL[0], freq_deltaL[-1], color=['C3'], ls='--',
            label='Saturation Level')
    plt.loglog(freq_deltaL, DCPDcts_cheta, color='C0', label='CHETA noise in DCPD cts/rtHz',
               linewidth = .8)
    plt.loglog(freq_deltaL, DCPDcts_cheta_rms, color='C0', ls='--', 
            label=f'CHETA noise in DCPD cts RMS\n% of Saturation Limit: {ratio:.3f}%',
               linewidth = .8)
        ## add DARM at NLN plot ##
    plt.loglog(NLN_fr, DCPDcts_nln, color='C1', label='DARM at NLN in DCPD cts/rtHz',
               linewidth = .8)
    plt.loglog(NLN_fr, DCPDcts_nln_rms, color='C1', ls='--',
               linewidth = .8)

    plt.xlabel("Frequency [Hz]")
    plt.ylabel("CHETA in DCPDs [cts/rtHz]")
    plt.title(r"Calibrating $\delta l_{CHETA}$ to counts before DCPD")
    plt.grid(True, 'major', 'both', alpha=.5)
    plt.grid(True, 'minor', 'both', alpha=.2)
    plt.xlim(10, max(freq_deltaL))
    print(f"\nPlotting Frequencies: {10} to {max(freq_deltaL)}")
    plt.ylim(1e-3, 1e8)
    plt.legend(loc='lower left', )
    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/cheta_in_dcpds.pdf'))

    # plot CHETA RIN raw
    fig = plt.figure(figsize=(14,9), tight_layout=True)

    plt.loglog(freq_deltaL, data['CHETA_RIN_OPEN_raw'][:855], color='C0', label='CHETA RIN open',
               linewidth = .8)
    plt.loglog(freq_deltaL, asd_to_rms(freq_deltaL, data['CHETA_RIN_OPEN_raw'][:855]), 
               color='C0', ls='--',
               linewidth = .8)
    

    plt.xlabel("Frequency [Hz]")
    plt.ylabel("CHETA RIN raw [1/rtHz]")
    plt.title(r"Relative Intensity Noise of CHETA CO2 Laser")
    plt.grid(True, 'major', 'both', alpha=.5)
    plt.grid(True, 'minor', 'both', alpha=.2)
    plt.xlim(10, max(freq_deltaL))
    print(f"\nPlotting Frequencies: {10} to {max(freq_deltaL)}")
    # plt.ylim(1e-11, 2**20)
    plt.legend(loc='lower left', )
    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures/cheta_rin_raw.pdf'))
