####### MODEL_TRAINING_CODE.PY ######### 
# Updated 24 Mar 2021. 17:36 - Stefano Magni
# This module contains the following functions: 
#
# class DataGenerator(tf.keras.utils.Sequence)
# model_training_steps(training_gen,validation_gen, es = True, sd = True, num_ResBs= 8, 
# channels = nch, window_len = wl, num_classes = ncl,  bs= b_s, epochs= 50, use_gpu = use_gpu)
# define_callbacks(earlyStopping=True, step_d=True)
# step_decay(epoch)
# step_decay_fine_tuning(epoch) : smaller learning rate
# resnet_se_modified(N=8, ch=12, win_len=4096, num_cat_vars= 2, classes=24)
# ResBs_Conv(block_input, num_filters)
# ResBs_Identity(block_input, num_filters)
# se_block(block_input, num_filters, ratio=16)
# subset_ch_recordings(available_leads, recording, lead_selected)
#
#
# POTENTIAL ISSUES 
# TO DOs 

import tensorflow as tf
import os
import numpy as np
from scipy.io import loadmat      # Required to load .mat files
from scipy import signal
import math
import challengePackage.processing_code as pc
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv1D, MaxPooling1D, Dense, Flatten, Dropout, concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, Reshape, GlobalAveragePooling1D, GlobalMaxPooling1D, multiply, Add
from keras.regularizers import l2
from keras.callbacks import LearningRateScheduler

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession


##### DEFAULTS FOR MODEL TRAINING #####
use_gpu = True
data_directory = './dataset'
# Defaults are needed for the DataGenerator
pa = 10  # patience of early stopping if True
b_s = 8 
wl = 4096 
nch = 12 
ncl = 26 
fs = 257
num_feats = 16
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
########################################

