#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of required functions, remove non-required functions, and add your own function.

from helper_code import *
import numpy as np, os, sys, joblib
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
import neurokit2 as nk
from skmultilearn.ensemble import LabelSpacePartitioningClassifier
from skmultilearn.cluster import FixedLabelSpaceClusterer
from sklearn.ensemble import RandomForestClassifier
from skmultilearn.problem_transform import ClassifierChain
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from ECGfeaturizer import featurize as ef
import time

# Define 12, 6, and 2 lead ECG sets.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2') 
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)


twelve_lead_model_filename = '12_lead_model.sav'
six_lead_model_filename = '6_lead_model.sav'
four_lead_model_filename = '4_lead_model.sav'
three_lead_model_filename = '3_lead_model.sav'
two_lead_model_filename = '2_lead_model.sav'


################################################################################
#
# Training function
#
################################################################################


# Train your model. This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    # Find header and recording files.
    print('Finding header and recording files...')

    features_num_12 = 136

    header_files, recording_files = find_challenge_files(data_directory)
    num_recordings = len(recording_files)

    if not num_recordings:
        raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)

    # Extract classes from dataset.
    print('Extracting classes...')
    
    all_labels = []
    #classes = set()
    for header_file in header_files:
        header = load_header(header_file)
        #classes |= set(get_labels(header))
        all_labels.append(get_labels(header))

    df_labels = pd.DataFrame(all_labels)

    SNOMED_scored=pd.read_csv("./dx_mapping_scored.csv", sep=",")
    SNOMED_unscored=pd.read_csv("./dx_mapping_unscored.csv", sep=",")
    for i in range(len(SNOMED_unscored.iloc[0:,1])):
        df_labels.replace(to_replace=str(SNOMED_unscored.iloc[i,1]), inplace=True ,value="undefined class", regex=True)

    one_hot = MultiLabelBinarizer()
    y_temp = one_hot.fit_transform(df_labels[0].str.split(pat=','))
    y_temp= np.delete(y_temp, -1, axis=1)
    classes = one_hot.classes_[0:-1]


    if all(is_integer(x) for x in classes):
        classes = sorted(classes, key=lambda x: int(x)) # Sort classes numerically if numbers.
    else:
        classes = sorted(classes) # Sort classes alphanumerically otherwise.
    num_classes = len(classes)
    print("classes:",num_classes)


    # Extract features and labels from dataset.
    print('Extracting features and labels...')

    featurize_lead_I = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, t_int = False, 
                                        q_peak = True, q_int= False, s_peak = True, s_int = False, qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_II = ef.get_features(r_peak = True, r_int = True, p_peak = True, p_int = True, t_peak = True, t_int = True, 
                                        q_peak = True, q_int= True, s_peak = True, s_int = True, qrs_dur= True, qt_dur = True, pr_dur = True)

    featurize_lead_III = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_aVR = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_aVF = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_aVL = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V1 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V2 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V3 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V4 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V5 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V6 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    data = np.zeros((num_recordings, features_num_12+2), dtype=np.float32) # 6 features: 4 feature for based on ECG, one feature for age, and one feature for sex
    labels = np.zeros((num_recordings, num_classes), dtype=np.bool) # One-hot encoding of classes

    for i in range(num_recordings):
        print('    {}/{}...'.format(i+1, num_recordings))

        # Load header and recording.
        header = load_header(header_files[i])
        recording = load_recording(recording_files[i])

        # Get age, sex and root mean square of the leads.
        try:
            
            #age, sex, ecg_features = get_features(header, recording, twelve_leads,featurize_lead_I,featurize_lead_II,featurize_lead_III,featurize_lead_aVR,featurize_lead_aVL,featurize_lead_aVF,featurize_lead_V1,featurize_lead_V2,
            #featurize_lead_V3, featurize_lead_V4, featurize_lead_V5, featurize_lead_V6)
                # Extract age.
            age = get_age(header)
            if age is None:
                age = float('nan')

            # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
            sex = get_sex(header)
            if sex in ('Female', 'female', 'F', 'f'):
                sex = 0
            elif sex in ('Male', 'male', 'M', 'm'):
                sex = 1
            else:
                sex = float('nan')

            # Reorder/reselect leads in recordings.
            available_leads = get_leads(header)
            indices = list()
            leads = twelve_leads
            for lead in leads:
                i = available_leads.index(lead)
                indices.append(i)
            recording = recording[indices, :]
            
            # Pre-process recordings.
            adc_gains = get_adcgains(header, leads)
            baselines = get_baselines(header, leads)
            num_leads = len(leads)
            sample_freq = int(header.split()[2])
            for i in range(num_leads):
                recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]
                if leads[i] == "I":
                    temp_features,_ = featurize_lead_I.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="II":
                    temp_features,_ = featurize_lead_II.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="III":
                    temp_features,_ = featurize_lead_III.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="aVR":
                    temp_features,_ = featurize_lead_aVR.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="aVL":
                    temp_features,_ = featurize_lead_aVL.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="aVF":
                    temp_features,_ = featurize_lead_aVF.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="V1":
                    temp_features,_ = featurize_lead_V1.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="V2":
                    temp_features,_ = featurize_lead_V2.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="V3":
                    temp_features,_ = featurize_lead_V3.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="V4":
                    temp_features,_ = featurize_lead_V4.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="V5":
                    temp_features,_ = featurize_lead_V5.featurize_ecg(recording[i, :],sample_freq)
                elif leads[i] =="V6":
                    temp_features,_ = featurize_lead_V6.featurize_ecg(recording[i, :],sample_freq)
                else:
                    print("undefined lead")

                if i == 0:
                    ecg_features = temp_features
                else:
                    ecg_features = np.hstack([ecg_features,temp_features])
        except:
            # Senere skal jeg legge til en liste med ekstraherte features som jeg selv har annotert (for "vanskelige" EKG)
            continue

        data[i, 0:features_num_12] = ecg_features
        data[i, features_num_12] = age
        data[i, features_num_12+1] = sex

        current_labels = get_labels(header)
        for label in current_labels:
            if label in classes:
                j = classes.index(label)
                labels[i, j] = 1


    # Make cluster
    ohe = labels * 1
    my_cluster = []
    for i in range(len(ohe.T)):
        my_cluster.append(np.unique(np.where(ohe[np.where(ohe.T[i]==1)])[1]))
    
    # Train models.

    # Define parameters for random forest classifier.
    n_estimators = 3     # Number of trees in the forest.
    #max_leaf_nodes = 100 # Maximum number of leaf nodes in each tree.
    #random_state = 0     # Random state; set for reproducibility.

    # Train 12-lead ECG model.
    print('Training 12-lead ECG model...')

    leads = lead_sets[0]
    filename = os.path.join(model_directory, twelve_lead_model_filename)

    #feature_indices = [twelve_leads.index(lead) for lead in leads] + [12, 13]
    #features = data[:, feature_indices]

    FeatureNames = [ "Lead_I_mean_r_peak", "Lead_I_sd_r_peak", "Lead_I_mean_p_peak", "Lead_I_sd_p_peak", "Lead_I_mean_t_peak", "Lead_I_sd_t_peak", "Lead_I_mean_q_peak", 
    "Lead_I_sd_q_peak", "Lead_I_mean_s_peak", "Lead_I_sd_s_peak", 

    "Lead_II_mean_rr_interval", "Lead_II_sd_rr_interval", "Lead_II_mean_r_peak", "Lead_II_sd_r_peak", "Lead_II_mean_pp_interval",
    "Lead_II_sd_pp_interval", "Lead_II_mean_p_peak", "Lead_II_sd_p_peak", "Lead_II_mean_tt_interval", "Lead_II_sd_tt_interval", "Lead_II_mean_t_peak",
    "Lead_II_sd_t_peak", "Lead_II_mean_qq_interval", "Lead_II_sd_qq_interval", "Lead_II_mean_q_peak", "Lead_II_sd_q_peak", "Lead_II_mean_s_int", "Lead_II_sd_s_int", 
    "Lead_II_mean_s_peak", "Lead_II_sd_s_peak", "Lead_II_qrs_mean", "Lead_II_qrs_std", "Lead_II_qt_mean", "Lead_II_qt_std", "Lead_II_pr_mean", "Lead_II_pr_std",

    "Lead_III_mean_r_peak", "Lead_III_sd_r_peak", "Lead_III_mean_p_peak", "Lead_III_sd_p_peak", "Lead_III_mean_t_peak", "Lead_III_sd_t_peak", "Lead_III_mean_q_peak", 
    "Lead_III_sd_q_peak", "Lead_III_mean_s_peak", "Lead_III_sd_s_peak", 

    
    "Lead_aVR_mean_r_peak", "Lead_aVR_sd_r_peak", "Lead_aVR_mean_p_peak", "Lead_aVR_sd_p_peak", "Lead_aVR_mean_t_peak", "Lead_aVR_sd_t_peak", "Lead_aVR_mean_q_peak", 
    "Lead_aVR_sd_q_peak", "Lead_aVR_mean_s_peak", "Lead_aVR_sd_s_peak", 

    "Lead_aVF_mean_r_peak", "Lead_aVF_sd_r_peak", "Lead_aVF_mean_p_peak", "Lead_aVF_sd_p_peak", "Lead_aVF_mean_t_peak", "Lead_aVF_sd_t_peak", "Lead_aVF_mean_q_peak", 
    "Lead_aVF_sd_q_peak", "Lead_aVF_mean_s_peak", "Lead_aVF_sd_s_peak", 

    "Lead_aVL_mean_r_peak", "Lead_aVL_sd_r_peak", "Lead_aVL_mean_p_peak", "Lead_aVL_sd_p_peak", "Lead_aVL_mean_t_peak", "Lead_aVL_sd_t_peak", "Lead_aVL_mean_q_peak", 
    "Lead_aVL_sd_q_peak", "Lead_aVL_mean_s_peak", "Lead_aVL_sd_s_peak",

    "Lead_V1_mean_r_peak", "Lead_V1_sd_r_peak", "Lead_V1_mean_p_peak", "Lead_V1_sd_p_peak", "Lead_V1_mean_t_peak", "Lead_V1_sd_t_peak", "Lead_V1_mean_q_peak", 
    "Lead_V1_sd_q_peak", "Lead_V1_mean_s_peak", "Lead_V1_sd_s_peak",

    "Lead_V2_mean_r_peak", "Lead_V2_sd_r_peak", "Lead_V2_mean_p_peak", "Lead_V2_sd_p_peak", "Lead_V2_mean_t_peak", "Lead_V2_sd_t_peak", "Lead_V2_mean_q_peak", 
    "Lead_V2_sd_q_peak", "Lead_V2_mean_s_peak", "Lead_V2_sd_s_peak",

    "Lead_V3_mean_r_peak", "Lead_V3_sd_r_peak", "Lead_V3_mean_p_peak", "Lead_V3_sd_p_peak", "Lead_V3_mean_t_peak", "Lead_V3_sd_t_peak", "Lead_V3_mean_q_peak", 
    "Lead_V3_sd_q_peak", "Lead_V3_mean_s_peak", "Lead_V3_sd_s_peak",

    "Lead_V4_mean_r_peak", "Lead_V4_sd_r_peak", "Lead_V4_mean_p_peak", "Lead_V4_sd_p_peak", "Lead_V4_mean_t_peak", "Lead_V4_sd_t_peak", "Lead_V4_mean_q_peak", 
    "Lead_V4_sd_q_peak", "Lead_V4_mean_s_peak", "Lead_V4_sd_s_peak",

    "Lead_V5_mean_r_peak", "Lead_V5_sd_r_peak", "Lead_V5_mean_p_peak", "Lead_V5_sd_p_peak", "Lead_V5_mean_t_peak", "Lead_V5_sd_t_peak", "Lead_V5_mean_q_peak", 
    "Lead_V5_sd_q_peak", "Lead_V5_mean_s_peak", "Lead_V5_sd_s_peak",

    "Lead_V6_mean_r_peak", "Lead_V6_sd_r_peak", "Lead_V6_mean_p_peak", "Lead_V6_sd_p_peak", "Lead_V6_mean_t_peak", "Lead_V6_sd_t_peak", "Lead_V6_mean_q_peak", 
    "Lead_V6_sd_q_peak", "Lead_V6_mean_s_peak", "Lead_V6_sd_s_peak", "age", "sex"]
    
    #Til nå kan features være lik data
    data = pd.DataFrame(data=data)
    data.columns = FeatureNames


    features = data.to_numpy()
    #features = np.nan_to_num(features, nan = 0)

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)

    print("Making the 12-lead model")
    classifier = LabelSpacePartitioningClassifier(
        classifier = ClassifierChain(
            classifier= RandomForestClassifier(n_jobs=-1,n_estimators=n_estimators, verbose=1),
            require_dense = [False, True]
        ),
        require_dense = [True, True],
        clusterer = FixedLabelSpaceClusterer(clusters=my_cluster)
    )
    classifier.fit(features, labels)
    #save_model(filename, classes, leads, imputer, classifier)
    save_model(model_directory, leads, classes, imputer, classifier)
    # Train 6-lead ECG model.
    print('Training 6-lead ECG model...')

    leads = lead_sets[1]
    filename = os.path.join(model_directory, six_lead_model_filename)

    #feature_indices = [twelve_leads.index(lead) for lead in leads] + [12, 13]
    #features = data[:, feature_indices]

    #Til nå kan features være lik data
    data_6_lead = data[["Lead_I_mean_r_peak", "Lead_I_sd_r_peak", "Lead_I_mean_p_peak", "Lead_I_sd_p_peak", "Lead_I_mean_t_peak", "Lead_I_sd_t_peak", "Lead_I_mean_q_peak", 
    "Lead_I_sd_q_peak", "Lead_I_mean_s_peak", "Lead_I_sd_s_peak", 

    "Lead_II_mean_rr_interval", "Lead_II_sd_rr_interval", "Lead_II_mean_r_peak", "Lead_II_sd_r_peak", "Lead_II_mean_pp_interval",
    "Lead_II_sd_pp_interval", "Lead_II_mean_p_peak", "Lead_II_sd_p_peak", "Lead_II_mean_tt_interval", "Lead_II_sd_tt_interval", "Lead_II_mean_t_peak",
    "Lead_II_sd_t_peak", "Lead_II_mean_qq_interval", "Lead_II_sd_qq_interval", "Lead_II_mean_q_peak", "Lead_II_sd_q_peak", "Lead_II_mean_s_int", "Lead_II_sd_s_int", 
    "Lead_II_mean_s_peak", "Lead_II_sd_s_peak", "Lead_II_qrs_mean", "Lead_II_qrs_std", "Lead_II_qt_mean", "Lead_II_qt_std", "Lead_II_pr_mean", "Lead_II_pr_std",

    "Lead_III_mean_r_peak", "Lead_III_sd_r_peak", "Lead_III_mean_p_peak", "Lead_III_sd_p_peak", "Lead_III_mean_t_peak", "Lead_III_sd_t_peak", "Lead_III_mean_q_peak", 
    "Lead_III_sd_q_peak", "Lead_III_mean_s_peak", "Lead_III_sd_s_peak", 

    
    "Lead_aVR_mean_r_peak", "Lead_aVR_sd_r_peak", "Lead_aVR_mean_p_peak", "Lead_aVR_sd_p_peak", "Lead_aVR_mean_t_peak", "Lead_aVR_sd_t_peak", "Lead_aVR_mean_q_peak", 
    "Lead_aVR_sd_q_peak", "Lead_aVR_mean_s_peak", "Lead_aVR_sd_s_peak", 

    "Lead_aVF_mean_r_peak", "Lead_aVF_sd_r_peak", "Lead_aVF_mean_p_peak", "Lead_aVF_sd_p_peak", "Lead_aVF_mean_t_peak", "Lead_aVF_sd_t_peak", "Lead_aVF_mean_q_peak", 
    "Lead_aVF_sd_q_peak", "Lead_aVF_mean_s_peak", "Lead_aVF_sd_s_peak", 

    "Lead_aVL_mean_r_peak", "Lead_aVL_sd_r_peak", "Lead_aVL_mean_p_peak", "Lead_aVL_sd_p_peak", "Lead_aVL_mean_t_peak", "Lead_aVL_sd_t_peak", "Lead_aVL_mean_q_peak", 
    "Lead_aVL_sd_q_peak", "Lead_aVL_mean_s_peak", "Lead_aVL_sd_s_peak", "age", "sex"]]

    features = data_6_lead.to_numpy()
    #features = np.nan_to_num(features, nan = 0)

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)

    print("Making the 6-lead model")
    classifier = LabelSpacePartitioningClassifier(
        classifier = ClassifierChain(
            classifier= RandomForestClassifier(n_jobs=-1,n_estimators=n_estimators, verbose=1),
            require_dense = [False, True]
        ),
        require_dense = [True, True],
        clusterer = FixedLabelSpaceClusterer(clusters=my_cluster)
    )
    classifier.fit(features, labels)
    
    #save_model(filename, classes, leads, imputer, classifier)
    save_model(model_directory, leads, classes, imputer, classifier)








    # Train 4-lead ECG model.
    print('Training 4-lead ECG model...')

    leads = lead_sets[2]
    filename = os.path.join(model_directory, four_lead_model_filename)

    #feature_indices = [twelve_leads.index(lead) for lead in leads] + [12, 13]
    #features = data[:, feature_indices]

    data_4_lead = data[["Lead_I_mean_r_peak", "Lead_I_sd_r_peak", "Lead_I_mean_p_peak", "Lead_I_sd_p_peak", "Lead_I_mean_t_peak", "Lead_I_sd_t_peak", "Lead_I_mean_q_peak", 
    "Lead_I_sd_q_peak", "Lead_I_mean_s_peak", "Lead_I_sd_s_peak", 

    "Lead_II_mean_rr_interval", "Lead_II_sd_rr_interval", "Lead_II_mean_r_peak", "Lead_II_sd_r_peak", "Lead_II_mean_pp_interval",
    "Lead_II_sd_pp_interval", "Lead_II_mean_p_peak", "Lead_II_sd_p_peak", "Lead_II_mean_tt_interval", "Lead_II_sd_tt_interval", "Lead_II_mean_t_peak",
    "Lead_II_sd_t_peak", "Lead_II_mean_qq_interval", "Lead_II_sd_qq_interval", "Lead_II_mean_q_peak", "Lead_II_sd_q_peak", "Lead_II_mean_s_int", "Lead_II_sd_s_int", 
    "Lead_II_mean_s_peak", "Lead_II_sd_s_peak", "Lead_II_qrs_mean", "Lead_II_qrs_std", "Lead_II_qt_mean", "Lead_II_qt_std", "Lead_II_pr_mean", "Lead_II_pr_std", 
    "Lead_III_mean_r_peak", "Lead_III_sd_r_peak", "Lead_III_mean_p_peak", "Lead_III_sd_p_peak", "Lead_III_mean_t_peak", "Lead_III_sd_t_peak", "Lead_III_mean_q_peak", 
    "Lead_III_sd_q_peak", "Lead_III_mean_s_peak", "Lead_III_sd_s_peak", "Lead_V2_mean_r_peak", "Lead_V2_sd_r_peak", "Lead_V2_mean_p_peak", "Lead_V2_sd_p_peak", 
    "Lead_V2_mean_t_peak", "Lead_V2_sd_t_peak", "Lead_V2_mean_q_peak", "Lead_V2_sd_q_peak", "Lead_V2_mean_s_peak", "Lead_V2_sd_s_peak","age", "sex"]]

    
    features = data_4_lead.to_numpy()

    #features = np.nan_to_num(features, nan = 0)

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)


    print("Making the 4-lead model")
    classifier = LabelSpacePartitioningClassifier(
        classifier = ClassifierChain(
            classifier= RandomForestClassifier(n_jobs=-1,n_estimators=n_estimators, verbose=1),
            require_dense = [False, True]
        ),
        require_dense = [True, True],
        clusterer = FixedLabelSpaceClusterer(clusters=my_cluster)
    )
    classifier.fit(features, labels)

    #save_model(filename, classes, leads, imputer, classifier)
    save_model(model_directory, leads, classes, imputer, classifier)






    # Train 3-lead ECG model.
    print('Training 3-lead ECG model...')

    leads = lead_sets[3]
    filename = os.path.join(model_directory, three_lead_model_filename)

    #feature_indices = [twelve_leads.index(lead) for lead in leads] + [12, 13]
    #features = data[:, feature_indices]

    data_3_lead = data[["Lead_I_mean_r_peak", "Lead_I_sd_r_peak", "Lead_I_mean_p_peak", "Lead_I_sd_p_peak", "Lead_I_mean_t_peak", "Lead_I_sd_t_peak", "Lead_I_mean_q_peak", 
    "Lead_I_sd_q_peak", "Lead_I_mean_s_peak", "Lead_I_sd_s_peak", 

    "Lead_II_mean_rr_interval", "Lead_II_sd_rr_interval", "Lead_II_mean_r_peak", "Lead_II_sd_r_peak", "Lead_II_mean_pp_interval",
    "Lead_II_sd_pp_interval", "Lead_II_mean_p_peak", "Lead_II_sd_p_peak", "Lead_II_mean_tt_interval", "Lead_II_sd_tt_interval", "Lead_II_mean_t_peak",
    "Lead_II_sd_t_peak", "Lead_II_mean_qq_interval", "Lead_II_sd_qq_interval", "Lead_II_mean_q_peak", "Lead_II_sd_q_peak", "Lead_II_mean_s_int", "Lead_II_sd_s_int", 
    "Lead_II_mean_s_peak", "Lead_II_sd_s_peak", "Lead_II_qrs_mean", "Lead_II_qrs_std", "Lead_II_qt_mean", "Lead_II_qt_std", "Lead_II_pr_mean", "Lead_II_pr_std",
    "Lead_V2_mean_r_peak", "Lead_V2_sd_r_peak", "Lead_V2_mean_p_peak", "Lead_V2_sd_p_peak", "Lead_V2_mean_t_peak", "Lead_V2_sd_t_peak", "Lead_V2_mean_q_peak", 
    "Lead_V2_sd_q_peak", "Lead_V2_mean_s_peak", "Lead_V2_sd_s_peak","age", "sex"]]

    
    features = data_3_lead.to_numpy()

    #features = np.nan_to_num(features, nan = 0)

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)


    print("Making the 3-lead model")
    classifier = LabelSpacePartitioningClassifier(
        classifier = ClassifierChain(
            classifier= RandomForestClassifier(n_jobs=-1,n_estimators=n_estimators, verbose=1),
            require_dense = [False, True]
        ),
        require_dense = [True, True],
        clusterer = FixedLabelSpaceClusterer(clusters=my_cluster)
    )
    classifier.fit(features, labels)

    #save_model(filename, classes, leads, imputer, classifier)
    save_model(model_directory, leads, classes, imputer, classifier)

    # Train 2-lead ECG model.
    print('Training 2-lead ECG model...')

    leads = lead_sets[4]
    filename = os.path.join(model_directory, two_lead_model_filename)

    #feature_indices = [twelve_leads.index(lead) for lead in leads] + [12, 13]
    #features = data[:, feature_indices]
    data_2_lead = data[["Lead_I_mean_r_peak", "Lead_I_sd_r_peak", "Lead_I_mean_p_peak", "Lead_I_sd_p_peak", "Lead_I_mean_t_peak", "Lead_I_sd_t_peak", "Lead_I_mean_q_peak", 
    "Lead_I_sd_q_peak", "Lead_I_mean_s_peak", "Lead_I_sd_s_peak", "Lead_II_mean_rr_interval", "Lead_II_sd_rr_interval", "Lead_II_mean_r_peak", "Lead_II_sd_r_peak", 
    "Lead_II_mean_pp_interval", "Lead_II_sd_pp_interval", "Lead_II_mean_p_peak", "Lead_II_sd_p_peak", "Lead_II_mean_tt_interval", "Lead_II_sd_tt_interval", "Lead_II_mean_t_peak",
    "Lead_II_sd_t_peak", "Lead_II_mean_qq_interval", "Lead_II_sd_qq_interval", "Lead_II_mean_q_peak", "Lead_II_sd_q_peak", "Lead_II_mean_s_int", "Lead_II_sd_s_int", 
    "Lead_II_mean_s_peak", "Lead_II_sd_s_peak", "Lead_II_qrs_mean", "Lead_II_qrs_std", "Lead_II_qt_mean", "Lead_II_qt_std", "Lead_II_pr_mean", "Lead_II_pr_std",
     "age", "sex"]]
    
    features = data_2_lead.to_numpy()
    features = np.nan_to_num(features)

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)

    print("Making the 2-lead model")
    classifier = LabelSpacePartitioningClassifier(
        classifier = ClassifierChain(
            classifier= RandomForestClassifier(n_jobs=-1,n_estimators=n_estimators, verbose=1),
            require_dense = [False, True]
        ),
        require_dense = [True, True],
        clusterer = FixedLabelSpaceClusterer(clusters=my_cluster)
    )
    classifier.fit(features, labels)
    
    #save_model(filename, classes, leads, imputer, classifier)
    save_model(model_directory, leads, classes, imputer, classifier)
