#!/usr/bin/env python

import numpy as np, os, sys, joblib
from scipy.io import loadmat
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from get_12ECG_features import get_12ECG_features


from scipy.signal import butter,filtfilt,resample
import matplotlib.pyplot as plt
from scipy import signal
from biosppy import storage
from biosppy.signals import ecg
import pandas as pd
import scipy
from scipy import optimize


import tensorflow as tf
from tensorflow import keras
from keras.optimizers import Adam
from keras.callbacks import (ModelCheckpoint,
                             TensorBoard, ReduceLROnPlateau,
                             CSVLogger, EarlyStopping)
from keras.backend.tensorflow_backend import set_session
from tensorflow.keras import activations

import argparse
from keras.utils import HDF5Matrix
import h5py
from keras.layers import (Input, Conv1D, MaxPooling1D, Dropout,
                          BatchNormalization, Activation, Add,
                          Flatten, Dense)
from keras.models import Model


import numpy as np, os, sys, joblib
from scipy.io import loadmat
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from get_12ECG_features import get_12ECG_features

# ----- Model ----- #
from keras.layers.core import Dense, Activation, Dropout
from keras.layers.recurrent import LSTM
from keras.models import Sequential
from keras.layers import Bidirectional


def butter_highpass_filter(data,cutoff,fs,order):
    nyq=0.5*fs
    normal_cutoff = cutoff / nyq
    # Get the filter coefficients 
    b, a = butter(order, normal_cutoff, btype='high', analog=False)
    y = filtfilt(b, a, data)
    return y

def load_data(header_file):
    with open(header_file, 'r') as f:
        header = f.readlines()
    mat_file = header_file.replace('.hea', '.mat')
    x = loadmat(mat_file)
    recording = np.asarray(x['val'], dtype=np.float64)
    return recording, header



def temp_generate(ecg_signal,fs,plot_peaks,show_template=False):
    N  = len(ecg_signal)
    t  = np.arange(0,N)
    time=t/fs   
    ecg_signal=ecg_signal-np.mean(ecg_signal) 
    ecg_signal = ecg_signal/ max(abs(ecg_signal));
    ecg_signal=butter_highpass_filter(ecg_signal,0.5,fs,2)
    ecg_signal=ecg_signal / max(abs(ecg_signal)) 
    try:
        out = ecg.ecg(signal=ecg_signal, sampling_rate=fs, show=False)  
        R_loc=out[2]  
        R_val=np.empty((len(R_loc),))  
        for i in range(len(R_loc)):
            R_val[i]=ecg_signal[R_loc[i]]
        if plot_peaks==True:
                plt.figure(figsize=(20,5))
                plt.plot(time,ecg_signal,label='ECG')
                plt.plot(time[R_loc],R_val,'o',c='r',markeredgewidth=1.5,label='Detected R Peaks')
                plt.legend(loc='lower left')
                plt.xlabel('Time')
                plt.ylabel('Normalized Amplitude')
                plt.title('R peaks')
                plt.xlim([0,time[N-1]])
                plt.show()
        else:
            pass 
        cardiac_cycles=list()
        length_of_cycles=list()
        interpolated_cycles=list()
        for i in range(len(R_loc)-1):
            cardiac_op=(ecg_signal[R_loc[i]:R_loc[i+1]])  
            cardiac_length=len(cardiac_op)
            interpolated_cardiac_cycle=resample(cardiac_op,400)
            cardiac_cycles.append(cardiac_op)
            length_of_cycles.append(cardiac_length)  
            interpolated_cycles.append(interpolated_cardiac_cycle) 


        interpolated_cycles_matrix=np.vstack( interpolated_cycles )  
        average_template=np.zeros([len(interpolated_cycles_matrix[0,:]),1]) 

        for i in range(len(interpolated_cycles_matrix[0,:])):
            average_template[i]=np.mean(interpolated_cycles_matrix[:,i]) 
        average_template=average_template/max(abs(average_template));

        swap2_loc = int(np.argmax(average_template[0 : 198]))
        swap1_loc = int(np.argmax(average_template[199 : 399]))

        if swap2_loc !=0 or swap1_loc !=200:
            swap1=np.hstack(average_template[199:swap1_loc+198]);
            swap2=np.hstack(average_template[swap2_loc:198]);
            swapped_template=[*swap1,*swap2]
            swapped_template=np.asarray(swapped_template);
            final_template=resample(swapped_template,400);
            final_template=final_template/max(abs(final_template));
        else:
            swap1=np.hstack(average_template[199:399]);
            swap2=np.hstack(average_template[swap2_loc:198]);
            swapped_template=[*swap1,*swap2]
            swapped_template=np.asarray(swapped_template);
            final_template=resample(swapped_template,400);
            final_template=final_template/max(abs(final_template));
        if show_template==True:
            plt.figure(figsize=(10,10))
            plt.plot(final_template)
            plt.show()
        else:
            pass
    except:
        final_template=np.empty((400,))*np.nan
    return final_template


