import torch


class Feedback_CNN_pretrain(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(Feedback_CNN_pretrain, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test

        self.relu = torch.nn.ReLU()

        self.conv1 = torch.nn.Conv1d(in_channels=num_leads, out_channels=64, kernel_size=15, stride=2, padding=3)
        self.conv_pad1 = torch.nn.Conv1d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=2)
        self.max_pool1 = torch.nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
        self.batch_norm1 = torch.nn.BatchNorm1d(64)

        self.conv2 = torch.nn.Conv1d(in_channels=64, out_channels=128, kernel_size=13, stride=2, padding=1)
        self.conv_pad2 = torch.nn.Conv1d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2)
        self.max_pool2 = torch.nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
        self.batch_norm2 = torch.nn.BatchNorm1d(128)

        self.conv3 = torch.nn.Conv1d(in_channels=128, out_channels=256, kernel_size=11, stride=2)
        self.conv_pad3 = torch.nn.Conv1d(in_channels=256, out_channels=256, kernel_size=5, stride=1, padding=2)
        self.max_pool3 = torch.nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
        self.batch_norm3 = torch.nn.BatchNorm1d(256)

        self.conv4 = torch.nn.Conv1d(in_channels=256, out_channels=512, kernel_size=9, stride=2)
        self.conv_pad4 = torch.nn.Conv1d(in_channels=512, out_channels=512, kernel_size=5, stride=1, padding=2)
        self.max_pool4 = torch.nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
        self.batch_norm4 = torch.nn.BatchNorm1d(512)

        self.conv5 = torch.nn.Conv1d(in_channels=512, out_channels=256, kernel_size=7, stride=2)
        self.conv_pad5 = torch.nn.Conv1d(in_channels=256, out_channels=256, kernel_size=5, stride=1, padding=2)
        self.max_pool5 = torch.nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
        self.batch_norm5 = torch.nn.BatchNorm1d(256)

        self.conv6 = torch.nn.Conv1d(in_channels=256, out_channels=128, kernel_size=5, stride=2)
        self.conv_pad6 = torch.nn.Conv1d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2)
        self.max_pool6 = torch.nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
        self.batch_norm6 = torch.nn.BatchNorm1d(128)

        self.conv7 = torch.nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, stride=2)
        self.conv_pad7 = torch.nn.Conv1d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=2)
        self.max_pool7 = torch.nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
        self.batch_norm7 = torch.nn.BatchNorm1d(64)

        self.drop_out = torch.nn.Dropout(p=0.2)

        self.fc1 = torch.nn.Linear(64 * 20, 273)

        self.ln1 = torch.nn.Linear(7, 32)
        self.ln2 = torch.nn.Linear(32, 7)

        self.fc2 = torch.nn.Linear(273 + 7, self.num_classes)

        self.fc3 = torch.nn.Linear(273, self.num_classes)

        self.last_dim = 512
        self.proj_dim = 256  # for projection head

        self.linear1 = torch.nn.Linear(64 * 20, self.last_dim)
        self.bn1 = torch.nn.BatchNorm1d(self.last_dim)
        self.linear2 = torch.nn.Linear(self.last_dim, 64)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.linear3 = torch.nn.Linear(64 + 7, self.num_classes)
        self.pretrain_linear3 = torch.nn.Linear(64, self.num_classes)

        self.projection_head = torch.nn.Sequential(
            torch.nn.Linear(self.last_dim, self.proj_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.proj_dim, self.proj_dim)
        )

    def forward(self, input_data, features=None):

        if features != None:
            age = features[:, 0]
            gender = features[:, 1]

            age_mask = torch.ones_like(age)
            gender_mask = torch.ones_like(gender)
            age_mask[torch.isnan(age)] = 0
            gender_mask[torch.isnan(gender)] = 0

            age[torch.isnan(age)] = 0
            age /= 100

            man = gender.clone()
            man[torch.isnan(gender)] = 0
            woman = gender.clone()
            woman[torch.isnan(gender)] = 1
            woman += 1
            woman[woman != 1] = 0

            age = age.unsqueeze(1)
            man = man.unsqueeze(1)
            woman = woman.unsqueeze(1)
            age_mask = age_mask.unsqueeze(1)
            gender_mask = gender_mask.unsqueeze(1)

            mean_RR = features[:, 2]
            std_RR = features[:, 3]
            # if torch.sum(torch.isnan(mean_RR)).item() != 0:
            #     print(torch.sum(torch.isnan(mean_RR)).item())

            mean_RR = mean_RR.unsqueeze(1)
            std_RR = std_RR.unsqueeze(1)
            extracted_features = torch.cat((mean_RR, std_RR, age, age_mask, man, woman, gender_mask), dim=1)  # shape=(batch, 7)

        x = input_data
        # print(x.size())  # torch.Size([batch_size, 6, 3000])
        x = self.conv1(x)
        x1 = x
        x = self.relu(x)
        x = self.conv_pad1(x)
        x = self.batch_norm1(x)
        x = x + x1
        x = self.relu(x)
        x = self.max_pool1(x)
        x = self.drop_out(x)

        # print(x.size())  # torch.Size([batch_size, 64, 1496])

        x = self.conv2(x)
        x2 = x
        x = self.relu(x)
        x = self.conv_pad2(x)
        x = self.batch_norm2(x)
        x = x + x2
        x = self.relu(x)
        x = self.max_pool2(x)
        x = self.drop_out(x)

        # print(x.size())  # torch.Size([batch_size, 128, 743])

        x = self.conv3(x)
        x3 = x
        x = self.relu(x)
        x = self.conv_pad3(x)
        x = self.batch_norm3(x)
        x = x + x3
        x = self.relu(x)
        x = self.drop_out(x)
        x = self.max_pool3(x)
        x = self.drop_out(x)

        # print(x.size())  # torch.Size([batch_size, 256, 367])

        x = self.conv4(x)
        x4 = x
        x = self.relu(x)
        x = self.conv_pad4(x)
        x = self.batch_norm4(x)
        x = x + x4
        x = self.relu(x)
        x = self.drop_out(x)
        x = self.max_pool4(x)
        x = self.drop_out(x)

        # print(x.size())  # torch.Size([batch_size, 512, 180])

        x = self.conv5(x)
        x5 = x
        x = self.relu(x)
        x = self.conv_pad5(x)
        x = self.batch_norm5(x)
        x = x + x5
        x = self.relu(x)
        x = self.drop_out(x)
        x = self.max_pool5(x)
        x = self.drop_out(x)

        # print(x.size())  # torch.Size([batch_size, 256, 87])

        x = self.conv6(x)
        x6 = x
        x = self.relu(x)
        x = self.conv_pad6(x)
        x = self.batch_norm6(x)
        x = x + x6
        x = self.relu(x)
        x = self.drop_out(x)
        x = self.max_pool6(x)
        x = self.drop_out(x)

        # print(x.size())  # torch.Size([batch_size, 128, 42])

        x = self.conv7(x)
        x7 = x
        x = self.relu(x)
        x = self.conv_pad7(x)
        x = self.batch_norm7(x)
        x = x + x7
        x = self.relu(x)
        x = self.drop_out(x)
        x = self.max_pool7(x)
        x = self.drop_out(x)

        # print(x.size())  # torch.Size([batch_size, 64, 20])

        x = torch.flatten(x, 1)  # flatten all dimensions except batch

        x_proj = self.relu(self.bn1(self.linear1(x))) 
        # print(x.size())  # torch.Size([512, 512])
        projected_reps = self.projection_head(x_proj)
        # print(projected_reps.size())  # torch.Size([512, 256])

        # print(x.size())  # torch.Size([batch_size, (64*20)])
        x = self.fc1(x)
        # print(x.size())  # torch.Size([batch_size, 128])
        x = self.relu(x)

        if features != None:
            extracted_features = self.ln1(extracted_features)
            extracted_features = self.relu(extracted_features)
            extracted_features = self.ln2(extracted_features)
            x = self.fc2(torch.cat((x, extracted_features), dim=1))
        else:
            x = self.fc3(x)

        return x, projected_reps
