#!/usr/bin/env python

# Edit this script to add your team's code. Some functions are *required*, but you can edit most parts of the required functions,
# change or remove non-required functions, and add your own functions.

################################################################################
#
# Import libraries and functions. You can change or remove them.
#
################################################################################

from helper_code import *
import numpy as np, scipy as sp, scipy.stats, os, sys, joblib
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import normalize
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
import pickle as pk
from scipy import signal
from scipy.stats import multivariate_normal
from scipy.io import loadmat
from scipy.fft import rfft, rfftfreq
import pywt
from functools import partial
from multiprocessing import Pool

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import Sequential, Model
from tensorflow.python.keras.models import Input
from tensorflow.keras.layers import Dense, Dropout, Flatten, BatchNormalization, LeakyReLU, LSTM, concatenate, Conv2D, MaxPooling2D
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt

################################################################################
#
# Required functions. Edit these functions to add your code, but do not change the arguments.
#
################################################################################

# Train your model.
def train_challenge_model(data_folder, model_folder, verbose):
    # Find data files.
    if verbose >= 1:
        print('Finding data files...')

    # Find the patient data files.
    patient_files = find_patient_files(data_folder)
    num_patient_files = len(patient_files)

    if num_patient_files==0:
        raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    os.makedirs(model_folder, exist_ok=True)

    murmur_classes = ['Present', 'Unknown', 'Absent']
    num_murmur_classes = len(murmur_classes)
    outcome_classes = ['Abnormal', 'Normal']
    num_outcome_classes = len(outcome_classes)

    if verbose >= 1:
        print('Extracting features and labels from the Challenge data...')
    np.seterr(divide = 'ignore') 

    alg = 'cnn'

#%% Get all features for parent neural net
    crafted_features = list()
    murmurs = list()
    outcomes = list()

    for i in range(num_patient_files):
        # Load the current patient data and recordings.
        current_patient_data = load_patient_data(patient_files[i])
        current_recordings = load_recordings(data_folder, current_patient_data)

        fs = get_frequency(current_patient_data)
        # fs_resamp = 1000
        # max_len = 60 # second
        
        # Extract crafted features. #336 features per recording
        current_crafted_features = np.zeros(336*4)
        locations = get_locations(current_patient_data)
        for j, location in enumerate(locations):
            if 'AV' in location:
                location_ind = 0
            elif 'PV' in location:
                location_ind = 1
            elif 'TV' in location:
                location_ind = 2
            elif 'MV' in location:
                location_ind = 3
            
            this_recording = current_recordings[j]#[:int(max_len*fs)]
            current_features = get_crafted_features(current_patient_data, this_recording)
            if current_features is not None:
                current_crafted_features[336*location_ind:336*(location_ind+1)] = current_features

        crafted_features.append(current_crafted_features)

        # Extract features.
        if alg == 'cnn':
            current_features = get_cnn_features(current_patient_data, current_recordings)
        elif alg == 'rnn':
            current_features = get_rnn_features(current_patient_data, current_recordings)

        current_features = np.expand_dims(current_features,axis=0) #make 0th axis batch size
        
        if i == 0:
            features = current_features
        else: 
            features = np.concatenate((features,current_features)) 

        # Extract labels and use one-hot encoding.
        current_murmur = np.zeros(num_murmur_classes, dtype=int)
        murmur = get_murmur(current_patient_data)
        if murmur in murmur_classes:
            j = murmur_classes.index(murmur)
            current_murmur[j] = 1
        murmurs.append(current_murmur)

        current_outcome = np.zeros(num_outcome_classes, dtype=int)
        outcome = get_outcome(current_patient_data)
        if outcome in outcome_classes:
            j = outcome_classes.index(outcome)
            current_outcome[j] = 1
        outcomes.append(current_outcome)
 
    crafted_features = np.vstack(crafted_features)
    murmurs = np.vstack(murmurs)
    outcomes = np.vstack(outcomes)
        
    if verbose >= 1:
        print('Done.')
        
#%% Train model
    #imputer = SimpleImputer().fit(features)
    #features = imputer.transform(features)
    
    # Add another dimension for cnn
    # features = np.expand_dims(features,axis=len(features.shape)) 

    # cross validate CNN
    
    # histories_murmurs = cross_validate(features, crafted_features, murmurs, alg=alg)
    # histories_outcomes = cross_validate(features, crafted_features, outcomes, alg=alg)

    # with open('../../model/cv_histories_murmurs.pk', 'wb') as file: 
    #     pk.dump(histories_murmurs, file)
    # with open('../../model/cv_histories_outcomes.pk', 'wb') as file: 
    #     pk.dump(histories_outcomes, file)

    murmur_classifier, history_murmurs = run_parent_net(features, crafted_features, murmurs)
    outcome_classifier, history_outcomes = run_parent_net(features, crafted_features, outcomes)

    # with open('../../model/history_murmurs.pk', 'wb') as file:
    #     pk.dump(history_murmurs, file)
    # with open('../../model/history_outcomes.pk', 'wb') as file:
    #     pk.dump(history_outcomes, file)

    imputer = list() #no imputation required in recording data
    save_challenge_model(model_folder, imputer, murmur_classes, murmur_classifier, outcome_classes, outcome_classifier)

#%% Load your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.

def load_challenge_model(model_folder, verbose):
    #models = pk.load(open(os.path.join(model_folder, 'model.pk'), 'rb'))
    #imputer = pk.load(open(os.path.join(model_folder, 'imputer.pk'), 'rb'))
    #classes = pk.load(open(os.path.join(model_folder, 'classes.pk'), 'rb'))
    #return [models, imputer, classes]
    filename = os.path.join(model_folder, 'model.sav')
    all_vars = joblib.load(filename)
    murmur_classifier = tf.keras.models.load_model(os.path.join(model_folder, 'murmur_classifier'))
    outcome_classifier = tf.keras.models.load_model(os.path.join(model_folder, 'outcome_classifier'))
    all_vars['murmur_classifier'] = murmur_classifier
    all_vars['outcome_classifier'] = outcome_classifier
    return all_vars

#%% Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the
# arguments of this function.
def run_challenge_model(model, data, recordings, verbose):
    imputer = model['imputer']
    murmur_classes = model['murmur_classes']
    murmur_classifier = model['murmur_classifier']
    outcome_classes = model['outcome_classes']
    outcome_classifier = model['outcome_classifier']

    # Extract crafted features. #336 features per recording
    crafted_features = np.zeros(336*4)
    locations = get_locations(data)
    for j, location in enumerate(locations):
        if 'AV' in location:
            location_ind = 0
        elif 'PV' in location:
            location_ind = 1
        elif 'TV' in location:
            location_ind = 2
        elif 'MV' in location:
            location_ind = 3
        
        this_recording = recordings[j]
        current_features = get_crafted_features(data, this_recording)
        if current_features is not None:
            crafted_features[336*location_ind:336*(location_ind+1)] = current_features

    # Extract features.
    features = get_cnn_features(data, recordings)
    features = np.expand_dims(features,axis=0) #make 0th axis batch size
    crafted_features = np.expand_dims(crafted_features,axis=0)

    # features = features.reshape(1, -1)
    # crafted_features = crafted_features.reshape(1, -1)

    # Get classifier probabilities.
    murmur_probabilities = murmur_classifier.predict([features, crafted_features])
    murmur_probabilities = np.asarray(murmur_probabilities, dtype=np.float32)[0]
    outcome_probabilities = outcome_classifier.predict([features, crafted_features])
    outcome_probabilities = np.asarray(outcome_probabilities, dtype=np.float32)[0]

    # Choose label with highest probability.
    murmur_labels = np.zeros(len(murmur_classes), dtype=np.int_)
    idx = np.argmax(murmur_probabilities)
    murmur_labels[idx] = 1
    outcome_labels = np.zeros(len(outcome_classes), dtype=np.int_)
    idx = np.argmax(outcome_probabilities)
    outcome_labels[idx] = 1

    # Concatenate classes, labels, and probabilities.
    classes = murmur_classes + outcome_classes
    labels = np.concatenate((murmur_labels, outcome_labels))
    probabilities = np.concatenate((murmur_probabilities, outcome_probabilities))

    return classes, labels, probabilities


#%%###############################################################################
#
# Optional functions. You can change or remove these functions and/or add new functions.
#
################################################################################

# Save your trained model.
def save_challenge_model(model_folder, imputer, murmur_classes, murmur_classifier, outcome_classes, outcome_classifier):
    #pk.dump(models, open(os.path.join(model_folder, 'models.pk'), 'wb'))
    #pk.dump(imputer, open(os.path.join(model_folder, 'imputer.pk'), 'wb'))
    #pk.dump(classes, open(os.path.join(model_folder, 'classes.pk'), 'wb'))
    d = {'imputer': imputer, 'murmur_classes': murmur_classes, 'outcome_classes': outcome_classes}
    filename = os.path.join(model_folder, 'model.sav')
    joblib.dump(d, filename, protocol=0)

    murmur_classifier.save(os.path.join(model_folder, 'murmur_classifier'))
    outcome_classifier.save(os.path.join(model_folder, 'outcome_classifier'))

    
def downsample_recording(x,n):
    end =  n * int(len(x)/n)
    return np.mean(x[:end].reshape(-1, n), 1)


