#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of required functions, remove non-required functions, and add your own function.

from helper_code import *
import os
import numpy as np, sys,os
import pandas as pd
from scipy.io import loadmat
#import wdfb
import pickle
import json
import random

from scipy import optimize
from scipy.signal import resample, decimate, resample_poly

from sklearn.preprocessing import MultiLabelBinarizer
import math
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight

import tensorflow as tf
from tensorflow import keras

from keras import layers, Sequential
from keras.layers import Input, Add, Dense, Activation, ZeroPadding1D, BatchNormalization, Flatten, Conv1D, AveragePooling1D, MaxPooling1D, GlobalMaxPooling1D, Dropout, GlobalAveragePooling1D, Multiply
from keras.models import Model, load_model #save
from keras import preprocessing
from keras.preprocessing import sequence
from keras.preprocessing.sequence import pad_sequences
from keras.initializers import glorot_uniform

SEED = 1234
tf.random.set_seed(SEED)
np.random.seed(seed=SEED)


def load_data(filename, channels):
        x = loadmat(filename)
        if (channels == 12):
            data = np.asarray(x['val'], dtype=np.float32)
        elif (channels == 6):
            data = np.asarray(x['val'][0:6], dtype=np.float32)
        elif (channels == 4):
            data = []
            data.extend(x['val'][0:3])
            data.extend(x['val'][7:8])
            data = np.asarray(data, dtype=np.float32)
        elif (channels == 3):
            data = []
            data.extend(x['val'][0:2])
            data.extend(x['val'][7:8])
            data = np.asarray(data, dtype=np.float32)
        elif (channels == 2):
            data = []
            data.extend(x['val'][0:2])
            data = np.asarray(data, dtype=np.float32)
        new_file = filename.replace('.mat','.hea')
        header_input = os.path.join(new_file)
        with open(header_input,'r') as f:
            header=f.readlines()
        return data, header
    
def resample_ecg(signal, input_freq, output_freq):
        signal = np.atleast_1d(signal).astype(float)
        if input_freq != int(input_freq):
            raise ValueError("input_freq must be an integer")
        if output_freq != int(output_freq):
            raise ValueError("output_freq must be an integer")

        if input_freq == output_freq:
            new_signal = signal
        elif np.mod(input_freq, output_freq) == 0:
            new_signal = decimate(signal, q=input_freq//output_freq,
                             ftype='iir', zero_phase=True, axis=-1)
        else:
            new_signal = resample_poly(signal, up=output_freq, down=input_freq, axis=-1)
        return new_signal

def remove_unscored_classes(labels, df_unscored):
        labels_corrected = pd.DataFrame(labels)
        for i in range(len(df_unscored.iloc[0:,1])):
            labels_corrected.replace(to_replace=str(df_unscored.iloc[i,1]), inplace=True ,value="undefined_class", regex=True)
            labels_corrected.replace(to_replace=str(59118001), inplace=True ,value="713427006", regex=True)  
            labels_corrected.replace(to_replace=str(63593006), inplace=True ,value="284470004", regex=True)    
            labels_corrected.replace(to_replace=str(17338001), inplace=True ,value="427172004", regex=True)
            labels_corrected.replace(to_replace=str(164909002), inplace=True ,value="733534002", regex=True)
        return labels_corrected

def load_table(table_file):
    # The table should have the following form:
    #
    # ,    a,   b,   c
    # a, 1.2, 2.3, 3.4
    # b, 4.5, 5.6, 6.7
    # c, 7.8, 8.9, 9.0
    #
    table = list()
    with open(table_file, 'r') as f:
        for i, l in enumerate(f):
            arrs = [arr.strip() for arr in l.split(',')]
            table.append(arrs)

    # Define the numbers of rows and columns and check for errors.
    num_rows = len(table)-1
    if num_rows<1:
        raise Exception('The table {} is empty.'.format(table_file))

    num_cols = set(len(table[i])-1 for i in range(num_rows))
    if len(num_cols)!=1:
        raise Exception('The table {} has rows with different lengths.'.format(table_file))
    num_cols = min(num_cols)
    if num_cols<1:
        raise Exception('The table {} is empty.'.format(table_file))

    # Find the row and column labels.
    rows = [table[0][j+1] for j in range(num_rows)]
    cols = [table[i+1][0] for i in range(num_cols)]

    # Find the entries of the table.
    values = np.zeros((num_rows, num_cols), dtype=np.float64)
    for i in range(num_rows):
        for j in range(num_cols):
            value = table[i+1][j+1]
            if is_number(value):
                values[i, j] = float(value)
            else:
                values[i, j] = float('nan')

    return rows, cols, values

