import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as K
from tensorflow.keras import initializers, regularizers

@tf.keras.utils.register_keras_serializable()
class PysionetChallengeMetric(tf.keras.metrics.Metric):
    '''
    '''
    def __init__(self, inactive_index,name='challenge_metric',class_count=9,weights=None, **kwargs):
        super(PysionetChallengeMetric, self).__init__(name=name, **kwargs)
        self.observed_score  = self.add_weight(name='observed_score', initializer='zeros',shape=(class_count,class_count))
        self.correct_score  = self.add_weight(name='correct_score', initializer='zeros',shape=(class_count,class_count))
        self.inactive_score  = self.add_weight(name='inactive_score', initializer='zeros',shape=(class_count,class_count))

        if weights is None : weights = tf.ones_like(self.observed_score)
        self.dx_weights = tf.cast(weights,tf.float32)
        self.inactive_index = tf.cast(tf.scatter_nd([[inactive_index]],[1],[class_count]), tf.float32)

        self.weights_input = weights
        self.class_count = class_count
        self.inactive_index_input = inactive_index

    def update_state(self, y_true, y_pred,**kwargs):
        threshold=0.5
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(tf.math.greater(y_pred, threshold),tf.float32)
        inactive_outputs = tf.ones_like(y_true) * self.inactive_index

        self.observed_score.assign_add(self._compute_modified_cm(y_true,y_pred))
        self.correct_score.assign_add(self._compute_modified_cm(y_true,y_true))
        self.inactive_score.assign_add(self._compute_modified_cm(y_true,inactive_outputs))

    def result(self):
        correct_score = tf.reduce_sum(self.correct_score*self.dx_weights)
        observed_score = tf.reduce_sum(self.observed_score*self.dx_weights)
        inactive_score = tf.reduce_sum(self.inactive_score*self.dx_weights)
        if correct_score != inactive_score:
            normalized_score = (observed_score - inactive_score) / (correct_score - inactive_score)
        else:
            normalized_score = 0.0
        return normalized_score
    def reset_states(self):
        self.observed_score.assign(tf.zeros_like(self.observed_score))
        self.correct_score.assign(tf.zeros_like(self.correct_score))
        self.inactive_score.assign(tf.zeros_like(self.inactive_score))

        # Compute modified confusion matrix for multi-class, multi-label tasks.
    def _compute_modified_cm(self,labels, outputs):
        normalization = tf.maximum(tf.reduce_sum(tf.minimum(labels+outputs,1),axis=1,keepdims=True), 1)
        return tf.linalg.matmul(labels,(outputs/normalization),transpose_a=True)

    def get_config(self):
        config = super().get_config()
        config.update( {"inactive_index": self.inactive_index_input,
                "class_count":self.class_count,
                "weights":self.weights_input})
        return config


@tf.keras.utils.register_keras_serializable()
class F1Macro(tf.keras.metrics.Metric):
    '''
    Custom metric subclass to compute f1 macro

    '''
    def __init__(self, name='f1_macro',class_count=9, **kwargs):
        super(F1Macro, self).__init__(name=name, **kwargs)
        self.tp = self.add_weight(name='tp', initializer='zeros',shape=(1,class_count))
        self.fp = self.add_weight(name='fp', initializer='zeros',shape=(1,class_count))
        self.fn = self.add_weight(name='fn', initializer='zeros',shape=(1,class_count))

        self.class_count = class_count

    def update_state(self, y_true, y_pred,**kwargs):
        threshold=0.5
        y_true = tf.cast(y_true, 'float')
        y_pred = tf.cast(tf.math.greater(y_pred, threshold),'float')
        self.tp.assign_add(tf.reduce_sum(tf.cast(y_true*y_pred, 'float'), axis=0,keepdims=True))
        self.fp.assign_add(tf.reduce_sum(tf.cast((1-y_true)*y_pred, 'float'), axis=0,keepdims=True))
        self.fn.assign_add(tf.reduce_sum(tf.cast(y_true*(1-y_pred), 'float'), axis=0,keepdims=True))
    def result(self):
        p = self.tp / (self.tp + self.fp + K.epsilon())
        r = self.tp / (self.tp + self.fn + K.epsilon())
        f1 = 2*p*r / (p+r+K.epsilon())
        f1 = tf.where(tf.math.is_nan(f1), tf.zeros_like(f1), f1)
        return tf.reduce_mean(f1)
    def reset_states(self):
        self.tp.assign(tf.zeros_like(self.tp))
        self.fp.assign(tf.zeros_like(self.tp))
        self.fn.assign(tf.zeros_like(self.tp))

    def get_config(self):
        config = super().get_config()
        config.update( {"class_count":self.class_count})
        return config

@tf.keras.utils.register_keras_serializable()
class NoamSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model = 512, warmup_steps=4000):
        super(NoamSchedule, self).__init__()

        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self):
        return {
            "d_model": self.d_model,
            "warmup_steps": self.warmup_steps
        }

@tf.keras.utils.register_keras_serializable()
class ExponentialDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self,initial_learning_rate,decay_steps, decay_rate, warmup_steps=4000):
        super(ExponentialDecayWithWarmup, self).__init__()
        self.initial_learning_rate = initial_learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = step/self.warmup_steps*self.initial_learning_rate
        arg2 = self.initial_learning_rate * self.decay_rate**((step-self.warmup_steps)/self.decay_steps)
        return tf.math.minimum(arg1, arg2)

    def get_config(self):
        return {
            "initial_learning_rate": self.initial_learning_rate,
            "decay_steps": self.decay_steps,
            "decay_rate": self.decay_rate,
            "warmup_steps": self.warmup_steps}

