####### PROCESSING_CODE.PY ######### 
# Updated by Stefano Magni 5 Apr 2021. 17:22

# This module contains the following functions: 
# extract_windows(ecg_rec, ecg_fs, window_len, channels, mode_use = 'train')
# dataset_generator(train_set, json_directory, json_save = True)
# create_header_label_data(header_list, labels_map, label_df = True)

# filtering(ecg, current_fs)
# from_ecg_to_rr(norm_filt_ecg, current_fs, len_time_signal)
# time_domain_features(rr_list, current_dir_filename)
# signal_features(header_list, recording_list, labels_map)

# name_list(head_list)
# move_to_test_folder(test_dir, list_header_dirs)
# move_data_back(directory_from, directory_to)
# new_list_test(test_dir)
# define_classes(weights_dir)

# POTENTIAL ISSUES 
# TO DOs 


#### LOAD REQUIREDs
import os, pickle
import biosppy
import shutil
import tensorflow as tf 
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import math  # Needed only when taking age and gender from headers
import json
from scipy import signal
from helper_code import *
import challengePackage.eval_code as ec
import challengePackage.prediction_code as prc
import challengePackage.model_training_code as mtc
from hrvanalysis import get_time_domain_features
from sklearn import preprocessing
 
############ Creation of list of classes ############
def classes_list(classes):
    '''
        In processing_code (pc)

        This function is useful in the stratification, during which is needed an hashable object to perform MLB and MSKF
        It takes the classes set, returned by the load_weights function of the challenge, and flattens the set making it into a list of strings,
        removing and replacing the labels that are considered as equivalent.
        
        The flat_list that is returned is a list of strings with only the 26 classes, whith equivalent ones replaced by an unique class
    '''
    flat_list = [item for sublist in classes for item in sublist]
  
    for i, label in enumerate(flat_list):
        if label == '733534002' or label == '164909002':
            flat_list[i] = '164909002|733534002'
        elif label == '713427006' or label == '59118001':
            flat_list[i] = '59118001|713427006'
        elif label == '284470004' or label == '63593006':
            flat_list[i] = '284470004|63593006'
        elif label == '17338001' or label == '427172004':
            flat_list[i] = '17338001|427172004'


    flat_list = list(dict.fromkeys(flat_list))

    return flat_list


####### SAVING RESULTS OF TRAINING #######
def save_fold_model(filename_dir, model, labels_map, leads, scores, ix_fold, ix_dataset):
    dictionary_tmp = {}
    dictionary_tmp['leads'] = leads
    dictionary_tmp['classes'] = [*labels_map]
    dictionary_tmp['scores'] = scores
    dictionary_tmp['ix_fold'] = ix_fold
    dictionary_tmp['ix_dataset'] = ix_dataset

    # saving json
    with open(filename_dir + '.json', 'w') as json_file:
        json.dump(dictionary_tmp, json_file)

    model.save(filename_dir + '.h5')