def load_weights(weight_file, equivalent_classes):
    # Load the table with the weight matrix.
    rows, cols, values = load_table(weight_file)

    rows_classes = replace_equivalent_classes(rows, equivalent_classes)
    classes_set = [x for j, x in enumerate(rows_classes) if x not in rows_classes[:j]]
    classes = [(single_class.split('|'))[0] for single_class in classes_set]


    # Split the equivalent classes.
    rows = [set(row.split('|')) for row in rows]
    cols = [set(col.split('|')) for col in cols]
    assert(rows == cols)

    # Identify the classes and the weight matrix.
    #classes = rows
    weights = values

    return classes, weights

def is_number(x):
    try:
        float(x)
        return True
    except ValueError:
        return False

def replace_equivalent_classes(classes, equivalent_classes):
    for j, x in enumerate(classes):
        for multiple_classes in equivalent_classes:
            if x in multiple_classes:
                classes[j] = multiple_classes[1] # Use the 2nd class as the representative class.
    return classes
      
def compute_challenge_metric(weights, labels, outputs, classes, sinus_rhythm):
    num_recordings, num_classes = np.shape(labels)
    if sinus_rhythm in classes:
        sinus_rhythm_index = classes.index(sinus_rhythm)
    else:
        raise ValueError('The sinus rhythm class is not available.')

    # Compute the observed score.
    A = compute_modified_confusion_matrix(labels, outputs)
    observed_score = np.nansum(weights * A)

    # Compute the score for the model that always chooses the correct label(s).
    correct_outputs = labels
    A = compute_modified_confusion_matrix(labels, correct_outputs)
    correct_score = np.nansum(weights * A)

    # Compute the score for the model that always chooses the sinus rhythm class.
    inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool)
    inactive_outputs[:, sinus_rhythm_index] = 1
    A = compute_modified_confusion_matrix(labels, inactive_outputs)
    inactive_score = np.nansum(weights * A)

    if correct_score != inactive_score:
        normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score)
    else:
        normalized_score = 0.0

    return normalized_score

    
def compute_modified_confusion_matrix(labels, outputs):
        # Compute a binary multi-class, multi-label confusion matrix, where the rows
        # are the labels and the columns are the outputs.
        num_recordings, num_classes = np.shape(labels)
        A = np.zeros((num_classes, num_classes))

        # Iterate over all of the recordings.
        for i in range(num_recordings):
            # Calculate the number of positive labels and/or outputs.
            normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1))
            # Iterate over all of the classes.
            for j in range(num_classes):
                # Assign full and/or partial credit for each positive class.
                if labels[i, j]: 
                    for k in range(num_classes):
                        if outputs[i, k]:
                            A[j, k] += 1.0/normalization
        return A
    
def th_optimization(thr, weights, label, output_prob, classes_list):
        return -compute_challenge_metric(weights, label, np.array(output_prob>thr), classes_list, 'NSR')

def initialize_th(weights, label, y_pred, classes_list):
    init_th = np.arange(0,1,0.01)
    
    all_scores = []
    for i in init_th:
        prediction = y_pred > i
        prediction = prediction * 1

        score = compute_challenge_metric(weights,label,prediction,classes_list, 'NSR')
        print(score)
        all_scores.append(score)
    all_scores = np.asarray(all_scores)
    
    return all_scores

