#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of the required functions, remove non-required functions, and add your own functions.

################################################################################
#
# Imported functions and variables
#
################################################################################

from challengePackage import model_training_code as mtc
from challengePackage import prediction_code as prc
from challengePackage import processing_code as pc
from challengePackage import eval_code as ec
from helper_code import *

# Import multilabel stratification to perform K fold CV
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from sklearn.preprocessing import MultiLabelBinarizer

import numpy as np, os, sys
import tensorflow as tf  # Tensorflow
from scipy.io import loadmat  # Required to load .mat files
from scipy import signal # Library for preprocessing
from tqdm.auto import tqdm  # For progress bars
import pandas as pd
import math
import json
import random
import warnings

# Set the SEED for partial reproducibility
SEED = 1234
tf.random.set_seed(SEED)
np.random.seed(seed=SEED)

# Kaggle key
kaggle = False

# configure GPU
tf.config.list_physical_devices('GPU')

if tf.test.gpu_device_name(): 
    print('Default GPU Device:{}'.format(tf.test.gpu_device_name()))
else:
    print("Please install GPU version of TF")
    

# This function was defined to save pikle. However it does not work in Challenge submissions
if kaggle:
    json_directory = '/kaggle/working/json_files'
else:
    json_directory = './json_files'

# Define the Challenge lead sets. These variables are not required. You can change or remove them.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)
developing = False
if developing: 
    import seaborn as sns
    import matplotlib.pyplot as plt

################################################################################
#
# Training model function
#
################################################################################

