# ResNet based on https://www.nature.com/articles/s41467-020-15432-4.pdf

import tensorflow
from tensorflow import keras as K

dropoutPerc = 0.50

class Conv1Dblock(K.Sequential):
    def __init__(self, nFilters, width=20, stride=1):
        super().__init__()
        self.conv = K.layers.Conv1D(nFilters, width, strides=stride, padding='same')
        self.batchnorm = K.layers.BatchNormalization()
        self.act = K.layers.Activation('relu')

    def call(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.act(x)
        return x

class ResidBlock(K.Sequential):
    def __init__(self, nFilters, width=20, pool=4):
        super().__init__()
        self.maxpool = K.layers.MaxPool1D(pool)
        self.conv1x1 = K.layers.Conv1D(nFilters, 1, strides=1, padding='same')
        self.conv1 = K.layers.Conv1D(nFilters, width, strides=2, activation='linear', padding='same')
        self.conv2 = K.layers.Conv1D(nFilters, width, strides=2, activation='linear', padding='same')
        self.batchnorm1 = K.layers.BatchNormalization()
        self.act1 = K.layers.Activation('relu')
        self.dropout1 = K.layers.Dropout(dropoutPerc)
        self.batchnorm2 = K.layers.BatchNormalization()
        self.act2 = K.layers.Activation('relu')
        self.dropout2 = K.layers.Dropout(dropoutPerc)

    def call(self, top, bot):
        # Top
        top = self.maxpool(top)
        top = self.conv1x1(top)
        # Bottom
        bot = self.conv1(bot)
        bot = self.batchnorm1(bot)
        bot = self.act1(bot)
        bot = self.dropout1(bot)
        bot = self.conv2(bot)
        # Merge
        merge = bot + top
        # New Top and Bottom
        newTop = merge
        newBot = self.batchnorm2(merge)
        newBot = self.act2(newBot)
        newBot = self.dropout2(newBot)
        return newTop, newBot

class NatureResid1D(K.Sequential):
    def __init__(self,  nClasses=27):
        super().__init__()
        self.preamble = Conv1Dblock(64, width=20)
        self.resid1 = ResidBlock(64, width=20)
        self.resid2 = ResidBlock(64, width=20) #128
        self.resid3 = ResidBlock(64, width=20) #182
        self.resid4 = ResidBlock(64, width=20) #256
        self.flatten = K.layers.Flatten()
        self.dense1 = K.layers.Dense(128, activation='relu')
        self.dropout1 = K.layers.Dropout(dropoutPerc)
        self.dense2 = K.layers.Dense(nClasses, activation='sigmoid')

    def call(self, x):
        x = self.preamble(x)
        x1,x2 = self.resid1(x,x)
        x1,x2 = self.resid2(x1,x2)
        x1,x2 = self.resid3(x1,x2)
        x1,x2 = self.resid4(x1,x2)
        y = self.flatten(x2)
        y = self.dense1(y)
        y = self.dropout1(y)
        y = self.dense2(y)
        return y

def get_forked_model(nSamples=4096, nChannels=12, nClasses=27, nRF_features=14):
    inputs1 = K.layers.Input(shape=(nSamples, nChannels))
    inputs2 = K.layers.Input(shape=(nRF_features,))

    x = Conv1Dblock(64, width=20)(inputs1)
    x1, x2 = ResidBlock(64, width=20)(x, x)
    x1, x2 = ResidBlock(64, width=20)(x1, x2)
    x1, x2 = ResidBlock(64, width=20)(x1, x2)
    x1, x2 = ResidBlock(64, width=20)(x1, x2)
    x = K.layers.Flatten()(x2)
    y = K.layers.concatenate([x, inputs2])
    y = K.layers.Dense(128, activation='relu')(y)
    y = K.layers.Dropout(dropoutPerc)(y)
    y = K.layers.Dense(nClasses, activation='sigmoid')(y)

    model = K.Model(inputs=[inputs1, inputs2], outputs=y)
    model.compile("adam", "binary_crossentropy", ["accuracy"])
    return model


#########################################################################################
if __name__ == '__main__':
    import numpy as np
    x = np.ones([2,4096,12])
    model = NatureResid1D()
    model.compile("adam", "binary_crossentropy", ["accuracy"])
    model.fit(x, np.array([1,0]))
    print(model.summary())