def label_encode(correct_labeles, classes_input):
        multi_label = MultiLabelBinarizer(classes_input)
        y = multi_label.fit_transform(correct_labeles[0].str.split(pat=','))
        print("Classes evaluated:")
        print(multi_label.classes_)
        y = np.delete(y, -1, axis=1)
        print("classes: {}".format(y.shape[1]))
        return y, multi_label.classes_[0:-1]
    
def labels_of_combinations(y):
        y_combinations = LabelEncoder().fit_transform([''.join(str(l)) for l in y])
        return y_combinations
  
def split_data(labels, y_comb):
    folds = list(StratifiedKFold(n_splits=6, shuffle=True, random_state=42).split(labels,y_comb))
    print("Training split: {}".format(len(folds[0][0])))
    print("Validation split: {}".format(len(folds[0][1])))
    return folds

def generate_validation_data_only(ecg_filenames, y, test_order_array, channels):
    y_val = y[test_order_array]
    ecg_filenames_val = ecg_filenames[test_order_array]

    validation_set = []
    for names in ecg_filenames_val:
        data, header_data = load_data(names, channels)
        freq = int(header_data[0].split(' ')[2])
        data = resample_ecg(data, freq, 257)
        X = pad_sequences(data, maxlen=4096, dtype='float32', truncating='post',padding="post")
        #normalize
        validation_set.append(X)
    X_val = np.asarray(validation_set)
    X_val = X_val.reshape(ecg_filenames_val.shape[0],4096,channels)
    return X_val, y_val

def shuffle_batch_generator_data(channels, classes, order_train, batch_size, x_gen, y_gen): 
    
    np.random.shuffle(order_train)
    batch_features = np.zeros((batch_size,4096, channels))
    batch_labels = np.zeros((batch_size,classes.shape[0]))
    
    while True:
        for i in range(batch_size):

            batch_features[i] = next(x_gen)
            batch_labels[i] = next(y_gen)
            

        X = [batch_features]
        yield X, batch_labels
        

def generate_y_shuffle(order_train, y_train):
    while True:
        for i in order_train:
            y_shuffled = y_train[i]
            yield y_shuffled


def generate_X_shuffle(order_train, X_train, channels):
    while True:
        for i in order_train:
                #if filepath.endswith(".mat"):
                    data, header_data = load_data(X_train[i], channels)
                    freq = int(header_data[0].split(' ')[2]) 
                    data = resample_ecg(data, freq, 257)
                    X_train_new = pad_sequences(data, maxlen=4096, dtype='float32', truncating='post',padding="post")
                    #X_train_new = list(X_train_new)
                    #for i in range(X_train_new.shape[1]):
                     # X_train_new[:, i, :] = scalers[i].transform(X_train_new[:, i, :]) 
                    X_train_new = X_train_new.reshape(4096,channels)
                    yield X_train_new
                    
def step_decay(epoch):
    initial_lrate = 0.003
    drop = 0.1
    epochs_drop = 10.0
    lrate = initial_lrate * math.pow(drop,  
           math.floor((1+epoch)/epochs_drop))
    return lrate