# Train your model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    #################################### DIRECTORIES AND SETTINGS ##############################
    CURRENT_BS = 64
    CURRENT_EPOCHS = 12
    CURRENT_ES = False
    FINE_TUNING = True

    weights_file = './challengePackage/weights_new.csv'
    existing_folder = './model_weights' # directory with pretrained models (only weights and jsons without threshold)
    ################################################################################################
    
    #### STEP 0 ####
    # define and create directories
    # 1. Where to save the trained models
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)
        print('Model directory created!')
    # 2. Where to save the developing data
    if not os.path.isdir(json_directory):
        os.mkdir(json_directory)
        print('Json dir created!')

    #### STEP 1 ####
    # Find header and recording files.
    print('Finding header and recording files...')
    # 1. Header_files: list of header directories
    # 2. recording_files: list of signal directories
    header_files, recording_files = find_challenge_files(data_directory)
    num_recordings = len(recording_files)
    print(num_recordings)
    if not num_recordings:
        raise Exception('No data was provided.')

    #### STEP 2 ####
    # Extract classes from dataset: the challenge function generates a list of 26 sets, in which the equivalent classes are in the same set.
    print('Extracting classes...')
    classes, weights = ec.load_weights(weights_file)

    #### STEP 3 ####
    print('Extracting data from headers ...')
    label_dataframe, header_data_dict = pc.create_header_label_data(header_files, classes, label_df=True) # Header data and signal features to be used in Generator
    list_datasets = pc.dataset_generator(label_dataframe, json_directory, json_save=True) # Dataset split in 3 subsets 
    
    #### STEP 4 ####
    # Extract features from dataset for the wide branch
    print('Extracting time features from signals...')
    header_list = [s + '.hea' for s in [*header_data_dict]] # Here we generate list only for the files that have classes belonging to the challenge 
    recording_list = [s + '.mat' for s in [*header_data_dict]] # Here we generate list only for the files that have classes belonging to the challenge 
    
    features_dict = pc.signal_features(header_list, recording_list, model_directory) 

    #### STEP 6 ####
    # Cycle over combinations
    # len(leads_combinations)
    # Train a model for each lead set.
    for leads in lead_sets:
        print('Training model for {}-lead set: {}...'.format(len(leads), ', '.join(leads)))
        current_leads = leads
        print('Actual leads considered are:\n')
        print(leads)
        n_leads = len(leads)
        filename = get_model_filename(leads)
        print(f'The model will be saved with the name: {filename}')
        
        
        #### STEP 7 ####
        # Cycle over Datasets
        # len(list_datasets)
        for data in range(1):
            print(f'Working on Dataset no. {data + 1} ...')
            current_dataset = list_datasets[data] # Select i-th dataset from the dataset's list
            # Prepare for Multilabel Stratified Split
            current_X = current_dataset.file_names#.tolist()
            classes_flat = pc.classes_list(classes)
            mlb = MultiLabelBinarizer(classes = classes_flat) #classes=classes
            current_y = mlb.fit_transform(current_dataset.labels)
            print(mlb.classes_)
            print(current_dataset.labels)
            print(current_y)
            #### STEP 8 ####
            print('Train-Validation splitting ...')
            valid_kfold = MultilabelStratifiedKFold(n_splits=5) # 5 fold cross validation for each model (1 model for each dataset in the submission)
            # At submission time we split the dataset in train-test with a ratio of 0.8 -0.2.
            # This is achieved by using the indexes of the first train-validation split (break at the end of the for loop)
            # During training time all the 5 folds should be executed. The best trained model should be selected (considering overfitting and performance on the challenge metric)
            fold = 0
            for train_index, valid_index in valid_kfold.split(current_X, current_y):
                fold = fold + 1
                header_train_list = []
                header_valid_list = []
                # len(train_index)
                for i in range(len(train_index)):
                    header_train_list.append(current_X[train_index[i]])
                # len(valid_index)
                for j in range(len(valid_index)):
                    header_valid_list.append(current_X[valid_index[j]])
                # lists containing the partitions of train/validation for the current split
                current_name_train_list = pc.name_list(header_train_list)
                current_name_valid_list = pc.name_list(header_valid_list)

                # The portion of the dictionary pertaining the current split is defined as a new dictionary
                current_header_train_dict = {ith_key: header_data_dict[ith_key] for ith_key in current_name_train_list}
                current_header_valid_dict = {ith_key: header_data_dict[ith_key] for ith_key in current_name_valid_list}
                # k fold cv should be done
                break
            #### STEP 9 ####
            print("Defining data generators...")
            # Generators are initialised using both the header dictionary and feature dictionary generated for each recording
            # The generator is defined considering the current number of leads (the input is shaped accordingly)

            # self, header_dict,feature_dict, list_IDs, lead_selected=twelve_leads, data_dir=data_directory, batch_size=b_s, dim= wl,
            #     n_classes=ncl, sampling_freq=fs, shuffle=True
            training_generator = mtc.DataGenerator(current_header_train_dict, features_dict, current_name_train_list,
                                                    current_leads, batch_size=CURRENT_BS, shuffle=False)
            validation_generator = mtc.DataGenerator(current_header_valid_dict, features_dict, current_name_valid_list,
                                                    current_leads, batch_size=CURRENT_BS, shuffle=False)
            
            
            # Generate a unique name for each dataset and lead combo. 
            filename_data = filename + '_D' + str(data + 1)   
            filename_ver = os.path.join(existing_folder, filename)
            filename_ver_data = filename_ver + '_D' + str(data + 1) 
            
            
            # Check if the weights of the model are present and if true fine tune the model. 
            if os.path.isfile(filename + '_weights.h5') and FINE_TUNING:
                print("There was already a pretrained network, using it for fine tuning!")
                #### STEP 10 A ####
                # Training of the model perfomed using the training and validation generator. 
                # It receives also parameters to perform fine tuning and potentially also early stopping (not performed at submission time)
                # There is only one function which adapts to the number of leads (channels)

                # model_name, training_gen, validation_gen, model_weights = '', fine_tuning= False, freeze_u = 112, num_features = num_feats,
                #  es = True, sd = True, num_ResBs= 8, channels = nch, window_len = wl, num_classes = ncl,  bs= b_s, epochs= 50
                history_t, model = mtc.model_training_steps(filename_ver_data, training_generator,
                                                            validation_generator,
                                                            model_weights = filename + '_weights.h5',
                                                            freeze_u=50,
                                                            num_features = 16,
                                                            fine_tuning=FINE_TUNING,
                                                            es=CURRENT_ES,
                                                            channels=len(current_leads),
                                                            bs=CURRENT_BS,
                                                            epochs=CURRENT_EPOCHS,
                                                            train_deep = False)
                
                if os.path.isfile(filename_ver_data+'.json'):
                    # load json 
                    current_json = {}
                    with open(filename_ver_data+'.json') as json_dict: 
                        current_json = json.load(json_dict)
                else: 
                    current_json = {}
                    current_json['leads'] = current_leads
                    current_json['ix_fold'] = fold
                    current_json['ix_dataset'] = str(data+1)
                    current_json['classes'] = classes_flat
                    

                
            else: 
                print("No pretrained network, complete training in progress ...")
                #### STEP 10 B ####
                # Training of the model perfomed using the training and validation generator. 
                # It receives also parameters to perform fine tuning and potentially also early stopping (not performed at submission time)
                # There is only one function which adapts to the number of leads (channels)
                history_t, model = mtc.model_training_steps(filename_ver_data, training_generator,
                                                            validation_generator,
                                                            num_features = 16,
                                                            fine_tuning=False,
                                                            es=CURRENT_ES,
                                                            freeze_u = 110,
                                                            channels=len(current_leads),
                                                            bs=CURRENT_BS,
                                                            epochs=CURRENT_EPOCHS,
                                                            train_deep = True)
                
                current_json = {}
                current_json['leads'] = current_leads
                current_json['ix_fold'] = fold
                current_json['ix_dataset'] = str(data+1)
                current_json['classes'] = classes_flat
                #model.summary()
                filename_chosen = os.path.join(model_directory, filename)
                filename_chosen_dir = filename_chosen + '_D' + str(data+1)
                model.save(filename_chosen_dir + '.h5')
                #### STEP 11 ####
                print("Evaluating the model on the remaining left over data of the fold ...")
                # Average performance will be computed accross folds
                # Generate generalization metrics 
                scores_fold = model.evaluate(validation_generator, verbose=0)
                print(scores_fold)
                print(f'Score: {model.metrics_names[0]} of {scores_fold[0]}; {model.metrics_names[1]} of {scores_fold[1]}; {model.metrics_names[2]} of {scores_fold[2]}; Precision of {scores_fold[3]}')
                current_json['valid_scores'] = scores_fold
                
                # perform threshold
                thresholds, A = prc.threshold_optimization([*current_header_valid_dict],features_dict, model, current_leads)
                
                # normalize conf matrix - The previous way gave NaNs for reduced lead sets A.max(axis = 0)
                A_normed = A / max(map(max, A))

                
                if developing:
                    # Names used to save images with training.
                    train_hist_name = filename_chosen_dir

                    # print(history_t)
                    ### PLOT TRAINING ACCURACY DURING EACH EPOCH ###
                    fig1 = plt.figure(figsize=(8, 6))
                    plt.plot(history_t.history['accuracy'])
                    plt.plot(history_t.history['val_accuracy'])

                    plt.title('model accuracy')
                    plt.ylabel('accuracy')
                    plt.xlabel('epoch')
                    lgd = plt.legend(['train', 'validation'], loc='best')
                    plt.grid(False)
                    plt.savefig(train_hist_name + '_accuracy.jpg')
                    plt.close()

                    ### PLOT TRAINING LOSS DURING EACH EPOCH ###
                    fig2 = plt.figure(figsize=(8, 6))
                    plt.plot(history_t.history['loss'])
                    plt.plot(history_t.history['val_loss'])

                    plt.title('model loss')
                    plt.ylabel('loss')
                    plt.xlabel('epoch')
                    lgd = plt.legend(['train', 'validation'], loc='upper left')
                    plt.savefig(train_hist_name + '_loss.jpg')
                    plt.close()

                    ### PLOT TRAINING LOSS DURING EACH EPOCH ###
                    fig3 = plt.figure(figsize=(8, 6))
                    plt.plot(history_t.history['recall'])
                    plt.plot(history_t.history['val_recall'])

                    plt.title('model recall')
                    plt.ylabel('recall')
                    plt.xlabel('epoch')
                    lgd = plt.legend(['train', 'validation'], loc='upper left')
                    plt.savefig(train_hist_name + '_recall.jpg')
                    plt.close()

                    ### PLOT PRECISION DURING EACH EPOCH ###
                    fig4 = plt.figure(figsize=(8, 6))
                    plt.plot(history_t.history['precision'])
                    plt.plot(history_t.history['val_precision'])

                    plt.title('model precision')
                    plt.ylabel('precision')
                    plt.xlabel('epoch')
                    lgd = plt.legend(['train', 'validation'], loc='upper left')
                    plt.savefig(train_hist_name + '_precision.jpg')
                    plt.close()
                
                    plt.figure(figsize=(20,20))
                    h = sns.heatmap(A_normed, annot=True, cmap="Blues")
                    h.set_xticklabels(labels=['AF', 'AFL' , 'BBB' , 'Brady' , 'CLBBB|LBBB', 'CRBBB|RBBB', 'IAVB', 'IRBBB', 'LAD', 'LAnFB', 'LPR', 'LQRSV', 'LQT', 'NSIVCB', 'NSR', 'PAC|SVPB', 'PR', 'PRWP', 'PVC|VPB', 'QAb', 'RAD', 'SA', 'SB', 'STach', 'TAb', 'TInv'], rotation=60)
                    h.set_yticklabels(labels=['AF', 'AFL' , 'BBB' , 'Brady' , 'CLBBB|LBBB', 'CRBBB|RBBB', 'IAVB', 'IRBBB', 'LAD', 'LAnFB', 'LPR', 'LQRSV', 'LQT', 'NSIVCB', 'NSR', 'PAC|SVPB', 'PR', 'PRWP', 'PVC|VPB', 'QAb', 'RAD', 'SA', 'SB', 'STach', 'TAb', 'TInv'], rotation=360)
                    plt.title('Confusion Matrix')
                    plt.ylabel('Actual Values')
                    plt.xlabel('Predicted Values')
                    #plt.show()
                    plt.savefig(filename_chosen_dir + '_confusion_matrix.png')

                # save model
                save_model(filename_chosen_dir, model, current_json, thresholds, A_normed, model_directory)
                print("Saved the model!")
        

