#!/usr/bin/env python3

from pydarm.cmd._report import list_all_reports
import h5py
import os
import numpy as np


def compare_fitted_sensing(report):
    # # sensing h5py format from pydarm/measurement.py:
    #
    #         chain : `float`, array-like
    #             The MCMC chain where index:
    #             - 0 is the gain in units of counts per meter of DARM
    #             - 1 is the coupled cavity pole frequency in units of Hz
    #             - 2 is the optical spring frequency in units of Hz
    #             - 3 is the quality factor of the optical spring (Q, unitless)
    #             - 4 is the residual time delay in units of seconds
    js = report.get_sens_mcmc_results()['map']
    h5 = h5py.File(os.path.join(report.path,
                                "sensing_mcmc_chain.hdf5"))['posteriors'][:]
    rs = report.model.sensing
    ini = {'Hc': rs.coupled_cavity_optical_gain,
           'Fcc': rs.coupled_cavity_pole_frequency,
           'Qs': rs.detuned_spring_q,
           'Fs': rs.detuned_spring_frequency}

    # json results, h5py results, ini results
    o = {'Hc': (js['Hc'], np.median(h5[:,0]), ini['Hc']),
         'Fcc': (js['Fcc'], np.median(h5[:,1]), ini['Fcc']),
         'Fs': (js['Fs'], np.median(h5[:,2]), ini['Fs']),
         'Qs': (js['Qs'], np.median(h5[:,3]), ini['Qs'])
         }

    return o

def compare_fitted_act(report):
    # chain : `float`, array-like
    #     The MCMC chain where index:
    #     - 0 is the gain in units of newtons per driver output (amps or
    #     volts**2). For example, for TST the units will be in N/V**2.
    #     - 1 is the residual time delay in units of seconds

    optics = ("EX",)
    arms = ("L1", "L2", "L3")
    o = {}

    js = report.get_act_mcmc_results()
    for arm in arms:
        for ii, optic in enumerate(optics):
            key = f"{arm}/{optic}"
            o[key] = {}
            if ii == 0:
                ra = getattr(report.model.actuation, f"{optic[1].lower()}arm")
                ini = {"L1": ra.uim_npa,
                       "L2": ra.pum_npa,
                       "L3": ra.tst_npv2
                       }
            h5_fname = f"actuation_{arm}_{optic}_mcmc_chain.hdf5"
            h5 = h5py.File(os.path.join(report.path, h5_fname))['posteriors'][:]

            # json results, h5py results, ini results
            o[key] = {'H_A': (js[key]['map']['H_A'], np.median(h5[:,0]), ini[arm]),
                      'tau_A': (js[key]['map']['tau_A'], np.median(h5[:,1]), '')}
    return o


if __name__ == '__main__':
    reps = list_all_reports()
    for r in reps:
        print(f"Report {r.id}")
        sr = compare_fitted_sensing(r)
        ar = compare_fitted_act(r)
        print("# JSON, H5PY, INI")
        for k,v in sr.items():
            print(f"{k}: ", v)
        for k,v in ar.items():
            print(f"{k}: ")
            print("# JSON, H5PY, INI")
            for kk, vv in v.items():
                print(f"\t{kk} ", vv)
        print('\n\n')
