import os
import psutil
import pathlib
import time
import copy
import tensorflow as tf
import numpy as np
from .io import write_json, read_json
from .metrics import CincScore2
from .utils import set_GPU, load_weights
from .data import DataGeneratorSupervised
from .transformer_functions import (
    TransformerEncoder,
    rms_tf,
    CustomSchedule,
    zeropad,
    create_padding_mask,
    loss_function,
    prepare_transformer,
    submitted_CNN,
    SubmittedCNN
)


def supertrain_transformer(
        model_name, epochs, data_folder, checkpoint_folder,
        pretrain_folder=None, num_layers=6, num_heads=8, d_model=1000,
        dff=2048, maximum_position_encoding=50, dropout_rate=0.1,
        embedding_function=rms_tf, seed=7, batch_size=8, gpu_id=None,
        share_memory=False, load_checkpoint=False, load_pretrained=True,
        pretrained_model_name=None, layers_to_remove=None,
        layers_to_freeze=None, datasets=None, training_folds=None,
        validation_datasets=None, validation_folds=None
):
    """Supervised pretraining of transformer.

    Train transformer on labelled supervised task.

    Args:
        model_name: name of model
        data_folder: path to datasets used for training
        pretrain_folder: path to pretrain model
        num_layers: number of transformer encoder layers
                    default is 6, same as in original transformer
        num_heads: number of multihead attention heads
                   default is 8, same as in original transformer
        d_model: the inner dimension of the model and the dimension of the
                 "word vectors" (length of each beat)
                 default is 1000 to make sure to capture the slowest heart rate
        dff: the dimension of the final fully connected layer
             default is 2048, same as in original transformer
        embedding_function: function used to create "word vectors", has to be
                            compatible with Tensorflow graph mode
        maximum_position_encoding: maximum number of beats per recording
        dropout_rate: dropout rate of transformer
                      default is 0.1, same as in tutorial
        epochs: number of epochs to train
        thrs: thresholds on individual beat regression loss to consider it
              "correct" for accuracy metric [4e-06, 6e-06, 1e-05, 1e-03]
        seed: seed for random number generators
        batch_size: (mini) batch size
        checkpoint_folder: folder to save the checkpoints of the model
        load_checkpoint: whether or not to load existing checkpoints in
                         checkpoint folder
        gpu_id: which GPU to train on. If None, all available GPUs will be used
        share_memory: Boolean to say whether or not to share memory
        datasets: A list of dataset names (strings) to include in training
        validation_datasets: A list of dataset names (strings) for validation
        training_folds: a list of integers (fold numbers) to train on
        validation_folds: a list of integers (fold numbers) to validate on
        load_pretrained: a boolean, whether or not to load a pretrained model
                         the last layer of this model will be removed and
                         replaced with a linear classification layer
        pretrained_model_name: a string, name of the pretrained model to load
        layers_to_remove: an integer of layers to remove from the end of the
                          layer list (excluding last linear layer)
                          Ex: 0 removes just the last linear layer and replaces
                              it with another linear layer
                          Ex: 1 removes the last encoder layer etc.
        layers_to_freeze: integer of encoder layers to freeze during training.
                          Ex: 0 will retrain the whole model
                          Ex: 1 will freeze only the first encoder layer etc.
    """
    process = psutil.Process(os.getpid())
    print("Model parameters:")
    model_params = {
        "datasets": datasets,
        "validation_datasets": validation_datasets,
        "training_folds": training_folds,
        "validation_folds": validation_folds,
        "pretrained_model_name": pretrained_model_name,
        "load_pretrained": load_pretrained,
        "model_name": model_name,
        "load_checkpoint": load_checkpoint,
        "num_layers": num_layers,
        "num_heads": num_heads,
        "d_model": d_model,
        "dff": dff,
        "layers_to_remove": layers_to_remove,
        "layers_to_freeze": layers_to_freeze,
        "embedding_function": embedding_function,
        "maximum_position_encoding": maximum_position_encoding,
        "dropout_rate": dropout_rate,
        "gpu_id": gpu_id,
        "share_memory": share_memory,
        "epochs": epochs,
        "seed": seed,
        "batch_size": batch_size
    }
    for key, val in model_params.items():
        print(key, val)
    if training_folds is None:
        training_folds = [0, 1, 2, 3]
    if datasets is None:
        datasets = ['Georgia', 'PTB_XL']
    if validation_datasets is None:
        validation_datasets = [
            'Georgia', 'PTB_XL', 'CPSC_1', 'CPSC_2', 'PTB', 'StPetersburg'
        ]
    if gpu_id is not None:
        set_GPU(gpu_id)
        print("GPU set to", gpu_id)
    if share_memory:
        share_memory()
        print("Memory being shared")

    np.random.seed(seed)
    tf.random.set_seed(seed)

    # Load the saved beat data generator
    print("Load the data")
    # classes = get_classes(datasets)
    print("data folder:", data_folder)
    generator = DataGeneratorSupervised(
        batch_size=batch_size, root_path=data_folder,  # classes=classes,
        datasets=datasets, folds=training_folds
    )
    # Save number of classes.json
    if pretrain_folder is None:
        pretrain_folder = checkpoint_folder

    checkpoint_folder = pathlib.Path(checkpoint_folder)
    checkpoint_path = checkpoint_folder / model_name
    checkpoint_path.mkdir(parents=True, exist_ok=True)
    pretrain_folder = pathlib.Path(pretrain_folder)
    write_json(checkpoint_path / "classes.json", generator.encoder.classes)

    if validation_folds:
        generator_validation = DataGeneratorSupervised(
            batch_size=batch_size, root_path=data_folder,  # classes=classes,
            folds=validation_folds, datasets=validation_datasets
        )

    # Define the model
    print("Define the model")
    transformer = TransformerEncoder(
        num_layers, d_model, num_heads, dff, embedding_function,
        maximum_position_encoding, rate=dropout_rate, seed=seed
    )

    # Define optimizer
    learning_rate = CustomSchedule(d_model)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9
    )
    loss_object = tf.keras.losses.BinaryCrossentropy()
    train_loss = tf.keras.metrics.Mean(name='train_loss')

    # Define metrics
    scored_classes = generator.encoder.classes
    train_accuracy = tf.keras.metrics.BinaryAccuracy()
    weights = load_weights("weights.csv", scored_classes)
    train_score = CincScore2(weights, scored_classes)

    if load_pretrained:
        # Load the saved beat data generator
        pre_checkpoint_path = pretrain_folder / pretrained_model_name
        pre_ckpt = tf.train.Checkpoint(
            transformer=transformer, optimizer=optimizer
        )
        pre_ckpt_manager = tf.train.CheckpointManager(
            pre_ckpt, pre_checkpoint_path, max_to_keep=5
        )
        pre_ckpt.restore(pre_ckpt_manager.latest_checkpoint).expect_partial()
        if pre_ckpt_manager.latest_checkpoint:
            print(
                "Restored from {}".format(pre_ckpt_manager.latest_checkpoint))
        else:
            raise FileNotFoundError("No pretrained model found at",
                                    pre_checkpoint_path)

    n_classes = len(scored_classes)
    print("Number of classes", n_classes)
    transformer = prepare_transformer(
        transformer, layers_to_freeze, layers_to_remove, n_classes,
        embedding_function
    )

    # Define checkpoints
    ckpt = tf.train.Checkpoint(
        transformer=transformer, optimizer=optimizer
    )
    ckpt_manager = tf.train.CheckpointManager(
        ckpt, checkpoint_path, max_to_keep=5
    )
    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        if load_checkpoint:
            ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
            first_epoch = ckpt.save_counter.numpy() * 5
            if load_pretrained:
                print("WARNING: Pretrained checkpoint overwritten "
                      "by supervised checkpoint!")
            print('Latest checkpoint restored from {}'.format(
                ckpt_manager.latest_checkpoint))
        else:
            print("WARNING: Overwriting exsisting checkpoints!")
            first_epoch = 0
    else:
        first_epoch = 0
        if load_checkpoint:
            raise ValueError("Did not find checkpoint at", checkpoint_folder)

    def train_step(inp, tar, maximum_position_encoding):
        padding_mask = create_padding_mask(inp, maximum_position_encoding)
        inp_padded = zeropad(inp, maximum_position_encoding)

        with tf.GradientTape() as tape:
            try:
                predictions, attention_weights = transformer(
                    inp_padded, True, padding_mask
                )
                # MB taken by process
                # print(
                #     'memory usesd: ' +
                #     str(process.memory_info().rss // 1024 ** 2) + 'MB'
                # )
            except ValueError:
                raise ValueError
            loss = loss_function(tar, predictions, loss_object)

        gradients = tape.gradient(loss, transformer.trainable_variables)
        optimizer.apply_gradients(
            zip(gradients, transformer.trainable_variables)
        )

        train_loss(loss)
        assert tar.shape == predictions.shape
        train_accuracy(tar, predictions)
        train_score(tar, predictions)

        # MB taken by process
        # print(
        #     'memory usesd: ' + str(process.memory_info().rss // 1024**2) + 'MB'
        # )

    # Training loop
    print("Start training")
    if not load_checkpoint:
        accuracy_metrics = {
            "train_loss": [],
            "train_accuracy": [],
            "train_score": [],
            "val_loss": [],
            "val_accuracy": [],
            "val_score": [],
        }
    else:
        accuracy_metrics = read_json(checkpoint_path / "accuracy_metrics.json")

    for epoch in range(first_epoch, epochs):
        start = time.time()

        train_loss.reset_states()
        train_accuracy.reset_states()
        train_score.reset_states()

        for batch in range(generator.n_batches):
            inp, tar = generator.__getitem__(batch)
            tar = np.array(tar)
            train_step(inp, tar, maximum_position_encoding)

            if batch % 10 == 0:
                string_1 = 'Epoch: {} Batch: {} Loss: {:.7f}'
                string_2 = " ".join([
                    'Binary accuracy: {:.4f}',
                    'Challenge metric: {:.4f}'
                ])
                print(" ".join([string_1, string_2]).format(
                    epoch + 1, batch, train_loss.result(),
                    train_accuracy.result(),
                    train_score.result()
                ), end="\r")

        if (epoch + 1) % 5 == 0:
            ckpt_save_path = ckpt_manager.save()
            print('Saving checkpoint for epoch {} at {}'
                  .format(epoch + 1, ckpt_save_path))

        if validation_folds:
            if "val_accuracy" not in locals():
                val_accuracy = tf.keras.metrics.BinaryAccuracy()
                weights = load_weights("weights.csv", scored_classes)
                val_score = CincScore2(weights, scored_classes)
                val_loss = tf.keras.metrics.Mean(name='val_loss')

            val_loss, val_accuracy, val_score = validate_epoch(
                transformer, generator_validation, embedding_function,
                maximum_position_encoding, loss_function, loss_object,
                val_loss, val_accuracy, val_score
            )
            # MB taken by process
            # print(
            #     'memory usesd: ' +
            #     str(process.memory_info().rss // 1024 ** 2) + 'MB'
            # )

        string_1 = 'Epoch: {}'.format(epoch + 1)
        string_2 = '  ||  '.join([
            "Training loss: {:.7f}",
            "Training accuracy: {:.4f}",
            "Training CinC2020 score: {:.4f}"
        ]).format(
            train_loss.result(),
            train_accuracy.result(),
            train_score.result()
        )
        string_3 = '  ||  '.join([
            "Validation loss: {:.7f}",
            "Validation accuracy: {:.4f}",
            "Validation CinC2020 score: {:.4f}"
        ]).format(
            val_loss.result(),
            val_accuracy.result(),
            val_score.result()
        )
        print(string_1)
        print(string_2)
        print(string_3)

        t = time.time() - start
        print('Time taken for 1 epoch: {} secs\n'.format(t))
        write_json(checkpoint_path / "time.json", t)
        for key in accuracy_metrics:
            accuracy_metrics[key].append(float(eval(key).result().numpy()))
        write_json(checkpoint_path / "accuracy_metrics.json", accuracy_metrics)


