'''Pull ASC coupling functions from real data and calibrate'''
# %%
import numpy as np
import matplotlib.pyplot as plt
import h5py
import gwpy.timeseries
from gwpy.plot import BodePlot
import scipy.signal as sig
from matplotlib.ticker import LogLocator
import matplotlib as mpl

from functions import *

from QUAD_suspension import QUAD_suspension

fontsize = 12
mpl.rcParams.update(
    {
        "text.usetex": True,
        "figure.figsize": (12, 9),
        "font.family": "serif",
        "font.serif": "georgia",
        # 'mathtext.fontset': 'cm',
        "lines.linewidth": 2,
        "font.size": fontsize,
        "xtick.labelsize": fontsize,
        "ytick.labelsize": fontsize,
        "legend.fancybox": True,
        "legend.fontsize": fontsize,
        "legend.framealpha": 0.7,
        "legend.handletextpad": 0.5,
        "legend.labelspacing": 0.2,
        "legend.loc": "best",
        "savefig.dpi": 80,
        "pdf.compression": 9,
    }
)


%matplotlib inline
# %%
# conversion from G1100968
# for PUM, 20-bit DACs
# 20 V /2**20 ct * 0.268 mA / V * 0.0309 N / A * 70.7 mm * 4 [DAC correction BS] * 3.5355 [eul2osem] * 4
# pit/yaw lever arm is the same for the PUM
to_Nm = (20 / 2**20) * 0.268e-3 * 0.0309 * 70.7e-3 * 4 * 3.5355 * 4 # Nm/ct
# %%
# import the data from the noise budget injections for HARD loops and DARM
dt = 60

asc_chans = [
    'H1:ASC-DHARD_P_OUT_DQ',
    'H1:ASC-CHARD_P_OUT_DQ',
    'H1:ASC-DHARD_Y_OUT_DQ',
    'H1:ASC-CHARD_Y_OUT_DQ',
    'H1:ASC-CSOFT_P_OUT_DQ'
]
# times from the ASC noise budget injections
gps_dict = {'CHARD_P': 1420222806, 'DHARD_P': 1420222915, 
            'CHARD_Y': 1420223207, 'DHARD_Y': 1420223021,
            }

# times from the ASC noise budget injections
gps_dict = {'CHARD_P': {'inj': 1420222806, 'quiet': 1420227133}, 'DHARD_P': {'inj': 1420222915, 'quiet': 1420227133}, 
            'CHARD_Y': {'inj': 1420223207, 'quiet': 1420227133}, 'DHARD_Y': {'inj': 1420223021, 'quiet': 1420227133},
            }


keys = gps_dict.keys()

# plot colors
colors_dict = {'CHARD_P': 'xkcd:grass green', 'DHARD_P': 'xkcd:cornflower', 
               'CHARD_Y': 'xkcd:tangerine', 'DHARD_Y': 'xkcd:crimson',
               }

# %%
# import time series data for ASC and strain
# convert strain into DARM (meters)
asc_data = {key:{'inj': None, 'quiet': None} for key in keys}
darm_data = {key:{'inj': None, 'quiet': None} for key in keys}

for key in keys:
    asc_data[key]['inj'] = gwpy.timeseries.TimeSeries.get('H1:ASC-'+key+'_OUT_DQ', gps_dict[key]['inj'], gps_dict[key]['inj']+dt, verbose=True, host='nds.ligo-wa.caltech.edu')
    darm_data[key]['inj'] = 3995 * gwpy.timeseries.TimeSeries.get('H1:GDS-CALIB_STRAIN', gps_dict[key]['inj'], gps_dict[key]['inj']+dt, verbose=True, host='nds.ligo-wa.caltech.edu')
    asc_data[key]['quiet'] = gwpy.timeseries.TimeSeries.get('H1:ASC-'+key+'_OUT_DQ', gps_dict[key]['quiet'], gps_dict[key]['quiet']+dt, verbose=True, host='nds.ligo-wa.caltech.edu')
    darm_data[key]['quiet'] = 3995 * gwpy.timeseries.TimeSeries.get('H1:GDS-CALIB_STRAIN', gps_dict[key]['quiet'], gps_dict[key]['quiet']+dt, verbose=True, host='nds.ligo-wa.caltech.edu')

# %%
# for the excess power coupling
fftLen = 8 # s
overlap = 0.75
average_type = 'median'

def welch_psd(time_data, fft_length, overlap, average_type):
    # time data must be in gwpy timeseries format
    ff, psd = sig.welch(time_data.value,
                        fs= time_data.sample_rate.value, 
                        window='hann', 
                        scaling='density',
                        nperseg= time_data.sample_rate.value * fft_length,  
                        noverlap= time_data.sample_rate.value * fft_length * overlap, 
                        detrend='constant', 
                        return_onesided=True,
                        average=average_type,
                        )
    return ff, psd
