#Python port of J.Driggers Matlab Ringdown file alog 50550
#Contains functions for analysing ringdowns in OMC-PI_DCPD and MODE#_RMSMON channels
#Author G.Wallace 18th Sept 2019

import numpy as np
import math, nds2
import scipy.signal as sp
import scipy.optimize as so
from tqdm import tqdm, trange
from scipy import array
from gwpy.time import to_gps
import matplotlib.pyplot as plt

def ringdown(UTC_excite, UTC_ringdown_start, UTC_ringdown_stop, BP_Low, BP_High, Graph=False):

    #Set connection to nds
    server = 'nds.ligo-wa.caltech.edu'
    conn = nds2.connection(server, 31200)
    conn.set_parameter('ALLOW_DATA_ON_TAPE', '1')
    channels = ['H1:OMC-PI_DCPD_64KHZ_AHF_DQ']

    #Fetch data in chunks, concatenate it for use

    time_excite         = to_gps(UTC_excite) #Convert to GPS time from UTC Originally 'Jul 03 2019 16:38:00.00'
    time_ringdown_start = to_gps(UTC_ringdown_start) #Originally 'Jul 03 2019 16:41:15.00'
    time_ringdown_end   = to_gps(UTC_ringdown_stop) #Originally 'Jul 03 2019 17:12:33.00'

    pretime     = 400 # start data taking a little beforehand, so can window / ignore it.
    times_fetch = np.arange((time_excite-pretime), time_ringdown_end, 500)
    duration    = 500 # 500 sec okay, 1000 sec not okay at one time.  Can concat them after fetching.

    srate    = 65536
    value    = srate * duration * len(times_fetch)
    DCPDdata = np.zeros((value,), dtype=float)
    times    = np.zeros((value,), dtype=float)

    num_sample = int(srate*duration)
    t          = np.arange(num_sample) / srate

    for ii in tqdm(range(0,(len(times_fetch)))):
        buffers = conn.fetch(int(times_fetch[ii]), int(times_fetch[ii]+duration), ['H1:OMC-PI_DCPD_64KHZ_AHF_DQ'])
        print( f"Connecting to NDS server {server}, Fetching {len(channels)} channel, Start GPS {int(times_fetch[ii])}, Duration {duration} sec" )
        DCPDdata[(ii)*duration*srate:(ii+1)*duration*srate] = buffers[0].data
        times[(ii)*duration*srate:(ii+1)*duration*srate]    = t+int(times_fetch[ii])-int(time_ringdown_start)

    #Bandpass Data
    def butter_bandpass(lowcut, highcut, fs, order=8):
        nyq     = 0.5 * fs
        low     = lowcut / nyq
        high    = highcut / nyq
        sos     = sp.butter(order, [low, high], analog=False, btype='band', output='sos')
        return sos


    def butter_bandpass_filter(data, lowcut, highcut, fs, order=8):
        sos = butter_bandpass(lowcut, highcut, fs, order=order)
        y   = sp.sosfilt(sos, data)
        return y

    BPData = butter_bandpass_filter(DCPDdata, BP_Low, BP_High, srate, order=8)

    del buffers, t, DCPDdata

    #Plot Data
    #plt.figure(1)
    #plt.plot(times, BPData)

    #Hilbert Transform time!!!
    hilb     = np.imag(sp.hilbert(BPData))
    envelope = np.sqrt(BPData**2 +hilb**2)

    del hilb

    #plt.figure(2)
    #plt.plot(times, BPData, 'ro', times, envelope, 'b--')

    #Decimate data
    envelope_downsampled = envelope[::128]
    times_downsampled    = times[::128]

    del envelope

    keepIDX = np.where((times_downsampled > 20) & (times_downsampled < max(times)-100))

    def fun(timepoints, a, b):
        return a * np.exp(-b * timepoints)

    # Fit data using initial guesses of parameters (from a fit of the log of the data)
    popt, pcov = so.curve_fit(fun, times_downsampled[keepIDX], envelope_downsampled[keepIDX], p0 = [4.8, -0.0039])
    Excitation_Frequency = (BP_Low + (0.5 * (BP_High-BP_Low)))
    phi = 1/((math.pi * Excitation_Frequency) * popt[0])
    #Plot Figures
    """
    plt.figure(3)
    plt.plot(times_downsampled, envelope_downsampled, 'r--', label='Envelope magnitude')
    plt.plot(times_downsampled[keepIDX], fun(times_downsampled[keepIDX], popt[0], popt[1]), 'b--', label='Fit: %.6f *exp( %.6f *t)' % (popt[0], popt[1]))
    plt.ylim(-150,150)
    plt.xlim(-100,max(times))
    plt.xlabel('Time since excitation stopped [sec]')
    plt.ylabel('Magnitude of bandpassed data [arb]')
    plt.title('OMC DCPD 64kHz channel, bandpassed, downsampled')
    plt.legend()
    """
    #Plot Final figure
    if Graph==True:
        plt.figure(4)
        plt.plot(times, BPData, label='Bandpassed DCPD')
        plt.plot(times_downsampled[keepIDX], fun(times_downsampled[keepIDX], popt[0], popt[1]),
        label='Fit: %.6f*exp(%.6f *t)' % (popt[0], popt[1]))
        plt.plot(times_downsampled[keepIDX], -fun(times_downsampled[keepIDX], popt[0], popt[1]),\
        label='-1*Fit: -%.6f*exp(%.6f*t)' % (popt[0], popt[1]))

        plt.ylim(-150,150)
        plt.xlim(-100,max(times))

        plt.xlabel('Time since excitation stopped [sec]')
        plt.ylabel('Bandpassed OMC DCPD [arb]')
        plt.title('OMC DCPD 64kHz channel, bandpassed %s Hz around %.1f Hz' % ((BP_High - BP_Low), (BP_Low + (BP_High - BP_Low))))

        plt.legend()
        plt.show()
        return Excitation_Frequency, phi
    else:
        return Excitation_Frequency, phi