def extract_windows(ecg_rec, ecg_fs, window_len, channels, mode_use='train'):
    '''
        In processing_code (pc)

        Dependences: 
            signal, np, prediction_code (in challengePackage) 
        Args: 
            ecg_rec: recording of ecg (array) to be windowed 
            signal_fs: sampling frequency of the signal 
            channels: channels of the signal (leads)
            mode_use: string describing if used during training or inference
        Returns: 
            windows: array of windows of the specified size
    '''
    ######### STEP 0 : Define all variables and the signal########
    sampling_freq = 257
    O = 256
    W = window_len
    CH = channels
    current_fs = ecg_fs
    current_sig_len = ecg_rec.shape[1]
    current_time_sec = current_sig_len / current_fs
    num_samples_new = int(current_time_sec * sampling_freq)
    # print(num_samples_new)

    current_signal = ecg_rec

    #### STEP 1 ####
    signal_res = signal.resample(current_signal, num_samples_new, axis=1)

    #### STEP 2 ####
    # If the window is smaller than W is padded anyways 
    # Initialise vectors
    current_windows = []

    # if the number of samples is smaller then a window it is zero padded. 
    # shared Path between train and test
    if num_samples_new < W:
        P = 1  # only one window
        start_window = np.zeros((CH, W))
        start_window[:, 0:num_samples_new] = current_signal[:, 0:num_samples_new]
        current_windows.append(start_window)
    elif num_samples_new > W and mode_use == 'train':
        # add this from dataGenerator
        P = 1  # only one window even if the signal is longer
        start_window = np.zeros((CH, W))
        start = np.random.randint(1, current_sig_len - W)
        end = start + W
        start_window[:, 0:W] = current_signal[:, start:end]

        current_windows.append(start_window)
    elif num_samples_new > W and mode_use == 'test':
        N = num_samples_new
        P = np.ceil((N - W) / (W - O)) + 1  # numer of patches.
        # function is checked and works fine. Note that 
        s, e = prc.get_windows(num_samples_new, W, O)

        ith_w = 1
        for st, en in zip(s, e):
            start_window = np.zeros((CH, W))
            start_window[:, 0:W] = current_signal[:, st:en]
            current_windows.append(start_window)

            # current_id_signal.append(str(current_name +'_'+ str(i)))
            ith_w += 1

            # the code before is channel first and we make it channel last
    windows = np.array(current_windows)
    reshaped_windows = np.moveaxis(windows, 1, -1)
    rw_shape = reshaped_windows.shape
    # DEBUGGING: print P 
    # print(f'The number of expected windows is: {P}. Function is in {mode_use} mode')
    # print(f'the shape of the windows is: {rw_shape}. Function is in {mode_use} mode')

    # windows is the array given as input to the model. 
    # shape should be (num_windows, window_len, channels)
    return reshaped_windows.tolist()


