import os, sys
import torch
import torch.nn.functional as F
import json
import mlflow as mf
from mlflow.tracking import MlflowClient
import numpy as np
from pathlib import Path
from lib.evaluation_2021.evaluate_model import *
from mlflow.utils.autologging_utils.safety import try_mlflow_log
from shutil import rmtree
import coloredlogs, logging
coloredlogs.install()
logger = logging.getLogger(__name__)  

#%%
# =============================================================================
# data
# =============================================================================
def to_numpy(x):
    if x.is_cuda: return x.detach().cpu().data.numpy()
    return x.detach().data.numpy()

# =============================================================================
# data
# =============================================================================
ALL_TRAIN_SITES = ["WFDB_CPSC2018", "WFDB_CPSC2018_2", "WFDB_StPetersburg", "WFDB_PTB", "WFDB_PTBXL", "WFDB_Ga", "WFDB_ChapmanShaoxing", "WFDB_Ningbo"]


# =============================================================================
# feature
# =============================================================================
def map_sex(sex, impute_nan=1):
    if sex in ('Female', 'female', 'F', 'f'):
        sex = np.asarray([0])
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = np.asarray([1])
    else:
        if impute_nan:
            sex = np.asarray([impute_nan])
        else:
            sex = np.asarray([np.nan])
    return sex.astype("float32")

# =============================================================================
# label
# =============================================================================
SCORED_CLASSES = load_table('lib/evaluation_2021/weights.csv')[0] # all target classes without merge (24)
EQ_CLASSES = []
for c in SCORED_CLASSES:
    tmp = c.split("|")
    if len(tmp)>1:
        EQ_CLASSES.append([tmp, c])

def replace_equivalent_classes(classes, equivalent_classes):
    for j, x in enumerate(classes):
        for multiple_classes in equivalent_classes:
            if x in multiple_classes[0]:
                classes[j] = multiple_classes[1] # Use the first class as the representative class.
    return classes


# =============================================================================
# loader
# =============================================================================
def get_nested_fold_idx(kfold):
    for fold_test_idx in range(kfold):
        fold_val_idx = (fold_test_idx+1)%kfold
        fold_train_idx = [fold for fold in range(kfold) if fold not in [fold_test_idx, fold_val_idx]]
        yield fold_train_idx, [fold_val_idx], [fold_test_idx]

# =============================================================================
# metric
# =============================================================================
def predToScalarBinary(pred, label, classes):
    scalar_outputs = F.sigmoid(torch.tensor(pred)).numpy()
    binary_outputs = np.zeros_like(scalar_outputs)
    binary_outputs[np.where(scalar_outputs>=0.5)] = 1
    binary_outputs[np.where(scalar_outputs<0.5)] = 0
    return scalar_outputs, binary_outputs

def logit_To_Scalar_Binary(logit:torch.tensor, bin_thre_arr=None):
    scalar_output = F.sigmoid(torch.tensor(logit)).cpu()
    binary_output = torch.zeros_like(scalar_output, dtype=torch.float32)
    if bin_thre_arr is not None:
        binary_output[scalar_output>=bin_thre_arr] = 1
        binary_output[scalar_output<bin_thre_arr] = 0
    else:
        binary_output[scalar_output>=0.5] = 1
        binary_output[scalar_output<0.5] = 0
    return scalar_output, binary_output

def get_cv_logits_metrics(fold_errors, model, outputs, mode="val"):
    scalar_outputs, binary_outputs = logit_To_Scalar_Binary(outputs["logit"], model.bin_thre_arr.cpu())
    metrics = model._cal_metric(torch.tensor(outputs["logit"]), torch.tensor(outputs["label"]))
    fold_errors[f"{mode}_binary_outputs"].append(binary_outputs)
    fold_errors[f"{mode}_scalar_outputs"].append(scalar_outputs)
    fold_errors[f"{mode}_label"].append(outputs["label"])
    fold_errors[f"{mode}_auroc"].append([metrics["auroc"]])
    fold_errors[f"{mode}_auprc"].append([metrics["auprc"]])
    fold_errors[f"{mode}_cm"].append([metrics["cm"]])


