import torch
import numpy as np
from hyperopt import fmin, tpe, hp, rand

from lib.models import attention_model
from lib.models import trainer_lib


def initialize_model(hparams, device):
    hparams_fixed = {
        'lstm_input_dim': 512,
#         'lstm_input_dim': 128,
        'lstm_hidden_dim': 256,
#         'lstm_hidden_dim': 128,
        'lstm_num_layers': 1,
        'num_classes': 24,
        'use_layer_norm': True,
    }
    hparams = {**hparams, **hparams_fixed}
    model  = attention_model.Attention(
        dim_header=3,
        dim_peak=5,
        dim_peak_delta=1,
        dim_fft=100 * 12,
        hparams=hparams,
        device=device
    ).to(device)
    
    return model


def experiment(X, Y, batch_size, epoches, device,
               train_ids, test_ids):
    def train_and_evaluate(hparams):
        print("________________________________")
        print("Experiment with hparams: ", hparams)
        model = initialize_model(hparams, device=device)
        optimizer=torch.optim.Adam(model.parameters(), hparams['lr'])
        trainer_lib.train_model(
            model, X, Y, optimizer, device=device, 
            batch_size=batch_size, epochs=epoches, ids=train_ids,
            verbose=False,
            grad_clip_norm=hparams['grad_clip_norm'],
            lambda_1=hparams['lambda_1'],
            lambda_2=hparams['lambda_2'])
        train_loss = trainer_lib.calc_loss(model, X, Y, batch_size=128,
                              device=device, ids=train_ids)
        print("Train loss was: ", train_loss)
        test_loss = trainer_lib.calc_loss(model, X, Y, batch_size=128,
                              device=device, ids=test_ids)
        print("Test loss was: ", test_loss)
        print("________________________________")
        return np.asscalar(test_loss)
    return train_and_evaluate


def tune_model(X, Y, batch_size, epoches,
               train_ids, test_ids,
               num_trials, device):
    
    parameters = {
        'lr': hp.loguniform('lr', -10, -5),
        'grad_clip_norm': hp.uniform('grad_clip_norm', 0, 10),
        'lambda_1': hp.loguniform('lambda_1', -10, 2),
        'lambda_2': hp.loguniform('lambda_2', -10, 2),
        'dropout': hp.choice('dropout', [0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5])
    }
    experiment_fn = experiment(X, Y, batch_size, epoches, device,
                               train_ids, test_ids)
    best = fmin(fn=experiment_fn,
                space=parameters,
                algo=tpe.suggest, #tpe.suggest,
                max_evals=num_trials)
    print("Best parameters", best)