def se_block(in_block, ch, ratio=16):
    x = GlobalAveragePooling1D()(in_block)
    x = Dense(ch//ratio, activation='relu')(x)
    x = Dense(ch, activation='sigmoid')(x)
    return tf.keras.layers.Multiply()([in_block, x]) 

def identity_block(X, kernel, filters, stage, block):

    # defining name basis
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    F1, F2 = filters

    X_shortcut = X

    X = Conv1D(filters=F1, kernel_size=kernel, strides=1, padding='same', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(name=bn_name_base + '2a')(X) #axis=3
    X = Activation('relu')(X)
    
    X =  Dropout (0.2) (X)

    X = Conv1D(filters=F2, kernel_size=kernel, strides=1, padding='same', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(name=bn_name_base + '2c')(X)
    X = se_block(X, F2, ratio=16)

    X = Add()([X, X_shortcut])
    X = Activation('relu')(X)


    return X


def convolutional_block(X, kernel, filters, stage, block, s=2):

    # defining name basis
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    F1, F2 = filters

    X_shortcut = X

    X = Conv1D(filters=F1, kernel_size=kernel, strides=s, padding='valid', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(name=bn_name_base + '2a')(X) #axis=3
    X = Activation('relu')(X)
    
    X =  Dropout (0.2) (X)

    X = Conv1D(filters=F2, kernel_size=kernel, strides=1, padding='same', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(name=bn_name_base + '2c')(X)
    X = se_block(X, F2, ratio=16)

    X_shortcut = Conv1D(filters=F2, kernel_size=kernel, strides=s, padding='valid', name=conv_name_base + '1', kernel_initializer=glorot_uniform(seed=0))(X_shortcut)
    X_shortcut = BatchNormalization(name=bn_name_base + '1')(X_shortcut)

    X = Add()([X, X_shortcut])
    X = Activation('relu')(X)

    return X


def ResNetmod(channels, classes):


    X1_input = Input(shape=(4096, channels))


    X = Conv1D(64, 15, strides=1, name='conv1', kernel_initializer=glorot_uniform(seed=0))(X1_input)
    X = BatchNormalization(name='bn_conv1')(X)
    X = Activation('relu')(X)
    X = MaxPooling1D(2, strides=2)(X)

    X = identity_block(X, 7, [64, 64], stage=2, block='a')
    X = identity_block(X, 7, [64, 64], stage=2, block='b')

    X = convolutional_block(X, 7, filters=[128, 128], stage=3, block='a', s=2)
    X = identity_block(X, 7, [128, 128], stage=3, block='b')

    X = convolutional_block(X, 7, filters=[256, 256], stage=4, block='a', s=2)
    X = identity_block(X, 7, [256, 256], stage=4, block='b')
    

    X = convolutional_block(X, 7, filters=[512, 512], stage=5, block='a', s=2)
    X = identity_block(X, 7, [512, 512], stage=5, block='b')

    X = tf.keras.layers.GlobalMaxPool1D()(X)

    # output layer
    X = Flatten()(X)
    # X = Dense(classes, activation='sigmoid', name='fc' + str(classes), kernel_initializer=glorot_uniform(seed=0))(X)  #softmax??

    final_layer = Dense(classes, activation="sigmoid")(X) #softamx per classi
    model = Model(inputs=[X1_input], outputs=final_layer)

    opt = keras.optimizers.Adam()
    model.compile(optimizer=opt, loss='binary_crossentropy', metrics=['accuracy', 'Precision', 'Recall' ])

   
    return model



           
        
        

################################################################################
#
# Training function
#
################################################################################

# Train your model. This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    # Find header and recording files.
    print('Finding header and recording files...')

    # Create a folder for the model if it does not already exist.
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)
        
    df_scored = pd.read_csv('dx_mapping_scored.csv')
    df_unscored = pd.read_csv('dx_mapping_unscored.csv')

    channels = 12

    path = data_directory
    labels = []
    filenames = []
    samplenames = []
    sample_path = os.listdir(path)
    for sample in sample_path:
    	if sample.endswith('.mat'):
                signal_name = os.path.join(path, sample)
                data, header_data = load_data(signal_name, channels)
                labels.append(header_data[15][5:-1])
                filenames.append(signal_name)
   
    filenames_array = np.array (filenames)

    

    labels_corrected = remove_unscored_classes(labels,df_unscored)

    weights_file = 'weights.csv'
    normal_class = '426783006'
    equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001'], ['733534002', '164909002']]
    classes, weights = load_weights(weights_file, equivalent_classes)
       
    classes_input = np.append(classes, 'undefined_class')

    y_tot , classes = label_encode(labels_corrected, classes_input)

    indexes = []
    for i in range (y_tot.shape[0]):
      if np.all(y_tot[i] == 0):
        indexes.append(i)
    filenames_array = np.delete(filenames_array, indexes, 0)
    y_tot = np.delete(y_tot, indexes, 0)

    y_comb = labels_of_combinations(y_tot)
    
    sinus_index = list(np.where((y_tot[:, 14]) & (y_tot[:, 22]==0))[0])
    sb_index = list(np.where((y_tot[:, 22]))[0])
    sinus_index = random.Random(42).sample(sinus_index, len(sinus_index))
    sb_index = random.Random(42).sample(sb_index, len(sb_index))                                 
    other_index = list(np.where((y_tot[:, 14]==0) & (y_tot[:, 22]==0))[0]) 
    first_sinus_segment = sinus_index[0:9538] + sb_index[0:6306]
    second_sinus_segment = sinus_index[9538:19075] + sb_index[6306:12612]
    third_sinus_segment = sinus_index[19075:] + sb_index[12612:]
    first_fold = []
    second_fold = []
    third_fold = []
    first_fold.extend(other_index)
    second_fold.extend(other_index)
    third_fold.extend(other_index)
    first_fold.extend(first_sinus_segment)
    second_fold.extend(second_sinus_segment)
    third_fold.extend(third_sinus_segment)
    y1 = y_tot [first_fold] 
    y2 = y_tot [second_fold] 
    y3 = y_tot [third_fold] 
    
    y1_comb = labels_of_combinations(y1)
    y2_comb = labels_of_combinations(y2)
    y3_comb = labels_of_combinations(y3)
    
    folds1 = split_data(y1, y1_comb)
    folds2 = split_data(y2, y2_comb)
    folds3 = split_data(y3, y3_comb)

    first_fold_array = np.array (first_fold)
    second_fold_array = np.array (second_fold)
    third_fold_array = np.array (third_fold)

    classes_name = np.zeros((classes.shape), dtype='object')
    for j in range(len(classes)):
        for i in range(len(df_scored.iloc[:,1])):
                if (str(df_scored.iloc[:,1][i]) == classes[j]):
                        classes_name[j] = df_scored.iloc[:,2][i]
    
    cb = [
        keras.callbacks.LearningRateScheduler(step_decay, verbose=1)
    ]
    
    prediction_array = []
    th_array = []
    score_array = []
    classes_list = classes_name.tolist()                                 
    y = y_tot
    i = 1
    batchsize = 64
    epochs = 30

    print (classes_list)
    

    
    # Train 12-lead ECG model.
    print('Training 12-lead ECG model...')

    channels = 12
    twelve_lead_model_filename = 'twelve_lead_model_filename'
    folder_name = model_directory + '/' + twelve_lead_model_filename
    os.makedirs(folder_name + '/final_models/')
    os.makedirs(folder_name + '/final_thrs/') 
    
    order_train = first_fold_array[folds1[i][0]]
    order_valid = first_fold_array[folds1[i][1]]
    
    model1 = ResNetmod(channels, classes=26)
    history = model1.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)
    filename = folder_name + '/final_models/model_fold_1' + '.h5'                                   
    model1.save(filename)
    print('>Saved %s' % filename)
                                     
    y_pred = model1.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_1 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))
    filenames = folder_name + '/final_thrs/thr_1' + '.pickle'                                 
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_1, filehandle)
    
                                     
    order_train = second_fold_array[folds2[i][0]]
    order_valid = second_fold_array[folds2[i][1]]
                                     
    model2 = ResNetmod(channels, classes=26)
    history = model2.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_2' + '.h5'
    model2.save(filename)
    print('>Saved %s' % filename)

    y_pred = model2.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_2 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                 
    filenames = folder_name + '/final_thrs/thr_2' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_2, filehandle)
            
        
    order_train = third_fold_array[folds3[i][0]]
    order_valid = third_fold_array[folds3[i][1]]
                                     
    model3 = ResNetmod(channels, classes=26)
    history = model3.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_3' + '.h5'
    model3.save(filename)
    print('>Saved %s' % filename)

    y_pred = model3.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_3 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                     
    filenames = folder_name + '/final_thrs/thr_3' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_3, filehandle)
        
    # Train 6-lead ECG model.
    print('Training 6-lead ECG model...')

    channels = 6
    six_lead_model_filename = 'six_lead_model_filename'
    folder_name = model_directory + '/' + six_lead_model_filename
    os.makedirs(folder_name + '/final_models/') 
    os.makedirs(folder_name + '/final_thrs/') 

    order_train = first_fold_array[folds1[i][0]]
    order_valid = first_fold_array[folds1[i][1]]
    
    model1 = ResNetmod(channels, classes=26)
    history = model1.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)
    filename = folder_name + '/final_models/model_fold_1' + '.h5'                                   
    model1.save(filename)
    print('>Saved %s' % filename)
                                     
    y_pred = model1.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_1 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))
    filenames = folder_name + '/final_thrs/thr_1' + '.pickle'                                 
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_1, filehandle)
    
                                     
    order_train = second_fold_array[folds2[i][0]]
    order_valid = second_fold_array[folds2[i][1]]
                                     
    model2 = ResNetmod(channels, classes=26)
    history = model2.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_2' + '.h5'
    model2.save(filename)
    print('>Saved %s' % filename)

    y_pred = model2.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_2 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                 
    filenames = folder_name + '/final_thrs/thr_2' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_2, filehandle)
            
        
    order_train = third_fold_array[folds3[i][0]]
    order_valid = third_fold_array[folds3[i][1]]
                                     
    model3 = ResNetmod(channels, classes=26)
    history = model3.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_3' + '.h5'
    model3.save(filename)
    print('>Saved %s' % filename)

    y_pred = model3.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_3 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                     
    filenames = folder_name + '/final_thrs/thr_3' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_3, filehandle)
 


    # Train 4-lead ECG model.
    print('Training 4-lead ECG model...')

    channels = 4
    four_lead_model_filename = 'four_lead_model_filename'
    folder_name = model_directory + '/' + four_lead_model_filename
    os.makedirs(folder_name + '/final_models/') 
    os.makedirs(folder_name + '/final_thrs/') 

    order_train = first_fold_array[folds1[i][0]]
    order_valid = first_fold_array[folds1[i][1]]
                                     
    model1 = ResNetmod(channels, classes=26)
    history = model1.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)
    filename = folder_name + '/final_models/model_fold_1' + '.h5'                                   
    model1.save(filename)
    print('>Saved %s' % filename)
                                     
    y_pred = model1.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_1 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))
    filenames = folder_name + '/final_thrs/thr_1' + '.pickle'                                 
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_1, filehandle)
    
                                     
    order_train = second_fold_array[folds2[i][0]]
    order_valid = second_fold_array[folds2[i][1]]
                                     
    model2 = ResNetmod(channels, classes=26)
    history = model2.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_2' + '.h5'
    model2.save(filename)
    print('>Saved %s' % filename)

    y_pred = model2.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_2 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                 
    filenames = folder_name + '/final_thrs/thr_2' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_2, filehandle)
            
        
    order_train = third_fold_array[folds3[i][0]]
    order_valid = third_fold_array[folds3[i][1]]
                                     
    model3 = ResNetmod(channels, classes=26)
    history = model3.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_3' + '.h5'
    model3.save(filename)
    print('>Saved %s' % filename)

    y_pred = model3.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_3 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                     
    filenames = folder_name + '/final_thrs/thr_3' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_3, filehandle)
                                     
        
    # Train 3-lead ECG model.
    print('Training 3-lead ECG model...')

    channels = 3
    three_lead_model_filename = 'three_lead_model_filename'    
    folder_name = model_directory + '/' + three_lead_model_filename
    os.makedirs(folder_name + '/final_models/') 
    os.makedirs(folder_name + '/final_thrs/') 

    order_train = first_fold_array[folds1[i][0]]
    order_valid = first_fold_array[folds1[i][1]]
                                     
    model1 = ResNetmod(channels, classes=26)
    history = model1.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)
    filename = folder_name + '/final_models/model_fold_1' + '.h5'                                   
    model1.save(filename)
    print('>Saved %s' % filename)
                                     
    y_pred = model1.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_1 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))
    filenames = folder_name + '/final_thrs/thr_1' + '.pickle'                                 
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_1, filehandle)
    
                                     
    order_train = second_fold_array[folds2[i][0]]
    order_valid = second_fold_array[folds2[i][1]]
                                     
    model2 = ResNetmod(channels, classes=26)
    history = model2.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_2' + '.h5'
    model2.save(filename)
    print('>Saved %s' % filename)

    y_pred = model2.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_2 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                 
    filenames = folder_name + '/final_thrs/thr_2' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_2, filehandle)
            
        
    order_train = third_fold_array[folds3[i][0]]
    order_valid = third_fold_array[folds3[i][1]]
                                     
    model3 = ResNetmod(channels, classes=26)
    history = model3.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_3' + '.h5'
    model3.save(filename)
    print('>Saved %s' % filename)

    y_pred = model3.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_3 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                     
    filenames = folder_name + '/final_thrs/thr_3' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_3, filehandle)
                                     
                                     
    
    # Train 2-lead ECG model.
    print('Training 2-lead ECG model...')

    channels = 2
    two_lead_model_filename = 'two_lead_model_filename'
    folder_name = model_directory + '/' + two_lead_model_filename
    os.makedirs(folder_name + '/final_models/') 
    os.makedirs(folder_name + '/final_thrs/')
        
    order_train = first_fold_array[folds1[i][0]]
    order_valid = first_fold_array[folds1[i][1]]
                                     
    model1 = ResNetmod(channels, classes=26)
    history = model1.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)
    filename = folder_name + '/final_models/model_fold_1' + '.h5'                                   
    model1.save(filename)
    print('>Saved %s' % filename)
                                     
    y_pred = model1.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_1 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))
    filenames = folder_name + '/final_thrs/thr_1' + '.pickle'                                 
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_1, filehandle)
    
                                     
    order_train = second_fold_array[folds2[i][0]]
    order_valid = second_fold_array[folds2[i][1]]
                                     
    model2 = ResNetmod(channels, classes=26)
    history = model2.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_2' + '.h5'
    model2.save(filename)
    print('>Saved %s' % filename)

    y_pred = model2.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_2 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                 
    filenames = folder_name + '/final_thrs/thr_2' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_2, filehandle)
            
        
    order_train = third_fold_array[folds3[i][0]]
    order_valid = third_fold_array[folds3[i][1]]
                                     
    model3 = ResNetmod(channels, classes=26)
    history = model3.fit(x=shuffle_batch_generator_data(batch_size=batchsize, channels=channels, classes=classes, order_train=order_train, x_gen=generate_X_shuffle(order_train=order_train, X_train=filenames_array, channels=channels), y_gen=generate_y_shuffle(order_train=order_train, y_train=y)), epochs=epochs, callbacks=cb, steps_per_epoch=(len(order_train)/batchsize), validation_data=generate_validation_data_only(filenames_array,y, order_valid, channels), validation_freq=1)                               
    filename = folder_name + '/final_models/model_fold_3' + '.h5'
    model3.save(filename)
    print('>Saved %s' % filename)

    y_pred = model3.predict(x=generate_validation_data_only(filenames_array,y, order_valid, channels)[0])
    labels = generate_validation_data_only(filenames_array, y, order_valid, channels)[1]
    th_init = np.arange(0,1,0.01)
    scores = initialize_th(weights, labels, y_pred, classes_list)
    new_thr_3 = optimize.fmin(th_optimization, args=(weights, labels, y_pred, classes_list), x0=th_init[scores.argmax()]*np.ones(26))                                     
    filenames = folder_name + '/final_thrs/thr_3' + '.pickle'
    with open(filenames, 'wb') as filehandle:
        pickle.dump(new_thr_3, filehandle)
        