@tf.keras.utils.register_keras_serializable()
class CustomChallengeLoss_v1(tf.keras.losses.Loss):
    def __init__(self,inactive_index,class_count, weights=None, name="custom_loss"):
        super().__init__(name=name)
        self.name = name
        self.weights = weights
        self.inactive_index = tf.cast(tf.scatter_nd([[inactive_index]],[1],[class_count]), tf.float32)

        self.class_count = class_count
        self.inactive_index_input = inactive_index

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        # shape = tf.cast(tf.shape(y_true), tf.float32)
        # n = shape[0] #batch size
        # k = shape[1] #num of labels
        # num_true_label = tf.reduce_sum(y_true)+K.epsilon()
        # # num_false_label = tf.reduce_sum(1-y_true)+K.epsilon()
        # num_true_or_pred = tf.reduce_sum(y_true)+tf.reduce_sum(y_pred)-tf.reduce_sum(y_true*y_pred)
        # # log_y_pred = tf.math.log(y_pred+K.epsilon())
        # # bce = -1*(tf.reduce_sum(y_true*log_y_pred)/num_true_label+tf.reduce_sum((1-y_true)*tf.math.log(1-y_pred+K.epsilon()))/num_true_label)
        # # bce = -1*(tf.reduce_sum(y_true*log_y_pred)/n+tf.reduce_sum((1-y_true)*tf.math.log(1-y_pred+K.epsilon()))/n)
        # bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)

        # # custom_loss = -1*(num_true_or_pred/num_true_label-1)*(tf.reduce_sum(tf.linalg.matmul(y_true,log_y_pred,transpose_a=True)*self.weights)/(n*k))
        # custom_loss = -1*(tf.reduce_sum(tf.linalg.matmul(y_true,y_pred,transpose_a=True)*self.weights)/num_true_or_pred)

        bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
        
        inactive_outputs = tf.ones_like(y_true) * self.inactive_index
        observed_score = self._compute_loss(y_true,y_pred)
        correct_score = self._compute_loss(y_true,y_true)
        inactive_score = self._compute_loss(y_true,inactive_outputs)
        custom_loss = -1*(observed_score - inactive_score)/(correct_score - inactive_score + K.epsilon())

        return bce+custom_loss

    def _compute_loss(self,labels, outputs):
        normalization = tf.maximum(tf.reduce_sum(labels+outputs-labels*outputs,axis=1,keepdims=True), 1)
        return tf.reduce_sum(tf.linalg.matmul(labels,(outputs/normalization),transpose_a=True)*self.weights)
    
    def get_config(self):
        return {
            "inactive_index": self.inactive_index_input,
            "class_count":self.class_count,
            "name": self.name,
            "weights": self.weights
        }

@tf.keras.utils.register_keras_serializable()
class CustomChallengeLoss(tf.keras.losses.Loss):
    #changed normalization function to a traditional or, added check for all inactive labels
    def __init__(self,inactive_index,class_count, weights=None, name="custom_loss"):
        super().__init__(name=name)
        self.name = name
        self.weights = weights
        self.inactive_index = tf.cast(tf.scatter_nd([[inactive_index]],[1],[class_count]), tf.float32)

        self.class_count = class_count
        self.inactive_index_input = inactive_index

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)

        difference = y_true - y_pred
        ignore = tf.cast(tf.less(tf.abs(difference),0.3),tf.float32)
        mul_factor = 0.1*ignore  + 1.0-ignore
        bce = y_true * K.log(y_pred + K.epsilon())
        bce += (1 - y_true) * K.log(1 - y_pred + K.epsilon())
        bce *= mul_factor
        bce = -K.mean(bce)
        
        inactive_outputs = tf.ones_like(y_true) * self.inactive_index
        observed_score = self._compute_loss(y_true,y_pred)
        correct_score = self._compute_loss(y_true,y_true)
        inactive_score = self._compute_loss(y_true,inactive_outputs)

        if correct_score != inactive_score:
            custom_loss = -1*(observed_score - inactive_score)/(correct_score - inactive_score + K.epsilon())
        else:
            custom_loss = 0.0
            
        return bce+custom_loss

    def _compute_loss(self,labels, outputs):
        normalization = tf.maximum(tf.reduce_sum(labels+outputs-labels*outputs,axis=1,keepdims=True), 1)
        return tf.reduce_sum(tf.linalg.matmul(labels,(outputs/normalization),transpose_a=True)*self.weights)
    
    def get_config(self):
        return {
            "inactive_index": self.inactive_index_input,
            "class_count":self.class_count,
            "name": self.name,
            "weights": self.weights
        }