#### FUNCTIONS FOR MODEL TRAINING #####
def model_training_steps(model_name, training_gen, validation_gen, model_weights = '', fine_tuning= False, freeze_u = 110, num_features = num_feats, es = True, sd = True, num_ResBs= 8, channels = nch, window_len = wl, num_classes = ncl,  bs= b_s, epochs= 50, train_deep = False): 
    '''
        In model training code (mtc)

        Dependences: 
            tensorflow
            
        Args: 
            model_name: not used at the moment can be useful for saving the model inside this function. 
            training_gen: training generator. 
            validation_gen: validation generator. 
            model_weights: directory with the weights to be used ('')
            fine_tuning: boolean (False)
            freeze_u: layers to freeze and do not train (111)
            es: boolean for early stopping. (True)
            sd: boolean for step decay. (True)
            num_ResBs: number of resnet blocks. (8)
            channels: number of channels. (12) 
            window_len: window length. (4096)
            num_classes: classes on which to perform prediction. (24)
            bs: batch size. (64)
            epochs: number of epochs. (50)

        Returns: 
            train_history: history of training 
            model: model trained 
        
        This function performs model definition, model compile, model fit on the data given by the generator. 

    '''
    
    # Modify function to allow only wide branch. The most intuitive idea is to use a pretrained net starting always from the same file and not having one file for each subset of leads but train everything.
    # Anyway The most intuitive thing is to leave only the wide model and decide wether to train it from scratch or use a pretrained net. 
    # The file with the location of the weights is given as input
    if fine_tuning and model_weights != '': 
        print('Trying finetuning of a previous model')
        # Define two models: The deep branch only model and the deep+wide model 
        model_nowide = resnet_se_modified_nowide(N=num_ResBs, ch=channels, win_len=window_len, classes=num_classes)
        model_wide = resnet_se_modified_wide(N=num_ResBs, ch=channels,num_wide_features= num_features, win_len=window_len, classes=num_classes)
        
        # Load weights from h5 file (ResNetSE only model)
        try: 
            print("Loading weights ...")
            model_nowide.load_weights(model_weights)
            
            for l in range(len(model_nowide.layers[:112])):
                current_weights = model_nowide.layers[l].get_weights()
                model_wide.layers[l].set_weights(current_weights)
        except: 
            print("Could not load weights, training from scratch ...")
        # model.trainable = False
        freeze_until = freeze_u # layer from which we want to fine-tune
        for layer in model_wide.layers[:freeze_until]:
            layer.trainable = False
    
    
    else: 
        print("Training the model from scratch ...")
        # Freeze the wide layers to train only the deep part
        model_nowide = resnet_se_modified_nowide(N=num_ResBs, ch=channels, win_len=window_len, classes=num_classes)
        model_wide = resnet_se_modified_wide(N=num_ResBs, ch=channels,num_wide_features= num_features, win_len=window_len, classes=num_classes)

        if train_deep:
            print('Starting with training of the ResNet SE...')
            #freeze_from = 112
            #for layer in model_wide.layers[freeze_from:-1]:
            #    layer.trainable = False

        
            # STEP 2. Model is compiled
            deep_epochs = int(np.ceil(0.7*epochs))
            print(f'The number of epochs for Deep network are: {deep_epochs}')
            model_nowide.compile(
                            loss=tf.keras.losses.BinaryCrossentropy(),    # loss 
                            optimizer=tf.keras.optimizers.Adam(),
                            metrics=['accuracy', tf.keras.metrics.Recall(name = 'recall'), tf.keras.metrics.Precision(name='precision')]
                            # metrics=['accuracy', 'Recall', 'Precision']
            )
            
            # STEP 3. Define Callbacks 
            my_callbacks = define_callbacks(fine_tuning, earlyStopping=es, step_d=sd)
            
            # STEP 4. Model Fit 
            # The two steps to train the deep part and then the wide one require two compilations

            train_history = model_nowide.fit(
                training_gen,
                epochs=deep_epochs,
                verbose=1,
                callbacks=my_callbacks,
                validation_data=validation_gen,
                shuffle=False,
                use_multiprocessing=use_gpu
            )
            flag_deep = True
            train_deep = False

        if not train_deep:
            
            # I enter here only if the deep training is false, so when I have already finished the previous steps or I don't want to do a 2 step training
            if flag_deep:
                print('ResNET already trained. Training the wide branch...')
                # Here only if I did the previous deep training

                # Here we unlock the layers
                #freeze_from = 112
                #for layer in model_wide.layers[freeze_from:-1]:
                #    layer.trainable = True
                print('Loading ResNET SE weights...')
                for l in range(len(model_nowide.layers[:110])):
                    current_weights = model_nowide.layers[l].get_weights()
                    model_wide.layers[l].set_weights(current_weights)

                freeze_until = freeze_u # layer from which we want to fine-tune
                for layer in model_wide.layers[:freeze_until]:
                    layer.trainable = False
        
                # STEP 2. Model is compiled
                wide_epochs = int(np.ceil(epochs - deep_epochs))
                model_wide.compile(
                            loss=tf.keras.losses.BinaryCrossentropy(),    # loss 
                            optimizer=tf.keras.optimizers.Adam(),
                            metrics=['accuracy', tf.keras.metrics.Recall(name = 'recall'), tf.keras.metrics.Precision(name='precision')]
                            # metrics=['accuracy', 'Recall', 'Precision']
                            )
            
                # STEP 3. Define Callbacks 
                my_callbacks = define_callbacks(fine_tuning, earlyStopping=es, step_d=sd)
            
                # STEP 4. Model Fit 
                # The two steps to train the deep part and then the wide one require two compilations

                train_history = model_wide.fit(
                    training_gen,
                    epochs=wide_epochs,
                    verbose=1,
                    callbacks=my_callbacks,
                    validation_data=validation_gen,
                    shuffle=False,
                    use_multiprocessing=use_gpu
                )
            else:
                # Here if I don't want to do a two steps training
                print('Training the net all together...')
                model_wide = resnet_se_modified_wide(N=num_ResBs, ch=channels,num_wide_features= num_features, win_len=window_len, classes=num_classes)
            
                # STEP 2. Model is compiled
                model_wide.compile(
                                loss=tf.keras.losses.BinaryCrossentropy(),    # loss 
                                optimizer=tf.keras.optimizers.Adam(),
                                metrics=['accuracy', tf.keras.metrics.Recall(name = 'recall'), tf.keras.metrics.Precision(name='precision')]
                                # metrics=['accuracy', 'Recall', 'Precision']
                )
                
                # STEP 3. Define Callbacks 
                my_callbacks = define_callbacks(fine_tuning, earlyStopping=es, step_d=sd)
                
                # STEP 4. Model Fit 
                # The two steps to train the deep part and then the wide one require two compilations

                train_history = model_wide.fit(
                    training_gen,
                    epochs=epochs,
                    verbose=1,
                    callbacks=my_callbacks,
                    validation_data=validation_gen,
                    shuffle=False,
                    use_multiprocessing=use_gpu
                )
        
    # save_model(filename_dir, model)

    return train_history, model_wide

