# this file contains the functions and code to lode the databases
import csv
import os
import re
from scipy.io import loadmat
import numpy as np
import warnings
import torch
import glob

'''
Some information regarding the databases:
We have six in total named:
CPSC: AXXXX (6877 Patients: 13754 files in total)
e: EXXXX (10344 Patients: 20688 files in total)
ptb: SXXXX (516 Patients: 1032 files in total)
ptb-xl: HRXXXXX (21837 Patients: 43674 files in total)
training_2: QXXXX (3453 Patients: 6906 files in total)
training_ptb: IXXXX (74 Patients: 148 files in total)

TOTAL: (43101 Patients: 86202 files)
'''


def load_classes_mapping(path, combine_equal_classes=True):
    """
    This function can load a csv that contains the classes mapping with notes whether classes are connected. It can be
    used to map the SNOMED-Codes to class indices back and forth.

    :param path: The path to the classes.txt file that contains the necessary information.
    :type path: str
    :param combine_equal_classes: This decides whether equal classes are mapped to the same integer.
    :type combine_equal_classes: bool
    """

    # open the file an insert it into a list
    with open(path, 'r') as filet:
        data = list(csv.reader(filet))
    with open(path, 'r') as filet:
        text = filet.read()

    # seperate heading of the table from the content
    heading = data[0]
    data = data[1:]

    # make a mapping from SNOMED codes to integers and back again lst = re.findall(r'[0-9]+', row[-1])
    code_to_int = {}
    int_to_code = {}

    # make the combined mapping option (map equal classes to the same integer)
    if combine_equal_classes:

        # count how many equal classes we have and also extract the equal codes
        mapping_counter = 0
        for row in data:
            if row[-1]:
                mapping_counter += 1

        # make a sanity check
        if not mapping_counter % 2 == 0:
            raise ValueError('There is an odd number of mappings, which is not possible.')

        # iterate over all classes
        class_counter = 0
        inc_class_counter = False
        for row in data:

            # check whether the last row has a note with equal classes
            lst = re.findall(r'[0-9]+', row[-1])
            if lst:
                # go through the classes that are equal and check whether they are already mapped to an int
                for ele in lst:
                    if ele in code_to_int.keys():
                        # if they are already mapped, check the integer they are mapped to and also make a note in the
                        # integer to code mapping
                        code_to_int[row[1]] = code_to_int[ele]
                        # attach the code to the int (list(set(list_var)) makes a unique list)
                        int_to_code[code_to_int[ele]] = list(set(int_to_code[code_to_int[ele]] + [row[1]]))
                        inc_class_counter = False
                    else:
                        # just make the normal mapping
                        code_to_int[row[1]] = class_counter
                        int_to_code[class_counter] = [row[1]]
                        inc_class_counter = True
            else:
                # make the normal mapping
                code_to_int[row[1]] = class_counter
                int_to_code[class_counter] = [row[1]]
                inc_class_counter = True

            # IMPORTANT: increase class counter
            if inc_class_counter:
                class_counter += 1
    else:
        class_counter = -1
        for row in data:
            # IMPORTANT: increase the class counter
            class_counter += 1
            # make the normal mapping
            code_to_int[row[1]] = class_counter
            int_to_code[class_counter] = [row[1]]

    return code_to_int, int_to_code, text


def code2name(code_list, path):
    """
    This function can map SNOMED codes to names using the class description file.

    :param code_list: This is a list of SNOMED codes as strings.
    :type code_list: list
    :param path: This defines the path to the class file that will be used for the mapping.
    :type path: str
    """

    # make a code-name dictionary
    code_to_name = {}
    with open(path) as filet:
        data = list(csv.reader(filet))
        # cut the first row since these are headings
        data = data[1:]
        for row in data:
            code_to_name[row[1]] = row[0]

    # get the text of the file to account for unscored classes
    with open(path) as filet:
        text = filet.read()

    # iterate through the list
    name_list = []
    for code in code_list:
        if code in text:
            name_list.append(code_to_name[code] + f' ({code})')
        else:
            name_list.append(f'Not scored ({code})')

    return name_list


def load_challenge_data(path):
    """
    This function loads the ecg data from a given file path. The file path needs to be a .csv with the 12 lead ecg
    values.

    :param path: The path to the file we want to read.
    :type path: str
    """
    x = loadmat(path)
    data = np.asarray(x['val'], dtype=np.float64)

    return data


