#!/usr/bin/python

# Brute force coherence (Gabriele Vajente, July 4th 2014)
# 
# Command line arguments (with default values)
#
# --ifo=L1                    interferometer prefix
# --channel=OAF-CAL_YARM_DQ   name of the main channel
# --gpsb=1087975458           starting time
# --lenght=180                amount of data to use (in seconds)
# --outfs=8192                sampling frequency of the output results (coherence will be cmputed up to outfs/2 if possible)
# --minfs=512                 skip all channels with samplig frequency smaller than this
# --naver=100                 number of averages to compute the coherence
# --dir=bruco_1087975458      output directory
# --top=100                   for each frequency, save to cohtab.txt and idxtab.txt this maximum number of coherence channels
# --webtop=20                 show this number of coherence channels per frequency, in the web page summary
#
# Example:
# ./bruco.py --channel=OAF-CAL_DARM_DQ --gpsb=1106463004 --lenght=512 --outfs=8192 --naver=100 --dir=/home/gabriele.vajente/public_html/bruco_1106463004b --top=100 --webtop=20 --minfs=32
# ./bruco.py --channel=LSC-DARM_IN1_DQ --gpsb=1107680416 --lenght=600 --outfs=4096 --naver=300 --dir=/home/gabriele.vajente/public_html/bruco_1107680416 --top=100 --webtop=20 --minfs=32 --ifo=H1
# ./bruco.py --channel=LSC-DARM_IN1_DQ --gpsb=1107760396 --lenght=600 --outfs=4096 --naver=300 --dir=/home/gabriele.vajente/public_html/bruco_1107760396 --top=100 --webtop=20 --minfs=32 --ifo=H1
#
# CHANGELOG:
# 
# 2015-01-29 added linear detrending in PSD and coherence to improve low frequency bins

import nds2
import numpy
import os
import matplotlib
matplotlib.use("Agg")
from optparse import OptionParser
from pylab import *
import time
from functions import *
import markup
import fnmatch
import scipy.stats
import sys
import subprocess
from glue import lal
from pylal import frutils, Fr

# some auxiliary files
style = 'general.css'
# this contains the list of channels to exclude
exc = 'bruco_excluded_channels.txt'

start_time = time.time()  

# some timing
t_readdata = 0
t_decimate = 0
t_coherence = 0
t_plot = 0
t_saveplot = 0

# command line options
parser = OptionParser()
parser.add_option("-c", "--channel", dest="channel",
                  default='OAF-CAL_DARM_DQ',
                  help="target channel", metavar="Channel")
parser.add_option("-i", "--ifo", dest="ifo",
                  default="L1",
                  help="interferometer", metavar="IFO")
parser.add_option("-g", "--gpsb", dest="gpsb",
                  default='1090221600',
                  help="start GPS time (-1 means now)", metavar="GpsTime")
parser.add_option("-l", "--lenght", dest="dt",
                  default='600',
                  help="duration in seconds", metavar="Duration")
parser.add_option("-o", "--outfs", dest="outfs",
                  default='8192',
                  help="sampling frequency", metavar="OutFs")
parser.add_option("-n", "--naver", dest="nav",
                  default='300',
                  help="number of averages", metavar="NumAver")
parser.add_option("-d", "--dir", dest="dir",
                  default='bruco_1090221600',
                  help="output directory", metavar="DestDir")
parser.add_option("-t", "--top", dest="ntop",
                  default='100',
                  help="number of top coherences saved in the datafile", metavar="NumTop")
parser.add_option("-w", "--webtop", dest="wtop",
                  default='20',
                  help="number of top coherences written to the web page", metavar="NumTop")
parser.add_option("-m", "--minfs", dest="minfs",
                  default='32',
                  help="minimum sampling frequency of aux channels", metavar="MinFS")
(opt,args) = parser.parse_args()

gpsb = int(opt.gpsb)
gpse = gpsb + int(opt.dt)
dt = int(opt.dt)
outfs = int(opt.outfs)
nav = int(opt.nav)
ntop = int(opt.ntop)
wtop = int(opt.wtop)
minfs = int(opt.minfs)

print "Analyzing from gps %d to %d.\n" % (gpsb, gpse)

# determine which are the useful frame files and create the cache
if opt.ifo=='L1':
	file_pref = '/archive/frames/A6/raw/' + opt.ifo + '/L-' + opt.ifo + '_R-'
if opt.ifo=='H1':
	file_pref = '/archive/frames/A6/raw/' + opt.ifo + '/H-' + opt.ifo + '_R-'
#file_pref = '/archive/frames/A6/L0/' + opt.ifo + '/L-' + opt.ifo + '_R-'
#file_pref = '/archive/frames/ER6/raw/' + opt.ifo + '/L-' + opt.ifo + '_R-'
dir1 = str(gpsb)[0:5]
dir2 = str(gpsb+dt)[0:5]
if dir1 == dir2:
    subprocess.call('ls ' + file_pref + dir1 + '/*.gwf | /usr/bin/lalapps_path2cache > bruco.cache', shell=True)