# Extract features for CNN
def get_cnn_features(data, recordings):
    locations = get_locations(data)
    fs = get_frequency(data)
    fs_resamp = 1000
    # fs_resamp = fs
    time = 64512 / fs_resamp #seconds
    # time = 5
    
    nperseg = 256 #default
    noverlap = nperseg / 8 #default
    
    n_freq_samples = int((nperseg/2+1))
    n_time_samples = int((time*fs_resamp - noverlap)/ (nperseg - noverlap))
    spect_current_patient = np.zeros((n_freq_samples,n_time_samples,4))
        
    for i, location in enumerate(locations):
        x = recordings[i] #note that if there are multiple recordings, this will only load in the first
            
        #downsample to 1000Hz
        x = downsample_recording(x,int(fs/fs_resamp))
        
        #lowpass and highpass filter + spike removal
        x = preprocess_data(x, fs_resamp);
        
        # 0-pad data to maximum length
        if len(x) >= time*fs_resamp: # choose first max-time s
            x = x[:int(time*fs_resamp)]
        elif len(x) < time*fs_resamp:
            len_diff = time*fs_resamp - len(x)
            x = np.pad(x,(0,int(len_diff)), mode='constant', constant_values=0)
        
        _,_,spectro = signal.spectrogram(x, fs=fs_resamp, nperseg = nperseg, noverlap = noverlap)

        if 'AV' in location:
            location_ind = 0
        elif 'PV' in location:
            location_ind = 1
        elif 'TV' in location:
            location_ind = 2
        elif 'MV' in location:
            location_ind = 3
        spect_current_patient[:n_freq_samples,:n_time_samples,location_ind] = spectro
    
    return np.asarray(spect_current_patient, dtype=np.float32)

# Extract features for RNN
def get_rnn_features(data, recordings):
    
    locations = get_locations(data)
    fs = get_frequency(data)
    fs_resamp = 1000
    time = 64512 / fs_resamp #seconds
    # time = 5
    
    nperseg = 256 #default
    noverlap = nperseg / 8 #default
    
    n_freq_samples = int((nperseg/2+1))
    n_time_samples = int((time*fs_resamp - noverlap)/ (nperseg - noverlap))
    spect_current_patient = np.zeros((int(n_freq_samples*4),n_time_samples))
        
    for i, location in enumerate(locations):
        x = recordings[i] #note that if there are multiple recordings, this will only load in the first
            
        #downsample to 1000Hz
        x = downsample_recording(x,int(fs/fs_resamp))
        
        # 0-pad data to maximum length
        if len(x) >= time*fs_resamp: # choose first max-time s
            x = x[:int(time*fs_resamp)]
        elif len(x) < time*fs_resamp:
            len_diff = time*fs_resamp - len(x)
            x = np.pad(x,(0,int(len_diff)), mode='constant', constant_values=0)
        
        _,_,spectro = signal.spectrogram(x, fs=fs_resamp, nperseg = nperseg, noverlap = noverlap)

        if 'AV' in location:
            spect_current_patient[:n_freq_samples,:n_time_samples] = spectro
        elif 'PV' in location:
            spect_current_patient[n_freq_samples:2*n_freq_samples,:n_time_samples] = spectro
        elif 'TV' in location:
            spect_current_patient[2*n_freq_samples:3*n_freq_samples,:n_time_samples] = spectro
        elif 'MV' in location:
            spect_current_patient[-n_freq_samples:,:n_time_samples] = spectro
        
    
    return np.asarray(spect_current_patient, dtype=np.float32)    

def screening_cost(labels, outputs):
    outputs = tf.cast(outputs, tf.float32)
    labels = tf.cast(labels, tf.float32)
    # Define costs. Better to load these costs from an external file instead of defining them here.
    c_algorithm  =     1 # Cost for algorithmic prescreening.
    c_gp         =   250 # Cost for screening from a general practitioner (GP).
    c_specialist =   500 # Cost for screening from a specialist.
    c_treatment  =  1000 # Cost for treatment.
    c_error      = 10000 # Cost for diagnostic error.
    alpha        =   0.5 # Fraction of murmur unknown cases that are positive.
    
    n_pp = K.sum((labels - 1) * (1 - labels) * (2 - labels) \
                 * (outputs - 1) * (1 - outputs) * (2 - outputs))
    n_pu = K.sum((labels - 1) * (1 - labels) * (2 - labels) \
                 * (outputs) * (outputs - 2) * (2 - outputs))
    n_pn = K.sum((labels - 1) * (1 - labels) * (2 - labels) \
                 * (outputs) * (1 - outputs) * (outputs - 3))
    n_up = K.sum((labels) * (labels - 2) * (2 - labels) \
                 * (outputs - 1) * (1 - outputs) * (2 - outputs))
    n_uu = K.sum((labels) * (labels - 2) * (2 - labels) \
                 * (outputs) * (outputs - 2) * (2 - outputs))
    n_un = K.sum((labels) * (labels - 2) * (2 - labels) \
                 * (outputs) * (1 - outputs) * (outputs - 3))
    n_np = K.sum((labels) * (1 - labels) * (labels - 3) \
                 * (outputs - 1) * (1 - outputs) * (2 - outputs))
    n_nu = K.sum((labels) * (1 - labels) * (labels - 3) \
                 * (outputs) * (outputs - 2) * (2 - outputs))
    n_nn = K.sum((labels) * (1 - labels) * (labels - 3) \
                 * (outputs) * (1 - outputs) * (outputs - 3))
        
    # print(outputs)

    n_total = n_pp + n_pu + n_pn + n_up + n_uu + n_un + n_np + n_nu + n_nn

    total_score = c_algorithm * n_total \
        + c_gp * (n_pp + n_pu + n_pn) \
        + c_specialist * (n_pu + n_up + n_uu + n_un) \
        + c_treatment * (n_pp + alpha * n_pu + n_up + alpha * n_uu) \
        + c_error * (n_np + alpha * n_nu)
    mean_score = total_score / (n_total)

    return mean_score
    
def cross_validate(train_x_all, crafted_features, train_y_all, split_size=6, alg='cnn'):
  histories = []
  kf = KFold(n_splits=split_size, shuffle=True, random_state=1)
  for train_idx, val_idx in kf.split(train_x_all, train_y_all):
    train_x = train_x_all[train_idx]
    train_crafted_features = crafted_features[train_idx]
    train_y = train_y_all[train_idx]
    val_x = train_x_all[val_idx]
    val_crafted_features = crafted_features[val_idx]
    val_y = train_y_all[val_idx]
    _, history = run_parent_net(train_x, train_crafted_features, train_y, val_x, val_crafted_features, val_y)
    # _, history = run_cnn(train_x, train_y, val_x, val_y)
    histories.append(history)
    del history
  return histories

def run_parent_net(features, crafted_features, labels, val_x=None, val_crafted_features=None, val_y=None):
    # Inputs
    batch_size = 64
    epochs = 40
    num_classes = np.size(labels,1)
    lr = 0.1

    cnn_input = Input(shape=(129,287,4))
    cnn_output = Sequential([
        Conv2D(24, kernel_size=(3, 3),activation='linear',padding='same'), # 574 for max length, 44 for 5s
        BatchNormalization(),
        LeakyReLU(alpha=0.1),
        MaxPooling2D((2, 2),padding='same'),
        Dropout(0.25),
        Conv2D(48, (3, 3), activation='linear',padding='same'),
        BatchNormalization(),
        LeakyReLU(alpha=0.1),
        MaxPooling2D(pool_size=(2, 2),padding='same'),
        Dropout(0.25),
        Conv2D(96, (3, 3), activation='linear',padding='same'),
        BatchNormalization(),
        LeakyReLU(alpha=0.1),                
        MaxPooling2D(pool_size=(2, 2),padding='same'),
        Dropout(0.25),
        Flatten(),
        Dense(96, activation='relu'),
        BatchNormalization(),
        Dense(48, activation='relu')])(cnn_input)

    crafted_features_input = Input(shape=(336*4,))
    crafted_features_output = Dense(256, activation='relu')(crafted_features_input)
    crafted_features_output = Dense(48, activation='relu')(crafted_features_output)

    lubdub_output = concatenate([cnn_output, crafted_features_output])
    lubdub_output = Dense(24, activation='relu')(lubdub_output)
    lubdub_output = Dense(num_classes, activation='softmax')(lubdub_output)

    lubdub_model = Model([cnn_input, crafted_features_input], outputs=lubdub_output)
    
    optimizer = Adam(learning_rate=lr, decay=lr/epochs)
    lubdub_model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=optimizer,metrics=['accuracy'])
    
    
    class_weight = {0: 5, 1: 3, 2: 1}
    if val_x is None:
        lubdub_model.fit([features, crafted_features], labels, batch_size=batch_size,epochs=epochs,class_weight=class_weight,verbose=1)
    else:
        lubdub_model.fit([features, crafted_features], labels, batch_size=batch_size,epochs=epochs,class_weight=class_weight,verbose=1,validation_data=([val_x, val_crafted_features], val_y))
    
    # lubdub_train.save("lubdub_train.h5py")
    return lubdub_model, lubdub_model.history


def run_cnn(features, labels, val_x=None, val_y=None):
    # Inputs
    batch_size = 64
    epochs = 20
    num_classes = np.size(labels,1)
    lr = 0.1
    
    # Compile model
    lubdub_model = Sequential([
        Conv2D(24, kernel_size=(3, 3),activation='linear',input_shape=(129,287,4),padding='same'), # 574 for max length, 44 for 5s
        BatchNormalization(),
        LeakyReLU(alpha=0.1),
        MaxPooling2D((2, 2),padding='same'),
        Dropout(0.25),
        Conv2D(48, (3, 3), activation='linear',padding='same'),
        BatchNormalization(),
        LeakyReLU(alpha=0.1),
        MaxPooling2D(pool_size=(2, 2),padding='same'),
        Dropout(0.25),
        Conv2D(96, (3, 3), activation='linear',padding='same'),
        BatchNormalization(),
        LeakyReLU(alpha=0.1),                
        MaxPooling2D(pool_size=(2, 2),padding='same'),
        Dropout(0.25),
        Flatten(),
        Dense(96, activation='linear'),
        BatchNormalization(),
        LeakyReLU(alpha=0.1),             
        Dense(num_classes, activation='softmax')])
    
    optimizer = Adam(learning_rate=lr, decay=lr/epochs)
    lubdub_model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=optimizer,metrics=['accuracy'])
    class_weight = {0: 5, 1: 3, 2: 1}
    
    if val_x is None:
        lubdub_model.fit(features, labels, batch_size=batch_size,class_weight=class_weight,epochs=epochs,verbose=1)
    else:
        lubdub_model.fit(features, labels, batch_size=batch_size,class_weight=class_weight,epochs=epochs,verbose=1,validation_data=(val_x, val_y))
    
    # lubdub_train.save("lubdub_train.h5py")
    return lubdub_model, lubdub_model.history


