
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

from tensorflow.keras import backend as K
import keras
from keras.layers import BatchNormalization, Activation, Dropout
from keras.layers import GlobalMaxPooling1D, GlobalAveragePooling1D
from keras.layers import Conv1D, Dense, Add, MaxPooling1D, LSTM, Input, CuDNNLSTM
from keras.layers import AveragePooling1D, Concatenate, Bidirectional
from keras.layers import Reshape, Dot, Multiply, Layer
from keras.optimizers import Adam
from keras.models import Model
from keras import regularizers
from keras.losses import binary_crossentropy

import tensorflow as tf

class SeqWeightedAttention(Layer):
    """
    references:
    [1] https://github.com/CyberZHG/keras-self-attention/blob/master/keras_self_attention/seq_weighted_attention.py
    """

    def __init__(self, use_bias=True, return_attention=False, **kwargs):
        super(SeqWeightedAttention, self).__init__(**kwargs)
        self.supports_masking = True
        self.use_bias = use_bias
        self.return_attention = return_attention
        self.W, self.b = None, None

    def get_config(self):
        config = {
            'use_bias': self.use_bias,
            'return_attention': self.return_attention,
        }
        base_config = super(SeqWeightedAttention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def build(self, input_shape):
        # input_shape : (batch, time_steps, channels)
        self.W = self.add_weight(shape=(int(input_shape[-1]), int(input_shape[-1])),
                                 name='{}_W'.format(self.name),
                                 initializer=keras.initializers.get('glorot_uniform'))
        if self.use_bias:
            self.b = self.add_weight(shape=(int(input_shape[-1]),),
                                     name='{}_b'.format(self.name),
                                     initializer=keras.initializers.get('zeros'))

        self.U = self.add_weight(shape=(int(input_shape[-1]), 1),
                                 name='{}_U'.format(self.name),
                                 initializer=keras.initializers.get('glorot_uniform'))

        super(SeqWeightedAttention, self).build(input_shape)

    def call(self, x):
        uit = K.dot(x, self.W)
        if self.use_bias:
            uit += self.b
        uit = K.tanh(uit)
        logits = K.dot(uit, self.U)
        x_shape = K.shape(x)
        logits = K.reshape(logits, (x_shape[0], x_shape[1]))
        ai = K.exp(logits - K.max(logits, axis=-1, keepdims=True))

        att_weights = ai / (K.sum(ai, axis=1, keepdims=True) + K.epsilon())
        weighted_input = x * K.expand_dims(att_weights)
        result = K.sum(weighted_input, axis=1)
        if self.return_attention:
            return [result, att_weights]
        return result

    def compute_output_shape(self, input_shape):
        output_len = input_shape[2]
        if self.return_attention:
            return [(input_shape[0], output_len), (input_shape[0], input_shape[1])]
        return input_shape[0], output_len

    def compute_mask(self, _, input_mask=None):
        if self.return_attention:
            return [None, None]
        return None

    @staticmethod
    def get_custom_objects():
        return {'SeqWeightedAttention': SeqWeightedAttention}


class ResSENet:
    def __init__(self):
        return

    def _bn_relu(self, layer, dropout=0, **params):
        layer = BatchNormalization()(layer)
        layer = Activation(params['conv_activation'])(layer)
        # layer = LeakyReLU(alpha=0.3)(layer)

        if dropout > 0:
            layer = Dropout(params['conv_dropout'])(layer)
        return layer

    def add_conv_weight(self, layer, filter_length, num_filters,
                        stride=1, padding='same', **params):
        layer = Conv1D(filters=num_filters,
                       kernel_size=filter_length,
                       strides=stride,
                       padding='same',
                       kernel_initializer=params['conv_init'])(layer)
        # layer = SeparableConv1D(filters=num_filters,
        #                         kernel_size=filter_length,
        #                         strides=stride,
        #                         padding='same',
        #                         data_format='channels_last',
        #                         kernel_initializer=params['conv_init'])(layer)
        return layer

    def conv_layer(self, layer, num_filters,
                   filter_length, dropout=0, stride=1, **params):
        conv = self._bn_relu(layer, dropout=dropout, **params)
        conv = self.add_conv_weight(conv,
                                    filter_length=filter_length,
                                    num_filters=num_filters,
                                    stride=stride, **params)
        return conv

    def se_block(self, layer, num_channels):
        r = 8
        ave_layer = GlobalAveragePooling1D()(layer)
        selayer = Dense(units=int(num_channels/r), use_bias=False)(ave_layer)
        selayer = Activation('sigmoid')(selayer)
        selayer = Dense(units=num_channels, use_bias=False)(selayer)
        selayer = Activation('sigmoid')(selayer)
        selayer = Reshape((1,num_channels))(selayer)
        outlayer = Multiply()([layer, selayer])
        print('the seblock output shape is {}'.format(outlayer))
        return outlayer

    def conv_block1(self, layer, num_filters, filter_length, pool_stride=2, **params):
        conv_layer = self.conv_layer(layer, num_filters, filter_length, dropout=0, **params)
        conv_layer = self.conv_layer(conv_layer, num_filters, filter_length, dropout=params['conv_dropout'], **params)
        conv_layer = AveragePooling1D(pool_size=pool_stride)(conv_layer)
        conv_layer = self.se_block(conv_layer, num_filters)
        short_cut = MaxPooling1D(pool_size=pool_stride)(layer)
        out_layer = Add()([conv_layer, short_cut])
        print('the out layer of block1, shape={}'.format(out_layer.shape))
        return out_layer

    def conv_block2(self, layer, num_filters, filter_length, pool_stride=2, **params):
        conv_layer = self.conv_layer(layer, num_filters, filter_length, dropout=0, **params)
        conv_layer = self.conv_layer(conv_layer, num_filters, filter_length, dropout=params['conv_dropout'], **params)
        conv_layer = AveragePooling1D(pool_size=pool_stride)(conv_layer)
        conv_layer = self.se_block(conv_layer, num_filters)
        short_cut = self.add_conv_weight(layer,
                                         filter_length=params['conv_filter_length3'],
                                         num_filters=num_filters,
                                         **params)
        short_cut = MaxPooling1D(pool_size=pool_stride)(short_cut)
        out_layer = Add()([conv_layer, short_cut])
        print('the out layer of block3, shape={}'.format(out_layer.shape))
        return out_layer

    def add_resnet_layers(self, layer, **params):
        pool_stride = params['pool_stride']
        layer = self.add_conv_weight(layer,
                                     params['conv_filter_length1'],
                                     params['conv_num_filters_start'],
                                     **params)

        resnet_layer = self.conv_block2(layer, num_filters=32,
                                        filter_length=params['conv_filter_length1'],
                                        **params)
        resnet_layer = self.conv_block1(resnet_layer, num_filters=32,
                                        filter_length=params['conv_filter_length1'],
                                        **params)

        resnet_layer = self.conv_block2(resnet_layer, num_filters=64,
                                        filter_length=params['conv_filter_length1'],
                                        **params)
        resnet_layer = self.conv_block1(resnet_layer, num_filters=64,
                                        filter_length=params['conv_filter_length2'],
                                        **params)

        resnet_layer = self.conv_block2(resnet_layer, num_filters=128, filter_length=params['conv_filter_length2'],
                                        **params)
        resnet_layer = self.conv_block1(resnet_layer, num_filters=128, filter_length=params['conv_filter_length2'],
                                        **params)

        resnet_layer = self.conv_block2(resnet_layer, num_filters=256, filter_length=params['conv_filter_length2'],
                                        **params)
        resnet_layer = self.conv_block1(resnet_layer, num_filters=256, filter_length=params['conv_filter_length1'],
                                        **params)
        resnet_layer = self.conv_block2(resnet_layer, num_filters=256, filter_length=params['conv_filter_length1'],
                                        **params)
        resnet_layer = self._bn_relu(resnet_layer, dropout=0, **params)

        resnet_layer = self.se_block(resnet_layer, 256)

        layer1 = GlobalAveragePooling1D()(resnet_layer)
        layer2 = GlobalMaxPooling1D()(resnet_layer)
        concat_layer = Concatenate()([layer1, layer2])
        print('the out layer of concat_layer, shape={}'.format(concat_layer.shape))
        return concat_layer
        # return resnet_layer

    def add_lstm_layer(self, layer, hidden_units, **params):
        layer = Bidirectional(CuDNNLSTM(units=int(hidden_units / 2), return_sequences=True))(layer)
        #layer = LSTM(units=hidden_units, return_sequences=True)(layer)
        print('bilstm layer output shape is ' + str(layer.shape))
        # layer1 = GlobalAveragePooling1D()(layer)
        # layer2 = GlobalMaxPooling1D()(layer)
        # layer = Concatenate()([layer1, layer2])
        return layer

    def add_output_layer(self, layer, **params):
        layer = Dense(params["num_categories"])(layer)
        print('Dense layer output shape is ' + str(layer.shape))
        return Activation('sigmoid')(layer)

    def add_compile(self, model, **params):
        optimizer = Adam(lr=params['learning_rate'],
                         clipnorm=params.get('clipnorm', 1))
        # TODO 这里后期再在loss上进行改进
        model.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])
        # model.compile(loss='categorical_crossentropy',
        #               optimizer=optimizer,
        #               metrics=['accuracy'])

    def build_network(self, **params):
        inputs = Input(shape=params['input_shape'],
                       dtype='float32',
                       name='inputs')
        layer = self.add_resnet_layers(inputs, **params)

        # layer = self.add_lstm_layer(layer, 256, **params)
        # layer = SeqWeightedAttention()(layer)
        output = self.add_output_layer(layer, **params)

        model = Model(inputs=[inputs], outputs=[output])
        if params.get('compile', True):  # 如果params中不存在compile，返回True
            self.add_compile(model, **params)
        return model