# -*- coding: utf-8 -*-
# @Time    : 2018/4/8 22:37
# @Author  : Xiaofeifei, Chen Ziwei, Du changping
# @File    : read_data.py

import os
import os.path
import scipy.io
import pandas as pd
import numpy as np
import tqdm
import random

#STEP = 256
MAX_LEN = 30000


def get_classes(files):
    classes = set()
    for input_file in files:
        with open(input_file, 'r') as f:
            for lines in f:
                if lines.startswith('#Dx'):
                    tmp = lines.split(': ')[1].split(',')
                    for c in tmp:
                        classes.add(c.strip())
    # classes = {'270492004', '164889003', '164890007', '426627000',
    #            '713427006', '713426002', '445118002', '39732003',
    #            '164909002', '251146004', '698252002', '10370003',
    #            '284470004', '427172004', '164947007', '111975006',
    #            '164917005', '47665007', '59118001', '427393009',
    #            '426177001', '426783006', '427084000', '63593006',
    #            '164934002', '59931005', '17338001'}
    return sorted(classes)

def get_true_labels(input_file,classes):
    with open(input_file,'r') as f:
        for lines in f:
            if lines.startswith('#Dx'):
                tmp = lines.split(': ')[1].split(',')
                a = np.zeros(11)
                i=0
                for c in tmp:
                    if c.strip() in classes:
                        #print(i, c.strip())
                        idx = classes.index(c.strip())
                        a[i]=idx+1
                        i = i+1
    #print(a)

    return a

class ReadData:
    def __init__(self, raw_label=True):
        if raw_label:
            self.label_filename = 'REFERENCE.csv'
        else:
            self.label_filename = 'REFERENCE_CHECK.csv'

    def load_labels(self, path):
        # Find label and output files.
        label_files = []
        for f in os.listdir(path):
            g = os.path.join(path, f)
            if os.path.isfile(g) and not f.lower().startswith('.') and f.lower().endswith('hea'):
                label_files.append(g)
        label_files = sorted(label_files)
        #print(label_files)
        classes = get_classes(label_files)
        #print(classes)
        #print(len(classes))
        # Load labels and outputs.
        num_files = len(label_files)
        true_labels = []
        for k in range(num_files):
            #print(label_files[k])
            #print(get_true_labels(label_files[k], classes))
            true_labels.append(get_true_labels(label_files[k], classes))
        true_labels = np.array(true_labels)
        return true_labels

    def load_data(self, path):
        """
        :param path: 数据的存储路径
        :return: a list of data
        """
        data_file_names = [name for name in os.listdir(path) if name.endswith('.mat')]
        print(data_file_names)
        if data_file_names == []:
            return []
        data = []
        threshold = 10
        for name in tqdm.tqdm(data_file_names):
            data_path = os.path.join(path, name)     # 数据集路径
            ecg = self.load_ecg(data_path)
            # if ecg.shape[0]>=MAX_LEN:
            #     ecg = ecg[:MAX_LEN, :]
            data.append(ecg)
        return data

    def load_ecg(self, path):
        sample = scipy.io.loadmat(path)
        ecg = sample['val'].T
        if ecg.shape[0] >= MAX_LEN:
            ecg = ecg[:MAX_LEN, :]
        return ecg

    def load_dataset(self, path):
        data = self.load_data(path)
        labels = self.load_labels(path)
        return data, labels

    def preprocess_outlier_ecg(self, ecg, threshold=10):
        """
        :param ecg: raw ecg signal
        :param threshold: outlier threshold, any signal value greater than the threshold would be regarded as outlier
        :return: remove outlier ecg, the segment of which not include outlier
        """
        ecg_mask = np.where(np.abs(ecg) >= threshold, 1, 0)
        ecg_mask2 = np.sum(ecg_mask, axis=1)
        ecg_mask2 = np.where(ecg_mask2 > 0)[0]
        if len(ecg_mask2) <= 10:
            process_ecg = np.where(np.abs(ecg) >= threshold, 0, ecg)
        else:
            ecg_mask3 = np.diff(ecg_mask2)
            max_diff = np.max(ecg_mask3)
            max_diff_ind = np.argmax(ecg_mask3)

            interval1 = ecg_mask2[0]
            interval2 = len(ecg) - ecg_mask2[-1]
            max_interval = np.max([interval1, interval2, max_diff])
            if (max_interval == interval2):
                process_ecg = ecg[ecg_mask2[-1] + 1:, :]
            elif (max_interval == interval1):
                process_ecg = ecg[:ecg_mask2[0], :]
            else:
                process_ecg = ecg[ecg_mask2[max_diff_ind] + 1:ecg_mask2[max_diff_ind + 1], :]

        return process_ecg
