#!/usr/bin/env python

import numpy as np
import os, sys
from scipy.io import loadmat
import scipy
import scipy.signal

# model package
import keras
from keras.layers import LeakyReLU, Dense, Dropout, Input, Convolution1D, Layer, Flatten
from keras.models import Model
from keras.layers.normalization import BatchNormalization
from keras import regularizers, initializers, constraints
from keras import backend as K
from keras.layers import AveragePooling1D
from keras.layers.merge import concatenate
from keras_multi_head import MultiHead


# data generator
# length = 30000
# resampling

# %% Batch generator class

class DataGenerator_resample(keras.utils.Sequence):
    'Generates data for Keras'

    def __init__(self, list_IDs, labels, batch_size=64, dim=(30000, 12), sequence_length=30000,
                 n_classes=24, shuffle=True):

        'Initialization'
        self.list_IDs = list_IDs
        self.labels = labels
        self.batch_size = batch_size
        self.dim = dim
        self.sequence_length = sequence_length
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.input_directory = sys.argv[1]
        self.on_epoch_end()

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples'  # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim), dtype=float)
        y = np.empty((self.batch_size, self.n_classes), dtype=int)

        for i, ID in enumerate(list_IDs_temp):

            tmp_input_file = os.path.join(self.input_directory, ID)
            x = loadmat(tmp_input_file)  # Fetch one lead from ECG
            data = np.asarray(x['val'], dtype=np.float64)

            input_label_file = os.path.join(self.input_directory, ID)
            if input_label_file.endswith('.mat'):
                input_label_file = input_label_file.replace('.mat','.hea')
            else:
                input_label_file = input_label_file+'.hea'

            with open(input_label_file, 'r') as f:
                first_line = f.readline()
                fs = int(first_line.split(' ')[2])

            if fs != 500:

                secs = len(data[0, :]) / fs  # Number of seconds in signal X
                samps = int(secs * 500)  # Number of samples to downsample
                ecg = np.zeros((samps, 12), dtype=np.float32)

                for lead in range(12):
                    resample_signal = scipy.signal.resample(data[lead, :], samps)
                    ecg[:, lead] = resample_signal

                if ecg.shape[0] > 30000:
                    ecg = ecg[0:30000, :]

                else:
                    tmp = np.zeros((30000, 12), dtype=np.float32)
                    tmp[-ecg.shape[0]:, :] = ecg
                    ecg = tmp

            else:
                ecg = data.T

                if ecg.shape[0] > 30000:
                    ecg = ecg[0:30000, :]

                else:
                    tmp = np.zeros((30000, 12), dtype=np.float32)
                    tmp[-ecg.shape[0]:, :] = ecg
                    ecg = tmp

            # z-score
            for j in range(12):
                if np.std(ecg[:, j]) == 0:
                    continue
                else:
                    ecg[:, j] = ((ecg[:, j] - np.mean(ecg[:, j])) / np.std(ecg[:, j]))

            X[i, :, :] = ecg
            y[i] = self.labels[ID]

        return X, y


