# -*- coding: utf-8 -*-
"""
Created on Wed Mar 17 12:20:32 2021

@author: Maurice
"""
from helper_code import *
import numpy as np, os, sys, joblib
import scipy
from scipy.interpolate import interp1d

#import biosppy
import neurokit2 as nk

import traceback

def feature_names():
    return string_statistics('P_R') + ',' + string_statistics('Q_T')+ ',' + string_statistics('S_T')+ ',' + string_statistics('R_T')+ ',' + string_statistics('Q_S')+ ',' + string_statistics('RR')+ ',' + string_statistics('RR-d')+ ',' + string_statistics('R/P')+ ',' + string_statistics('R/Q')+ ',' + string_statistics('R/T')+ ',' + string_statistics('R/S')+ ',' + string_statistics('P_ON')+ ',' + string_statistics('T_OFF') + ', RMS, RMSSD, SDSD, pNN50'

def get_feature_names(indices):
    all_feat_names = list()
    for lead in twelve_leads:
        names = feature_names().split(',')
        for idx in range(len(names)):
            names[idx]= lead+names[idx]
        all_feat_names = all_feat_names+names
    all_feat_names= ['age','sex'] + all_feat_names
    req_feats_names = list()
    for i in indices:
        req_feats_names.append(all_feat_names[i])
    return req_feats_names    
    
    
def string_statistics(name):
     return " mean("+ name + "), median(" + name + "), std(" + name + "), max(" + name + "), min(" + name + "), skew(" + name + "), kurtosis(" + name + ") "   