################################################################################
#
# File I/O functions
#
################################################################################

# Load your trained 12-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def load_model(model_directory, leads):

    filename = os.path.join(model_directory, get_model_filename(leads))

    models = []
    
    path1 = 'final_models/model_fold_1.h5'
    name = os.path.join(filename, path1)
    model1 = keras.models.load_model(name)
    models.append(model1)
    
    path2 = 'final_models/model_fold_2.h5'
    name = os.path.join(filename, path2)
    model2 = keras.models.load_model(name)
    models.append(model2)
    
    path3 = 'final_models/model_fold_3.h5'
    name = os.path.join(filename, path3)
    model3 = keras.models.load_model(name)
    models.append(model3)
    
    th1 = 'final_thrs/thr_1.pickle'
    name = os.path.join(filename, th1)
    with open(name, 'rb') as filehandle:
        new_thr_1 = pickle.load(filehandle)
    models.append(new_thr_1)
            
    th2 = 'final_thrs/thr_2.pickle'
    name = os.path.join(filename, th2)
    with open(name, 'rb') as filehandle:
        new_thr_2 = pickle.load(filehandle)
    models.append(new_thr_2)

    th3 = 'final_thrs/thr_3.pickle'
    name = os.path.join(filename, th3)
    with open(name, 'rb') as filehandle:
        new_thr_3 = pickle.load(filehandle)
    models.append(new_thr_3)

    
    return (models)