def generate_single_template(data, header_data ):
    fs=int(header_data[0].split(' ')[2]);
    record=data[1]
    template = temp_generate(record,fs,plot_peaks=False,show_template=False)
    return template

def create_template_dataframe(input_directory):
    
    header_files = []
    for f in os.listdir(input_directory):   
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            header_files.append(g)
    
    headers=[] 
    fs=[];
    lead_II_ecg=[];
    classes = [];
    for i in header_files:
        #print(i.split('/'))
        subject=i.split('/')[1]
        headers.append(subject)
    for i in header_files:
        Sampling_freq=int(load_data(i)[1][0].split(' ')[2]);
        fs.append(Sampling_freq)
        record=load_data(i)[0][1]
        re=load_data(i)[1][15].split(' ')
        b = re[1][:-1]
        lead_II_ecg.append(record)
        classes.append(b)
        
        
    Lead_II_templates=list()
    for i,j,k in zip(lead_II_ecg,fs,range(len(headers))):
        templates=temp_generate(i,j,plot_peaks=False,show_template=False)
        Lead_II_templates.append(templates)

        
    
    ids=[]
    for i in range(len(headers)):
        idss = headers[i].split('.')[0]
        ids.append(idss)
    n = np.arange(400)
    dataset_df = pd.DataFrame(Lead_II_templates, columns = n, index = ids)
    dataset_df = dataset_df.multiply(100)
    classes_real = [270492004,164889003,164890007,426627000,713427006,
                713426002,445118002,39732003,164909002,251146004,698252002,
                10370003,284470004,427172004,164947007,111975006,164917005,
                47665007,59118001,427393009,426177001,426783006,427084000,
                63593006,164934002,59931005,17338001]

    class_dict = {}
    class_names = []
    for i, class_n in enumerate(classes_real):
        class_dict[str(class_n)] = i
        class_names.append(str(class_n))
        
        
    ## ALL label matrix
    labels = np.zeros((len(dataset_df), len(classes_real)))
    diseases =  classes #list(dataset_id[0])
    for i in range(len(diseases)):  #len(disease)
        true_classes = (diseases[i].split(','))
        for j,x in enumerate(true_classes):
            if x == '17338001':
                true_classes[j] ='427172004'
            if x == '63593006':
                true_classes[j] ='284470004'
            if x == '59118001':
                true_classes[j] ='713427006'  

        for true_cl in true_classes:
              if true_cl in class_dict.keys():
                labels[i][class_dict[true_cl]] = 1
                
                
    df_label =pd.DataFrame(labels,columns = class_names,index= ids)
    df_label_template = df_label.merge(dataset_df, left_index=True, right_index=True)
    df_label_template.dropna(inplace=True)
    
    return df_label_template, class_names


def train_12ECG_classifier(input_directory, output_directory):
    # Load data.
    print('Loading data...')

    df_label_template, classes_real = create_template_dataframe(input_directory)
    
    trainX, trainy, testX, testy = get_train_test_data(df_label_template)

    # Train model.
    print('Training model...')
    
    model , best_thr = get_machine_learning_model(trainX, trainy,testX, testy, num_epochs=1000)

    
    # Save model.
        # Save model.
    print('Saving model...')


    final_model={'model':model, 'classes':classes_real, 'best_thr': best_thr}#'imputer':imputer,

    filename_all = os.path.join(output_directory, 'finalized_model.sav')
    filename_model = os.path.join(output_directory, 'finalized_model.h5')
    model.save(filename_model)
    joblib.dump(final_model, filename_all, protocol=0)
    
    
    
    
    

# Load challenge data.
def load_challenge_data(header_file):
    with open(header_file, 'r') as f:
        header = f.readlines()
    mat_file = header_file.replace('.hea', '.mat')
    x = loadmat(mat_file)
    recording = np.asarray(x['val'], dtype=np.float64)
    return recording, header

# Find unique classes.
def get_classes(input_directory, filenames):
    classes = set()
    for filename in filenames:
        with open(filename, 'r') as f:
            for l in f:
                if l.startswith('#Dx'):
                    tmp = l.split(': ')[1].split(',')
                    for c in tmp:
                        classes.add(c.strip())
    return sorted(classes)