def compute_statistical_features(ecg_recording,sampling_rate,lead_names,resampling_freq=250):
    '''
    Compute some statistical features from ecg_recording based on neurokit2 and biosppy

    Parameters
    ----------
    ecg_recording : ndarray
        ECG recordings of (12) leads
    sampling_rate : float
        sampling rate
    lead_names : list
        list of lead names

    Returns
    -------
    numpy array
        flattened feature vector

    '''
    n_leads=len(lead_names)
    n_features= 13*7+1+3 # 13 time&peak features per segment times 7 statistics, 1 mean_energy, 3HRV time features,[ 14 hrv time features (redundant to 13*7 features)]
    features = np.zeros((n_leads,n_features))
    try: 
            # for each lead separately 
            #TODO use information from I and II lead for segmenation of others
            for lead_idx,lead in enumerate(lead_names):
                
                #if lead == 'II':
                
                    # preprocess ecg signal: change to custom function doing some bandpass filtering
                    
                    clean_ecg_signal = nk.ecg_clean(ecg_recording[lead_idx,:],sampling_rate,method='neurokit') #butterworth highpass+powerline (50Hz) alt:'pantompkins1985', 'neurokit'
                    
                    # resample signal
                    #clean_ecg_signal = signal.resample(clean_ecg_signal, round(len(clean_ecg_signal)*resampling_freq/sampling_rate))
                    t_orig = np.arange((clean_ecg_signal.shape[0]))/sampling_rate
                    t_res = np.arange(t_orig[0],t_orig[-1],1/resampling_freq)
                    resample = interp1d(t_orig,clean_ecg_signal)
                    clean_ecg_signal_r = resample(t_res)
                    
                    
                    # find R-peaks / QRS-complexes
                    
                    #signals,info = nk.ecg_peaks(clean_ecg_signal,sampling_rate,method='pamtompkins1985',correct_artifacts=true)
                    r_peaks = nk.ecg_findpeaks(clean_ecg_signal_r,resampling_freq,method='neurokit',show=False)['ECG_R_Peaks'] #'pantompkins1985' online in original, 'christov2004', 'rodrigues2021' very fast (wearables)
                    
                    #find p Wave, T wave
                    waves, signals_delin = nk.ecg_delineate(clean_ecg_signal_r,rpeaks=r_peaks,sampling_rate=resampling_freq,method='peaks',show=False)
                    #segments = nk.ecg_segment(clean_ecg_signal_r,r_peaks,sampling_rate)
                
                    #cardiac_phase = nk.ecg_phase(clean_ecg_signals,rpeaks=r_peaks_idx,delineate_info=signals_delin,sampling_rate=sampling_rate)
                    #nk.signal_plot([clean_ecg_signals, cardiac_phase], standardize=True)
                    
                    #heart_rate = nk.ecg_rate(r_peaks,sampling_rate,desired_length=None)
                    
                    #TODO check if PQRST found correctly
                    
                    p_onsets = np.array(signals_delin['ECG_P_Onsets'])
                    p_peaks = np.array(signals_delin['ECG_P_Peaks'])
                    q_peaks  = np.array(signals_delin['ECG_Q_Peaks'])
                    s_peaks  = np.array(signals_delin['ECG_S_Peaks'])
                    t_offsets = np.array(signals_delin['ECG_T_Offsets'])
                    t_peaks = np.array(signals_delin['ECG_T_Peaks'])
                    #t_peaks = t_peaks[np.isfinite(t_peaks)]
                    # maybe leave out first and last beat ( qrs-detectors like pan-tompson need 1 or 2 beats to tune)
                    #mark all indices for all beats where qpst peaks not found
                    beat_is_ok = np.all(np.array((np.isfinite(p_onsets), np.isfinite(p_peaks), np.isfinite(q_peaks), np.isfinite(s_peaks), np.isfinite(t_offsets), np.isfinite(t_peaks))),0)
                    
                    
                    try:
                            assert len(r_peaks) == len(p_peaks) == len(q_peaks) == len(s_peaks) == len(t_peaks)
                    except :
                        #continue
                        print("PQRST differ in length: ")
                        print(len(r_peaks), len(p_peaks),len(q_peaks),len(s_peaks), len(t_peaks))
                        
                        
                    hrv_feats = nk.hrv_time(r_peaks, sampling_rate=resampling_freq, show=False)    # time hrv features
                    
                    RMSSD = hrv_feats['HRV_RMSSD']
                    SDSD = hrv_feats['HRV_SDSD']
                    pNN50 = hrv_feats['HRV_pNN50']
                    
                    
                    
                    # static Features
                    
                    pr_p2p = (r_peaks[beat_is_ok] - p_peaks[beat_is_ok])/resampling_freq
                    qt_p2p = (t_peaks[beat_is_ok] - q_peaks[beat_is_ok])/resampling_freq
                    st_p2p = (t_peaks[beat_is_ok] - s_peaks[beat_is_ok])/resampling_freq
                    rt_p2p = (t_peaks[beat_is_ok] - r_peaks[beat_is_ok])/resampling_freq
                    qs_p2p = (s_peaks[beat_is_ok] - q_peaks[beat_is_ok])/resampling_freq
                    
                    #p_onset
                    #t_offset
            
            
                    rr = np.diff(r_peaks)/resampling_freq
                    rr_diff = np.diff(rr)/resampling_freq
                    
                    #TODO normalize signal and use absolute values
                    rp_ratio = clean_ecg_signal_r[r_peaks[beat_is_ok]]/clean_ecg_signal_r[p_peaks[beat_is_ok].astype(np.int)]
                    rq_ratio = clean_ecg_signal_r[r_peaks[beat_is_ok]]/clean_ecg_signal_r[q_peaks[beat_is_ok].astype(np.int)]
                    rt_ratio = clean_ecg_signal_r[r_peaks[beat_is_ok]]/clean_ecg_signal_r[t_peaks[beat_is_ok].astype(np.int)]
                    rs_ratio = clean_ecg_signal_r[r_peaks[beat_is_ok]]/clean_ecg_signal_r[s_peaks[beat_is_ok].astype(np.int)]
                    
                    # signal quality measure, not sure about this, maybe usable to skip certain beats
                    #signal_quality_mean = np.mean(nk.ecg_quality(clean_ecg_signal_r,r_peaks,sampling_rate))
                    
                    # compute average signal energy
                    # energy_mean=0
                    # counter=0
                    # for segment in segments:
                    #     if not np.isnan(segments[segment].values[:,0].astype(float)).any(): 
                    #         counter += 1
                    #         energy_mean = energy_mean + np.sum(segments[segment].values[:,0]**2/sampling_rate)
                    # energy_mean = energy_mean/counter   # TODO normalize that and other features
                    
                    # n_beats = len(r_peaks)
                    
                    RMS = np.sqrt(np.sum(clean_ecg_signal_r**2) / np.size(clean_ecg_signal_r))
                    
                    #features = np.concatenate((pr_p2p,qt_p2p,st_p2p,rt_p2p,qs_p2p,rr,rr_diff,rp_ratio,rq_ratio,rt_ratio,rs_ratio,p_onsets,t_offsets,np.array(energy_mean,ndmin=1)),axis=0)
                    #use statistics of features
                    features[lead_idx,:] = np.concatenate((get_statistics(pr_p2p),get_statistics(qt_p2p),get_statistics(st_p2p),
                                                           get_statistics(rt_p2p),get_statistics(qs_p2p),get_statistics(rr),get_statistics(rr_diff),get_statistics(rp_ratio),
                                                           get_statistics(rq_ratio),get_statistics(rt_ratio),get_statistics(rs_ratio),get_statistics(p_onsets),get_statistics(t_offsets),np.array((RMS, RMSSD, SDSD, pNN50))),axis=0)
            return features.flatten() , np.arange(n_leads+1)*n_features
    except Exception as e:
        print('skipped recording')
        return np.ones((n_leads,n_features)).flatten()*-1 , np.zeros(n_leads+1)



def get_statistics(values):
    try:
        return np.nanmean(values,axis=0), np.nanmedian(values,axis=0), np.nanstd(values,axis=0), np.nanmax(values,axis=0), np.nanmin(values,axis=0), scipy.stats.skew(values,axis=0,nan_policy='omit'), scipy.stats.kurtosis(values,axis=0,nan_policy='omit')
    except Exception as e:
        #traceback.print_exc()
        return -1, -1, -1, -1, -1, -1, -1


# Extract features from the header and recording.
def get_features_simple(header, recording, leads):
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Pre-process recordings.
    adc_gains = get_adcgains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]

    # Compute the root mean square of each ECG lead signal.
    rms = np.zeros(num_leads, dtype=np.float32)
    for i in range(num_leads):
        x = recording[i, :]
        rms[i] = np.sqrt(np.sum(x**2) / np.size(x))

    return age, sex, rms