################################################################################
#
# File I/O functions
#
################################################################################

# Save a trained model. This function is not required. You can change or remove it.
def save_model(model_directory, leads, classes, imputer, classifier):
    d = {'leads': leads, 'classes': classes, 'imputer': imputer, 'classifier': classifier}
    filename = os.path.join(model_directory, get_model_filename(leads))
    joblib.dump(d, filename, protocol=0)

# Load a trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def load_model(model_directory, leads):
    filename = os.path.join(model_directory, get_model_filename(leads))
    return joblib.load(filename)

# Define the filename(s) for the trained models. This function is not required. You can change or remove it.
def get_model_filename(leads):
    sorted_leads = sort_leads(leads)
    return 'model_' + '-'.join(sorted_leads) + '.sav'

################################################################################
#
# Running trained model functions
#
################################################################################

# Run your trained 12-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def run_twelve_lead_model(model, header, recording):
    return run_model(model, header, recording)

# Run your trained 6-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def run_six_lead_model(model, header, recording):
    return run_model(model, header, recording)

# Run your trained 4-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def run_four_lead_model(model, header, recording):
    return run_model(model, header, recording)

# Run your trained 3-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def run_three_lead_model(model, header, recording):
    return run_model(model, header, recording)

# Run your trained 2-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def run_two_lead_model(model, header, recording):
    return run_model(model, header, recording)