def define_callbacks(fine_tuning, earlyStopping=True, step_d=True): 
    '''
        In model training code (mtc)

        Dependences: 
            tensorflow 
        Args: 
            earlyStopping: boolean to explicit the implementation of early stopping
            step_d: boolean to explicit the implementation of step decay
        Return: 
            callbacks: list of keras callbacks 

    '''
    
    callbacks = []
    # Learning Schedule 
    # -----------------
    # Learning schedule as proposed by the paper: decrease 10 fold at 20th and 40th epoch
    

    if step_d: 
        if fine_tuning: 
            lrate_callback = tf.keras.callbacks.LearningRateScheduler(step_decay_fine_tuning)
        else:
            lrate_callback = tf.keras.callbacks.LearningRateScheduler(step_decay)
        callbacks.append(lrate_callback)

    # Early Stopping
    # --------------
    if earlyStopping:
        es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience= pa)
        callbacks.append(es_callback)
    
    return callbacks

def step_decay(epoch):
    '''
        In model training code (mtc)

        Dependences: 
            math

        Function to define a step decay learning schedule for keras. 
        Every 10 epochs the lr is dropped

    '''

    initial_lrate = 0.003
    drop = 0.1
    epochs_drop = 10.0
    lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))

    return lrate

def step_decay_fine_tuning(epoch):
    '''
        In model training code (mtc)

        Dependences: 
            math

        Function to define a step decay learning schedule for keras. 
        Every 10 epochs the lr is dropped

    '''

    initial_lrate = 0.003
    drop = 0.1
    epochs_drop = 10.0
    lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
    return lrate

##### FUNCTIONS TO DEFINE THE MODEL #####
def resnet_se_modified_wide(N=8, ch=12, win_len=4096, num_wide_features= 16, classes=26): 
    ''' 
        In model training code (mtc)

        Update 7 Apr 2021 20:05 - Stefano Magni 
        
        Dependences: 
            Input, Conv1D, BatchNormalization, Activation, MaxPooling1D, GlobalMaxPooling1D
    
        ResNet with Squeeze and excitation blocks modified as suggested by: 
        Adaptive lead Weighted ResNet Trained With Different Duration Signals for 
        Classifying 12-lead ECGs. Zhao, Wong et al. 2020
        
    '''
    # A. Wide features go into a MLP 10-20-10 

    # Il layer delle wide features è all'index 111!!!
    wide_input = Input(shape= (num_wide_features, ), name = 'wide_features')
    wide_branch = Dense(10, activation='relu')(wide_input)
    #wide_branch = Dense(10, kernel_regularizer=l2(0.005), bias_regularizer=l2(0.005), activation='relu')(wide_input)
    #wide_branch = Dropout(0.2)(wide_branch)
    #wide_branch = Dense(20, kernel_regularizer=l2(0.005), bias_regularizer=l2(0.005), activation='relu')(wide_branch)
    #wide_branch = Dropout(0.2)(wide_branch)
    #wide_branch = Dense(10, kernel_regularizer=l2(0.005), bias_regularizer=l2(0.005), activation='relu')(wide_branch)
    
    wide_branch = Flatten()(wide_branch)
    
    # B. ECG window input of shape (batch_size,  WINDOW_LEN, CHANNELS)
    ecg_input = Input(shape=(win_len, ch), name='ecg_signal')
    # B.1 Conv 
    ecg_branch = Conv1D(filters=64,kernel_size=15, padding = 'same')(ecg_input) # padding ?? 
    # B.2 BatchNorm 
    ecg_branch = BatchNormalization()(ecg_branch)
    # B.3 Relu 
    ecg_branch = Activation('relu')(ecg_branch)
    # B.4 Max Pool 
    ecg_branch = MaxPooling1D(pool_size=2, strides = 2)(ecg_branch)
    # B.5 ResBs (x8)
    # Here number of filters starts from 64 and doubles every two blocks
    # Max pooling is of size 2
    # Halving the dimension at the third, fifth and seventh ResBs
    
    # define ResBs_identity blocks (N = 1, N = 2)
    ecg_branch = ResBs_Identity(ecg_branch, 64)
    ecg_branch = ResBs_Identity(ecg_branch, 64)

    filters = 64
    M= int((N -2 )/2)
    for i in range(M): 

        filters = filters*2
        # define N-th ResBs block
        ecg_branch = ResBs_Conv(ecg_branch, filters)
        ecg_branch = ResBs_Identity(ecg_branch, filters)
    
    # reshape_size=int(np.floor(ch/2)*512)
    # Sigmoid activation function on the last layer
    ecg_branch = GlobalMaxPooling1D(name='gmp_layer')(ecg_branch)
    # Flatten 
    ecg_branch = Flatten()(ecg_branch)
    # Concatenate 
    shared_path = concatenate([ecg_branch, wide_branch], name='concat_layer')
    # HEAD Classifier 
    shared_path = Dense(classes, activation='sigmoid', name='sigmoid_classifier')(shared_path)
    # Finally the model is composed by connecting inputs to outputs: 
    model = Model(inputs=[ecg_input,  wide_input],outputs=shared_path)

    return model