def prepare_batch(
        generator, batch, embedding_function, maximum_position_encoding
):
    inp, tar = generator.__getitem__(batch)
    tar = np.array(tar)
    padding_mask = create_padding_mask(inp, maximum_position_encoding)
    inp_padded = zeropad(inp, maximum_position_encoding)
    return inp_padded, tar, padding_mask


def validate_epoch(
        model, generator, embedding_function, maximum_position_encoding,
        loss_function, loss_object, val_loss, val_accuracy, val_score
):
    val_loss.reset_states()
    val_accuracy.reset_states()
    val_score.reset_states()
    print()
    for batch in range(generator.n_batches):
        if batch % 10 == 0:
            print("Validating batch:", batch, end="\r")
        inp_padded, tar, padding_mask = prepare_batch(
            generator, batch, embedding_function, maximum_position_encoding
        )
        predictions, _ = model(inp_padded, False, padding_mask)
        loss = loss_function(tar, predictions, loss_object)
        val_loss(loss)
        val_accuracy(tar, predictions)
        val_score(tar, predictions)
    print()
    return val_loss, val_accuracy, val_score


def save_logs(input_params):
    save_dict = copy.deepcopy(input_params)
    checkpoint_folder = save_dict['checkpoint_folder']
    model_name = save_dict['model_name']
    log_path = pathlib.Path(checkpoint_folder) / pathlib.Path(model_name)
    log_path.mkdir(parents=True, exist_ok=True)
    save_dict['embedding_function'] = save_dict['embedding_function'].__name__
    log_filename = log_path / pathlib.Path("configuration.json")
    write_json(log_filename, save_dict)