######## CREATE DATA STRUCTURE ##########
def dataset_generator(train_set, json_directory, json_save=True):
    '''
        In processing_code (pc)
        
        UPDATE for official challenge:
        other than NSR (overrepresented in SPB-XL, HR files), 
        we take only 1/3 of Sinus Bradicardia (SB), overrepresented in JS data, Ningbo dataset

        Function that takes the train set and performs dataset split as follow: 
        - Dataset 1: 1/3 of sinus rythm + rest of the dataset
        - Dataset 2: 1/3 of sinus rythm + rest of the dataset
        - Dataset 3: 1/3 of sinus rythm + rest of the dataset
        
        Dependences: 
            json, pandas, tqdm, numpy 
        Args: 
            train_set: dataframe with all the labels inside the dataset 
            json_directory: where to save datasets
            json_save: boolean. Set to true to save json files with partions
        Returns: 
            dataset_1, dataset_2, dataset_3

    '''

    # STEP 0. Normal class and Sinus Bradycardia are defined (NSR and SB)
    sinus_rhythm = set(['426783006'])
    sinus_bradycardia = set(['426177001'])
    # Let's see if reducing these prevents overfitting
    #t_abnorm = '164934002'
    #sinus_tach = '427084000'
    #atrial_flutter = '164890007'
    #left_axis_dev = '39732003'


    # dataframe in input is shuffled and index is reset
    train_set = train_set.sample(frac=1).reset_index(drop=True)

    # Work on sinus rithm 
    # a dictionary is created from the dataframe
    label_dict = train_set.copy()
    label_dict = label_dict.to_dict()

    # STEP 1. NSR index is taken from the dictionary and used to filter the dataframe
    sinus_rhythm_index = []
    sinus_bradycardia_index = []
    #t_abnorm_index = []
    #sinus_tach_index = []
    #atrial_flutter_index = []
    # left_axis_dev_index = []
    print('Extracting normal class indexes ...')
    for indx, label in tqdm(label_dict['labels'].items()):
        for j in range(len(label)):
            if (label[j] == sinus_rhythm) and len(label) == 1:
                sinus_rhythm_index.append(indx)
            elif (label[j] == sinus_bradycardia) and len(label) == 1:
                sinus_bradycardia_index.append(indx)
            """ elif (label[j] == t_abnorm) and len(label) == 1:
                t_abnorm_index.append(indx)
            elif (label[j] == sinus_tach) and len(label) == 1:
                sinus_tach_index.append(indx)
            elif (label[j] == atrial_flutter) and len(label) == 1:
                atrial_flutter_index.append(indx)
            elif (label[j] == left_axis_dev) and len(label) == 1:
                left_axis_dev_index.append(indx) """                                
    sinusal_df = train_set.iloc[sinus_rhythm_index]
    brady_df = train_set.iloc[sinus_bradycardia_index]
    """ tabnorm_df = train_set.iloc[t_abnorm_index]
    sinus_tach_df = train_set.iloc[sinus_tach_index]
    atrial_flutter_df = train_set.iloc[atrial_flutter_index]
    left_axis_dev_df = train_set.iloc[left_axis_dev_index] """
    print('Index extraction done.')

    # STEP 2. The sinusal and bradycardiac rithm are dropped from label_dataframe
    train_set = train_set.drop(index=sinusal_df.index)
    train_set = train_set.drop(index=brady_df.index)
    """ train_set = train_set.drop(index=tabnorm_df.index)
    train_set = train_set.drop(index=sinus_tach_df.index)
    train_set = train_set.drop(index=atrial_flutter_df.index)
    train_set = train_set.drop(index=left_axis_dev_df.index) """
    # STEP 3. Sinusal and bradycardiac rithm are divided in 3 parts 
    sinusal_df_1, sinusal_df_2, sinusal_df_3 = np.array_split(sinusal_df.sample(frac=1), 3)
    brady_df_1, brady_df_2, brady_df_3 = np.array_split(brady_df.sample(frac=1), 3)
    """ tabnorm_df_1, tabnorm_df_2, tabnorm_df_3 = np.array_split(tabnorm_df.sample(frac=1), 3)
    sinus_tach_df_1, sinus_tach_df_2, sinus_tach_df_3 = np.array_split(sinus_tach_df.sample(frac=1), 3)
    atrial_flutter_df_1, atrial_flutter_df_2, atrial_flutter_df_3 = np.array_split(atrial_flutter_df.sample(frac=1), 3)
    left_axis_dev_df_1, left_axis_dev_df_2, left_axis_dev_df_3 = np.array_split(left_axis_dev_df.sample(frac=1), 3) """
    #print(f'{len(sinusal_df_1)} is no. of 1/3 NSR')
    #print(f'{len(brady_df_1)} is no. of 1/3 Brady')
    #print((f'{len(train_set)} is no. of data with no BS or NSR'))

    # STEP 4. The label dataframe is appended to each of the sinusal_df dataframes, then dataset_n is appended to the brady df
    dataset_1 = sinusal_df_1.append(train_set, ignore_index=True).sample(frac=1)
    dataset_1 = brady_df_1.append(dataset_1, ignore_index=True).sample(frac=1)
    """ dataset_1 = tabnorm_df_1.append(dataset_1, ignore_index=True).sample(frac=1)
    dataset_1 = sinus_tach_df_1.append(dataset_1, ignore_index=True).sample(frac=1)
    dataset_1 = atrial_flutter_df_1.append(dataset_1, ignore_index=True).sample(frac=1)
    dataset_1 = left_axis_dev_df_1.append(dataset_1, ignore_index=True).sample(frac=1) """
    dataset_1 = dataset_1.reset_index(drop=True)
    print(f'{len(dataset_1)} is dataset len')

    dataset_2 = sinusal_df_2.append(train_set, ignore_index=True).sample(frac=1)
    dataset_2 = brady_df_2.append(dataset_2, ignore_index=True).sample(frac=1)
    """ dataset_2 = tabnorm_df_2.append(dataset_2, ignore_index=True).sample(frac=1)
    dataset_2 = sinus_tach_df_2.append(dataset_2, ignore_index=True).sample(frac=1)
    dataset_2 = atrial_flutter_df_2.append(dataset_2, ignore_index=True).sample(frac=1)
    dataset_2 = left_axis_dev_df_2.append(dataset_2, ignore_index=True).sample(frac=1) """
    dataset_2 = dataset_2.reset_index(drop=True)

    dataset_3 = sinusal_df_3.append(train_set, ignore_index=True).sample(frac=1)
    dataset_3 = brady_df_3.append(dataset_3, ignore_index=True).sample(frac=1)
    """ dataset_3 = tabnorm_df_3.append(dataset_3, ignore_index=True).sample(frac=1)
    dataset_3 = sinus_tach_df_3.append(dataset_3, ignore_index=True).sample(frac=1)
    dataset_3 = atrial_flutter_df_3.append(dataset_3, ignore_index=True).sample(frac=1)
    dataset_3 = left_axis_dev_df_3.append(dataset_3, ignore_index=True).sample(frac=1) """
    dataset_3 = dataset_3.reset_index(drop=True)

    # if the boolean specified is true then the dataset is saved 
    if json_save:
        dataset_1_json = dataset_1['file_names'].to_dict() #orient="index"
        dataset_2_json = dataset_2['file_names'].to_dict()
        dataset_3_json = dataset_3['file_names'].to_dict() 

        if not os.path.isdir(json_directory):
            os.mkdir(json_directory)
        
        with open(str(json_directory + '/dataset_1.json'), 'w') as json_file:
            json.dump(dataset_1_json, json_file)
        with open(str(json_directory + '/dataset_2.json'), 'w') as json_file:
            json.dump(dataset_2_json, json_file)
        with open(str(json_directory + '/dataset_3.json'), 'w') as json_file:
            json.dump(dataset_3_json, json_file)

    # return 3 datasets. 
    return dataset_1, dataset_2, dataset_3