def resnet_se_modified_nowide(N=8, ch=12, win_len=4096, classes=26): 
    ''' 
        In model training code (mtc)

        Update 7 Apr 2021 20:05 - Stefano Magni 
        
        Dependences: 
            Input, Conv1D, BatchNormalization, Activation, MaxPooling1D, GlobalMaxPooling1D
    
        ResNet with Squeeze and excitation blocks modified as suggested by: 
        Adaptive lead Weighted ResNet Trained With Different Duration Signals for 
        Classifying 12-lead ECGs. Zhao, Wong et al. 2020
        
    '''
    # # A. Wide features go into a MLP 10-20-10 
    # wide_input = Input(shape= (num_wide_features, ), name = 'wide_features')
    # wide_branch = Dense(10, activation='relu')(wide_input)
    # wide_branch = Flatten()(wide_branch)
    
    # B. ECG window input of shape (batch_size,  WINDOW_LEN, CHANNELS)
    ecg_input = Input(shape=(win_len, ch), name='ecg_signal')
    # B.1 Conv 
    ecg_branch = Conv1D(filters=64,kernel_size=15, padding = 'same')(ecg_input) # padding ?? 
    # B.2 BatchNorm 
    ecg_branch = BatchNormalization()(ecg_branch)
    # B.3 Relu 
    ecg_branch = Activation('relu')(ecg_branch)
    # B.4 Max Pool 
    ecg_branch = MaxPooling1D(pool_size=2, strides = 2)(ecg_branch)
    # B.5 ResBs (x8)
    # Here number of filters starts from 64 and doubles every two blocks
    # Max pooling is of size 2
    # Halving the dimension at the third, fifth and seventh ResBs
    
    # define ResBs_identity blocks (N = 1, N = 2)
    ecg_branch = ResBs_Identity(ecg_branch, 64)
    ecg_branch = ResBs_Identity(ecg_branch, 64)

    filters = 64
    M= int((N -2 )/2)
    for i in range(M): 

        filters = filters*2
        # define N-th ResBs block
        ecg_branch = ResBs_Conv(ecg_branch, filters)
        ecg_branch = ResBs_Identity(ecg_branch, filters)
    
    # reshape_size=int(np.floor(ch/2)*512)
    # Sigmoid activation function on the last layer
    ecg_branch = GlobalMaxPooling1D(name='gmp_layer')(ecg_branch)
    # Flatten 
    ecg_branch = Flatten()(ecg_branch)
    # Concatenate 
    # shared_path = concatenate([ecg_branch, wide_branch], name='concat_layer')
    # HEAD Classifier 
    shared_path = Dense(classes, activation='sigmoid', name='sigmoid_classifier')(ecg_branch)
    # Finally the model is composed by connecting inputs to outputs: 
    model = Model(inputs=[ecg_input],outputs=shared_path)

    return model

