import importlib
import itertools

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import f1_score
from tqdm import tqdm
from src.utils.team_helper_code import get_target_classes_mask_by_sample_id, plot_confusion_matrix


class Trainer:
    def __init__(
        self,
        model,
        model_path,
        device,
        scaler=None,
        use_mlflow=False,
        train_mask=False,
        valid_mask=False,
        test_mask=True
    ):
        self.model = model
        self.model_path = model_path
        self.device = device
        self.scaler = scaler
        self.use_mlflow = use_mlflow
        self.thresholds = None
        self.train_mask = train_mask
        self.valid_mask = valid_mask
        self.test_mask = test_mask
        self.model = self.model.to(device)
        self.model.test_mask = test_mask
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)

    def train(
        self,
        train_loader,
        criterion,
        optimizer,
        n_epochs,
        scheduler=None,
        valid_loader=None,
        valid_period=2,
    ):

        best_model = self.model
        best_challenge_metric = float('-inf')
        best_epoch = 0
        best_valid_predictions = []

        classes = train_loader.dataset.classes

        best_threshold_data = {
            "classes": classes,
            "thresholds": [0.5] * len(classes) # default: 0.5
        }

        self.model.threshold_data = best_threshold_data

        for epoch in range(1, n_epochs+1):
            self.model.train()
            total = 0
            total_loss = 0.0
            pbar = tqdm(enumerate(train_loader), total=len(train_loader))
            for batch, data in pbar:
                sample_id = data["id"]
                recording = data["recording"].to(self.device)
                labels = data["labels"].to(self.device)
                features = data["features"].to(self.device)

                optimizer.zero_grad()
                if self.scaler:
                    with torch.cuda.amp.autocast():
                        outputs, _ = self.model(recording, features)
                else:
                    outputs, _ = self.model(recording, features)

                if self.train_mask:
                    mask = list(map(lambda i: get_target_classes_mask_by_sample_id(classes, sample_id[i]).tolist(), range(len(sample_id))))
                    # mask = torch.Tensor(mask).bool().to(self.device)
                    # neg_inf_mask = self._get_neg_inf_mask(outputs, mask)
                    # outputs = outputs * neg_inf_mask
                    mask = torch.Tensor(mask).to(self.device)
                    criterion.weight = mask
                else:
                    criterion.weight = None

                loss = criterion(outputs, labels)

                if self.scaler:
                    self.scaler.scale(loss).backward()
                    self.scaler.step(optimizer)
                    self.scaler.update()
                else:
                    loss.backward()
                    optimizer.step()

                total += labels.size(0)
                total_loss += loss.item() * len(recording)

                running_lr = scheduler.get_last_lr()[0]

                pbar.set_description(
                    f"Train: [{epoch:03d}] "
                    f"Loss: {total_loss / total:.4f}"
                )
                pbar.set_postfix_str(
                    f"lr: [{running_lr:.6f}]"
                )
            total_loss /= total
            pbar.close()
            scheduler.step()

            save_info = False
            if valid_loader is not None and epoch % valid_period == 0:
                valid_loss, valid_predictions, valid_challenge_metric, modified_confusion_matrix, threshold_data = self.test(valid_loader, criterion)

                if valid_challenge_metric > best_challenge_metric:
                    best_model = self.model
                    best_valid_predictions = valid_predictions
                    best_epoch = epoch
                    best_challenge_metric = valid_challenge_metric
                    best_threshold_data = threshold_data
                    torch.save(best_model, self.model_path)
                    save_info = True

            elif valid_loader is None:
                best_model = self.model
                torch.save(best_model, self.model_path)

            if self.use_mlflow:
                mlflow.log_metrics({
                    "train_loss": total_loss,
                    "running_lr": running_lr,
                }, step=epoch)
                if valid_loader is not None and epoch % valid_period == 0:
                    mlflow.log_metrics({
                        "valid_loss": valid_loss,
                        "valid_challenge_metric": valid_challenge_metric,
                        "best_epoch": best_epoch,
                    }, step=epoch)
                
                if epoch % 5 == 0:
                    save_info = True

                # if save_info:
                #     cm_fig = plot_confusion_matrix(modified_confusion_matrix, target_names=classes, normalize=True)
                #     mlflow.log_figure(cm_fig, f"{epoch:03d}_modified_confusion_matrix.png")
                #     mlflow.log_text(str(threshold_data["thresholds"]), f"{epoch:03d}_thresholds.txt")
                #     plt.close(cm_fig)

        return best_model, best_threshold_data

    @torch.no_grad()
    def test(
        self,
        test_loader,
        criterion,
    ):
        self.model.eval()
        total = 0
        total_loss = 0.0

        test_predictions = []
        classes = test_loader.dataset.classes
        pbar = tqdm(enumerate(test_loader), total=len(test_loader))
        for batch, data in pbar:
            sample_id = data["id"]
            recording = data["recording"].to(self.device)
            labels = data["labels"].to(self.device)
            features = data["features"].to(self.device)
            outputs, _ = self.model(recording, features)

            # outputs.shape = (batch_size, # classes)
            # for id in sample_id:
            #     mask = get_target_classes_mask_by_sample_id(classes, id)
            #     print(id)
            #     print(classes)
            #     print(mask)
            if self.valid_mask:
                mask = list(map(lambda i: get_target_classes_mask_by_sample_id(classes, sample_id[i]).tolist(), range(len(sample_id))))
                # mask = torch.Tensor(mask).bool().to(self.device)
                # neg_inf_mask = self._get_neg_inf_mask(outputs, mask)
                # outputs = outputs * neg_inf_mask
                mask = torch.Tensor(mask).to(self.device)
                criterion.weight = mask
            else:
                criterion.weight = None

            loss = criterion(outputs, labels)

            total += len(recording)
            total_loss += loss.item() * len(recording)

            for i in range(len(recording)):
                test_pred = torch.sigmoid(outputs[i]).cpu()
                if self.valid_mask:
                    nan_list = torch.ones(len(mask[i])) * -1
                    test_pred = np.where(mask[i].cpu().bool(), test_pred, nan_list)
                    test_pred = torch.tensor(test_pred)
                test_predictions.append({
                    "labels": labels[i].cpu(),
                    "prediction": test_pred,
                })
            
            pbar.set_description(
                f" Test: {'':5} Loss: {(total_loss / total):.4f}"
            )
        total_loss /= total
        pbar.close()

        THRESHOLD_CANDIDATES = sorted([0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7])
        threshold_data = self._select_threshold(
            threshold_cadidates=THRESHOLD_CANDIDATES,
            classes=classes,
            top_validation_predictions=test_predictions,
        )
        labels = np.array(list(map(lambda x: x["labels"].tolist(), test_predictions)))
        prediction = np.array(list(map(lambda x: x["prediction"].tolist(), test_predictions)))
        label_prediction = prediction > np.array(threshold_data["thresholds"])

        # 경로에 "-"가 들어간 파일은 import가 안 됨
        # 그래서 이렇게 import 해줬음
        evaluate_model = importlib.import_module("evaluation-2021.custom_evaluate_model")
        challenge_metric, modified_confusion_matrix = evaluate_model.custom_evaluate_model(labels, label_prediction, prediction, classes)
        return total_loss, test_predictions, challenge_metric, modified_confusion_matrix, threshold_data

    def pre_train(
        self,
        train_loader,
        criterion,
        optimizer,
        n_epochs,
        scheduler=None,
    ):
        self.model.train()
        pass

    def _select_threshold(self, threshold_cadidates, classes, top_validation_predictions):
        thresholds = np.zeros(len(classes))
        for class_idx in range(len(classes)):
            top_resulting_threshold = threshold_cadidates[0]
            top_score = float("-inf")
            for candidate_value in threshold_cadidates:
                # only uses f1 score here
                current_score, _, _ = self._evaluate_single_threshold(
                    top_validation_predictions, class_idx, candidate_value
                )
                if current_score > top_score:
                    top_resulting_threshold = candidate_value
                    top_score = current_score
            thresholds[class_idx] = top_resulting_threshold
            threshold_data = {"classes": classes, "thresholds": thresholds.tolist()}
        return threshold_data

    def _evaluate_single_threshold(self, preds, class_idx, threshold):
        """
        evaluates prediction score of a specified class, given a candidate threshold value

        Args:
            preds (list of dicts): predictions ({'labels': ..., 'prediction': ...})
            class_idx (int): index of currently processing class
            threshold (float): candidate threshold value
        Return:
            f1 score, precision, recall (tuple of floats)
        """
        _predictions = []
        _true_labels = []
        for x in preds:
            if x["prediction"][class_idx] < 0:
                continue
            _predictions.append(x["prediction"][class_idx])
            _true_labels.append(x["labels"][class_idx])
        binary_predictions = np.array(_predictions) > threshold  # list of True/False
        true_labels = np.array(_true_labels).astype(bool)  # list of True/False

        if len(set(binary_predictions)) == 1 and len(set(true_labels)) == 1:
            if not binary_predictions[0] and not true_labels[0]:
                # both all False
                return np.float64("nan"), np.float64("nan"), np.float64("nan")
            elif binary_predictions[0] and true_labels[0]:
                # both all True
                return np.float64(1), np.float64(1), np.float64(1)

        # If there's no tp, fp, or fn, some score results in nan (division by zero) because numpy.float handles the exception.
        # but any comparision operation including nan returns False, so you don't have to worry!
        # tn, fp, fn, tp = confusion_matrix(true_labels, binary_predictions).ravel()
        # result_f1_score = f1_score(true_labels, binary_predictions)
        # result_precision = tp / (tp + fp)
        # result_recall = tp / (tp + fn)
        # return result_f1_score, result_precision, result_recall
        result_f1_score = f1_score(true_labels, binary_predictions)
        return result_f1_score, None, None

    def _get_neg_inf_mask(self, outputs, mask):
        """
        Args:
            outputs: (batch_size, out_dim)
        """
        neg_inf = torch.sign(outputs) * -1e10
        neg_inf_mask = torch.where(mask, 1, neg_inf.long())
        return neg_inf_mask