################################################################################
#
# Running trained model function
#
################################################################################

# Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def run_model(model, header, recording):
    threshold = prc.extract_model_thresholds(model[1])
    
    pickle_dir = model[1][0]['model_dir']
    classes, labels, probabilities = prc.ensemble_evaluation(model[0], model[1], header, recording, threshold, pickle_dir)
  
    return classes, labels, probabilities

################################################################################
#
# File I/O functions
#
################################################################################

# Save trained models.
def save_model(filename_dir, model, current_json, thresholds, A, model_dir):
    
    # saving in the json also the thesholds and confusion matrix
    current_json['thresholds'] = thresholds.tolist()
    current_json['Conf_matrix'] = A.tolist()
    current_json['model_dir'] = model_dir
    # saving classes in a dictionary in json format
    with open(filename_dir + '.json', 'w') as json_file:
        json.dump(current_json, json_file)

    model.save(filename_dir + '.h5')
    
    
# Load a trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def load_model(model_directory, leads):
    
    # ISSUE HERE: the filename is just the directory of the model without ext. One model only
    name_model = get_model_filename(leads)
    filename = prc.find_models(model_directory, name_model)
    models_json = prc.load_ens_models(filename[0], filename[1])

    return models_json

# Define the filename(s) for the trained models. This function is not required. You can change or remove it.
def get_model_filename(leads):
    sorted_leads = sort_leads(leads)
    return 'model_' + '-'.join(sorted_leads)
