####### PREDICTION_CODE.PY ######### 
# Updated by Stefano Magni 5 Apr. 2021. 

# This module (prc package) contains the following functions:
#
# add_extension(list_no_ext)
# listToString(a_list)
# pred_to_file(output_dir, file_name, snowmed_labels_red, bin_pred, proba_pred)
# threshold_optimization(name_valid_list, model, selected_leads, name_current_fold)
# thr_to_challenge_metric(thr, weights, classes, sinus_rhythm, Y_test, predictions)
# pred_to_binary(predictions,ths)
# get_windows(N, W, O)
# average_prediction(pred)
# find_models(model_directory, model_name)
# select_best_model(files_dirs)
# ensemble_evaluation(all_models, all_json, header, recording, thr_opt_models)
# load_ens_models(model_files, json_files)
# extract_model_thresholds(all_classes)
#
#
# POTENTIAL ISSUES 
 
# TO DOs 

import numpy as np, os, os.path, sys

import challengePackage.eval_code as ec
import challengePackage.processing_code as pc
import challengePackage.model_training_code as mtc
from helper_code import *
from tqdm.auto import tqdm
import pickle
import json
import tensorflow as tf
from scipy.optimize import fmin
import pandas as pd

###### DEFAULTS ########
ths_dummy = np.full(26, 0.1)
########################


def save_fold_model(filename_dir, model, labels_map, leads, scores, ix_fold, ix_dataset): 
    '''
        Save data from the fold. 
        Dependences: 
            json

        Args: 
            filename_dir: where to save the data
            model: model trained
            labels_map: dictionary with label encoding into integers
            leads: leads available
            scores: results on validation set
            ix_fold, ix_dataset: indexes of fold and dataset of the training

    '''
    dictionary_tmp = {}
    dictionary_tmp['leads'] = leads 
    dictionary_tmp['classes'] = [*labels_map]
    dictionary_tmp['scores'] = scores
    dictionary_tmp['ix_fold'] = ix_fold
    dictionary_tmp['ix_dataset'] = ix_dataset

    # saving json
    with open(filename_dir +'.json', 'w') as json_file:
        json.dump(dictionary_tmp, json_file)

    model.save(filename_dir + '.h5')

####### FUNCTIONS TO PERFORM AND EVALUATE PREDICTION #######
def get_windows(N, W, O):

    '''
        In prediction code (prc)

        Dependences: 
            np 
        Args: 
            N: length of the signal 
            W: window length 
            O: overlap (samples)
        
        Returns: 
            starts: start indexes 
            ends: end indexes 
        
        Created 22 Feb 2021. 12:58 - Stefano Magni 
        This function recieves as input the number of samples of a signal, the length 
        of the window and the overlap. 
        Returns the start and end indexes to extract the windows from the signal. 
        Note that the last window's overlap adapts to the length of the signal (no padding) 
    '''

    # Split a signal of length n into olap% overlapping windows each containing Mt terms 
    starts = []
    ends = []

    if W-O != 0: 
        P = np.ceil((N- W)/(W -O))+1
    else: 
        print('error: W-O is equal to 0! Cannot divide by 0')

    start = 0
    for i in range(int(P)): 
        
        end = start + W
        if start > N:
            break
            # If index is > len signal then the last window is W from end. 
        elif end > N: 
            end = N 
            start = N-W
        
        starts.append(start)
        ends.append(end)
        
        start = end - O 
        
    return starts, ends

def average_prediction(pred): 
    '''
        In prediction code (prc)

        Dependences: 
            np
        Args: 
            pred: predictions as returned by the model.predict
        
        Return: 
            avg_pred: average prediction over positional itemes among windows
    '''
    # DEBUGGING: print(pred.shape)
    # STEP 0: the size of the prediction is taken to perform average 
    num_w = pred.shape[0]
    num_classes = pred.shape[1]
    # array is initialized 
    avg_pred = np.zeros(26)
    # DEBUGGING: print(num_w)

    # if we have more than one window the mean is computed. 
    # The mean is not over array values but among windows
    if num_w > 1: 
        avg_pred = np.mean(pred, axis=0)
        # DEBUGGING: print(avg_pred)

    # if the window is only one the prediction is appended as is. 
    else: 
        avg_pred= pred
    # thelist is transformed into an array
    # update: since it is a list of lists the first element is taked
    # avg_pred = np.array(avg_pred)

    if len(avg_pred) != 26: 
        avg_pred = avg_pred[0]

    # DEBUGGING: print(avg_pred.shape)
    return avg_pred

