import torch
from collections import OrderedDict
import torchvision.models as models


class BaseResnet34_SSL(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(BaseResnet34_SSL, self).__init__()
        self.num_leads = num_leads
        self.num_classes  = num_classes
        self.configs = configs
        self.info_for_test = info_for_test
        self.last_dim = 512
        self.proj_dim = 256 # for projection head
        self.feature_dim = 20
        self.model = models.resnet18(num_classes=2)
        self.model.conv1 = torch.nn.Sequential(OrderedDict([
            ('pre-conv1', torch.nn.Conv2d(1,32, kernel_size=(1, 15), stride=(1,3), padding=(0,3), bias=False)),
            ('pre-bn1', torch.nn.BatchNorm2d(32)),
            ('pre-relu1', torch.nn.ReLU(inplace=True)),
            ('pre-conv2', torch.nn.Conv2d(32,64, kernel_size=(1, 11), stride=(1,3), padding=(0,1), bias=False)),
            ('pre-bn2', torch.nn.BatchNorm2d(64)),
            ('pre-relu2', torch.nn.ReLU(inplace=True)),
            ('pre-conv3', torch.nn.Conv2d(64,64, kernel_size=(1, 7), stride=(1,3), padding=(0,1), bias=False)),
        ]))
        self.model.fc = torch.nn.Identity()

        self.projection_head = torch.nn.Sequential(
                                    torch.nn.Linear(self.last_dim, self.last_dim), torch.nn.BatchNorm1d(self.last_dim), torch.nn.ReLU(),
                                    torch.nn.Linear(self.last_dim, self.proj_dim), torch.nn.BatchNorm1d(self.proj_dim)
                                )
        self.lead_classifier = torch.nn.Linear(in_features=self.proj_dim, out_features=self.num_leads, bias=True)
        self.cls = torch.nn.Linear(self.last_dim * self.num_leads + self.feature_dim, self.num_classes)

    def forward(self, input_data, features=None):
        # import IPython;IPython.embed()
        outputs = self.model(input_data.reshape(input_data.shape[0] * self.num_leads, 1, 1, input_data.shape[2]))
        projected_reps = self.projection_head(outputs)
    
        reps = outputs.reshape(input_data.shape[0],-1)
        
        cls_reps = self.cls(torch.cat((reps, features), dim=1))

        return cls_reps, projected_reps
        # cls_reps for classification, projected_reps for ssl



class BaseResnet18(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(BaseResnet18, self).__init__()
        self.num_leads = num_leads
        self.num_classes  = num_classes
        self.configs = configs
        self.info_for_test = info_for_test
        self.last_dim = 512
        self.proj_dim = 256 # for projection head
        self.model = models.resnet34(num_classes=2)
        self.model.conv1 = torch.nn.Sequential(OrderedDict([
            ('pre-conv1', torch.nn.Conv2d(1,32, kernel_size=(1, 15), stride=(1,3), padding=(0,3), bias=False)),
            ('pre-bn1', torch.nn.BatchNorm2d(32)),
            ('pre-relu1', torch.nn.ReLU(inplace=True)),
            ('pre-conv2', torch.nn.Conv2d(32,64, kernel_size=(1, 11), stride=(1,3), padding=(0,1), bias=False)),
            ('pre-bn2', torch.nn.BatchNorm2d(64)),
            ('pre-relu2', torch.nn.ReLU(inplace=True)),
            ('pre-conv3', torch.nn.Conv2d(64,64, kernel_size=(1, 7), stride=(1,3), padding=(0,1), bias=False)),
        ]))
        self.model.fc = torch.nn.Identity()
        self.projection_head = torch.nn.Sequential(
                                    torch.nn.Linear(self.last_dim, self.proj_dim),
                                    torch.nn.ReLU(),
                                    torch.nn.Linear(self.proj_dim, self.proj_dim)
                                )
        self.cls = torch.nn.Linear(self.last_dim, self.num_classes)

    def forward(self, input_data, features=None):
        reps = self.model(input_data.unsqueeze(1))
        projected_reps = self.projection_head(reps)
        cls_reps = self.cls(reps)

        return cls_reps, projected_reps
        # cls_reps for classification, projected_reps for ssl

class BaseResnet34(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(BaseResnet34, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test
        self.last_dim = 512
        self.proj_dim = 256 # for projection head
        self.model = models.resnet34(num_classes=2)
        self.model.conv1 = torch.nn.Sequential(OrderedDict([
            ('pre-conv1', torch.nn.Conv2d(1,32, kernel_size=(1, 15), stride=(1,3), padding=(0,3), bias=False)),
            ('pre-bn1', torch.nn.BatchNorm2d(32)),
            ('pre-relu1', torch.nn.ReLU(inplace=True)),
            ('pre-conv2', torch.nn.Conv2d(32,64, kernel_size=(1, 11), stride=(1,3), padding=(0,1), bias=False)),
            ('pre-bn2', torch.nn.BatchNorm2d(64)),
            ('pre-relu2', torch.nn.ReLU(inplace=True)),
            ('pre-conv3', torch.nn.Conv2d(64,64, kernel_size=(1, 7), stride=(1,3), padding=(0,1), bias=False)),
        ]))
        self.model.fc = torch.nn.Identity()
        self.projection_head = torch.nn.Sequential(
                                    torch.nn.Linear(self.last_dim, self.proj_dim),
                                    torch.nn.ReLU(),
                                    torch.nn.Linear(self.proj_dim, self.proj_dim)
                                )
        self.cls = torch.nn.Linear(self.last_dim, self.num_classes)

    def forward(self, input_data, features=None):
        reps = self.model(input_data.unsqueeze(1))
        projected_reps = self.projection_head(reps)
        cls_reps = self.cls(reps)

        return cls_reps, projected_reps

class BaseResnet50(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(BaseResnet50, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test
        self.last_dim = 2048
        self.proj_dim = 512 # for projection head
        self.model = models.resnet50(num_classes=2)
        self.model.conv1 = torch.nn.Sequential(OrderedDict([
            ('pre-conv1', torch.nn.Conv2d(1,32, kernel_size=(1, 15), stride=(1,3), padding=(0,3), bias=False)),
            ('pre-bn1', torch.nn.BatchNorm2d(32)),
            ('pre-relu1', torch.nn.ReLU(inplace=True)),
            ('pre-conv2', torch.nn.Conv2d(32,64, kernel_size=(1, 11), stride=(1,3), padding=(0,1), bias=False)),
            ('pre-bn2', torch.nn.BatchNorm2d(64)),
            ('pre-relu2', torch.nn.ReLU(inplace=True)),
            ('pre-conv3', torch.nn.Conv2d(64,64, kernel_size=(1, 7), stride=(1,3), padding=(0,1), bias=False)),
        ]))
        self.model.fc = torch.nn.Identity()
        self.projection_head = torch.nn.Sequential(
                                    torch.nn.Linear(self.last_dim, self.proj_dim),
                                    torch.nn.ReLU(),
                                    torch.nn.Linear(self.proj_dim, self.proj_dim)
                                )
        self.cls = torch.nn.Linear(self.last_dim, self.num_classes)

    def forward(self, input_data, features=None):
        reps = self.model(input_data.unsqueeze(1))
        projected_reps = self.projection_head(reps)
        cls_reps = self.cls(reps)

        return cls_reps, projected_reps