def ResBs_Conv(block_input, num_filters): 
    '''
        In model training code (mtc)

        Created 16 feb 2021 - Stefano Magni
        Dependences: 
            Conv1D, BatchNormalization, Activation, Dropout, Add
        Args:
            block_input: input tensor to the ResNet block
            num_filters: no. of filters/channels in block_input
            
        Returns:
            relu2: activated tensor after addition with original input
    '''

    # The ResBs block consists of: 
    
    # 0. Filter Block input and BatchNorm it. 
    block_input = Conv1D(num_filters, kernel_size=7, strides = 2,  padding = 'same')(block_input) #uguale al conv layer finale 
    block_input = BatchNormalization()(block_input)
    # 1. First Convolutional Layer
    conv1 = Conv1D(filters=num_filters, kernel_size=7, padding= 'same')(block_input)
    norm1 = BatchNormalization()(conv1)
    relu1 = Activation('relu')(norm1)
   
    dropout = Dropout(0.2)(relu1)
    
    # 2. Second Convolutional Layer 
    conv2 = Conv1D(num_filters, kernel_size=7, padding= 'same')(dropout) #per avere concordanza
    norm2 = BatchNormalization()(conv2)

    # 3. SE block (fucntion defined above)
    se = se_block(norm2, num_filters=num_filters)
    
    # 4. Summing Layer (adding a residual connection)
    # mult = multiply([block_input, se])
    sum = Add()([block_input, se])
    
    # 5. Activation Layer
    relu2 = Activation('relu')(sum)
    
    return relu2 

def ResBs_Identity(block_input, num_filters): 
    '''
        In model training code (mtc)

        Created 16 feb 2021 - Stefano Magni
        Dependences: 
            Conv1D, BatchNormalization, Activation, Dropout, Add
        Args:
            block_input: input tensor to the ResNet block
            num_filters: no. of filters/channels in block_input
            
        Returns:
            relu2: activated tensor after addition with original input
    '''

    # 1. First Convolutional Layer
    conv1 = Conv1D(filters=num_filters, kernel_size=7, padding= 'same')(block_input)
    norm1 = BatchNormalization()(conv1)
    relu1 = Activation('relu')(norm1)
    
    dropout = Dropout(0.2)(relu1)
    
    # 2. Second Convolutional Layer 
    conv2 = Conv1D(num_filters, kernel_size=7, padding= 'same')(dropout) #per avere concordanza
    norm2 = BatchNormalization()(conv2)

    # 3. SE block (fucntion defined above)
    se = se_block(norm2, num_filters=num_filters)
    
    # 4. Summing Layer (adding a residual connection)
    # mult = multiply([block_input, se])
    sum = Add()([block_input, se])
    
    # 5. Activation Layer
    relu2 = Activation('relu')(sum)
    
    return relu2 