def run_rnn(features, labels, val_x=None, val_y=None):
    # Inputs
    batch_size = 64
    epochs = 10
    num_classes = np.size(labels,1)
    lr = 0.1
    
    # Compile model
    lubdub_model = Sequential([
        LSTM(516, input_shape=(516,287)), # 44 for 5 min, 574 for maxlen
        Dropout(0.2),
        BatchNormalization(),
        Dense(516, activation='relu'),
        Dense(256, activation='relu'),
        Dropout(0.4),
        BatchNormalization(),
        Dense(64, activation='relu'),
        Dropout(0.4),
        BatchNormalization(),
        Dense(32, activation='relu'),       
        Dense(num_classes, activation='softmax')])
    
    optimizer = Adam(learning_rate=lr, decay=lr/epochs)
    lubdub_model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=optimizer,metrics=['accuracy'])
    
    if val_x is None:
        lubdub_model.fit(features, labels, batch_size=batch_size,epochs=epochs,verbose=1)
    else:
        lubdub_model.fit(features, labels, batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(val_x, val_y))
    
    # lubdub_train.save("lubdub_train.h5py")
    return lubdub_model, lubdub_model.history

#%% Preprocess_data

def preprocess_data(recording, Fs = 4000, figures = False):
    # 25-400Hz 4th order Butterworth band pass
    recording_lowpass = butterworth_low_pass_filter(recording,2,400,Fs, figures);
    recording_filtered = butterworth_high_pass_filter(recording_lowpass,2,25,Fs, figures);
    #plt.plot(recording_filtered)
    
    # Spike removal from the original paper:
    despiked_data = schmidt_spike_removal(recording_filtered,Fs);
    #plt.plot(despiked_data)
    
    return despiked_data
         
#%% butterworth_low_pass
# Low-pass filter a given signal using a forward-backward, zero-phase
# butterworth low-pass filter.

# INPUTS:
# original_signal: The 1D signal to be filtered
# order: The order of the filter (1,2,3,4 etc). NOTE: This order is
# effectively doubled as this function uses a forward-backward filter that
# ensures zero phase distortion
# cutoff: The frequency cutoff for the low-pass filter (in Hz)
# sampling_frequency: The sampling frequency of the signal being filtered
# (in Hz).
# figures (optional): boolean variable dictating the display of figures

# OUTPUTS:
# low_pass_filtered_signal: the low-pass filtered signal.

# This code is derived from the paper:
# S. E. Schmidt et al., "Segmentation of heart sound recordings by a
# duration-dependent hidden Markov model," Physiol. Meas., vol. 31,
# no. 4, pp. 513-29, Apr. 2010.

# Developed by David Springer for comparison purposes in the paper:
# D. Springer et al., ?Logistic Regression-HSMM-based Heart Sound
# Segmentation,? IEEE Trans. Biomed. Eng., In Press, 2015.

#% Copyright (C) 2016  David Springer
# dave.springer@gmail.com

def butterworth_low_pass_filter(original_signal,order,cutoff,sampling_frequency, figures=False):
    
    #Get the butterworth filter coefficients
    B_low,A_low = signal.butter(order,2*cutoff/sampling_frequency,'low')
    
    #Forward-backward filter the original signal using the butterworth
    #coefficients, ensuring zero phase distortion
    low_pass_filtered_signal = signal.filtfilt(B_low,A_low,original_signal)

    
    if figures:
        w, h = signal.freqs(B_low, A_low)
        plt.semilogx(w, 20 * np.log10(abs(h)))
        plt.title('Butterworth low-pass filter frequency response')
        plt.xlabel('Frequency [radians / second]')
        plt.ylabel('Amplitude [dB]')
        plt.margins(0, 0.1)
        plt.grid(which='both', axis='both')
        plt.axvline(100, color='green') # cutoff frequency
        plt.show()

        plt.title('Filtfilt low-pass')
        #fgust = signal.filtfilt(B_low, A_low, original_signal, method="gust")
        fpad = signal.filtfilt(B_low, A_low, original_signal, padlen=50)
        plt.plot(original_signal, 'k-', label='input')
        #plt.plot(fgust, 'b-', linewidth=4, label='gust')
        plt.plot(fpad, 'c-', linewidth=1.5, label='pad')
        plt.legend(loc='best')
        plt.show()
        
    
    return low_pass_filtered_signal 

#%% butterworth_high_pass

# High-pass filter a given signal using a forward-backward, zero-phase
# butterworth low-pass filter.

# INPUTS:
# original_signal: The 1D signal to be filtered
# order: The order of the filter (1,2,3,4 etc). NOTE: This order is
# effectively doubled as this function uses a forward-backward filter that
# ensures zero phase distortion
# cutoff: The frequency cutoff for the high-pass filter (in Hz)
# sampling_frequency: The sampling frequency of the signal being filtered
# (in Hz).
# figures (optional): boolean variable dictating the display of figures

# OUTPUTS:
# high_pass_filtered_signal: the high-pass filtered signal.

# This code is derived from the paper:
# S. E. Schmidt et al., "Segmentation of heart sound recordings by a
# duration-dependent hidden Markov model," Physiol. Meas., vol. 31,
# no. 4, pp. 513-29, Apr. 2010.

# Developed by David Springer for comparison purposes in the paper:
# D. Springer et al., ?Logistic Regression-HSMM-based Heart Sound
# Segmentation,? IEEE Trans. Biomed. Eng., In Press, 2015.

#% Copyright (C) 2016  David Springer
# dave.springer@gmail.com

def butterworth_high_pass_filter(original_signal,order,cutoff,sampling_frequency, figures=False):

    
    #Get the butterworth filter coefficients
    [B_high,A_high] = signal.butter(order,2*cutoff/sampling_frequency,'high');
    
    #Forward-backward filter the original signal using the butterworth
    #coefficients, ensuring zero phase distortion
    high_pass_filtered_signal = signal.filtfilt(B_high,A_high,original_signal);
    
    if figures:
        w, h = signal.freqs(B_high, A_high)
        plt.semilogx(w, 20 * np.log10(abs(h)))
        plt.title('Butterworth high-pass filter frequency response')
        plt.xlabel('Frequency [radians / second]')
        plt.ylabel('Amplitude [dB]')
        plt.margins(0, 0.1)
        plt.grid(which='both', axis='both')
        plt.axvline(100, color='green') # cutoff frequency
        plt.show()
        
        plt.title('Filtfilt high-pass')
        #fgust = signal.filtfilt(B_high, A_high, original_signal, method="gust")
        fpad = signal.filtfilt(B_high, A_high, original_signal, padlen=50)
        plt.plot(original_signal, 'k-', label='input')
        #plt.plot(fgust, 'b-', linewidth=4, label='gust')
        plt.plot(fpad, 'c-', linewidth=1.5, label='pad')
        plt.legend(loc='best')
        plt.show()


    return high_pass_filtered_signal

#%% schmidt_spike_removal
# This function removes the spikes in a signal as done by Schmidt et al in
# the paper:
# Schmidt, S. E., Holst-Hansen, C., Graff, C., Toft, E., & Struijk, J. J.
# (2010). Segmentation of heart sound recordings by a duration-dependent
# hidden Markov model. Physiological Measurement, 31(4), 513-29.

# The spike removal process works as follows:
# (1) The recording is divided into 500 ms windows.
# (2) The maximum absolute amplitude (MAA) in each window is found.
# (3) If at least one MAA exceeds three times the median value of the MAA's,
# the following steps were carried out. If not continue to point 4.
#   (a) The window with the highest MAA was chosen.
#   (b) In the chosen window, the location of the MAA point was identified as the top of the noise spike.
#   (c) The beginning of the noise spike was defined as the last zero-crossing point before theMAA point.
#   (d) The end of the spike was defined as the first zero-crossing point after the maximum point.
#   (e) The defined noise spike was replaced by zeroes.
#   (f) Resume at step 2.
# (4) Procedure completed.

# Inputs:
# original_signal: The original (1D) audio signal array
# fs: the sampling frequency (Hz)

# Outputs:
# despiked_signal: the audio signal with any spikes removed.

# This code is derived from the paper:
# S. E. Schmidt et al., "Segmentation of heart sound recordings by a
# duration-dependent hidden Markov model," Physiol. Meas., vol. 31,
# no. 4, pp. 513-29, Apr. 2010.

# Developed by David Springer for comparison purposes in the paper:
# D. Springer et al., ?Logistic Regression-HSMM-based Heart Sound
# Segmentation,? IEEE Trans. Biomed. Eng., In Press, 2015.

# Copyright (C) 2016  David Springer
# dave.springer@gmail.com