# Define the filename(s) for the trained models. This function is not required. You can change or remove it.
def get_model_filename(leads):
    sorted_leads = sort_leads(leads)
    return leads_name(sorted_leads) + '_lead_model_filename'

def leads_name(leads):
    twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
    six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
    four_leads = ('I', 'II', 'III', 'V2')
    three_leads = ('I', 'II', 'V2')
    two_leads = ('I', 'II')

    if leads == twelve_leads:
        name = 'twelve'
    elif leads == six_leads:
        name = 'six'
    elif leads == four_leads:
        name = 'four'
    elif leads == three_leads:
        name = 'three'
    elif leads == two_leads:
        name = 'two'
    return name

################################################################################
#
# Running trained model functions
#
################################################################################


# Generic function for running a trained model.
def run_model(model, header, recording):
    filenames = 'list classes.txt'
    with open(filenames, 'r') as filehandle:
        classes = json.load(filehandle)
    print ('resampling')
    freq = int(header.split(' ')[2])
    X = resample_ecg(recording, freq, 257)
    X = pad_sequences(X, maxlen=4096, dtype='float32', truncating='post',padding="post")
    leads = get_leads(header)
    num_leads = len(leads)
    X = X.reshape(1,4096,num_leads)
    
    #K.clear_session()
    print ('load models')
    #load models

    model1 = model[0]
    model2 = model[1]  
    model3 = model[2]

    print ('load thresholds')
    # load thresholds
    
    new_thr_1 = model[3]
    new_thr_2 = model[4]
    new_thr_3 = model[5]
    
    print ('makes prediction')
    # Predict labels and probabilities.
    y_pred_1 = model1.predict(X)
    y_pred_2 = model2.predict(X)
    y_pred_3 = model3.predict(X)
    
    y_pred_1 = (y_pred_1>new_thr_1)*1
    y_pred_2 = (y_pred_2>new_thr_2)*1  
    y_pred_3 = (y_pred_3>new_thr_3)*1
    
    probabilities = (np.sum((y_pred_1 + y_pred_2 + y_pred_3),axis=0)/3)
    
    labels = (probabilities>0.5)*1

    return classes, labels, probabilities


################################################################################
#
# Other functions
#
################################################################################

# Extract features from the header and recording.
def get_features(header, recording, leads):
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Pre-process recordings.
    adc_gains = get_adcgains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i] 

    # Compute the root mean square of each ECG lead signal.
    rms = np.zeros(num_leads, dtype=np.float32)
    for i in range(num_leads):
        x = recording[i, :]
        rms[i] = np.sqrt(np.sum(x**2) / np.size(x))

    return age, sex, rms
