#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of the required functions, remove non-required functions, and add your own functions.

################################################################################
#
# Imported functions and variables
#
################################################################################

#%%
# Import functions. These functions are not required. You can change or remove them.
import os

# from nevergrad.parametrization.core import P 
from helper_code import *
import numpy as np, os, sys, joblib, json
import pandas as pd
# PATHS = {"PROJECT_PATH": f"{os.getcwd()}/",
#          "LIB_PATH": f"{os.getcwd()}/lib/",
#          "ENGINE_PATH": f"{os.getcwd()}/src/engine/",
#          "DB_PATH": f"{os.getcwd()}/datasets/",
#          "CONFIG_PATH": f"{os.getcwd()}/configs/"}
# for k in PATHS:
#     sys.path.append(PATHS[k])

from src.engine.models.mlp import MLPVanilla
from src.engine.models.base import to_tensor
from src.engine.utils import set_device, init_mlflow, log_params_mlflow, to_numpy, SCORED_CLASSES, logit_To_Scalar_Binary

from src.engine.solver_raw import Solver
from src.engine.loaders.raw_testloader import TestDataset
from time import time, ctime

import coloredlogs, logging
coloredlogs.install()
logger = logging.getLogger(__name__)  

from omegaconf import OmegaConf
from shutil import rmtree
from pathlib import Path
import mlflow as mf

from src.engine.models import *
from torch.utils.data import DataLoader
from glob import glob
import torch

# Define the Challenge lead sets. These variables are not required. You can change or remove them.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)

################################################################################
#
# Training model function
#
################################################################################

# Train your model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    #%%
    # =============================================================================
    # initialize
    # =============================================================================
    for leads in lead_sets:
        config_path = f'./config/entry3/rsn_raw_{len(leads)}leads.yaml'
        time_start = time()
        config = OmegaConf.load(config_path)
        
        config.path.official_data_directory = data_directory
        config.path.model_directory = model_directory
        # config.path.mlflow_dir = model_directory+'/mlruns'
        logger.info(f"Data Path: {Path(config.path.data_directory).absolute()} | MLflow Path: {Path(config.path.mlflow_dir).absolute()}")
        logger.info(f"Training Site:{config.exp.train_sites} | Eval Sites:{config.exp.eval_sites}")
        
        logger.info("\n"+'='*120)
        logger.info(f'Training {len(leads)}-lead ECG model...')
        config.param_loader.leads = leads
        config.param_loader.num_leads = len(config.param_loader.leads)

        solver = Solver(config)
        self=solver
        #%%
        init_mlflow(config)
        with mf.start_run(run_name=f"{config.exp.N_fold}fold_CV_Results") as run:
            log_params_mlflow(config)
            cv_metrics = solver.evaluate()
            print(cv_metrics)
            mf.log_metrics(cv_metrics)
        time_now = time()
        logger.warning(f"{len(leads)}-Lead Time Used: {ctime(time_now-time_start)}")

        # =============================================================================
        # output
        # =============================================================================
        pytorch_lightning_ckpt_dir = Path("./lightning_logs/")
        if pytorch_lightning_ckpt_dir.exists(): rmtree(pytorch_lightning_ckpt_dir)


################################################################################
#
# Running trained model function
#
################################################################################

# Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
@torch.no_grad()
def run_model(model, header, recording):
    # Dataloader
    config_path = './config/entry3/rsn_raw_12leads.yaml'
    config = OmegaConf.load(config_path)
    print("config", config)
    testSet = TestDataset(config, header, recording)
    testLoader = DataLoader(dataset=testSet, batch_size=len(testSet), shuffle=False) 
    print("lenght of testloader", len(testLoader))
    logger.info(f"Test:{len(testSet)}") 
    testData = next(iter(testLoader))[0]

    # Predict with top_N models:
    top_n_prob = []
    top_n_lab = []
    for n_model in model:
        n_model.eval()
        # Prediction
        # logits = to_numpy(n_model(testData).mean(dim=0)) # mean the probability of patches
        # thresholds = n_model.bin_thre_arr
        # labels = np.array([0 if probabilities[i]<thre else 1 for i, thre in enumerate(thresholds)])
        
        logits = n_model(testData).mean(dim=0) # mean the probability of patches
        probabilities, labels = logit_To_Scalar_Binary(logits, n_model.bin_thre_arr.cpu()) 
        # probabilities==scalar_outputs, labels==binary_outputs
        probabilities = to_numpy(probabilities)
        labels = to_numpy(labels)
        top_n_prob.append(probabilities)
        top_n_lab.append(labels)
    
    probabilities = np.mean(np.array(top_n_prob), 0)
    print("prob", probabilities)
    labels = np.any(np.array(top_n_lab), 0).astype(int)
    print("labels", labels)
    classes = np.array(SCORED_CLASSES)
    print("classes", classes)
    return classes, labels, probabilities

################################################################################
#
# File I/O functions
#
################################################################################

# Save a trained model. This function is not required. You can change or remove it.
def save_model(model_directory, leads, classes, imputer, classifier):
    d = {'leads': leads, 'classes': classes, 'imputer': imputer, 'classifier': classifier}
    filename = os.path.join(model_directory, get_model_filename(leads))
    joblib.dump(d, filename, protocol=0)

# Load a trained model. This function is *required*. You should edit this function to add your code, but do *not* change the arguments of this function.
def load_model(model_directory, leads):
    config_path = f'./config/entry1/rsn_raw_{len(leads)}leads.yaml'
    config = OmegaConf.load(config_path)
    # solver = Solver(config)
    top_N = config.exp.top_N
    model_path = glob(f'{model_directory}/{len(leads)}leads*.ckpt')
    if len(model_path)==0:
        print(f"Not found model for {len(leads)} leads")
    model_results = [float(m.split('=')[-1][:5]) for m in model_path]
    top_N_idx = sorted(range(len(model_results)), key=lambda i: model_results[i])[-top_N:]

    model = [RsnVanilla.load_from_checkpoint(model_path[idx]) for idx in top_N_idx]
    # model = [solver._get_model(ckpt_path_abs=model_path[idx]) for idx in top_N_idx]
    print("length of models", len(model))
    return model


# Define the filename(s) for the trained models. This function is not required. You can change or remove it.
def get_model_filename(leads):
    sorted_leads = sort_leads(leads)
    return 'model_' + '-'.join(sorted_leads) + '.sav'

################################################################################
#
# Feature extraction function
#
################################################################################

# Extract features from the header and recording. This function is not required. You can change or remove it.
def get_features(header, recording, leads):
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    # Reorder/reselect leads in recordings.
    recording = choose_leads(recording, header, leads)

    # Pre-process recordings.
    adc_gains = get_adc_gains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]

    # Compute the root mean square of each ECG lead signal.
    rms = np.zeros(num_leads)
    for i in range(num_leads):
        x = recording[i, :]
        rms[i] = np.sqrt(np.sum(x**2) / np.size(x))

    return age, sex, rms