def schmidt_spike_removal(original_signal, fs):

    #Find the window size
    # (500 ms)
    windowsize = 500 # round(fs/2);
    
    # Find any samples outside of a integer number of windows:
    trailingsamples = np.size(original_signal) % windowsize;
    
    # Reshape the signal into a number of windows:
    sampleframes = np.reshape(original_signal[:-trailingsamples], (windowsize,-1));
    
    # Find the MAAs:
    MAAs = np.max(abs(sampleframes), axis=0)
    
    # Find the threshold for spikes
    #threshold = np.median(MAAs)*3
    
    #for troubleshooting
    #print(str(threshold))
    spikenum = 1
    #startTime = time.time()
    
    # While there are still samples greater than 3* the median value of the
    # MAAs, then remove those spikes:
    while np.sum(MAAs>np.median(MAAs)*3)>0:
    #while (MAAs>np.median(MAAs)).nonzero(): #just for testing
        
    
        #Find the window with the max MAA:
        val = np.amax(MAAs)
        window_num = np.asarray(np.where(MAAs == val)).flatten()
        
        if np.size(window_num)>0:
        #if window_num.nonzero():
            window_num = window_num[0]
            
        # print('window_num ' + str(window_num))
        
        #Find the position of the spike within that window:
        val = np.max(abs(sampleframes[:,window_num]));
        if val == np.max(sampleframes[:,window_num]): #if positive
            spike_position = np.asarray(np.where(sampleframes[:,window_num] == val)).flatten()
        else: #if negative
            spike_position = np.asarray(np.where(sampleframes[:,window_num] == -val)).flatten()
        
        
        if np.size(spike_position)>0:
        #if spike_position.nonzero():
            spike_position = spike_position[0]  
    
        # print('spike_position ' + str(spike_position))
        
        # Finding zero crossings (where there may not be actual 0 values, just a change from positive to negative):
        zero_crossings = np.array(abs(np.diff(np.sign(sampleframes[:,window_num])))>1)
        
        #Find the start of the spike, finding the last zero crossing before
        #spike position. If that is empty, take the start of the window:
        if spike_position>0:
            zero_crossings_start_pos = np.array(zero_crossings[:spike_position].nonzero()).flatten()
            zero_crossings_start_pos_insert = np.insert(zero_crossings_start_pos,0,1)
            spike_start = zero_crossings_start_pos_insert[-1]
        else: 
            spike_start = 0
    
        #Find the end of the spike, finding the first zero crossing after
        #spike position. If that is empty, take the end of the window:
        zero_crossings[:spike_position+1] = 0;
        zero_crossings_end_pos = np.append(zero_crossings.nonzero(), windowsize)
        spike_end = zero_crossings_end_pos[0]
        
        #Set to Zero
        sampleframes[spike_start:spike_end,window_num] = 0.0001;
    
        #Recaclulate MAAs
        MAAs = np.max(abs(sampleframes),axis=0);
        
        # print('spikes removed: ' + str(spikenum))
        spikenum = spikenum+1

    
    despiked_signal = np.reshape(sampleframes,(-1,1)).flatten();
    
    # Add the trailing samples back to the signal:
    len_end = np.size(despiked_signal);
    despiked_signal = np.concatenate((despiked_signal, original_signal[len_end:]));
    
    #executionTime = (time.time() - startTime)
    #print('Execution time in seconds: ' + str(executionTime))

    return despiked_signal

#%% Homomorphic_Envelope_with_Hilbert
# This function finds the homomorphic envelope of a signal, using the method
# described in the following publications:

# S. E. Schmidt et al., ?Segmentation of heart sound recordings by a 
# duration-dependent hidden Markov model.,? Physiol. Meas., vol. 31, no. 4,
# pp. 513?29, Apr. 2010.
 
# C. Gupta et al., ?Neural network classification of homomorphic segmented
# heart sounds,? Appl. Soft Comput., vol. 7, no. 1, pp. 286?297, Jan. 2007.

# D. Gill et al., ?Detection and identification of heart sounds using 
# homomorphic envelogram and self-organizing probabilistic model,? in 
# Computers in Cardiology, 2005, pp. 957?960.
# (However, these researchers found the homomorphic envelope of shannon
# energy.)

# In I. Rezek and S. Roberts, ?Envelope Extraction via Complex Homomorphic
# Filtering. Technical Report TR-98-9,? London, 1998, the researchers state
# that the singularity at 0 when using the natural logarithm (resulting in
# values of -inf) can be fixed by using a complex valued signal. They
# motivate the use of the Hilbert transform to find the analytic signal,
# which is a converstion of a real-valued signal to a complex-valued
# signal, which is unaffected by the singularity. 

# A zero-phase low-pass Butterworth filter is used to extract the envelope.

## Inputs:
# input_signal: the original signal (1D) signal
# samplingFrequency: the signal's sampling frequency (Hz)
# lpf_frequency: the frequency cut-off of the low-pass filter to be used in
# the envelope extraciton (Default = 8 Hz as in Schmidt's publication).
# figures: (optional) boolean variable dictating the display of a figure of
# both the original signal and the extracted envelope:

## Outputs:
# homomorphic_envelope: The homomorphic envelope of the original
# signal (not normalised).

# This code was developed by David Springer for comparison purposes in the
# paper:
# D. Springer et al., ?Logistic Regression-HSMM-based Heart Sound 
# Segmentation,? IEEE Trans. Biomed. Eng., In Press, 2015.

# Copyright (C) 2016  David Springer
# dave.springer@gmail.com


def Homomorphic_Envelope_with_Hilbert(input_signal, sampling_frequency,lpf_frequency=8,figures=0):

    #8Hz, 1st order, Butterworth LPF
    B_low,A_low = signal.butter(1,2*lpf_frequency/sampling_frequency,'low')
    homomorphic_envelope = np.exp(signal.filtfilt(B_low,A_low,np.log(abs(signal.hilbert(input_signal)))));

    #low_pass_filtered_signal = signal.filtfilt(B_low,A_low,original_signal)
    
    # Remove spurious spikes in first sample:
    homomorphic_envelope[0] = homomorphic_envelope[1];
    
    if figures == 1:
        plt.title('Homomorphic Envelope');
        plt.plot(input_signal, label = 'Original Signal');
        plt.plot(homomorphic_envelope,'r', label = 'Homomorphic Envelope');
        plt.legend(loc='best')

    return homomorphic_envelope 

#%% Hilbert_Envelope

#This function finds the Hilbert envelope of a signal. This is taken from:

# Choi et al, Comparison of envelope extraction algorithms for cardiac sound
# signal segmentation, Expert Systems with Applications, 2008

## Inputs:
# input_signal: the original signal
# samplingFrequency: the signal's sampling frequency
# figures: (optional) boolean variable to display a figure of both the
# original and normalised signal

# Outputs:
# hilbert_envelope is the hilbert envelope of the original signal
#
# This code was developed by David Springer for comparison purposes in the
# paper:
# D. Springer et al., "Logistic Regression-HSMM-based Heart Sound
# Segmentation," IEEE Trans. Biomed. Eng., In Press, 2015.

# Copyright (C) 2016  David Springer
# dave.springer@gmail.com

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

def Hilbert_Envelope(input_signal, sampling_frequency,figures=0):

    hilbert_envelope = abs(signal.hilbert(input_signal)); #find the envelope of the signal using the Hilbert transform

    if figures==1:
        plt.title('Hilbert Envelope');
        plt.plot(input_signal,label = 'Original Signal');
        plt.plot(hilbert_envelope,'r', label = 'Hilbert Envelope');
        plt.legend(loc='best')
    return hilbert_envelope

#%% normalize_signal

# This function subtracts the mean and divides by the standard deviation of
# a (1D) signal in order to normalise it for machine learning applications.

# Inputs:
# signal: the original signal
#
# Outputs:
# normalised_signal: the original signal, minus the mean and divided by
# the standard deviation.

# Developed by David Springer for the paper:
# D. Springer et al., ?Logistic Regression-HSMM-based Heart Sound
# Segmentation,? IEEE Trans. Biomed. Eng., In Press, 2015.

# Copyright (C) 2016  David Springer
# dave.springer@gmail.com

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

def normalize_signal(signal):

    mean_of_signal = np.mean(signal);
    
    standard_deviation = np.std(signal);
    
    normalized_signal = np.divide((signal - mean_of_signal),standard_deviation);

    return normalized_signal

#%% get_PSD_feature_Springer_HMM

#PSD-based feature extraction for heart sound segmentation.

# INPUTS:
# data: this is the audio waveform
# sampling_frequency is self-explanatory
# frequency_limit_low is the lower-bound on the frequency range you want to
# analyse
# frequency_limit_high is the upper-bound on the frequency range
# figures: (optional) boolean variable to display figures

# OUTPUTS:
# psd is the array of maximum PSD values between the max and min limits,
# resampled to the same size as the original data.

# This code was developed by David Springer in the paper:
# D. Springer et al., "Logistic Regression-HSMM-based Heart Sound
# Segmentation," IEEE Trans. Biomed. Eng., In Press, 2015.

# Copyright (C) 2016  David Springer
# dave.springer@gmail.com

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

def next_power_of_2(x):  
    return 1 if x == 0 else 2**(x - 1).bit_length()


def get_PSD_feature_Springer_HMM(data, sampling_frequency, frequency_limit_low, frequency_limit_high, figures=0):

    # Find the spectrogram of the signal: 
#    _,F,T,P = signal.spectrogram(data,window = sampling_frequency/40,noverlap = #round(sampling_frequency/80),np.arange(1,round(sampling_frequency/2)),sampling_frequency);
    
    WINDOW = round(sampling_frequency/40)
    NOVERLAP = round(sampling_frequency/80)
    NFFT = next_power_of_2(WINDOW)
    Fs = sampling_frequency
    
    F, T, P = signal.spectrogram(data, fs=Fs, window='hann', nperseg=WINDOW, 
                                noverlap=NOVERLAP, nfft = NFFT, detrend=False)

    if figures == 1:
        #surf(T,F,10*log(P),'edgecolor','none'); axis tight;
        #view(0,90);
        #xlabel('Time (Seconds)'); ylabel('Hz');
        #pause();
        plt.pcolormesh(T, F, P)
        plt.ylabel('Frequency [Hz]')
        plt.xlabel('Time [sec]')
        plt.show()
    
    low_limit_position = np.argmin(abs(F - frequency_limit_low))
    high_limit_position = np.argmax(abs(F - frequency_limit_high))
    
    # Find the mean PSD over the frequency range of interest:
    psd = np.mean(P[low_limit_position:high_limit_position,:],axis=0);
    
    
    if figures == 1:
        t4  = np.divide(np.arange(0,len(psd)),sampling_frequency)
        t3 = np.divide(np.arange(0,len(data)),sampling_frequency)
        plt.title('PSD Feature')
        plt.plot(t3,np.divide((data - np.mean(data)),np.std(data)),'c'); 
        plt.plot(t4, (psd - np.mean(psd)/np.std(psd)),'k');


    return psd

