#%%
import os
import pandas as pd
import numpy as np 
# import seaborn as sn
import matplotlib.pyplot as plt
from matplotlib import gridspec

from loaders.hrv_loader import HrvDataModule
from src.engine.utils import get_ckpt, get_cv_logits_metrics, get_nested_fold_idx
from src.engine.utils import NORMAL_CLASS, CLASSES, WEIGHTS, SCORED_CLASSES
from lib.evaluation_2021.evaluate_model import *

from src.engine.models import *
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

import mlflow as mf

import coloredlogs, logging
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")

coloredlogs.install()
logger = logging.getLogger(__name__)  
# logger.setLevel(logging.DEBUG)

#%%
class Solver:
    DEFAULTS = {}   

    def __init__(self, config, fea_lead_all_cv, label_bin_all_cv):
        self.config = config
        self.leads = config.param_loader.leads
        self.NORMAL_CLASS = NORMAL_CLASS
        self.SCORED_CLASSES = SCORED_CLASSES # all target classes without merge (24)
        # self.SCORED_CLASSES = [s.split('|')[0] if '|' in s else s for s in SCORED_CLASSES]
        self.config.param_model.output_size = len(self.SCORED_CLASSES)
        self.fea_lead_all_cv = fea_lead_all_cv
        self.label_bin_all_cv = label_bin_all_cv
        
            
    # other functions
    def _get_model(self, class_weight=None, ckpt_path_abs=None):
        model = None
        if not ckpt_path_abs:
            if self.config.exp.model_type == "toy_model":
                model = ToyModel(self.param_model, self.random_state)
            elif self.config.exp.model_type == "mlp_vanilla":
                model = MLPVanilla(self.config.param_model, random_state=self.config.exp.random_state, class_weight=class_weight)
            return model
        else:
            if self.config.exp.model_type == "toy_model":
                model = ToyModel.load_from_checkpoint(ckpt_path_abs)
            elif self.config.exp.model_type == "mlp_vanilla":
                model = MLPVanilla.load_from_checkpoint(ckpt_path_abs)
            return model
            
    
    def _get_class_weight(self):
        site = os.path.basename(self.config.data_directory)
        # label_percent = pd.read_csv('./datasets/label_percent.csv')
        label_percent = self.label_bin_all_cv.mean(axis=0)
        class_weight = np.array([1/p if p>0 else p for p in label_percent])
        min_max = (class_weight - min(class_weight)) / (max(class_weight) - min(class_weight))
        return min_max

    
    def evaluate(self):
        '''
        Evaluate the model on unknown subjects before.
        '''
#%%
        fold_errors_template = {"binary_outputs":[],
                       "scalar_outputs":[],
                       "label":[],
                       "auroc":[],
                       "auprc":[],
                       "cm":[]}
        fold_errors = {f"{mode}_{k}":[] for k,v in fold_errors_template.items() for mode in ["val","test"]}
        # =============================================================================
        # data module
        # =============================================================================
        class_weight = self._get_class_weight() if self.config.param_model.is_class_weight else None
        dm = HrvDataModule(self.config, self.fea_lead_all_cv, self.label_bin_all_cv, self.fea_lead_all_cv)
        dm.setup()

        # for foldIdx in range(self.config.exp.N_fold):
        for foldIdx, (folds_train, folds_val, folds_test) in enumerate(get_nested_fold_idx(self.config.exp.N_fold)):
            # init data module
            logger.info("== CROSS-SUBJECT FOLD [{}/{}] ==".format(foldIdx+1, self.config.exp.N_fold))
            dm.setup_kfold(folds_train, folds_val, folds_test)

