# %%

import numpy as np
import gwpy.timeseries
import matplotlib.pyplot as plt

from gwpy.astro import sensemon_range_psd

%matplotlib inline
# %%
# choose some times to compare DARM
# more than 2 hours of clean no glitch time after these two times
gps_before = 1423272193
gps_after = 1424957186

# start with 10 minutes of data
dt = 1800

strain_chans = ['H1:GDS-CALIB_STRAIN', 
                'H1:GDS-CALIB_STRAIN_NOLINES', 
                'H1:GDS-CALIB_STRAIN_CLEAN']

strain_data_before = gwpy.timeseries.TimeSeriesDict.get(
    strain_chans,
    gps_before,
    gps_before + dt,
    verbose=True, 
    host='nds.ligo-wa.caltech.edu'
)

strain_data_after = gwpy.timeseries.TimeSeriesDict.get(
    strain_chans,
    gps_after,
    gps_after + dt,
    verbose=True, 
    host='nds.ligo-wa.caltech.edu'
)


# %%
# import strain/pcal transfer functions

data_before = np.loadtxt('PCAL_strain_cal_before_cal_change.txt')
ff_before = data_before[:,0]
tf_before = data_before[:,1]*np.exp(np.pi*1j*data_before[:,2]/180)

data_after = np.loadtxt('PCAL_strain_cal_after_TDCF_burnin.txt')
ff_after = data_after[:,0]
tf_after = data_after[:,1]*np.exp(np.pi*1j*data_after[:,2]/180)

fig, ax = plt.subplots(2,1, sharex=True)
ax[0].semilogx(ff_before, np.abs(tf_before), label = 'before calibration change')
ax[0].semilogx(ff_after, np.abs(tf_after), label = 'after TDCF burn in')
ax[0].legend()
ax[0].set_ylabel('GDS/PCAL [m/m]')
ax[0].set_ylim(0.9, 1.1)
ax[1].semilogx(ff_before, np.angle(tf_before, deg=True))
ax[1].semilogx(ff_after, np.angle(tf_after, deg=True))
ax[1].set_xlabel('Frequency [Hz]')
ax[1].set_ylabel('Phase [deg]')
ax[1].set_ylim(-5,5)
# %%
nolines_before = 4000* strain_data_before['H1:GDS-CALIB_STRAIN_NOLINES'].asd(
    fftlength=8, overlap=4, window='hann', method='median').crop(9.0, 450.125)

nolines_after = 4000* strain_data_after['H1:GDS-CALIB_STRAIN_NOLINES'].asd(
    fftlength=8, overlap=4, window='hann', method='median').crop(9.0, 450.125)

plt.figure()
plt.loglog(nolines_before.frequencies.value, nolines_before.value, label='GDS before')
plt.loglog(nolines_after.frequencies.value, nolines_after.value, label= 'GDS after')
plt.xlim(9, 450)
plt.ylim(1e-20, 1e-15)
plt.ylabel('GDS Strain * 4000 m')
plt.xlabel('Frequency [Hz]')
plt.legend()
# %%
# now convert both sets of strain meters into PCAL meters

nolines_before_pcal = nolines_before / np.abs(tf_before)
nolines_after_pcal = nolines_after / np.abs(tf_after)
plt.figure()
plt.loglog(nolines_before.frequencies.value, nolines_before_pcal, label='before')
plt.loglog(nolines_after.frequencies.value, nolines_after_pcal, label= 'after')
plt.xlim(9, 450)
plt.ylim(1e-20, 1e-15)
plt.ylabel('PCAL m')
plt.xlabel('Frequency [Hz]')
plt.legend()
# %%

def cum_range_diff(before_psd, after_psd, seg):
    before_cum_range_sq = sensemon_range_psd(before_psd).cumsum() * 1/seg
    after_cum_range_sq = sensemon_range_psd(after_psd).cumsum() * 1/8

    before_range = np.sqrt(max((sensemon_range_psd(before_psd).cumsum() * (1/seg)).value))
    after_range = np.sqrt(max((sensemon_range_psd(after_psd).cumsum() * (1/seg)).value))

    r_diff = (after_cum_range_sq - before_cum_range_sq) / (after_range + before_range)

    return r_diff

def total_range(psd, seg):
    return np.sqrt(max((sensemon_range_psd(psd).cumsum() * (1/seg)).value))
# %%
r_diff_uncorrected = cum_range_diff(nolines_before**2, nolines_after**2, 8)
uncorrected_range = total_range(nolines_before**2, 8)

r_diff_corrected = cum_range_diff(nolines_before_pcal**2, nolines_after_pcal**2, 8)
corrected_range = total_range(nolines_before_pcal**2, 8)
# %%
plt.figure()
plt.semilogx(nolines_before.frequencies.value, 100*r_diff_uncorrected/uncorrected_range, label = 'calib-uncorrected range difference')
plt.semilogx(nolines_before.frequencies.value, 100*r_diff_corrected/corrected_range, label = 'calib-corrected range difference')
#plt.ylim(-.0008,.0005)
plt.legend()
plt.xlabel('Frequency [Hz]')
plt.ylabel('% normalized range difference')
plt.title('After calib update - Before calib update')
plt.savefig('range_difference_nolines.jpg')


# %%
