from torch.utils.data.sampler import WeightedRandomSampler

from Scoring_file import compute_challenge_metric, compute_confusion_matrices
from constants import *
from models import *
from metrics import *
from dataset import *
from utils import save_output
import time


# train over a batch of signals.
def train_signals_batch(label_batch, samples_batch, model, optimizer, clip_grad=1., train=True):
    # Samples_batch shape is (batch_size, n_lead, len_ecg)
    if train is False:
        with torch.no_grad():
            output = model(samples_batch)

            # Use max for cross-entropy
            # _, label_entropy = torch.max(label_batch, dim=1)
            # loss = criterion(output, label_entropy)  + second_criterion(output, label_batch)
            # loss = second_criterion(output, label_batch)
            # loss = criterion(output, label_batch) + second_criterion(output, label_batch)
            # loss = criterion(output, label_batch)
            loss = third_criterion(output, label_batch.type_as(output)) + criterion(output, label_batch)
            # loss = loss * loss_weights
            # loss = loss.mean() + criterion(output, label_batch)
    else:
        output = model(samples_batch)

        # _, label_entropy = torch.max(label_batch, dim=1)
        # loss = criterion(output, label_entropy)  + second_criterion(output, label_batch)
        # loss = second_criterion(output, label_batch)
        # loss = criterion(output, label_batch) + second_criterion(output, label_batch)
        # loss = criterion(output, label_batch)
        loss = third_criterion(output, label_batch.type_as(output)) + criterion(output, label_batch)
        # loss = loss * loss_weights
        # loss = loss.mean() + criterion(output, label_batch)
        optimizer.zero_grad()
        loss.backward()

        # Prevent exploding gradients
        if clip_grad > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

        optimizer.step()

    return output, loss.item()


def train_epoch(dl, model, lr_scheduler, optimizer, epoch_number, train=True, inference=False, batch_size=1024):
    epoch_loss_arr, epoch_fbeta_arr, challenge_metric_arr, binary_fbeta_arr = [], [], [], []
    beta_patho = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    running_loss = 0.0
    binary_result_model = torch.empty(0, n_pathologies + 1, dtype=torch.float)
    real_result_model = torch.empty(0, n_pathologies + 1, dtype=torch.float)
    true_labels = torch.empty(0, n_pathologies + 1, dtype=torch.long)
    tmp_multilabel = torch.zeros(n_pathologies + 1, dtype=torch.long)

    for i, data_batch in enumerate(dl):
        samples_batch, label_batch = data_batch
        tmp_multilabel += sum(label_batch)
        if samples_batch.shape[0] != batch_size:
            break
        # Label shape is (batch_size, n_pathologies)
        # Sample shape is (batch_size, n_leads, len_ecg)
        # move all of the data to GPU
        samples_batch.to(device)
        label_batch.to(device)

        batch_output, batch_loss = train_signals_batch(label_batch, samples_batch, model, optimizer, clip_grad=1,
                                                       train=train)

        # Saving stats for the batch
        running_loss += batch_loss
        epoch_loss_arr.append(batch_loss)
        epoch_fbeta_arr.append(f2_score(label_batch, batch_output).item())

        ones_tensor = torch.ones(batch_output.shape)
        zeroes_tensor = torch.zeros(batch_output.shape)
        tresh_23 = torch.ones((batch_size, n_pathologies)) * treshold_23
        tresh_snr = torch.ones((batch_size, 1)) * treshold_snr
        treshold_total = torch.cat((tresh_23, tresh_snr), dim=1)

        binary_output = torch.where(batch_output > treshold_total, ones_tensor, zeroes_tensor)
        challenge_metric = compute_challenge_metric(weights, label_batch.numpy(), binary_output.detach().numpy(),
                                                    compute_score, '426783006')

        challenge_metric_arr.append(challenge_metric)
        binary_fbeta_arr.append(f2_score(label_batch, binary_output).item())
        binary_result_model = torch.cat((binary_result_model, binary_output), dim=0)
        real_result_model = torch.cat((real_result_model, batch_output), dim=0)
        true_labels = torch.cat((true_labels, label_batch), dim=0)

    # for patho in range(n_pathologies + 1):
    #    param1 = true_labels[:, patho].view(batch_size, -1)
    #    param2 = binary_result_model[:, patho].view(batch_size, -1)
    #    beta_patho[patho] = f2_score(param1, param2)
    # Compute stats for the all epoch
    epoch_fbeta_score = np.mean(epoch_fbeta_arr)
    epoch_challenge_metric = np.mean(challenge_metric_arr)

    return epoch_fbeta_score, epoch_challenge_metric, binary_result_model, real_result_model


def train_model(trainset_dl, testset_dl, rnn):
    if load_model_name != "":
        rnn.load_state_dict(torch.load(load_model_name))
    optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True)

    max_score = 0
    max_epoch = 0
    for i in range(epochs_number):
        rnn.train()
        train_epoch(trainset_dl, rnn, lr_scheduler, optimizer, i, train=True)
        rnn.eval()
        curr_fbeta_score, metric_challenge, binary_output_model, output_model = train_epoch(testset_dl,
                                                                                            rnn, lr_scheduler,
                                                                                            optimizer, i, train=False)

    return max_score, rnn


def eval_model(testset_dl):
    model_eval = ecg_classifier(dropout_cnn_final, hidden_size_final, 0.2)
    model_eval.load_state_dict(torch.load(load_model_name))
    model_eval.eval()
    curr_fbeta_score, metric_challenge, binary_output_model, output_model = train_epoch(testset_dl, model_eval, None,
                                                                                        None, 0,
                                                                                        train=False, inference=True)
    save_output(output_model)


def run_model(max_length, num_training, num_test, training_list, test_list, input_directory):
    model = ecg_classifier(dropout_cnn=dropout_cnn_final, hidden_size=hidden_size_final, dropout_gru=0.2)

    training_set = Dataset(max_length, num_training, training_list, input_directory, training=True)
    test_set = Dataset(max_length, num_test, test_list, input_directory, training=True)

    # sampler = WeightedRandomSampler(loss_weights, len(loss_weights))

    trainset_dl = DataLoader(dataset=training_set,
                             batch_size=batch_size,
                             shuffle=True,
                             pin_memory=False)

    testset_dl = DataLoader(dataset=test_set,
                            batch_size=batch_size,
                            shuffle=False,
                            pin_memory=False)

    return train_model(trainset_dl, testset_dl, model)
