#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of required functions, remove non-required functions, and add your own function.

from helper_code import *
from utils import *
from tf_utils import *
import numpy as np, os, random, gc, joblib
# from sklearn.impute import SimpleImputer
# from sklearn.ensemble import RandomForestClassifier

# Define the Challenge lead sets. These variables are not required. You can change or remove them.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)


model_sav = {len(i):'{}_lead_model.sav'.format(len(i)) for i in lead_sets}
model_weights = {len(i):'{}_lead_model_weights.h5'.format(len(i)) for i in lead_sets}

weights_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),'weights.csv')
normal_class = '426783006'
classes_set, weights = load_weights(weights_file)
y_map = {cls:i for i,cls_set in enumerate(classes_set) for cls in cls_set}
classes = y_map.keys()
class_count = max(y_map.values())+1
new_freq = 125

################################################################################
#
# Training function
#
################################################################################
def build_model_and_train(header_files,leads,model_directory):
    # set common training settings
    tf.random.set_seed(1)
    
    model_filename = os.path.join(model_directory, model_weights[len(leads)])
    model_sav_filename = os.path.join(model_directory, model_sav[len(leads)])

    epochs = 100
    batch_size = 32
    learning_rate = 0.001

    num_valid_recordings = len(header_files)
    dev_size = int(0.3*num_valid_recordings)
    dev_set = header_files[:dev_size]
    train_set = header_files[dev_size:]

    #make dataset
    dataset_maker = MakeDataset(y_map=y_map,freq=new_freq,leads=leads)
    train_dataset = dataset_maker(train_set,batch_size=batch_size,shuffle=True)
    dev_dataset = dataset_maker(dev_set,batch_size=batch_size,shuffle=False)

    #build model
    model = build_model(class_count=class_count,lead_count=len(leads))

    #compile model
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
    challenge_metric_obj = PysionetChallengeMetric(inactive_index=y_map[normal_class],
                            name='challenge_metric',class_count=class_count,weights=weights)
    save_best_callback=tf.keras.callbacks.ModelCheckpoint(model_filename, monitor='val_challenge_metric', 
                                                save_best_only=True,save_weights_only=True, mode='max')
    loss_obj = CustomChallengeLoss(inactive_index=y_map[normal_class],
                            class_count=class_count,weights=weights)
    model.compile(optimizer=optimizer, loss=loss_obj ,
                metrics=[F1Macro(name='f1_macro',class_count=class_count),challenge_metric_obj])
    history=model.fit(train_dataset, epochs=epochs,verbose=1,validation_data=dev_dataset,
                    callbacks =[save_best_callback, StopIfTrainingFails(monitor='val_challenge_metric',cutoff_epoch=5,min_value=0.2)])
    
    #Threshold Optimization on dev set
    model.load_weights(model_filename)
    Y_true_dev = np.vstack([get_y(get_labels(load_header(i)),y_map) for i in dev_set])
    Y_pred_dev = model.predict(dev_dataset,verbose=True)

    challenge_metric_objective_function = ChallengeMetricOpt(Y_true_dev,Y_pred_dev,weights,y_map[normal_class])
    threshold = get_optim_thresholds(challenge_metric_objective_function,class_count)

    save_model_data(model_sav_filename, leads, threshold)

    # Reset keras
    tf.keras.backend.clear_session()
    try:
        del model, Y_true_dev, Y_pred_dev, challenge_metric_objective_function
        del loss_obj, optimizer, challenge_metric_obj 
        del dataset_maker, train_dataset, dev_dataset, dev_set, train_set
    except:
        pass
    print(gc.collect())
    
    #check the challenge score is above threshold, This is a hack, please fix later. 
    if np.max(history.history['val_challenge_metric'])<0.2:
        #training not successful
        return 1
    return 0


# Train your model. This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    # Find header and recording files.
    print('Finding header and recording files...')

    header_files, recording_files = find_challenge_files(data_directory)
    num_recordings = len(recording_files)

    if not num_recordings:
        raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)

    # Make a list of valid header file names.
    print('Extracting valid header file names...')
    valid_header_files = []
    for header_file in header_files:
        header = load_header(header_file)
        if any([i in classes for i in get_labels(header)]):
            valid_header_files.append(header_file)

    # shuffle header files and split to train and validation set
    num_valid_recordings = len(valid_header_files)
    print('number of valid recordings found : ',num_valid_recordings)
    random.shuffle(valid_header_files)

    # Train 12-lead ECG model.
    
    for leads in lead_sets:
        print('Training {}-lead ECG model...'.format(len(leads)))
        count = 0
        while build_model_and_train(header_files=valid_header_files,leads=leads,
                                model_directory=model_directory)and count<3: count+=1

################################################################################
#
# File I/O functions
#
################################################################################

def save_model_data(filename, leads, threshold):
    # Construct a data structure for the model and save it.
    d = {'threshold': threshold, 'leads': leads}
    joblib.dump(d, filename, protocol=0)

def load_model(model_directory,leads):
    num_leads = len(leads)
    filename = os.path.join(model_directory, model_weights[num_leads])
    filename_sav = os.path.join(model_directory, model_sav[num_leads])
    model = build_model(class_count=class_count,lead_count=num_leads)
    model.load_weights(filename)
    data = joblib.load(filename_sav)
    data.update({"model": model})
    
    return data


################################################################################
#
# Running trained model functions
#
################################################################################

def run_model(model, header, recording):
    classes = y_map.keys()
    leads = model['leads']
    threshold = model['threshold']
    #convert to 27 class thresholds
    threshold = np.array([threshold[y_map[i]] for i in classes])
    m = model['model']
    # Load data.
    p=Patient(header.split('\n'),recording,leads,training=False)
    dr = int(round(p.fs/new_freq))
    p.ecg = p.ecg[:,::dr] #shape (12,T)
    p.ecg = p.ecg.T #shape (T,12)
    # crop signals if too long
    max_length = int(120*new_freq) #120s
    if p.ecg.shape[0]>max_length:
        p.ecg = p.ecg[:max_length,:]
    x = p.ecg[np.newaxis,...]
    # Predict labels and probabilities.
    y = m.predict(x)[0]
    probabilities = np.array([y[y_map[i]] for i in classes])
    labels = (probabilities>threshold).astype(np.int)

    return classes, labels, probabilities

################################################################################
#
# Other functions
#
################################################################################
