#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of the required functions, remove non-required functions, and add your own functions.

################################################################################
#
# Imported functions and variables
#
################################################################################

# Import functions. These functions are not required. You can change or remove them.
from helper_code import *
import numpy as np, os, sys, joblib
from sklearn.ensemble import RandomForestClassifier

import pandas as pd
from scipy import signal
import autosklearn.classification
import autosklearn.metrics
from evaluate_model import load_weights, compute_challenge_metric

# Define the Challenge lead sets. These variables are not required. You can change or remove them.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)

################################################################################
#
# Training model function
#
################################################################################

def c_metric(solution, prediction):

    weights_file = 'weights.csv'
    sinus_rhythm = set(['426783006'])
    #classes = ["270492004", "164889003", "164890007", "426627000", "713427006", "713426002", "445118002", "39732003", "164909002", "251146004",
    #           "698252002", "10370003", "284470004", "427172004", "164947007", "111975006", "164917005", "47665007", "427393009",
    #           "426177001", "426783006", "427084000", "164934002", "59931005"]
    classes, weights = load_weights(weights_file)

    return compute_challenge_metric(weights, solution, prediction, classes, sinus_rhythm)

# Train your model. This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    # Find header and recording files.
    print('Finding header and recording files...')

    header_files, recording_files = find_challenge_files(data_directory)
    num_recordings = len(recording_files)

    if not num_recordings:
        raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)

    # Use only the scored classes

    classes = ["270492004", "164889003", "164890007", "426627000", "713427006", "713426002", "445118002", "39732003", "164909002", "251146004",
               "698252002", "10370003", "284470004", "427172004", "164947007", "111975006", "164917005", "47665007", "427393009",
               "426177001", "426783006", "427084000", "164934002", "59931005", "59118001", "63593006", "17338001"]

    equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]

    num_classes = len(classes)

    # Extract the features and labels from the dataset.
    print('Extracting features and labels...')

    data = np.zeros((num_recordings, 12, 625), dtype=np.float32) # features: 5s per 12 leads at 125Hz
    labels = np.zeros((num_recordings, num_classes), dtype=np.bool) # One-hot encoding of classes

    for i in range(num_recordings):
        print('    {}/{}...'.format(i+1, num_recordings))

        # Load header and recording.
        header = load_header(header_files[i])
        recording = load_recording(recording_files[i])

        # Get age, sex and root mean square of the leads.
        data[i, :] = get_features(header, recording, twelve_leads)

        current_labels = get_labels(header)
        for label in current_labels:
            if label in classes:
                j = classes.index(label)
                labels[i, j] = 1
            # setting duplicate labels
            if label in equivalent_classes[0] :
                j01 = classes.index(equivalent_classes[0][0])
                j02 = classes.index(equivalent_classes[0][1])
                labels[i, j01] = 1
                labels[i, j02] = 1
            if label in equivalent_classes[1] :
                j11 = classes.index(equivalent_classes[1][0])
                j12 = classes.index(equivalent_classes[1][1])
                labels[i, j11] = 1
                labels[i, j12] = 1
            if label in equivalent_classes[2] :
                j21 = classes.index(equivalent_classes[2][0])
                j22 = classes.index(equivalent_classes[2][1])
                labels[i, j21] = 1
                labels[i, j22] = 1

    labels = labels[:, :num_classes-3]

    # redefine new classes
    classes = classes[:-3]

    # Train a model for each lead set.
    for leads in lead_sets:
        print('Training model for {}-lead set: {}...'.format(len(leads), ', '.join(leads)))

        features = data

        # Train the model.

        challenge_scorer = autosklearn.metrics.make_scorer(
            name="challenge_metric",
            score_func=c_metric,
            optimum=1,
            greater_is_better=True,
            needs_proba=False,
            needs_threshold=False,
        )

        feature_indices = [twelve_leads.index(lead) for lead in leads]
        features = data[:, feature_indices, :]
        features = features.reshape(num_recordings, len(feature_indices)*features.shape[2])

        exec_time = int(len(leads) * 2.5 * 3600)

        classifier = autosklearn.classification.AutoSklearnClassifier(metric = challenge_scorer,
                                                                      include_estimators = ["mlp"],
                                                                      include_preprocessors = ["pca", "kernel_pca"],
                                                                      memory_limit = 28000,
                                                                      ensemble_size=1,
                                                                      time_left_for_this_task = exec_time).fit(features, labels)
        print(classifier.sprint_statistics())
        print(classifier.show_models())

        # Save the model.
        save_model(model_directory, leads, classes, classifier)

################################################################################
#
# Running trained model function
#
################################################################################

# Run your trained model. This function is *required*. Do *not* change the arguments of this function.
def run_model(model, header, recording):
    classes = model['classes']
    leads = model['leads']
    classifier = model['classifier']

    # Load features.
    num_leads = len(leads)
    data = np.zeros(num_leads*625, dtype=np.float32) #change to num_leads*625

    data = get_features(header, recording, leads)

    features = data.reshape(1, -1)

    # Predict labels and probabilities.
    labels = classifier.predict(features)
    labels = np.asarray(labels, dtype=np.int)[0]

    probabilities = classifier.predict_proba(features)[0]

    return classes, labels, probabilities

################################################################################
#
# File I/O functions
#
################################################################################

# Save a trained model. This function is not required. You can change or remove it.
def save_model(model_directory, leads, classes, classifier):
    d = {'leads': leads, 'classes': classes, 'classifier': classifier}
    filename = os.path.join(model_directory, get_model_filename(leads))
    joblib.dump(d, filename, protocol=0)

# Load a trained model. This function is *required*. Do *not* change the arguments of this function.
def load_model(model_directory, leads):
    filename = os.path.join(model_directory, get_model_filename(leads))
    return joblib.load(filename)

# Define the filename(s) for the trained models. This function is not required. You can change or remove it.
def get_model_filename(leads):
    return 'model_' + '-'.join(sort_leads(leads)) + '.sav'

################################################################################
#
# Feature extraction function
#
################################################################################

# Extract features from the header and recording. This function is not required. You can change or remove it.
def get_features(header, recording, leads):

    # Reorder/reselect leads in recordings.
    recording = choose_leads(recording, header, leads)

    # Pre-process recordings.
    adc_gains = get_adc_gains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]

    fs = get_frequency(header)
    length_in_samples = int(get_num_samples(header))
    length_in_seconds = int(length_in_samples / fs)
    target_fs = 125.0

    # check if downsampling is needed
    if(fs > target_fs) :
        recording = signal.decimate(recording, int(fs/target_fs))

    # check if upsampling is needed
    if(fs < target_fs) :
        measured_time = np.linspace(0, length_in_samples, length_in_samples)
        linear_interp = interp1d(measured_time, recording)
        target_len_samples = int(length_in_seconds*target_fs)
        interpolation_time = np.linspace(0, length_in_samples, target_len_samples)
        recording = linear_interp(interpolation_time)

    # cutting to 10 seconds
    recording = recording[:, :int(target_fs)*5]

    # median filtering
    recording = signal.medfilt(recording, kernel_size = (1, 5))

    # band-pass filtering
    fs = 125
    b, a = signal.butter(3, 2 * np.array([0.67, 30]) / fs, btype = 'bandpass')
    recording = signal.filtfilt(b, a, recording)

    # zero padding if needed
    if(recording.shape[1] != int(target_fs)*5) :
        shape = np.shape(recording)
        padded_array = np.zeros((len(leads), int(target_fs)*5))
        padded_array[:shape[0],:shape[1]] = recording
        recording = padded_array
    assert recording.shape[1] == int(target_fs)*5
    return recording#.flatten()