def pred_to_binary(predictions,ths):
    '''
        In prediction code (prc)

        Dependences: 
            np
        Args:
            predictions: predictions obtained after function val_test_predictions. Should be array.
                        If list it is transformed into array. Should be of shape (1,1, num_classes)
            th: vector of thresholds to recognise classes of len num_classes (np.float32)
        
        Return:
            binary_predictions: one hot vector describing the predicted classes by the model

        Created 22 Feb 2021 10:30 - Stefano Magni
        This function converts the prediction in binary form by considering the predictions as arrays of
        shape (1,26).
    '''
    # if the predictions are a list then they are converted into array.
    if isinstance(predictions, list):
        predictions = np.array(predictions)
        
    # Initialise array for the predictions
    #DEBUGGIN: print(predictions.shape)
    if predictions.ndim>1:
        binary_predictions=np.zeros((predictions.shape[0],predictions.shape[1]), dtype = np.bool)
        # Range over all predictions and set them to either 0 or 1
        # The i-th element is the prediction array corresponding to the i-th signal
        for i in range(predictions.shape[0]):
            # The j-th element is the j-th probability class belonging to the i-th signal
            for j in range(predictions.shape[1]):
                # DEBUGGING: print(predictions[i][0][j])
                # DEBUGGING: print(ths[j])
                binary_predictions[i][j] = 1*(predictions[i][j] > ths[j])

    else:        
        # Initialise array for the predictions
        # DEBUGGIN: print(predictions.shape)
        binary_predictions=np.zeros(predictions.shape[0], dtype = np.bool)
        # Range over all predictions and set them to either 0 or 1
        # The i-th element is the prediction array corresponding to the i-th signal 
        for i in range(len(binary_predictions)): 
            # The j-th element is the j-th probability class belonging to the i-th signal
            binary_predictions[i] = 1*(predictions[i] > ths[i])

    # Return the array of binary predictions
    # return np.array(binary_predictions)
    return binary_predictions


def add_extension(list_no_ext):
    '''
        In prediction code (prc)

        Dependences: 
            os
        Args: 
            list_no_ext: list of strings of directories without extention 
        Returns: 
            two lists corresponding to headers and mat files 
    '''
    recording_files = list()
    header_files = list()
    for i in range(len(list_no_ext)):
        header_file = os.path.join(list_no_ext[i] + '.hea')
        recording_file = os.path.join(list_no_ext[i] + '.mat')
        header_files.append(header_file)
        recording_files.append(recording_file)
    return header_files, recording_files

def thr_to_challenge_metric(thr, weights, classes, sinus_rhythm, Y_test, predictions):
    '''
        In prediction code (prc)

        Dependences: 
            compute_challenge_metric from eval_code
        Args: 
            thr: threshold used to perform prediction
            weights: define by the challenge that represent the metric 
            Y_test: labels true
            classes: classes of the challenge that are evaluated
            sinus_rhythm: class defined as reference (NSR)
        Returns: 
            metrica: metric of the challenge
            A: confusion matrix 
    '''
    metrica, A = ec.compute_challenge_metric(weights, Y_test, pred_to_binary(predictions, thr), classes, sinus_rhythm)
    return (-1)*metrica

