import argparse
import os

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import yaml

from src.utils.team_helper_code import plot_confusion_matrix


parser = argparse.ArgumentParser()
parser.add_argument("--dir-path", "-d", type=str, required=True, help="directory path")
parser.add_argument("--lead", type=str)
args = parser.parse_args()

with open(os.path.join(args.dir_path, "model", "run_id.yaml")) as f:
    run_ids = yaml.load(f, Loader=yaml.FullLoader)

with open(os.path.join(args.dir_path, "model", "configs.yaml")) as f:
    global_config = yaml.load(f, Loader=yaml.FullLoader)


def logging(lead):
    if not global_config["USE_MLFLOW"]:
        return
    experiment = mlflow.get_experiment_by_name(lead)
    run_id = run_ids[lead]
    lead_dict = {
        "TWELVE": 12,
        "SIX": 6,
        "FOUR": 4,
        "THREE": 3,
        "TWO": 2
    }
    with mlflow.start_run(experiment_id=experiment.experiment_id, run_id=run_id):
        mlflow.log_artifact(os.path.join(args.dir_path, f"scores_{lead_dict[lead]}.csv"))
        mlflow.log_artifact(os.path.join(args.dir_path, f"class_scores_{lead_dict[lead]}.csv"))
        scores = pd.read_csv(os.path.join(args.dir_path, f"scores_{lead_dict[lead]}.csv")).loc[0].to_dict()
        mlflow.log_metrics(scores)
        
        cm_df = pd.read_csv(os.path.join(args.dir_path, f"cm_{lead_dict[lead]}.csv"), index_col=0)
        classes = list(map(lambda x : list(eval(x))[0] if len(eval(x)) == 1 else f"{list(eval(x))[0]}_", cm_df.columns))
        normalized_fig = plot_confusion_matrix(cm_df.values, target_names=classes, normalize=True)
        plt.savefig(os.path.join(args.dir_path, 'test_confusion_matrix_normalized.png'), dpi=300)
        unnormalized_fig = plot_confusion_matrix(cm_df.values, target_names=classes, normalize=False)
        plt.savefig(os.path.join(args.dir_path, 'test_confusion_matrix_unnormalized.png'), dpi=300)
        
        mlflow.log_figure(normalized_fig, "test_confusion_matrix_normalized.png")
        mlflow.log_figure(unnormalized_fig, "test_confusion_matrix_unnormalized.png")


if __name__ == "__main__":
    if global_config["ON_TRAINING"] is not None:
        logging(args.lead)
