#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue May 13 15:41:51 2025

This script is a prototype for an OSEM calibration routine.
The idea is to drive the ISI in the translational cartesian degrees of freedom, 
and use the results to calibrate the upper mass OSEMs of various suspensions

- May 2025, Edgard Bonilla

+ May 14, 2025 Implemented as an SR3 calibration script. Requires measuremens to have been taken
@author: controls
"""

import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as signal
import scipy.io as spio
import ezca
from gwpy.timeseries import TimeSeries
from gwpy.time import tconvert
from datetime import datetime, timezone

ezca=ezca.Ezca()

## Auxiliary functions
# Clean the angle for plots
def cleanAngle(TF):
    bit = 0.1
    ang = np.angle(TF * np.exp(1j * np.pi * bit))
    ang = ang - np.pi * bit
    ang = ang * 180 / np.pi
    return ang

# Returns the codename for the type of a particualr suspension given its name. For example,
# PR3 returns 'HLTS', IM3 returns 'HAUX', etc.
def findSusType(given_susName):
    susName=given_susName.upper()
    
    # Dictionary that keeps all of the suspension name associations
    possible_types={
        'QUAD':('QUAD'),
        'BSFM':('BS'),
        'HLTS':('PR3' ,'SR3'),
        'HSTS':('FC1', 'FC2', 'MC1' ,'MC2' , 'MC3' , 'PR2' , 'PRM' , 'SR2' , 'SRM' ),
        'TMTS':('TMSX' , 'TMSY' ),
        'OMCS':('OMC'),
        'OFIS':('OFI'),
        'OPOS':('OPO'),
        'HAUX':('IM1' , 'IM2' , 'IM3' , 'IM4' ),
        'HXDS':('OM2' , 'ZM1' , 'ZM2' , 'ZM3' , 'ZM4' , 'ZM5' , 'ZM6'),
        'HTTS':('OM1' , 'OM3' , 'RM1' , 'RM2')}
    
    for susType in possible_types:
        for sus in possible_types[susType]:
            if sus == susName:
                return susType # Exit early when we find the correct suspension Type
    # This part of the code can only be reached if we did not find a correct type
    raise NameError(f'The suspension name {given_susName} is not listed in this script')

# Returns the name of the top stage of a given SusType 
#TODO: implement this for more than a few suspensions
#TODO: implement this by using a dictionary when we have more suspensions
def findTopStageName(susType):
    susType=susType.upper()
    if susType == 'QUAD':
	    return 'M0'
    elif susType == 'HLTS' or susType == 'HSTS':
        return 'M1'
    else:
        raise NotImplementedError(f'Sorry, unrecognized susType {susType}. Try QUAD, HLTS, or HSTS.')

def calibPairings(drivedofs, euldofs,osemdofs,OSEM2EUL):
    # Creates a dictionary to pair 'L' 'T' and 'V' with the corresponding OSEM dofs
    
    indices=[ euldofs.index(drivedof) for drivedof in drivedofs]
    
    calibdofs={}
    for drivedof , jj in zip(drivedofs,indices):
        calibdofs[drivedof]=tuple([dof for ii, dof in enumerate(osemdofs) if OSEM2EUL[jj,ii]!=0])
    
    return calibdofs
    
# This function is here because Python sucks.
# Scipy does not have a native way of dealing with MIMO systems. 
# I cannot guarantee that LHO/LLO have python-control as part of their suite
# This assumes MATLAB indexing (which starts at 1)
def SISOfreqresp(ss,Out,In,wlist):
    In=In-1;
    Out=Out-1;
    A2=np.array(ss.A)
    B2=np.reshape(ss.B[:,In] ,[len(ss.A),1])
    C2=np.reshape(ss.C[Out,:],[1,len(ss.A)])
    D2=np.reshape(ss.D[Out,In] ,[1,1])
    n=len(ss.A);
    
    # Return the frequency response
    resp=[(np.matmul(C2,np.linalg.inv(1j*w*np.eye(n)-A2).dot(B2)) +D2).flatten() for w in wlist];
    
    return np.array(resp).flatten()

def getGPSdatetime(timestring):
    #Converts a time string with the YYYY-MM-DD_hhmm to gps time
    # Assumes UTC time
    YYYY, MM, DD_hhmm=timestring.split('-')
    DD, hhmm=DD_hhmm.split('_')
    
    date=datetime(int(YYYY), int(MM), int(DD), int(hhmm[0:2]),int(hhmm[2:4]), tzinfo=timezone.utc)
    GPSdate=tconvert(date)
    
    return GPSdate

def eulUnits(dof):
    dof=dof.upper()
    if dof =='L' or dof =='T' or dof =='V':
        return 'm'
    
    elif dof =='R' or dof =='P' or dof =='Y':
        return 'rad'
    else:
        raise NameError(f'The degree of freedom {dof} is not an Euler degree of freedom. Options are L T V R P Y.\n')
        return 0
    
def getGainAtTime(chan_name,timestring):
    ifo=chan_name.split(':')[0]
    if ifo=='H1':
        hostname='nds.ligo-wa.caltech.edu'
    elif ifo=='L1':
        hostname='nds.ligo-la.caltech.edu'   
    
    gps=getGPSdatetime(timestring)
    duration=0.5; #[s] duration of the gain obtaining
    
    gain=np.mean(TimeSeries.get(chan_name,gps,gps+duration,allow_tape=True,host=hostname).value)
                 
    return gain
#%% 
# Initialize settings
plt.rcParams['lines.linewidth'] = 2  # Thick lines
plt.rcParams['figure.figsize'] = [8, 8] # width=8 inches, height=8 inches
plt.rcParams['axes.grid'] = True

# Define the coordinates to find the data and the model
svnDir = '/home/controls/SusSVN/sus/trunk/'
ifo='H1'
isiName='HAM5'
isiStage='ST1'
susName='SR3'
measurement_datetime='2025-05-07_0000'


susType=findSusType(susName)

toolsDir=f'{svnDir}Common/MatlabTools/'
importDir=f'{toolsDir}ExportedModels/'
modelDir=importDir # The models are inside the "ExportedModels directroy in MatlabTools
projDir=importDir  # The projection matrix master file is inside the "ExportedModels directroy in MatlabTools
Data_folder = f'{svnDir}{susType}/{ifo}/{susName}/Common/Data/'
#%% 
# Load the EUL2OSEM matrix and the (ordered) dof names
# NOTE: The README gets jumbled up, but the projectrion matrices are kept intact thanks to simplify_cells
projection_struct = spio.loadmat(f'{projDir}SUS_projection_file',simplify_cells=True) 

susStage=projection_struct['SUS_projections'][susType]['susStages']['Sensor'][0] #Find the stage of the suspension that we are calibrating

#TODO: Deal with incomplete bases of OSEM/EUL, data will need to be imported per OSEM
#NOTE: The code assumes sensalign is the identity matrix
S_osem2eul=  projection_struct['SUS_projections'][susType][susStage]['OSEM2EUL']
S_eul2osem = np.linalg.inv(S_osem2eul)


osemdofs=tuple(projection_struct['SUS_projections'][susType][susStage]['osemdof']) #Make into tuple to prevent mutability
euldofs =tuple(projection_struct['SUS_projections'][susType][susStage]['euldof']) #Make into tuple to prevent mutability
drivedofs = ('L', 'T', 'V') # We always drive the translational euler degrees of freedom

#This function finds the DOFs that are paired with 
#TODO: Deal with not observing T or V
calibdofs=calibPairings(drivedofs, euldofs,osemdofs,S_osem2eul)

#%% 
# TODO: convert this into a function to import EUL data
# Grab the data and organize it in LTVRPY
LY_filename =  f'{measurement_datetime}_{ifo}ISI{isiName}_{isiStage}_WhiteNoise_ISO_Y_0p05to40Hz_calibration.txt'
VRP_filename = f'{measurement_datetime}_{ifo}ISI{isiName}_{isiStage}_WhiteNoise_ISO_Z_0p05to40Hz_calibration.txt'
T_filename =   f'{measurement_datetime}_{ifo}ISI{isiName}_{isiStage}_WhiteNoise_ISO_X_0p05to40Hz_calibration.txt'

# Extract the matrices
LY_mat = np.loadtxt(f'{Data_folder}{LY_filename}')
VRP_mat = np.loadtxt(f'{Data_folder}{VRP_filename}')
T_mat = np.loadtxt(f'{Data_folder}{T_filename}')

raw_data = {}  # Dictionary to organize the dataset
raw_data['freqs'] = {}
raw_data['freqs']['L'] = LY_mat[:, 0]
raw_data['freqs']['Y'] = LY_mat[:, 0]
raw_data['freqs']['V'] = VRP_mat[:, 0]
raw_data['freqs']['R'] = VRP_mat[:, 0]
raw_data['freqs']['P'] = VRP_mat[:, 0]
raw_data['freqs']['T'] = T_mat[:, 0]

# Grab the DOFs from each transfer function measurement
raw_data['L'] = LY_mat[:, 1] + 1j * LY_mat[:, 2]
raw_data['Y'] = LY_mat[:, 3] + 1j * LY_mat[:, 4]
raw_data['V'] = VRP_mat[:, 1] + 1j * VRP_mat[:, 2]
raw_data['R'] = VRP_mat[:, 3] + 1j * VRP_mat[:, 4]
raw_data['P'] = VRP_mat[:, 5] + 1j * VRP_mat[:, 6]
raw_data['T'] = T_mat[:, 1] + 1j * T_mat[:, 2]

#Calibration for the GS13s SUSPOINT outputs is always in nm.
calibration_to_um = 10**3  # To change the TF from [um/nm] to [um/um]
for dof in euldofs:
    raw_data[dof]=calibration_to_um*raw_data[dof]

    
#%%
# Transform L Y data in LF RT | V R P data into T1 T2 T3 data | T data into SD data


# Transform each one of the 3 sets of sensors
OSEM_data = {}
for ii, outdof in enumerate(osemdofs):
    for jj, indof in enumerate(euldofs):
        if S_eul2osem[ii, jj] != 0:
            
            if outdof not in OSEM_data:
                OSEM_data[outdof] = np.zeros_like(raw_data[indof])
                OSEM_data['freqs'] = {}
                OSEM_data['freqs'][outdof] = raw_data['freqs'][indof]
                
            OSEM_data[outdof] += S_eul2osem[ii, jj] * raw_data[indof]
#%%
# Load the (undamped) model
model_struct=spio.loadmat(f'{modelDir}{susType}_model',simplify_cells=True) 

susModel=signal.StateSpace(model_struct['susModel']['A'],
                           model_struct['susModel']['B'],
                           model_struct['susModel']['C'],
                           model_struct['susModel']['D'])

model_in=model_struct['susModel']['in']
model_out=model_struct['susModel']['out']

# Get frequency response L T V
modelResp = {}
for dof in drivedofs:
    in_dof = model_in['gnd']['disp'][dof]
    out_dof = model_out['m1']['disp'][dof]
    modelResp[dof] = SISOfreqresp(susModel,out_dof, in_dof, 2 * np.pi * raw_data['freqs'][dof])
    modelResp[dof] = modelResp[dof]-1; # We want the response to OSEM modtion
    
# Fit to a suitable window
fmin = 5  # [Hz] minimum frequency for fitting
fmax = 15  # [Hz] maximum frequency for fitting
calib = {}
calib_OSEM_data = {}

for ii, indof in enumerate(drivedofs):
    mask = (raw_data['freqs'][indof] > fmin) & (raw_data['freqs'][indof] < fmax)  # Select frequencies of interest
    for jj, outdof in enumerate(calibdofs[indof]):
        X = np.abs(OSEM_data[outdof][mask])
        Y = np.abs(modelResp[indof][mask])
        calib[outdof] = np.dot(X, Y) / np.dot(X, X)

        # Store the calibrated OSEM data for plotting
        calib_OSEM_data[outdof] = OSEM_data[outdof] * calib[outdof]
        
# Construct the multiplicative factor for the DAMP gains
# The factors are the diagonal terms of inv(O2E * Cal * pinv(O2E)).
# where no assumption is made about the number of OSEMs
calib_vec=np.diag([calib[dof] for dof in osemdofs])
eul2eul_calib_matrix=np.linalg.inv(S_osem2eul.dot(calib_vec.dot(np.linalg.pinv(S_osem2eul))))

#%%
# Plot the results to check if it all makes sense
for ii, indof in enumerate(drivedofs):
    freq = raw_data['freqs'][indof]
    figu=plt.figure()

    # AMPLITUDE
    plt.subplot(2, 1, 1)
    plt.loglog(freq, np.abs(modelResp[indof]), 'k')
    plt.title(f'ISI to OSEM transfer functions [BEFORE calibration] \n {measurement_datetime}')
    legText = ['Model']
    for jj, outdof in enumerate(calibdofs[indof]):
        plt.loglog(freq, np.abs(OSEM_data[outdof]))
        legText.append(f'({outdof} OSEM)/(Suspoint {indof})')
    plt.ylabel('Amplitude [(OSEM m)/(GS13 m)]')
    plt.grid(which='minor', linestyle='-', alpha=0.5)
    plt.xlim([0.1, 20])
    plt.ylim([1e-4,1e2])
    plt.legend(legText, loc='lower right')

    # ANGLE
    plt.subplot(2, 1, 2)
    plt.semilogx(freq, cleanAngle(modelResp[indof]), 'k')
    for jj, outdof in enumerate(calibdofs[indof]):
        plt.semilogx(freq, cleanAngle(OSEM_data[outdof]))
    plt.xlabel('Frequency [Hz]')
    plt.ylabel('Angle [deg]')
    plt.ylim([-200,200])
    plt.xlim([0.1, 20])
    figu.savefig(f'{ifo}SUS{susName}_before_{measurement_datetime}_calibration_{indof}.png')

# AFTER PLOTS (same as above but with calibrated data)
for ii, indof in enumerate(drivedofs):
    freq = raw_data['freqs'][indof]
    figu=plt.figure()

    # AMPLITUDE
    plt.subplot(2, 1, 1)
    plt.loglog(freq, np.abs(modelResp[indof]), 'k')
    plt.title(f'ISI to OSEM transfer functions [AFTER calibration] \n {measurement_datetime}')
    legText = ['Model']
    for jj, outdof in enumerate(calibdofs[indof]):
        plt.loglog(freq, np.abs(calib_OSEM_data[outdof]))
        legText.append(f'({outdof} OSEM)/(Suspoint {indof})')
    plt.ylabel('Amplitude [(OSEM m)/(GS13 m)]')
    plt.grid(which='minor', linestyle='-', alpha=0.5)
    plt.xlim([0.1, 20])
    plt.ylim([1e-4,1e2])
    plt.legend(legText, loc='lower right')

    # ANGLE
    plt.subplot(2, 1, 2)
    plt.semilogx(freq, cleanAngle(modelResp[indof]), 'k')
    for jj, outdof in enumerate(calibdofs[indof]):
        plt.semilogx(freq, cleanAngle(calib_OSEM_data[outdof]))
    plt.grid(which='minor', linestyle='-', alpha=0.5)
    plt.xlabel('Frequency [Hz]')
    plt.ylabel('Angle [deg]')
    plt.xlim([0.1, 20])
    plt.ylim([-200,200])
    figu.savefig(f'{ifo}SUS{susName}_after_{measurement_datetime}_calibration_{indof}.png')

#%%
#Find the input filter numbers
prev_sensInfGain={};
next_sensInfGain={};

# Find the alignment numbers
prev_eulAlign={}
next_eulAlign={}

for dof in osemdofs:
    chan_name=f'{ifo}:SUS-{susName}_{susStage}_OSEMINF_{dof}_GAIN'
    #ezca.connect(chan_name)
    prev_sensInfGain[dof]=round(getGainAtTime(chan_name, measurement_datetime),3)
    next_sensInfGain[dof]=round(calib[dof]*prev_sensInfGain[dof],3)

for dof in euldofs:
    chan_name=f'{ifo}:SUS-{susName}_{susStage}_DAMP_{dof}_INMON'
    #ezca.connect(chan_name)
    prev_eulAlign[dof]=round(getGainAtTime(chan_name, measurement_datetime),3)
    
new_eul_vec=eul2eul_calib_matrix.dot([prev_eulAlign[dof] for dof in euldofs])

for ii, dof in enumerate(euldofs):
    next_eulAlign[dof]=round(new_eul_vec[ii],3)
    
# for ii, dof in enumerate(euldofs):
#     chan_name=f'{ifo_2}:SUS-{susName_2}_{susStage}_DAMP_{dof}_GAIN'
#     ezca.connect(chan_name)
#     prev_controllerGain[dof]=ezca.read(chan_name)
#     next_controllerGain[dof]=round(eul2eul_calib[ii,ii]*prev_controllerGain[dof],3)
    
#%%    PRINT THE MESSAGE TO INFORM ABOUT CALIBRATION AND JUNK

message=[]
message.append(f'OSEM calibration of {ifo}:SUS-{susName}\n')
message.append(f'Stage: {susStage}\n{measurement_datetime} (UTC).\n')

message.append(f'\nThe suggested (calibrated) {susStage} OSEMINF gains are\n')
for dof in osemdofs:
    message.append(f'(new {dof}) = {calib[dof]:.3f} * (old {dof}) = {next_sensInfGain[dof]:.3f} \n')

message.append(f'\nTo compensate for the OSEM gain changes, we estimate that the {ifo}:SUS-{susName}_{susStage}_DAMP loops must be changed by factors of: \n')
for ii, dof in enumerate(euldofs):
    gg=round(eul2eul_calib_matrix[ii,ii],3)
    message.append(f'{dof} gain = {gg:.3f} * (old {dof} gain)\n')

message.append(f'\nThe calibration will change the apparent alignment of the suspension as seen by the at the {susStage} OSEMs\n')
message.append('NOTE: The actual alignment of the suspension will NOT change as a result of the calibration process\n\n')

message.append('The changes are computed as (osem2eul) * gain * inv(osem2eul).\n')
message.append(f'Using the alignments from {measurement_datetime} (UTC) as a reference, the new apparent alingments are:\n\n')


row = "{:<10} {:<20} {:<20} {:>10}"
header_row = "{:<10} {:<20} {:<20} {:>10}"
names=['DOF','Previous value','New value','Apparent change']
prev_a=[f'{prev_eulAlign[dof]: .1f} u{eulUnits(dof)}' for dof in euldofs]
next_a=[f'{next_eulAlign[dof]: .1f} u{eulUnits(dof)}' for dof in euldofs]
delta_a=[f'{next_eulAlign[dof]-prev_eulAlign[dof]:+.1f} u{eulUnits(dof)}' for dof in euldofs]

message.append(header_row.format(*names)+'\n')
message.append("-" * 81+'\n')
for a_, b_, c_, d_ in zip(euldofs, prev_a, next_a, delta_a):
    message.append(row.format(a_, b_, c_, d_ ))
    message.append('\n')

message.append(f'\nWe have estimated a OSEM calibration of {ifo} {susName} {susStage} using {isiName} {isiStage} drives from 2025-05-21_0000 (UTC).\n')
message.append(f'We fit the response {susStage}_DAMP/{isiName}_SUSPOINT between 5 and 15 Hz to get a calibration in [OSEM m]/[GS13 m]\n')    

this_filename=__file__.split('/')[-1]
message.append(f'\nThis message was generated automatically by {this_filename} on {datetime.now(timezone.utc)} UTC\n')

message.append('\n%%%%%%%%%%%%%%%%%%%%%%%%%%%% \n\n\n')
message.append('EXTRA INFORMATION \n\n')
message.append(f'The {ifo}:SUS-{susName}_{susStage}_OSEMINF gains at the time of measurement were:\n')
for dof in osemdofs:
    message.append(f'(old) {dof}: {prev_sensInfGain[dof]:.3f} \n')

message.append('\nThe matrix to convert from the old Euler dofs to the (calibrated) new Euler dofs is:\n\n')

message.append('\n'.join(['\t'.join(['{:+3}'.format(item) for item in row]) 
      for row in np.round(eul2eul_calib_matrix,3)]))
message.append('\n\nThe matrix is used as (M) * (old EUL dof) = (new EUL dof)\n')
message.append(f'The dof ordering is {euldofs}\n')
    
print(''.join(message))