@tf.keras.utils.register_keras_serializable()
class AttentionPooling(tf.keras.layers.Layer):
    """
    Attention Pooling Layer for  Multiple Instance Learning

    """
    def __init__(self, units, L, kernel_initializer='glorot_uniform', kernel_regularizer=None,return_attention_scores=False,time_major=False, **kwargs ):
        super(AttentionPooling, self).__init__(**kwargs)

        self.L = L
        self.K = units
        self.time_major  = time_major
        self.return_attention_scores = return_attention_scores
        self.v_initializer = initializers.get(kernel_initializer)
        self.w_initializer = initializers.get(kernel_initializer)
        self.v_regularizer = regularizers.get(kernel_regularizer)
        self.w_regularizer = regularizers.get(kernel_regularizer)

    def build(self, input_shape):
        assert len(input_shape) == 3
        self.V = self.add_weight("V", shape=(input_shape[-1],self.L), trainable=True,
                                        initializer=self.v_initializer,regularizer= self.v_regularizer)
        self.W = self.add_weight("W", shape=(self.L,self.K), trainable=True,
                                        initializer=self.w_initializer,regularizer= self.w_regularizer)

    def call(self, inputs, mask=None):
        a = inputs
        x = inputs
        if mask is not None:
            x = inputs * tf.expand_dims(tf.cast(mask, "float32"), -1)
            a = x
        # (batch, time_step, features)
        if not self.time_major:
            a = tf.transpose(a, [1, 0, 2])
        # (time_step, batch, features)
        a = tf.matmul(a,self.V)
        # (time_step, batch, L)
        a = tf.tanh(a)
        a = tf.matmul(a,self.W)
        # (batch, time_step, units)
        if not self.time_major:
            a = tf.transpose(a, [1, 0, 2])
        a = tf.exp(a)
        if mask is not None:
            a = a * tf.expand_dims(tf.cast(mask, "float32"), -1)
        # (batch, time_step, units)
        a = a/tf.reduce_sum(a, axis=1, keepdims=True)
        x = tf.matmul(tf.expand_dims(a,-1),tf.expand_dims(x,-2))
        # (batch, time_step,label_classes, features)
        x = tf.reduce_sum(x, axis=1, keepdims=False)
        # (batch,label_classes, features)
        if self.return_attention_scores : return [x,a]
        return x

    def compute_output_shape(self, input_shape):
        shape = list(input_shape)
        assert len(shape) == 3
        output_shape = tf.TensorShape([shape[0],self.K,shape[-1]])
        attention_score_shape = tf.TensorShape(shape[:-1]+[self.K])
        if self.return_attention_scores : return [output_shape,attention_score_shape]
        return output_shape

    def compute_mask(self, inputs, mask=None):
        # remove mask
        return None

    def get_config(self):
        config = super().get_config()
        config.update( {"units": self.K,
                "time_major":self.time_major,"L": self.L,
                "return_attention_scores":self.return_attention_scores})
        return config


class TemporalSoftmax(tf.keras.layers.Layer):
    def call(self, inputs, mask=None):
        x = tf.exp(inputs)
        if mask is not None:
            broadcast_float_mask = tf.expand_dims(tf.cast(mask, "float32"), -1)
            x = x * broadcast_float_mask
        x = x/tf.reduce_sum(x, axis=1, keepdims=True)
        return x

@tf.keras.utils.register_keras_serializable()
class Predict_v1(tf.keras.layers.Layer):
    """
    Final output Layer for  Multiple Instance Learning
    #change this
    """
    def __init__(self, units, kernel_initializer='glorot_uniform', kernel_regularizer=None,use_bias=True,
                    bias_initializer='zeros',activation = 'sigmoid',bias_regularizer=None,**kwargs ):
        super(Predict_v1, self).__init__(**kwargs)
        self.units = units
        self.use_bias = use_bias
        self.activation = activation
        self.dense_layers = [tf.keras.layers.Dense(1, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer) for _ in range(units)]

    def build(self, input_shape):
        assert len(input_shape) == 3
        assert input_shape[1] == self.units

    def call(self, inputs, mask=None):
        # (batch, label_classes, features)
        x = tf.transpose(inputs, [1, 0, 2])
        # (label_classes, batch, features)
        x = [self.dense_layers[i](x[i,...]) for i in range(self.units)]
        x = tf.concat(x,axis=-1)
        # (batch, prediction)
        return x

    def compute_output_shape(self, input_shape):
        shape = list(input_shape)
        assert len(shape) == 3
        return tuple(shape[0]+[self.units])

    def get_config(self):
        config = super().get_config()
        config.update( {"units": self.units,
                "use_bias":self.use_bias,
                "activation":self.activation})
        return config

@tf.keras.utils.register_keras_serializable()
class Predict(tf.keras.layers.Layer):
    """
    Final output Layer for  Multiple Instance Learning
    #change this
    """
    def __init__(self, units, kernel_initializer='glorot_uniform', kernel_regularizer=None,use_bias=True,
                    bias_initializer='zeros',activation = 'sigmoid',bias_regularizer=None,**kwargs ):
        super(Predict, self).__init__(**kwargs)
        self.units = units
        self.use_bias = use_bias
        self.activation = activation
        self.activator = tf.keras.layers.Activation(activation)
        self.w_initializer = initializers.get(kernel_initializer)
        self.w_regularizer = regularizers.get(kernel_regularizer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

    def build(self, input_shape):
        assert len(input_shape) == 3
        assert input_shape[1] == self.units
        self.W = self.add_weight("W", shape=(input_shape[1:]), trainable=True,
                                        initializer=self.w_initializer,regularizer= self.w_regularizer)
        if self.use_bias:
            self.b = self.add_weight('b', shape=(self.units,),trainable=True,
                    initializer=self.bias_initializer,regularizer=self.bias_regularizer)

    def call(self, inputs, mask=None):
        # (batch, label_classes, features)
        x = tf.math.multiply(inputs,self.W)
        x = tf.reduce_sum(x,axis=-1)
        # (batch, label_classes)
        if self.use_bias:
            x = x+self.b
        x =  self.activator(x)
        return x

    def compute_output_shape(self, input_shape):
        shape = list(input_shape)
        assert len(shape) == 3
        return tuple(shape[0]+[self.units])

    def get_config(self):
        config = super().get_config()
        config.update( {"units": self.units,
                "use_bias":self.use_bias,
                "activation":self.activation})
        return config

@tf.keras.utils.register_keras_serializable()
class MaskedConv1D(tf.keras.layers.Layer):
    def __init__(self,units,kernel_size = 3,strides = 1,activation='relu',use_batch_norm=True, **kwargs ):
        super(MaskedConv1D,self).__init__()
        self.units = units
        self.use_batch_norm = use_batch_norm
        self.activation = activation
        if activation != 'thsig': self.activator = tf.keras.layers.Activation(activation)
        self.kernel_size = kernel_size
        self.strides = strides
        if use_batch_norm:
            self.normalizer = tf.keras.layers.BatchNormalization()
        self.conv1d = tf.keras.layers.Conv1D(units,kernel_size=kernel_size,strides=strides,padding='same',activation=None,**kwargs)
    
    def build(self, input_shape):
        assert len(input_shape) == 3

    def call(self, inputs, mask=None, training=False):
        x = inputs
        x = self.conv1d(x)
        if self.use_batch_norm:
            x = self.normalizer(x,training=training)
        if self.activation == 'thsig':
            x = tf.nn.tanh(x)*tf.nn.sigmoid(x)
        else:
            x =  self.activator(x)
        if mask is not None:
            x = x * tf.expand_dims(tf.cast(mask, "float32"), -1)
        return x

    def compute_output_shape(self, input_shape):
        assert len(input_shape) == 3
        return tf.TensorShape(input_shape[:2]+[self.units])

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {"units": self.units,
                "use_batch_norm":self.use_batch_norm,
                "activation":self.activation,
                "strides":self.strides,
                "kernel_size":self.kernel_size})
        return config

