import torch
import numpy as np
from time import time
#from ax.service.managed_loop import optimize
from ax import optimize
from bayes_opt import BayesianOptimization
from hyperopt import fmin, tpe, hp, rand
#import hyperopt


import lib.models.batcher_lib as batcher_lib
from lib.scoring.score_challenge_lib import compute_challenge_metric 

def train_model(model, X, Y, optimizer, batch_size=64, epochs=100,
                device='cpu', ids=None, verbose=True, grad_clip_norm=None,
                lambda_1=0, lambda_2=0):
    def print_x(*args):
        if verbose:
            print(*args)
    losses = []
    total = Y.shape[0]
    total_positives = torch.sum(torch.tensor(Y, dtype=torch.float), dim=0).to(device)
    total_negatives = total - total_positives
    print_x(total_negatives) 
    weight = total_negatives / total_positives
    loss_fn = torch.nn.BCELoss(weight=weight)
    loss_fn = torch.nn.BCELoss()


    for i in range(epochs):
        start = time()
        
        print_x("---------- Running EPOCH --------",i + 1) 
        total_loss = 0
        n_batches = 0
        batcher = batcher_lib.batcher(X, Y,
                                      batch_size=batch_size,
                                      device=device,
                                      ids=ids)

        for batch_idx, (x_batch, y_batch) in enumerate(batcher):
            if (batch_idx % 100 == 0):
                print_x(batch_idx) if verbose else ""
            y_pred = model.forward(x_batch)
            loss = loss_fn(y_pred, y_batch) 
            loss = loss.mean()

            all_params = torch.cat([x.view(-1) for x in model.parameters()])

            l1_reg = lambda_1 * torch.norm(all_params, 1) / all_params.shape[0]
            l2_reg = lambda_2 * torch.norm(all_params, 2) / all_params.shape[0]

            loss = loss + l1_reg + l2_reg

#                 reg_loss = 0
#                 for param in self.parameters():
#                     reg_loss += param.norm()

#                 loss = loss + C * reg_loss
            total_loss+=loss.data
            optimizer.zero_grad()
            loss.backward()
            if grad_clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)


            optimizer.step()
            n_batches+=1

        losses.append(total_loss / n_batches)
        # print(total_loss / n_batches)
        if i % 10 == 0:
            print_x(losses[-10:])
        
        
@torch.no_grad()        
def evaluate_model(model, X, Y, classes, weights, device='cpu', batch_size=512, ids=None,
                   thresholds=None):
    
    Y_pred = []
    Y_orig = []
    model.eval()
    batcher = batcher_lib.batcher(X, Y,
                                  batch_size=batch_size,
                                  device=device,
                                  ids=ids)

    for batch_idx, (x_batch, y_batch) in enumerate(batcher):
        if (batch_idx % 100 == 0):
            print(batch_idx)
        y_pred = model.forward(x_batch)
        Y_pred.append(y_pred)
        Y_orig.append(y_batch)
        
    Y_pred = torch.cat(Y_pred)
    Y_orig = torch.cat(Y_orig)
    print(Y_pred.shape)
    start = time()
    Y_orig_cpu = Y_orig.detach().cpu().numpy()
    Y_pred_cpu = Y_pred.detach().cpu().numpy()
    Y_pred_cpu_binary = Y_pred_cpu
    if thresholds is not None:
        Y_pred_cpu_binary = Y_pred_cpu > thresholds
    val = compute_challenge_metric(weights, Y_orig_cpu, Y_pred_cpu_binary, classes)
    print("Computing the challenge metric took: ", time() - start)
    print("Value: ", val)

    return Y_orig_cpu, Y_pred_cpu, val
    

def tune_thresholds(Y_orig, Y_pred, weights, classes):
    
    def get_score(parameters):
        thresholds = np.zeros(24)
        for idx, threshold in parameters.items():
            #print(idx)
            thresholds[int(idx)] = threshold
        score = compute_challenge_metric(weights, Y_orig,
                                         Y_pred > thresholds, classes)
        print(score)
        return score


    
    parameters = [
        {"name": str(i), "type": "range", "bounds": [0.0, 1.0]} for i in range(24)
    ]
    best_parameters, values, experiment, model = optimize(
    parameters=parameters,
    evaluation_function=get_score,
    minimize=False,
    total_trials=200, arms_per_trial=1,
    #objective_name='score',) 
    )
    print(best_parameters, values)
    
    
def tune_thresholds_2(Y_orig, Y_pred, weights, classes):
    
    def get_score(**parameters):
        thresholds = np.zeros(24)
        for idx, threshold in parameters.items():
            #print(idx)
            thresholds[int(idx)] = threshold
        score = compute_challenge_metric(weights, Y_orig,
                                         Y_pred > thresholds, classes)
        print(score)
        return score


    
    parameters = {}
    for i in range(24):
        parameters[str(i)] = [0.0, 1.0]
        
    optimizer = BayesianOptimization(
        f=get_score,
        pbounds=parameters,
        verbose=2,
        random_state=1,
    )
    optimizer.maximize(
    init_points=2,
    n_iter=200,
    )
  

