#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of required functions, remove non-required functions, and add your own function.
import os
import time
import datetime
import random
from random import shuffle
import warnings

import numpy as np
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv1D, MaxPooling1D, BatchNormalization, Dropout, Dense, Flatten, AveragePooling1D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

from helper_code import *
from create_model import *
from preprocess_features import *


# Set the name of each folder that will contain the model
twelve_lead_model_filename = '12'
six_lead_model_filename = '6'
four_lead_model_filename = '4'
three_lead_model_filename = '3'
two_lead_model_filename = '2'

np.random.seed(1234)
tf.random.set_seed(1234)

# REMOVE FOR SUBMISSION
os.environ["CUDA_VISIBLE_DEVICES"]="0"
tf.get_logger().setLevel('ERROR')
warnings.filterwarnings("ignore")

window_len = 5
overlap = 1

# Set the leads
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)


################################################################################
#
# Training function
#
################################################################################
def sort_leads_signals(leads, signals):
    
    # This function ensures that the leads are all sorted
    order = []
    for lead in leads:
        if lead == ('I'):
            order.append(0)
        elif lead == ('II'):
            order.append(1)
        elif lead == ('III'):
            order.append(2)
        elif lead == ('aVR'):
            order.append(3)
        elif lead == ('aVL'):
            order.append(4)
        elif lead == ('aVF'):
            order.append(5)
        elif lead == ('V1'):
            order.append(6)
        elif lead == ('V2'):
            order.append(7)
        elif lead == ('V3'):
            order.append(8)
        elif lead == ('V4'):
            order.append(9)
        elif lead == ('V5'):
            order.append(10)
        elif lead == ('V6'):
            order.append(11)
    z = list(zip(order, leads, signals))     
    z = sorted(z, key = lambda x: x[0])
    order, leads, signals = zip(*z)
    signals = np.asarray(signals)
    return leads, signals

