import corner
import emcee
import matplotlib.pyplot as plt
import numpy as np
import scipy.optimize as opt
import sys
import tqdm

from gwpy.timeseries import TimeSeriesDict
from uncertainties import ufloat as uf
from uncertainties import unumpy as unp

start = '2016-09-14 23:31:00'
stop = '2016-09-15 00:16:00'
channels = ['H1:SUS-PI_PROC_COMPUTE_MODE{}_RMSMON'.format(ii)
        for ii in [15, 31, 32]]

dd = 6.2e-6 # m; coating thickness
eratio_arr = np.array([13.3, 13.9, 3.3])

print('Fetching data...')
data_ts_dict = TimeSeriesDict.fetch(channels, start, stop)
dh2_ts = data_ts_dict[channels[0]]
bf_ts = data_ts_dict[channels[1]]
dh1_ts = data_ts_dict[channels[2]]

tt = bf_ts.times.value
tt -= tt[0]
bf_amp = bf_ts.value
dh1_amp = dh1_ts.value
dh2_amp = dh2_ts.value

tt = np.resize(tt, (int(len(tt)/64.), 64))
bf_amp = np.resize(bf_amp, (int(len(bf_amp)/64.), 64))
dh1_amp = np.resize(dh1_amp, (int(len(dh1_amp)/64.), 64))
dh2_amp = np.resize(dh2_amp, (int(len(dh2_amp)/64.), 64))

tt = np.mean(tt, axis=1)
bf_means = np.mean(bf_amp, axis=1)
bf_stds = np.std(bf_amp, axis=1)
dh1_means = np.mean(dh1_amp, axis=1)
dh1_stds = np.std(dh1_amp, axis=1)
dh2_means = np.mean(dh2_amp, axis=1)
dh2_stds = np.std(dh2_amp, axis=1)

means_list = [bf_means, dh1_means, dh2_means]
stds_list = [bf_stds, dh1_stds, dh2_stds]
names_list = ['Butterfly', 'Drumhead 1', 'Drumhead 2 vert.']
ff_arr = np.array([6054., 8158., 9830.])

def curve(tt, amp, tau):
    return amp*np.exp(-(tt-tt[0])/tau)

def curve_unc(tt, amp, tau):
    return amp*unp.exp(-(tt-tt[0])/tau)

def chisq(theta, tt, means, stds):
    amp = theta[0]
    tau = theta[1]
    if amp < 0 or tau < 0:
        return -np.inf
    if abs(amp) > 2*means[0] or abs(tau) > 10*(tt[-1]-tt[0]):
        return -np.inf
    else:
        curve = amp*np.exp(-(tt-tt[0])/tau)
        return -np.sum((curve-amp)**2/(2*stds**2))

fit_params_list = []
amp_list = []
tau_list = []
redchisq_list = []
for name, ff, means, stds in zip(names_list, ff_arr, means_list, stds_list):
    tau_guess = (tt[-1]-tt[0])/(1-means[-1]/means[0])
    fit_params = opt.curve_fit(curve, tt, means, p0=(means[0], tau_guess),
            sigma=stds, absolute_sigma=True)
    #print(fit_params)
    fit_params_list.append(fit_params)
    amp_fit = fit_params[0][0]
    amp_unc = np.sqrt(fit_params[1][0,0])
    tau_fit = fit_params[0][1]
    tau_unc = np.sqrt(fit_params[1][1,1])
    redchisq = -chisq((amp_fit, tau_fit), tt, means, stds)/len(tt[:-2])
    redchisq_list.append(redchisq)
    amp_unc *= np.sqrt(redchisq)
    tau_unc *= np.sqrt(redchisq)
    amp_list.append(uf(amp_fit, amp_unc))
    tau_list.append(uf(tau_fit, tau_unc))
    hh, ax = plt.subplots()
    ax.errorbar(tt/60, means, stds, fmt='o', ms=0, c=(0, 0, 0.5, 0.5))
    #curve_uncs = curve_unc(tt, uf(amp_fit, amp_unc), uf(tau_fit, tau_unc))
    #ax.errorbar(tt/60, unp.nominal_values(curve_uncs), unp.std_devs(curve_uncs),
    #        color=(0.5, 0, 0, 0.5), fmt='-')
    ax.plot(tt/60, curve(tt, amp_fit, tau_fit), c=(0.5, 0, 0, 0.5))
    ax.set_xlabel('Time [min.]')
    ax.set_ylabel('Amplitude [arb.]')
    ax.set_title('{} mode, {:.0f} Hz'.format(name, ff))
    hh.savefig('ringdown_{}.pdf'.format(name.replace(' ', '_')))

tau_arr = np.array(tau_list)
Q_arr = np.pi*ff_arr*tau_arr
phi_arr = 1./Q_arr

print('Q values: ', Q_arr)

def lnlike(theta, phi_vals, phi_uncs, eratios):
    phi_sub, phi_coat = theta
    phi_eqs = phi_sub + eratios*dd*phi_coat
    if phi_sub < 0 or phi_coat < 0:
        return -np.inf
    else:
        return -np.sum((phi_eqs - phi_vals)**2/(2*phi_uncs**2)) - np.log(phi_sub+phi_coat)

nwalkers, ndim, nthreads = 50, 2, 1
samp = emcee.EnsembleSampler(nwalkers, ndim, lnlike,
        args=(unp.nominal_values(phi_arr), unp.std_devs(phi_arr), eratio_arr),
        threads=nthreads)
pos = np.array([
    1e-8*np.random.normal(1, 0.1, nwalkers),
    1e-4*np.random.normal(1, 0.1, nwalkers),
    ])
pos = np.transpose(pos)
burnin = 200
iters = 200
for ii in tqdm.tqdm(range(burnin+iters)):
    samp.run_mcmc(pos, burnin+iters);

samples = samp.chain[:,burnin:,:].reshape((-1, ndim))

samples[:,0] *= 1e8
samples[:,1] *= 1e4

print('Plotting...')
hpost = corner.corner(samples, labels=[r'$\phi_\mathrm{s}\times10^8$', r'$\phi_\mathrm{c}\times10^4$'],
        smooth=2, range=[(0, 4), (0, 6)], quantiles=[0.17, 0.5, 0.83], verbose=True)
hpost.savefig('loss_posterior.pdf')
