#!/usr/bin/env python

import numpy as np, os
from get_12ECG_features import get_12ECG_features
from keras import initializers
from keras.layers import Layer
from bert4keras.backend import keras, K
from bert4keras.backend import sequence_masking
from bert4keras.backend import recompute_grad
from keras.preprocessing import sequence
import scipy
from keras.preprocessing import sequence
import tensorflow as tf
from biosppy.signals import ecg
from keras.layers import Dense,Input,GlobalAveragePooling1D,Dropout,Add,Activation
from keras.models import Model
from keras.layers.normalization import BatchNormalization

def run_12ECG_classifier(data,header_data,loaded_model):


    model = loaded_model
    temp_class = np.load('classes.npy')
    classes = []
    for i in range(len(temp_class)):
        classes.append(str(temp_class[i]))
    current_label = np.zeros(len(classes), dtype=int)
    current_score = np.zeros(len(classes))
    for l in header_data:
        fs = int(l.split(' ')[2])
        break
    ECG = []
    if np.size(data,1) >= 20*fs:
        for num in range(np.size(data,1)//(10*fs)):
            ECG = []
            label = np.zeros(len(classes), dtype=int)
            ecg_data = data[:,num*10*fs:(num+1)*10*fs]
            if fs!=500:
                ecg_data = scipy.signal.resample(ecg_data,round(len(ecg_data[1,:])*500/fs),axis=1)
            rpeaks, = ecg.hamilton_segmenter(signal=ecg_data[1,:], sampling_rate=500)
            rpeaks, = ecg.correct_rpeaks(ecg_data[1,:], rpeaks=rpeaks,
                                                     sampling_rate=500)
            if len(rpeaks)!=0 and rpeaks[0] <1000:
                ecg_data = sequence.pad_sequences(np.reshape(ecg_data[:,rpeaks[0]:],[12,len(ecg_data[0,rpeaks[0]:])]),5000,dtype = 'float64')
            else:
                ecg_data = sequence.pad_sequences(np.reshape(ecg_data,[12,len(ecg_data[0,:])]),5000,dtype = 'float64')
            ecg_data = np.transpose(ecg_data)
            ECG.append(ecg_data)
            score = model.predict(np.array(ECG))
            label[np.argmax(score[0])] = 1
            current_label[np.argmax(score[0])] = current_label[np.argmax(score[0])] or label[np.argmax(score[0])]
            current_score += score[0]
        current_score/= num
            
    else:
        label = np.zeros(len(classes), dtype=int)
        if fs!=500:
            ecg_data = scipy.signal.resample(data,round(len(data[1,:])*500/fs),axis=1)
        else:
            ecg_data = data
        rpeaks, = ecg.hamilton_segmenter(signal=ecg_data[1,:], sampling_rate=500)
        rpeaks, = ecg.correct_rpeaks(ecg_data[1,:], rpeaks=rpeaks,
                                                 sampling_rate=500)
        if len(rpeaks)!=0 and rpeaks[0] <1000:
            ecg_data = sequence.pad_sequences(np.reshape(ecg_data[:,rpeaks[0]:],[12,len(ecg_data[0,rpeaks[0]:])]),5000,dtype = 'float64')
        else:
            ecg_data = sequence.pad_sequences(np.reshape(ecg_data,[12,len(ecg_data[0,:])]),5000,dtype = 'float64')
        ecg_data = np.transpose(ecg_data)
        ECG.append(ecg_data)
        score = model.predict(np.array(ECG))
        label[np.argmax(score[0])] = 1
        current_label[np.argmax(score[0])] = current_label[np.argmax(score[0])] or label[np.argmax(score[0])]
        current_score += score[0]
    current_label = np.array(current_label)

    return current_label, current_score,classes

def load_12ECG_model(input_directory):
    # load the model from disk 
    f_out='model-27'
    filename = os.path.join(input_directory,f_out)
    inputs = Input(shape=(5000,12), name="inputs")
    x = add_resnet_layers(inputs)
    encodings = PositionEncoding(128)(x)
    encodings = Add()([x, encodings])
    x = MultiHeadAttention(8, 64)([encodings, encodings, encodings])
    x = GlobalAveragePooling1D()(x)
    x = Dropout(0.2)(x)
    x = Dense(64, activation='relu')(x)
    outputs = Dense(27, activation='softmax')(x)
    model = Model(inputs=inputs, outputs=outputs)
    model.load_weights(filename)

    return model
def _bn_relu(layer, dropout=0, **params):
    layer = Activation('relu')(layer)

    if dropout > 0:
        layer = Dropout(0.5)(layer)

    return layer

def add_conv_weight(
        layer,
        filter_length,
        num_filters,
        subsample_length=1,
        **params):
    from keras.layers import Conv1D 
    layer = Conv1D(
        filters=num_filters,
        kernel_size=filter_length,
        strides=subsample_length,
        padding='same')(layer)
    return layer

def resnet_block(
        layer,
        num_filters,
        subsample_length,
        block_index,
        **params):
    from keras.layers import Add 
    from keras.layers import MaxPooling1D
    from keras.layers.core import Lambda

    def zeropad(x):
        y = K.zeros_like(x)
        return K.concatenate([x, y], axis=2)

    def zeropad_output_shape(input_shape):
        shape = list(input_shape)
        assert len(shape) == 3
        shape[2] *= 2
        return tuple(shape)

    shortcut = MaxPooling1D(pool_size=subsample_length,padding="same")(layer)
    zero_pad = (block_index % 4) == 0 \
        and block_index > 0
    if zero_pad is True:
        shortcut = Lambda(zeropad, output_shape=zeropad_output_shape)(shortcut)

    for i in range(2):
        if not (block_index == 0 and i == 0):
            layer = _bn_relu(
                layer,
                dropout=0.5 if i > 0 else 0,
                **params)
        layer = add_conv_weight(
            layer,
            16,
            num_filters,
            subsample_length if i == 0 else 1,
            **params)
        layer= BatchNormalization()(layer)
    layer = Add()([shortcut, layer])
    return layer

def get_num_filters_at_index(index, num_start_filters, **params):
    return 2**int(index / 4) \
        * num_start_filters

def add_resnet_layers(layer, **params):
    layer = add_conv_weight(
        layer,
        16,
        32,
        subsample_length=1,
        **params)
    layer = BatchNormalization()(layer)
    layer = _bn_relu(layer, **params)
    conv_subsample_lengths = [1,2,1,2,1,2,1,2,1,2]
    for index, subsample_length in enumerate(conv_subsample_lengths):
        num_filters = get_num_filters_at_index(
            index, 32, **params)
        layer = resnet_block(
            layer,
            num_filters,
            subsample_length,
            index,
            **params)
    layer = _bn_relu(layer, **params)
    return layer
class MultiHeadAttention(Layer):
    """多头注意力机制
    """
    def __init__(
        self,
        heads,
        head_size,
        key_size=None,
        use_bias=True,
        attention_scale=True,
        kernel_initializer='glorot_uniform',
        **kwargs
    ):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.heads = heads
        self.head_size = head_size
        self.out_dim = heads * head_size
        self.key_size = key_size or head_size
        self.use_bias = use_bias
        self.attention_scale = attention_scale
        self.kernel_initializer = initializers.get(kernel_initializer)

    def build(self, input_shape):
        super(MultiHeadAttention, self).build(input_shape)
        self.q_dense = Dense(
            units=self.key_size * self.heads,
            use_bias=self.use_bias,
            kernel_initializer=self.kernel_initializer
        )
        self.k_dense = Dense(
            units=self.key_size * self.heads,
            use_bias=self.use_bias,
            kernel_initializer=self.kernel_initializer
        )
        self.v_dense = Dense(
            units=self.out_dim,
            use_bias=self.use_bias,
            kernel_initializer=self.kernel_initializer
        )
        self.o_dense = Dense(
            units=self.out_dim,
            use_bias=self.use_bias,
            kernel_initializer=self.kernel_initializer
        )

    @recompute_grad
    def call(self, inputs, mask=None, a_mask=None, p_bias=None):
        """实现多头注意力
        q_mask: 对输入的query序列的mask。
                主要是将输出结果的padding部分置0。
        v_mask: 对输入的value序列的mask。
                主要是防止attention读取到padding信息。
        a_mask: 对attention矩阵的mask。
                不同的attention mask对应不同的应用。
        p_bias: 在attention里的位置偏置。
                一般用来指定相对位置编码的种类。
        """
        q, k, v = inputs[:3]
        q_mask, v_mask, n = None, None, 3
        if mask is not None:
            if mask[0] is not None:
                q_mask = K.cast(mask[0], K.floatx())
            if mask[2] is not None:
                v_mask = K.cast(mask[2], K.floatx())
        if a_mask:
            a_mask = inputs[n]
            n += 1
        # 线性变换
        qw = self.q_dense(q)
        kw = self.k_dense(k)
        vw = self.v_dense(v)
        # 形状变换
        qw = K.reshape(qw, (-1, K.shape(q)[1], self.heads, self.key_size))
        kw = K.reshape(kw, (-1, K.shape(k)[1], self.heads, self.key_size))
        vw = K.reshape(vw, (-1, K.shape(v)[1], self.heads, self.head_size))
        # Attention
        a = tf.einsum('bjhd,bkhd->bhjk', qw, kw)
        # 处理位置编码
        if p_bias == 'typical_relative':
            pos_embeddings = inputs[n]
            a = a + tf.einsum('bjhd,jkd->bhjk', qw, pos_embeddings)
        elif p_bias == 't5_relative':
            pos_embeddings = K.permute_dimensions(inputs[n], (2, 0, 1))
            a = a + K.expand_dims(pos_embeddings, 0)
        # Attention（续）
        if self.attention_scale:
            a = a / self.key_size**0.5
        a = sequence_masking(a, v_mask, 1, -1)
        if a_mask is not None:
            a = a - (1 - a_mask) * 1e12
        a = K.softmax(a)
        # 完成输出
        o = tf.einsum('bhjk,bkhd->bjhd', a, vw)
        if p_bias == 'typical_relative':
            o = o + tf.einsum('bhjk,jkd->bjhd', a, pos_embeddings)
        o = K.reshape(o, (-1, K.shape(o)[1], self.out_dim))
        o = self.o_dense(o)
        # 返回结果
        o = sequence_masking(o, q_mask, 0)
        return o

    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][1], self.out_dim)

    def compute_mask(self, inputs, mask=None):
        if mask is not None:
            return mask[0]

    def get_config(self):
        config = {
            'heads': self.heads,
            'head_size': self.head_size,
            'key_size': self.key_size,
            'use_bias': self.use_bias,
            'attention_scale': self.attention_scale,
            'kernel_initializer':
                initializers.serialize(self.kernel_initializer),
        }
        base_config = super(MultiHeadAttention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
class PositionEncoding(Layer):

    def __init__(self, model_dim, **kwargs):
        self._model_dim = model_dim
        super(PositionEncoding, self).__init__(**kwargs)

    def call(self, inputs):
        seq_length = inputs.shape[1]
        position_encodings = np.zeros((seq_length, self._model_dim))
        for pos in range(seq_length):
            for i in range(self._model_dim):
                position_encodings[pos, i] = pos / np.power(10000, (i-i%2) / self._model_dim)
        position_encodings[:, 0::2] = np.sin(position_encodings[:, 0::2]) # 2i
        position_encodings[:, 1::2] = np.cos(position_encodings[:, 1::2]) # 2i+1
        position_encodings = K.cast(position_encodings, 'float32')
        return position_encodings

    def compute_output_shape(self, input_shape):
        return input_shape