#%% getSpringerPCGFeatures

# Get the features used in the Springer segmentation algorithm. These 
# features include:
# -The homomorphic envelope (as performed in Schmidt et al's paper)
# -The Hilbert envelope
# -A wavelet-based feature - set to False
# -A PSD-based feature
# This function was developed for use in the paper:
# D. Springer et al., "Logistic Regression-HSMM-based Heart Sound 
# Segmentation," IEEE Trans. Biomed. Eng., In Press, 2015.

# INPUTS:
# audio_data: array of data from which to extract features
# Fs: the sampling frequency of the audio data
# figures (optional): boolean variable dictating the display of figures

# OUTPUTS:
# PCG_Features: array of derived features

# Copyright (C) 2016  David Springer
# dave.springer@gmail.com

def getSpringerPCGFeatures(despiked_data, Fs, audio_segmentation_Fs=50, figures = False):
    
    ########## Homomorphic envelope ##########
    
    # Find the homomorphic envelope
    homomorphic_envelope = Homomorphic_Envelope_with_Hilbert(despiked_data, Fs);
    #plt.plot(homomorphic_envelope)
    
    # Downsample the envelope:
    num_samps = round(np.size(homomorphic_envelope)/Fs*audio_segmentation_Fs)
    downsampled_homomorphic_envelope = signal.resample(homomorphic_envelope, num_samps);
    
    # normalise the envelope:
    downsampled_homomorphic_envelope = normalize_signal(downsampled_homomorphic_envelope);
    #plt.plot(downsampled_homomorphic_envelope)
    
    ########## Hilbert Envelope ##########
    # Hilbert Envelope
    hilbert_envelope = Hilbert_Envelope(despiked_data, Fs, figures =0);
    downsampled_hilbert_envelope = signal.resample(hilbert_envelope,num_samps);
    downsampled_hilbert_envelope = normalize_signal(downsampled_hilbert_envelope);
    #plt.plot(downsampled_hilbert_envelope)
    
    ########## Get power spectral density features ##########
    
    psd = get_PSD_feature_Springer_HMM(despiked_data, Fs, 40, 60);
    #plt.plot(psd)
    psd = signal.resample_poly(psd, len(downsampled_homomorphic_envelope), len(psd));
    psd = normalize_signal(psd);
    #plt.plot(psd)
    
    ########## Wavelet features ##########
    #if include_wavelet_features:
    wavelet_level=3;
    wavelet_name = 'db7';
    [cA, cD, _, _] = pywt.wavedec(despiked_data, wavelet_name, mode='zero', level = wavelet_level);
    wavelet_feature = signal.resample_poly(abs(cD), len(downsampled_homomorphic_envelope), len(cD));
    wavelet_feature = normalize_signal(wavelet_feature);
        
    
    ########## Wavelet features ##########
    #if include_wavelet_features:
    PCG_Features = [downsampled_homomorphic_envelope, downsampled_hilbert_envelope, psd, wavelet_feature]

    
    return PCG_Features

#%% Get Heart Rate 

# Derive the heart rate and the sytolic time interval from a PCG recording.
# This is used in the duration-dependant HMM-based segmentation of the PCG
# recording.

# This method is based on analysis of the autocorrelation function, and the
# positions of the peaks therein.

# This code is derived from the paper:
# S. E. Schmidt et al., "Segmentation of heart sound recordings by a 
# duration-dependent hidden Markov model," Physiol. Meas., vol. 31,
# no. 4, pp. 513-29, Apr. 2010.

# Developed by David Springer for comparison purposes in the paper:
# D. Springer et al., "Logistic Regression-HSMM-based Heart Sound 
# Segmentation," IEEE Trans. Biomed. Eng., In Press, 2015.

# INPUTS:
# audio_data: The raw audio data from the PCG recording
# Fs: the sampling frequency of the audio recording
# figures: optional boolean to display figures

# OUTPUTS:
# heartRate: the heart rate of the PCG in beats per minute
# systolicTimeInterval: the duration of systole, as derived from the
# autocorrelation function, in seconds
#
# Copyright (C) 2016  David Springer
# dave.springer@gmail.com

def getHeartRateSchmidt(audio_data, Fs, figures=False):

    # Get heatrate:
    # From Schmidt:
    # "The duration of the heart cycle is estimated as the time from lag zero
    # to the highest peaks between 500 and 2000 ms in the resulting
    # autocorrelation"
    # This is performed after filtering and spike removal:
    
    # 25-400Hz 4th order Butterworth band pass
    audio_data = butterworth_low_pass_filter(audio_data,2,400,Fs, False);
    audio_data = butterworth_high_pass_filter(audio_data,2,25,Fs);
    
    # Spike removal from the original paper:
    audio_data = schmidt_spike_removal(audio_data,Fs);
    
    # Find the homomorphic envelope
    homomorphic_envelope = Homomorphic_Envelope_with_Hilbert(audio_data, Fs);
    
    # Find the autocorrelation:
    y =homomorphic_envelope-np.mean(homomorphic_envelope);
  
    def autocorr(x):
        autocorr_f = np.correlate(x, x, mode='full') #autocorrelation
        temp = autocorr_f[int(autocorr_f.size/2):]/autocorr_f[int(autocorr_f.size/2)] #normalize
        return temp
    
    
    signal_autocorrelation = autocorr(y);
    #signal_autocorrelation = c(length(homomorphic_envelope)+1:end);
    
    min_index = round(0.5*Fs);
    max_index = round(2*Fs);
    
    index = np.argmax(signal_autocorrelation[min_index:max_index]);
    true_index = index+min_index-1;
    
    heartRate = 60/(true_index/Fs);
    # print('Heart Rate is: ' + str(heartRate))
    
    # Find the systolic time interval:
    # From Schmidt: "The systolic duration is defined as the time from lag zero
    # to the highest peak in the interval between 200 ms and half of the heart
    # cycle duration"
    
    max_sys_duration = round(((60/heartRate)*Fs)/2);
    min_sys_duration = round(0.2*Fs);
    
    pos = np.argmax(signal_autocorrelation[min_sys_duration:max_sys_duration]);
    systolicTimeInterval = (min_sys_duration+pos)/Fs;
    
    
    if figures:
        plt.title('Heart rate calculation figure');
        plt.plot(signal_autocorrelation, label = 'Autocorrelation');
        plt.plot(true_index, signal_autocorrelation[true_index],'ro', label = 'Position of max peak used to calculate HR');
        plt.plot(min_sys_duration+pos, signal_autocorrelation[min_sys_duration+pos], 'mo', label = 'Position of max peak within systolic interval');
        plt.xlabel('Samples');
        plt.legend(loc='best')
    
    return heartRate, systolicTimeInterval

#%% get_duration_distributions
#
# This function calculates the duration distributions for each heart cycle
# state, and the minimum and maximum times for each state.
#
## Inputs:
# heartrate is the calculated average heart rate over the entire recording
# systolic_time is the systolic time interval
#
## Outputs:
# d_distributions is a 4 (the number of states) dimensional vector of
# gaussian mixture models (one dimensional in this case), representing the
# mean and std deviation of the duration in each state.
#
# The max and min values are self-explanatory.
#
# This code is implemented as outlined in the paper:
# S. E. Schmidt et al., "Segmentation of heart sound recordings by a
# duration-dependent hidden Markov model," Physiol. Meas., vol. 31,
# no. 4, pp. 513-29, Apr. 2010.
#
# Developed by David Springer for comparison purposes in the paper:
# D. Springer et al., "Logistic Regression-HSMM-based Heart Sound
# Segmentation," IEEE Trans. Biomed. Eng., In Press, 2015.
#
## Copyright (C) 2016  David Springer
# dave.springer@gmail.com
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
def get_duration_distributions(heartrate = None,systolic_time = None): 
    audio_segmentation_Fs = 50

    mean_S1 = np.round(0.122 * audio_segmentation_Fs)
    std_S1 = np.round(0.022 * audio_segmentation_Fs)
    mean_S2 = np.round(0.094 * audio_segmentation_Fs)
    std_S2 = np.round(0.022 * audio_segmentation_Fs)
    mean_systole = np.round(systolic_time * audio_segmentation_Fs) - mean_S1
    std_systole = (25 / 1000) * audio_segmentation_Fs
    mean_diastole = ((60 / heartrate) - systolic_time - 0.094) * audio_segmentation_Fs
    std_diastole = 0.07 * mean_diastole + (6 / 1000) * audio_segmentation_Fs
    ## Cell array for the mean and covariance of the duration distributions:
    d_distributions = np.zeros((4,2))
    ## Assign mean and covariance values to d_distributions:
    d_distributions[0,0] = mean_S1
    d_distributions[0,1] = (std_S1) ** 2
    d_distributions[1,0] = mean_systole
    d_distributions[1,1] = (std_systole) ** 2
    d_distributions[2,0] = mean_S2
    d_distributions[2,1] = (std_S2) ** 2
    d_distributions[3,0] = mean_diastole
    d_distributions[3,1] = (std_diastole) ** 2
    #Min systole and diastole times
    min_systole = mean_systole - 3 * (std_systole + std_S1)
    max_systole = mean_systole + 3 * (std_systole + std_S1)
    min_diastole = mean_diastole - 3 * std_diastole
    max_diastole = mean_diastole + 3 * std_diastole
    #Setting the Min and Max values for the S1 and S2 sounds:
    #If the minimum lengths are less than a 50th of the sampling frequency, set
    #to a 50th of the sampling frequency:
    min_S1 = (mean_S1 - 3 * (std_S1))
    if (min_S1 < (audio_segmentation_Fs / 50)):
        min_S1 = (audio_segmentation_Fs / 50)
    
    min_S2 = (mean_S2 - 3 * (std_S2))
    if (min_S2 < (audio_segmentation_Fs / 50)):
        min_S2 = (audio_segmentation_Fs / 50)
    
    max_S1 = (mean_S1 + 3 * (std_S1))
    max_S2 = (mean_S2 + 3 * (std_S2))
    return d_distributions,max_S1,min_S1,max_S2,min_S2,max_systole,min_systole,max_diastole,min_diastole