# =============================================================================
# evaluation

# =============================================================================
NORMAL_CLASS = set(['426783006'])
weights_file = Path('lib/evaluation_2021/weights.csv').absolute()
CLASSES, WEIGHTS = load_weights(weights_file)
SCORED_CLASSES = load_table(Path('lib/evaluation_2021/weights.csv'))[0]


#%%
def set_device(gpu_id):
    # Manage GPU availability
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    if gpu_id != "": 
        torch.cuda.set_device(0)
        
    else:
        n_threads = torch.get_num_threads()
        n_threads = min(n_threads, 8)
        torch.set_num_threads(n_threads)
        print("Using {} CPU Core".format(n_threads))
        
        
#%%
import multiprocessing.pool as mpp
# istarmap, hgy
def istarmap(self, func, iterable, chunksize=1):
    """starmap-version of imap
    """
    if self._state != mpp.RUN:
        raise ValueError("Pool not running")

    if chunksize < 1:
        raise ValueError(
            "Chunksize must be 1+, not {0:n}".format(
                chunksize))

    task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
    result = mpp.IMapIterator(self)
    self._taskqueue.put(
        (
            self._guarded_task_generation(result._job,
                                          mpp.starmapstar,
                                          task_batches),
            result._set_length
        ))
    return (item for chunk in result for item in chunk)
mpp.Pool.istarmap = istarmap

def init_mlflow(config):
    mf.set_tracking_uri(str(Path(config.path.mlflow_dir).absolute()))  # set up connection
    mf.set_experiment(config.exp.exp_name)          # set the experiment

def log_params_mlflow(config):
    mf.log_params(config.get("exp"))
    mf.log_params(config.get("param_feature"))
    try_mlflow_log(mf.log_params, config.get("param_preprocess"))
    mf.log_params(config.get("param_loader"))
    mf.log_params(config.get("param_trainer"))
    mf.log_params(config.get("param_early_stop"))
    mf.log_params(config.get("param_aug"))
    mf.log_params(config.get("param_model"))

def log_hydra_mlflow(name):
    mf.log_artifact(os.path.join(os.getcwd(), '.hydra/config.yaml'))
    mf.log_artifact(os.path.join(os.getcwd(), '.hydra/hydra.yaml'))
    mf.log_artifact(os.path.join(os.getcwd(), '.hydra/overrides.yaml'))
    mf.log_artifact(os.path.join(os.getcwd(), f'{name}.log'))
    rmtree(os.path.join(os.getcwd()))

def init_dirs(config):
    mlflow_dir = PATHS["PROJECT_PATH"] + "mlruns"
    active_run = mf.active_run()
    config.root_dir = f"{mlflow_dir}/{active_run.info.experiment_id}/{active_run.info.run_id}/artifacts/"
    for item in ["model_dir","sample_dir","log_dir"]:
        config[item] = config.root_dir + config[item]
        Path(config[item]).mkdir(parents=True, exist_ok=True)
    return config

def print_auto_logged_info(r):
    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    models = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    ckpts = [f for f in MlflowClient().list_artifacts(r.info.run_id, "restored_model_checkpoint")]
    print("run_id: {}".format(r.info.run_id))
    print("model: {}".format(artifacts))
    print("ckpt: {}".format(ckpts))
    print("params: {}".format(r.data.params))
    print("metrics: {}".format(r.data.metrics))
    print("tags: {}".format(tags))
    
def get_ckpt(r):
    ckpts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "restored_model_checkpoint")]
    return r.info.artifact_uri, ckpts
    
class TuneBinaryThreshold(object):
    def __init__(self, model):
        self.model = model
    def __enter__(self):
        self.model.update_bin_thre_arr = True
        logger.info("Tuning Bianry Thresholding !")
        return self.model
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.model.update_bin_thre_arr = False