class ResidualUnit(object):
    """Residual unit block (unidimensional).
    Parameters
    ----------
    n_samples_out: int
        Number of output samples.
    n_filters_out: int
        Number of output filters.
    kernel_initializer: str, otional
        Initializer for the weights matrices. See Keras initializers. By default it uses
        'he_normal'.
    dropout_rate: float [0, 1), optional
        Dropout rate used in all Dropout layers. Default is 0.8
    kernel_size: int, optional
        Kernel size for convolutional layers. Default is 17.
    preactivation: bool, optional
        When preactivation is true use full preactivation architecture proposed
        in [1]. Otherwise, use architecture proposed in the original ResNet
        paper [2]. By default it is true.
    postactivation_bn: bool, optional
        Defines if you use batch normalization before or after the activation layer (there
        seems to be some advantages in some cases:
        https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md).
        If true, the batch normalization is used before the activation
        function, otherwise the activation comes first, as it is usually done.
        By default it is false.
    activation_function: string, optional
        Keras activation function to be used. By default 'relu'.
    References
    ----------
    .. [1] K. He, X. Zhang, S. Ren, and J. Sun, "Identity Mappings in Deep Residual Networks,"
           arXiv:1603.05027 [cs], Mar. 2016. https://arxiv.org/pdf/1603.05027.pdf.
    .. [2] K. He, X. Zhang, S. Ren, and J. Sun, "Deep Residual Learning for Image Recognition," in 2016 IEEE Conference
           on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770-778. https://arxiv.org/pdf/1512.03385.pdf
    """

    def __init__(self, n_samples_out, n_filters_out, kernel_initializer='he_normal',
                 dropout_rate=0.8, kernel_size=17, preactivation=True,
                 postactivation_bn=False, activation_function='relu'):
        self.n_samples_out = n_samples_out
        self.n_filters_out = n_filters_out
        self.kernel_initializer = kernel_initializer
        self.dropout_rate = dropout_rate
        self.kernel_size = kernel_size
        self.preactivation = preactivation
        self.postactivation_bn = postactivation_bn
        self.activation_function = activation_function

    def _skip_connection(self, y, downsample, n_filters_in):
        """Implement skip connection."""
        # Deal with downsampling
        if downsample > 1:
            y = MaxPooling1D(downsample, strides=downsample, padding='same')(y)
        elif downsample == 1:
            y = y
        else:
            raise ValueError("Number of samples should always decrease.")
        # Deal with n_filters dimension increase
        if n_filters_in != self.n_filters_out:
            # This is one of the two alternatives presented in ResNet paper
            # Other option is to just fill the matrix with zeros.
            y = Conv1D(self.n_filters_out, 1, padding='same',
                       use_bias=False, kernel_initializer=self.kernel_initializer)(y)
        return y

    def _batch_norm_plus_activation(self, x):
        if self.postactivation_bn:
            x = Activation(self.activation_function)(x)
            x = BatchNormalization(center=False, scale=False)(x)
        else:
            x = BatchNormalization()(x)
            x = Activation(self.activation_function)(x)
        return x

    def __call__(self, inputs):
        """Residual unit."""
        x, y = inputs
        n_samples_in = y.shape[1].value
        downsample = n_samples_in // self.n_samples_out
        n_filters_in = y.shape[2].value
        y = self._skip_connection(y, downsample, n_filters_in)
        # 1st layer
        x = Conv1D(self.n_filters_out, self.kernel_size, padding='same',
                   use_bias=False, kernel_initializer=self.kernel_initializer)(x)
        x = self._batch_norm_plus_activation(x)
        if self.dropout_rate > 0:
            x = Dropout(self.dropout_rate)(x)

        # 2nd layer
        x = Conv1D(self.n_filters_out, self.kernel_size, strides=downsample,
                   padding='same', use_bias=False,
                   kernel_initializer=self.kernel_initializer)(x)
        if self.preactivation:
            x = Add()([x, y])  # Sum skip connection and main connection
            y = x
            x = self._batch_norm_plus_activation(x)
            if self.dropout_rate > 0:
                x = Dropout(self.dropout_rate)(x)
        else:
            x = BatchNormalization()(x)
            x = Add()([x, y])  # Sum skip connection and main connection
            x = Activation(self.activation_function)(x)
            if self.dropout_rate > 0:
                x = Dropout(self.dropout_rate)(x)
            y = x
        return [x, y]



