import torch

class IBN_Block(torch.nn.Module):
    def __init__(self, in_planes, out_planes, version, stride, padding, no_dropout1 = False):
        super(IBN_Block, self).__init__()

        self.conv = torch.nn.Conv1d(in_channels=in_planes, out_channels=out_planes, kernel_size=17-2*version, stride=stride, padding=padding)
        self.relu = torch.nn.ReLU()
        self.conv_pad = torch.nn.Conv1d(in_channels=out_planes, out_channels=out_planes, kernel_size=5, stride=1, padding=2)
        self.batch_norm = torch.nn.BatchNorm1d(int(out_planes/2))
        self.instance_norm = torch.nn.InstanceNorm1d(int(out_planes/2))
        if no_dropout1:
            self.drop_out1 = torch.nn.Dropout(p=0)
        else:
            self.drop_out1 = torch.nn.Dropout(p=0.2)
        self.max_pool = torch.nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
        self.drop_out2 = torch.nn.Dropout(p=0.2)

    def forward(self, x):
        x = self.conv(x)
        x_saved = x
        x = self.relu(x)
        x = self.conv_pad(x)
        x1, x2 = torch.split(x, split_size_or_sections=int(x.size(1)/2), dim=1)
        x1 = self.batch_norm(x1)
        x2 = self.instance_norm(x2)
        x = torch.cat((x1, x2), dim=1)
        x = x + x_saved
        x = self.relu(x)
        x = self.drop_out1(x)
        x = self.max_pool(x)
        x = self.drop_out2(x)

        return x

class Block(torch.nn.Module):
    def __init__(self, in_planes, out_planes, version, stride, padding, no_dropout1 = False):
        super(Block, self).__init__()

        self.conv = torch.nn.Conv1d(in_channels=in_planes, out_channels=out_planes, kernel_size=17-2*version, stride=stride, padding=padding)
        self.relu = torch.nn.ReLU()
        self.conv_pad = torch.nn.Conv1d(in_channels=out_planes, out_channels=out_planes, kernel_size=5, stride=1, padding=2)
        self.batch_norm = torch.nn.BatchNorm1d(out_planes)
        if no_dropout1:
            self.drop_out1 = torch.nn.Dropout(p=0)
        else:
            self.drop_out1 = torch.nn.Dropout(p=0.2)
        self.max_pool = torch.nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
        self.drop_out2 = torch.nn.Dropout(p=0.2)

    def forward(self, x):
        x = self.conv(x)
        x_saved = x
        x = self.relu(x)
        x = self.conv_pad(x)
        x = self.batch_norm(x)
        x = x + x_saved
        x = self.relu(x)
        x = self.drop_out1(x)
        x = self.max_pool(x)
        x = self.drop_out2(x)

        return x

class Base_CNN(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(Base_CNN, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test
        self.last_dim = 512
        self.proj_dim = 256  # for projection head
        self.num_features = 20 # this must be changed when the size of a feature vector differs

        self.relu = torch.nn.ReLU()

        self.layer1 = self.make_layer(in_planes=num_leads, out_planes=64, version=1, stride=2, padding=3, no_dropout1 = True, ibn=False)
        self.layer2 = self.make_layer(in_planes=64, out_planes=128, version=2, stride=2, padding=1, no_dropout1 = True, ibn=False)
        self.layer3 = self.make_layer(in_planes=128, out_planes=256, version=3, stride=2, padding=0, no_dropout1 = False, ibn=False)
        self.layer4 = self.make_layer(in_planes=256, out_planes=512, version=4, stride=2, padding=0, no_dropout1 = False, ibn=False)
        self.layer5 = self.make_layer(in_planes=512, out_planes=256, version=5, stride=2, padding=0, no_dropout1 = False, ibn=False)
        self.layer6 = self.make_layer(in_planes=256, out_planes=128, version=6, stride=2, padding=0, no_dropout1 = False, ibn=False)
        self.layer7 = self.make_layer(in_planes=128, out_planes=64, version=7, stride=2, padding=0, no_dropout1 = False, ibn=False)

        self.pretrain_linear = torch.nn.Linear(64 * 28, self.last_dim)
        self.pretrain_bn = torch.nn.BatchNorm1d(self.last_dim)
        self.projection_head = torch.nn.Sequential(
            torch.nn.Linear(self.last_dim, self.proj_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.proj_dim, self.proj_dim)
        )

        self.fc1 = torch.nn.Linear(64 * 28, 273)

        self.ln1 = torch.nn.Linear(self.num_features, 32)
        self.ln2 = torch.nn.Linear(32, self.num_features)

        self.fc2 = torch.nn.Linear(273 + self.num_features, self.num_classes)
        self.fc3 = torch.nn.Linear(273, self.num_classes) # used when features = None
        
    def make_layer(self, in_planes, out_planes, version, stride, padding, no_dropout1, ibn=False):
        layers = []
        if ibn:
            layers.append(IBN_Block(in_planes, out_planes, version, stride, padding, no_dropout1))
        else:
            layers.append(Block(in_planes, out_planes, version, stride, padding, no_dropout1))
        return torch.nn.Sequential(*layers)


    def forward(self, input_data, features=None):
        x = input_data
        # print(x.size())  # torch.Size([batch_size, num_leads, 3600])
        x = self.layer1(x)
        # print(x.size())  # torch.Size([batch_size, 64, 1796])
        x = self.layer2(x)
        # print(x.size())  # torch.Size([batch_size, 128, 893])
        x = self.layer3(x)
        # print(x.size())  # torch.Size([batch_size, 256, 442])
        x = self.layer4(x)
        # print(x.size())  # torch.Size([batch_size, 512, 217])
        x = self.layer5(x)
        # print(x.size())  # torch.Size([batch_size, 256, 106])
        x = self.layer6(x)
        # print(x.size())  # torch.Size([batch_size, 128, 51])
        x = self.layer7(x)
        # print(x.size())  # torch.Size([batch_size, 64, 25])

        x = torch.flatten(x, 1)  # flatten all dimensions except batch

        x_proj = self.relu(self.pretrain_bn(self.pretrain_linear(x))) 
        # print(x.size())  # torch.Size([batch_size, self.last_dim])
        projected_reps = self.projection_head(x_proj)
        # print(projected_reps.size())  # torch.Size([batch_size, self.proj_dim])

        # print(x.size())  # torch.Size([batch_size, (64*25)])
        x = self.fc1(x)
        # print(x.size())  # torch.Size([batch_size, 273])
        x = self.relu(x)

        if features != None:
            # print(features.size())    # torch.Size([batch_size, self.num_features])
            features = self.ln1(features)
            features = self.relu(features)
            features = self.ln2(features)
            # print(features.size())    # torch.Size([batch_size, self.num_features])
            x = self.fc2(torch.cat((x, features), dim=1))
        else:
            x = self.fc3(x)

        return x, projected_reps