@tf.keras.utils.register_keras_serializable()
class MaskedConv2D(tf.keras.layers.Layer):
    def __init__(self,units,kernel_size = (3,1),strides = (1,1),activation='relu',use_batch_norm=True, **kwargs ):
        super(MaskedConv2D,self).__init__()
        self.units = units
        self.use_batch_norm = use_batch_norm
        self.activation = activation
        self.activator = tf.keras.layers.Activation(activation)
        self.kernel_size = kernel_size
        self.strides = strides
        if use_batch_norm:
            self.normalizer = tf.keras.layers.BatchNormalization()
        self.conv2d = tf.keras.layers.Conv2D(units,kernel_size=kernel_size, strides=strides, padding='same',activation=None,**kwargs)
        self.concat = tf.keras.layers.Concatenate()
    
    def build(self, input_shape):
        assert len(input_shape) == 4

    def call(self, inputs, mask=None,training=False):
        x = inputs
        x = self.conv2d(x)
        if self.use_batch_norm:
            x = self.normalizer(x,training=training)
        x =  self.activator(x)
        if mask is not None:
            x = x * tf.cast(mask, "float32")[...,tf.newaxis,tf.newaxis]
        return x

    def compute_output_shape(self, input_shape):
        assert len(input_shape) == 4
        return tf.TensorShape(input_shape[:3]+[self.units])

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {"units": self.units,
                "use_batch_norm":self.use_batch_norm,
                "activation":self.activation,
                "strides":self.strides,
                "kernel_size":self.kernel_size})
        return config

@tf.keras.utils.register_keras_serializable()
class Conv1DInception(tf.keras.layers.Layer):
    def __init__(self,units,kernel_size = [3,5,7],strides = 1,shrink_channels=False,activation='relu',use_batch_norm=True, **kwargs ):
        super(Conv1DInception,self).__init__()
        self.units = units
        self.kernels = len(kernel_size)
        self.shrink_channels = shrink_channels
        self.use_batch_norm = use_batch_norm
        self.activation = activation
        self.kernel_size=kernel_size
        self.strides=strides
        self.activator = tf.keras.layers.Activation(activation)
        if use_batch_norm:
            self.normalizer = tf.keras.layers.BatchNormalization()
        self.conv1d = [tf.keras.layers.Conv1D(units,kernel_size=k, strides=1, padding='same',activation=None,**kwargs) for k in kernel_size]
        if shrink_channels:
            self.shrink = [tf.keras.layers.Conv1D(1,kernel_size=1, strides=1, padding='same',activation=None,**kwargs) for k in kernel_size]
        self.concat = tf.keras.layers.Concatenate()
    
    def build(self, input_shape):
        assert len(input_shape) == 3

    def call(self, inputs, mask=None,training=None):
        x = inputs
        if self.shrink_channels:
            x = self.concat([conv(shrink(x)) for conv,shrink in zip(self.conv1d,self.shrink)])
        else :
            x = self.concat([conv(x) for conv in self.conv1d])
        if self.use_batch_norm:
            x = self.normalizer(x,training=training)
        x =  self.activator(x)
        if mask is not None:
            x = x * tf.expand_dims(tf.cast(mask, "float32"), -1)
        return x

    def compute_output_shape(self, input_shape):
        assert len(input_shape) == 3
        return tf.TensorShape(input_shape[:2]+[self.units*self.kernels])

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {"units": self.units,
                "use_batch_norm":self.use_batch_norm,
                "activation":self.activation,
                "shrink_channels":self.shrink_channels,
                "strides":self.strides,
                "kernel_size":self.kernel_size})
        return config