def get_f1(y_true, y_pred): #taken from old keras source code
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    recall = true_positives / (possible_positives + K.epsilon())
    f1_val = 2*(precision*recall)/(precision+recall+K.epsilon())
    return f1_val
    

from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers.wrappers import TimeDistributed, Bidirectional
from keras.layers import Dense, Dropout, Embedding, LSTM, Bidirectional,TimeDistributed
from keras.layers.core import Flatten
import tensorflow as tf 
from keras.layers import BatchNormalization, Lambda
from keras import initializers, regularizers
lrelu = Lambda(lambda x: tf.keras.activations.relu(x, alpha=0.1))
#lrelu = lambda x: tf.keras.activations.relu(x, alpha=0.1)
from sklearn.metrics import hamming_loss
from keras.layers import Input, Dense, LSTM, MaxPooling1D, Conv1D
from keras.models import Model
from keras.callbacks import History 
history = History()
import numpy as np
from keras.layers.advanced_activations import PReLU


from tensorflow.keras import initializers


def compute_modified_confusion_matrix(labels, outputs):
    # Compute a binary multi-class, multi-label confusion matrix, where the rows
    # are the labels and the columns are the outputs.
    num_recordings, num_classes = np.shape(labels)
    A = np.zeros((num_classes, num_classes))

    # Iterate over all of the recordings.
    for i in range(num_recordings):
        # Calculate the number of positive labels and/or outputs.
        normalization = float(max(np.sum(np.any((labels.values[i, :], outputs.values[i, :]), axis=0)), 1))
        # Iterate over all of the classes.
        for j in range(num_classes):
            # Assign full and/or partial credit for each positive class.
            if labels.values[i, j] == 1:
                for k in range(num_classes):
                    if outputs.values[i, k]:
                        A[j, k] += 1.0/normalization

    return A

# Compute the evaluation metric for the Challenge.
def compute_challenge_metric(th, weights, labels, outputss, classes, normal_class):
    
 
    outputs = outputss.values
    outputs = (outputs>th).astype(int)
    outputs = pd.DataFrame(outputs)
    num_recordings, num_classes = np.shape(labels)
    normal_index = classes.index(normal_class)

    # Compute the observed score.
    A = compute_modified_confusion_matrix(labels, outputs)
    observed_score = np.nansum(weights * A)

    # Compute the score for the model that always chooses the correct label(s).
    correct_outputs = labels
    A = compute_modified_confusion_matrix(labels, correct_outputs)
    correct_score = np.nansum(weights * A)

    # Compute the score for the model that always chooses the normal class.
    inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool)
    inactive_outputs[:, normal_index] = 1
    A = compute_modified_confusion_matrix(labels, pd.DataFrame(inactive_outputs))
    inactive_score = np.nansum(weights * A)

    if correct_score != inactive_score:
        normalized_score = (float(observed_score - inactive_score) / float(correct_score - inactive_score))*-1
        #print(observed_score,inactive_score,correct_score,normalized_score*-1)
    else:
        normalized_score = 0.0

    return normalized_score

def get_machine_learning_model(trainX, trainy, testX, testy, num_epochs=1000,  verbose=0, kernel_size = 16, output_size = 27):
    Lr = 0.001
    batch_size = 256
    input_layer = Input(shape=(400,1))
    conv1 = Conv1D(filters=64,
                 kernel_size=4,
                 strides=1,
                 padding='same')(input_layer)

                 #PReLU
    #act1 = Lambda(lrelu)(conv1)
    conv2 = Conv1D(filters=64,
                 kernel_size=4,
                 strides=1,
                 padding='same')(conv1)
               	
    #act2 = Lambda(lrelu)(conv2)
               	
               	              	
    lstm1 = LSTM(32, return_sequences=True)(conv2)
    lstm2=LSTM(12,return_sequences=True)(lstm1)
    flatten=Flatten()(lstm2)
    output_layer = Dense(output_size, activation='sigmoid')(flatten)
    model = Model(inputs=input_layer, outputs=output_layer)

    from keras.callbacks import ReduceLROnPlateau
    opt = keras.optimizers.Adam(learning_rate= 0.005)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
    callbacks =keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=8,
                              verbose=0, mode='min')
    learning_rate_reduction = ReduceLROnPlateau(monitor='val_get_f1', 
                                            patience=8, 
                                            verbose=1, 
                                            factor=0.5, 
                                            min_lr=0.0001,mode ='auto')
    # Fit the model
    history = model.fit(trainX, trainy,  validation_data=(testX, testy), epochs=num_epochs,batch_size=256,callbacks=[callbacks])
    
    ypred = model.predict(testX) 

    #weights = [list(map(float, s.split())) for s in string.split("\n") if s]
    
    weights_df = pd.read_csv('weights.csv',index_col=0)
    #weights_df = weights_df.drop(['59118001','63593006','17338001'],axis=1)
    #print(weights_df.shape)
    #weights_df = weights_df.drop([18,23,26],axis=1)
    #weights_df = weights_df.drop([59118001,63593006,17338001])

    #print(weights_df.shape)
    labels = pd.DataFrame(testy)
    outputss= pd.DataFrame(ypred)



    classes = list(labels.keys())
    #print(classes)
    normal_class = 20



    best_thr = scipy.optimize.fmin(compute_challenge_metric, args=(weights_df,labels, outputss,classes, normal_class), x0=0.2)
    best_thr = best_thr*-1 
    #print(best_thr)
    
    #historys.append(history)
    return model, best_thr