def excess_power_coupling_darm(freq, quiet_wit, quiet_tar, inj_wit, inj_tar):
    ff, quiet_wit_psd = welch_psd(quiet_wit, fftLen, overlap, average_type)
    ff, inj_wit_psd = welch_psd(inj_wit, fftLen, overlap, average_type)
    f, quiet_tar_psd = welch_psd(quiet_tar, fftLen, overlap, average_type)
    f, inj_tar_psd = welch_psd(inj_tar, fftLen, overlap, average_type)

    quiet_wit_psd = np.interp(freq, ff, quiet_wit_psd)
    inj_wit_psd = np.interp(freq, ff, inj_wit_psd)
    quiet_tar_psd = np.interp(freq, f, quiet_tar_psd)
    inj_tar_psd = np.interp(freq, f, inj_tar_psd)

    coupling = (inj_tar_psd - quiet_tar_psd) / (inj_wit_psd - quiet_wit_psd)

    return np.sqrt(coupling)

# %%
# resample the darm data to 512 Hz
darm_data_512 = {key:{'inj': None, 'quiet': None} for key in keys}
for key in keys:
    darm_data_512[key]['inj'] = gwpy.timeseries.TimeSeries(sig.resample_poly(darm_data[key]['inj'], 1, 32), t0=darm_data[key]['inj'].t0, sample_rate=512.0)
    darm_data_512[key]['quiet'] = gwpy.timeseries.TimeSeries(sig.resample_poly(darm_data[key]['quiet'], 1, 32), t0=darm_data[key]['inj'].t0, sample_rate=512.0)
# %%
# take transfer function to get m/ASC cts, convert to m/Nm pum torque
freq = {key:None for key in keys}
tf = {key:None for key in keys}
cohe = {key:None for key in keys}
for key in keys:
    tf[key], freq[key], cohe[key] = tfe(asc_data[key]['inj'].value, darm_data_512[key]['inj'].value, fs=512, window='hann', nperseg=8*512, noverlap=6*512)
    tf[key] = tf[key]/to_Nm

# %%
# calculate error bars based on coherence
mag_std = {key:None for key in keys}
ph_std = {key:None for key in keys}
N_avg = 15
for key in keys:
    mag_std[key] = np.sqrt(csd_variance_mag(cohe[key], N_avg))
    ph_std[key] = np.sqrt(csd_variance_phase(cohe[key], N_avg))

# %%
# plot NB injection calibrated into m/Nm, with coherence
fig, axs = plt.subplots(3,1, sharex=True)
for key in keys:
    axs[0].loglog(freq[key], np.abs(tf[key]), label=key, color=colors_dict[key])
    axs[0].set_ylabel('DARM/ASC drive [m/Nm]')
    axs[0].set_ylim(1e-12, 1e-7)
    axs[1].semilogx(freq[key], np.angle(tf[key], deg=True), color=colors_dict[key])
    axs[1].set_ylim(-180,180)
    axs[1].set_yticks([-180, -90, 0, 90, 180])
    axs[1].set_ylabel('Phase [deg]')
    axs[2].semilogx(freq[key], cohe[key], color=colors_dict[key])
    axs[2].set_ylim(0,1)
    axs[2].set_yticks([0, 0.25, 0.5, 0.75, 1.0])
    axs[2].set_xlim(10,100)
    axs[2].set_ylabel('Coherence')
    axs[2].set_xlabel('Frequency [Hz]')
axs[0].legend()
axs[0].set_title('ASC Noise Budget Injections')
fig.savefig('plots/DARM_PUM_ASC_drive.pdf')

# %%
# import free quad L2 to L3 transfer functions from state space model
freqs = freq['CHARD_P']
# rad / Nm transfer functions
PUM_TST_P = QUAD_suspension(freqs, 'L2', 'P', damped=True)
PUM_TST_Y = QUAD_suspension(freqs, 'L2', 'Y', damped=True)

# %%
# HARD plants from Gabriele

Z = [-0.37656602]
P = [-0.00292406 +0.j,-0.11496697 +6.49363092j,
 -0.11496697 -6.49363092j,-0.19955184+16.53674456j,
 -0.19955184-16.53674456j]
K = 2176.925725327256

__, hard_p = sig.freqresp((Z,P,K), 2.*np.pi*freqs)

# scale them by the magnitude at 10 Hz
idx10 = np.where(freqs==20)[0][0]
np_scale_p = np.abs(PUM_TST_P[idx10])
hp_scale_p = np.abs(hard_p[idx10])

plt.figure()
plt.loglog(freqs, np.abs(PUM_TST_P), label='plant model, no power')
plt.loglog(freqs, np.abs(hard_p*np_scale_p/hp_scale_p), label='fitted plant, high power')
plt.legend()
plt.title('HARD P')

