# this file realize the data loading for pytorch
from torch.utils.data import Dataset
import sys
import dataReading
import os
import utilityFunctions
from tqdm import tqdm
import torch
import dataPreprocessing
from scipy import signal
import numpy as np


class CardiacAbnormalitiesDataset(Dataset):
    """
    This is the pytorch dataset class for the challenge data
    """

    def __init__(self, file_list, classes_path, train, target_sample_frequency=500):
        """
        This function initializes the Dataset.

        :param file_list: A list of full paths (str) to the dataset files
        :type file_list: list
        :param classes_path: The path to the classes.txt
        :type classes_path: str
        :param train: A boolean to tell the dataset if it is in train or test mode
        :type train: bool
        :param target_sample_frequency: The sampling frequency we wish to have for the signals
        :type target_sample_frequency: int
        :param epsilon: The label smoothing epsilon (0-1)
        :type epsilon: float
        """

        # save the input variables
        self.file_list = file_list
        self.train = train
        self.classes_path = classes_path
        self.target_sample_frequency = target_sample_frequency

        # get the mapping information
        code_to_int, int_to_code, text = dataReading.load_classes_mapping(classes_path)
        self.code_to_int = code_to_int
        self.int_to_code = int_to_code
        self.class_text = text

        # extract the class amount
        self.class_amount = len(int_to_code)

        # load all the header information
        self.header_list = []
        for fil in tqdm(file_list, desc='Loading the header information'):
            header = dataReading.load_header_text(fil + '.hea')
            inf = dataReading.header2info(header, code_to_int, int_to_code, text)
            class_names = dataReading.code2name(inf[0], classes_path)
            self.header_list.append(inf + (class_names,))

        # print the size of the header list
        mem = 0
        for ele in self.header_list:
            mem += sys.getsizeof(ele)
        print(f'The header list occupies {utilityFunctions.convert_size(mem)} RAM.')

        # make the filter for the signals
        self.sos = signal.butter(2, [0.5, 45], btype='band', analog=False, output='sos', fs=500)

        # sanity check the file list in correspondence with the header info list
        if not len(self.header_list) == len(self.file_list):
            raise ValueError('Path amount and Header amount are not the same.')
        for i in range(len(self.header_list)):
            if not self.header_list[i][5] == self.file_list[i].split('/')[-1]:
                raise ValueError('The elements in the lists are not the same.')
        print('Dataset has been initialized with no errors.\n')

    def __len__(self):
        return len(self.header_list)

    def print_patient_info(self, idx):
        dataReading.header2text(self.header_list[idx], self.classes_path)

    def get_class_amount(self):
        return self.class_amount

    def int2code(self, idx):
        """
        This function maps class integers to SNOMED codes.

        :param idx: Single integer value integer to map a SNOMED code to
        :type idx: int
        """
        code = self.int_to_code(idx)
        return code

    def get_class_list(self):

        # make an empty list
        filler = '------'
        cls_lst = [filler for _ in self.int_to_code.keys()]

        # put in the codes
        for key in self.int_to_code:
            if isinstance(self.int_to_code[key], list):
                cls_lst[key] = self.int_to_code[key][0]
            else:
                cls_lst[key] = self.int_to_code[key]

        # sanity check the cls_list
        if filler in cls_lst:
            raise ValueError('Something is wrong.')
        return cls_lst

    def __getitem__(self, idx):

        """
        Example for element 22 from all_data
        (['164865005', '426783006', '429622005', '59931005'],
        [20, 23],
        tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.]),
        66,
        1,
        'HR17869',
        ['Not scored (164865005)', 'sinus rhythm (426783006)', 'Not scored (429622005)', 't wave inversion (59931005)'])
        """

        # get the header information
        _, class_index, class_one_hot, age, sex, patient_id, sample_frequency, gain, class_names = self.header_list[idx]
        # correct the gain
        label = class_one_hot
        label_idx = class_index

        # get the signal (with sanity check)
        assert gain < 5000, f'Gain is {gain} for patient {patient_id}.'
        unfiltered_sig = dataReading.load_challenge_data(self.file_list[idx])/gain

        # resample the signal
        unfiltered_sig = dataPreprocessing.resample_signal(unfiltered_sig, sample_frequency,
                                                           self.target_sample_frequency)

        # filter the signal
        sig = dataPreprocessing.filter_sig(unfiltered_sig, self.sos)

        # make the sample
        sample = {'signal': unfiltered_sig, 'class_index': class_index, 'onehot': class_one_hot, 'age': age, 'sex': sex,
                  'id': patient_id, 'class_names': class_names, 'filtered_signal': sig,
                  'sample_frequency': sample_frequency}

        # get a random window from the signal (if the signal is longer than 10s)
        sig = dataPreprocessing.cropping(sig, minimum_length=4000, maximum_length=10000)

        if self.train:

            # apply the cutout
            sig = dataPreprocessing.cutout(sig, sig.shape[1] * 0.2, 0.35)

            # shift the signal for a little bit
            sig = dataPreprocessing.shifting(sig)

            # add a little noise
            sig = dataPreprocessing.noise(sig, 0.02)

            # make a little wiggle per channel
            sig = dataPreprocessing.wiggle_channels(sig, max_samples_rotation=5)

        # make a tensor out of the samples
        sig = torch.from_numpy(sig).float()

        return sig, label, label_idx, sample


if __name__ == '__main__':

    # get a file list
    data_path = os.path.expanduser('~') + '/datasets/all_data'
    lst = dataReading.generate_path_list(data_path)

    # test the init of the dataset
    class_path = os.path.expanduser('~') + '/project/classes.txt'
    dataset = CardiacAbnormalitiesDataset(lst, class_path, True)
    print(f'We now have access to {len(dataset)} samples.')
    print(f'The samples have {dataset.get_class_amount()} different classes.')
    out = dataset[22]
    print(out[-1]['class_index'], out[-1]['onehot'], out[-1]['age'], out[-1]['sex'], out[-1]['id'],
          out[-1]['sample_frequency'], out[-1]['class_names'])
    dataset.print_patient_info(22)
