import warnings
warnings.filterwarnings('ignore')
from numpy import *
import numpy
from pylab import *
import os
from IPython import display
import time
from scipy.signal import *
import nds2
import cPickle as pickle
import sys

matplotlib.rcParams.update({'font.size': 16})
matplotlib.rcParams.update({'figure.figsize': (10,6)})

def save_fig(fig_id, tight_layout=True):
    path = fig_id + '.png'
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format='png', dpi=100)

## parse parameters
if len(sys.argv) < 3:
    print('Usage: violin_psd.py start_gps end_gps [resolution]')
    exit(0)
    
gps0 = int(sys.argv[1])
gps1 = int(sys.argv[2])

if len(sys.argv) == 4:
    dt = 1.0/float(sys.argv[3])
else:
    dt = 5000.

## read data
print('Reading data from %d to %d...' % (gps0, gps1))
channels = ['H1:CAL-DELTAL_EXTERNAL_DQ']
conn = nds2.connection('nds.ligo-wa.caltech.edu')
data = conn.fetch(gps0, gps1, channels)
h = data[0].data
# pre-decimate
from scipy.signal import decimate
h = decimate(h, 16384/4096)
fs = 4096

## Heterodyne
print('Heterodyning...')
f0 = 507.0
bw = 32.
# build LO
t  = arange(h.shape[0], dtype=float)/fs
lo = exp(-2j*pi*f0*t) 
# heterodyne
hlo = h * lo
# downsample
hlo = decimate(hlo, int(fs/(2*bw)))
fso = int(2*bw)

## Compute PSD
print('Computing PSD (resolution %.2fmHz, length %ds)...' % (1000./dt, dt))
np = int(fso * dt)
fr,sp = welch(hlo, fs=fso, window='hann', nperseg=np, noverlap=np/2, nfft=np)
fr = fftshift(fr) + f0
sp = fftshift(sp)

## Plot and save
print('Plotting and saving...')
figure()
semilogy(fr,sp)
grid()
xlim([500, 517])
xlabel('Frequency [Hz]')
ylabel('PSD uncalibrated [a.u.]')
save_fig('violin_modes_%.2fmHz_resolution_%d_%d' % (1000./dt, gps0, gps1))

X = c_[fr,sp]
savetxt('violin_modes_%.2fmHz_resolution_%d_%d.txt' % (1000./dt, gps0, gps1), X)

## Code to find peaks (adapted from C.Ri.Me. lab code https://git.ligo.org/gabriele-vajente/pycrime)

## Find peaks
def smooth_bg_log(fr, sp, nbands, threshold):
    """
    Compute a smooth approximation of the spectral background noise by
    removing lines. The algorithm divides the whole frequency range in
    bands, and compute the mean and std of the PSD in each band. Bins 
    with large PSD values (out of threshold*std) are removed and the mean
    updated accordingly. Then the mean for each band is interpolated 
    linearly to all frequency bins.
    
    Inputs:
       fr        = frequency vector
       sp        = PSD vector
       nbands    = number of bands into which the whole frequency range is divided
       threshold = exclude lines larger than sigma*threshold
    Outputs:
       bg        = smoothed background
    """
    
    n = len(sp)
    npoints = int(n/nbands)
    band_avg = zeros((nbands+2,))
    fr_avg   = zeros((nbands+2,))
    for i in range(nbands):
        freqs = fr[npoints*i:min(npoints*(i+1), n)]
        spec  = log(sp[npoints*i:min(npoints*(i+1), n)])
        fr_avg[i+1] = mean(freqs)
        st = std(spec)
        avg = mean(spec)
        band_avg[i+1] = mean(spec[where(abs(spec-avg) < st*threshold)])
    band_avg[0] = band_avg[1]
    band_avg[-1] = band_avg[-2]
    fr_avg[0] = 0
    fr_avg[-1] = fr[-1]
    bg = interp(fr, fr_avg, exp(band_avg))
    return bg

def find_peaks(fr, swx, sbgx, minsnr=40, minfr=500):
    # whiten with background
    swx = swx / sbgx
   
    # find peaks in X
    idx = where(logical_and(swx > minsnr, fr > minfr))[0]
    if len(idx) != 0:
        pk = []
        # aggregate all adjiacent bins above threshold
        di = diff(idx)
        dix = where(di > 1)[0]
        if len(dix) != 0:
            pk.append(idx[0:dix[0]+1])
            for i in range(len(dix)-1):
                pk.append(idx[dix[i]+1:dix[i+1]+1])
            pk.append(idx[dix[-1]+1:])
        else:
            pk.append(idx[0])
        # compute mean frequency and total SNR for each cluster
        freqsx = map(lambda x: sum(fr[x] * swx[x]**2) / sum(swx[x]**2), pk)
        snrx   = map(lambda x: sqrt(sum(swx[x]**2)), pk)
    else:
        freqsx = []
        snrx = []
        
    return freqsx,snrx

## Find peaks
bg = smooth_bg_log(fr, sp, 1000, 2)
freq, snr = find_peaks(fr, sp, bg, minsnr=50, minfr=500)
idx = [argmin(abs(f - fr)) for f in freq]

figure()
semilogy(fr,sp)
semilogy(fr[idx], sp[idx], 'rx')
grid()
xlim([500, 517])
xlabel('Frequency [Hz]')
ylabel('PSD uncalibrated [a.u.]')
save_fig('violin_modes_%.2fmHz_resolution_peaks_%d_%d' % (1000./dt, gps0, gps1))

## Print peaks
print('\n\nFrequency [Hz]\t\tSNR')
for f,s in zip(freq, snr):
    print('%.4f\t\t%.0f' % (f,s))

## Save peaks
with open('violin_modes_%.2fmHz_resolution_peaks_%d_%d.txt' % (1000./dt, gps0, gps1), 'w') as outfile:
    outfile.write('Frequency [Hz]\t\tSNR\n')
    for f,s in zip(freq, snr):
        outfile.write('%.4f\t\t%.0f\n' % (f,s))