def load_dataset(data_directory, window_len, overlap, classes, leads):
    # Loads the dataset
    # dataset: which dataset is being loaded
    dataset_path = data_directory
    header_files, recording_files = find_challenge_files(dataset_path) # get lists of header and recording files from the dataset directory
    
    
    num_recordings = len(recording_files)

    print('Found ' + str(num_recordings) + ' files')
    labels = []
    signals = []
    features = []
    
    # Display time to process signals
    if num_recordings != 0: # check to see if recordings were found 
        Time = time.time()
        previous_time = Time
        total_batches = (num_recordings//5000)+1
        
        if leads == twelve_leads:
            for i in range(num_recordings): # loop through all recordings in the dataset
                if i % (5000) == 0 and i != 0:
                    current_time = time.time()
                    elapsed_time = current_time - previous_time
                    total_batches = total_batches - 1

                    time_remaining = elapsed_time*total_batches
                    time_remaining = str(datetime.timedelta(seconds=time_remaining))
                    previous_time = current_time
                    print('Processed {} of {} files. Batches Remaining: {}, Est. time remaining: {}'.format(i, num_recordings, total_batches, time_remaining))

                header = load_header(header_files[i]) # load the header 
                recording = load_recording(recording_files[i]) # load the recording
                
                processed_signal, processed_label, processed_features = process_recording(header, recording, window_len,overlap, classes, True, leads) # process the signal and label

                if type(processed_signal[0])!= str: # check to see if the label had scored classes or not
                    processed_signal = np.asarray(processed_signal, dtype = 'float32') # change signals to float32 values

                    # appened signals, features and labels with scored classes to list of all signals, features and labels
                    signals.append(processed_signal)
                    labels.append(processed_label)
                    features.append(processed_features)

            if len(signals) != 0: # check to see if the signal still exists
                # concatenate all signals, labels and features to numpy arrays
                signals = np.concatenate(signals) 
                labels = np.concatenate(labels)
                features = np.concatenate(features)

            # Feature Scaling using standard scaler. this is done here because at this point, the features are no longer changed
            feature_scaler = StandardScaler()
            features = feature_scaler.fit_transform(features)
            # turn any NaNs to 0 for PCA handling
            features[np.isnan(features)] = 0
            # create PCA transformer, and perform PCA
            pca = PCA(n_components = 10)
            features = pca.fit_transform(features)

            print('Total processing time took {} seconds'.format(round(time.time() - Time,0)))
            return signals, labels, features, feature_scaler, pca
        
        # To prevent having to apply feature extraction again, pass through this else statement 
        else:
            for i in range(num_recordings): # loop through all recordings in the dataset
                
                if i % (5000) == 0 and i != 0:
                    current_time = time.time()
                    elapsed_time = current_time - previous_time
                    total_batches = total_batches - 1

                    time_remaining = elapsed_time*total_batches
                    time_remaining = str(datetime.timedelta(seconds=time_remaining))
                    previous_time = current_time
                    print('Processed {} of {} files. Batches Remaining: {}, Est. time remaining: {}'.format(i, num_recordings, total_batches, time_remaining))

                header = load_header(header_files[i]) # load the header 
                recording = load_recording(recording_files[i]) # load the recording 
                processed_signal, processed_label = process_recording(header, recording, window_len,overlap, classes, True, leads) # process the signal and label

                if type(processed_signal[0])!= str: # check to see if the label had scored classes or not
                    processed_signal = np.asarray(processed_signal, dtype = 'float32') # change signals to float32 values

                    # appened signals, features and labels with scored classes to list of all signals, features and labels
                    signals.append(processed_signal)
                    labels.append(processed_label)

            if len(signals) != 0: # check to see if the signal still exists
                # concatenate all signals, labels and features to numpy arrays
                signals = np.concatenate(signals) 
                labels = np.concatenate(labels)

            print('Total processing time took {} seconds'.format(round(time.time() - Time,0)))
            return signals, labels
    

def load_reduced_dataset(signals,leads):
    # Select the required leads for training
    if leads != twelve_leads:
        lead_idx = [twelve_leads.index(lead) for lead in leads] # get the index for the required leads
        signals = signals[:,:,lead_idx]
    
    # Perform signal scalling. this is done here because different number of leads requires different scalers.
    signal_scaler = StandardScaler()
    
    # Reshape to dimension stander scaler can use, scale the features, and reshape back 
    s0, s1, s2 = signals.shape[0], signals.shape[1], signals.shape[2]
    reshaped = signals.reshape(s0,s1*s2)
    reshaped = signal_scaler.fit_transform(reshaped)
    signals = reshaped.reshape(s0,s1,s2)
    return signals, signal_scaler


def train_model(data_directory, classes, leads, signals, features, labels):
    # Train models
    t = time.time() # get the time

    signals, signal_scaler = load_reduced_dataset(signals, leads)
    
    # Manual training/validation split. this is done to reduce memory usage
    print('createing split idx')
    split = 0.85
    idx = list(range(signals.shape[0]))
    train_idx, val_idx = train_test_split(idx, train_size = split)
    
    print('get val')
    signals_val = signals[val_idx]
    features_val = features[val_idx]
    labels_val = labels[val_idx]
    
    print('get training')
    signals = signals[train_idx]
    features = features[train_idx]
    labels = labels[train_idx]
    
    
    model = create_model_wide_deep(signals, leads) # create the model
    es = EarlyStopping(monitor='val_loss', mode='min', verbose=1) # Early stopping
    
    # fit the model
    print('... starting training')
    model.fit({'signal_input': signals,'feature_input': features}, labels,
                validation_data = ({'signal_input': signals_val, 'feature_input': features_val}, labels_val),
                batch_size = 64, epochs = 30, verbose = 1, callbacks = [es])
    
    print("Total training took " + str(round(time.time() - t,0)) + ' seconds')
    return model, signal_scaler

# Train your model. This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):

    # Get the scored classes
    classes = list(pd.read_csv('dx_mapping_scored.csv')['SNOMEDCTCode'])
    classes = sorted(classes, key=lambda x: int(x))
    num_classes = len(classes)
    
    # Create a directory for the model
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)
    
    # run each model for each number of leads
    leads = twelve_leads
    signals, labels, features, feature_scaler, pca = load_dataset(data_directory,5,1,classes, leads)
    filename = os.path.join(model_directory, twelve_lead_model_filename)
    print('\n . . . training 12 leads model')
    classifier, signal_scaler= train_model(data_directory, classes, leads, signals, features, labels)
    save_model(filename, classes, leads, classifier, signal_scaler, feature_scaler, pca)
    
    leads = six_leads
    signals, labels = load_dataset(data_directory,5,1,classes, leads)
    filename = os.path.join(model_directory, six_lead_model_filename)
    print('\n . . . training 6 leads model')
    classifier, signal_scaler = train_model(data_directory, classes, leads, signals, features, labels)
    save_model(filename, classes, leads, classifier, signal_scaler, feature_scaler, pca)
    
    leads = four_leads
    signals, labels = load_dataset(data_directory,5,1,classes, leads)
    filename = os.path.join(model_directory, four_lead_model_filename)
    print('\n . . . training 4 leads model')
    classifier, signal_scaler = train_model(data_directory, classes, leads, signals, features, labels)
    save_model(filename, classes, leads, classifier, signal_scaler, feature_scaler, pca)
    
    leads = three_leads
    signals, labels = load_dataset(data_directory,5,1,classes, leads)
    filename = os.path.join(model_directory, three_lead_model_filename)
    print('\n . . . training 3 leads model')
    classifier, signal_scaler, = train_model(data_directory,classes, leads, signals, features, labels)
    save_model(filename, classes, leads, classifier, signal_scaler, feature_scaler, pca)
    
    leads = two_leads
    signals, labels = load_dataset(data_directory,5,1,classes, leads)
    filename = os.path.join(model_directory, two_lead_model_filename)
    print('\n . . . training 2 leads model')
    classifier, signal_scaler = train_model(data_directory, classes, leads, signals, features, labels)
    save_model(filename, classes, leads, classifier, signal_scaler, feature_scaler, pca)