#%% viterbiDecodePCG_Springer
# function [delta, psi, qt] = viterbiDecodePCG_Springer(observation_sequence, pi_vector, b_matrix, total_obs_distribution, heartrate, systolic_time, Fs, figures)
#
# This function calculates the delta, psi and qt matrices associated with
# the Viterbi decoding algorithm from:
# L. R. Rabiner, "A tutorial on hidden Markov models and selected
# applications in speech recognition," Proc. IEEE, vol. 77, no. 2, pp.
# 257-286, Feb. 1989.
# using equations 32a - 35, and equations 68 - 69 to include duration
# dependancy of the states.
#
# This decoding is performed after the observation probabilities have been
# derived from the logistic regression model of Springer et al:
# D. Springer et al., "Logistic Regression-HSMM-based Heart Sound
# Segmentation," IEEE Trans. Biomed. Eng., In Press, 2015.
#
# Further, this function is extended to allow the duration distributions to extend
# past the beginning and end of the sequence. Without this, the label
# sequence has to start and stop with an "entire" state duration being
# fulfilled. This extension takes away that requirement, by allowing the
# duration distributions to extend past the beginning and end, but only
# considering the observations within the sequence for emission probability
# estimation. More detail can be found in the publication by Springer et
# al., mentioned above.
#
## Inputs:
# observation_sequence: The observed features
# pi_vector: the array of initial state probabilities, dervived from
# "trainSpringerSegmentationAlgorithm".
# b_matrix: the observation probabilities, dervived from
# "trainSpringerSegmentationAlgorithm".
# heartrate: the heart rate of the PCG, extracted using
# "getHeartRateSchmidt"
# systolic_time: the duration of systole, extracted using
# "getHeartRateSchmidt"
# Fs: the sampling frequency of the observation_sequence
# figures: optional boolean variable to show figures
#
## Outputs:
# logistic_regression_B_matrix:
# pi_vector:
# total_obs_distribution:
# As Springer et al's algorithm is a duration dependant HMM, there is no
# need to calculate the A_matrix, as the transition between states is only
# dependant on the state durations.
#
## Copyright (C) 2016  David Springer
# dave.springer@gmail.com
    
def viterbiDecodePCG_Springer(observation_sequence, pi_vector, b_matrix, total_obs_distribution, heartrate ,systolic_time, Fs,figures = False): 
    
    ## Preliminary
    T = len(observation_sequence[0])
    N = 4
    observation_sequence = np.vstack(observation_sequence).T
    realmin = np.finfo(np.double).tiny
    
    # Setting the maximum duration of a single state. This is set to an entire
    # heart cycle:
    max_duration_D = int(np.round((1 * (60 / heartrate)) * Fs))
    #Initialising the variables that are needed to find the optimal state path along
    #the observation sequence.
    #delta_t(j), as defined on page 264 of Rabiner, is the best score (highest
    #probability) along a single path, at time t, which accounts for the first
    #t observations and ends in State s_j. In this case, the length of the
    #matrix is extended by max_duration_D samples, in order to allow the use
    #of the extended Viterbi algortithm:
    delta = np.ones((T + max_duration_D - 1,N)) * (- np.inf)
    #The argument that maximises the transition between states (this is
    #basically the previous state that had the highest transition probability
    #to the current state) is tracked using the psi variable.
    psi = np.zeros((T + max_duration_D - 1,N))
    #An additional variable, that is not included on page 264 or Rabiner, is
    #the state duration that maximises the delta variable. This is essential
    #for the duration dependant HMM.
    psi_duration = np.zeros((T + max_duration_D - 1,N))
    ## Setting up observation probs
    observation_probs = np.zeros((T,N))
    for n in np.arange(N):
        #MLR gives P(state|obs)
        #Therefore, need Bayes to get P(o|state)
        #P(o|state) = P(state|obs) * P(obs) / P(states)
        #Where p(obs) is derived from a MVN distribution from all
        #obserbations, and p(states) is taken from the pi_vector:
        estimator = LogisticRegression()
        clf = estimator.fit(observation_sequence, np.random.choice([0, 1], size=(observation_sequence.shape[0])))
        clf.coef_ = b_matrix[1:,n].reshape((1,4))
        clf.intercept_ = np.array(b_matrix[0,n])
        pihat = np.flip(clf.predict_proba(observation_sequence),axis=1)
        Po_correction = multivariate_normal.pdf(observation_sequence, mean=total_obs_distribution[0,0][0,:], cov=total_obs_distribution[1,0])
        observation_probs[:,n] = (pihat[:,1] * Po_correction) / pi_vector[0][n]
        # for t in np.arange(T):
        #     Po_correction = multivariate_normal.pdf(observation_sequence[t,:], mean=total_obs_distribution[0,0][0,:], cov=total_obs_distribution[1,0])
        #     observation_probs[t,n] = (pihat[t,1] * Po_correction) / pi_vector[0][n]
    
    ## Setting up state duration probabilities, using Gaussian distributions:
    d_distributions,max_S1,min_S1,max_S2,min_S2,max_systole,min_systole,max_diastole,min_diastole = get_duration_distributions(heartrate,systolic_time)
    duration_probs = np.zeros((N,3 * Fs))
    duration_sum = np.zeros(N)
    for state_j in np.arange(N):
        for d in np.arange(1,max_duration_D+1):
            if (state_j == 0):
                duration_probs[state_j,d] = multivariate_normal.pdf(d,d_distributions[state_j,0],d_distributions[state_j,1])
                if (d < min_S1 or d > max_S1):
                    duration_probs[state_j,d] = realmin
            else:
                if (state_j == 2):
                    duration_probs[state_j,d] = multivariate_normal.pdf(d,d_distributions[state_j,0],d_distributions[state_j,1])
                    if (d < min_S2 or d > max_S2):
                        duration_probs[state_j,d] = realmin
                else:
                    if (state_j == 1):
                        duration_probs[state_j,d] = multivariate_normal.pdf(d,d_distributions[state_j,0],d_distributions[state_j,1])
                        if (d < min_systole or d > max_systole):
                            duration_probs[state_j,d] = realmin
                    else:
                        if (state_j == 3):
                            duration_probs[state_j,d] = multivariate_normal.pdf(d,d_distributions[state_j,0],d_distributions[state_j,1])
                            if (d < min_diastole or d > max_diastole):
                                duration_probs[state_j,d] = realmin
        duration_sum[state_j] = np.sum(duration_probs[state_j,:])
    
    if (len(duration_probs) > 3 * Fs):
        duration_probs = duration_probs[:,:3*Fs]
    
    if (figures):
        plt.figure('Name','Duration probabilities')
        plt.plot(duration_probs[0,:] / duration_sum[0],linewidth=2,label='S1 Duration')
        plt.plot(duration_probs[1,:] / duration_sum[1],'r',linewidth=2,label='Systolic Duration')
        plt.plot(duration_probs[2,:] / duration_sum[2],'g',linewidth=2,label='S2 Duration')
        plt.plot(duration_probs[3,:] / duration_sum[3],'k',linewidth=2,label='Diastolic Duration')
        plt.legend()
        plt.show()
    
    ## Perform the actual Viterbi Recursion:
    
    qt = np.zeros(delta.shape[0])
    ## Initialisation Step
    
    #Equation 32a and 69a, but leave out the probability of being in
    #state i for only 1 sample, as the state could have started before time t =
    #0.
    
    delta[0,:] = np.log(pi_vector) + np.log(observation_probs[0,:])
    
    #Equation 32b
    psi[0,:] = - 1
    # The state duration probabilities are now used.
    #Change the a_matrix to have zeros along the diagonal, therefore, only
    #relying on the duration probabilities and observation probabilities to
    #influence change in states:
    #This would only be valid in sequences where the transition between states
    #follows a distinct order.
    a_matrix = np.array([[0,1,0,0],[0,0,1,0],[0,0,0,1],[1,0,0,0]])
    
    ## Run the core Viterbi algorithm
    for t in np.arange(1,T + max_duration_D - 1):
        for j in np.arange(N):
            for d in np.arange(1,max_duration_D+1):
                start_t = t - d
                if (start_t < 0):
                    start_t = 0
                if (start_t > T - 2):
                    start_t = T - 2
                end_t = t
                if (t > T - 1):
                    end_t = T - 1
                max_delta = np.max(delta[start_t,:] + np.transpose(np.log(a_matrix[:,j])))
                max_index = np.argmax(delta[start_t,:] + np.transpose(np.log(a_matrix[:,j])))
                probs = np.prod(observation_probs[start_t:end_t+1,j])
                if (probs == 0):
                    probs = realmin
                emission_probs = np.log(probs)
                if (emission_probs == 0 or np.isnan(emission_probs)):
                    emission_probs = realmin
                delta_temp = max_delta + (emission_probs) + np.log(duration_probs[j,d]) - np.log(duration_sum[j])
                if (delta_temp > delta[t,j]):
                    delta[t,j] = delta_temp
                    psi[t,j] = max_index
                    psi_duration[t,j] = d

    
    ## Termination
    # Find just the delta after the end of the actual signal
    temp_delta = delta[T:,:]
    #Find the maximum value in this section, and which state it is in:
    idx = np.argmax(temp_delta.flatten())
    pos = idx // temp_delta.shape[0]
    # Change this position to the real position in delta matrix:
    pos = pos + T
    
    #1)
    state = np.argmax(delta[pos,:])
    #2)
    offset = pos
    preceding_state = psi[offset,state]
    #3)
    # state_duration = psi_duration(offset, state);
    onset = offset - psi_duration[offset,state] + 1
    #4)
    qt[int(onset):int(offset+1)] = np.ones_like(qt[int(onset):int(offset+1)]) * state
    #The state is then updated to the preceding state, found above, which must
    #end when the last most likely state started in the observation sequence:
    state = preceding_state
    count = 0
    #While the onset of the state is larger than the maximum duration
    #specified:
    while (onset > 2):

        #2)
        offset = onset - 1
        #     offset_array(offset,1) = inf;
        preceding_state = psi[int(offset),int(state)]
        #     offset_array(offset,2) = preceding_state;
        #3)
        #     state_duration = psi_duration(offset, state);
        onset = offset - psi_duration[int(offset),int(state)] + 1
        #4)
        #     offset_array(onset:offset,3) = state;
        if (onset < 2):
            onset = 1
        qt[int(onset):int(offset+1)] = np.ones_like(qt[int(onset):int(offset+1)]) * state
        state = preceding_state
        count = count + 1
        if (count > 1000):
            break

    
    qt = qt[:T]
    return delta,psi,qt


