import pathlib
import os
import psutil
import time
import copy
import tensorflow as tf
import numpy as np
from .io import read_json, write_json
from .utils import set_GPU
from .data import DataGeneratorAutoregressive
from .transformer_functions import (
    TransformerEncoder,
    rms_tf,
    CustomSchedule,
    MSEAccuracy,
    zeropad,
    create_padding_mask,
    loss_function,
    submitted_CNN,
    SubmittedCNN
)


# Define hyperparameters
def pretrain_transformer(
        data_folder, model_name, epochs, num_layers=6, num_heads=8,
        d_model=1000, dff=2048, maximum_position_encoding=50, dropout_rate=0.1,
        embedding_function=rms_tf, thrs=None, seed=7, batch_size=8,
        gpu_id=None, share_memory=False,
        checkpoint_folder="../../../../logs/transformer",
        load_checkpoint=False, datasets=None,
        training_folds=None, validation_folds=None
):
    """Unsupervised pretraining of transformer.

    Train transformer on autoregressive task in an unsupervised manner.

    Args:
        data_folder: path to datasets used for pretrain
        model_name: name of pretrain model
        num_layers: number of transformer encoder layers
                    default is 6, same as in original transformer
        num_heads: number of multihead attention heads
                   default is 8, same as in original transformer
        d_model: the inner dimension of the model and the dimension of the
                 "word vectors" (length of each beat)
                 default is 1000 to make sure to capture the slowest heart rate
        dff: the dimension of the final fully connected layer
             default is 2048, same as in original transformer
        embedding_function: function used to create "word vectors", has to be
                            compatible with Tensorflow graph mode
        maximum_position_encoding: maximum number of beats per recording
        dropout_rate: dropout rate of transformer
                      default is 0.1, same as in tutorial
        epochs: number of epochs to train
        thrs: thresholds on individual beat regression loss to consider it
              "correct" for accuracy metric [0, 4e-06, 6e-06, 1e-05, 1e-03]
        seed: seed for random number generators
        batch_size: (mini) batch size
        checkpoint_folder: folder to save the checkpoints of the model
        load_checkpoint: whether or not to load existing checkpoints in
                          checkpoint folder
        gpu_id: which GPU to train on. If None, all available GPUs will be used
        share_memory: Boolean to say whether or not to share memory
        datasets: A list of dataset used to pretrain the model
        training_folds: a list of integers (fold numbers) to train on
        validation_folds: a list of integers (fold numbers) to validate on
    """
    process = psutil.Process(os.getpid())
    print("Model parameters:")
    model_params = {
        "datasets": datasets,
        "training_folds": training_folds,
        "validation_folds": validation_folds,
        "model_name": model_name,
        "load_checkpoint": load_checkpoint,
        "num_layers": num_layers,
        "num_heads": num_heads,
        "d_model": d_model,
        "dff": dff,
        "embedding_function": embedding_function,
        "maximum_position_encoding": maximum_position_encoding,
        "dropout_rate": dropout_rate,
        "gpu_id": gpu_id,
        "share_memory": share_memory,
        "epochs": epochs,
        "seed": seed,
        "batch_size": batch_size
    }
    for key, val in model_params.items():
        print(key, val)
    if training_folds is None:
        training_folds = [0, 1, 2, 3]
    if datasets is None:
        datasets = ['StPetersburg']
    if thrs is None:
        thrs = [4e-06, 6e-06, 1e-05, 1e-03]
    if os.name == 'posix':
        if gpu_id is not None:
            set_GPU(gpu_id)
            print("GPU set to", gpu_id)
        if share_memory:
            share_memory()
            print("Memory being shared")

    np.random.seed(seed)
    tf.random.set_seed(seed)

    # Load the saved beat data generator
    print("Load the data")
    generator = DataGeneratorAutoregressive(
        batch_size=batch_size, root_path=data_folder, datasets=datasets,
        folds=training_folds
    )
    if validation_folds:
        generator_validation = DataGeneratorAutoregressive(
            batch_size=batch_size, root_path=data_folder,
            folds=validation_folds, datasets=datasets
        )

    # Define the model
    print("Define the model")
    transformer = TransformerEncoder(
        num_layers, d_model, num_heads, dff, embedding_function,
        maximum_position_encoding, rate=dropout_rate, seed=seed
    )

    # Define optimizer
    learning_rate = CustomSchedule(d_model)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9
    )
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    loss_object = tf.keras.losses.MeanSquaredError()

    # Define checkpoints
    checkpoint_path = pathlib.Path(checkpoint_folder) / pathlib.Path(
        model_name)
    checkpoint_path.mkdir(parents=True, exist_ok=True)
    ckpt = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(
        ckpt, checkpoint_path, max_to_keep=5
    )
    # If a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        if load_checkpoint:
            ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
            print('Latest checkpoint restored from {}'
                  .format(ckpt_manager.latest_checkpoint))
        else:
            raise ValueError("Loading checkpoint but load_checkpoint is False")
    else:
        if load_checkpoint:
            raise ValueError("Did not find checkpoint at", checkpoint_folder)

    train_accuracy = [MSEAccuracy(thr) for thr in thrs]

    def train_step(inp, tar, maximum_position_encoding):
        # print("Target shape", tar.shape)
        # print(len(inp), len(inp[0]), len(inp[0][0]), len(inp[0][0][0]))
        padding_mask = create_padding_mask(inp, maximum_position_encoding)
        # print("Mask shape", padding_mask.shape)
        inp_padded = zeropad(inp, maximum_position_encoding)
        # if len(inp_padded.shape) == 1:
        #     inp_padded = np.array([x for x in inp_padded])
        # print("Padding done", inp_padded.shape)
        # enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp) #gets created in the embedding

        with tf.GradientTape() as tape:
            try:
                predictions, attention_weights = transformer(
                    inp_padded, True, padding_mask
                )
                # MB taken by process
                print(
                    'memory usesd: ' + str(
                        process.memory_info().rss // 1024 ** 2) + 'MB'
                )
            except ValueError:
                raise ValueError
            # print("Predictions shape", predictions.shape)
            # print("Target shape", tar.shape)
            loss = loss_function(tar, predictions, loss_object)

        gradients = tape.gradient(loss, transformer.trainable_variables)
        optimizer.apply_gradients(
            zip(gradients, transformer.trainable_variables)
        )

        train_loss(loss)
        for accuracy in train_accuracy:
            accuracy(tar, predictions)

        # MB taken by process
        print(
            'memory usesd: ' + str(
                process.memory_info().rss // 1024 ** 2) + 'MB'
        )

    # Training loop
    print("Start training")
    for epoch in range(epochs):
        start = time.time()

        train_loss.reset_states()
        for accuracy in train_accuracy:
            accuracy.reset_states()

        # for (batch, (inp, tar)) in enumerate(train_dataset):
        for batch in range(generator.n_batches):
            inp, tar = generator.__getitem__(batch)
            # TODO: FILTER OUT WRONG LENGTH TARGETS/INPUTS
            tar = np.array(tar)
            if len(tar.shape) != 3:
                tar = np.array([beat for beat in tar])
            keep_idx = np.ones(len(tar))
            try:  # TODO: REMOVE THIS
                assert len(tar.shape) == 3
            except AssertionError:
                tar_temp = []
                for i, beat in enumerate(tar):
                    if beat.shape[-2] != 1000:
                        print("Removing recording with target shape",
                              beat.shape)
                        keep_idx[i] = 0
                    else:
                        tar_temp.append(beat)
                tar = np.array(tar_temp)
            keep_idx = keep_idx.astype(bool)
            assert len(tar.shape) == 3
            tar = rms_tf(np.array(tar))  # embedding_function(np.array(tar))
            inp = np.array(inp)[keep_idx]
            # print("Target shape", tar.shape)
            # print("Input shape", len(inp), len(inp[0]), len(inp[0][0]), len(inp[0][0][0]))
            train_step(inp, tar, maximum_position_encoding)

            if batch % 50 == 0:
                string_1 = 'Epoch: {} Batch: {} Loss: {:.7f}'
                string_2 = " ".join(['Accuracy ({:.0e}): {:.4f}'] * len(thrs))
                print(" ".join([string_1, string_2]).format(
                    epoch + 1, batch, train_loss.result(),
                    *np.concatenate([[thr, ta.result()] for thr, ta
                                     in zip(thrs, train_accuracy)])
                ), end="\r")

        if (epoch + 1) % 5 == 0:
            ckpt_save_path = ckpt_manager.save()
            print('Saving checkpoint for epoch {} at {}'
                  .format(epoch + 1, ckpt_save_path))

        if validation_folds:
            validation_loss = validate_epoch(
                transformer, generator_validation, embedding_function,
                maximum_position_encoding, loss_function, loss_object
            )
            # MB taken by process
            print(
                'memory usesd: ' + str(
                    process.memory_info().rss // 1024 ** 2) + 'MB'
            )
            string = 'Epoch: {}  ||  Training loss: {:.7f}  ' \
                     '||  Validation loss {:.7f}'
            print(string.format(epoch + 1,
                                train_loss.result(), validation_loss.result()))
        # string_1 = 'Epoch: {} Batch: {} Loss: {:.7f}'
        # string_2 = " ".join(['Training Accuracy ({:.0e}): {:.4f}']*len(thrs))
        # print (" ".join([string_1, string_2]).format(
        #     epoch + 1,
        #     batch,
        #     train_loss.result(),
        #     *np.concatenate([[thr, ta.result()] for thr, ta in zip(thrs, train_accuracy)])
        # ))

        t = time.time() - start
        print('Time taken for 1 epoch: {} secs\n'.format(t))
        write_json(checkpoint_path / "time.json", t)


def prepare_batch(
        generator, batch, embedding_function, maximum_position_encoding
):
    inp, tar = generator.__getitem__(batch)
    # TODO: FILTER OUT WRONG LENGTH TARGETS/INPUTS
    tar = np.array(tar)
    if len(tar.shape) != 3:
        tar = np.array([beat for beat in tar])
    keep_idx = np.ones(len(tar))
    try:  # TODO: REMOVE THIS
        assert len(tar.shape) == 3
    except AssertionError:
        tar_temp = []
        for i, beat in enumerate(tar):
            if beat.shape[-2] != 1000:
                print("Removing recording with target shape", beat.shape)
                keep_idx[i] = 0
            else:
                tar_temp.append(beat)
        tar = np.array(tar_temp)
    keep_idx = keep_idx.astype(bool)
    assert len(tar.shape) == 3
    tar = rms_tf(np.array(tar))  # embedding_function(np.array(tar))
    inp = np.array(inp)[keep_idx]
    padding_mask = create_padding_mask(inp, maximum_position_encoding)
    inp_padded = zeropad(inp, maximum_position_encoding)
    return inp_padded, tar, padding_mask


def validate_epoch(
        model, generator, embedding_function, maximum_position_encoding,
        loss_function, loss_object
):
    validation_loss = tf.keras.metrics.Mean(name='val_loss')
    print()
    for batch in range(generator.n_batches):
        if batch % 50 == 0:
            print("Validating batch:", batch, end="\r")
        inp_padded, tar, padding_mask = prepare_batch(
            generator, batch, embedding_function, maximum_position_encoding
        )
        predictions, _ = model(inp_padded, False, padding_mask)
        loss = loss_function(tar, predictions, loss_object)
        validation_loss(loss)
    print()
    return validation_loss


def save_logs(input_params):
    save_dict = copy.deepcopy(input_params)
    checkpoint_folder = save_dict['checkpoint_folder']
    model_name = save_dict['model_name']
    log_path = pathlib.Path(checkpoint_folder) / pathlib.Path(model_name)
    log_path.mkdir(parents=True, exist_ok=True)
    save_dict['embedding_function'] = save_dict['embedding_function'].__name__
    log_filename = log_path / pathlib.Path("configuration.json")
    write_json(log_filename, save_dict)


def load_config(model_folder):
    model_folder = pathlib.Path(model_folder)
    return read_json(model_folder / pathlib.Path("configuration.json"))


def get_input_params(config_folder, config_name, checkpoint_folder):
    config_params = read_json(pathlib.Path(config_folder) / config_name)
    if config_params['load_checkpoint']:
        load_model_name = config_params["model_name"]
        input_params = load_config(
            pathlib.Path(checkpoint_folder) / load_model_name
        )
        input_params['epochs'] = config_params['epochs']
        input_params['gpu_id'] = config_params['gpu_id']
        input_params['share_memory'] = config_params['share_memory']
        input_params['load_checkpoint'] = config_params['load_checkpoint']
    else:
        input_params = config_params
    input_params['checkpoint_folder'] = checkpoint_folder
    input_params['embedding_function'] = eval(
        input_params['embedding_function']
    )
    save_logs(input_params)
    return input_params
