#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of required functions, remove non-required functions, and add your own function.
import random
from helper_code import *
import numpy as np, os, sys, joblib
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
import physionet_challenge_utility_script as pc
import tensorflow as tf

tf.compat.v1.disable_eager_execution()
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pandas as pd
import scipy

twelve_lead_model_filename = '12_lead_model.sav'
six_lead_model_filename = '6_lead_model.sav'
three_lead_model_filename = '3_lead_model.sav'
two_lead_model_filename = '2_lead_model.sav'

# Define the Challenge lead sets. These variables are not required. You can change or remove them.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)


################################################################################
#
# Training function
#
################################################################################

# Train your model. This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    # Find header and recording files.
    print('Finding header and recording files...')

    gender, age, labels, ecg_filenames = pc.import_key_data(data_directory)
    # age, gender = pc.import_gender_and_age(age, gender)
    ecg_filenames = np.asarray(ecg_filenames)

    num_recordings = len(ecg_filenames)

    if not num_recordings:
        raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)

    # Extract classes from dataset.
    print('Extracting classes...')

    # All diagnoses are encoded with SNOMED-CT codes. We need a CSV-file to decode them:
    SNOMED_scored = pd.read_csv("dx_mapping_scored.csv", sep=",")
    SNOMED_unscored = pd.read_csv("dx_mapping_unscored.csv", sep=",")
    df_labels = pc.make_undefined_class(labels, SNOMED_unscored)

    y, snomed_classes = pc.onehot_encode(df_labels)
    classes = snomed_classes
    num_classes = len(classes)
    print(num_classes)

    y_all_comb = pc.get_labels_for_all_combinations(y)
    folds = pc.split_data(labels, y_all_comb)
    order_array_train = folds[0][0]
    order_array_valid = folds[0][1]
    # order_array = list(range(len(y)))

    # weights for class imbalance
    new_weights = pc.calculating_class_weights(y)
    keys = np.arange(0, 30, 1)
    weight_dictionary = dict(zip(keys, new_weights.T[1]))

    # model hyper-parameters
    batch_size = 20
    n_epochs = 20
    lr = 0.001
    sh_len = 0.05
    init = 'matrix_profile'
    ecg_freq = 500
    ecg_time_len = 10
    ecg_signal_len = ecg_freq * ecg_time_len
    norm_type = 'adc'  # adc, minmax
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_AUC', factor=0.1, patience=1, verbose=1, mode='max',
        min_delta=0.0001, cooldown=0, min_lr=0
    )

    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_AUC', mode='max', verbose=1, patience=2)

    # Train a model for each lead set.
    for leads in lead_sets:
        print('Training model for {}-lead set: {}...'.format(len(leads), ', '.join(leads)))

        filename = os.path.join(model_directory, get_model_filename(leads))
        ckpt = tf.keras.callbacks.ModelCheckpoint(filename.replace('.sav', '.h5'),
                                                  save_best_only=True,
                                                  save_weights_only=True,
                                                  monitor='val_AUC', mode='max')

        train_data = tf.data.Dataset.from_generator(generate_xy_shuffle_physio,
                                                    args=[ecg_filenames, y, order_array_train, leads, norm_type,
                                                          ecg_freq, ecg_time_len],
                                                    output_types=(tf.float32, tf.float32),
                                                    output_shapes=(tf.TensorShape((None, None)),
                                                                   tf.TensorShape((None,))))

        valid_data = tf.data.Dataset.from_generator(generate_xy_shuffle_physio,
                                                    args=[ecg_filenames, y, order_array_valid, leads, norm_type,
                                                          ecg_freq, ecg_time_len],
                                                    output_types=(tf.float32, tf.float32),
                                                    output_shapes=(tf.TensorShape((None, None)),
                                                                   tf.TensorShape((None,))))

        # imputer = SimpleImputer().fit(features)
        # features = imputer.transform(features)
        imputer = None
        # classifier = pc.residual_network_1d(input_shape=(5000, len(leads)))
        classifier = pc.shapelet_encoder_model(input_shape=(ecg_signal_len, len(leads)), data=ecg_filenames,
                                               df_labels=df_labels, leads=leads, init=init, lr=lr,
                                               classes=num_classes, sh_len=sh_len, freq=ecg_freq, norm_type=norm_type)
        train_data_ds = train_data.repeat().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
        valid_data_ds = valid_data.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)  # do not repeat
        classifier.fit(x=train_data_ds,
                       epochs=n_epochs,
                       steps_per_epoch=(len(order_array_train) / batch_size),
                       validation_data=valid_data_ds,
                       validation_steps=(len(order_array_valid) / batch_size),
                       validation_freq=1,
                       class_weight=weight_dictionary,
                       callbacks=[reduce_lr, early_stop, ckpt])
        y_pred = classifier.predict(x=generate_validation_data(ecg_filenames, y, order_array_valid, leads)[0])
        thres = get_threshold(y_pred, ecg_filenames, y, order_array_valid, leads, num_classes, norm_type)
        save_model(model_directory, leads, classes, imputer, thres, lr, sh_len, norm_type, ecg_freq, ecg_time_len)