else:
    subprocess.call('ls ' + file_pref + dir1 + '/*.gwf' + file_pref + dir2 + '*.gwf | /usr/bin/lalapps_path2cache > bruco.cache', shell=True)

# read from the cache file
c = lal.Cache.fromfile(open('bruco.cache'))
d = frutils.FrameCache(c, scratchdir="/home/gabriele.vajente/tmp", verbose=True)

print ">>>>> Extracting list of channels...."
# read the list of channels from the first file
firstfile = c[0].path
os.system('/usr/bin/FrChannels ' + firstfile + ' > bruco.channels')
f = open('bruco.channels')
lines = f.readlines()
channels = []
sample_rate = []
for l in lines:
    ll = l.split()
    if ll[0][1] != '0':
        # remove all L0/H0 channels
        channels.append(ll[0])
        sample_rate.append(int(ll[1]))
channels = array(channels)
sample_rate = array(sample_rate)


# keep only channels with high sampling rate
idx = find(sample_rate >= minfs)
channels = channels[idx]
sample_rate = sample_rate[idx]

# load exclusion list
f = open(exc, 'r')
L = f.readlines()
excluded = []
for c in L:
    c = c.split()[0]
    excluded.append(c)
f.close()

# delete excluded channels, allowing for unix-shell-like wildcards
idx = ones(shape(channels), dtype='bool')
for c,i in zip(channels, arange(len(channels))):
    for e in excluded:
        if fnmatch.fnmatch(c, opt.ifo + ':' + e):
            idx[i] = False

channels = channels[idx]

# make list unique
channels = unique(channels)

# save reduced list on textfile

try:
    os.stat(opt.dir)
except:
    os.mkdir(opt.dir)

f = open(opt.dir + '/channels.txt', 'w')
for c in channels:
    f.write("%s\n" % (c))
f.close()
nch = len(channels)

print "Found %d channels\n\n" % nch

print ">>>>> Processing all channels...."

# load target channel
a = time.time()
buffer = d.fetch(opt.ifo + ':' + opt.channel, gpsb, gpse)
ch1 = numpy.array(buffer)
fs1 = len(ch1) / dt
t_readdata = t_readdata + (time.time() - a)

# downsample the target channel if needed
a = time.time()
if fs1 > outfs:
    ch1 = decimate(ch1, int(fs1 / outfs))
npoints = pow(2,int(log((gpse - gpsb) * outfs / nav) / log(2)))
t_decimate = t_decimate + (time.time() - a)

print "Number of points = %d\n" % npoints

# create table of top ntop channels for each frequency bin
ntop = int(opt.ntop)
cohtab = zeros((npoints/2+1, ntop))
idxtab = zeros((npoints/2+1, ntop), dtype=int)

errchan = []
errdata = []
flatchan = []

# compute PSD of main channel

psd1,f = psd(ch1, NFFT=npoints, Fs=outfs, 
                     window=mlab.window_hanning, noverlap=npoints/2, detrend=detrend_linear)
psd1 = sqrt(psd1)

# compute the confidence level
s = scipy.stats.f.ppf(0.95, 2, 2*nav)
s = s/(nav - 1 + s)

# analyze every channel in the list
for channel2,i in zip(channels, arange(len(channels))):
    print "%d / %d : %s" % (i+1, len(channels), channel2)

    # read auxiliary channel
    a = time.time()
    try:
        buffer = d.fetch(channel2, gpsb, gpse)
        ch2 = numpy.array(buffer)
        fs2 = len(ch2) / dt
        t_readdata = t_readdata + (time.time() - a)
    except:
        print "   Some error occurred...", sys.exc_info()
        errdata.append(channel2)
        continue
    
    # check if the channel is flat
    if min(ch2) == max(ch2):
        print "   Flat channel, skipping"
        flatchan.append(channel2)
        continue

    # resample if needed
    a = time.time()
    if fs2 < outfs:
        ch2 = numpy.repeat(ch2, int(outfs / fs2))
    if fs2 > outfs:
        ch2 = decimate(ch2, int(fs2 / outfs))
    t_decimate = t_decimate + (time.time() - a)

    # compute coherence
    try:
        a = time.time()
        c,f = cohere(ch1, ch2, NFFT=npoints, Fs=outfs, 
                     window=mlab.window_hanning, noverlap=npoints/2,
		     detrend=detrend_linear)
        # remove coherence points out of second channel bandwidth
        c[f>fs2] = 0
        
        # save coherence in summary table
        for cx,j in zip(c,arange(len(c))):
            top = cohtab[j, :]
            idx = idxtab[j, :]
            if cx > min(top):
                ttop = concatenate((top, [cx]))
                iidx = concatenate((idx, [i]))
                ii = ttop.argsort()
                ii = ii[1:]
                cohtab[j, :] = ttop[ii]
                idxtab[j, :] = iidx[ii]
        
        t_coherence = t_coherence + (time.time() - a)
        
        # plot
        a = time.time()
        figure()
        subplot(211)
        title('Coherence %s vs %s - GPS %d' % (opt.channel, channel2, gpsb), fontsize='smaller')
        loglog(f, c, f, ones(shape(f))*s, 'r--', linewidth=0.5)
        axis(xmax=outfs/2)
        axis(ymin=1e-3, ymax=1)
        grid(True)
        ylabel('Coherence')
        subplot(212)
        loglog(f, psd1)
        mask = ones(shape(f))
        mask[c<s] = nan
        loglog(f, psd1 * sqrt(c) * mask, 'r')
        axis(xmax=outfs/2)
        xlabel('Frequency [Hz]')
        grid(True)

        t_plot = t_plot + (time.time() - a)
        a = time.time()
        savefig(opt.dir + '/%s.pdf' % channel2.split(':')[1], format='pdf')
        close()

        t_saveplot = t_saveplot + (time.time() - a)
        del ch2, c, f
    except:
        print "   Some error occurred...", sys.exc_info()
        errchan.append(channel2)
        del ch2
        pass
       
    el = (time.time() - start_time)
    print "   elapsed time %d min  -- expected remaining %d min" % \
          (el / 60, el / float(i+1) * float(nch - i) / 60)
    

