import os
import torch
import mlflow
import torch.nn.functional as F
from tqdm import tqdm
from src.utils.team_helper_code import EarlyStopping

class BCSSL(torch.nn.Module):
    def __init__(self, model, configs, device, model_directory, use_mlflow=False):
        super(BCSSL, self).__init__()
        self.device = device
        self.model = model.to(device)
        self.batch_size = configs["BATCH_SIZE_PRETRAINING"]
        self.epochs = configs["N_EPOCHS_PRETRAINING"]
        self.lr = configs["LEARNING_RATE_PRETRAINING"]
        self.temperature = configs.get("TEMPERATURE_PRETRAINING", 0.07) # default: 0.07
        self.patience = configs.get("PATIENCE_PRETRAINING", 30) # default: 30
        self.weight_decay = configs.get("WEIGHT_DECAY_PRETRAINING", 1e-4) # default 1e-4
        self.save_frequency = configs.get("SAVE_FREQUENCY_PRETRAINING", 10) # default: 10
        self.criterion = torch.nn.CrossEntropyLoss().to(device)
        self.model_directory = model_directory
        self.k = 2 # number of views. should have the same views as src.data_preparation.dataset_class.Dataset_ECG_SimCLR
        self.use_mlflow = use_mlflow

        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model)

    def train(self, train_loader):
        early_stopping = EarlyStopping(self.patience, verbose=True,
                                    checkpoint_pth='{}/backbone_best.pt'.format(self.model_directory))
        optimizer = torch.optim.Adam([
                    {'params': self.model.parameters()}], lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs, eta_min=0, last_epoch=-1)
        self.model.train()
        epoch_max=0
        acc_1_max=0
        # acc_5_max=0
        for epoch in range(1, self.epochs+1):
            acc_1_epoch=0
            acc_2_epoch=0
            loss_epoch=0
            
            pbar = tqdm(enumerate(train_loader), total=len(train_loader))
            for i, data in pbar:
                recording = data["recording"].to(self.device)
                labels = data["labels"].to(self.device)
                features = data["features"].to(self.device)

                optimizer.zero_grad()
###
                _, projected_reps = self.model(recording, features)
                logits = self.model.lead_classifier(projected_reps)

                sim = torch.nn.CosineSimilarity(dim=-1)(projected_reps.unsqueeze(1), projected_reps.unsqueeze(0))

                mask_positive = torch.block_diag(*[torch.ones([self.model.num_leads,self.model.num_leads])]*recording.shape[0])
                mask_negative = 1 - mask_positive
                mask_positive = mask_positive.fill_diagonal_(0).type(torch.bool)
                mask_negative = mask_negative.fill_diagonal_(0).type(torch.bool)

                loss_block_contrastive = sim[mask_negative].abs().mean() - sim[mask_positive].mean()

                labels = torch.cat([torch.tensor(range(self.model.num_leads))] * recording.shape[0]).to(self.device)
                loss_lead = torch.nn.CrossEntropyLoss()(logits, labels)
                loss = loss_block_contrastive + loss_lead
###

                loss.backward()
                optimizer.step()
                # estimate the accuracy
                top1, top2 = cal_accuracy(logits, labels, topk=(1, 2))
                acc_1_epoch += top1.item()
                acc_2_epoch += top2.item()
                loss_epoch += loss.item()

                pbar.set_description(
                    f"Train: [{epoch:03d}] "
                    f"Loss: {loss_epoch / (i+1):.4f}" # As the dataloader's drop_last option is set to True, this is accurate
                )
                running_lr = scheduler.get_last_lr()[0]
                pbar.set_postfix_str(
                    f"lr: [{running_lr:.6f}]"
                )

            acc_1_epoch /= len(train_loader)
            acc_2_epoch /= len(train_loader)
            loss_epoch /= len(train_loader)

            if acc_1_epoch>acc_1_max:
                acc_1_max = acc_1_epoch
                epoch_max = epoch

            # No warmup as of now
            if epoch >= 1:
                scheduler.step()

            if epoch % self.save_frequency==0:
                print("[INFO] save backbone at epoch {}!".format(epoch))
                torch.save(self.model, '{}/backbone_{}.pt'.format(self.model_directory, epoch))

            print('Epoch [{}] loss= {:.5f}; Epoch Top1 ACC.= {:.2f}%, Epoch Top2 ACC.= {:.2f}%, Max Top1 ACC.= {:.1f}%, Max Epoch={}' \
                    .format(epoch, loss_epoch, acc_1_epoch, acc_2_epoch, acc_1_max, epoch_max))
            
            if self.use_mlflow:
                mlflow.log_metrics({
                    "pretraining_loss": loss_epoch,
                    "pretraining_top1_acc": acc_1_epoch,
                    "pretraining_top2_acc": acc_2_epoch,
                    "pretraining_best_epoch": epoch_max,
                    "pretraining_running_lr": running_lr,
                }, step=epoch)

            early_stopping(acc_1_epoch, self.model)
            if early_stopping.early_stop:
                print("Early stopping")
                break

        return self.model

def cal_accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