@tf.keras.utils.register_keras_serializable()
class Conv2DInception(tf.keras.layers.Layer):
    def __init__(self,units,kernel_size = [3,5,7],strides = 1,shrink_channels=False,activation='relu',use_batch_norm=True, **kwargs ):
        super(Conv2DInception,self).__init__()
        self.units = units
        self.kernels = len(kernel_size)
        self.shrink_channels = shrink_channels
        self.use_batch_norm = use_batch_norm
        self.activation = activation
        self.kernel_size=kernel_size
        self.strides=strides
        self.activator = tf.keras.layers.Activation(activation)
        if use_batch_norm:
            self.normalizer = tf.keras.layers.BatchNormalization()
        self.conv2d = [tf.keras.layers.Conv2D(units,kernel_size=k, strides=1, padding='same',activation=None,**kwargs) for k in kernel_size]
        if shrink_channels:
            self.shrink = [tf.keras.layers.Conv2D(1,kernel_size=1, strides=1, padding='same',activation=None,**kwargs) for k in kernel_size]
        self.concat = tf.keras.layers.Concatenate()
    
    def build(self, input_shape):
        assert len(input_shape) == 3

    def call(self, inputs, mask=None,training=False):
        x = inputs
        if self.shrink_channels:
            x = self.concat([conv(shrink(x)) for conv,shrink in zip(self.conv2d,self.shrink)])
        else :
            x = self.concat([conv(x) for conv in self.conv2d])
        if self.use_batch_norm:
            x = self.normalizer(x,training=training)
        x =  self.activator(x)
        if mask is not None:
            x = x * tf.cast(mask, "float32")[...,tf.newaxis,tf.newaxis]
        return x

    def compute_output_shape(self, input_shape):
        assert len(input_shape) == 3
        return tf.TensorShape(input_shape[:2]+[self.units*self.kernels])

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {"units": self.units,
                "use_batch_norm":self.use_batch_norm,
                "activation":self.activation,
                "shrink_channels":self.shrink_channels,
                "strides":self.strides,
                "kernel_size":self.kernel_size})
        return config

class SumPooling1D(tf.keras.layers.Layer):
    def __init__(self,pool_size=2,strides=None,padding='valid',**kwargs):
        super(SumPooling1D,self).__init__(**kwargs)
        if strides is None : strides = pool_size
        self.pool_size = pool_size
        self.strides = strides
        self.padding = padding
        
    # def _pad_input(self,x):
    #     input_shape = tf.shape(x)
    #     if (input_shape[1] % self.strides == 0):
    #         pad_along_width = max(self.pool_size - self.strides, 0)
    #     else:
    #         pad_along_width = max(self.pool_size - (input_shape[1] % self.strides), 0)
    #     pad_left = pad_along_width // 2
    #     pad_right = pad_along_width - pad_left

    #     return tf.pad(x,[[0,0],[pad_left,pad_right],[0,0]])

    def call(self, inputs, mask=None):
        x = inputs
        if mask is not None:
            x = x * tf.expand_dims(tf.cast(mask, "float32"), -1)
        x = tf.signal.frame(x,frame_length=self.pool_size,frame_step=self.strides,axis=1,pad_end=True)
        x = tf.reduce_sum(x,axis=2)
        return x

    def compute_mask(self, inputs, mask=None):
        # compute new mask
        #only valid padding implemented
        if mask is not None:
            mask = tf.signal.frame(tf.cast(mask,tf.float32),frame_length=self.pool_size,frame_step=self.strides,axis=1,pad_end=True)
            mask = tf.cast(tf.reduce_sum(mask,axis=-1),tf.bool)
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {
                "strides":self.strides,
                "pool_size":self.pool_size,
                "padding":self.padding})
        return config

@tf.keras.utils.register_keras_serializable()
class MaskedAvgPooling1D(tf.keras.layers.Layer):
    def __init__(self,pool_size=2,strides=None,padding='valid',**kwargs):
        super(MaskedAvgPooling1D,self).__init__(**kwargs)
        if strides is None : strides = pool_size
        self.pool_size = pool_size
        self.strides = strides
        self.padding = padding
        
    # def _pad_input(self,x):
    #     input_shape = tf.shape(x)
    #     if (input_shape[1] % self.strides == 0):
    #         pad_along_width = max(self.pool_size - self.strides, 0)
    #     else:
    #         pad_along_width = max(self.pool_size - (input_shape[1] % self.strides), 0)
    #     pad_left = pad_along_width // 2
    #     pad_right = pad_along_width - pad_left

    #     return tf.pad(x,[[0,0],[pad_left,pad_right],[0,0]])

    def build(self, input_shape):
        assert len(input_shape) == 3

    def compute_output_shape(self, input_shape):
        output_shape = list(input_shape)
        assert len(output_shape) == 3
        if output_shape[1] is not None:
            output_shape = output_shape[0]+tf.math.ceil(output_shape[1]/self.strides)+output_shape[2]
        return output_shape

    def call(self, inputs, mask=None):
        #only valid padding implemented
        x = inputs
        if mask is not None:
            x = x * tf.expand_dims(tf.cast(mask, "float32"), -1)
            m = tf.signal.frame(tf.expand_dims(tf.cast(mask,tf.float32),-1),
                frame_length=self.pool_size,frame_step=self.strides,axis=1,pad_end=True)
        x = tf.signal.frame(x,frame_length=self.pool_size,frame_step=self.strides,axis=1,pad_end=True)
        if mask is not None:
            x = tf.math.divide_no_nan(tf.reduce_sum(x,axis=2),tf.reduce_sum(m,axis=2))
        else:
            x = tf.reduce_sum(x,axis=2)/self.pool_size #dosent consider right pad
            #subclass AveragePooling1D and use super.call()
        return x

    def compute_mask(self, inputs, mask=None):
        # compute new mask
        #only valid padding implemented
        if mask is not None:
            mask = tf.signal.frame(tf.cast(mask,tf.float32),frame_length=self.pool_size,frame_step=self.strides,axis=1,pad_end=True)
            mask = tf.cast(tf.reduce_sum(mask,axis=-1),tf.bool)
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {
                "strides":self.strides,
                "pool_size":self.pool_size,
                "padding":self.padding})
        return config

@tf.keras.utils.register_keras_serializable()
class MaskedMaxPooling1D(tf.keras.layers.Layer):
    def __init__(self,**kwargs):
        super(MaskedMaxPooling1D,self).__init__()
        self.pool = tf.keras.layers.MaxPooling1D(**kwargs)
    
    def call(self, inputs, mask=None):
        return self.pool(inputs)

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer
        if mask is not None:
            x = tf.expand_dims(tf.cast(mask,tf.int8),-1)
            x = self.pool(x)
            return tf.cast(tf.squeeze(x, -1),tf.bool)
        return mask

