import tensorflow as tf
from tensorflow import keras as K

class Conv1Dblock(K.Model):
    def __init__(self, nFilters, width=3, stride=1, pool=2, act='relu'):
        super().__init__()
        self.pool = pool
        self.conv = K.layers.Conv1D(nFilters, width, stride=stride, activation=act, padding='same')
        if pool is not None:
            self.pool = K.layers.MaxPool1D(pool)

    def call(self, x):
        x = self.conv(x)
        if self.pool is not None:
            x = self.pool(x)
        return x

class ResidUnit(K.Model):
    def __init__(self, nFilters, width=3):
        super().__init__()
        self.conv1 = K.layers.Conv1D(nFilters, width, strides=1, activation='relu', padding='same')
        self.conv2 = K.layers.Conv1D(nFilters, width, strides=1, activation='linear', padding='same')

    def call(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        return y + x

class ResidStack(K.Model):
    def __init__(self, nFilters, width=3, pool=2):
        super().__init__()
        self.pool = pool
        self.conv = K.layers.Conv1D(nFilters, 1, strides=1, padding='same')
        self.resUnit1 = ResidUnit(nFilters, width=width)
        self.resUnit2 = ResidUnit(nFilters, width=width)
        if pool is not None:
            self.pool = K.layers.MaxPool1D(pool)

    def call(self, x):
        x = self.conv(x)
        x = self.resUnit1(x)
        x = self.resUnit2(x)
        if self.pool is not None:
            x = self.pool(x)
        return x

class Oshea_ResNet(K.Model):
    def __init__(self, nClasses, nFilters=16):
        super().__init__()
        self.nClasses = nClasses
        self.resid1 = ResidStack(nFilters=nFilters)
        self.resid2 = ResidStack(nFilters=nFilters)
        self.resid3 = ResidStack(nFilters=nFilters)
        self.resid4 = ResidStack(nFilters=nFilters)
        self.resid5 = ResidStack(nFilters=nFilters)
        self.resid6 = ResidStack(nFilters=nFilters)
        self.flatten = K.layers.Flatten()
        self.dense1 = K.layers.Dense(128, activation='relu')
        self.dense2 = K.layers.Dense(128, activation='relu')
        self.dense3 = K.layers.Dense(nClasses, activation='softmax')

    def call(self, x):
        x = self.resid1(x)
        x = self.resid2(x)
        x = self.resid3(x)
        x = self.resid4(x)
        x = self.resid5(x)
        x = self.resid6(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        return x

##################################################################################
if __name__ == '__main__':
    model = Oshea_ResNet(9)