#%%
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue May 13 15:41:51 2025

This script is a prototype for an OSEM calibration routine.
The idea is to drive the ISI in the translational cartesian degrees of freedom, 
and use the results to calibrate the upper mass OSEMs of various suspensions

- May 2025, Edgard Bonilla

+ May 14, 2025 Implemented as an SR3 calibration script. Requires measuremens to have been taken
@author: controls
"""

import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as signal
import scipy.io as spio
import ezca
from gwpy.timeseries import TimeSeries
from gwpy.time import tconvert
from datetime import datetime, timezone

ezca=ezca.Ezca()

## Auxiliary functions
# Clean the angle for plots
def cleanAngle(TF):
    bit = 0.1
    ang = np.angle(TF * np.exp(1j * np.pi * bit))
    ang = ang - np.pi * bit
    ang = ang * 180 / np.pi
    return ang

# Returns the codename for the type of a particualr suspension given its name. For example,
# PR3 returns 'HLTS', IM3 returns 'HAUX', etc.
def findSusType(given_susName):
    susName=given_susName.upper()
    
    # Dictionary that keeps all of the suspension name associations
    possible_types={
        'QUAD':('QUAD'),
        'BSFM':('BS'),
        'HLTS':('PR3' ,'SR3'),
        'HSTS':('FC1', 'FC2', 'MC1' ,'MC2' , 'MC3' , 'PR2' , 'PRM' , 'SR2' , 'SRM' ),
        'TMTS':('TMSX' , 'TMSY' ),
        'OMCS':('OMC'),
        'OFIS':('OFI'),
        'OPOS':('OPO'),
        'HAUX':('IM1' , 'IM2' , 'IM3' , 'IM4' ),
        'HXDS':('OM2' , 'ZM1' , 'ZM2' , 'ZM3' , 'ZM4' , 'ZM5' , 'ZM6'),
        'HTTS':('OM1' , 'OM3' , 'RM1' , 'RM2')}
    
    for susType in possible_types:
        for sus in possible_types[susType]:
            if sus == susName:
                return susType # Exit early when we find the correct suspension Type
    # This part of the code can only be reached if we did not find a correct type
    raise NameError(f'The suspension name {given_susName} is not listed in this script')

# Returns the name of the top stage of a given SusType 
#TODO: implement this for more than a few suspensions
#TODO: implement this by using a dictionary when we have more suspensions
def findTopStageName(susType):
    susType=susType.upper()
    if susType == 'QUAD':
        return 'M0'
    elif susType == 'HLTS' or susType == 'HSTS':
        return 'M1'
    else:
        raise NotImplementedError(f'Sorry, unrecognized susType {susType}. Try QUAD, HLTS, or HSTS.')

def calibPairings(drivedofs, euldofs,osemdofs,OSEM2EUL):
    # Creates a dictionary to pair 'L' 'T' and 'V' with the corresponding OSEM dofs
    
    indices=[ euldofs.index(drivedof) for drivedof in drivedofs]
    
    calibdofs={}
    for drivedof , jj in zip(drivedofs,indices):
        calibdofs[drivedof]=tuple([dof for ii, dof in enumerate(osemdofs) if OSEM2EUL[jj,ii]!=0])
    
    return calibdofs
    
# This function is here because Python sucks.
# Scipy does not have a native way of dealing with MIMO systems. 
# I cannot guarantee that LHO/LLO have python-control as part of their suite
# This assumes MATLAB indexing (which starts at 1)
def SISOfreqresp(ss,Out,In,wlist):
    In=In-1;
    Out=Out-1;
    A2=np.array(ss.A)
    B2=np.reshape(ss.B[:,In] ,[len(ss.A),1])
    C2=np.reshape(ss.C[Out,:],[1,len(ss.A)])
    D2=np.reshape(ss.D[Out,In] ,[1,1])
    n=len(ss.A);
    
    # Return the frequency response
    resp=[(np.matmul(C2,np.linalg.inv(1j*w*np.eye(n)-A2).dot(B2)) +D2).flatten() for w in wlist];
    
    return np.array(resp).flatten()

def getGPSdatetime(timestring):
    #Converts a time string with the YYYY-MM-DD_hhmm to gps time
    # Assumes UTC time
    YYYY, MM, DD_hhmm=timestring.split('-')
    DD, hhmm=DD_hhmm.split('_')
    
    date=datetime(int(YYYY), int(MM), int(DD), int(hhmm[0:2]),int(hhmm[2:4]), tzinfo=timezone.utc)
    GPSdate=tconvert(date)
    
    return GPSdate

def getGainAtTime(chan_name,timestring):
    ifo=chan_name.split(':')[0]
    if ifo=='H1':
        hostname='nds.ligo-wa.caltech.edu'
    elif ifo=='L1':
        hostname='nds.ligo-la.caltech.edu'   
    
    gps=getGPSdatetime(timestring)
    duration=0.5; #[s] duration of the gain obtaining
    
    gain=np.mean(TimeSeries.get(chan_name,gps,gps+duration,allow_tape=True,host=hostname).value)
                 
    return gain
#%% 
# Initialize settings
plt.rcParams['lines.linewidth'] = 2  # Thick lines
plt.rcParams['figure.figsize'] = [8, 8] # width=8 inches, height=8 inches
plt.rcParams['axes.grid'] = True

# Define the coordinates to find the data and the model
svnDir = '/home/controls/SusSVN/sus/trunk/'
ifo='H1'
isiName='HAM5'
isiStage='ST1'
susName='SR3'
measurement_datetime='2025-05-07_0000'


susType=findSusType(susName)

toolsDir=f'{svnDir}Common/MatlabTools/'
importDir=f'{toolsDir}ExportedModels/'
modelDir=importDir # The models are inside the "ExportedModels directroy in MatlabTools
projDir=importDir  # The projection matrix master file is inside the "ExportedModels directroy in MatlabTools
Data_folder = f'{svnDir}{susType}/{ifo}/{susName}/Common/Data/'
#%% 
# Load the EUL2OSEM matrix and the (ordered) dof names
# NOTE: The README gets jumbled up, but the projectrion matrices are kept intact thanks to simplify_cells
projection_struct = spio.loadmat(f'{projDir}SUS_projection_file',simplify_cells=True) 

susStage=projection_struct['SUS_projections'][susType]['susStages']['Sensor'][0] #Find the stage of the suspension that we are calibrating

#TODO: Deal with incomplete bases of OSEM/EUL, data will need to be imported per OSEM
#NOTE: The code assumes sensalign is the identity matrix
S_osem2eul=  projection_struct['SUS_projections'][susType][susStage]['OSEM2EUL']
S_eul2osem = np.linalg.inv(S_osem2eul)


osemdofs=tuple(projection_struct['SUS_projections'][susType][susStage]['osemdof']) #Make into tuple to prevent mutability
euldofs =tuple(projection_struct['SUS_projections'][susType][susStage]['euldof']) #Make into tuple to prevent mutability
drivedofs = ('L', 'T', 'V') # We always drive the translational euler degrees of freedom

#This function finds the DOFs that are paired with 
#TODO: Deal with not observing T or V
calibdofs=calibPairings(drivedofs, euldofs,osemdofs,S_osem2eul)

#%% 
# TODO: convert this into a function to import EUL data
# Grab the data and organize it in LTVRPY
LY_filename =  f'{measurement_datetime}_{ifo}ISI{isiName}_{isiStage}_WhiteNoise_ISO_Y_0p05to40Hz_calibration.txt'
VRP_filename = f'{measurement_datetime}_{ifo}ISI{isiName}_{isiStage}_WhiteNoise_ISO_Z_0p05to40Hz_calibration.txt'
T_filename =   f'{measurement_datetime}_{ifo}ISI{isiName}_{isiStage}_WhiteNoise_ISO_X_0p05to40Hz_calibration.txt'

# Extract the matrices
LY_mat = np.loadtxt(f'{Data_folder}{LY_filename}')
VRP_mat = np.loadtxt(f'{Data_folder}{VRP_filename}')
T_mat = np.loadtxt(f'{Data_folder}{T_filename}')

raw_data = {}  # Dictionary to organize the dataset
raw_data['freqs'] = {}
raw_data['freqs']['L'] = LY_mat[:, 0]
raw_data['freqs']['Y'] = LY_mat[:, 0]
raw_data['freqs']['V'] = VRP_mat[:, 0]
raw_data['freqs']['R'] = VRP_mat[:, 0]
raw_data['freqs']['P'] = VRP_mat[:, 0]
raw_data['freqs']['T'] = T_mat[:, 0]

# Grab the DOFs from each transfer function measurement
raw_data['L'] = LY_mat[:, 1] + 1j * LY_mat[:, 2]
raw_data['Y'] = LY_mat[:, 3] + 1j * LY_mat[:, 4]
raw_data['V'] = VRP_mat[:, 1] + 1j * VRP_mat[:, 2]
raw_data['R'] = VRP_mat[:, 3] + 1j * VRP_mat[:, 4]
raw_data['P'] = VRP_mat[:, 5] + 1j * VRP_mat[:, 6]
raw_data['T'] = T_mat[:, 1] + 1j * T_mat[:, 2]

#Calibration for the GS13s SUSPOINT outputs is always in nm.
calibration_to_um = 10**3  # To change the TF from [um/nm] to [um/um]
for dof in euldofs:
    raw_data[dof]=calibration_to_um*raw_data[dof]

    
#%%
# Transform L Y data in LF RT | V R P data into T1 T2 T3 data | T data into SD data


# Transform each one of the 3 sets of sensors
OSEM_data = {}
for ii, outdof in enumerate(osemdofs):
    for jj, indof in enumerate(euldofs):
        if S_eul2osem[ii, jj] != 0:
            
            if outdof not in OSEM_data:
                OSEM_data[outdof] = np.zeros_like(raw_data[indof])
                OSEM_data['freqs'] = {}
                OSEM_data['freqs'][outdof] = raw_data['freqs'][indof]
                
            OSEM_data[outdof] += S_eul2osem[ii, jj] * raw_data[indof]
#%%
# Load the (undamped) model
model_struct=spio.loadmat(f'{modelDir}{susType}_model',simplify_cells=True) 

susModel=signal.StateSpace(model_struct['susModel']['A'],
                           model_struct['susModel']['B'],
                           model_struct['susModel']['C'],
                           model_struct['susModel']['D'])

model_in=model_struct['susModel']['in']
model_out=model_struct['susModel']['out']

# Get frequency response L T V
modelResp = {}
for dof in drivedofs:
    in_dof = model_in['gnd']['disp'][dof]
    out_dof = model_out['m1']['disp'][dof]
    modelResp[dof] = SISOfreqresp(susModel,out_dof, in_dof, 2 * np.pi * raw_data['freqs'][dof])
    modelResp[dof] = modelResp[dof]-1; # We want the response to OSEM modtion
    
# Fit to a suitable window
fmin = 5  # [Hz] minimum frequency for fitting
fmax = 15  # [Hz] maximum frequency for fitting
calib = {}
calib_OSEM_data = {}

for ii, indof in enumerate(drivedofs):
    mask = (raw_data['freqs'][indof] > fmin) & (raw_data['freqs'][indof] < fmax)  # Select frequencies of interest
    for jj, outdof in enumerate(calibdofs[indof]):
        X = np.abs(OSEM_data[outdof][mask])
        Y = np.abs(modelResp[indof][mask])
        calib[outdof] = np.dot(X, Y) / np.dot(X, X)

        # Store the calibrated OSEM data for plotting
        calib_OSEM_data[outdof] = OSEM_data[outdof] * calib[outdof]
        
# Construct the multiplicative factor for the DAMP gains
# The factors are the diagonal terms of inv(O2E * Cal * pinv(O2E)).
# where no assumption is made about the number of OSEMs
calib_vec=np.diag([calib[dof] for dof in osemdofs])
eul2eul_calib=np.linalg.inv(S_osem2eul.dot(calib_vec.dot(np.linalg.pinv(S_osem2eul))))

#%%
# Plot the results to check if it all makes sense
for ii, indof in enumerate(drivedofs):
    freq = raw_data['freqs'][indof]
    plt.figure()

    # AMPLITUDE
    plt.subplot(2, 1, 1)
    plt.loglog(freq, np.abs(modelResp[indof]), 'k')
    plt.title(f'ISI to OSEM transfer functions [BEFORE calibration] \n {measurement_datetime}')
    legText = ['Model']
    for jj, outdof in enumerate(calibdofs[indof]):
        plt.loglog(freq, np.abs(OSEM_data[outdof]))
        legText.append(f'({outdof} OSEM)/(Suspoint {indof})')
    plt.ylabel('Amplitude [(OSEM m)/(GS13 m)]')
    plt.grid(which='minor', linestyle='-', alpha=0.5)
    plt.xlim([0.1, 20])
    plt.ylim([1e-4,1e2])
    plt.legend(legText, loc='lower right')

    # ANGLE
    plt.subplot(2, 1, 2)
    plt.semilogx(freq, cleanAngle(modelResp[indof]), 'k')
    for jj, outdof in enumerate(calibdofs[indof]):
        plt.semilogx(freq, cleanAngle(OSEM_data[outdof]))
    plt.xlabel('Frequency [Hz]')
    plt.ylabel('Angle [deg]')
    plt.ylim([-200,200])
    plt.xlim([0.1, 20])

# AFTER PLOTS (same as above but with calibrated data)
for ii, indof in enumerate(drivedofs):
    freq = raw_data['freqs'][indof]
    plt.figure()

    # AMPLITUDE
    plt.subplot(2, 1, 1)
    plt.loglog(freq, np.abs(modelResp[indof]), 'k')
    plt.title(f'ISI to OSEM transfer functions [AFTER calibration] \n {measurement_datetime}')
    legText = ['Model']
    for jj, outdof in enumerate(calibdofs[indof]):
        plt.loglog(freq, np.abs(calib_OSEM_data[outdof]))
        legText.append(f'({outdof} OSEM)/(Suspoint {indof})')
    plt.ylabel('Amplitude [(OSEM m)/(GS13 m)]')
    plt.grid(which='minor', linestyle='-', alpha=0.5)
    plt.xlim([0.1, 20])
    plt.ylim([1e-4,1e2])
    plt.legend(legText, loc='lower right')

    # ANGLE
    plt.subplot(2, 1, 2)
    plt.semilogx(freq, cleanAngle(modelResp[indof]), 'k')
    for jj, outdof in enumerate(calibdofs[indof]):
        plt.semilogx(freq, cleanAngle(calib_OSEM_data[outdof]))
    plt.grid(which='minor', linestyle='-', alpha=0.5)
    plt.xlabel('Frequency [Hz]')
    plt.ylabel('Angle [deg]')
    plt.xlim([0.1, 20])
    plt.ylim([-200,200])
    
#%%
#Find the input filter numbers
prev_sensInfGain={};
next_sensInfGain={};

# #TODO: Change this so it actually reads from the site
# ifo_2='S1'
# susName_2='PR3'
# prev_controllerGain={};
# next_controllerGain={};
# Collect the NEW Input Filter gains
for dof in osemdofs:
    chan_name=f'{ifo}:SUS-{susName}_{susStage}_OSEMINF_{dof}_GAIN'
    #ezca.connect(chan_name)
    prev_sensInfGain[dof]=round(getGainAtTime(chan_name, measurement_datetime),3)
    next_sensInfGain[dof]=round(calib[dof]*prev_sensInfGain[dof],3)
    
# for ii, dof in enumerate(euldofs):
#     chan_name=f'{ifo_2}:SUS-{susName_2}_{susStage}_DAMP_{dof}_GAIN'
#     ezca.connect(chan_name)
#     prev_controllerGain[dof]=ezca.read(chan_name)
#     next_controllerGain[dof]=round(eul2eul_calib[ii,ii]*prev_controllerGain[dof],3)
    
#%%    
message=[]
message.append(f'We have estimated a OSEM calibration of {ifo} {susName} {susStage} using {isiName} {isiStage} drives from {measurement_datetime} (UTC).\n')
message.append(f'We fit the response {susStage}_DAMP/{isiName}_SUSPOINT between {fmin} and {fmax} Hz to get a calibration in [OSEM m]/[GS13 m]\n\n')
message.append(f'The {ifo}:SUS-{susName}_{susStage}_OSEMINF gains at the time of measurement were:\n')
for dof in osemdofs:
    message.append(f'(old) {dof}: {prev_sensInfGain[dof]:.3f} \n')
    
message.append(f'\nThe suggested (calibrated) {susStage} OSEMINF gains are\n')
for dof in osemdofs:
    message.append(f'(new) {dof}: {next_sensInfGain[dof]} \n')

message.append(f'\nTo compensate for the OSEM gain changes, we estimate that the {ifo}:SUS-{susName}_{susStage}_DAMP loops must be changed by a factors of: \n')
for ii, dof in enumerate(euldofs):
    gg=round(eul2eul_calib[ii,ii],3)
    message.append(f'{dof} gain = {gg:.3f} * (old {dof} gain)\n')
    
this_filename=__file__.split('/')[-1]
message.append(f'\nThis message was generated automatically by {this_filename} on {datetime.now(timezone.utc)} UTC')
print(''.join(message))