from keras import Model
from keras.layers import Conv1D, MaxPooling1D, BatchNormalization, Activation, Add, Input, ZeroPadding1D
from keras.layers import Flatten, Dense, GlobalAveragePooling1D, Dropout, GRU, Bidirectional, LeakyReLU

class ResNet:

    def __init__(self):
        pass

    def Conv1DBlock(self, x, nFilters, width=3, stride=1, pool=2, activation='relu'):
        x = Conv1D(nFilters, width, strides=stride, activation=activation, padding='same')(x)
        if pool is not None:
            x = MaxPooling1D(pool)(x)
        return x

    def ResidUnit(self, x, nFilters, width=3, stride=1, extraConv=False):
        y = ZeroPadding1D(padding=1)(x)
        y = Conv1D(nFilters, width, strides=stride, use_bias=False)(y)
        y = BatchNormalization(axis=1, epsilon=1e-5)(y)
        y = Activation('relu')(y)
        y = ZeroPadding1D(padding=1)(y)
        y = Conv1D(nFilters, width, strides=stride, use_bias=False)(y)
        y = BatchNormalization(axis=1, epsilon=1e-5)(y)
        if extraConv:
            x = Conv1D(nFilters, 1, strides=1, use_bias=False)(x)
            x = BatchNormalization(axis=1, epsilon=1e-5)(x)
        z = Add()([y, x])
        z = Activation('relu')(z)
        return z

    def ConvBlock(self, x, nFilters, width=3):
        y = Conv1D(nFilters, width, strides=1, activation='linear', padding='same')(x)
        y = BatchNormalization(axis=1, epsilon=1e-5)(y)
        y = Activation('relu')(y)
        y = ZeroPadding1D()(y)
        y = Conv1D(nFilters, width, strides=1, activation='linear', padding='same')(y)
        y = BatchNormalization(axis=1, epsilon=1e-5)(y)
        y = Activation('relu')(y)
        y = Conv1D(nFilters, width, strides=1, activation='linear', padding='same')(y)
        y = BatchNormalization(axis=1, epsilon=1e-5)(y)
        xShort = Conv1D(nFilters, width, strides=1, activation='linear', padding='same')(x)
        xShort = BatchNormalization(axis=1, epsilon=1e-5)(xShort)
        z = Add()([y, xShort])
        z = Activation('relu')(z)
        return z

    def ResidStack(self, x, nFilters, width=3, pool=2, stage=0):
        if stage == 0:
            x = self.ResidUnit(x, nFilters, width=width, stride=1, extraConv=True)
            x = self.ResidUnit(x, nFilters, width=width, stride=1)
        else:
            x = self.ResidUnit(x, nFilters, width=width, stride=1, extraConv=True)
            x = self.ResidUnit(x, nFilters, width=width, stride=1)
        if pool is not None:
            x = MaxPooling1D(pool)(x)
        return x

    def ResNet1D(self, inputShape=(72000,12), nFilters=12, nClasses=9):

        inputs = Input(inputShape)

        # Stage 1
        x = ZeroPadding1D(padding=3)(inputs)
        x = Conv1D(12, 7, strides=2, use_bias=False, name="conv1")(x)
        x = BatchNormalization(axis=1, epsilon=1e-5)(x)
        x = Activation("relu")(x)
        x = MaxPooling1D(3, strides=2, padding="same")(x)

        # Stage 2
        x = self.ResidStack(x, nFilters=nFilters, stage=0)

        # Stage 3
        x = self.ResidStack(x, nFilters=2*nFilters, stage=1)

        # Stage 4
        x = self.ResidStack(x, nFilters=4*nFilters, stage=2)

        # Stage 5
        x = self.ResidStack(x, nFilters=8*nFilters, stage=3)

        # Fully connected layer
        x = GlobalAveragePooling1D()(x)
        x = Dropout(0.5)(x)
        x = Dense(nClasses, activation='softmax')(x)

        # Create model
        model = Model(inputs=inputs, outputs=x, name='ResNet1D')

        return model

#######################################################################################
if __name__ == '__main__':
    import numpy as np

    myResid = ResNet()
    inputShape = (72000,12)
    y = myResid.ResNet1D(inputShape)