#%%
            # init model
            model = self._get_model(class_weight=class_weight)
            # model = Classifier(self.config.param_model, random_state=self.config.exp.random_state, class_weight=class_weight)
            # model = Classifier(random_state=self.config.exp.random_state, class_weight=class_weight, **dict(self.config.param_model))
            early_stop_callback = EarlyStopping(**dict(self.config.param_early_stop))
            checkpoint_callback = ModelCheckpoint(**dict(self.config.logger.param_ckpt))
            trainer = pl.Trainer(**dict(self.config.param_trainer), callbacks=[early_stop_callback, checkpoint_callback], logger=self.config.logger.log_lightning)
            # trainer = pl.Trainer(**dict(self.config.param_trainer), callbacks=[early_stop_callback])

            # fit
            mf.pytorch.autolog()
            with mf.start_run(run_name=f"cv{foldIdx}", nested=True) as run:
                trainer.fit(model, dm)
                artifact_uri, ckpt_path = get_ckpt(mf.get_run(run_id=run.info.run_id))
                metrics_test = trainer.test(ckpt_path="best")
            metrics_val_test = trainer.validate(ckpt_path="best")
            logit, label, scalar_outputs, binary_outputs = trainer.model.get_logits(dm.val_dataloader())

#%%

            # =============================================================================
            # load trained best model
            # =============================================================================
            # load best model - 1 (from mlflow artifacts)
            ckpt_path_abs = Path(artifact_uri)/ckpt_path[0]
            model = self._get_model(ckpt_path_abs=ckpt_path_abs)
            trainer.model = model
            # trainer.validate()
            logger.info(f"bin_thre_arr:{model.bin_thre_arr}")
            
            # load best model - 2 (from trainer bbest model, lightning checkpoints)
            # best_ckpt_path = trainer.checkpoint_callback.best_model_path

#%%            
            # =============================================================================
            # check validation/test metrics and logits
            # =============================================================================
            get_cv_logits_metrics(fold_errors, trainer=trainer, loader=dm.val_dataloader(), mode="val")
            get_cv_logits_metrics(fold_errors, trainer=trainer, loader=dm.test_dataloader(), mode="test")
            


            # 6) Visualization: confusion matrix
            # confusion = compute_confusion_matrices(labels, binary_outputs, normalize=False)
            # n_row = 6
            # n_col = 5
            # gs = gridspec.GridSpec(n_row, n_col, wspace=0.5, hspace=0.7)
            # plt.figure(figsize=(15, 10), dpi=80)
            # for i in range(len(confusion)):
            #     ax = plt.subplot(gs[i//n_col, i%n_col])
            #     df_cm = pd.DataFrame(confusion[i], index = [i for i in ["pred_N", "pred_Y"]],
            #                     columns = [i for i in ["true_N", "true_Y"]])
            #     sn.heatmap(df_cm, cmap='RdPu', annot=True, fmt='g')
            #     ax.title.set_text(classes[i]) 
            # plt.savefig('./plot_test.png')
            
            
                    
        # Cross Validation Summary
        fold_errors = {k:np.concatenate(v, axis=0) for k,v in fold_errors.items()}
        val_auroc, val_auprc, val_auroc_classes, val_auprc_classes = compute_auc(fold_errors["val_label"], fold_errors["val_scalar_outputs"])
        val_cm = compute_challenge_metric(WEIGHTS, fold_errors["val_label"], fold_errors["val_binary_outputs"], CLASSES, self.NORMAL_CLASS)
        test_auroc, test_auprc, test_auroc_classes, test_auprc_classes = compute_auc(fold_errors["test_label"], fold_errors["test_scalar_outputs"])
        test_cm = compute_challenge_metric(WEIGHTS, fold_errors["test_label"], fold_errors["test_binary_outputs"], CLASSES, self.NORMAL_CLASS)

        # logger.info("-"*47) 
        # logger.info("Average | AUROC: {:.4f} | AUPRC: {:.4f} | CHALLENGE: {:.4f}".format(auroc, auprc, challenge_metric))
        # logger.info("-"*47) 
        return {"cv_test_cm":test_cm, "cv_test_auroc":test_auroc, "cv_test_auprc":test_auprc, "cv_val_cm":val_cm, "cv_val_auroc":val_auroc, "cv_val_auprc":val_auprc}



    
#%%
"""
            # =============================================================================
            # 1. Train General Model.
            # =============================================================================
            logger.info("-- Training --")
            gm_tag = self._get_gm_tag(folds_train, folds_val)
            cache_path = Path(self.config.model_directory)/f"{gm_tag}.pth"
            print(cache_path)
            
            if os.path.exists(cache_path):
                logger.info("Skip training general model, found {}".format(cache_path))
                # model.load(cache_path)
                model.fit(loader_train)
            else:
                model.fit(loader_train)
                model.save(cache_path) 
            

"""


