import numpy as np
import nds2 as nds
import gpstime
import matplotlib.pyplot as plt

def myround(x, base=60):

    return int(base * round(float(x)/base))


def get_data(start, end, clean):

    # fetch data

    conn = nds.connection('h1nds1', 8088)
    conn.set_parameter('GAP_HANDLER', 'STATIC_HANDLER_ZERO')
    buff_lock = conn.fetch(myround(start), myround(end), ['H1:GRD-ISC_LOCK_STATE_N.mean,m-trend'])            # guardian state, samples once per minute
    buff_ms = conn.fetch(myround(start), myround(end), ['H1:ISI-GND_STS_ITMY_Z_BLRMS_100M_300M.mean,m-trend'])# microseism level, samples once per minute

    lock_data = buff_lock[0].data
    for n, i in enumerate(lock_data): # force 600 as the maximum guardian state for the data
        if i > 600:
            lock_data[n] = 600
    lock_minutes = np.arange(0, len(buff_lock[0].data))
    lock_days = lock_minutes/(60*24)

    ms_data = buff_ms[0].data
    ms_minutes = np.arange(0, len(buff_ms[0].data))
    ms_days = ms_minutes/(60*24)

    # refine data

    if clean == True: # when set to True, this gets rid of values during state transitions that make data harder to read

        lock_days_cleaned, lock_data_cleaned = [], []

        for i in range(len(lock_days)):
            if i < 10 or i >= 600:
                lock_days_cleaned.append(lock_days[i])
                lock_data_cleaned.append(lock_data[i])

        if start == start_O2: # get rid of large data gaps present in O2

            remove_ranges = [32000, 50000, 228000, 255000]
            lock_days_O2_cleaned = lock_days_cleaned[0:remove_ranges[0]] + lock_days_cleaned[remove_ranges[1]:remove_ranges[2]] + lock_days_cleaned[remove_ranges[3]:]
            ms_days_O2_cleaned = ms_days[0:remove_ranges[0]].tolist() + ms_days[remove_ranges[1]:remove_ranges[2]].tolist() + ms_days[remove_ranges[3]:].tolist()
            lock_data_O2_cleaned = lock_data_cleaned[0:remove_ranges[0]] + lock_data_cleaned[remove_ranges[1]:remove_ranges[2]] + lock_data_cleaned[remove_ranges[3]:]
            ms_data_O2_cleaned = ms_data[0:remove_ranges[0]].tolist() + ms_data[remove_ranges[1]:remove_ranges[2]].tolist() + ms_data[remove_ranges[3]:].tolist()

            return lock_days_O2_cleaned, ms_days_O2_cleaned, lock_data_O2_cleaned, ms_data_O2_cleaned
        
        else:

            return lock_days_cleaned, ms_days, lock_data_cleaned, ms_data

    elif start == start_O2: # get rid of large data gaps present in O2

        remove_ranges = [32000, 50000, 228000, 255000]
        lock_data_O2_cleaned = lock_data[0:remove_ranges[0]].tolist() + lock_data[remove_ranges[1]:remove_ranges[2]].tolist() + lock_data[remove_ranges[3]:].tolist()
        ms_data_O2_cleaned = ms_data[0:remove_ranges[0]].tolist() + ms_data[remove_ranges[1]:remove_ranges[2]].tolist() + ms_data[remove_ranges[3]:].tolist()

        return lock_days, ms_days, lock_data_O2_cleaned, ms_data_O2_cleaned
    
    else:

        return lock_days, ms_days, lock_data, ms_data
    

def get_fractional_data(start, end, clean, N):

    _, _, lock_data, ms_data = get_data(start, end, clean)
    days = int((end - start)/(3600*24))
    weeks = int(days/7)
    months = int(days/30)

    frac_locked, avg_microseism = [], []

    # average data over some length of time determined by the variable N

    if N == 60*24:
        for i in range(days):
            frac_locked.append(np.sum(lock_data[i*N:(i+1)*N]) / (600*N))
            avg_microseism.append(np.sum(ms_data[i*N:(i+1)*N]) / N)
    elif N == 60*24*7:
        for i in range(weeks):
            frac_locked.append(np.sum(lock_data[i*N:(i+1)*N]) / (600*N))
            avg_microseism.append(np.sum(ms_data[i*N:(i+1)*N]) / N)
    elif N == 60*24*30:
        for i in range(months):
            frac_locked.append(np.sum(lock_data[i*N:(i+1)*N]) / (600*N))
            avg_microseism.append(np.sum(ms_data[i*N:(i+1)*N]) / N)

    return avg_microseism, frac_locked