main_input = Input(shape=(30000, 12), dtype='float32', name='main_input')
def Inception_model(number):
    dropout_ = 0.2
    for i in range(1, number + 1):
        if i == 1:
            # branch a
            branch_a = Convolution1D(12, 1, strides=2, padding='same')(main_input)
            branch_a = BatchNormalization()(branch_a)
            branch_a = LeakyReLU(alpha=0.3)(branch_a)
            print(branch_a.shape)

            # branch b
            branch_b = Convolution1D(12, 1)(main_input)
            branch_b = BatchNormalization()(branch_b)
            branch_b = LeakyReLU(alpha=0.3)(branch_b)
            branch_b = Convolution1D(12, 3, strides=2, padding='same')(branch_b)
            branch_b = BatchNormalization()(branch_b)
            branch_b = LeakyReLU(alpha=0.3)(branch_b)
            print(branch_b.shape)

            # branch c
            branch_c = AveragePooling1D(3, strides=2, padding='same')(main_input)
            branch_c = Convolution1D(12, 3, padding='same')(branch_c)
            branch_c = BatchNormalization()(branch_c)
            branch_c = LeakyReLU(alpha=0.3)(branch_c)
            print(branch_c.shape)

            # branch d
            branch_d = Convolution1D(12, 3, padding='same')(main_input)
            branch_d = BatchNormalization()(branch_d)
            branch_d = LeakyReLU(alpha=0.3)(branch_d)
            branch_d = Convolution1D(12, 3, padding='same')(branch_d)
            branch_d = BatchNormalization()(branch_d)
            branch_d = LeakyReLU(alpha=0.3)(branch_d)
            branch_d = Convolution1D(12, 12 * 2, strides=2, padding='same')(branch_d)
            branch_d = BatchNormalization()(branch_d)
            branch_d = LeakyReLU(alpha=0.3)(branch_d)
            print(branch_d.shape)

            # concatenate

            output = concatenate([branch_a, branch_b, branch_c, branch_d], axis=-1)
            print(output.shape)

            # dropout layer
            x = Dropout(dropout_)(output)

        else:
            # branch a
            branch_a = Convolution1D(12, 1, strides=2, padding='same')(x)
            branch_a = BatchNormalization()(branch_a)
            branch_a = LeakyReLU(alpha=0.3)(branch_a)
            print(branch_a.shape)

            # branch b
            branch_b = Convolution1D(12, 1)(x)
            branch_b = BatchNormalization()(branch_b)
            branch_b = LeakyReLU(alpha=0.3)(branch_b)
            branch_b = Convolution1D(12, 3, strides=2, padding='same')(branch_b)
            branch_b = BatchNormalization()(branch_b)
            branch_b = LeakyReLU(alpha=0.3)(branch_b)
            print(branch_b.shape)

            # branch c
            branch_c = AveragePooling1D(3, strides=2, padding='same')(x)
            branch_c = Convolution1D(12, 3, padding='same')(branch_c)
            branch_c = BatchNormalization()(branch_c)
            branch_c = LeakyReLU(alpha=0.3)(branch_c)
            print(branch_c.shape)

            # branch d
            branch_d = Convolution1D(12, 3, padding='same')(x)
            branch_d = BatchNormalization()(branch_d)
            branch_d = LeakyReLU(alpha=0.3)(branch_d)
            branch_d = Convolution1D(12, 3, padding='same')(branch_d)
            branch_d = BatchNormalization()(branch_d)
            branch_d = LeakyReLU(alpha=0.3)(branch_d)
            branch_d = Convolution1D(12, 12 * 2, strides=2, padding='same')(branch_d)
            branch_d = BatchNormalization()(branch_d)
            branch_d = LeakyReLU(alpha=0.3)(branch_d)
            print(branch_d.shape)

            # concatenate

            output = concatenate([branch_a, branch_b, branch_c, branch_d], axis=-1)
            print(output.shape)

            # dropout layer
            x = Dropout(dropout_)(output)

    x = MultiHead([
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext(),
        AttentionWithContext()

    ], name='Multi-Atts')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.3)(x)
    x = Dropout(0.2)(x)
    x = Flatten()(x)

    main_output = Dense(24, activation='sigmoid')(x)
    model = Model(inputs=main_input, outputs=main_output)
    print(model.summary())

    return model


# attention_layer
def dot_product(x, kernel):
    if K.backend() == 'tensorflow':
        return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)
    else:
        return K.dot(x, kernel)


class AttentionWithContext(Layer):
    def __init__(self,
                 W_regularizer=None, u_regularizer=None, b_regularizer=None,
                 W_constraint=None, u_constraint=None, b_constraint=None,
                 bias=True, **kwargs):
        self.supports_masking = True
        self.init = initializers.get('glorot_uniform')
        self.W_regularizer = regularizers.get(W_regularizer)
        self.u_regularizer = regularizers.get(u_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)
        self.W_constraint = constraints.get(W_constraint)
        self.u_constraint = constraints.get(u_constraint)
        self.b_constraint = constraints.get(b_constraint)
        self.bias = bias
        super(AttentionWithContext, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3
        self.W = self.add_weight((input_shape[-1], input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        if self.bias:
            self.b = self.add_weight((input_shape[-1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
            self.u = self.add_weight((input_shape[-1],),
                                     initializer=self.init,
                                     name='{}_u'.format(self.name),
                                     regularizer=self.u_regularizer,
                                     constraint=self.u_constraint)
        super(AttentionWithContext, self).build(input_shape)

    def compute_mask(self, input, input_mask=None):
        return None

    def call(self, x, mask=None):
        uit = dot_product(x, self.W)
        if self.bias:
            uit += self.b
        uit = K.tanh(uit)
        ait = dot_product(uit, self.u)
        a = K.exp(ait)
        if mask is not None:
            a *= K.cast(mask, K.floatx())
        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
        a = K.expand_dims(a)
        weighted_input = x * a
        return K.sum(weighted_input, axis=1)

    def compute_output_shape(self, input_shape):
        return input_shape[0], input_shape[-1]