def get_machine_learning_model_sa(trainX, trainy,num_epochs=2,  verbose=1, kernel_size = 16, output_size = 27):
    
    kernel_initializer = 'he_normal'
    #lrelu = lambda x: tf.keras.activations.relu(x, alpha=0.1)
    signal = Input(shape=(400, 1), dtype=np.float32, name='signal')
    #age_range = Input(shape=(6,), dtype=np.float32, name='age_range')
    #is_male = Input(shape=(1,), dtype=np.float32, name='is_male')
    x = signal
    x = Conv1D(64, kernel_size, padding='same', use_bias=False,
               kernel_initializer=kernel_initializer)(x)
    x = BatchNormalization()(x)
    
    x = Activation('sigmoid')(x)
    x, y = ResidualUnit(300, 128, kernel_size=kernel_size,
                        kernel_initializer=kernel_initializer)([x, x])
    x, y = ResidualUnit(256, 196, kernel_size=kernel_size,
                        kernel_initializer=kernel_initializer)([x, y])
    x, y = ResidualUnit(64, 256, kernel_size=kernel_size,
                        kernel_initializer=kernel_initializer)([x, y])
    x, _ = ResidualUnit(48, 300, kernel_size=kernel_size,
                        kernel_initializer=kernel_initializer)([x, y])

    x = LSTM(32,return_sequences=True)(x)
    x = LSTM(12,return_sequences=True)(x)
    x = Flatten()(x)
    diagn = Dense(output_size, activation='sigmoid', kernel_initializer=kernel_initializer)(x)
    model = Model(signal, diagn)

    loss = 'binary_crossentropy'
    lr = 0.001
    batch_size = 256
    opt = Adam(lr)
    model.compile(loss=loss, optimizer=opt, metrics=['accuracy'])
    callbacks = [ReduceLROnPlateau(monitor='val_loss',
                                       factor=0.1,
                                       patience=7,
                                       min_lr=lr / 100),
                     EarlyStopping(patience=9,  # Patience should be larger than the one in ReduceLROnPlateau
                                   min_delta=0.00001)]

    callbacks += [TensorBoard(log_dir='./logs', batch_size=batch_size, write_graph=False),
                      CSVLogger('training.log', append=False)]  # Change append to true if continuing training
    # Save the BEST and LAST model
    callbacks += [ModelCheckpoint('./backup_model_last.hdf5'),
                      ModelCheckpoint('./backup_model_best.hdf5', save_best_only=True)]
    # Train neural network
    
    model.fit(trainX, trainy,
                        batch_size=batch_size,
                        epochs=num_epochs,
                        initial_epoch=0,  # If you are continuing a interrupted section change here
                        validation_split=0.2,
                        shuffle='batch',  
                        callbacks=callbacks,
                        verbose=verbose)
    
    return model


def get_train_test_data(df1, nclasses=27):
    from sklearn.model_selection import train_test_split

    train, test = train_test_split(df1, test_size=1)

    trainX = train[train.columns[-400:]]
    trainy = train[train.columns[:nclasses]]
    testX = test[test.columns[-400:]]
    testy = test[test.columns[:nclasses]]


    trainX = trainX.values
    trainy = trainy.values
    testX = testX.values
    testy = testy.values
    
    trainX = trainX.reshape((trainX.shape[0], trainX.shape[1] ,1))
    testX = testX.reshape((testX.shape[0], testX.shape[1],1))
    trainy = trainy.reshape(trainy.shape[0], nclasses)
    testy = testy.reshape(testy.shape[0], nclasses)
    print(trainX.shape, trainy.shape, testX.shape, testy.shape)

    return trainX, trainy, testX, testy