# Generic function for running a trained model.
def run_model(model, header, recording):
    classes = model['classes']
    leads = model['leads']
    imputer = model['imputer']
    classifier = model['classifier']


    featurize_lead_I = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, t_int = False, 
                                        q_peak = True, q_int= False, s_peak = True, s_int = False, qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_II = ef.get_features(r_peak = True, r_int = True, p_peak = True, p_int = True, t_peak = True, t_int = True, 
                                        q_peak = True, q_int= True, s_peak = True, s_int = True, qrs_dur= True, qt_dur = True, pr_dur = True)

    featurize_lead_III = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_aVR = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_aVF = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_aVL = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V1 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V2 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V3 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V4 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V5 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    featurize_lead_V6 = ef.get_features(r_peak = True, r_int = False, p_peak = True, p_int = False, t_peak = True, 
                                    t_int = False, q_peak = True, q_int= False, s_peak = True, s_int = False, 
                                    qrs_dur= False, qt_dur = False, pr_dur = False)

    # Load features.
    num_leads = len(leads)
    if num_leads == 2:
        data = np.zeros(38, dtype=np.float32)
        feature_count = 36
    elif num_leads == 3:
        data = np.zeros(48, dtype=np.float32)
        feature_count = 46
    elif num_leads == 4:
        data = np.zeros(58, dtype=np.float32)
        feature_count = 56
    elif num_leads == 6:
        data = np.zeros(78, dtype=np.float32)
        feature_count = 76
    elif num_leads == 12:
        data = np.zeros(138, dtype=np.float32)
        feature_count = 136
    else:
        print("Undefined number of leads")

    try:
        age = get_age(header)
        if age is None:
            age = float('nan')

        # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
        sex = get_sex(header)
        if sex in ('Female', 'female', 'F', 'f'):
            sex = 0
        elif sex in ('Male', 'male', 'M', 'm'):
            sex = 1
        else:
            sex = float('nan')

        # Reorder/reselect leads in recordings.
        available_leads = get_leads(header)
        indices = list()
        #leads = twelve_leads
        for lead in leads:
            i = available_leads.index(lead)
            indices.append(i)
        recording = recording[indices, :]
        
        # Pre-process recordings.
        adc_gains = get_adcgains(header, leads)
        baselines = get_baselines(header, leads)
        num_leads = len(leads)
        sample_freq = int(header.split()[2])
        for i in range(num_leads):
            recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]
            if leads[i] == "I":
                temp_features,_ = featurize_lead_I.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="II":
                temp_features,_ = featurize_lead_II.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="III":
                temp_features,_ = featurize_lead_III.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="aVR":
                temp_features,_ = featurize_lead_aVR.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="aVL":
                temp_features,_ = featurize_lead_aVL.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="aVF":
                temp_features,_ = featurize_lead_aVF.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="V1":
                temp_features,_ = featurize_lead_V1.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="V2":
                temp_features,_ = featurize_lead_V2.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="V3":
                temp_features,_ = featurize_lead_V3.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="V4":
                temp_features,_ = featurize_lead_V4.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="V5":
                temp_features,_ = featurize_lead_V5.featurize_ecg(recording[i, :],sample_freq)
            elif leads[i] =="V6":
                temp_features,_ = featurize_lead_V6.featurize_ecg(recording[i, :],sample_freq)
            else:
                print("undefined lead")

            if i == 0:
                ecg_features = temp_features
            else:
                ecg_features = np.hstack([ecg_features,temp_features])
   
        data[0:feature_count] = ecg_features
        data[feature_count] = age
        data[feature_count+1] = sex

        # Impute missing data.
        features = data.reshape(1, -1)
        features = np.nan_to_num(features)
        features = imputer.transform(features)

        # Predict labels and probabilities.
        labels = classifier.predict(features)
        labels = labels.todense()
        labels = np.asarray(labels, dtype=np.int).ravel()
        #labels = np.asarray(labels, dtype=np.int)[0]

        #probabilities = classifier.predict_proba(features)
        #probabilities = probabilities.todense()
        probabilities = labels * 1.0
        probabilities = np.asarray(probabilities, dtype=np.float32).ravel()
        #probabilities = np.asarray(probabilities, dtype=np.float32)[:, 0, 1]

        return classes, labels, probabilities
    except:
        labels = np.zeros(len(classes))
        if "426783006" in classes:
            labels[np.where(classes == "426783006")[0]]
        else:
            labels = labels

        probabilities = labels * 1.0
        return classes, labels, probabilities