####### FUNCTIONS TO DEAL WITH HEADER DATA IN THE TRAINING PHASE ####### 
"""
def define_classes(weights_dir):
    '''
        Inside processing code (pc)

        If you want to obtain list of ordered 26 classes just type [*labels_map]
        This function should be used whenever the list of classes is needed

        Dependences: 
            eval_code from challengePackage
        Args: 
            weights_dir: directory where weights.csv resides (challengePackage)
        Returns: 
            labels_map: dictionary that maps snowmed codes to integers
            
    '''
    # print('Extracting classes...')
    # Defining NSR class and equivalent classes 
    sinus_rhythm = '426783006'
    #equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
    # Lists of snowmed codes and abbreviations taken from dx mapping scored
    #Abbreviation = ['IAVB', 'AF', 'AFL', 'Brady', 'CRBBB', 'IRBBB', 'LAnFB', 'LAD', 'LBBB', 'LQRSV', 'NSIVCB', 'PR',
    #                'PAC', 'PVC', 'LPR', 'LQT', 'QAb', 'RAD', 'RBBB', 'SA', 'SB', 'NSR', 'STach', 'SVPB', 'TAb', 'TInv',
    #                'VPB']
    Abbreviation = ['AF', 'AFL' , 'BBB' , 'Brady' , 'CLBBB|LBBB', 'CRBBB|RBBB', 'IAVB', 'IRBBB', 'LAD', 'LAnFB', 'LPR', 'LQRSV', 'LQT', 'NSIVCB', 'NSR', 'PAC|SVPB', 'PR', 'PRWP', 'PVC|VPB', 'QAb', 'RAD', 'SA', 'SB', 'STach', 'TAb', 'TInv']
    snowmed_codes = ['164889003', '164890007', '6374002', '426627000', '733534002|164909002', '713427006|59118001', '270492004', '713426002', '39732003', '445118002', '164947007', '251146004', '111975006', '698252002', '426783006', '284470004|63593006', '10370003', '365413008', '427172004|17338001', '164917005', '47665007', '427393009', '426177001', '427084000', '164934002', '59931005']
    # snowmed_codes = ['270492004', '164889003', '164890007', '426627000', '713427006', '713426002', '445118002',
    #                  '39732003', '164909002', '251146004', '698252002', '10370003', '284470004', '427172004',
    #                  '164947007', '111975006', '164917005', '47665007', '59118001', '427393009', '426177001',
    #                  '426783006', '427084000', '63593006', '164934002', '59931005', '17338001']
    #
    #  Load the scored classes for the Challenge metric only
    classes = ec.load_26_classes(weights_dir)
    # Sort classes 
    if all(is_integer(x) for x in classes):
        classes = sorted(classes, key=lambda x: int(x))  # Sort classes numerically if numbers.
    else:
        # print('Non all integer!')
        classes = sorted(classes)  # Sort classes alphanumerically otherwise.

    # print('Computing dictionary maps ...')
    # map from labels to integers
    labels_map = {classes[i]: i for i in range(len(classes))}

    snowmed_names = {}
    i = 0
    for sn_code in snowmed_codes:
        snowmed_names[sn_code] = Abbreviation[i]
        i += 1

    #inv_labels_map = {}
    #for k in range(len(classes)):
        #row_dict = {}
        #row_dict['snowmed'] = classes[k]
        #row_dict['abbrev'] = snowmed_names[classes[k]]
        #inv_labels_map[k] = row_dict

        # map from integers to labels
    # inv_labels_map = {i:classes[i] for i in range(len(classes))}

    #return labels_map, inv_labels_map
    return labels_map
    """
    