@tf.keras.utils.register_keras_serializable()
class MaskedMaxPooling2D(tf.keras.layers.Layer):
    def __init__(self,**kwargs):
        super(MaskedMaxPooling2D,self).__init__()
        self.pool = tf.keras.layers.MaxPooling2D(**kwargs)
    
    def call(self, inputs, mask=None):
        return self.pool(inputs)

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer
        if mask is not None:
            x = tf.cast(mask,tf.int8)
            x = x[...,tf.newaxis,tf.newaxis]
            x = self.pool(x)
            return tf.cast(tf.squeeze(x, [-2,-1]),tf.bool)
        return mask

@tf.keras.utils.register_keras_serializable()
class SandE1D(tf.keras.layers.Layer):
    """
    Implementation of Squeeze and Excitation module from http://arxiv.org/abs/1709.01507

    """
    def __init__(self, reduction_ratio, kernel_initializer='glorot_uniform', 
                kernel_regularizer=None,return_channel_scores=False,
                bias_initializer='zeros',bias_regularizer=None,
                use_bias=True, **kwargs ):
        super(SandE1D, self).__init__(**kwargs)

        self.reduction_ratio = reduction_ratio
        self.use_bias  = use_bias
        self.return_channel_scores = return_channel_scores
        self.w1_initializer = initializers.get(kernel_initializer)
        self.w2_initializer = initializers.get(kernel_initializer)
        self.w1_regularizer = regularizers.get(kernel_regularizer)
        self.w2_regularizer = regularizers.get(kernel_regularizer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

    def build(self, input_shape):
        assert len(input_shape) == 3
        bottleneck = input_shape[-1]//self.reduction_ratio
        self.W1 = self.add_weight("W1", shape=(input_shape[-1],bottleneck), trainable=True,
                                        initializer=self.w1_initializer,regularizer= self.w1_regularizer)
        self.W2 = self.add_weight("W2", shape=(bottleneck,input_shape[-1]), trainable=True,
                                        initializer=self.w2_initializer,regularizer= self.w2_regularizer)
        if self.use_bias:
            self.b1 = self.add_weight('b1', shape=(bottleneck,),trainable=True,
                    initializer=self.bias_initializer,regularizer=self.bias_regularizer)
            self.b2 = self.add_weight('b2', shape=(input_shape[-1],),trainable=True,
                    initializer=self.bias_initializer,regularizer=self.bias_regularizer)

    def call(self, inputs, mask=None):
        a = inputs
        x = inputs
        # (batch,time_step,channel)
        # global avg pooling
        # different alogorithms can be tried 
        if mask is not None:
            mask = tf.expand_dims(tf.cast(mask, "float32"),-1)
            a *= mask
            a = tf.reduce_sum(inputs, axis=1) / tf.reduce_sum(mask,axis=1)
        else:
            a = tf.reduce_mean(inputs, axis=1)
        
        # (batch,channel)
        a = tf.matmul(a,self.W1)
        if self.use_bias: a = a+self.b1
        # (batch, bottleneck)
        a = tf.nn.relu(a)
        a = tf.matmul(a,self.W2)
        if self.use_bias: a = a+self.b2
        a = tf.sigmoid(a)
        # (batch, channel)
        x = x*tf.expand_dims(a,1)
        # (batch, time_step,channel)
        if self.return_channel_scores : return [x,a]
        return x

    def compute_output_shape(self, input_shape):
        shape = list(input_shape)
        assert len(shape) == 3
        output_shape = input_shape
        channel_score_shape = tf.TensorShape([shape[0],shape[-1]])
        if self.return_channel_scores : output_shape
        return [output_shape,channel_score_shape]

    def compute_mask(self, inputs, mask=None):
        # return mask as it is.
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {
                "reduction_ratio":self.reduction_ratio,
                "use_bias": self.use_bias,
                "return_channel_scores":self.return_channel_scores})
        return config