def threshold_optimization(name_valid_list, features_dict, model, selected_leads):
    '''
        In prediction code (prc)
        
        Function that performs threshold optimisation as proposed by: 
        ResNet with Squeeze and excitation blocks modified as suggested by: Adaptive lead Weighted ResNet Trained With Different Duration Signals for 
        Classifying 12-lead ECGs. Zhao, Wong et al. 2020
        Dependences: 
            scipy, np, eval_code (in challengePackage), helper code, model_training_code (in challengePackage)
        Args: 
          name_valid_list: 
          features_dict: 
          model: 
          selected_leads: 
        Returns: 
          best_threshold: 
          A: 

    '''
    print('Optimizing threshold...')  
    ###STEP1###
      
    print('Extracting classes...')
    weights_file = './challengePackage/weights_new.csv'
    classes, weights = ec.load_weights(weights_file)
    sinus_rhythm = set(['426783006'])

    print("printing type of classes for debug...")
    #print(f'Checking that sinus rithm class is correct: {sinus_rhythm} is equal to 426783006?')

    # Get the lists for header and recording files
    header_files, recording_files = add_extension(name_valid_list)
    print('Extension added!')

    # load the true labels of the validation set
    labels = ec.load_labels(header_files, classes)    
    print('True labels loaded!')
    #print(f"True labels: {(len(labels))}") 
    ###STEP 2### 
    # We use the for cycles from the first part of run_model to obtain the predictions and save them in a matrix 

    sampling_freq = 257
    W = 4096
    # O = 256
    num_recordings = len(recording_files)

    # Initialize the matrix probability as an empty matrix

    prob_saved = []
    # Save the probabilities for each recording
    for i in tqdm(range(num_recordings)):
        # Load header and recording.
        header = load_header(header_files[i])
        recording = load_recording(recording_files[i])
        leads = get_leads(header)
        CH = len(selected_leads)
        # start_window = np.zeros((CH, W))
        # current_windows = []
        current_feature_dict = features_dict[name_valid_list[i]]
        current_features =  list(current_feature_dict.values())
        # Same preprocessing is applied to the test set
        current_signal = recording
        # get_frequency is defined inside helper_code.py and provided by the challenge
        current_fs = get_frequency(header)

        current_signal = mtc.subset_ch_recordings(leads, current_signal, selected_leads)
        reshaped_window = np.array(pc.extract_windows(current_signal, current_fs, W, CH, mode_use = 'train'))
        current_features = np.array(current_features)
        # predictions are performed on all the windows of the signal.
        preds = model.predict([reshaped_window, np.array([current_features])])
        # the average is computed
        probabilities = average_prediction(preds)
        prob_saved.append(probabilities)   
    
    #### STEP 3 #### 
    # all attempts in the grid search
    possible_thrs = np.arange(0, 0.4, 0.01)
    # initialized vector of thresholds
    ths_dummy = np.full(26, 0.01)
    # Calculating metric for each threshold then selecting the threshold corresponding to the max value of the metric
    # Initialized list with selected threshold (26,)

    w, h = len(classes), len(possible_thrs)
    current_metric = np.zeros([w, h])

    # all_labels = np.array(all_labels)
    #print(possible_thrs)
    for thrs in tqdm(range(len(classes))): 
        for t in range(len(possible_thrs)): 
            ths_dummy[thrs] = possible_thrs[t]
            current_metric[thrs, t], A = ec.compute_challenge_metric(weights, labels, pred_to_binary(prob_saved, ths_dummy), classes, sinus_rhythm)

    initialized_thr = np.zeros(len(classes))

    for i in tqdm(range(len(classes))):
        ix = np.argmax(current_metric[i, :])
        initialized_thr[i] = possible_thrs[ix]
        
    best_threshold = fmin(thr_to_challenge_metric, args=(weights, classes, sinus_rhythm, labels, prob_saved), x0=initialized_thr)
    
    #print('The threshold vector is:')
    current_metric, A = ec.compute_challenge_metric(weights, labels, pred_to_binary(prob_saved, best_threshold), classes, sinus_rhythm)
    #print(best_threshold)
    #print(f'the current metric is: {current_metric}')

    return best_threshold, A

############## ENSAMBLE MODEL #################
# Ensemble predictions package with all corresponding functions
# Starting function to find model files ####

def find_models(model_directory, model_name):
    '''
        This function receives the directory path of the saved models and returns a list object
        containing all the specific paths to models and the one containing related json. 
        
        Dependences: 
            tqdm, os
        Args: 
            model_directory: path to './modeldir'
            model_name:  string that contains the model name
        Returns: 
            model_files: list containing all the model files (should be 3 for each lead)
            json_files: list containing all the json files (each corresponding to an .h5)
    '''
    model_files = list()
    json_files = list()
    for model in os.listdir(model_directory):
        root, extension = os.path.splitext(model)
        if not root.startswith('.') and extension == '.h5' and root.startswith(model_name+'_'):
            curr_model_file = os.path.join(model_directory, root + '.h5')
            curr_json_file = os.path.join(model_directory, root + '.json')
            if os.path.isfile(curr_model_file) and os.path.isfile(curr_json_file):
                model_files.append(curr_model_file)
                json_files.append(curr_json_file)
    return model_files, json_files


def load_ens_models(model_files, json_files):
    '''
        This function receives as input the list of all available models
        and loads them as a list in order to enable their simultaneous use

        Dependences: 
            tqdm, tensorflow, json
        Args: 
            model_files: a list of the paths leading to .h5 models
            json_files: a list of the paths leading to json files
        Returns:
            all_models: the list of LOADED models
            all_json: the list of loaded jsons
    
    '''
    all_models = list()
    all_json = list()

    for i in range(len(model_files)):
        # define filename for this ensemble
        curr_model = model_files[i]
        # load model from file
        model = tf.keras.models.load_model(curr_model)
        # add to list of members
        all_models.append(model)
        # print(f'loaded {curr_model}')
        curr_json = []
        with open(json_files[i]) as phys_classes:
            curr_json = json.load(phys_classes)
        all_json.append(curr_json)

    return all_models, all_json


