from abc import ABC

import dataLoading
import dataReading
import modelDefinition
import modelEvaluation
import utilityFunctions
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.sampler import RandomSampler
import os
import time
import shutil


def start_simple_training(input_directory, save_dir, split_value=0.8, shuffle_seed=3):

    # get the list of files
    path_list = dataReading.generate_path_list(input_directory)

    # make the train/test split
    random.seed(shuffle_seed)
    random.shuffle(path_list)
    split_int = int(len(path_list)*split_value)
    train_data, test_data = path_list[:split_int], path_list[split_int:]

    # train a single network
    train_network(train_data, test_data, save_dir, epochs=30)


def train_network(train_data, test_data, latest_subdir, epochs):

    # set device
    assert torch.cuda.is_available(), 'Cuda should be available.'
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # initialize  value_savers
    performance_saver = utilityFunctions.ValueSaver()
    batch_timer = utilityFunctions.ValueSaver()

    # initialize the dataset
    train_data = dataLoading.CardiacAbnormalitiesDataset(train_data, 'classes.txt', train=True)
    train_data.get_class_list()
    test_data = dataLoading.CardiacAbnormalitiesDataset(test_data, 'classes.txt', train=False)

    # initialize the model
    net = modelDefinition.GenericResNet(features=16, classes=24, kernel_size=5, padding=2,
                                        down_kernel_size=5, down_padding=2)

    net = net.to(device)

    # count the parameters
    parameter_amount = sum([param.nelement() for param in net.parameters()])
    print(f'The Model has {parameter_amount:,} parameters.')

    # initialize optimizer
    optimizer = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-5)

    # initialize a loss
    criterion = nn.BCEWithLogitsLoss().to(device)

    for epoch in range(epochs):

        # introduce a running loss
        loss = torch.zeros(1).to(device)

        # define the batch size
        batch_size = 64

        # make a new sampler and an iterable object out of the data sampler
        sampler = RandomSampler(train_data)
        idx_ide = iter(sampler)

        # set the gradients to zero
        optimizer.zero_grad()

        # start a timer
        start = time.time()

        # reset the timer value saver
        batch_timer.reset()

        for i in range(len(train_data)):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels, _, spl = train_data[next(idx_ide)]
            inputs, labels = inputs.to(device).unsqueeze(0), labels.to(device).unsqueeze(0)
            spl = torch.tensor([[spl['age'], spl['sex']]]).to(device).float()

            # forward + backward + optimize
            outputs = net(inputs, spl)
            loss += criterion(outputs, labels)

            if (i + 1) % batch_size == 0 and i > 0:

                # average the loss over the minibatch
                loss = loss / batch_size

                # make the backward calculation
                loss.backward()

                # make a step after a batch of 32
                optimizer.step()

                # print some statistics
                timing = ((time.time() - start) / (i + 1)) * batch_size
                batch_timer(timing/batch_size)
                remaining = utilityFunctions.convert_seconds(batch_timer.mean()*(len(train_data)-i))
                print(f'Iteration {i:05d} in epoch {epoch:04d}: {loss[-1].item():05.4f} ({timing:05.2f} s per batch). '
                      f'Remaining epoch time: {remaining}.')

                # reset the loss
                loss = torch.zeros(1).to(device)

                # zero the gradients afterwards
                optimizer.zero_grad()

            if (i + 1) % 4096 == 0 and i > 0:
                # we will evaluate the model
                # print('\nEvaluation on training data.')
                # t_auroc, t_auprc, _, _, _, _, _ = modelEvaluation.evaluate_model(net, train_data, device, 500)

                # we will evaluate the model
                print('\nEvaluation on test data.')
                auroc, auprc, _, _, _, _, _ = modelEvaluation.evaluate_model(net, test_data, device, len(test_data))

                # calculate the comparison value
                performance = auroc + auprc

                # save the model if the auc is better (also overwrite the best_model so far)
                if performance > performance_saver.max()[1]:
                    file_path = os.path.join(latest_subdir, 'performance_' + f'{int(performance * 1000):04d}')
                    torch.save(net.state_dict(), file_path)
                    shutil.copy(file_path, os.path.join(latest_subdir, 'best_model'))
                    print(f'Model saved in "{file_path}"')

                # print the best model so far and save its performance
                performance_saver(performance, f'{epoch}-{i}')
                best_performance = performance_saver.max()
                print(f'The best value is {best_performance[1]:05.4f} in [{" ".join(best_performance[0])}].')