def create_header_label_data(header_list, classes, label_df = True):
    ''' 
        In processing_code (pc)

        This function takes all the headers and generates dictionaries to be used in the code. 
        This step is needed by the DataGenerator only. This was done so to avoid recursive loading of headers at training time. 

        TO DO: pass to the function window len and sampling frequency. 

        Dependences: 
            tqdm, helper_code (module in root dir), eval_code (module in challengePackage)

        Args: 
            header_list: list with header files with ext.  
            labels_map: used to define the 24 classes evaluated in the challenge (for one hot encoding)
            label_df: Boolean to return also the dataframe from the function. 
        Returns:
            header_data: dictionary with name of file as key and as value a dictionary containing: 
                - time in seconds (duration)
                - sampling frequency
                - labels (list, one hot encoded)
                - leads (list)
             df: dataframe that has two cols. Filenames and labels. 

        Updated 6 mar 2021. 00:14 - Stefano Magni

    '''
    # STEP 0. Initialize dictionaries header_data for the DataGenerator
    # label_dict is used to save the labels for stratified K fold 
    header_data = {}
    label_dict = {}
    #label_dict_onehot = {}

    # STEP 1. Dataframe is generated from the header list (file_names)
    df_dirs_labels = pd.DataFrame(header_list, columns=['file_names'])
    print('Extracting labels ...')
    # rec_list = name_list(header_list) ONLY IF NEEDED
    # For each of the elements in the header list the data is loaded
    for i in tqdm(range(len(header_list))):
        # Initialise dictionary for the current row 
        row_list = {}
        # load current header from which all the information can be extracted. 
        current_header = load_header(header_list[i])
        
        # STEP 2. Time duration of the signal is computed. [s]
        current_fs = get_frequency(current_header)
        current_sig_len = get_num_samples(current_header)
        
        # STEP 3. Compute start and end samples of the window.
        if current_fs != 0:
            current_time_sec = current_sig_len / current_fs
            num_samples_new = int(current_time_sec * 257)

        if num_samples_new < 4096:
            # P = 1  # only one window
            left_eq = num_samples_new
            start = 0
            end = num_samples_new
        elif num_samples_new > 4096:
            left_eq = 4096
            start = np.random.randint(1, current_sig_len - 4096)
            end = start +  4096
        
        row_list['bounds'] = [left_eq, start, end]

        row_list['time_sec'] = current_time_sec
        row_list['fs'] = current_fs

        # STEP 3. Get labels for each header in the following format: 
        # For the dictionary the encoded version is required. 
        # For the dataframe only the list of SNOWMED CT codes is required. 
        # equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
        #labels = get_labels(current_header)  # header_list[i]

        # STEP 4. label encoding --> we use the official function of the challenge for one hot encoding in ordered way
        num_classes = len(classes)
        one_hot_labels = np.zeros((num_classes), dtype=np.bool)
        y = set(get_labels(current_header))
        for j, x in enumerate(classes):
            if x & y:
                one_hot_labels[j] = 1
        
        current_dir = header_list[i]
        
        if any(one_hot_labels):

            row_list['labels'] = one_hot_labels.tolist()
            # STEP 5. Add leads from header
            available_leads = get_leads(current_header)
            row_list['leads'] = available_leads
            
            # STEP 6. Now we save the dictionary with key equal to the file name dir 
            root_hea, extension = os.path.splitext(current_dir)
            header_data[root_hea] = row_list
        

        # STEP 7. I save the data only when belonging to one of the classes we are interested in
        
        # This is used by the dataframe structure, not by dictionary
        # Not one ho encoded (needed for stratified k-fold)
        classes_flat = classes_list(classes)
        my_labels = list()
        for j, x in enumerate(classes_flat):
            if one_hot_labels[j] == True:
                my_labels.append(x)
                
        label_dict[str(current_dir)] = my_labels
        #label_dict_onehot[str(current_dir)] = one_hot_labels

    # create column of the dataframe with the dictionary just created
    df_dirs_labels['labels'] = df_dirs_labels['file_names'].map(label_dict)
    #df_dirs_labels['one_hot_labels'] = df_dirs_labels['file_names'].map(label_dict_onehot)
    for i in range(len(df_dirs_labels.labels)):
        # Data without any of the labels of the evaluation metric are dropped.
        # do not need to check classes belong to the 26 measured, done before.
        if len(df_dirs_labels.labels[i]) == 0:
            df_dirs_labels.drop(i, inplace=True)

    # I can decide not to return the label dataframe 
    if label_df:
        return df_dirs_labels, header_data
    else:
        return header_data