################################################################################
#
# Other functions
#
################################################################################

# Extract features from the header and recording.

# Extract features from the header and recording.
def get_features(header, recording, leads,featurize_lead_I,featurize_lead_II,featurize_lead_III,featurize_lead_aVR,featurize_lead_aVL,featurize_lead_aVF,featurize_lead_V1,featurize_lead_V2,
    featurize_lead_V3, featurize_lead_V4, featurize_lead_V5, featurize_lead_V6):
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Pre-process recordings.
    adc_gains = get_adcgains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    sample_freq = int(header.split()[2])
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]
        if leads[i] == "I":
            temp_features,_ = featurize_lead_I.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="II":
            temp_features,_ = featurize_lead_II.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="III":
            temp_features,_ = featurize_lead_III.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="aVR":
            temp_features,_ = featurize_lead_aVR.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="aVL":
            temp_features,_ = featurize_lead_aVL.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="aVF":
            temp_features,_ = featurize_lead_aVF.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="V1":
            temp_features,_ = featurize_lead_V1.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="V2":
            temp_features,_ = featurize_lead_V2.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="V3":
            temp_features,_ = featurize_lead_V3.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="V4":
            temp_features,_ = featurize_lead_V4.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="V5":
            temp_features,_ = featurize_lead_V5.featurize_ecg(recording[i, :],sample_freq)
        elif leads[i] =="V6":
            temp_features,_ = featurize_lead_V6.featurize_ecg(recording[i, :],sample_freq)
        else:
            print("undefined lead")

        if i == 0:
            features = temp_features
        else:
            features = np.hstack([features,temp_features])
    
    return age, sex, features