def tune_thresholds_3(Y_orig, Y_pred, weights, classes):
    
    def get_score(*parameters):
        thresholds = np.zeros(24)
        # print("function enter")
        parameters = parameters[0]
        # print(parameters)
        for idx, threshold in parameters.items():
            #print(idx)
            # print(threshold, idx)
            thresholds[int(idx)] = threshold
            # print(thresholds)
        score = compute_challenge_metric(weights, Y_orig,
                                         Y_pred > thresholds, classes)
        print(score)
        return -score


    
    parameters = {}
    for i in range(24):
        parameters[str(i)] = hp.uniform(str(i), 0, 1)
    
    best = fmin(fn=get_score,
                space=parameters,
                algo=tpe.suggest, #tpe.suggest,
                max_evals=200)
    
    print(best)
        

def tune_thresholds_4(Y_orig, Y_pred, weights, classes, percent=20):
    
    def get_score(*parameters):
        thresholds = np.zeros(24)
        # print("function enter")
        parameters = parameters[0]
        # print(parameters)
        for idx, threshold in parameters.items():
            #print(idx)
            # print(threshold, idx)
            thresholds[int(idx)] = threshold
            # print(thresholds)
        score = compute_challenge_metric(weights, Y_orig,
                                         Y_pred > thresholds, classes)
        print(score)
        return score


    
    parameters = {}
    for i in range(24):
        vals_pos = [Y_pred[j, i] for j in range(Y_pred.shape[0]) if Y_orig[j, i] == 1]
        n_pos = len(vals_pos)
        percent_skip_pos = percent * n_pos // 100
        
        parameters[str(i)] = sorted(vals_pos)[percent_skip_pos]
    
    score = get_score(parameters)
    
    print(parameters)
    print("HEURISTIC", score)
        
        
        
def tune_thresholds_5(Y_orig, Y_pred, weights, classes, percent_skip_pos=5, num_trials=200):
    
    def get_score(*parameters):
        thresholds = np.zeros(24)
        # print("function enter")
        parameters = parameters[0]
        # print(parameters)
        for idx, threshold in parameters.items():
            #print(idx)
            # print(threshold, idx)
            thresholds[int(idx)] = threshold
            # print(thresholds)
        score = compute_challenge_metric(weights, Y_orig,
                                         Y_pred > thresholds, classes)
        print(score)
        return -score


    
    parameters = {}
    for i in range(24):
        vals_pos = [Y_pred[j, i] for j in range(Y_pred.shape[0]) if Y_orig[j, i] == 1]
        n_pos = len(vals_pos)
        num_boundary = percent_skip_pos * n_pos // 100
        low = sorted(vals_pos)[percent_skip_pos]
        high = sorted(vals_pos)[n_pos - percent_skip_pos]
        parameters[str(i)] = hp.uniform(str(i), low, high)
    
    best = fmin(fn=get_score,
                space=parameters,
                algo=tpe.suggest, #tpe.suggest,
                max_evals=num_trials)
    
    thresholds = np.zeros(24)
    for idx, threshold in best.items():
            #print(idx)
            # print(threshold, idx)
            thresholds[int(idx)] = threshold
    print(best)
    return thresholds
    
#from bayes_opt import SequentialDomainReductionTransformer
    
def tune_thresholds_6(Y_orig, Y_pred, weights, classes, percent_skip_pos=5):
    
    def get_score(**parameters):
        thresholds = np.zeros(24)
        for idx, threshold in parameters.items():
            #print(idx)
            thresholds[int(idx)] = threshold
        score = compute_challenge_metric(weights, Y_orig,
                                         Y_pred > thresholds, classes)
        print(score)
        return score


    
    parameters = {}
    for i in range(24):
        vals_pos = [Y_pred[j, i] for j in range(Y_pred.shape[0]) if Y_orig[j, i] == 1]
        n_pos = len(vals_pos)
        num_boundary = percent_skip_pos * n_pos // 100
        low = sorted(vals_pos)[percent_skip_pos]
        high = sorted(vals_pos)[n_pos - percent_skip_pos]
        parameters[str(i)] = [low, high]
        
    optimizer = BayesianOptimization(
        f=get_score,
        pbounds=parameters,
        verbose=2,
        random_state=1,
        bounds_transformer = SequentialDomainReductionTransformer()
    )
    optimizer.maximize(
    init_points=2,
    n_iter=200,
    )
  


@torch.no_grad()
def calc_loss(model, X, Y, batch_size=64,
              device='cpu', ids=None):
    

    loss_fn = torch.nn.BCELoss(reduction='none')

    total_loss = 0
    total_pts = 0
    batcher = batcher_lib.batcher(X, Y,
                                  batch_size=batch_size,
                                  device=device,
                                  ids=ids)

    for batch_idx, (x_batch, y_batch) in enumerate(batcher):
        y_pred = model.forward(x_batch)
        loss = loss_fn(y_pred, y_batch)
        loss = loss.sum()
        total_loss += loss.data
        total_pts += y_pred.shape[0]

    return (total_loss / total_pts).detach().cpu().numpy()