# Define SE Block
def se_block(block_input, num_filters, ratio=16):
    '''
        In model training code (mtc)

        Created 16 feb 2021 - Stefano Magni
        Dependences: 
            GlobalAveragePooling1D, Reshape, Dense, multiply
        Args:
            block_input: input tensor to the squeeze and excitation block
            num_filters: no. of filters/channels in block_input
            ratio: a hyperparameter that denotes the ratio by which no. of channels will be reduced
            
        Returns:
            se_scale: scaled tensor after getting multiplied by new channel weights
    '''

    # 1. Global AVG Pool 1D that computes average on channels 
    se_pool1 = GlobalAveragePooling1D()(block_input)   # default is channel last
    flat = Reshape((1, num_filters))(se_pool1)
    # 2. Fully connected with C//ratio x1 and relu as activation 
    se_dense1 = Dense(num_filters // ratio, activation='relu')(flat)  # Relu can be expliced as an activation layer
    # 3. Fully connected with sigmoidal activation Cx1
    se_dense2 = Dense(num_filters, activation='sigmoid')(se_dense1)
    # 4. The output of the block is then multiplied with the input
    se_scale = multiply([block_input, se_dense2])
    
    return se_scale


#### CODE FOR CHANNEL EXTRACTION ####

def subset_ch_recordings(available_leads, recording, lead_selected): 
    '''
        In model training code (mtc)

        This function orders the leads according to selected leads and outputs the recording
        Dependences: 
            None

        Args: 
            available_leads: leads that are taken from the header
            recording: ecg recording correposnding to the leads
            lead_selected: leads that are taken (subset with precise order)
        Returns: 
            recording: input recording with ordered leads
    '''
    # header gives us current channel order 
    # list gives us channel names 
    # to be used inside create header label data
    # Reorder/reselect leads in recordings.
    indices = list()
    # print("INSIDE SUBSET CH : the leads selected are:")
    # print(lead_selected)
    
    # print("INSIDE SUBSET CH: the available leads are:")
    # print(available_leads)
    
    for lead in lead_selected:
        # print(lead)
        i = available_leads.index(lead)
        indices.append(i)
        
    recording = recording[indices, :]

    # return channel positions related to channels names from list 
    return recording

class DataGenerator(tf.keras.utils.Sequence):
    
    '''
        In model training code (mtc)

        Update 24 Mar 2021 13:23 - Stefano Magni
        Now the generator is considered fixed and it works as expected.
        The preprocessing part and the generation of windows have been debugged and work as exprected. 
        Dependences: 
            numpy, loadmat, processing_code (inside challengePackage)

        Args:
            header_dict: dictionary defined containing information from the header files. I need fs, labels, leads
            list_IDs: list of the IDs containing all file name directories of the dataset. Splitted in train and validation.
            data_dir: the default is set to our data structure='./dataset'
            batch_size: default is defined at the beginning of module.
            dim: window length. Default is set to the chosen WINDOW_LEN by the paper (4096)
            n_channels: the number of channels of each signal. Default to 12 but less may be available
            n_classes: total number of classes in this project is 24 that are actually scored.
            sampling_freq: the lowest inside the dataset is 257 Hz, thus it is chosen
            shuffle: defaults to True decides whether or not to take random samples from the lists.
                     Set to false for reproducibility

        Returns:
            full_windows: np.array of the ecg signal of shape= (bs, 4096, 12)
            gen_labels: lables array binary 

    '''

    def __init__(self, header_dict,feature_dict, list_IDs, lead_selected=twelve_leads, data_dir=data_directory, batch_size=b_s, dim= wl,
                 n_classes=ncl, sampling_freq=fs, shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.list_IDs = list_IDs
        self.lead_selected = lead_selected
        self.header_dict = header_dict
        self.feature_dict = feature_dict
        self.n_classes = n_classes
        self.sampling_freq = sampling_freq
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'

        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples'

        # Initialize output variables
        full_labels = []
        full_windows = []
        wide_features = []
        # cycle over list_ID that now is a dir without ext
        for i, ID in enumerate(list_IDs_temp):

            # 1. Take file name and import file
            string_dir_no_ext = str(ID)
            file_name_dir = string_dir_no_ext + '.mat'
            mat = loadmat(file_name_dir)
            # the mat file is a dictionary with 'val' key of shape (num_channels, num_samples)
            current_signal = mat['val']
            current_header = self.header_dict[str(ID)]
            current_signal = subset_ch_recordings(current_header['leads'], current_signal, self.lead_selected)
            # number of real channels. 
            num_channels = len(current_signal) 
            current_fs = current_header['fs']
            # current_fs can be chosen by inputting it into gen.

            ######### STEP 0 : Define all variables and the signal########
            current_sig_len = current_signal.shape[1]
            current_time_sec = current_sig_len / current_fs
            num_samples_new = int(current_time_sec * self.sampling_freq)
            # print(num_samples_new)
            #### STEP 1 ####
            signal_res = signal.resample(current_signal, num_samples_new, axis=1)
            #### STEP 2 ####
            # If the window is smaller than W is padded anyways 
            # Initialise vectors
            current_windows = []
            # if the number of samples is smaller then a window it is zero padded. 
            # shared Path between train and test
            start_window = np.zeros((num_channels , self.dim))
            
            left_eq = int(current_header['bounds'][0])
            my_start = int(current_header['bounds'][1])
            my_end = int(current_header['bounds'][2])
            
            start_window[:, 0:left_eq] = current_signal[:, my_start:my_end]
            current_windows.append(start_window)
            # the code before is channel first and we make it channel last
            windows = np.array(current_windows)
            window_reshape = np.moveaxis(windows, 1, -1)
        
            full_windows.extend(window_reshape.tolist())

            # 5. Extract labels integers from header
            current_labels = current_header['labels']
            # Save current ID value in a list (batch)
            # Update 21 Feb. 2021 13:15 - Stefano Magni
            # Labels should be only transformed into an array of length num classes.
            full_labels.append(current_labels)
            
            # 6. Loading features from dictionary and saving into a list. 
            current_feature_dict = self.feature_dict[str(ID)]
            features = list(current_feature_dict.values())
            
            wide_features.append(features)

        # return np.array(full_windows), gen_labels
        return {'ecg_signal': np.array(full_windows), 'wide_features': np.array(wide_features)}, {'sigmoid_classifier': np.array(full_labels)}
        
    

