#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of the required functions, remove non-required functions, and add your own functions.

################################################################################
#
# Imported functions and variables
#
################################################################################

import os
import json

import joblib
import mlflow
import numpy as np
import torch
import gc

# Import functions. These functions are not required. You can change or remove them.
from helper_code import *
from src.data_preparation.dataset_func import get_data_info_dict, get_supervised_dataset, get_simclr_dataset, get_dataloader, get_ssl_dataloader
from src.setup import LEAD_DICT, TARGETED, get_configs, setup
from src.trainer import Trainer
from src.utils.preprocess import preprocess_recording
from src.utils.team_helper_code import get_features, get_recording_mean_per_lead, identify_dataset, get_target_classes_mask_by_sample_id, get_id

# Define the Challenge lead sets. These variables are not required. You can change or remove them.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)

twelve_lead_model_filename = "12_lead_model.pt"
six_lead_model_filename = "6_lead_model.pt"
four_lead_model_filename = "4_lead_model.pt"
three_lead_model_filename = "3_lead_model.pt"
two_lead_model_filename = "2_lead_model.pt"

################################################################################
#
# Training model function
#
################################################################################

# Train your model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    # Create a folder for the model if it does not already exist.
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)

    os.system(f"cp configs/configs.yaml {model_directory}/configs.yaml")
    os.system(f"cp -r thresholds {model_directory}/thresholds")
    configs = get_configs(model_directory)

    setup(configs)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scaler = torch.cuda.amp.GradScaler() if configs["FP16_AMP"] else None
    use_mlflow = configs["USE_MLFLOW"]
    train_params = configs["TRAIN_PARAMS"]
    on_training = configs["ON_TRAINING"] if configs["ON_TRAINING"] is not None else []
    # ["TWELVE", "SIX", ...]

    # model should hold these variables for test session
    info_for_test = {
        "TRANS_BASE": configs["TRANS_BASE"],
        "TRANS_BASE_PARAMS": configs["TRANS_BASE_PARAMS"],
        "PREPROCESS_PARAMS": configs["PREPROCESS_PARAMS"]
    }

    for lead_str in on_training:
        lead_configs = configs[lead_str]
        leads = eval(LEAD_DICT[lead_str])
        model_path = os.path.join(model_directory, f"{len(leads)}_lead_model.pt")
        
        info_for_test.update({
            "leads": leads,
            "model_directory": model_directory
        })

        if use_mlflow:
            ### Set MLflow
            mlflow.set_experiment(lead_str)
            experiment_id = mlflow.get_experiment_by_name(lead_str).experiment_id
            mlflow.start_run(experiment_id=experiment_id)
            run = mlflow.active_run()
            run_id = run.info.run_id
            os.system(f"echo {lead_str}: {run_id} >> {model_directory}/run_id.yaml")
            
            ### log params
            # log necessary params
            params_to_log = [
                "DEVELOP",
                "RANDOM_SEED", 
                "FP16_AMP",
                "GLOBAL_THRESHOLDS",
                "TRAIN_PARAMS",
                "TRANS_BASE", 
                "TRANS_BASE_PARAMS",
                "TRANS_TRAIN",
                "TRANS_TRAIN_PARAMS",
                "TRANS_CONTRAST",
                "TRANS_CONTRAST_PARAMS",
                "PREPROCESS_PARAMS"
            ]
            for k in params_to_log:
                v = configs[k]
                if k.split("_")[-1] == "PARAMS":
                    for k_, v_ in v.items():
                        mlflow.log_param(f"{'_'.join(k.split('_')[:-1])}_{k_}", v_)
                else:
                    mlflow.log_param(k, v)

            # log all lead_configs i.e. configs[lead_str]
            for k, v in lead_configs.items():
                if k.split("_")[-1] == "CONFIG":
                    for k_, v_ in v.items():
                        mlflow.log_param(f"{'_'.join(k.split('_')[:-1])}_{k_}", v_)
                else:
                    mlflow.log_param(k, v)

            # log artifacts
            mlflow.log_artifact(f"{model_directory}/configs.yaml")

        data_info_dict = get_data_info_dict(lead_str, data_directory, configs)

        # prepare supervised dataset first: should know how many classes are in supervised dataset in advance
        train_dataset, valid_dataset = get_supervised_dataset(
            data_info_dict, 
            classes=TARGETED["target_dxs"], 
            valid_ratio=lead_configs["VALID_RATIO"]
        )

        ### pretraining session
        from_pretrained = False
        if lead_configs.get("USE_PRETRAINED_MODEL", False):
            # Start from pretrained model
            from_pretrained = True
            pretrained_model = torch.load(lead_configs["PRETRAINED_MODEL_PATH"])
        elif lead_configs.get("N_EPOCHS_PRETRAINING", 0) > 0:
            # Do pretraining
            from_pretrained = True
            pretrained_model = getattr(__import__("src.models", fromlist=[""]), lead_configs["MODEL"])(
                num_leads=len(leads),
                num_classes=train_dataset.num_classes, # use information from supervised dataset
                configs=lead_configs,
                info_for_test=info_for_test
            )
            pretraining_model_directory = os.path.join(model_directory, f"{len(leads)}_lead_model_backbone")
            os.makedirs(pretraining_model_directory, exist_ok=True)

            if lead_configs["METHOD_PRETRAINING"] == "SimCLR":
                pretrain_dataset = get_simclr_dataset(data_info_dict)
                pretrain_loader = get_ssl_dataloader(dataset=pretrain_dataset, batch_size=lead_configs["BATCH_SIZE_PRETRAINING"], num_workers=configs["NUM_WORKERS"])
                pretrainer = getattr(__import__("src.pretraining",  fromlist=[""]), "SimCLR")(pretrained_model, lead_configs, device, pretraining_model_directory, use_mlflow)
                pretrained_model = pretrainer.train(train_loader=pretrain_loader)

            elif lead_configs["METHOD_PRETRAINING"] == "BCSSL":
                train_loader = get_dataloader(train_dataset, lead_configs["BATCH_SIZE_PRETRAINING"], num_workers=configs["NUM_WORKERS"])
                pretrainer = getattr(__import__("src.pretraining",  fromlist=[""]), "BCSSL")(pretrained_model, lead_configs, device, pretraining_model_directory, use_mlflow)
                pretrained_model = pretrainer.train(train_loader=train_loader)

            elif lead_configs["METHOD_PRETRAINING"] == "BCSSL_12L":
                ##### 12 lead pretraining
                del train_dataset, valid_dataset, pretrained_model, data_info_dict
                gc.collect()
                torch.cuda.empty_cache()

                _data_info_dict = get_data_info_dict('TWELVE', data_directory, configs)
                _train_dataset, _valid_dataset = get_supervised_dataset(
                _data_info_dict, 
                classes=TARGETED["target_dxs"], 
                valid_ratio=lead_configs["VALID_RATIO"]
                )
                pretrained_model = getattr(__import__("src.models", fromlist=[""]), lead_configs["MODEL"])(
                    num_leads=12,
                    num_classes=_train_dataset.num_classes, # use information from supervised dataset
                    configs=lead_configs,
                    info_for_test=info_for_test
                )

                _train_loader = get_dataloader(_train_dataset, lead_configs["BATCH_SIZE_PRETRAINING"], num_workers=configs["NUM_WORKERS"])
                pretrainer = getattr(__import__("src.pretraining",  fromlist=[""]), "BCSSL")(pretrained_model, lead_configs, device, pretraining_model_directory, use_mlflow)
                pretrained_model = pretrainer.train(train_loader=_train_loader)
                ##### change leads to supervised target
                pretrained_model.num_leads = len(leads)
                pretrained_model.cls = torch.nn.Linear(pretrained_model.last_dim * len(leads) + pretrained_model.feature_dim, pretrained_model.num_classes)

                del _train_dataset, _valid_dataset, _train_loader, _data_info_dict
                gc.collect()
                torch.cuda.empty_cache()

                data_info_dict = get_data_info_dict(lead_str, data_directory, configs)
                train_dataset, valid_dataset = get_supervised_dataset(
                    data_info_dict, 
                    classes=TARGETED["target_dxs"], 
                    valid_ratio=lead_configs["VALID_RATIO"]
                )

        ### supervised training session
        if from_pretrained:
            model = pretrained_model
        else:
            # (from scratch) random initializing
            model = getattr(__import__("src.models", fromlist=[""]), lead_configs["MODEL"])(
                num_leads=len(leads),
                num_classes=train_dataset.num_classes,
                configs=lead_configs,
                info_for_test=info_for_test
            )
        
        train_loader = get_dataloader(train_dataset, lead_configs["BATCH_SIZE"], num_workers=configs["NUM_WORKERS"])
        valid_loader = get_dataloader(valid_dataset, lead_configs["BATCH_SIZE"], num_workers=configs["NUM_WORKERS"]) if len(valid_dataset) != 0 else None


        criterion = torch.nn.BCEWithLogitsLoss()  # Multi-label classification
        
        # steps = [10, 20, 30, 40]
        steps = [7, 14, 25]
        gamma = 0.1
        optimizer = torch.optim.Adam(model.parameters(), lr=lead_configs["LEARNING_RATE"])
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps, gamma=gamma)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer, T_max=lead_configs["N_EPOCHS"], eta_min=0, last_epoch=-1
        # )

        if use_mlflow:
            mlflow.log_params({
                "scheduler": scheduler.__class__.__name__,
                "steps": steps,
                "gamma": gamma
            })
        trainer = Trainer(model, model_path, device, scaler, use_mlflow, **train_params)
        if configs["GLOBAL_THRESHOLDS"]:
            best_model, threshold_data = trainer.train(
                train_loader=train_loader,
                criterion=criterion,
                optimizer=optimizer,
                n_epochs=lead_configs["N_EPOCHS"],
                scheduler=scheduler,
                valid_loader=None,
            )
            with open(os.path.join(model_directory, "thresholds", f"{len(leads)}_threshold_data.json")) as json_file:
                threshold_data = json.load(json_file)
            
            sorted_idx = np.argsort(threshold_data["classes"])
            threshold_data["classes"] = np.array(threshold_data["classes"])[sorted_idx].tolist()
            threshold_data["thresholds"] = np.array(threshold_data["thresholds"])[sorted_idx].tolist()

        else:
            best_model, threshold_data = trainer.train(
                train_loader=train_loader,
                criterion=criterion,
                optimizer=optimizer,
                n_epochs=lead_configs["N_EPOCHS"],
                scheduler=scheduler,
                valid_loader=valid_loader,
                valid_period=1,
            )
        
        best_model.threshold_data = threshold_data
        torch.save(best_model, model_path)

        if use_mlflow:
            mlflow.log_dict(threshold_data, "threshold_data.json")
            mlflow.end_run()