################################################################################
#
# File I/O functions
#
################################################################################

# Save your trained models.
def save_model(filename, classes, leads, classifier, signal_scaler, feature_scaler, pca):
    # Construct a data structure for the model and save it.
    classifier.save(filename + '/model')
    pd.DataFrame(classes).to_csv(filename + '/classes.csv', sep=',',index=False)
    pd.DataFrame(leads).to_csv(filename + '/leads.csv', sep=',',index=False)
    scaler = {'signal_scaler': signal_scaler, 'feature_scaler': feature_scaler, 'pca': pca}
    pickle.dump(scaler, open(filename + '/scaler.pkl', 'wb'))



# Generic function for loading a model.
def load_model(filename, leads):
    filename = filename + '/' + str(len(leads))
    
    leads = sort_leads(leads)
    classes = pd.read_csv(filename + '/' + 'classes.csv')['0'].to_list()
    classifier = tf.keras.models.load_model(filename + '/' + 'model')
    signal_scaler = pickle.load(open(filename + '/scaler.pkl', 'rb'))['signal_scaler']
    feature_scaler = pickle.load(open(filename + '/scaler.pkl', 'rb'))['feature_scaler']
    pca = pickle.load(open(filename + '/scaler.pkl', 'rb'))['pca']
    model = {'classes': classes, 'leads': leads, 'classifier': classifier,
             'signal_scaler': signal_scaler, 'feature_scaler': feature_scaler, 'pca': pca }
    return model

################################################################################
#
# Running trained model functions
#
################################################################################



# Generic function for running a trained model.
def run_model(model, header, recording):

    window_len = 5
    overlap = 1
    
    # load the model in 
    classes = model['classes']
    leads = model['leads']
    classifier = model['classifier']
    signal_scaler = model['signal_scaler']
    feature_scaler = model['feature_scaler']
    pca = model['pca']

    # process the recording
    signals, features = process_recording(header, recording, window_len, overlap, classes, False, leads)
    #signals = load_reduced_dataset(signals, leads)
    
    # scale the recording
    s0, s1, s2 = signals.shape[0], signals.shape[1], signals.shape[2]
    reshaped = signals.reshape(s0,s1*s2)
    reshaped = signal_scaler.transform(reshaped)
    signals = reshaped.reshape(s0,s1,s2)
    
    #scale the features
    features[np.isnan(features)] = 0
    features = feature_scaler.transform(features)
    features = pca.transform(features)
    
    # stack the features vector on itself multiple times if more than 1 window exists
    features = np.tile(features, (s0, 1))


    # Predict labels and probabilities.
    labels = classifier.predict({'signal_input':signals, 'feature_input': features})
    labels = np.max(np.asarray(np.around(labels), dtype=np.int), axis = 0) # return a true label if any of the windows predict a true label

    probabilities = classifier.predict({'signal_input':signals, 'feature_input': features})
    probabilities = np.max(np.asarray(probabilities, dtype=np.float32), axis = 0)

    return classes, labels, probabilities
