#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of required
# functions, remove non-required functions, and add your own function.

from helper_code import *
import numpy as np
import os
import pathlib
from csem.io import read_json
from csem.data import load_data, transformer_preprocess, split_zeropad_save
from csem.data import FOLDER_NAME, transformer_preprocess
from csem.training import supertrain_transformer, get_input_params
from csem.load_model import load_input_parameters, load_transformer
from csem.post_processing import post_process


###############################################################################
#
# Definable constants
#
###############################################################################
fs = 500.0
training_folds = [0, 1, 2, 3]
validation_folds = [4]
config_name = "configuration.json"
twelve_lead_model_filename = 'submission_1'
six_lead_model_filename = twelve_lead_model_filename
three_lead_model_filename = twelve_lead_model_filename
two_lead_model_filename = twelve_lead_model_filename

###############################################################################
#
# Training function
#
###############################################################################


# Train your model.
# This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    # Find header and recording files.
    print('Finding header and recording files...')
    dataset = load_data(data_directory, "..", fs)
    # CSEM machine
    # data_folder = pathlib.Path(data_directory).parent
    # Submission
    data_folder = pathlib.Path(".")
    beats_folder = data_folder / FOLDER_NAME
    if not os.path.isdir(beats_folder):
        os.mkdir(beats_folder)
    config_folder = "config"
    input_params = get_input_params(
        config_folder, config_name, model_directory
    )
    # WARNING: by separating all the datasets into beats, the resulting beated
    # folder has a 140GB size > 100GB of challenge requirements
    # datasets = list(
    #     set(input_params['datasets'] + input_params['validation_datasets'])
    # )
    datasets = input_params['datasets']
    split_zeropad_save(
        dataset.iterate_records(fold_index=True), beats_folder, fs=fs,
        datasets=datasets, return_beats=False, supervised=True
    )
    del dataset
    pretrain_folder = "pretrained"
    supertrain_transformer(
        data_folder=data_folder,
        pretrain_folder=pretrain_folder,
        **input_params
    )

###############################################################################
#
# File I/O functions
#
###############################################################################


# Load your trained 12-lead ECG model.
# This function is *required*. Do *not* change the arguments of this function.
def load_twelve_lead_model(model_directory):
    filename = os.path.join(model_directory, twelve_lead_model_filename)
    return load_model(filename)


# Load your trained 6-lead ECG model.
# This function is *required*. Do *not* change the arguments of this function.
def load_six_lead_model(model_directory):
    filename = os.path.join(model_directory, six_lead_model_filename)
    return load_model(filename)


# Load your trained 3-lead ECG model.
# This function is *required*. Do *not* change the arguments of this function.
def load_three_lead_model(model_directory):
    filename = os.path.join(model_directory, three_lead_model_filename)
    return load_model(filename)


# Load your trained 2-lead ECG model.
# This function is *required*. Do *not* change the arguments of this function.
def load_two_lead_model(model_directory):
    filename = os.path.join(model_directory, two_lead_model_filename)
    return load_model(filename)


# Generic function for loading a model.
def load_model(model_directory, leads):
    filename = os.path.join(model_directory, twelve_lead_model_filename)
    filename = pathlib.Path(filename)
    model_name = filename.stem
    classes = read_json(filename / "classes.json")
    n_classes = len(classes)
    transformer = load_transformer(
        model_name, n_classes=n_classes, log_folder=filename.parent
    )
    return transformer, classes

###############################################################################
#
# Running trained model functions
#
###############################################################################


# Run your trained 12-lead ECG model.
# This function is *required*. Do *not* change the arguments of this function.
def run_twelve_lead_model(model, header, recording):
    return run_model(model, header, recording)


# Run your trained 6-lead ECG model.
# This function is *required*. Do *not* change the arguments of this function.
def run_six_lead_model(model, header, recording):
    return run_model(model, header, recording)


# Run your trained 3-lead ECG model.
# This function is *required*. Do *not* change the arguments of this function.
def run_three_lead_model(model, header, recording):
    return run_model(model, header, recording)


# Run your trained 2-lead ECG model.
# This function is *required*. Do *not* change the arguments of this function.
def run_two_lead_model(model, header, recording):
    return run_model(model, header, recording)


# Generic function for running a trained model.
def run_model(model, header, recording):
    model, classes = model
    n_leads = min(recording.shape)
    inp_padded, padding_mask = transformer_preprocess(
        recording, header, n_leads=n_leads, hp_filter=True
    )
    if inp_padded == "flat":
        # No scored class if recording is flat
        probabilities = np.zeros(len(classes))
    else:
        probabilities, _ = model(inp_padded, False, padding_mask)
    labels = post_process(probabilities)
    return classes, labels, np.array(probabilities).flatten()
