#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of the required functions, remove non-required functions, and add your own functions.

################################################################################
#
# Imported functions and variables
#
#
################################################################################

# checkl https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/sequence/TimeseriesGenerator
# Import functions. These functions are not required. You can change or remove them.
from helper_code import *
import numpy as np, os, sys, joblib
import numpy
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
import h5py

from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, LSTM, Dense, Dropout, TimeDistributed, Flatten,Bidirectional
from tensorflow.keras.optimizers import Adam
import keras
tf.__version__

# Define the Challenge lead sets. These variables are not required. You can change or remove them.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)

################################################################################
#
# Training model function
#
################################################################################

# Train your model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    # Find header and recording files.
    print('Finding header and recording files...')

    header_files, recording_files = find_challenge_files(data_directory)
    num_recordings = len(recording_files)

    if not num_recordings:
        raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)

    # Extract the classes from the dataset.
    print('Extracting classes...')

    classes = set()
    for header_file in header_files:
        header = load_header(header_file)
        classes |= set(get_labels(header))
    if all(is_integer(x) for x in classes):
        classes = sorted(classes, key=lambda x: int(x)) # Sort classes numerically if numbers.
    else:
        classes = sorted(classes) # Sort classes alphanumerically if not numbers.
    num_classes = len(classes)
    
    #CARDILUX log
    print('Find the biggest Number of samples on all dataset')
    biggest_N = 0
    memory = 10000
    #prev_header_file=np.array([])
    for header_file in header_files:
        header = load_header(header_file)
        N = get_num_samples(header)
        if N>biggest_N:
            biggest_N = N 

        if biggest_N> memory:
          biggest_N = memory

    print("biggest header file is: ", biggest_N)
    #print("Max size (x12 leads): ",biggest_N*12)
    
    biggest_N = int(biggest_N)
    
    # Bagging proceddure for biggest_N size
    data = np.zeros((num_recordings, biggest_N), dtype=np.float32) # average all 12 leads into one stack.
    labels = np.zeros((num_recordings, num_classes), dtype=np.bool) # One-hot encoding of classes
    
    for i in range(num_recordings):
        print('    {}/{}...'.format(i+1, num_recordings))

        # Load header and recording.
        header = load_header(header_files[i])
        recording = load_recording(recording_files[i])

        # Get age, sex and root mean square of the leads.
        training, N  = pre_processing(header, recording, twelve_leads)
        
        aux = np.array([])
        
        frequency = get_frequency(header)
        
                # If N is smaller than biggest_N, refill with zeros
        
        if N <= biggest_N:
            refilled = biggest_N - len(training)# Missing samples to have the same size
            refilled = int(refilled) # cast to int
            # concatenate refill to chained record.
            aux = np.concatenate((training, np.zeros(refilled)), axis=None)# 
            #print("len of aux: ", len(aux))
            data[i,:] = np.concatenate((training, np.zeros(refilled)), axis=None)# 
            #print("len of data: "len(data[i,:]))
            '''
            Alternatively use
            
            from tensorflow.keras.preprocessing.sequence import pad_sequences
            padding = tf.keras.preprocessing.sequence.pad_sequences(mel, maxlen = biggest_N - len(training), padding = "post")
            '''
        elif N>biggest_N:
           # print("len of training ", len(training))
            data[i,:] = training[:biggest_N] # [1:end,60000]

        current_labels = get_labels(header)
        for label in current_labels:
            if label in classes:
                j = classes.index(label)
                labels[i, j] = 1
                
    training_padded = np.array(data)
    training_labels = np.array(labels)
    
    print("training_padded shape: ", training_padded.shape)
    print("training_labels shape: ", training_labels.shape)
    #testing_padded = np.array(testing_padded)
    #testing_labels = np.array(testing_labels)
    
    #training_padded = data[:,:10]

    # Train a model for each lead set.
    for leads in lead_sets:
        print('Training model for {}-lead set: {}...'.format(len(leads), ', '.join(leads)))

        print("Training model.. ")
    
        #data = training_padded.reshape(training_padded.shape[0], training_padded.shape[1], 1)

        #data = np.array(data)
        
        num_classes = training_labels.shape[1] 
        
        classifier = dnn_model(training_padded, num_classes)
        
    
        classifier.compile(loss = "categorical_crossentropy", optimizer = "adam", metrics=['accuracy'])
        
        seq_test = 100 # manually set each time!! 
        history = classifier.fit(training_padded, 
                                 labels, 
                                 validation_split = 0.2,
                                 epochs = 5,
                                 batch_size = 20, 
                                 verbose=2)
        
        # Save the model.
        save_tf_model(model_directory, leads, classes, classifier)
        # Save meta data for testing
        save_data_from_set(model_directory, leads, classes)


################################################################################
#
# Running trained model function
#
################################################################################

# Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def run_model(model, header, recording):
    
    #leads = get_leads(header)
    
    # Open dictionary

    meta_data = model[0]
    
    classifier = model[1] # from dict corresponds to model on h5
    
    classes = meta_data['classes']
    leads = meta_data['leads']

    # Load features.
    num_leads = len(leads)
    
    memory = 10000
    
    data = np.zeros(memory)
    
    training, N = pre_processing(header, recording, leads)

    aux = np.array([])
        
    frequency = get_frequency(header)
        
                
    ##########################################################################
    ####################TO CHECK##############################################
    if N < memory:
        
        refilled = memory - len(training)# Missing samples to have the same size
        refilled = int(refilled) # cast to int
            # concatenate refill to chained record.
        aux = np.concatenate((training, np.zeros(refilled)), axis=None)# 
            #print("len of aux: ", len(aux))
        data = np.concatenate((training, np.zeros(refilled)), axis=None)# 
    elif N >= memory:
        data = training[:memory]
         
         
    # Fit model with 100 initual neurons
    data = data.reshape(1,-1)
    ##########################################################################
    ####################TO CHECK##############################################

    # Predict labels and probabilities.
    probabilities = classifier.predict(data)[0]
    probabilities = min_max_norm(probabilities)
    probabilities = np.array(probabilities)
    print("probs: ", probabilities)
    labels = np.array(probabilities).round().flatten()#np.asarray(probabilities, dtype=np.int)[0]
    labels = labels.astype("int64")
    print("labels: ",labels)
    #probabilities = classifier.predict(data)
    ##########################CORRECT_TO-POSITOIVE-PROBS
    #probabilities = np.asarray(probabilities, dtype=np.float32)[:, 0, 1]
    return classes, labels, probabilities

################################################################################
#
# File I/O functions
#
################################################################################

# Save a trained model. This function is not required. You can change or remove it.
def save_data_from_set(model_directory, leads, classes):
    d = {'leads': leads, 'classes': classes}
    name = "meta-"+get_model_filename(leads) 
    filename = os.path.join(model_directory, name)
    joblib.dump(d, filename, protocol=0)

# Save a trained model. This function is not required. You can change or remove it.
def save_tf_model(model_directory, leads, classes, classifier):
    filename = os.path.join(model_directory, get_model_filename(leads))
    classifier.save(filename)

# Load a trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def load_model(model_directory, leads):
    # meta data 
    name_meta = "meta-"+get_model_filename(leads)
    filename = os.path.join(model_directory, name_meta)
    
    # Model
    name_model = get_model_filename(leads) 
    model = os.path.join(model_directory, name_model)
    print("loading model & metadata ",filename)
    leads_to_model = dict()
    leads_to_model = joblib.load(filename), tf.keras.models.load_model(model)

    return leads_to_model

# Define the filename(s) for the trained models. This function is not required. You can change or remove it.
def get_model_filename(leads):
    sorted_leads = sort_leads(leads)
    return 'model_' + '-'.join(sorted_leads) + '.h5'

################################################################################
#
# Feature extraction function
#
################################################################################

# Extract features from the header and recording. This function is not required. You can change or remove it.
def get_features(header, recording, leads):
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    # Reorder/reselect leads in recordings.
    recording = choose_leads(recording, header, leads)

    # Pre-process recordings.
    adc_gains = get_adc_gains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]

    # Compute the root mean square of each ECG lead signal.
    rms = np.zeros(num_leads)
    for i in range(num_leads):
        x = recording[i, :]
        rms[i] = np.sqrt(np.sum(x**2) / np.size(x))

    return age, sex, rms


def pre_processing(header, recording, leads):
    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
        
    recording = choose_leads(recording, header, leads)
    
    # type float correction iny16 -float 32
    recording = recording.astype(np.float32)

    adc_gains = get_adc_gains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    
    
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]
    
    #N = get_num_samples(header)
    
    
    aux = np.array([])
    leads_No = len(available_leads)
    
    #for i in range(num_leads):
    #    x = []
    #    x = recording[i, :] # 5000 or whatever size
     #   print("x: ", x)
       # print("aux[i]: ",aux[i])
        #aux[i*5000:i*5000+4999] = x        
    #    aux = np.concatenate((aux, x), axis=None)# [rms, x]
     #   if  (len(aux > leads_No*leads_No)):
     #       aux = aux[0:leads_No*leads_No]
        #print("AUX", aux)
        #print("SEE AUX: ",aux.shape)
        
    ##########################Bagging ideas#################################
    # For this technique, just average 12-lead vertically or sample-wise
    aux = np.average(recording,axis=0)
    N = len(aux)
    
    return aux,N # [5000*12]

def min_max_norm(probs):
    minprob = probs.min()
    maxprob = probs.max()
    return [ (p-minprob)/(maxprob-minprob) for p in probs]

def train_using_gini(X_train, y_train):
    # Gini index will prefer attributes where entropy is lower
    # Gini Index is a metric to measure how often a randomly
    # chosen element would be incorrectly identified.
    # Classifier obj
    clf_gini = DecisionTreeClassifier(class_weight=None, criterion="gini", max_depth=None,
    max_features=None, max_leaf_nodes=None,
    min_impurity_split=1e-07, min_samples_leaf=1,
    min_samples_split=2, min_weight_fraction_leaf=0.0,
    presort=False, random_state=None, splitter="best")

    #train
    clf_gini.fit(X_train, y_train)
    return clf_gini

def dnn_model(training_padded,NUM_CLASSES):
    seq = training_padded.shape[1] # this is too big, so change by 100
    print("dimentiosn to model:  ", seq)
    #For computational facility
#    if seq>10000:
 #       seq = 10000
  #  else:
   #   seq = 10000

    model = Sequential()
    model.add(Dense(1000, input_dim=seq, activation='relu')) #12
    model.add(Dense(32, activation='relu'))  #8
    model.add(Dense(NUM_CLASSES))
    return model

def lstm_model(N_INPUTS,NUM_CLASSES):
    model = Sequential()
    model.add(Bidirectional(LSTM(32,input_shape = (N_INPUTS.shape[0], 1))))
    model.add(Dense(NUM_CLASSES,activation='softmax'))
    return model