#%% expand_qt
# 
# Function to expand the derived HMM states to a higher sampling frequency. 
#
# Developed by David Springer for comparison purposes in the paper:
# D. Springer et al., "Logistic Regression-HSMM-based Heart Sound 
# Segmentation," IEEE Trans. Biomed. Eng., In Press, 2015.

# INPUTS:
# original_qt: the original derived states from the HMM
# old_fs: the old sampling frequency of the original_qt
# new_fs: the desired sampling frequency
# new_length: the desired length of the qt signal

# Outputs:
# expanded_qt: the expanded qt, to the new FS and length
#
# Copyright (C) 2016  David Springer
# dave.springer@gmail.com
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

def expand_qt(original_qt, old_fs, new_fs, new_length):

    original_qt = np.transpose(original_qt);
    expanded_qt = np.zeros((new_length,1));
    
    indeces_of_changes = np.argwhere(np.diff(original_qt));
    
    indeces_of_changes = np.append(indeces_of_changes,len(original_qt))
    
    start_index = 0;
    for ii in np.arange(0,len(indeces_of_changes)):
        
        end_index = indeces_of_changes[ii];
        
        mid_point = round((end_index - start_index)/2) + start_index;
        
        value_at_mid_point = original_qt[mid_point];
        
        expanded_start_index = round(np.multiply(np.divide(start_index,old_fs),new_fs));
        expanded_end_index = round(np.multiply(np.divide(end_index,old_fs),new_fs));
        
        if expanded_end_index > new_length:
            expanded_end_index = new_length;
        
        expanded_qt[expanded_start_index:expanded_end_index] = value_at_mid_point;
    
        start_index = end_index;
    
    return expanded_qt

#%% extractPhysioFeaturesFromHsIntervals
#
# This function calculate 20 features based on the assigned_states by running "runSpringerSegmentationAlgorithm.m" function
#
# INPUTS:
# assigned_states: the array of state values assigned to the sound recording.
# PCG: sound recording 
# Fs: frequency of sound recording
# audio_Fs: new downsampled frequency (use 1000Hz)
#
# OUTPUTS:
# features: the obtained 20 features for the current sound recording
#
# Written by: Chengyu Liu, January 22 2016
#             chengyu.liu@emory.edu
# Edited by: Julia Ding, Aug 10 2022
#               julia.ding@emory.edu 
#
# $$$$$$ IMPORTANT
# Please note: the calculated 20 features are only some pilot features, some features maybe
# helpful for classifying normal/abnormal heart sounds, some maybe
# not. You need re-construct the features for a more accurate classification.

def extractPhysioFeaturesFromHsIntervals(assigned_states, PCG): #, Fs, audio_Fs):

#Downsample the recording
#num_samps = round(np.size(PCG)/Fs*audio_Fs)
#PCG = signal.resample(PCG, num_samps);

    # We just assume that the assigned_states cover at least 2 whole heart beat cycle
    indx = np.argwhere(abs(np.diff(assigned_states, axis=0))>0)[:,0] + 1; # find the locations with changed states

#    if assigned_states[0]>0:
    if assigned_states[0] == 3: 
        K=1
    elif assigned_states[0] == 2:
        K=2
    elif assigned_states[0] == 1:
        K=3
    elif assigned_states[0] == 0:
        K=4  
#    elif assigned_states[indx[0]]>0 :
#        if assigned_states[indx[0]] == 3:
#            K=1
#        elif assigned_states[indx[0]] == 2:
#            K=2
#        elif assigned_states[indx[0]] == 1:
#            K=3
#        elif assigned_states[indx[0]] == 0:
#            K=0
#        K= K+1
#    else :
#        if assigned_states[indx[1]] == 3:
#            K=1
#        elif assigned_states[indx[1]] == 2:
#            K=2
#        elif assigned_states[indx[1]] == 1:
#            K=3
#        elif assigned_states[indx[1]] == 0:
#            K=0
#        K= K+1

    indx2                = indx[K-1:]; #get usable states
    rem                  = np.mod(len(indx2),4);
    if rem >0:
        indx2            = indx2[:-rem]; #remove extra values
    A                    = np.reshape(indx2,(int(len(indx2)/4),4)) #'; % A is N*4 matrix, the 4 columns save the beginnings of S1, systole, S2 and diastole in the same heart cycle respectively
    
    # Feature calculation
    m_RR        = round(np.mean(np.diff(A[:,0])));       #mean value of RR intervals
    med_RR      = round(np.median(np.diff(A[:,0])));     #median value of RR intervals
    sd_RR       = round(np.std(np.diff(A[:,0])));        #standard deviation (SD) of RR intervals
    mean_IntS1  = round(np.mean(A[:,1]-A[:,0]));         # mean value of S1 intervals
    med_IntS1   = round(np.median(A[:,1]-A[:,0]));       # median value of S1 intervals
    sd_IntS1    = round(np.std(A[:,1]-A[:,0]));          # SD value of S1 intervals
    mean_IntS2  = round(np.mean(A[:,3]-A[:,2]));         # mean value of S2 intervals
    med_IntS2   = round(np.median(A[:,3]-A[:,2]));       # median value of S2 intervals    
    sd_IntS2    = round(np.std(A[:,3]-A[:,2]));          # SD value of S2 intervals
    mean_IntSys = round(np.mean(A[:,2]-A[:,1]));         # mean value of systole intervals
    med_IntSys  = round(np.median(A[:,2]-A[:,1]));       # median value of systole intervals
    sd_IntSys   = round(np.std(A[:,2]-A[:,1]));          # SD value of systole intervals
    mean_IntDia = round(np.mean(A[1:-1,0]-A[0:-2,3]));   # mean value of diastole intervals
    med_IntDia  = round(np.median(A[1:-1,0]-A[0:-2,3])); # median value of diastole intervals
    sd_IntDia   = round(np.std(A[1:-1,0]-A[0:-2,3]));    # SD value of diastole intervals
    
    # Assign 
    len_hb = np.size(A,0)-1
    R_SysRR = np.zeros((len_hb,1))
    R_DiaRR = np.zeros((len_hb,1))
    R_SysDia = np.zeros((len_hb,1))
    P_S1 = np.zeros((len_hb,1))
    P_Sys = np.zeros((len_hb,1))
    P_S2 = np.zeros((len_hb,1))
    P_Dia = np.zeros((len_hb,1))
    P_SysS1 = np.zeros((len_hb,1))
    P_DiaS2 = np.zeros((len_hb,1))
    
    for i in np.arange(np.size(A,0)-1):
        R_SysRR[i]  = (A[i,2]-A[i,1])/(A[i+1,0]-A[i,0])*100;
        R_DiaRR[i]  = (A[i+1,0]-A[i,3])/(A[i+1,0]-A[i,0])*100;
        R_SysDia[i] = R_SysRR[i]/R_DiaRR[i]*100;
        
        P_S1[i]     = sum(abs(PCG[np.arange(A[i,0],A[i,1])]))/(A[i,1]-A[i,0]);
        P_Sys[i]   = sum(abs(PCG[np.arange(A[i,1],A[i,2])]))/(A[i,2]-A[i,1]);
        P_S2[i]     = sum(abs(PCG[np.arange(A[i,2],A[i,3])]))/(A[i,3]-A[i,2]);
        P_Dia[i]    = sum(abs(PCG[np.arange(A[i,3],A[i+1,0])]))/(A[i+1,0]-A[i,3]);
        if P_S1[i]>0:
            P_SysS1[i] = P_Sys[i]/P_S1[i]*100;
        else:
            P_SysS1[i] = 0;
    
        if P_S2[i]>0:
            P_DiaS2[i] = P_Dia[i]/P_S2[i]*100;
        else:
            P_DiaS2[i] = 0;
    
    m_Ratio_SysRR   = np.mean(R_SysRR);  # mean value of the interval ratios between systole and RR in each heart beat
    med_Ratio_SysRR   = np.median(R_SysRR);  # median value of the interval ratios between systole and RR in each heart beat
    sd_Ratio_SysRR  = np.std(R_SysRR);   # SD value of the interval ratios between systole and RR in each heart beat
    m_Ratio_DiaRR   = np.mean(R_DiaRR);  # mean value of the interval ratios between diastole and RR in each heart beat
    med_Ratio_DiaRR   = np.median(R_DiaRR);  # median value of the interval ratios between diastole and RR in each heart beat
    sd_Ratio_DiaRR  = np.std(R_DiaRR);   # SD value of the interval ratios between diastole and RR in each heart beat
    m_Ratio_SysDia  = np.mean(R_SysDia); # mean value of the interval ratios between systole and diastole in each heart beat
    med_Ratio_SysDia  = np.median(R_SysDia); # median value of the interval ratios between systole and diastole in each heart beat
    sd_Ratio_SysDia = np.std(R_SysDia);  # SD value of the interval ratios between systole and diastole in each heart beat
    
    indx_sys = np.argwhere((P_SysS1>0) & (P_SysS1<100))[:,0];   # avoid the flat line signal
    if len(indx_sys)>1:
        m_Amp_SysS1  = np.mean(P_SysS1[indx_sys]); # mean value of the mean absolute amplitude ratios between systole period and S1 period in each heart beat
        sd_Amp_SysS1 = np.std(P_SysS1[indx_sys]);  # SD value of the mean absolute amplitude ratios between systole period and S1 period in each heart beat
    else:
        m_Amp_SysS1  = 0;
        sd_Amp_SysS1 = 0;
    
    indx_dia = np.argwhere((P_DiaS2>0) & (P_DiaS2<100))[:,0];
    if len(indx_dia)>1:
        m_Amp_DiaS2  = np.mean(P_DiaS2[indx_dia]); # mean value of the mean absolute amplitude ratios between diastole period and S2 period in each heart beat
        sd_Amp_DiaS2 = np.std(P_DiaS2[indx_dia]);  # SD value of the mean absolute amplitude ratios between diastole period and S2 period in each heart beat
    else:
        m_Amp_DiaS2  = 0;
        sd_Amp_DiaS2 = 0;
    
    features = np.array([m_RR, med_RR, sd_RR,  mean_IntS1, med_IntS1, sd_IntS1,  mean_IntS2, med_IntS2, sd_IntS2, mean_IntSys, med_IntSys, sd_IntSys,  mean_IntDia, med_IntDia, sd_IntDia, m_Ratio_SysRR, med_Ratio_SysRR, sd_Ratio_SysRR, m_Ratio_DiaRR, med_Ratio_DiaRR, sd_Ratio_DiaRR, m_Ratio_SysDia, med_Ratio_SysDia, sd_Ratio_SysDia, m_Amp_SysS1, sd_Amp_SysS1, m_Amp_DiaS2, sd_Amp_DiaS2])
    
    
    return features, A
    