################################################################################
#
# File I/O functions
#
################################################################################

# Save a trained model. This function is not required. You can change or remove it.
def save_model(model_directory, leads, classes, imputer, threshold, lr, sh_len, norm_type, ecg_freq, ecg_time_len):
    d = {'leads': leads, 'classes': classes, 'imputer': imputer,  'threshold': threshold, 'lr': lr, 'sh_len': sh_len,
         'norm_type': norm_type, 'ecg_freq': ecg_freq, 'ecg_time_len': ecg_time_len}
    filename = os.path.join(model_directory, get_model_filename(leads))
    joblib.dump(d, filename, protocol=0)


# Load a trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def load_model(model_directory, leads):
    filename = os.path.join(model_directory, get_model_filename(leads))
    model = joblib.load(filename)
    classifier = pc.shapelet_encoder_model(input_shape=(model['ecg_freq']*model['ecg_time_len'], len(leads)), data=None,
                                           df_labels=None, leads=leads, init='ones', lr=model['lr'],
                                           classes=len(model['classes']), sh_len=model['sh_len'],
                                           freq=model['ecg_freq'], norm_type=model['norm_type'])
    classifier.load_weights(filename.replace('.sav', '.h5'))
    model['classifier'] = classifier
    return model


# Define the filename(s) for the trained models. This function is not required. You can change or remove it.
def get_model_filename(leads):
    sorted_leads = sort_leads(leads)
    return 'model_' + '-'.join(sorted_leads) + '.sav'


################################################################################
#
# Running trained model function
#
################################################################################

# Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def run_model(model, header, recording):
    classes = model['classes']
    leads = model['leads']
    imputer = model['imputer']
    thres = model['threshold']
    classifier = model['classifier']
    norm_type = model['norm_type']
    ecg_freq = model['ecg_freq']
    ecg_time_len = model['ecg_time_len']
    ecg_signal_len = ecg_freq * ecg_time_len

    # Load features.
    num_leads = len(leads)
    # if data.shape[1] > ecg_signal_len:
    #     data = pc.random_window(data, ecg_signal_len)
    data = recording[:, :ecg_signal_len]
    data = get_processed_ecg_record(header, data, leads, ecg_freq, ecg_time_len, norm_type=norm_type)
    data = pad_sequences(data, maxlen=ecg_signal_len, truncating='post', padding="post", dtype='float64')
    data = data.reshape((1, ecg_signal_len, num_leads))

    # Impute missing data.
    # features = data.reshape(1, -1)
    # features = imputer.transform(features)

    # Predict labels and probabilities.
    probabilities = np.round(classifier.predict(data)[0].astype(np.float32), 3)
    # labels = np.copy(probabilities).astype(np.int)
    labels = (probabilities > thres) * 1

    # probabilities = classifier.predict_proba(features)
    # probabilities = np.asarray(probabilities, dtype=np.float32)[:, 0, 1]

    return classes, labels, probabilities


################################################################################
#
# Other functions
#
################################################################################

# Extract features from the header and recording.
def get_features(header, recording, leads):
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Pre-process recordings.
    adc_gains = get_adc_gains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]

    # Compute the root mean square of each ECG lead signal.
    rms = np.zeros(num_leads, dtype=np.float32)
    for i in range(num_leads):
        x = recording[i, :]
        rms[i] = np.sqrt(np.sum(x ** 2) / np.size(x))

    return age, sex, rms