@tf.keras.utils.register_keras_serializable()
class ChannelAttention1D(tf.keras.layers.Layer):
    """

    """
    def __init__(self,use_bias=True, kernel_initializer='glorot_uniform', 
                kernel_regularizer=None,return_channel_scores=False,
                bias_initializer='zeros',bias_regularizer=None,
                 **kwargs ):
        super(ChannelAttention1D, self).__init__(**kwargs)

        self.attention_heads = 1
        # self.feature_count = 8

        self.use_bias  = use_bias
        self.return_channel_scores = return_channel_scores
        self.bias_initializer = initializers.get(bias_initializer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.w_initializer = initializers.get(kernel_initializer)
        self.w_regularizer = regularizers.get(kernel_regularizer)

    def build(self, input_shape):
        assert len(input_shape) == 3
        feature_count = 4
        self.W = self.add_weight("W", shape=(feature_count,self.attention_heads), trainable=True,
                                        initializer=self.w_initializer,regularizer= self.w_regularizer)
        if self.use_bias:
            self.b = self.add_weight('b', shape=(self.attention_heads,),trainable=True,
                    initializer=self.bias_initializer,regularizer=self.bias_regularizer)

    def _compute_features(self,inputs,mask=None):
        #input >> (channel,batch,time_step)
        if mask is None :
            mask = tf.ones(tf.shape(inputs)[1:])
        mask = tf.expand_dims(tf.cast(mask, "float32"),0)
        x = mask*inputs

        x_length = tf.reduce_sum(mask,axis=-1)
        mean = tf.reduce_sum(x, axis=-1)/x_length
        x_max = tf.reduce_max(x, axis=-1)
        # x_min = tf.reduce_min(x, axis=-1)
        # second order features
        x_div = (x-tf.expand_dims(mean,-1))*mask
        x_div_2 = x_div**2
        rms = tf.sqrt(tf.reduce_sum(x**2, axis=-1)/x_length)
        sd = tf.sqrt(tf.reduce_sum(x_div_2,axis=-1)/x_length)
        # crest_factor = 0.5*tf.math.divide_no_nan((x_max-x_min),rms) #crest factor
        # # # third order features
        # skewness = tf.math.divide_no_nan((tf.reduce_sum(x_div*x_div_2,axis=-1)/x_length),(sd**3))
        # kurtosis = tf.math.divide_no_nan((tf.reduce_sum(x_div_2*x_div_2,axis=-1)/x_length),(sd**4))

        return tf.stack([mean,x_max,rms,sd],axis=-1)


    def call(self, inputs, mask=None):
        a = inputs
        x = inputs # (batch,time_step,channel)
        a = tf.transpose(a, [2, 0, 1]) #(channel,batch,time_step)
        a = self._compute_features(a,mask) #(channel,batch,features)
        a = tf.matmul(a,self.W) #(channel,batch,1)
        if self.use_bias: a = a+self.b
        a = tf.sigmoid(a)
        a = tf.transpose(a, [1, 2, 0]) #(batch,1,channel)
        x = x*a # (batch, time_step,channel)
        if self.return_channel_scores : return [x,a]
        return x


    def compute_output_shape(self, input_shape):
        shape = list(input_shape)
        assert len(shape) == 3
        output_shape = input_shape
        channel_score_shape = tf.TensorShape([shape[0],shape[-1]])
        if self.return_channel_scores : output_shape
        return [output_shape,channel_score_shape]

    def compute_mask(self, inputs, mask=None):
        # return mask as it is.
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {
                "use_bias": self.use_bias,
                "return_channel_scores":self.return_channel_scores})
        return config

@tf.keras.utils.register_keras_serializable()
class ChannelSelfAttention1D(tf.keras.layers.Layer):
    """
    """
    def __init__(self, filters, kernel_size,
                strides=1, padding='SAME',
                use_bias=True, kernel_initializer='glorot_uniform', 
                kernel_regularizer=None,return_channel_scores=False,
                bias_initializer='zeros',bias_regularizer=None,
                 **kwargs ):
        super(ChannelSelfAttention1D, self).__init__(**kwargs)

        self.attention_heads = 1
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding

        self.use_bias  = use_bias
        self.return_channel_scores = return_channel_scores

        self.conv_kernel_initializer = initializers.get(kernel_initializer)
        self.conv_bias_initializer = initializers.get(bias_initializer)
        self.conv_kernel_regularizer = regularizers.get(kernel_regularizer)
        self.conv_bias_regularizer = regularizers.get(bias_regularizer)

        self.bias_initializer = initializers.get(bias_initializer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.w_initializer = initializers.get(kernel_initializer)
        self.w_regularizer = regularizers.get(kernel_regularizer)

    def build(self, input_shape):
        assert len(input_shape) == 3
        feature_count = self.filters
        kernel_shape = (self.kernel_size, 1, self.filters)

        self.kernel_conv = self.add_weight(
            name='kernel_conv',
            shape=kernel_shape,
            initializer=self.conv_kernel_initializer,
            regularizer=self.conv_kernel_regularizer,
            trainable=True,
            dtype=self.dtype)
        if self.use_bias:
            self.bias_conv = self.add_weight(
                name='bias_conv',
                shape=(self.filters,),
                initializer=self.conv_bias_initializer,
                regularizer=self.conv_bias_regularizer,
                trainable=True,
                dtype=self.dtype)

        self.W = self.add_weight("W", shape=(feature_count,self.attention_heads), trainable=True,
                                        initializer=self.w_initializer,regularizer= self.w_regularizer)
        if self.use_bias:
            self.b = self.add_weight('b', shape=(self.attention_heads,),trainable=True,
                    initializer=self.bias_initializer,regularizer=self.bias_regularizer)

    def _compute_features(self,inputs,mask=None):
        #input >> (channel,batch,time_step), mask >> (batch,time_step)
        if mask is None :
            mask = tf.ones(tf.shape(inputs)[1:])
        mask = tf.cast(mask, self.dtype)
        x_length = tf.reduce_sum(mask,axis=-1,keepdims=True) #(batch,1)
        x = tf.expand_dims(inputs,-1) #(channel,batch,time_step,1)
        x = tf.nn.conv1d(x, self.kernel_conv, stride=self.strides,
                            padding=self.padding, data_format='NWC')
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias_conv)
        x = tf.nn.relu(x)
        #x >> (channel,batch,time_step,kernel_size)
        x = x * tf.expand_dims(tf.expand_dims(mask, 0),-1)
        x_length = tf.expand_dims(x_length,0) #(1,batch,1)
        x = tf.math.divide_no_nan(tf.reduce_sum(x, axis=-2, keepdims=False),x_length)
        #x >> (channel,batch,features)
        return x


    def call(self, inputs, mask=None):
        a = inputs
        x = inputs # (batch,time_step,channel)
        # global avg pooling
        # different alogorithms can be tried 
        a = tf.transpose(a, [2, 0, 1]) #(channel,batch,time_step)
        a = self._compute_features(a,mask) #(channel,batch,features)
        a = tf.matmul(a,self.W) #(channel,batch,1)
        if self.use_bias: a = a+self.b
        a = tf.sigmoid(a)
        a = tf.transpose(a, [1, 2, 0]) #(batch,1,channel)
        x = x*a # (batch, time_step,channel)
        if self.return_channel_scores : return [x,a]
        return x


    def compute_output_shape(self, input_shape):
        shape = list(input_shape)
        assert len(shape) == 3
        output_shape = input_shape
        channel_score_shape = tf.TensorShape([shape[0],shape[-1]])
        if self.return_channel_scores : output_shape
        return [output_shape,channel_score_shape]

    def compute_mask(self, inputs, mask=None):
        # return mask as it is.
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {
                "use_bias": self.use_bias,
                "return_channel_scores":self.return_channel_scores,
                "filters":self.filters, "kernel_size":self.kernel_size,
                "strides":self.strides, "padding":self.padding})
        return config

