import os
import torch
import mlflow
import torch.nn.functional as F
from tqdm import tqdm
from src.utils.team_helper_code import EarlyStopping

class SimCLR(torch.nn.Module):
    def __init__(self, model, configs, device, model_directory, use_mlflow=False):
        super(SimCLR, self).__init__()
        self.device = device
        self.model = model.to(device)
        self.batch_size = configs["BATCH_SIZE_PRETRAINING"]
        self.epochs = configs["N_EPOCHS_PRETRAINING"]
        self.lr = configs["LEARNING_RATE_PRETRAINING"]
        self.temperature = configs.get("TEMPERATURE_PRETRAINING", 0.07) # default: 0.07
        self.patience = configs.get("PATIENCE_PRETRAINING", 30) # default: 30
        self.weight_decay = configs.get("WEIGHT_DECAY_PRETRAINING", 1e-4) # default 1e-4
        self.save_frequency = configs.get("SAVE_FREQUENCY_PRETRAINING", 10) # default: 10
        self.criterion = torch.nn.CrossEntropyLoss().to(device)
        self.model_directory = model_directory
        self.k = 2 # number of views. should have the same views as src.data_preparation.dataset_class.Dataset_ECG_SimCLR
        self.use_mlflow = use_mlflow

        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model)

    def info_nce_loss(self, features):
        # Reference: https://github.com/sthalles/SimCLR
        # comment는 batch_size == 256인 경우 기준으로 작성
        labels = torch.cat([torch.arange(self.batch_size) for i in range(self.k)], dim=0) # labels.shape: (512), 0~255, 0~255
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() # labels.shape: (512,512), 0은 0번째, 256번째가 1, 1은 1번째, 257번째가 1
        labels = labels.cuda()

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T) # similarity_matrix.shape: (512, 512)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).cuda() # mask.shape: (512, 512), identity matrix
        labels = labels[~mask].view(labels.shape[0], -1) # labels.shape: (512, 511)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) # similarity_matrix.shape: (512, 511)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) # positives.shape: (512, 1)

        # select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) # negatives.shape: (512, 510)

        logits = torch.cat([positives, negatives], dim=1) # logits.shape: (512, 511)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # labels.shape: 512

        logits = logits / self.temperature
        return logits, labels

    def train(self, train_loader):
        early_stopping = EarlyStopping(self.patience, verbose=True,
                                    checkpoint_pth='{}/backbone_best.pt'.format(self.model_directory))
        optimizer = torch.optim.Adam([
                    {'params': self.model.parameters()}], lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs, eta_min=0, last_epoch=-1)
        self.model.train()
        epoch_max=0
        acc_1_max=0
        # acc_5_max=0
        for epoch in range(1, self.epochs+1):
            acc_1_epoch=0
            acc_5_epoch=0
            loss_epoch=0
            
            # the real target is discarded (unsupervised)
            pbar = tqdm(enumerate(train_loader), total=len(train_loader))
            
            for i, data_augmented in pbar:
                K = len(data_augmented) # tot augmentations = 2
                # len(data_augmented): K, data_augmented[0~K-1].shape: (batch_size, timestamp_dim, channel_dim)
                x = torch.cat(data_augmented, 0).to(self.device)
                # x.shape: (K * batch_size, timestamp_dim, channel_dim)

                optimizer.zero_grad()
                # forward pass (model)
                _, projected_reps = self.model(x) # get projected representations here
                # projected_reps.shape: (K * batch_size, feature_size)
                # aggregation function
                logits, labels = self.info_nce_loss(projected_reps)
                loss = self.criterion(logits, labels)

                loss.backward()
                optimizer.step()
                # estimate the accuracy
                top1, top5 = cal_accuracy(logits, labels, topk=(1, 5))
                acc_1_epoch += top1.item()
                acc_5_epoch += top5.item()
                loss_epoch += loss.item()

                pbar.set_description(
                    f"Train: [{epoch:03d}] "
                    f"Loss: {loss_epoch / (i+1):.4f}" # As the dataloader's drop_last option is set to True, this is accurate
                )
                running_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix_str(
                    f"lr: [{running_lr:.6f}]"
                )

            acc_1_epoch /= len(train_loader)
            acc_5_epoch /= len(train_loader)
            loss_epoch /= len(train_loader)

            if acc_1_epoch>acc_1_max:
                acc_1_max = acc_1_epoch
                epoch_max = epoch

            # No warmup as of now
            if epoch >= 1:
                scheduler.step()

            if epoch % self.save_frequency==0:
                print("[INFO] save backbone at epoch {}!".format(epoch))
                torch.save(self.model, '{}/backbone_{}.pt'.format(self.model_directory, epoch))

            print('Epoch [{}] loss= {:.5f}; Epoch Top1 ACC.= {:.2f}%, Epoch Top5 ACC.= {:.2f}%, Max Top1 ACC.= {:.1f}%, Max Epoch={}' \
                    .format(epoch, loss_epoch, acc_1_epoch, acc_5_epoch, acc_1_max, epoch_max))
            
            if self.use_mlflow:
                mlflow.log_metrics({
                    "pretraining_loss": loss_epoch,
                    "pretraining_top1_acc": acc_1_epoch,
                    "pretraining_top5_acc": acc_5_epoch,
                    "pretraining_best_epoch": epoch_max,
                    "pretraining_running_lr": running_lr,
                }, step=epoch)

            early_stopping(acc_1_epoch, self.model)
            if early_stopping.early_stop:
                print("Early stopping")
                break

        return self.model

def cal_accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