"""
def label_encoding(classes, current_label):
    '''
    In processing code (pc)
    Dependences: 
        eval_code (in challengePackage), numpy
    Args: 
        labels_map: dictionary with classes mapping
        current_label: label taken from the header 

    Returns: 
        label: one hot encoded label

    Updated version to adapt to the new code

    '''
    my_labels = []
    # Define equivalent classes
    #equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
    # current classes is ordered like labels_map
    num_classes = len(classes)
    # equally evaluated classes are made as one in classes by making them belong to the same set

    # initialize label vector 
    label_enc = np.zeros((num_classes), dtype=np.bool)
    # if the class is present 1 is put in the right position
    for j, x in enumerate(classes):
        # if x in labels:
        if x in current_label:
            # print('a class was found to be in the challenge')
            label_enc[j] = 1
            # print(j)
            my_labels.append(x)
            # print(x)

    return label_enc, my_labels
"""

def name_list(head_list):
    '''
        In processing_code (pc)

        Dependences: 
            os
        Args:
             test_dir: directory of test from which the list is defined
        Returns:
            part_test: list of root file name directories corresponding to test folder

    '''
    name_list = []
    for i in range(len(head_list)):
        root, extension = os.path.splitext(head_list[i])
        name_list.append(root)

    return name_list


####### CODE ONLY FOR DEVELOPING MODE #######
def from_model_to_weights(dir_models): 

    for model in os.listdir(dir_models):
        root, extension = os.path.splitext(model)
        if not root.startswith('.') and extension == '.h5':
            curr_model_file = os.path.join(dir_models, root + '.h5')
            save_dir = os.path.join(dir_models, root + '_weights.h5')
            if os.path.isfile(curr_model_file):
                curr_model = tf.keras.models.load_model(curr_model_file)
                curr_model.save_weights(save_dir)

    
def move_to_test_folder(test_dir, list_header_dirs):
    '''
        In processing_code (pc)

        Dependences: 
            shutil, os
        Args: 
            test_dir: directory where hold-out test data is placed. 
            list_header_dirs: list of directory files without extention to be moved. 
        Returns: 
            performs the move of the files in the defined directory for both headers and recordings. 
    '''

    # The test directory specified must be empty so to have only exactly the desired files. 
    directory = os.listdir(test_dir)
    if len(directory) == 0:
        print("Empty directory")

        for header_dir in range(len(list_header_dirs)):
            current_signal_dir = list_header_dirs[header_dir] + '.mat'
            current_header_dir = list_header_dirs[header_dir] + '.hea'

            shutil.move(current_header_dir, test_dir)
            shutil.move(current_signal_dir, test_dir)

    # error is prompted if otherwise
    else:
        print("Not empty directory. Exiting ...")


