"""
@author: easyG

CinC 2020 - Collection of useful functions
"""

import os
import numpy as np
import scipy as sp
import scipy.signal


from scipy.io import loadmat


import keras as ks
import keras.backend as K
from keras.models import Sequential, Model
from keras.layers.core import Dense
from keras.layers import Layer, concatenate, Convolution1D, GRU, CuDNNGRU, BatchNormalization, LeakyReLU, Dropout, Bidirectional
from keras import initializers, regularizers, constraints



from scipy.signal import butter, lfilter

def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

# Custom Attention Layer (CSPC2018)
class AttentionWithContext(Layer):
    """
    Custom Attention Layer from Chen et al. 2020
    Tsai-Min Chen, Chih-Han Huang, Edward S.C. Shih, Yu-Feng Hu, Ming-Jing Hwang,
    "Detection and Classification of Cardiac Arrhythmias by a Challenge-Best Deep Learning Neural Network Model"
    iScience, Volume 23, Issue 3, 2020
    http://2018.icbeb.org/Challenge.html
    """
    
    def __init__(self,
                 W_regularizer=None, u_regularizer=None, b_regularizer=None,
                 W_constraint=None, u_constraint=None, b_constraint=None,
                 bias=True, **kwargs):
        self.supports_masking = True
        self.init = initializers.get('glorot_uniform')
        self.W_regularizer = regularizers.get(W_regularizer)
        self.u_regularizer = regularizers.get(u_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)
        self.W_constraint = constraints.get(W_constraint)
        self.u_constraint = constraints.get(u_constraint)
        self.b_constraint = constraints.get(b_constraint)
        self.bias = bias
        super(AttentionWithContext, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3
        self.W = self.add_weight(shape=(input_shape[-1], input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        if self.bias:
            self.b = self.add_weight(shape=(input_shape[-1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
            self.u = self.add_weight(shape=(input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_u'.format(self.name),
                                 regularizer=self.u_regularizer,
                                 constraint=self.u_constraint)
        super(AttentionWithContext, self).build(input_shape)

    def compute_mask(self, input, input_mask=None):
        return None

    def call(self, x, mask=None):
        uit = dot_product(x, self.W)
        if self.bias:
            uit += self.b
        uit = K.tanh(uit)
        ait = dot_product(uit, self.u)
        a = K.exp(ait)
        if mask is not None:
            a *= K.cast(mask, K.floatx())
        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
        a = K.expand_dims(a)
        weighted_input = x * a
        return K.sum(weighted_input, axis=1)

    def compute_output_shape(self, input_shape):
        return input_shape[0], input_shape[-1]

# Dot product calculation (CSPC2018)
def dot_product(x, kernel):
    if K.backend() == 'tensorflow':
        return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)
    else:
        return K.dot(x, kernel)
   
# Load challenge data.
def load_challenge_data(header_file):
    with open(header_file, 'r') as f:
        header = f.readlines()
    mat_file = header_file.replace('.hea', '.mat')
    x = loadmat(mat_file)
    recording = np.asarray(x['val'], dtype=np.float64)
    return recording, header

# Decode labels from one-hot-vectors to categorical (for stratified k-fold shuffle split)
def decode_labels(labels):
    unique_labels = np.unique(labels, axis = 0)
    decoded_labels = np.zeros(labels.shape[0])
    k = 0
    for label in labels:
        dx = [i for i, j in enumerate(unique_labels) if (j == label).all()]
        decoded_labels[k] = np.asarray(dx)
        k += 1
    return decoded_labels

# Map reduced labels to full 27 diagnosis labels to have standardized vector for metric
def label_mapping(classes, scores, labels):
    classes_meta = np.loadtxt('dx_mapping_scored.csv', delimiter = ',', dtype = object)
    classes_scored = list(classes_meta[1:, 1]) 
    
    scores_full = np.zeros((scores.shape[0], len(classes_scored)))
    labels_full = np.zeros((labels.shape[0], len(classes_scored)))
    ix = []
    for i in range(len(classes)):
        ix.append(classes_scored.index(classes[i]))
        
    for i in range(scores.shape[0]):
        for j in range(len(ix)):
            scores_full[i, ix[j]] = scores[i, j]
            labels_full[i, ix[j]] = labels[i, j]
            
    return scores_full, labels_full

# Get raw signals, features and labels for a defined subject list
def prepare_data(header_files):

    num_files = len(header_files)
    
    # Meta variables
    sequence_length = 5000
    downscale_fs = 1
    percent_to_keep = 1 
    threshold_classes = round(num_files / 100 * percent_to_keep)
    
    # Clean classes to only include SCORED classes and those with enough samples to simplify the problem
    classes_meta = np.loadtxt('dx_mapping_scored.csv', delimiter = ',', dtype = object)
    scored_classes = classes_meta[1:, 1] 
    
    counts = count_diagnoses(header_files)
        
    to_keep = counts > threshold_classes
    classes = list(scored_classes[to_keep == True]) 
    
    # Pre-allocation to speed up the code
    recordings = np.zeros((num_files, sequence_length, 12))
    features = np.zeros((num_files, 434, 12))
    headers = ["" for x in range(num_files)]
    labels = np.zeros((num_files, len(classes)))
    
    #eng = matlab.engine.start_matlab()
    #matlab_dir = os.path.dirname(__file__) + "/matlab_files/"
    #eng.cd(matlab_dir)
    # Get Data and header infos
    for i in range(num_files):
        recording, header = load_challenge_data(header_files[i])
        res = int(header[1].split(" ")[2].split("/")[0])
        fs = int(header[0].split(" ")[2])
        
        #data_matlab = matlab.double(recording.tolist())
        recordings[i] = preprocess_ECG(recording, sequence_length, fs, res, downscale_fs)
        """
        try:
            features_matlab =eng.ait_cinc2020_v02(data_matlab, [], [], [], [], header, nargout=2) 
        except:
            pass
        
        features[i] = np.asarray(features_matlab[0]).swapaxes(0,1)
        """
        headers[i] = header
        for l in header:
            if l.startswith('#Dx:'):
                labels_act = np.zeros(len(classes))
                arrs = l.strip().split(' ')
                for arr in arrs[1].split(','):
                    try:
                        class_index = classes.index(arr.rstrip()) # Only use first positive index
                        labels_act[class_index] = 1
                    except:
                        pass
        labels[i] = labels_act
    labels = np.array(labels)
    scores = labels
    
    labels, _ = label_mapping(classes, scores, labels)
        
    return recordings, features, labels, classes

# Find unique classes.
def get_classes(input_directory, filenames):
    classes = set()
    for filename in filenames:
        with open(filename, 'r') as f:
            for l in f:
                if l.startswith('#Dx'):
                    tmp = l.split(': ')[1].split(',')
                    for c in tmp:
                        classes.add(c.strip())
    return sorted(classes)

def filtering(data, fs):

    data = butter_bandpass_filter(data, 0.1, 30, fs, order=2)
    return data

def zero_padding(data, sequence_length):
    
    diff = sequence_length - data.shape[1]
    if diff > 0:
        padding = np.zeros((12, diff))
        data = np.hstack((padding, data))
    return data

def preprocess_ECG(data, sequence_length, fs, res, downscale_fs = 1):
    data = data[:, :sequence_length*downscale_fs]

    # Scale data to mV
    data /= res

    # Downscale if wanted
    if downscale_fs > 1:
        print("downsampled")
        data = sp.signal.decimate(data, downscale_fs)
    
    # Filtering
    data = filtering(data, fs/downscale_fs)
    
    # Zero pad for shorter ECG recordings
    data = zero_padding(data, sequence_length)
        
    # Switch array dimensions
    data = data.swapaxes(0, 1)
    data = np.expand_dims(data ,0)

    return data

#def preprocess_features(features):
    


# Create Multilayer-Perceptron for the feature input, expand in future submissions
def mlp(dim):
    model = Sequential()
    
    model.add(Dense(128, input_dim = dim))
    model.add(LeakyReLU(alpha = 0.4))
    model.add(Dropout(0.3))
    model.add(BatchNormalization())
        
    model.add(Dense(64))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    model.add(BatchNormalization())
    
    model.add(Dense(24))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.1))
    model.add(BatchNormalization())
    """
    model.add(Dense(9, activation = 'sigmoid'))
    """
    return model

# Create Deep Learning Network for raw signal input (modelled after CPSC2018 Winner)
def dnn(sequence_length):
    model = Sequential()
    n_leads = 12
    
    n_filters = 12
    
    # Conv Block 1
    model.add(Convolution1D(filters = n_filters * 1, kernel_size = 3, padding = 'same', kernel_initializer = 'he_normal', input_shape=(sequence_length, n_leads)))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 24, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Conv Block 2
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 24, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Conv Block 3
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 24, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Conv Block 4
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 24, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Conv Block 5
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 48, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
        
    # Bidirectional Recurrent Layer
    model.add(Bidirectional(CuDNNGRU(15, return_sequences=True, return_state=False, kernel_initializer = 'orthogonal')))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Custom Attention layer for dimensionality reduction
    model.add(AttentionWithContext())
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    model.add(BatchNormalization())
    
    model.add(Dense(27, activation = 'sigmoid'))
    
    return model

def dnn_cpu(sequence_length):
    model = Sequential()
    n_leads = 12
    
    n_filters = 12
    
    # Conv Block 1
    model.add(Convolution1D(filters = n_filters * 1, kernel_size = 3, padding = 'same', kernel_initializer = 'he_normal', input_shape=(sequence_length, n_leads)))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 24, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Conv Block 2
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 24, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Conv Block 3
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 24, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Conv Block 4
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 24, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Conv Block 5
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 3, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Convolution1D(n_filters * 1, 48, strides = 2, padding = 'same', kernel_initializer = 'he_normal'))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
        
    # Bidirectional Recurrent Layer
    model.add(Bidirectional(GRU(15, return_sequences=True, reset_after = True, return_state=False, kernel_initializer = 'orthogonal')))
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    
    # Custom Attention layer for dimensionality reduction
    model.add(AttentionWithContext())
    model.add(LeakyReLU(alpha = 0.3))
    model.add(Dropout(0.2))
    model.add(BatchNormalization())
    
    
    model.add(Dense(27, activation = 'sigmoid'))
    
    return model

# Create mixed-data Neural Network combining the MLP and DNN
def mixed_nn(dim, sequence_length):
    # Create both streams
    model_mlp = mlp(dim)
    model_dnn = dnn(sequence_length)
    #model_mlp.summary()
    #model_dnn.summary()
    
    # Combine the two models' last output
    combined_input = concatenate([model_mlp.output, model_dnn.output])
    
    # Concatenate for combined classification layer
    x = Dense(27, activation="sigmoid")(combined_input)
    
    mixed_model = Model(inputs=[model_mlp.input, model_dnn.input], outputs = x)
    #mixed_model.summary()
    
    return mixed_model

# Create mixed-data Neural Network combining the MLP and DNN
def mixed_nn_cpu(dim, sequence_length):
    # Create both streams
    model_mlp = mlp(dim)
    model_dnn = dnn_cpu(sequence_length)
    #model_mlp.summary()
    #model_dnn.summary()
    
    # Combine the two models' last output
    combined_input = concatenate([model_mlp.output, model_dnn.output])
    
    # Concatenate for combined classification layer
    x = Dense(27, activation="sigmoid")(combined_input)
    
    mixed_model = Model(inputs=[model_mlp.input, model_dnn.input], outputs = x)
    #mixed_model.summary()
    
    return mixed_model

def count_diagnoses(header_files):
    
    diagnoses = []
    
    for header_file in header_files:
        with open(header_file,'r') as f:
            header_data=f.readlines()
            diagnosis = header_data[15].split(": ")[1].split("\n")[:-1]
            diagnosis = diagnosis[0].split(",")
            for i in range(len(diagnosis)):
                diagnoses.append(diagnosis[i])
                
    classes_meta = np.loadtxt('dx_mapping_scored.csv', delimiter = ',', dtype = object)
    scored_classes = classes_meta[1:, 1] 
    
    counts = np.zeros(len(scored_classes))
    for i in range(len(scored_classes)):
        counts[i] = diagnoses.count(scored_classes[i])
        
    return counts
    
    

def scores_to_label(scores):
    
    labels = np.zeros(len(scores))
    
    t = 0.4
    
    probas_above_t = scores > t
        
    labels[probas_above_t] = 1
    labels = labels.astype(int)
    
    return labels

def convert_model(dataset_id, directory, epoch):
    
    model_file = directory + dataset_id + "_model_" + str(epoch) + ".h5"

    with ks.utils.CustomObjectScope({'AttentionWithContext': AttentionWithContext}):
        model = ks.models.load_model(model_file)
    
    sequence_length = model.inputs[0].shape[1].value
    n_leads = model.inputs[0].shape[2].value
    n_classes = model.layers[-1].output[0].shape[0].value
    
    cpu_model = dnn_cpu(sequence_length, n_leads, n_classes)
    
    model.save_weights('cuda.h5')
    
    cpu_model.load_weights('cuda.h5')
    os.remove('cuda.h5')
    cpu_model.name = dataset_id
    cpu_model.save(directory + dataset_id + '_model_' + str(epoch) + '_cpu.h5')
    
"""
if __name__ == '__main__':
    input_directory = "C:/data/cinc-2020/3_stpetersburg/"
    header_files = []
    res = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            header_files.append(g)
    num_files = len(header_files)
    for header in header_files:
        with open(header,'r') as f:
            print(header)
            header_data=f.readlines()
            res.append(header_data[1].split(" ")[2].split("/")[0])
"""