@tf.keras.utils.register_keras_serializable()
class ChannelSelfAttention1D_v2(tf.keras.layers.Layer):
    """
    """
    def __init__(self, filters, kernel_size,L,
                strides=1, padding='SAME',
                use_bias=True, kernel_initializer='glorot_uniform', 
                kernel_regularizer=None,return_channel_scores=False,
                bias_initializer='zeros',bias_regularizer=None,
                 **kwargs ):
        super(ChannelSelfAttention1D_v2, self).__init__(**kwargs)

        self.L = L
        self.attention_heads = 1
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding

        self.use_bias  = use_bias
        self.return_channel_scores = return_channel_scores

        self.conv_kernel_initializer = initializers.get(kernel_initializer)
        self.conv_bias_initializer = initializers.get(bias_initializer)
        self.conv_kernel_regularizer = regularizers.get(kernel_regularizer)
        self.conv_bias_regularizer = regularizers.get(bias_regularizer)

        self.v_initializer = initializers.get(kernel_initializer)
        self.v_regularizer = regularizers.get(kernel_regularizer)
        self.b1_initializer = initializers.get(bias_initializer)
        self.b1_regularizer = regularizers.get(bias_regularizer)

        self.w_initializer = initializers.get(kernel_initializer)
        self.w_regularizer = regularizers.get(kernel_regularizer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

    def build(self, input_shape):
        assert len(input_shape) == 3
        feature_count = self.filters
        kernel_shape = (self.kernel_size, 1, self.filters)

        self.kernel_conv = self.add_weight(
            name='kernel_conv',
            shape=kernel_shape,
            initializer=self.conv_kernel_initializer,
            regularizer=self.conv_kernel_regularizer,
            trainable=True,
            dtype=self.dtype)
        if self.use_bias:
            self.bias_conv = self.add_weight(
                name='bias_conv',
                shape=(self.filters,),
                initializer=self.conv_bias_initializer,
                regularizer=self.conv_bias_regularizer,
                trainable=True,
                dtype=self.dtype)

        self.V = self.add_weight("V", shape=(feature_count,self.L), trainable=True,
                                        initializer=self.v_initializer,regularizer= self.v_regularizer)

        self.W = self.add_weight("W", shape=(self.L,self.attention_heads), trainable=True,
                                        initializer=self.w_initializer,regularizer= self.w_regularizer)
        if self.use_bias:
            self.b = self.add_weight('b', shape=(self.attention_heads,),trainable=True,
                    initializer=self.bias_initializer,regularizer=self.bias_regularizer)
            self.b1 = self.add_weight('b1', shape=(self.L,),trainable=True,
                    initializer=self.b1_initializer,regularizer=self.b1_regularizer)

    def _compute_features(self,inputs,mask=None):
        #input >> (channel,batch,time_step), mask >> (batch,time_step)
        if mask is None :
            mask = tf.ones(tf.shape(inputs)[1:])
        mask = tf.cast(mask, self.dtype)
        x_length = tf.reduce_sum(mask,axis=-1,keepdims=True) #(batch,1)
        x = tf.expand_dims(inputs,-1) #(channel,batch,time_step,1)
        x = tf.nn.conv1d(x, self.kernel_conv, stride=self.strides,
                            padding=self.padding, data_format='NWC')
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias_conv)
        x = tf.nn.relu(x)
        #x >> (channel,batch,time_step,kernel_size)
        x = x * tf.expand_dims(tf.expand_dims(mask, 0),-1)
        x_length = tf.expand_dims(x_length,0) #(1,batch,1)
        x = tf.math.divide_no_nan(tf.reduce_sum(x, axis=-2, keepdims=False),x_length)
        #x >> (channel,batch,features)
        return x


    def call(self, inputs, mask=None):
        a = inputs
        x = inputs # (batch,time_step,channel)
        # global avg pooling
        # different alogorithms can be tried 
        a = tf.transpose(a, [2, 0, 1]) #(channel,batch,time_step)
        a = self._compute_features(a,mask) #(channel,batch,features)
        a = tf.matmul(a,self.V)#(channel,batch,L)
        if self.use_bias:
            a = a + self.b1
        a = tf.tanh(a)
        a = tf.matmul(a,self.W) #(channel,batch,1)
        if self.use_bias: a = a+self.b
        a = tf.sigmoid(a)
        a = tf.transpose(a, [1, 2, 0]) #(batch,1,channel)
        x = x*a # (batch, time_step,channel)
        if self.return_channel_scores : return [x,a]
        return x


    def compute_output_shape(self, input_shape):
        shape = list(input_shape)
        assert len(shape) == 3
        output_shape = input_shape
        channel_score_shape = tf.TensorShape([shape[0],shape[-1]])
        if self.return_channel_scores : output_shape
        return [output_shape,channel_score_shape]

    def compute_mask(self, inputs, mask=None):
        # return mask as it is.
        return mask

    def get_config(self):
        config = super().get_config()
        config.update( {
                "use_bias": self.use_bias,"L":self.L,
                "return_channel_scores":self.return_channel_scores,
                "filters":self.filters, "kernel_size":self.kernel_size,
                "strides":self.strides, "padding":self.padding})
        return config