from tensorflow.keras import backend as K
#import tensorflow as tf

# def _bn_relu(layer, dropout=0, **params):
def _bn_relu(layer, dropout=0.5, **params):
    from tensorflow.keras.layers import BatchNormalization
    from tensorflow.keras.layers import Activation
    # ADC: changed axis
    layer = BatchNormalization(axis=2)(layer)
    layer = Activation(params["conv_activation"])(layer)

    if dropout > 0:
        from tensorflow.keras.layers import Dropout
        layer = Dropout(params["conv_dropout"])(layer)

    return layer

def add_conv_weight(layer,
        filter_length,
        num_filters,
        subsample_length=1,
        **params):
    # ADC
    # from keras.layers import Conv1D
    # layer = Conv1D(
    #     filters=num_filters,
    #     kernel_size=filter_length,
    #     strides=subsample_length,
    #     padding='same',
    #     kernel_initializer=params["conv_init"])(layer)
    from tensorflow.keras.layers import Conv1D
    layer = Conv1D(
        filters=num_filters,
        kernel_size=filter_length,
        strides=subsample_length,
        padding='same',
        kernel_initializer=params["conv_init"])(layer)
    return layer


def add_conv_layers(layer, **params):
    for subsample_length in params["conv_subsample_lengths"]:
        layer = add_conv_weight(
                    layer,
                    params["conv_filter_length"],
                    params["conv_num_filters_start"],
                    subsample_length=subsample_length,
                    **params)
        layer = _bn_relu(layer, **params)
    return layer

def resnet_block(
        layer,
        num_filters,
        subsample_length,
        block_index,
        **params):
    from tensorflow.keras.layers import Add
    from tensorflow.keras.layers import MaxPooling1D
    # from tensorflow.keras.layers.core import Lambda
    from tensorflow.keras.layers import Lambda

    def zeropad(x):
        y = K.zeros_like(x)
        return K.concatenate([x, y], axis=2)

    def zeropad_output_shape(input_shape):
        shape = list(input_shape)
        # ADC: commented
        # assert len(shape) == 3
        shape[2] *= 2
        return tuple(shape)

    # ADC: changed from MaxPooling1D to MaxPooling2D
    # shortcut = MaxPooling1D(pool_size=subsample_length)(layer)
    shortcut = MaxPooling1D(pool_size=subsample_length)(layer)
    # ADC: commented
    zero_pad = (block_index % params["conv_increase_channels_at"]) == 0 \
        and block_index > 0
    if zero_pad is True:
        shortcut = Lambda(zeropad, output_shape=zeropad_output_shape)(shortcut)

    # print('1****')
    # print('shortcut:', shortcut)
    # print('layer:', layer)
    # print('1****')

    for i in range(params["conv_num_skip"]):
        if not (block_index == 0 and i == 0):
            layer = _bn_relu(
                layer,
                dropout=params["conv_dropout"] if i > 0 else 0,
                **params)
        layer = add_conv_weight(layer,
            params["conv_filter_length"],
            num_filters,
            subsample_length if i == 0 else 1,
            **params)

    # print('2****')
    # print('shortcut:', shortcut)
    # print('layer:', layer)
    # print('2****')

    layer = Add()([shortcut, layer])
    return layer

def get_num_filters_at_index(index, num_start_filters, **params):
    return 2**int(index / params["conv_increase_channels_at"]) \
        * num_start_filters

def add_resnet_layers(layer, **params):
    layer = add_conv_weight(
        layer,
        params["conv_filter_length"],
        params["conv_num_filters_start"],
        subsample_length=1,
        **params)
    layer = _bn_relu(layer, **params)
    # ADC: commented
    for index, subsample_length in enumerate(params["conv_subsample_lengths"]):
        num_filters = get_num_filters_at_index(
            index, params["conv_num_filters_start"], **params)
        layer = resnet_block(
            layer,
            num_filters,
            subsample_length,
            index,
            **params)
    layer = _bn_relu(layer, **params)
    return layer

def add_output_layer(layer, **params):
    # from tensorflow.keras.layers.core import Dense, Activation
    from tensorflow.keras.layers import Dense, Activation
    # from tensorflow.keras.layers.wrappers import TimeDistributed
    from tensorflow.keras.layers import TimeDistributed
    from tensorflow.keras.layers import GlobalAveragePooling1D
    # from keras.layers.GlobalAveragePooling1D import GlobalAveragePooling1D
    # ADC: followed orig CNNet
    # layer = TimeDistributed(Dense(params["num_categories"]))(layer)
    # layer = TimeDistributed(Dense(27, activation='sigmoid'))(layer)

    # ADC: trying GlobalAveragePooling1D
    layer = GlobalAveragePooling1D()(layer)

    layer = Dense(27)(layer)
    # layer = TimeDistributed(Dense(27))(layer)
    # layer = Activation('softmax')(layer)
    layer = Activation('sigmoid')(layer)
    return layer

def add_compile(model, **params):
    from tensorflow.keras.optimizers import Adam
    optimizer = Adam(
        lr=params["learning_rate"],
        clipnorm=params.get("clipnorm", 1))

    model.compile(loss='binary_crossentropy',
                  optimizer=optimizer,
                  metrics=['accuracy'])


def CNNetR(inputs):
    print('Inside resnet-network: Input size = ', inputs)

    from tensorflow.keras.models import Model
    from tensorflow.keras.layers import Input
    import json
    from tensorflow.keras.utils import plot_model
    from tensorflow.keras.layers import Flatten

    params = json.load(open("config_resnet.json"))

    # if params.get('is_regular_conv', False):
    #     layer = add_conv_layers(inputs, **params)
    # else:
    layer = add_resnet_layers(inputs, **params)

    output = add_output_layer(layer, **params)

    output = Flatten()(output)

    # model = Model(inputs=[inputs], outputs=[output])

    # summarize the model
    # plot_model(model, 'model.png', show_shapes=True)

    # if params.get("compile", True):
    #     add_compile(model, **params)

    return output