def load_config(model_folder):
    model_folder = pathlib.Path(model_folder)
    return read_json(model_folder / pathlib.Path("configuration.json"))


def get_input_params(config_folder, config_name, checkpoint_folder):
    config_params = read_json(pathlib.Path(config_folder) / config_name)
    if config_params['load_checkpoint']:
        load_model_name = config_params["model_name"]
        input_params = load_config(
            pathlib.Path(checkpoint_folder) / load_model_name)
        input_params['gpu_id'] = config_params['gpu_id']
        input_params['share_memory'] = config_params['share_memory']
        input_params['epochs'] = config_params['epochs']
        input_params['load_checkpoint'] = config_params['load_checkpoint']
    elif config_params['load_pretrained']:
        load_model_name = config_params["pretrained_model_name"]
        input_params = load_config(
            pathlib.Path(checkpoint_folder) / load_model_name)
        input_params['datasets'] = config_params['datasets']
        input_params['validation_datasets'] = config_params[
            'validation_datasets']
        input_params['training_folds'] = config_params['training_folds']
        input_params['validation_folds'] = config_params['validation_folds']
        input_params['gpu_id'] = config_params['gpu_id']
        input_params['share_memory'] = config_params['share_memory']
        input_params['epochs'] = config_params['epochs']
        input_params['load_checkpoint'] = config_params['load_checkpoint']
        input_params['pretrained_model_name'] = config_params[
            'pretrained_model_name']
        input_params['load_pretrained'] = config_params['load_pretrained']
        input_params['model_name'] = config_params['model_name']
        input_params['layers_to_freeze'] = config_params['layers_to_freeze']
        input_params['layers_to_remove'] = config_params['layers_to_remove']
        input_params['seed'] = config_params['seed']
        input_params.pop('thrs', None)
    else:
        input_params = config_params
    input_params['checkpoint_folder'] = checkpoint_folder
    input_params['embedding_function'] = eval(
        input_params['embedding_function'])
    save_logs(input_params)
    return input_params