# save some results
numpy.savetxt(opt.dir + '/cohtab.txt', cohtab)
numpy.savetxt(opt.dir + '/idxtab.txt', idxtab)

# save error lists
f = open(opt.dir + '/errorchannels.txt', 'w')
for c in errchan:
    f.write("%s\n" % (c))
f.close()
f = open(opt.dir + '/errordata.txt', 'w')
for c in errdata:
    f.write("%s\n" % (c))
f.close()
f = open(opt.dir + '/flatchannels.txt', 'w')
for c in flatchan:
    f.write("%s\n" % (c))
f.close()

print ">>>>> Generating report...."

# get list of files
command = 'ls %s/*.pdf' % opt.dir
p,g = os.popen4(command)
L = g.readlines()
files = []
for c in L:
    c = (c[:-5]).split('/')[-1]
    files.append(c)

# open web page
page = markup.page( )
page.init( title="Brute force Coherences", \
           css=( style ), \
           footer="(2014)  <a href=mailto:vajente@caltech.edu>vajente@caltech.edu</a>" )


# first section, top channels per frequency bin
nf,nt = shape(cohtab)
freq = linspace(0,outfs/2,nf)

page.h1('Top %d coherences at all frequencies' % wtop)
page.h2('GPS %d (%s) + %d s' % (gpsb, gps2str(gpsb), dt))

page.table(border=1, style='font-size:12px')
page.tr()
page.td(bgcolor="#5dadf1")
page.h3('Frequency [Hz]')
page.td.close()
page.td(colspan=ntop, bgcolor="#5dadf1")
page.h3('Top channels')
page.td.close()
page.tr.close()

for i in range(nf):
    page.tr()
    page.td(bgcolor="#5dadf1")
    page.add("%.2f" % freq[i])
    page.td.close()
    for j in range(wtop):
        
        if cohtab[i,-(j+1)] > s:
            page.td(bgcolor=cohe_color(cohtab[i,-(j+1)]))
            ch = (channels[int(idxtab[i,-(j+1)])]).split(':')[1]
            page.add("<a target=_blank href=%s.pdf>%s</a><br>(%.2f)" \
                         % (ch, newline_name(ch), cohtab[i,-(j+1)]))
        else:
            page.td(bgcolor=cohe_color(0))

        page.td.close()
    page.tr.close()

page.table.close()

# second section, links to all coherences
page.h1('Coherence with all channels ')
page.h2('GPS %d (%s) + %d s' % (gpsb, gps2str(gpsb), dt))

N = len(files)
m = 6     # number of channels per row
M = N / m + 1

page.table(border=1)
for i in range(M):
    page.tr()
    for j in range(m):
        if i*m+j < N:
            page.td()
            page.add('<a target=_blank href=%s.pdf>%s</a>' % (files[i*m+j], files[i*m+j]))
            page.td.close()
        else:
            page.td()
            page.td.close()            
    
    page.tr.close()

page.table.close()
page.br()

# save page
page.savehtml(opt.dir  + '/index.html')


el = time.time() - start_time
print "\n\nTotal elapsed time %d s" % int(el)

print "    Data access  = %d s" % t_readdata
print "    Decimation   = %d s" % t_decimate
print "    Coh. comput. = %d s" % t_coherence
print "    Plotting     = %d s" % t_plot
print "    Saving plots = %d s" % t_saveplot