def load_header_text(path):
    """
    This function loads the header data from a given file path. The file path needs to be a .hea with the information
    about the 12 lead ecg.

    :param path: The path to the file we want to read.
    :type path: str
    """
    with open(path, 'r') as f:
        header_data = f.readlines()

    return header_data


def header2info(header, code_to_int, int_to_code, text, debug=False):
    """
    This function reads the information from a header file and extracts the necessary information.

    :param header: The string you can read from the header files.
    :type header: str
    :param code_to_int: Dictionary that maps SNOMED codes to class integers.
    :type code_to_int: dict
    :param int_to_code: Dictionary that maps class integers to SNOMED codes.
    :type int_to_code: dict
    :param text: The plain text of the classes.txt. This is used in cases where there are unscored diagnosis.
    :type text: str
    :param debug: With debug=True the function will print, when Age or Sex are not given in the header file
    :type debug: bool
    """

    # define how to tranlate the sex
    sex_list = {'Male': 0, 'male': 0, 'M': 0, 'm': 0, 'Female': 1, 'female': 1, 'F': 1, 'f': 0}

    # try to find sampling rate and check whether it is 500
    spl_rate = int((header[0].split(' '))[2])
    if spl_rate < 250:
        raise ValueError(f'Sampling frequency is below 250 (real value: {spl_rate}) in file: [{path}].\n {header[0]}')

    # try to find an age
    try:
        age = int(re.findall(r"[\w'-]+", header[13])[-1])
    except ValueError:
        if debug:
            print(f'The Age in header file [{header[0][:-1]}] is not defined ({header[13][:-1]})!')
        age = -1

    # find the identifier, the sex and the class_name
    identifier = re.findall(r"[\w'-]+", header[0])[0]
    try:
        sex = sex_list[re.findall(r"[\w'-]+", header[14])[-1]]
    except KeyError:
        if debug:
            print(f'The Sex in header file [{header[0][:-1]}] is not defined ({header[14][:-1]})!')
        sex = -1
    class_codes = re.findall(r"[\w'-]+", header[15])[1:]

    # find the gain
    gain = int(header[1].split(' ')[2].split('/')[0])
    if gain > 10000:
        gain = gain / 1000
    if gain == 4880:
        gain = 2440

    # get the class index and the one hot vector and cope with extra diagnosis (check in classes.txt)
    one_hot = torch.zeros(len(int_to_code.keys()))
    class_index = [code_to_int[ele] for ele in class_codes if ele in text]
    one_hot[class_index] = 1

    return class_codes, class_index, one_hot, age, sex, identifier, spl_rate, gain


def header2text(header, path):
    """
    This function uses the output of the header2info function and prints readable text from the information.

    :param header: The output of the header2info function
    :type header: tuple
    :param path: The path to the classes.txt
    :type path: str
    """

    # get the diagnosis names
    name_list = []
    for ele in header[0]:
        name = code2name([ele], path)
        name_list.append(name[0])

    # get the sex
    sex = 'ERROR'
    if header[4] == 0:
        sex = 'Male'
    elif header[4] == 1:
        sex = 'Female'

    # Make the print.
    print(f'\nThe following Information is for patient {header[5]}.')
    print('---------------------------------------------------')
    print(f'The patient is {sex} and {header[3]} years old.')
    print(f'The signal is sampled with {header[6]} Hz (Gain: {header[7]}/mV).')
    print('He shows the following arrhythmias:')
    for ele in name_list:
        print(f'\t - {ele}')
    print('---------------------------------------------------')


def generate_path_list(path):
    """
    This function finds all the .mat files in a directory and also checks whether all of them have .hea files.

    :param path: The path to the directory with the data
    :type path: str
    """

    # list all files in the directory that end with .mat
    file_list = [sg[:-4] for sg in glob.glob(os.path.join(path, '*.mat'))]

    # sanity check the directory by comparing the lists
    for sg in file_list:
        if not os.path.isfile(sg + '.hea'):
            raise ValueError('Not every signal has a header file.')

    return file_list


if __name__ == '__main__':

    # define the classes.txt path
    pat = os.path.join(os.path.expanduser('~'), 'project/classes.txt')

    # test the classes mapping
    out = load_classes_mapping(pat, combine_equal_classes=True)

    # read a random header
    hd = load_header_text('/home/ubuntu/datasets/e/E09660.hea')

    # test the header2info function
    outp = header2info('/home/ubuntu/datasets/e/E09660.hea', *out)

    # test the code2name function
    print(code2name(['164890007'], pat))

    # test the header2text function
    header2text(outp, pat)

    # test the generate_path_list function
    lst = generate_path_list('/home/ubuntu/datasets/all_data')



