import torch
import torch.nn as nn


class Classifier(nn.Module):

    def __init__(self, in_features, hidden_features, classes, drop=0.5):
        super().__init__()

        # introduce the final classifier
        self.fc1 = nn.Linear(in_features=in_features, out_features=hidden_features)
        self.relu1 = nn.ELU()
        self.drop = nn.Dropout(drop)
        self.fc2 = nn.Linear(in_features=hidden_features, out_features=classes)

    def forward(self, x):

        # feed features through classifier
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.drop(x)
        x = self.fc2(x)

        return x


class BasicBlock(nn.Module):

    def __init__(self, features, kernel_size=3, padding=1):
        super().__init__()

        # first convolution
        self.conv1 = nn.Conv1d(in_channels=features, out_channels=features, kernel_size=kernel_size, stride=1,
                               padding=padding)
        self.act1 = nn.ELU()

        # second convolution
        self.conv2 = nn.Conv1d(in_channels=features, out_channels=features, kernel_size=kernel_size, stride=1,
                               padding=padding)
        self.act2 = nn.ELU()

    def forward(self, x):

        # save the start tensor for the skip connection
        x_start = x

        # feed the tensor through the convolutions
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)

        # add the two tensors (skip connection)
        x = x_start + x
        return x


class Bottleneck(nn.Module):

    def __init__(self, features, kernel_size=3, padding=1, down_kernel_size=5, down_pad=2):
        super().__init__()

        # first convolution
        self.conv1 = nn.Conv1d(in_channels=features, out_channels=features * 2, kernel_size=down_kernel_size, stride=3,
                               padding=down_pad)
        self.act1 = nn.ELU()

        # second convolution
        self.conv2 = nn.Conv1d(in_channels=features * 2, out_channels=features * 2, kernel_size=kernel_size, stride=1,
                               padding=padding)
        self.act2 = nn.ELU()

        # make the identity mapping
        self.ident_conv = nn.Conv1d(in_channels=features, out_channels=features * 2, kernel_size=1, stride=3)
        self.ident_act = nn.ELU()

    def forward(self, x):

        # save the start tensor for the skip connection
        x_start = x
        x_start = self.ident_conv(x_start)
        x_start = self.ident_act(x_start)

        # feed the tensor through the convolutions
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)

        # add the two tensors (skip connection)
        x = x_start + x
        return x


class StartBlock(nn.Module):

    def __init__(self, in_channels, features, kernel_size=3, padding=1):
        super().__init__()

        # zero convolution
        self.conv0 = nn.Conv1d(in_channels=in_channels, out_channels=features, kernel_size=kernel_size, stride=1,
                               padding=padding)
        self.act0 = nn.ELU()

        # first convolution
        self.conv1 = nn.Conv1d(in_channels=features, out_channels=features * 2, kernel_size=kernel_size, stride=1,
                               padding=padding)
        self.act1 = nn.ELU()

        # second convolution
        self.conv2 = nn.Conv1d(in_channels=features * 2, out_channels=features * 2, kernel_size=kernel_size, stride=1,
                               padding=padding)
        self.act2 = nn.ELU()

        # make the identity mapping
        self.ident_conv = nn.Conv1d(in_channels=features, out_channels=features * 2, kernel_size=1, stride=1)
        self.ident_act = nn.ELU()

    def forward(self, x):

        # feed the tensor through the first convolutions
        x = self.conv0(x)
        x = self.act0(x)

        # save the start tensor for the skip connection
        x_start = x
        x_start = self.ident_conv(x_start)
        x_start = self.ident_act(x_start)

        # apply the second and third convolution
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)

        # add the two tensors (skip connection)
        x = x_start + x
        return x


class GenericResNet(nn. Module):

    def __init__(self, in_channels=12, features=16, multiplier=5, classes=9, depth=3,
                 kernel_size=3, padding=1, down_kernel_size=5, down_padding=2):
        super().__init__()

        # save the values
        self.in_channels = in_channels
        self.features = features
        self.multiplier = multiplier
        self.classes = classes
        self.depth = depth
        self.kernel_size = kernel_size
        self.padding = padding
        self.down_kernel_size = down_kernel_size
        self.down_padding = down_padding

        # initialize the first few convolutions
        self.first_convs = StartBlock(in_channels, features, kernel_size, padding)

        # make the repetition for the repetitive structure
        self.block_repetition = []
        feature_multiplier = 2
        for i in range(depth):

            # make the basic block
            basic_block = nn.Sequential(*[BasicBlock(features * feature_multiplier, kernel_size=kernel_size,
                                                     padding=padding) for _ in range(multiplier)])
            self.block_repetition.append(basic_block)

            # make the bottleneck
            bottleneck = Bottleneck(features * feature_multiplier, kernel_size=kernel_size, padding=padding,
                                    down_kernel_size=down_kernel_size, down_pad=down_padding)
            self.block_repetition.append(bottleneck)

            # higher the feature multiplier
            feature_multiplier *= 2
        self.block_repetition = nn.Sequential(*self.block_repetition)

        # initialize the final basic block
        self.final_basic = nn.Sequential(*[BasicBlock(features * feature_multiplier, kernel_size=kernel_size,
                                                      padding=padding) for _ in range(multiplier)])

        # make a global average pooling
        self.pool = nn.AdaptiveMaxPool1d(1)

        # make the characteristics extractor
        self.characteristics = Classifier(2, 4, 2)

        # make a classifier
        self.classifier = Classifier(features*feature_multiplier + 2, 100, classes)

    def forward(self, x, spl):

        # forward the signal through the first layer
        x = self.first_convs(x)

        # forward the signal through the repetitive block
        x = self.block_repetition(x)

        # forward the signal through the final layer
        x = self.final_basic(x)

        # pool the things
        x = self.pool(x)
        # make the reshape
        x = x.view(x.shape[0], -1)

        # make patient characteristics
        spl = self.characteristics(spl)

        # concatenate features and patient characteristics
        x = torch.cat([x, spl], dim=1)

        # feed through the classifier
        x = self.classifier(x)

        return x