z = [-4.27268475+16.14989614j, -4.27268475-16.14989614j,
        -0.84677811+11.7910939j , -0.84677811-11.7910939j ,
        -0.15001959 +3.1577398j , -0.15001959 -3.1577398j ,
        -0.14073526 +2.60978661j, -0.14073526 -2.60978661j,
        -0.89432965 +1.06440904j, -0.89432965 -1.06440904j]
p = [-1.26172189+18.71400182j, -1.26172189-18.71400182j,
        -0.42125296+16.04217932j, -0.42125296-16.04217932j,
        -1.88894857+14.52668064j, -1.88894857-14.52668064j,
        -0.17624191 +6.42823571j, -0.17624191 -6.42823571j,
        -0.08126954 +3.12803462j, -0.08126954 -3.12803462j,
        -0.33177813 +2.74361234j, -0.33177813 -2.74361234j,
        -1.01199783 +0.j        , -0.20377166 +0.j        ]
k = -2608.840762444897

__, hard_y = sig.freqresp((z,p,k), 2.*np.pi*freqs)

np_scale_y = np.abs(PUM_TST_Y[idx10])
hp_scale_y = np.abs(hard_y[idx10])

plt.figure()
plt.loglog(freqs, np.abs(PUM_TST_Y), label='plant model, no power')
plt.loglog(freqs, np.abs(hard_y*np_scale_y/hp_scale_y), label='fitted plant, high power')
plt.legend()
plt.title('HARD Y')

p_scale = np_scale_p/hp_scale_p
y_scale = np_scale_y/hp_scale_y

# calibrate the high power plant
hard_p_cal = p_scale * hard_p 

hard_y_cal = y_scale * hard_y 

# %%
# this section is Lee's method for calibration, but I am calibrating at 10 Hz instead
# keeping this for reference anyway
# see T1100595
import scipy.constants as scc
def test_hard_soft_modes(Parm_W = 350e3):
    # Parameters
    Q = 100
    I_kgm2 = 2.73
    
    Larm_m = 4e3
    Ri_m = 1934
    Re_m = 2245
    gi = 1 - Larm_m / Ri_m
    ge = 1 - Larm_m / Re_m

    # Hard and soft torsional stiffness
    k0 = 2 * Parm_W * Larm_m / (scc.c * (gi * ge - 1))
    kh = k0 * (ge + gi - np.sqrt((ge - gi)**2 + 4)) / 2
    ks = k0 * (ge + gi + np.sqrt((ge - gi)**2 + 4)) / 2
    # print("kh", kh)
    # print("ks", ks)
    # print("resonance [Hz]:",  (kh / I_kgm2 * 2)**0.5 / (2 * np.pi))
    return kh, ks
kh, ks = test_hard_soft_modes()

# add in the stiffness of the quad without RPN
kh_p = 10 + kh # 10 Nm/rad stiffness of free sus in pitch
kh_y = 20 + kh # 20 Nm/rad stiffness of free sus in yaw

# %%
coupling = {key:None for key in keys}
# convert to test mass motion using suspension model TF
# convert from m to mm
for key in ['CHARD_P', 'DHARD_P']:
    coupling[key] = 1e3 * tf[key] / hard_p_cal

for key in ['CHARD_Y', 'DHARD_Y']:
    coupling[key] = 1e3 * tf[key] / hard_y_cal

# %%
# make some nice plots
fig, axs = plt.subplots(2,1, sharex=True)

for key in ['CHARD_P', 'DHARD_P']:
    axs[0].loglog(freqs, np.abs(coupling[key]), label=key, color=colors_dict[key])
    axs[0].fill_between(freqs, np.abs(coupling[key])*(1+mag_std[key]), np.abs(coupling[key])*(1-mag_std[key]), color= colors_dict[key], label = '$\pm 1$ $\sigma$', alpha=0.3)

for key in ['CHARD_P', 'DHARD_P']:
    axs[1].semilogx(freqs, np.angle(coupling[key], deg=True), label=key, color=colors_dict[key])
    axs[1].fill_between(freqs, np.angle(coupling[key], deg=True) + (180/np.pi)*ph_std[key], np.angle(coupling[key], deg=True) - (180/np.pi)*ph_std[key], label=key, color=colors_dict[key], alpha=0.3)

axs[0].set_xlim(10, 100)
axs[0].set_ylim(0.1,10)
axs[1].set_xlabel('Freq [Hz]')
axs[1].set_yticks([-180, -90, 0, 90, 180])
axs[1].set_ylim(-180, 180)
axs[0].legend(loc='upper right')
axs[0].set_ylabel('DARM/ASC motion [mm/rad]')
axs[1].set_ylabel('Phase [deg]')
axs[0].set_title('Test Mass Angle to Length coupling')
fig.savefig('plots/DARM_ASC_TST_coupling_pitch.pdf')
plt.show()