def ensemble_evaluation(all_models, all_json, header, recording, thr_opt_models, json_directory):
    '''
        This function extracts and performs the mean value of the probability obtained in the test from all the models
        Should be iteratively applied on a single header-recording couple each time. Used in run_model()
        
        Dependences: 
            tqdm, numpy, model_training_code, prediction_code, processing_code [all inside challengePackage], pickle
        Args:
            all_models: list() containing all LOADED models
            all_classes: list() containing all LOADED json files corresponding to the respective model
            optimal_thr_models: np.array containing the extracted mean value of threshold of models
            header: of the signal
            recording: of the signal 
        Returns:
            signal_probs: np.array with the MEAN of all the probabilities estimated from the models taken into account
            classes: 26 classes taken from json dictionary
            labels: predicted labels of the signal
    
    '''
    
    pickle_filename = os.path.join(json_directory, 'minmax_scaler.pkl')
    # with open(pickle_filename, 'rb') as file:
    #     fitted_scaler = pickle.load(open(pickle_filename))
    fitted_scaler = pickle.load(open(pickle_filename, 'rb'))   
    
    signal_probs = np.zeros(shape=(len(all_models), 26))
    for model in range(len(all_models)):
        #print(all_json)
        # print(all_json[model])
        # Load info of json file of current model
        dictionary_model = all_json[model]
        # need to define varying number of leads depending on model.
        CH = len(dictionary_model['leads'])
        # print("printing dictionary model")
        # print(dictionary_model['leads'])
        classes = dictionary_model['classes']
        current_fs = get_frequency(header)
        current_sig_len = get_num_samples(header)
        current_time_sec = current_sig_len / current_fs
        leads = get_leads(header)
        W = 4096
        #### STEP 1 ####
        # Extract windows from recording to be predicted
        my_leads = ('II')
        # STEP. Preprocessing of the signal before filtering. 
        # print("the leads obtained from header are:")
        # print(leads)
        ecg_ord = mtc.subset_ch_recordings(leads, recording, my_leads)
        # STEP. Filtering and normalization 
        filtered_norm_ecg = pc.filtering(ecg_ord, current_fs)
        # STEP. RR series 
        rr_series_list = pc.from_ecg_to_rr(filtered_norm_ecg, current_fs, current_time_sec)
        # STEP. Feature Extraction
        current_feature = pd.DataFrame()
        time_features = pc.time_domain_features(rr_series_list, 'prediction_code')
        current_feature = current_feature.append(time_features, ignore_index = True)

        current_feature =  current_feature.set_index('name')
        #print(current_feature)
        # print('Current_feats_values')
        # print(current_feature[:].values)
        current_feature[:] = fitted_scaler.transform(current_feature[:].values) 
        feature_dict = current_feature.to_dict('index')
        feats = list(feature_dict['prediction_code'].values())
        # print(f"the leads I want to select are: {dictionary_model['leads']}")
        current_signal = mtc.subset_ch_recordings(leads, recording, dictionary_model['leads'])
        #### STEP 1.5 ####
        reshaped_window = pc.extract_windows(current_signal, current_fs, W, CH, mode_use='test')
        
        pred_feats = []
        for i in range(len(reshaped_window)):
            pred_feats.append(np.array(feats))
        
        #### STEP 2 ####
        # predictions are performed on all the windows of the signal.
        preds = all_models[model].predict([np.array(reshaped_window), np.array(pred_feats)])
        #### STEP 3 ####
        # the average is computed
        probabilities = average_prediction(preds)
        # print(f'The model {model + 1} performs the following probabilities on the signal: {probabilities}')
        # We take the mean probability of the prediction of all models
        signal_probs[model, :] = probabilities
    signal_probs = np.mean(signal_probs, axis=0)
    # print(f'After the average the probability results as: {signal_probs}')
    labels = pred_to_binary(signal_probs, thr_opt_models)

    return classes, labels, signal_probs