def ringdown_RMS(UTC_excite, UTC_ringdown_start, UTC_ringdown_stop, Excitation_Frequency, Excitation_Mode, Graph=False):

    #Set connection to nds
    server = 'nds.ligo-wa.caltech.edu'
    conn = nds2.connection(server, 31200)
    conn.set_parameter('ALLOW_DATA_ON_TAPE', '1')
    channels = ['H1:SUS-PI_PROC_COMPUTE_MODE'+ str(Excitation_Mode) +'_RMSMON']

    #Fetch data in chunks, concatenate it for use

    time_excite         = to_gps(UTC_excite) #Convert to GPS time from UTC Originally 'Jul 03 2019 16:38:00.00'
    time_ringdown_start = to_gps(UTC_ringdown_start) #Originally 'Jul 03 2019 16:41:15.00'
    time_ringdown_end   = to_gps(UTC_ringdown_stop) #Originally 'Jul 03 2019 17:12:33.00'

    pretime     = 400 # start data taking a little beforehand, so can window / ignore it.
    times_fetch = np.arange((time_excite - pretime), time_ringdown_end, 500)
    duration    = 500 # 500 sec okay, 1000 sec not okay at one time.  Can concat them after fetching.

    srate    = 16
    value    = srate * duration * len(times_fetch)
    DCPDdata = np.zeros((value,), dtype=float)
    times    = np.zeros((value,), dtype=float)

    num_sample = int(srate*duration)
    t          = np.arange(num_sample) / srate

    for ii in tqdm(range(0,(len(times_fetch)))):
        buffers = conn.fetch(int(times_fetch[ii]), int(times_fetch[ii]+duration), channels)
        print( f"Connecting to NDS server {server}, Fetching {len(channels)} channel, Start GPS {int(times_fetch[ii])}, Duration {duration} sec" )
        DCPDdata[(ii)*duration*srate:(ii+1)*duration*srate] = buffers[0].data
        times[(ii)*duration*srate:(ii+1)*duration*srate]    = t+int(times_fetch[ii])-int(time_ringdown_start)

    keepIDX = np.where((times > 20) & (times < max(times)-100))

    def fun(timepoints, a, b):
        return a * np.exp(-b * timepoints)

    # initial guesses of parameters (from a fit of the log of the data)
    popt, pcov = so.curve_fit(fun, times[keepIDX], DCPDdata[keepIDX], p0 = [4.8, -0.0039])
    phi = 1/((math.pi * Excitation_Frequency) * popt[0])

    #Plot Final figure
    if Graph==True:
        plt.figure(1)
        plt.plot(times, DCPDdata, label='RMS DCPD Channel')#label='Bandpassed DCPD')
        plt.plot(times[keepIDX], fun(times[keepIDX], popt[0], popt[1]),
        label='Fit: %.6f*exp(%.6f *t)' % (popt[0], popt[1]))

        plt.ylim(-10,500)
        plt.xlim(-100,max(times))

        plt.xlabel('Time since excitation stopped [sec]')
        plt.ylabel('RMS Mode Channel [arb]')
        plt.title('RMS channel output, at frequency %.1f Hz' % Excitation_Frequency)

        plt.legend()
        plt.show()
        return Excitation_Frequency, phi
    else:
        return Excitation_Frequency, phi