# %%
fig, axs = plt.subplots(2,1, sharex=True)

for key in ['CHARD_Y', 'DHARD_Y']:
    axs[0].loglog(freqs, np.abs(coupling[key]), label=key, color=colors_dict[key])
    axs[0].fill_between(freqs, np.abs(coupling[key])*(1+mag_std[key]), np.abs(coupling[key])*(1-mag_std[key]), color= colors_dict[key], label = '$\pm 1$ $\sigma$', alpha=0.3)

for key in ['CHARD_Y', 'DHARD_Y']:
    axs[1].semilogx(freqs, np.angle(coupling[key], deg=True), label=key, color=colors_dict[key])
    axs[1].fill_between(freqs, np.angle(coupling[key], deg=True) + (180/np.pi)*ph_std[key], np.angle(coupling[key], deg=True) - (180/np.pi)*ph_std[key], label=key, color=colors_dict[key], alpha=0.3)

axs[0].set_xlim(10, 100)
axs[0].set_ylim(0.01,100)
axs[1].set_ylim(-180, 180)
axs[1].set_yticks([-180, -90, 0, 90, 180])
axs[1].set_xlabel('Freq [Hz]')
axs[0].legend(loc='upper right')
axs[0].set_ylabel('DARM/ASC motion [mm/rad]')
axs[1].set_ylabel('Phase [deg]')
axs[0].set_title('Test Mass Angle to Length coupling')
fig.savefig('plots/DARM_ASC_TST_coupling_yaw.pdf')
plt.show()
# %%
coupling = {key:None for key in keys}
# calculate excess power coupling
for key in keys:
    coupling[key] = excess_power_coupling_darm(freq[key], asc_data[key]['quiet'], darm_data_512[key]['quiet'], asc_data[key]['inj'], darm_data_512[key]['inj'])
# %%
plt.figure()
for key in keys:
    plt.loglog(freq[key], np.abs(tf[key]), label=key + ' linear', color = colors_dict[key], alpha = 0.5)
    plt.loglog(freq[key], coupling[key]/to_Nm, label=key + ' excess power', color = colors_dict[key], linestyle = '--')
    plt.xlim(10,100)
    plt.ylabel('DARM/ASC drive [m/Nm]')
    plt.xlabel('Frequency [Hz]')
    plt.ylim(1e-13, 1e-7)
    plt.legend()
plt.title('ASC coupling')
fig.savefig('plots/DARM_ASC_linear_excess_power_coupling.pdf')

# %%
# convert to test mass motion using suspension model TF
# convert from m to mm
coupling_linear = {key:None for key in keys}
coupling_excess_power = {key:None for key in keys}
for key in ['CHARD_P', 'DHARD_P']:
    coupling_linear[key] = 1e3 * tf[key] / hard_p_cal
    coupling_excess_power[key] = 1e3 * coupling[key] / np.abs(hard_p_cal) / to_Nm

for key in ['CHARD_Y', 'DHARD_Y']:
    coupling_linear[key] = 1e3 * tf[key] / hard_y_cal
    coupling_excess_power[key] = 1e3 * coupling[key] / np.abs(hard_y_cal) / to_Nm

# %%
plt.figure()
for key in ['CHARD_P', 'DHARD_P']:
    plt.loglog(freq[key], np.abs(coupling_linear[key]), label=key + ' linear', color = colors_dict[key], alpha = 0.5)
    plt.loglog(freq[key], coupling_excess_power[key], label=key + ' excess power', color = colors_dict[key], linestyle = '--')
    plt.xlim(10,100)
    plt.ylabel('DARM/ASC motion [mm/rad]')
    plt.xlabel('Frequency [Hz]')
    plt.ylim(0.1, 10)
    plt.legend()
plt.title('ASC Pitch coupling')
fig.savefig('plots/pitch_linear_excess_power_coupling.pdf')

plt.figure()
for key in ['CHARD_Y', 'DHARD_Y']:
    plt.loglog(freq[key], np.abs(coupling_linear[key]), label=key + ' linear', color = colors_dict[key], alpha = 0.5)
    plt.loglog(freq[key], coupling_excess_power[key], label=key + ' excess power', color = colors_dict[key], linestyle = '--')
    plt.xlim(10,100)
    plt.ylabel('DARM/ASC motion [mm/rad]')
    plt.xlabel('Frequency [Hz]')
    plt.ylim(0.1, 10)
    plt.legend()
plt.title('ASC Yaw coupling')
fig.savefig('plots/yaw_linear_excess_power_coupling.pdf')
# %%