def move_data_back(directory_from, directory_to):
    '''
        In processing_code (pc)

        Dependences: 
            shutil, os

        Args: 
            directory_from: directory where to take the files from
            directory_to: directory where to put the files in
        
        This function was created in order to move files from test folder when developing mode is turned off. 
        Some errors are promted if the directory_from is empty and when process is completed a validation message is returned. 

    '''
    directory = os.listdir(directory_from)
    if len(directory) != 0:
        print("Non Empty directory specified. Proceding ...")
        for f in directory:
            direct = directory_from + '/' + f
            shutil.move(direct, directory_to)

    directory = os.listdir(directory_from)
    if len(directory) == 0:
        print("Now it is an Empty directory")


def new_list_test(test_dir):
    '''
        In processing_code (pc)

        Args:
             test_dir: directory of test from which the list is defined
        Returns:
            part_test: list of root file name directories corresponding to test folder

    '''
    part_test_hea, part_test_recs = find_challenge_files(test_dir)

    part_test = []
    for i in range(len(part_test_hea)):
        root, extension = os.path.splitext(part_test_hea[i])
        part_test.append(root)

    return part_test


###### DEPRICATED #######
# this becomes load labels

def one_hot_snowmed(labels_map, labels):
    '''
        In processing_code (pc)

        This function is necessary to perform encoding item by item. 

        Args: 
            labels_map: dictionary that maps nowmed codes to integers (alphabetical order)
            labels: list of labels upon which return the output. 
            
        Returns:
            encoding: array of length num_classes. 1 signfies the presence of the class, 0 otherwise. 
            
    '''
    encoding = np.zeros(len(labels_map), dtype=np.bool)
    for label in labels:
        encoding[labels_map[label]] = 1

    return encoding


########### TIME FEATURES ##################
######## FUNCTIONS TO DEFINE SIGNAL FEATURES ##########
def signal_features(header_list, recording_list, model_dir): 

    '''
        Creates a dictionary of features to be given to the DataGenerator. 
        Dependences: 
            helper code, model training code, os, pandas, sklearn.preprocessing, tqdm
        Args:
            header_list: Header dirs
            recording_list: recording dirs
            labels_map: mapping from SNOWMED to int
        Returns:
            features_dict: dictionary with all the features
    '''
    df_features = pd.DataFrame()
    # STEP 1. Iterate over all files in input list. 
    for i in tqdm(range(len(header_list))): 
        # STEP 2. Load header and recording
        current_header = load_header(header_list[i])
        current_ecg = load_recording(recording_list[i]) 
        # STEP 3. Time duration of the signal is computed. [s]
        current_fs = get_frequency(current_header)
        current_sig_len = get_num_samples(current_header)
        current_time_sec = current_sig_len / current_fs 
        # STEP 4. Preprocessing of the signal before filtering. 
        my_leads = ('II')
        ecg_ord = mtc.subset_ch_recordings(get_leads(current_header), current_ecg, my_leads)
        # STEP 5. Filtering and normalization 
        filtered_norm_ecg = filtering(ecg_ord, current_fs)
        # STEP 6. RR series 
        rr_series_list = from_ecg_to_rr(filtered_norm_ecg, current_fs, current_time_sec)
        # STEP 7. Feature Extraction 
        current_dir = header_list[i]
        root_hea, extension = os.path.splitext(current_dir)
        time_features = time_domain_features(rr_series_list, root_hea)
        # STEP 8. Append to Dataframe
        df_features = df_features.append(time_features, ignore_index=True)
    
    # STEP 9. Outside the for cycle I set name column as index 
    df_features = df_features.set_index('name')
    
    pickle_filename = os.path.join(model_dir, 'minmax_scaler.pkl')
    min_max_scaler = preprocessing.MinMaxScaler()
    fitted_scaler = min_max_scaler.fit(df_features[:].values)
    # Saving the minmax scaler
    pickle.dump(fitted_scaler, open(pickle_filename, 'wb'))
    df_features[:] = fitted_scaler.transform(df_features[:].values)
    features_dict = df_features.to_dict('index')

    return features_dict