start_O1, end_O1 = 1126623617, 1136649617
start_O2, end_O2 = 1164556817, 1187733618
start_O3a, end_O3a = 1238166018, 1253977218
start_O3b, end_O3b = 1256655618, 1269363618
start_O4a, end_O4a = 1368975618, 1389456018 #gpstime.gpsnow()


# make figures

lock_days_O1, ms_days_O1, lock_data_O1, ms_data_O1 = get_data(start_O1, end_O1, clean=True)

fig1, axs1 = plt.subplots(figsize = [15, 5])

axs1.set_xlabel('Time [d]')
axs1.set_ylabel('H1 lock state', color='tab:red')
axs1.plot(lock_days_O1, lock_data_O1, color='tab:red')
axs1.tick_params(axis='y', labelcolor='tab:red')
axs1.set_ylim(0, 650)

axs2 = axs1.twinx()

axs2.set_ylabel('Microseism level [nm/s]', color='tab:blue')
axs2.plot(ms_days_O1, ms_data_O1, color='tab:blue')
axs2.tick_params(axis='y', labelcolor='tab:blue')
axs2.set_ylim(0, 2200)

axs1.grid()

fig1.tight_layout()
fig1.savefig("figure1.png", bbox_inches = "tight")


fig2 = plt.figure()

xO1, yO1 = get_fractional_data(start_O1, end_O1, clean=True, N=60*24)
plt.scatter(xO1, yO1, label = "O1", color = "red", alpha = 0.5)

xO2, yO2 = get_fractional_data(start_O2, end_O2, clean=True, N=60*24)
plt.scatter(xO2, yO2, label = "O2", color = "orange", alpha = 0.5)

xO3a, yO3a = get_fractional_data(start_O3a, end_O3a, clean=True, N=60*24)
plt.scatter(xO3a, yO3a, label = "O3a", color = "skyblue", alpha = 0.5)

xO3b, yO3b = get_fractional_data(start_O3b, end_O3b, clean=True, N=60*24)
plt.scatter(xO3b, yO3b, label = "O3b", color = "blue", alpha = 0.5)

xO4a, yO4a = get_fractional_data(start_O4a, end_O4a, clean=True, N=60*24)
plt.scatter(xO4a, yO4a, label = "O4a", color = "indigo", alpha = 0.5)

ax2 = fig2.gca()
ax2.set_xlabel("Average microseism level per day [nm/s]")
ax2.set_ylabel("H1 fraction of time locked per day")
ax2.legend(loc=3)

fig2.savefig("figure2.png", bbox_inches = "tight")


fig3, axs3 = plt.subplots(2, 2, figsize = [10, 8], sharex = True, sharey = True)

xO1, yO1 = get_fractional_data(start_O1, end_O1, clean=True, N=60*24*7)
xO2, yO2 = get_fractional_data(start_O2, end_O2, clean=True, N=60*24*7)
xO3a, yO3a = get_fractional_data(start_O3a, end_O3a, clean=True, N=60*24*7)
xO3b, yO3b = get_fractional_data(start_O3b, end_O3b, clean=True, N=60*24*7)
xO4a, yO4a = get_fractional_data(start_O4a, end_O4a, clean=True, N=60*24*7)

axs3[0, 0].scatter(xO1, yO1, label = "O1", color = "red", alpha = 0.5)
axs3[0, 0].scatter(xO4a, yO4a, label = "O4a", color = "indigo", alpha = 0.5)

axs3[0, 1].scatter(xO2, yO2, label = "O2", color = "orange", alpha = 0.5)
axs3[0, 1].scatter(xO4a, yO4a, label = "O4a", color = "indigo", alpha = 0.5)

axs3[1, 0].scatter(xO3a, yO3a, label = "O3a", color = "mediumseagreen", alpha = 0.5)
axs3[1, 0].scatter(xO4a, yO4a, label = "O4a", color = "indigo", alpha = 0.5)

axs3[1, 1].scatter(xO3b, yO3b, label = "O3b", color = "dodgerblue", alpha = 0.5)
axs3[1, 1].scatter(xO4a, yO4a, label = "O4a", color = "indigo", alpha = 0.5)

axs3[1, 0].set_xlabel("Average microseism level per week [nm/s]")
axs3[1, 1].set_xlabel("Average microseism level per week [nm/s]")
axs3[0, 0].set_ylabel("H1 fraction of time locked per week")
axs3[1, 0].set_ylabel("H1 fraction of time locked per week")
axs3[0, 0].set_xscale("log")
axs3[0, 1].set_xscale("log")
axs3[1, 0].set_xscale("log")
axs3[1, 1].set_xscale("log")
axs3[0, 0].legend()
axs3[0, 1].legend()
axs3[1, 0].legend()
axs3[1, 1].legend()

fig3.tight_layout()
fig3.savefig("figure3.png", bbox_inches = "tight")