def extract_model_thresholds(all_thrs):
    '''
        Dependences: 
            tqdm, numpy
        Args: 
            all_thrs: list of json_dictionary of all the models in the ensemble
        Returns: 
            thr_opt: mean threshold of the models 

    '''
    thr_opt = np.zeros(shape=(len(all_thrs), 26))
    for model in range(len(all_thrs)):
        dictionary_model = all_thrs[model]
        # Now we do the same on thresholds: sum all them up and mean
        thr_opt[model] = dictionary_model['thresholds']

    thr_opt = np.mean(thr_opt, axis=0)

    return thr_opt

# Find all models inside the directory that have both json and h5 

def select_best_model(files_dirs): 
    '''
        This function:
        - takes directories
        - selects best model depending on scores
        - loads 3 best models (one for each dataset)
        - loads 3 best json files
        Dependences: 
            json
        Args: 
            files_dirs: list of directories of the models /jsons
        Returns:  
            model_D and json_D: best models for each dataset considered
    '''

    num_files = len(files_dirs[1])
   
    # Initialize metric's vectors 
    f_measure_d1 = {}
    f_measure_d2 = {}
    f_measure_d3 = {}

    # range over all files and divide in datasets by creating a dictionary
    for f in range(num_files):
        current_json_dir = files_dirs[1][f]
        
        # load json file
        curr_json = {}
        with open(current_json_dir) as json_file:
            curr_json = json.load(json_file)
        
        # compute f_measure
        curr_scores = curr_json['scores']
        recall = curr_scores[2]
        precision = curr_scores[3]
        f_measure = (2*precision*recall)/(precision+recall+0.001)
        
        if curr_json['ix_dataset'] == '1': 
            f_measure_d1[str(f)] = f_measure
        elif curr_json['ix_dataset'] == '2':
            f_measure_d2[str(f)] = f_measure
        elif curr_json['ix_dataset'] == '3': 
            f_measure_d3[str(f)] = f_measure
    
    # find max f_measure for each of the datasets if non empty
    max_keys = []
    if bool(f_measure_d1):
        max_key_d1 = max(f_measure_d1, key=f_measure_d1.get)
        max_keys.append(max_key_d1)
    if bool(f_measure_d2): 
        max_key_d2 = max(f_measure_d2, key=f_measure_d2.get)
        max_keys.append(max_key_d2)
    if bool(f_measure_d3):
        max_key_d3 = max(f_measure_d3, key=f_measure_d3.get)
        max_keys.append(max_key_d3)

    # We use the best f measure model for each of the datasets
    
    model_D = []
    json_D = []
    for m_key in max_keys: 
        model_D.append(tf.keras.models.load_model(files_dirs[0][int(m_key)]))

        j_dict = {}
        with open(files_dirs[1][int(m_key)]) as json_dict: 
            j_dict = json.load(json_dict)
        
        json_D.append(j_dict)

    return model_D, json_D 

###### THIS FUNCTIONS ARE NOT USED AT THE MOMENT. KEPT FOR FUTURE USE ######
def pred_to_file(output_dir, file_name, snowmed_labels_red, bin_pred, proba_pred): 
    '''
        In prediction code (prc)

        Dependences: 
            None
        Args: 
            output_dir: directory where to save the output files
            file_name: name of the file. 
            snowmed_labels: 26 classes for which the prediction is performed 
            proba_pred: mean probabilities of prediction 
            bin_pred: prediction after binarisation
            
            # Recording ID
            diagnosis_1, diagnosis_2, diagnosis_3
                    0,           1,           1
                    0.12,        0.34,        0.56
            
            The function saves output_file_list to .txt file in the same format in
            a specified directory   
                    
    '''
    # initialise list
    output_file_list = []
    # description line
    output_file_list.append(f'# Recording {file_name}')
    # snowmed labels
    output_file_list.append(listToString(snowmed_labels_red))
    # binary array of predictions
    output_file_list.append(listToString(bin_pred))
    
    # probability array of predictions
    output_file_list.append(listToString(proba_pred))
    
    # save each list to file line 
    out_dir = str(output_dir)+'/'+str(file_name) +'.csv'
    with open(out_dir, 'w') as f:
        for i in output_file_list:
            f.write("%s\n" % i)         
  
    # DEBUGGING: print(output_file_list)

def listToString(a_list):  
    '''
        In prediction code (prc)

        Dependences: 
            None
        Args: 
            a_list: list as input
        Returns: 
            joined_string: a string separated by commas is returned. 
        
        Function to convert list to string. 
        
    '''
    # the list is converted into string (if not altready a string)
    converted_list = [str(element) for element in a_list]
    # the string elements are separated by a comma
    joined_string = ", ".join(converted_list)
  
    return (joined_string)