def time_domain_features(rr_list, current_dir_filename): 
    '''
        Takes RR series and extract features that will be output as a dictionary. 
        One column will be 'name' which is the directory of the file without extention. 

        Dependences: 
            os, 
            from hrvanalysis import get_time_domain_features
        Args: 
            rr_list: rr series list
            current_dir_filename: name of dir file without ext
        Returns: 

    '''

    # STEP 1. Compute time domain indices
    feature_dictionary = get_time_domain_features(rr_list)
    # STEP 2. split text for root name 
    root_hea, extension = os.path.splitext(current_dir_filename)
    feature_dictionary['name'] = root_hea

    # STEP 3. TO DO: Add any other time feature based on RR series

    return feature_dictionary 

def from_ecg_to_rr(norm_filt_ecg, current_fs, len_time_signal): 
    '''
        Take the ecg signal as input and return the RR series using find peaks. 
        The detection of R peaks is based on the definition of a time and amplitude threhsolds. 

        Dependences: 
            biosppy, signal
        Args: 
            norm_filt_ecg: filtered ecg_signal and normalized
            current_fs: sampling frequency of the signal
            len_time_signal: time duration of the signal in seconds 

        Returns: 
            rr_series: series in list format for feature extraction

    '''
    # STEP 1. Define the height threhsold. 
    peak_height = np.mean(norm_filt_ecg[0])+2*np.std(norm_filt_ecg[0])
    # STEP 2. Define time threshold. 
    time_th = 0.15*current_fs
    # STEP 3. Find R peaks
    rpeaks, heights = signal.find_peaks(norm_filt_ecg[0], height=peak_height, distance=time_th)

    # STEP 4. If the number of R peaks is lower than the threshold a different method is used. (Pan Tompkins based)
    if len(rpeaks)/ len_time_signal  <0.3 or len(rpeaks)<3: 
        rpeaks = biosppy.signals.ecg.ecg(norm_filt_ecg[0], sampling_rate= current_fs, show=False)[2]
    # elif len(rpeaks) < 3: 
    #     peak_height_II = np.mean(norm_filt_ecg[0])+2*np.std(norm_filt_ecg[0])
    #     rpeaks, heights = signal.find_peaks(norm_filt_ecg[0], height=peak_height_II, distance=time_th)

  
    # STEP 5. Find time when rpeaks happen [ms]
    rpeaks_sec = np.array([x/current_fs*1000 for x in rpeaks])
    rr_series = np.diff(rpeaks_sec)

    rr_series = rr_series.tolist()

    # STEP 6. Outlier removal using a threhsold. Taken From: https://www.frontiersin.org/articles/10.3389/fphys.2012.00045/full
    rr_series_corrected = rr_series.copy()
    for r in range(len(rr_series)):
        if abs(rr_series[r] - np.mean(rr_series)) > 5*np.std(rr_series):
            rr_series_corrected.remove(rr_series[r])
            
    return rr_series_corrected

def filtering(ecg, current_fs):
    '''
        Filter the input singal for QRS complex detection and normalization in the range -1,1. 
        IIR butterworth 2nd order filter bandpass [5, 15] Hz

        Dependences: 
            signal, numpy
        Args: 
            ecg: input signal to be filtered 
        
        Returns: 
            filtered_ecg: filtered version of the input singal
    '''
    ####### PARAMETERS OF THE FILTER ##########
    butter_order = 2
    f_cut_low = 5
    f_cut_high = 15
    band_type = 'bandpass'
    ###########################################

    # STEP 1. Definition of the filter using sos to avoid numerical errors. 
    current_sos = signal.butter(butter_order, [f_cut_low, f_cut_high], band_type, fs = current_fs, output = 'sos', analog = False)

    # STEP 2. Filtering of the signal on the first axis
    filtered_ecg = signal.sosfiltfilt(current_sos, ecg, axis = 1)
    
    # STEP 3. Signal normalization range [-1, 1]
    norm_filt_ecg = filtered_ecg - np.mean(filtered_ecg)
    my_max = np.amax(norm_filt_ecg)
    my_min = np.amin(norm_filt_ecg)
    norm_filt_ecg = 2 * ((norm_filt_ecg - my_min) / (my_max - my_min)) - 1

    return norm_filt_ecg 