def get_processed_ecg_record(header, recording, leads, out_freq=500, out_time=10, norm_type='adc'):
    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        try:
            lead = lead.decode("utf-8")
        except (UnicodeDecodeError, AttributeError):
            pass
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # return recording

    # Pre-process recordings.
    recording = np.nan_to_num(recording, nan=0.1)

    if norm_type == 'adc':
        adc_gains = get_adc_gains(header, leads)
        baselines = get_baselines(header, leads)
    num_leads = len(leads)
    recording_processed = []
    freq = get_frequency(header)
    samples = recording.shape[1]
    for i in range(num_leads):
        # recording_processed.append(preprocess_recording(recording[i, :], time=1800, minf=3, maxf=45, maxhz=500, T=15))
        recording_processed.append(
            preprocess_recording(recording[i, :], time=samples / freq, minf=3, maxf=45, maxhz=out_freq, T=out_time,
                                 resample=freq != out_freq))
        if norm_type == 'adc':
            recording_processed[-1] = (recording_processed[-1] - baselines[i]) / adc_gains[i]
        elif norm_type == 'minmax':
            recording_processed[-1] = ab_norm(recording_processed[-1], -1, 1)

    return np.array(recording_processed)


def ab_norm(record, a, b):
    thres = 10**-7
    norm_value = ((b - a) * (record - min(record) + thres) / (max(record) - min(record) + thres)) + a
    return norm_value


def bandpassFilter(signal, minf, maxf, order=2):
    fs = 500
    nyq = 0.5 * fs
    low = minf / nyq
    high = maxf / nyq
    b, a = scipy.signal.butter(order, [low, high], 'bandpass', analog=False)
    y = scipy.signal.filtfilt(b, a, signal, axis=0)
    return y


def preprocess_recording(recording, time, minf, maxf, maxhz, T, resample):
    maxtime = int(maxhz * time)
    # if resample:
    #     y = scipy.signal.resample(recording, maxtime)
    # else:
    #     y = recording
    y = scipy.signal.resample(recording, maxtime)
    y = bandpassFilter(y, minf, maxf)
    return y[:T * maxhz]


def generate_xy_shuffle_physio(X_train, y_train, order_array, leads, norm_type, ecg_freq=500, ecg_time_len=10):
    np.random.shuffle(order_array)
    ecg_signal_len = ecg_freq * ecg_time_len

    for i in order_array:
        data, header_data = pc.load_challenge_data(X_train[i], read_lines=False)
        if data.shape[1] > ecg_signal_len:
            X_train_new = pc.random_window(data, ecg_signal_len)
        else:
            X_train_new = data
        X_train_new = get_processed_ecg_record(
            header_data, X_train_new, leads, ecg_freq, ecg_time_len, norm_type=norm_type)
        X_train_new = pad_sequences(
            X_train_new, maxlen=ecg_signal_len, truncating='post', padding="post", dtype='float64')
        X_train_new = X_train_new.reshape(ecg_signal_len, len(leads))

        y_shuffled = y_train[i]

        yield X_train_new, y_shuffled


def generate_validation_data(ecg_filenames, y, test_order_array, leads, norm_type='adc', ecg_freq=500, ecg_time_len=10):
    ecg_signal_len = ecg_freq * ecg_time_len
    y_train_gridsearch = y[test_order_array]
    ecg_filenames_train_gridsearch = ecg_filenames[test_order_array]

    ecg_train_timeseries = []
    for names in ecg_filenames_train_gridsearch:
        data, header_data = pc.load_challenge_data(names, read_lines=False)
        # if data.shape[1] > ecg_signal_len:
        #     data = pc.random_window(data, ecg_signal_len)
        data = data[:, :ecg_signal_len]
        data = get_processed_ecg_record(header_data, data, leads, ecg_freq, ecg_time_len, norm_type=norm_type)
        data = pad_sequences(data, maxlen=ecg_signal_len, truncating='post', padding="post", dtype='float64')
        ecg_train_timeseries.append(data)
    X_train_gridsearch = np.asarray(ecg_train_timeseries)

    X_train_gridsearch = X_train_gridsearch.reshape(
        (ecg_filenames_train_gridsearch.shape[0], ecg_signal_len, len(leads)))

    return X_train_gridsearch, y_train_gridsearch


def get_threshold(y_pred, ecg_filenames, y, test_order_array, leads, num_classes, norm_type='adc'):
    def thr_chall_metrics(thr, label, output_prob):
        return -pc.compute_challenge_metric_for_opt(label, np.array(output_prob > thr))

    init_thresholds = np.arange(0, 1, 0.05)
    all_scores = pc.iterate_threshold(y_pred, ecg_filenames, y, test_order_array, leads=leads, norm_type=norm_type)
    best_thr = scipy.optimize.fmin(thr_chall_metrics,
                                   args=(
                                   generate_validation_data(ecg_filenames, y, test_order_array, leads=leads,
                                                            norm_type=norm_type)[1], y_pred),
                                   x0=init_thresholds[all_scores.argmax()] * np.ones(num_classes))

    return best_thr
