from constants import *
from scipy import signal
import csv

from preprocessing import preprocessing_ecg


class Dataset(data.Dataset):
    def __init__(self, maxlen, number_samples, files_list, input_directory, training):
        self.maxlen = maxlen
        self.len = number_samples
        self.files_list = files_list
        self.training = training
        self.input_directory = input_directory
        if training is False:
            if load_from_last:
                self.ecgs = torch.load("ecgs_test.pt")
                self.labels = torch.load("labels_test.pt")
            else:
                self.ecgs, self.labels = load_data_test_set(files_list, maxlen, self.input_directory)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        if self.training is False:
            return self.ecgs[index], self.labels[index]

        tmp_input_file = os.path.join(self.input_directory, self.files_list[index])
        data_ecg, header_data = load_challenge_data(tmp_input_file)
        data_ecg = resample(tmp_input_file, data_ecg)
        data_ecg = preprocessing_ecg(data_ecg, header_data)

        curr_multilabel = extract_label(header_data)

        ecg = tensor_padding(data_ecg, self.maxlen)
        ecg = torch.FloatTensor(ecg)

        return ecg, curr_multilabel


def extract_label(header_data):
    curr_multilabel = torch.zeros(n_pathologies + 1, dtype=torch.long)
    for iline in header_data:
        if iline.startswith('#Dx'):
            label = iline.split(': ')[1].split(',')
            for index in range(len(label)):
                label[index] = label[index].rstrip("\n")
                if label[index] in test_pathologies:
                    output_label_index = encoding_labels(annot_to_patho[label[index]])
                    curr_multilabel[output_label_index[0]] = 1
                if label[index] == '426783006':
                    curr_multilabel[n_pathologies] = 1
    curr_multilabel = torch.LongTensor(curr_multilabel)
    return curr_multilabel


def encoding_labels(label):
    return encodeur.transform([label])


def resample(file, data):
    # Resample to 500 HZ

    if "Training_WFDB/S" in file:
        # PTB database, sampling frequency is 1000 Hz
        seconds = int(len(data[0]) / 1000)
    elif "Training_WFDB/I" in file:
        # St-Petersbourg database, sampling frequency is 257 Hz
        seconds = int(len(data[0]) / 257)
    else:
        return data
    result = []
    for i in range(n_leads):
        result.append(signal.resample(data[i], int(seconds * 500)))
    return result


# Load challenge data.
def load_challenge_data(header_file):
    with open(header_file, 'r') as f:
        header = f.readlines()
    mat_file = header_file.replace('.hea', '.mat')
    x = loadmat(mat_file)
    recording = np.asarray(x['val'], dtype=np.float64)
    return recording, header


# def load_challenge_data(filename):
#     x = loadmat(filename)
#     data = np.asarray(x['val'], dtype=np.float64)
#     new_file = filename.replace('.mat', '.hea')
#     input_header_file = os.path.join(new_file)
#     with open(input_header_file, 'r') as f:
#         header_data = f.readlines()
#     return data, header_data


# Minimum over the dataset is: -17561
# Maximum over the dataset is: 32767
# Receive a tensor ecg of shape (12, len_ecg) and returns tensor of shape (12, max_length)
def tensor_padding(ecg, max_length):
    result = torch.empty((0, max_length))
    for index in range(len(ecg)):
        if len(ecg[index]) >= max_length:
            ecg_torch = torch.FloatTensor(ecg[index][0:max_length])
            ecg_torch = ecg_torch.unsqueeze(0)
            result = torch.cat((result, ecg_torch), dim=0)
        else:
            padding_array = np.ones(max_length - len(ecg[index])) * 500
            toarch_cat = torch.FloatTensor(list(ecg[index]) + list(padding_array))
            toarch_cat = toarch_cat.unsqueeze(0)
            result = torch.cat((result, toarch_cat), dim=0)
    return result


def check_number_samples(input_directory):
    with open('train_filenames.csv', newline='') as f:
        reader = csv.reader(f)
        training = list(reader)

        # indices_to_remove = []
        # for i, sample_name in enumerate(training):
        #     if sample_name[0][0:2] != "HR":
        #         indices_to_remove.append(i)
        #
        # for index in reversed(indices_to_remove):
        #     training.pop(index)

    with open('test_filenames.csv', newline='') as f:
        reader = csv.reader(f)
        test = list(reader)

        # indices_to_remove = []
        # for i, sample_name in enumerate(test):
        #     if sample_name[0][0:2] != "HR":
        #         indices_to_remove.append(i)
        #
        # for index in reversed(indices_to_remove):
        #     test.pop(index)

    counter_training, new_training = extract_data(training, input_directory, test=False)
    counter_test, new_test = extract_data(test, input_directory, test=True)

    return counter_training, counter_test, new_training, new_test


def extract_data(data_list, input_directory, test=True):
    new_list = []
    counter = 0
    for i, f in enumerate(data_list):
        f = f[0]
        if f[-1] != 't':
            f = f + '.mat'

        tmp_input_file = os.path.join(input_directory, f)
        data_ecg, header_data = load_challenge_data(tmp_input_file)
        # if (i == 4) and test:
        #     break
        for iline in header_data:
            if iline.startswith('#Dx'):
                label = iline.split(': ')[1].split(',')
                for index in range(len(label)):
                    label[index] = label[index].rstrip("\n")
                    if label[index] in test_pathologies:
                        counter += 1
                        new_list.append(f)
                        break
                    if label[index] == '426783006':
                        counter += 1
                        new_list.append(f)
                        break
    return counter, new_list


def load_data_test_set(input_files, max_length, input_directory):
    ecgs = torch.empty((0, 12, max_length))
    labels = torch.empty(0, n_pathologies + 1, dtype=torch.long)

    for i, f in (enumerate(input_files)):
        tmp_input_file = os.path.join(input_directory, f)
        data_ecg, header_data = load_challenge_data(tmp_input_file)
        data_ecg = resample(tmp_input_file, data_ecg)

        ecg = data_ecg
        ecg = preprocessing_ecg(ecg, header_data)
        ecg = tensor_padding(ecg, max_length)
        ecg = torch.FloatTensor(ecg)
        ecg = ecg.unsqueeze(0)
        ecgs = torch.cat((ecgs, ecg), dim=0)

        curr_multilabel = extract_label(header_data)
        curr_multilabel = curr_multilabel.unsqueeze(0)
        labels = torch.cat((labels, curr_multilabel), dim=0)

    torch.save(ecgs, "ecgs_test.pt")
    torch.save(labels, "labels_test.pt")
    return ecgs, labels
