import math
import torch
from transformers import BertConfig, BertModel, ViTConfig, DistilBertConfig, DistilBertModel, DeiTConfig
from transformers.models.vit.modeling_vit import ViTEncoder
from transformers.models.deit.modeling_deit import DeiTEncoder

class ECGEmbedding(torch.nn.Module):
    def __init__(self, 
                num_leads, 
                num_classes, 
                configs, 
                hiddin_size, 
                info_for_test=None, 
                flatten=False,
                pooling=False,
                pooling_out=None,
    ):

        super(ECGEmbedding, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test
        self.flatten = flatten
        self.pooling = pooling
        self.pooling_out = pooling_out

        model_config = self.configs["MODEL_CONFIG"]
        conv_channel = model_config["conv_channel"]

        self.conv1 = torch.nn.Conv1d(
            self.num_leads,
            conv_channel // 2,
            kernel_size=14,
            stride=3,
            padding=2,
            bias=False,
        )
        self.ln1 = torch.nn.BatchNorm1d(conv_channel // 2)
        self.relu1 = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(
            conv_channel // 2,
            conv_channel,
            kernel_size=14,
            stride=3,
            padding=0,
            bias=False,
        )
        self.ln2 = torch.nn.BatchNorm1d(conv_channel)
        self.relu2 = torch.nn.ReLU()
        self.conv3 = torch.nn.Conv1d(
            conv_channel,
            conv_channel,
            kernel_size=10,
            stride=2,
            padding=0,
            bias=False,
        )
        self.ln3 = torch.nn.BatchNorm1d(conv_channel)
        self.relu3 = torch.nn.ReLU()
        self.conv4 = torch.nn.Conv1d(
            conv_channel,
            conv_channel,
            kernel_size=10,
            stride=2,
            padding=0,
            bias=False,
        )
        self.ln4 = torch.nn.BatchNorm1d(conv_channel)
        self.relu4 = torch.nn.ReLU()
        self.conv5 = torch.nn.Conv1d(
            conv_channel,
            conv_channel,
            kernel_size=10,
            stride=1,
            padding=0,
            bias=False,
        )
        self.ln5 = torch.nn.BatchNorm1d(conv_channel)
        self.relu5 = torch.nn.ReLU()
        self.conv6 = torch.nn.Conv1d(
            conv_channel,
            hiddin_size,
            kernel_size=10,
            stride=1,
            padding=0,
            bias=False,
        )
        self.ln6 = torch.nn.BatchNorm1d(hiddin_size)
        self.relu6 = torch.nn.ReLU()
        self.avgpool = torch.nn.AdaptiveAvgPool1d(5)
        self.flatten_layer = torch.nn.Flatten()

    def forward(self, input_recording):
        # input_recording = input_recording.squeeze(-3)
        """
        input_recording.shape : (batch_size, num_leads, len_recording)
        """
        batch_size = input_recording.shape[0]
        ecg_conv_embeddings = self.conv1(input_recording)
        ecg_conv_embeddings = self.ln1(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu1(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv2(ecg_conv_embeddings)
        ecg_conv_embeddings = self.ln2(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu2(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv3(ecg_conv_embeddings)
        ecg_conv_embeddings = self.ln3(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu3(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv4(ecg_conv_embeddings)
        ecg_conv_embeddings = self.ln4(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu4(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv5(ecg_conv_embeddings)
        ecg_conv_embeddings = self.ln5(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu5(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv6(ecg_conv_embeddings)
        ecg_conv_embeddings = self.ln6(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu6(ecg_conv_embeddings)
        
        if self.pooling:
            ecg_conv_embeddings = self.avgpool(ecg_conv_embeddings)

        ecg_conv_embeddings = ecg_conv_embeddings.transpose(1, 2)

        return ecg_conv_embeddings if not self.flatten else self.flatten_layer(ecg_conv_embeddings)

class OnlyEmbedding(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(OnlyEmbedding, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test

        self.model_config = self.configs["MODEL_CONFIG"]
        self.embeddings = ECGEmbedding(num_leads, num_classes, configs, self.model_config["dim"], info_for_test, flatten=True)

        self.fc_layers = torch.nn.ModuleList()
        in_dim = 45056 #hard coding
        const_dim = 128
        for i in range(self.model_config["n_layers"]):
            if i==self.model_config["n_layers"]-1:
                self.fc_layers.append(torch.nn.Linear(in_dim, num_classes))
            elif in_dim <= const_dim:
                self.fc_layers.append(torch.nn.Linear(in_dim, const_dim))
            else:
                self.fc_layers.append(torch.nn.Linear(in_dim, const_dim))
                in_dim=const_dim

    def forward(self, input_recording):
        ecg_conv_embeddings = self.embeddings(input_recording)
        
        output = torch.nn.Flatten(ecg_conv_embeddings)
        output = ecg_conv_embeddings
        for fc in self.fc_layers[:-1]:
            output = fc(output)
            output = torch.nn.functional.relu(output)

        output = self.fc_layers[-1](output)
        return output, None

class ECGDistilBert(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(ECGDistilBert, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test

        model_config = self.configs["MODEL_CONFIG"]

        self.distilbert_config = DistilBertConfig(
            **self.configs["DISTILBERT_CONFIG"],
            position_embedding_type="relative_key_query",
        )
        self.cls_embedding = torch.nn.Parameter(
            torch.randn((self.distilbert_config.hidden_size,))
        )
        self.sep_embedding = torch.nn.Parameter(
            torch.randn((self.distilbert_config.hidden_size,))
        )

        self.embeddings = ECGEmbedding(num_leads, num_classes, configs, self.distilbert_config.dim, info_for_test)
        self.distilbert = DistilBertModel(self.distilbert_config)

        self.class_head = torch.nn.Linear(
            self.distilbert_config.dim, self.num_classes
        )
        self.pretrain_head = torch.nn.Linear(self.distilbert_config.dim, 2)

    def forward(self, input_recording):
        # input_recording = input_recording.squeeze(-3)
        """
        input_recording.shape : (batch_size, num_leads, len_recording)
        """
        batch_size = input_recording.shape[0]
        ecg_conv_embeddings = self.embeddings(input_recording)
        
        # ecg_conv_embeddings now has a shape of (batch_size, sequence_length, hidden_size)
        # cls_embedding + ecg_conv_embeddings + sep_embedding
        transformer_input = torch.cat(
            (
                self.cls_embedding.repeat(batch_size, 1, 1),
                ecg_conv_embeddings,
                self.sep_embedding.repeat(batch_size, 1, 1),
            ),
            dim=1,
        )
        distilbert_outputs = self.distilbert(inputs_embeds=transformer_input)
        cls_output = distilbert_outputs.last_hidden_state[:, 0, :]
        class_out = self.class_head(cls_output)
        pretrain_out = self.pretrain_head(cls_output)
        return class_out, pretrain_out
'''
class ECGBert(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(ECGBert, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test

        model_config = self.configs["MODEL_CONFIG"]


        self.bert_config = BertConfig(
            **self.configs["BERT_CONFIG"],
            position_embedding_type="relative_key_query",
        )

        # self.embeddings = ECGEmbedding(num_leads, 
        #                                 num_classes, 
        #                                 configs, 
        #                                 self.bert_config.hidden_size, 
        #                                 info_for_test,
        #                                 pooling=model_config["pooling"]
        #                                 )
        self.embeddings = EfficientNet(num_leads, 
                                        num_classes, 
                                        configs=configs, 
                                        info_for_test=info_for_test,
                                        include_top=False
                                        )
        self._conv = torch.nn.Conv1d(
            1792,
            self.bert_config.hidden_size,
            kernel_size=10,
            stride=3,
            padding=0,
            bias=True,
        )
        self.bert = BertModel(self.bert_config, add_pooling_layer=False)

        self.cls_embedding = torch.nn.Parameter(
            torch.randn((1792,))
        )
        self.sep_embedding = torch.nn.Parameter(
            torch.randn((1792,))
        )
        # self.cls_embedding = torch.nn.Parameter(
        #     torch.randn((self.bert_config.hidden_size,))
        # )
        # self.sep_embedding = torch.nn.Parameter(
        #     torch.randn((self.bert_config.hidden_size,))
        # )
        self.class_head = torch.nn.Linear(
            self.bert_config.hidden_size, self.num_classes
        )
        self.pretrain_head = torch.nn.Linear(self.bert_config.hidden_size, 2)

    def forward(self, input_recording, attention_mask=None):
        # input_recording = input_recording.squeeze(-3)
        """
        input_recording.shape : (batch_size, num_leads, len_recording)
        """
        batch_size = input_recording.shape[0]
        ecg_conv_embeddings, _ = self.embeddings(input_recording)
        ecg_conv_embeddings = ecg_conv_embeddings.transpose(1, 2)
        # print(ecg_conv_embeddings.size())
        # ecg_conv_embeddings now has a shape of (batch_size, sequence_length, hidden_size)
        # cls_embedding + ecg_conv_embeddings + sep_embedding
        # transformer_input = torch.cat(
        #     (
        #         self.cls_embedding.repeat(batch_size, 1, 1),
        #         ecg_conv_embeddings,
        #         self.sep_embedding.repeat(batch_size, 1, 1),
        #     ),
        #     dim=1,
        # )
        # print(self.cls_embedding.repeat(batch_size, 1, 1).size())
        transformer_input = torch.cat(
            (
                self.cls_embedding.repeat(batch_size, 1, 1),
                ecg_conv_embeddings,
                self.sep_embedding.repeat(batch_size, 1, 1),
            ),
            dim=1,
        )
        bert_outputs = self.bert(inputs_embeds=transformer_input)
        cls_output = bert_outputs.last_hidden_state[:, 0, :]
        class_out = self.class_head(cls_output)
        pretrain_out = self.pretrain_head(cls_output)
        return class_out, pretrain_out
'''

class ECGBert(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(ECGBert, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test

        model_config = self.configs["MODEL_CONFIG"]
        conv_channel = model_config["conv_channel"]

        self.bert_config = BertConfig(
            **self.configs["BERT_CONFIG"],
            position_embedding_type="relative_key_query",
        )
        self.cls_embedding = torch.nn.Parameter(
            torch.randn((self.bert_config.hidden_size,))
        )
        self.sep_embedding = torch.nn.Parameter(
            torch.randn((self.bert_config.hidden_size,))
        )

        self.bert = BertModel(self.bert_config, add_pooling_layer=False)
        self.conv1 = torch.nn.Conv1d(
            self.num_leads,
            conv_channel // 2,
            kernel_size=14,
            stride=3,
            padding=2,
            bias=False,
        )
        self.ln1 = torch.nn.BatchNorm1d(conv_channel // 2)
        self.relu1 = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(
            conv_channel // 2,
            conv_channel,
            kernel_size=14,
            stride=3,
            padding=0,
            bias=False,
        )
        self.ln2 = torch.nn.BatchNorm1d(conv_channel)
        self.relu2 = torch.nn.ReLU()
        self.conv3 = torch.nn.Conv1d(
            conv_channel,
            conv_channel,
            kernel_size=10,
            stride=2,
            padding=0,
            bias=False,
        )
        self.ln3 = torch.nn.BatchNorm1d(conv_channel)
        self.relu3 = torch.nn.ReLU()
        self.conv4 = torch.nn.Conv1d(
            conv_channel,
            conv_channel,
            kernel_size=10,
            stride=2,
            padding=0,
            bias=False,
        )
        self.ln4 = torch.nn.BatchNorm1d(conv_channel)
        self.relu4 = torch.nn.ReLU()
        self.conv5 = torch.nn.Conv1d(
            conv_channel,
            conv_channel,
            kernel_size=10,
            stride=1,
            padding=0,
            bias=False,
        )
        self.ln5 = torch.nn.BatchNorm1d(conv_channel)
        self.relu5 = torch.nn.ReLU()
        self.conv6 = torch.nn.Conv1d(
            conv_channel,
            self.bert_config.hidden_size,
            kernel_size=10,
            stride=1,
            padding=0,
            bias=False,
        )
        self.ln6 = torch.nn.BatchNorm1d(self.bert_config.hidden_size)
        self.relu6 = torch.nn.ReLU()
        self.class_head = torch.nn.Linear(
            self.bert_config.hidden_size, self.num_classes
        )
        self.pretrain_head = torch.nn.Linear(self.bert_config.hidden_size, 2)

    def forward(self, input_recording, attention_mask=None):
        # input_recording = input_recording.squeeze(-3)
        """
        input_recording.shape : (batch_size, num_leads, len_recording)
        """
        batch_size = input_recording.shape[0]
        ecg_conv_embeddings = self.conv1(input_recording)
        ecg_conv_embeddings = self.relu1(ecg_conv_embeddings)
        # ecg_conv_embeddings = self.ln1(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv2(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu2(ecg_conv_embeddings)
        # ecg_conv_embeddings = self.ln2(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv3(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu3(ecg_conv_embeddings)
        # ecg_conv_embeddings = self.ln3(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv4(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu4(ecg_conv_embeddings)
        # ecg_conv_embeddings = self.ln4(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv5(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu5(ecg_conv_embeddings)
        # ecg_conv_embeddings = self.ln5(ecg_conv_embeddings)
        ecg_conv_embeddings = self.conv6(ecg_conv_embeddings)
        ecg_conv_embeddings = self.relu6(ecg_conv_embeddings)
        # ecg_conv_embeddings = self.ln6(ecg_conv_embeddings)
        ecg_conv_embeddings = ecg_conv_embeddings.transpose(1, 2)

        # ecg_conv_embeddings now has a shape of (batch_size, sequence_length, hidden_size)
        # cls_embedding + ecg_conv_embeddings + sep_embedding
        transformer_input = torch.cat(
            (
                self.cls_embedding.repeat(batch_size, 1, 1),
                ecg_conv_embeddings,
                self.sep_embedding.repeat(batch_size, 1, 1),
            ),
            dim=1,
        )
        bert_outputs = self.bert(inputs_embeds=transformer_input)
        cls_output = bert_outputs.last_hidden_state[:, 0, :]
        class_out = self.class_head(cls_output)
        pretrain_out = self.pretrain_head(cls_output)
        return class_out, pretrain_out


class ViTEmbeddings(torch.nn.Module):
    """
    Copied from huggingface transformers (transformers.models.vit.modeling_vit.py)
    Construct the CLS token, position and patch embeddings.

    """

    def __init__(self, vit_config, model_config, num_leads):
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, vit_config.hidden_size))

        self.window_length = model_config["window_length"]
        self.window_stride = model_config["window_stride"]
        self.padding = 0
        self.dilation = 1
        self.input_projection = torch.nn.Conv1d(
            num_leads,
            vit_config.hidden_size,
            kernel_size=self.window_length,
            stride=self.window_stride,
            padding=self.padding,
            dilation=self.dilation,
            bias=False,
        )
        self.ln = torch.nn.BatchNorm1d(vit_config.hidden_size)

        raw_input_length = 4096 # this may change!
        self.num_patches = math.floor((raw_input_length + 2 * self.padding - self.dilation * (self.window_length - 1) - 1) / self.window_stride + 1) 
        print("input token length: {}".format(raw_input_length))

        self.position_embeddings = torch.nn.Parameter(torch.zeros(1, self.num_patches + 1, vit_config.hidden_size))
        self.dropout = torch.nn.Dropout(vit_config.hidden_dropout_prob)

    def forward(self, input_recording):
        batch_size = input_recording.shape[0]
        embeddings = self.input_projection(input_recording)
        embeddings = self.ln(embeddings)
        embeddings = embeddings.transpose(1, 2)

        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
        embeddings = embeddings + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

class ECGViT(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(ECGViT, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test

        self.model_config = self.configs["MODEL_CONFIG"]
        self.vit_config = ViTConfig(
            **self.configs["VIT_CONFIG"],
        )

        self.embeddings = ViTEmbeddings(self.vit_config, self.model_config, self.num_leads)
        self.encoder = ViTEncoder(self.vit_config)
        self.layernorm = torch.nn.LayerNorm(self.vit_config.hidden_size, eps=self.vit_config.layer_norm_eps)

        # Classifer head
        self.classifier = torch.nn.Linear(self.vit_config.hidden_size, self.num_classes)

    def forward(self, input_recording, attention_mask=None):
        """
        input_recording.shape : (batch_size, num_leads, len_recording)
        """

        recording_embeddings = self.embeddings(input_recording)
        encoder_outputs = self.encoder(recording_embeddings)
        sequence_output = encoder_outputs[0]
        sequence_output = self.layernorm(sequence_output)
        logits = self.classifier(sequence_output[:, 0, :])
        return logits, None

class ECGDeiT(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(ECGDeiT, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test

        self.model_config = self.configs["MODEL_CONFIG"]
        self.deit_config = DeiTConfig(
            **self.configs["DEIT_CONFIG"],
        )

        self.embeddings = ViTEmbeddings(self.deit_config, self.model_config, self.num_leads)
        self.encoder = DeiTEncoder(self.deit_config)
        self.layernorm = torch.nn.LayerNorm(self.deit_config.hidden_size, eps=self.deit_config.layer_norm_eps)

        # Classifer head
        self.classifier = torch.nn.Linear(self.deit_config.hidden_size, self.num_classes)

    def forward(self, input_recording, attention_mask=None):
        """
        input_recording.shape : (batch_size, num_leads, len_recording)
        """

        recording_embeddings = self.embeddings(input_recording)
        encoder_outputs = self.encoder(recording_embeddings)
        sequence_output = encoder_outputs[0]
        sequence_output = self.layernorm(sequence_output)
        logits = self.classifier(sequence_output[:, 0, :])
        return logits, None

class TimeSSLTransformer(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(TimeSSLTransformer, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test

        model_config = self.configs["MODEL_CONFIG"]
        window_length = model_config["window_length"]
        window_stride = model_config["window_stride"]
        contrastive_head_size = model_config["contrastive_head_size"]
        num_transformations = model_config["num_transformations"]
        recording_projection_head_size = model_config["recording_projection_head_size"]
        num_extra_features = 0

        self.bert_config = BertConfig(
            **self.configs["BERT_CONFIG"],
        )

        self.cl_embedding = torch.nn.Parameter(
            torch.randn(self.bert_config.hidden_size)
        ) # for contrastive learning

        self.td_embedding = torch.nn.Parameter(
            torch.randn(self.bert_config.hidden_size)
        ) # for transformation detection
        
        self.input_projection = torch.nn.Conv1d(
            self.num_leads,
            self.bert_config.hidden_size,
            kernel_size=window_length,
            stride=window_stride,
            bias=False,
        )

        # self.position_embeddings = torch.nn.Embedding(self.bert_config.max_position_embeddings, self.bert_config.hidden_size)
        # self.embedding_layer_norm = torch.nn.LayerNorm(self.bert_config.hidden_size)

        self.bert = BertModel(self.bert_config, add_pooling_layer=False)
        self.contrastive_head = torch.nn.Linear(
            self.bert_config.hidden_size, contrastive_head_size
        )
        self.detection_head = torch.nn.Linear(
            self.bert_config.hidden_size, num_transformations
        )
        self.recording_head = torch.nn.Conv1d(
            self.bert_config.hidden_size,
            recording_projection_head_size,
            kernel_size=1,
            bias=False,
        )
        self.classification_head = torch.nn.Linear(
            self.bert_config.hidden_size * 2 + num_extra_features, self.num_classes
        )

    def forward(self, input_recording, attention_mask=None):
        """
        input_recording.shape : (batch_size, num_leads, len_recording)
        """

        batch_size = input_recording.shape[0]
        recording_embeddings = self.input_projection(input_recording)
        recording_embeddings = recording_embeddings.transpose(1, 2)

        # recording_embeddings now has a shape of (batch_size, sequence_length, hidden_size)
        # cl_embedding + td_embedding + recording_embeddings
        transformer_input = torch.cat(
            (
                self.cl_embedding.repeat(batch_size, 1, 1),
                self.td_embedding.repeat(batch_size, 1, 1),
                recording_embeddings,
            ),
            dim=1,
        )

        # positional embeddings are added while processed in bert
        bert_outputs = self.bert(inputs_embeds=transformer_input)
        cl_output = bert_outputs.last_hidden_state[:, 0, :] # (batch_size, hidden_dim)
        td_output = bert_outputs.last_hidden_state[:, 1, :] # (batch_size, hidden_dim)
        recordings_output = bert_outputs.last_hidden_state[:, 2:, :] # (batch_size, sequence_length, hidden_dim)

        contrastive_out = self.contrastive_head(cl_output)
        detection_out = self.detection_head(td_output)
        recordings_out = self.recording_head(recordings_output.transpose(1,2)).transpose(1,2)
        classification_out = self.classification_head(torch.cat((cl_output, td_output), dim=1)) # dimension check
        
        # classification_out: (batch_size, self.bert_config.hidden_size * 2 + num_extra_features)
        # contrastive_out: (batch_size, contrastive_head_size)
        # detection_out: (batch_size, num_transformations)
        # recordings_out: (batch_size, sequence_length, self.bert_config.hidden_size)
        return classification_out, (contrastive_out, detection_out, recordings_out)
