import pathlib
import tensorflow as tf
from .transformer_functions import CustomSchedule, TransformerEncoder
from .transformer_functions import (
    prepare_transformer,
    rms_tf,
    submitted_CNN,
    SubmittedCNN
)
from .io import read_json


def load_input_parameters(model_name, log_folder="."):
    model_folder = pathlib.Path(log_folder)
    model_folder = model_folder / model_name

    # Load the config file
    input_params = read_json(model_folder/"configuration.json")
    input_params['embedding_function'] = eval(
        input_params['embedding_function']
    )
    return input_params


def load_transformer(model_name, log_folder=".", n_classes=None):
    input_params = load_input_parameters(model_name, log_folder)
    checkpoint_path = pathlib.Path(log_folder) / model_name
    # Define the model
    print("Define the model")
    learning_rate = CustomSchedule(input_params['d_model'])
    optimizer = tf.keras.optimizers.Adam(
        learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9
    )
    transformer_params = {}
    for key in [
        'num_layers', 'num_heads', 'd_model', 'dff',
        'maximum_position_encoding', 'embedding_function'
    ]:
        transformer_params[key] = input_params[key]

    transformer = TransformerEncoder(
        rate=input_params['dropout_rate'], **transformer_params
    )

    if n_classes is not None:
        transformer = prepare_transformer(
            transformer, input_params['layers_to_freeze'],
            input_params['layers_to_remove'], n_classes,
            input_params['embedding_function']
        )

    # Load the trained weights
    ckpt = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(
        ckpt, checkpoint_path, max_to_keep=5
    )
    ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
    if ckpt_manager.latest_checkpoint:
        print("Restored from {}".format(ckpt_manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")
    return transformer