#%% extractFreqDomainFeatures

# This function calculates the DFT spectrum features for one recording. 
#
# INPUTS:
# A: array from segmentation code that contains indices for each segment of the cardiac cycle
# recording: a single audio recording 

# OUTPUTS:
# features: frequency domain of state features. Note: there are 77x4 features where:
#       77 refers to intervals of 10Hz frequency bands between 30 and 800 and 
#       4 refers to the different segments of the cardiac cycle (S1, Systole, S2, Diastole)

# Written by: Julia Ding, August 10, 2022

def find_nearest(array, value):
    #array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

def get_mean_freqspectrum(segment, Fs, freq_intervals_start):
    N = len(segment) #number of samples
    yf = rfft(segment)
    xf = rfftfreq(N, 1 / Fs)
    
    spectra = np.zeros((len(freq_intervals_start),1))
    for freq_interval in np.arange(len(freq_intervals_start)):
        #spectra[freq_interval] = np.abs(yf[xf == freq_intervals_start[freq_interval]])
        spectra[freq_interval] = np.abs(yf[find_nearest(xf,freq_intervals_start[freq_interval])])
        
    return spectra
            
def extractFreqDomainFeatures(A, recording, Fs):

    freq_intervals_start = np.arange(30,800,10)
        
    for heart_cycle in np.arange(len(A)-1):
            S1  = recording[np.arange(A[heart_cycle,0],A[heart_cycle,1])]         
            S1_spectra = get_mean_freqspectrum(S1, Fs, freq_intervals_start)
            
            Systole  = recording[np.arange(A[heart_cycle,1],A[heart_cycle,2])]  
            Systole_spectra = get_mean_freqspectrum(Systole, Fs, freq_intervals_start)
            
            S2  = recording[np.arange(A[heart_cycle,2],A[heart_cycle,3])]   
            S2_spectra = get_mean_freqspectrum(S2, Fs, freq_intervals_start)
            
            Diastole  = recording[np.arange(A[heart_cycle,3],A[heart_cycle+1,0])]  
            Diastole_spectra = get_mean_freqspectrum(Diastole, Fs, freq_intervals_start)
            
            spectra_heart_cycle = [S1_spectra, Systole_spectra,S2_spectra,Diastole_spectra]
            spectra_heart_cycle = np.hstack(spectra_heart_cycle) # 77 (freq intervals) x 4 (segment of cardiac cycle)
            
            if heart_cycle == 0:
                spectra_all_heart_cycles = spectra_heart_cycle; 
            else:
                spectra_all_heart_cycles = np.dstack((spectra_all_heart_cycles, spectra_heart_cycle)) # freq intervals x segment of cardiac cycle x cardiac cycle             
            
    
    return np.mean(spectra_all_heart_cycles, axis=2)

#%% get_crafted_features

def get_crafted_features(data, recording):

    Fs = get_frequency(data)

    #lowpass and highpass filter + spike removal
    despiked_data = preprocess_data(recording, Fs);
    
    #get PCG Features
    audio_segmentation_Fs = 50; # The downsampled frequency, set to 50 in Springer paper
    PCG_Features = getSpringerPCGFeatures(despiked_data, Fs, audio_segmentation_Fs);
    
    #get heart rate
    heartRate, systolicTimeInterval = getHeartRateSchmidt(despiked_data, Fs, figures=False);
    
    ##### load in pretrained logistic regression model values from Potes et al 2016 #####
    
    #b_matrix = loadmat('Springer_B_matrix.mat')['Springer_B_matrix']
    b_matrix = np.array([[ 0.63857214,  0.09177584,  0.24603384,  0.26193571], [-2.21245101,  1.58873217, -0.43129579,  2.80462864],[ 1.28041437, -0.75384565, -0.13265229, -1.26759852],[-0.07861093,  0.54823381,  0.2668253 ,  0.02630179], [-0.04995835,  0.26066891, -0.24690963,  0.44108376]]);
    
    # pi_vector = loadmat('Springer_pi_vector.mat')['Springer_pi_vector']
    pi_vector = np.array([[0.25, 0.25, 0.25, 0.25]]);
    
    # total_obs_distribution = loadmat('Springer_total_obs_distribution.mat')['Springer_total_obs_distribution']
    total_obs_distribution = np.array([[np.array([[ 3.05826000e-15, 4.93091067e-16, -6.21845104e-18, -1.53149158e-15]])],[np.array([[0.99841395, 0.90354361, 0.7425632 , 0.70693174],[0.90354361, 0.99841395, 0.85038686, 0.79414174],[0.7425632 , 0.85038686, 0.99841395, 0.706813  ],[0.70693174, 0.79414174, 0.706813  , 0.99841395]])]], dtype=object);
    
    ##### get states #####
    
    _, _, qt = viterbiDecodePCG_Springer(PCG_Features, pi_vector, b_matrix, total_obs_distribution, heartRate, systolicTimeInterval, audio_segmentation_Fs)
    
    assigned_states = expand_qt(qt, audio_segmentation_Fs, Fs, len(despiked_data));
    
    ##### get features 

    if np.all(assigned_states.T[0]==assigned_states.T[0][0]):
        return None
    
    # physician-driven features 
    physio_features, A = extractPhysioFeaturesFromHsIntervals(assigned_states, despiked_data) #, Fs, audio_Fs)
     
    #frequency domain features for state (concept from PCG Classification Using Multidomain Features and SVM Classifier by Tang et al, 2018)
    freq_state_features = extractFreqDomainFeatures(A, recording, Fs).flatten() #, Fs, audio_Fs) 
    
    features = np.concatenate((physio_features, freq_state_features))
    
    return np.asarray(features, dtype=np.float32)
    

def one_patient_crafted_features(i, data_folder, patient_files, num_patient_files, crafted_features):
    print(str(i) + '/' + str(num_patient_files))
    # Load the current patient data and recordings.
    current_patient_data = load_patient_data(patient_files[i])
    current_recordings = load_recordings(data_folder, current_patient_data)

    # fs = get_frequency(current_patient_data)
    # fs_resamp = 1000
    # max_len = 60 # second
    
    # Extract crafted features. #336 features per recording
    locations = get_locations(current_patient_data)
    for j, location in enumerate(locations):
        if 'AV' in location:
            location_ind = 0
        elif 'PV' in location:
            location_ind = 1
        elif 'TV' in location:
            location_ind = 2
        elif 'MV' in location:
            location_ind = 3
        
        this_recording = current_recordings[j]#[:int(max_len*fs)]
        current_features = get_crafted_features(current_patient_data, this_recording)
        if current_features is not None:
            crafted_features[i,336*location_ind:336*(location_ind+1)] = current_features

def calculate_and_save_crafted_features(data_folder):
    # Find the patient data files.
    patient_files = find_patient_files(data_folder)
    num_patient_files = len(patient_files)
    
#%% Get all features for parent neural net
    crafted_features = np.zeros((num_patient_files, 336*4))
                
    with Pool(4) as p:
        p.map(partial(one_patient_crafted_features, data_folder=data_folder, patient_files=patient_files, \
            num_patient_files=num_patient_files, crafted_features=crafted_features), np.arange(num_patient_files))

    with open('crafted_features.pk', 'wb') as file:
        pk.dump(crafted_features, file)