################################################################################
#
# Running trained model function
#
################################################################################

# Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def run_model(model, header, recording):
    classes = model.threshold_data["classes"]
    thresholds =  model.threshold_data["thresholds"]
    num_classes = len(classes)
    if isinstance(model, torch.nn.DataParallel):
        model = model.module

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    leads = model.info_for_test["leads"]
    transform = getattr(
        __import__("src.augmentation.policies", fromlist=[""]), model.info_for_test["TRANS_BASE"]
    )(**model.info_for_test["TRANS_BASE_PARAMS"])
    
    recording_mean_per_lead = get_recording_mean_per_lead(recording)
    dataset_id = identify_dataset(header, recording_mean_per_lead)
    mask = get_target_classes_mask_by_sample_id(classes, dataset_id)
    model.eval()
    with torch.no_grad():
        recording = torch.Tensor(preprocess_recording(header, recording, leads, model.info_for_test["PREPROCESS_PARAMS"]))
        features = torch.Tensor(get_features(header, recording, leads, model.info_for_test["PREPROCESS_PARAMS"]))
        output, _ = model(transform(recording).unsqueeze(0).to(device), features.unsqueeze(0).to(device))
        probabilities = torch.sigmoid(output[0])
        labels = probabilities >= torch.Tensor(thresholds).to(device)

        labels = labels.cpu()

        if model.test_mask:
            mask = np.array(mask, dtype=bool)
            labels = labels * mask

    return classes, labels.numpy(), probabilities.cpu().numpy()

################################################################################
#
# File I/O functions
#
################################################################################

# Save a trained model. This function is not required. You can change or remove it.
def save_model(model_directory, leads, classes, imputer, classifier):
    d = {'leads': leads, 'classes': classes, 'imputer': imputer, 'classifier': classifier}
    filename = os.path.join(model_directory, get_model_filename(leads))
    joblib.dump(d, filename, protocol=0)

# Load a trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def load_model(model_directory, leads):
    configs = get_configs(model_directory)
    setup(configs)
    filename = os.path.join(model_directory, get_model_filename(leads))
    return torch.load(filename)

# Define the filename(s) for the trained models. This function is not required. You can change or remove it.
def get_model_filename(leads):
    return f"{len(leads)}_lead_model.pt"