# -*- coding: utf-8 -*-
# @Time    : 2020/7/17 10:10
# @Author  : Xiaofeifei, Tengfei Shen, Chen Ziwei, Du changping
# File:data_util.py

import random

import numpy as np
from sklearn.preprocessing import scale, MultiLabelBinarizer
# from scipy import signal
from read_data import MAX_LEN


class Preprocess_data:
    def __init__(self, ecg, labels=None, channels=12):
        self.channels = channels
        # self.mean, self.std = self.compute_mean_std(ecg)
        if labels is not None:
            self.classes = np.arange(1, 112)
            #self.classes = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,])
            self.mlb = MultiLabelBinarizer(self.classes)

    def process(self, ecg, label):
        return self.process_ecg(ecg), self.process_label(label)

    def process_ecg(self, ecg):
        ecg = self.copy(ecg, dtype=np.float32)  # copy为在数组中自复制
        # ecg = self.pad(ecg, dtype=np.float32) #pad为在数组中补零

        # ecg = (ecg - self.mean) / self.std
        return ecg

    def process_label(self, label):
        return self.mlb.fit_transform(label)
        # label中为每个输入心电图的标签，每个最多可能有三个标签
        # fig_transform将其转化为统一的label标签，例如（0，0,1，0,0,1,0,1,0）

    def pad(self, x, val=0, dtype=np.float32):
        # max_len = max(len(i) for i in x)
        max_len = MAX_LEN
        padded = np.full((len(x), max_len, self.channels), val, dtype=dtype)
        for e, i in enumerate(x):
            padded[e, :len(i), :] = i
        return padded

    def copy(self, data, val=0, dtype=np.float32):
        padded = np.full((len(data), MAX_LEN, self.channels), val, dtype=dtype)

        for e, i in enumerate(data):
            end = 0
            while (end + len(i) <= MAX_LEN):
                padded[e, end:end + len(i), :] = i
                end += len(i)
            padded[e, end:MAX_LEN, :] = i[:MAX_LEN - end, :]
        return padded

    def compute_mean_std(self, ecg):
        if self.channels == 1:
            ecg = np.hstack(ecg)
            mean_ecg = np.mean(ecg).astype(np.float32)
            std_ecg = np.std(ecg).astype(np.float32)
        else:
            mean_ecg = np.zeros(self.channels)
            std_ecg = np.zeros(self.channels)
            for i in range(self.channels):
                ecg_channel = [ecg[j][:, i] for j in range(len(ecg))]
                ecg_channel = np.hstack(ecg_channel)
                mean_ecg[i] = np.mean(ecg_channel).astype(np.float32)
                std_ecg[i] = np.std(ecg_channel).astype(np.float32)
        return mean_ecg, std_ecg


def meanstd_scale(ecg):
    if ecg.shape[1] == 12:  # 输入的ecg形状为（time_step, channels）
        norm_axis = 0
    else:  # # 输入的ecg形状为（channels, time_step）
        norm_axis = 1
    return scale(ecg, axis=norm_axis, with_mean=True, with_std=True)


def extend_signal(signal, length):
    # 输入的signal是一个样本的ECG数据，形状为(time_steps, n_channels)
    extend = np.zeros((length, 12))
    signal_len = np.min([length, signal.shape[0]])
    extend[:signal_len] = signal[:signal_len]
    return extend


def data_generator(batch_size, preproc, x, y):
    num_examples = len(x)
    examples = zip(x, y)
    examples = sorted(examples, key=lambda x: x[0].shape[0])  # 按照ecg信号长度进行排序
    end = num_examples - batch_size + 1
    batches = [examples[i:i + batch_size] for i in range(0, end, batch_size)]
    random.shuffle(batches)
    while True:
        for batch in batches:
            x, y = zip(*batch)
            yield preproc.process(x, y)


def data_generator2(batch_size, preproc, x, y, x_feat):
    num_examples = len(x)
    examples = zip(x, y)
    examples = sorted(examples, key=lambda x: x[0].shape[0])  # 按照ecg信号长度进行排序
    x_shape_len = [s.shape[0] for s in x]
    x_ind = np.argsort(x_shape_len)
    x_feat = [x_feat[i] for i in x_ind]
    end = num_examples - batch_size + 1
    batches = [examples[i:i + batch_size] for i in range(0, end, batch_size)]
    batches2 = [x_feat[i:i + batch_size] for i in range(0, end, batch_size)]
    random.seed(12306)
    random.shuffle(batches)
    random.seed(12306)
    random.shuffle(batches2)
    while True:
        for batch, batch2 in zip(batches, batches2):
            x, y = zip(*batch)
            prec_x, prec_y = preproc.process(x, y)
            x_feat_hand = np.vstack(batch2)
            yield ([prec_x, x_feat_hand], prec_y)


import scipy as sc


def random_resample(data, resample_rate=0.4):
    # data : 样本的ECG数据，形状为(n_channel, time_steps)
    # resample: shape (time_steps, n_channels)
    [n_channel, length] = data.shape
    new_length = length * resample_rate
    interpol = sc.interpolate.interp1d(np.arange(length), data)
    resample_point = np.linspace(start=0, stop=length - 1, num=new_length)
    resample = interpol(resample_point)

    return resample.T


def bandpass_filter(data, lowcut=0.5, highcut=49, signal_freq=500, filter_order=1):
    """
    Method responsible for creating and applying Butterworth filter.
    :param deque data: raw data, shape (n_channels, time_steps)
    :param float lowcut: fi ,lter lowcut frequency value
    :param float highcut: filter highcut frequency value
    :param int signal_freq: signal frequency in samples per second (Hz)
    :param int filter_order: filter order
    :return array: filtered data(time_steps, n_channels)
    """
    from scipy.signal import butter, lfilter
    nyquist_freq = 0.5 * signal_freq
    low = lowcut / nyquist_freq
    high = highcut / nyquist_freq
    b, a = butter(filter_order, [low, high], btype="band")
    # y = filtfilt(b, a, data)
    y = lfilter(b, a, data)
    return y.T


def lowpass_filter(data, lowcut=50, signal_freq=500, filter_order=1):
    """
    Method responsible for creating and applying Butterworth filter.
    :param deque data: raw data, shape (n_channels, time_steps)
    :param float lowcut: fi ,lter lowcut frequency value
    :param float highcut: filter highcut frequency value
    :param int signal_freq: signal frequency in samples per second (Hz)
    :param int filter_order: filter order
    :return array: filtered data(time_steps, n_channels)
    """
    from scipy.signal import butter, lfilter
    nyquist_freq = 0.5 * signal_freq
    low = lowcut / nyquist_freq
    b, a = butter(filter_order, low, btype="lowpass")
    # y = filtfilt(b, a, data)
    y = lfilter(b, a, data)
    return y.T


def ecg_spectrogram(data, fs=100, nperseg=16, noverlap=8, nfft=None):
    # 此函数是计算ECG信号经过短时傅里叶变换后时域频谱
    # data: 输入的样本ECG数据，形状为(n_channel, time_steps)
    # fs: ECG信号的频率
    # nperseg: 每个窗口的长度
    # noverlap: 相邻窗口重叠长度
    # 具体参数解释详见： https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.spectrogram.html
    log_spectrogram = True
    _, _, Sxx = signal.spectrogram(data, fs=fs, nperseg=nperseg, noverlap=noverlap, nfft=nfft)
    Sxx = np.transpose(np.reshape(Sxx, (-1, Sxx.shape[2])))
    # Sxx = np.reshape(Sxx, (-1, Sxx.shape[2] * Sxx.shape[1] * Sxx.reshape[0]))# resahpe (channels, timesteps) --> (timesteps, channels)
    # Sxx = np.transpose(Sxx, [1,2,0])
    if log_spectrogram == True:
        Sxx = abs(Sxx)
        mask = Sxx > 0
        Sxx[mask] = np.log(Sxx[mask])

    return Sxx


def data_split(data, labels, folds=5, valid_fold=1, seed=12306):
    """由于数据集中样本的样本数量不均衡，在划分数据集时，采用分类分比例划分的方法
    参数：
    data: 样本数据集特征，一个列表，列表的每一个元素为一个样本（以数组形式存储）
    labels: 数据集对应的标签，形状：（样本数量, 1）
    split_size: 划分的测试集比例
    state: 随机数种子

    """

    def k_folds(data_index, folds=5, n_fold=1):
        """
        :param data_index: 数据集的下标，数组形式
        :param folds: 交叉验证折数，默认为5
        :param n_fold: 第几折数据作为验证集，比如n_fold=1时，第一折数据作为测试集
        :return: 训练集和测试集的下标
        """
        num_data = len(data_index)
        fold_size = int(num_data / folds)
        if n_fold <= 0 or n_fold > folds:
            raise ValueError("输入的n_fold错误！")
        test_index = data_index[(n_fold - 1) * fold_size:n_fold * fold_size]
        train_index = np.setdiff1d(data_index, test_index)
        return train_index, test_index

    first_label = list(labels[:, 0])
    classes = sorted(set(label for label in first_label))
    # 先找出每种类别的样本下标
    classes_index = {'index_' + str(clas): [] for clas in classes}
    # {'index_1.0': [], 'index_2.0': [], 'index_3.0': [], 'index_4.0': [], 'index_5.0': [], 'index_6.0': [], 'index_7.0': [], 'index_8.0': [], 'index_9.0': []}
    for ind, label in enumerate(first_label):
        classes_index['index_' + str(label)].append(ind)
    # {'index_1.0': [1, 15, 19, 28, 29, 36, 37, 40, 58, 72, 74, 88, 89, 93, 106,...], 'index_2.0': [2, 3, 6, 8, 16, 18, 22, 25, 60, 63, 64, 70, 85, 100, 108,...],...'index_9.0': [...]}
    for clas in classes:
        print(clas,":",classes_index['index_' + str(clas)])


    train_data = [];
    train_label = [];
    test_data = [];
    test_label = []
    # 将各个类别按比例分成训练集和测试集
    for clas in classes:
        np.random.seed(seed)
        random_index = np.random.permutation(classes_index['index_' + str(clas)])
        train_index, test_index = k_folds(random_index, folds=folds, n_fold=valid_fold)

        x_train = [data[ind] for ind in train_index]
        x_test = [data[ind] for ind in test_index]

        y_train = [tuple(labels[ind][np.where(labels[ind] > 0)]) for ind in train_index]
        y_test = [tuple(labels[ind][np.where(labels[ind] > 0)]) for ind in test_index]
        # 此处的一个心电图标签值可能有一个，也可能是两个或者三个
        # 例如：(2.0,),(2.0, 5.0),(3.0, 5.0, 7.0)都有可能

        train_data.extend(x_train)
        train_label.extend(y_train)
        test_data.extend(x_test)
        test_label.extend(y_test)
        seed += 1

    return (train_data, train_label, test_data, test_label)


def neighblock_denoising(signal, wavelet='db4', de_level=2):
    """
    输入：
    signal: 原始信号
    wavelet: 基小波名字，比如'sym8'，'db4'
    de_level: 分解的层次
    返回：
    rec_waves: 去噪后的重构信号
    """
    import pywt
    base_wave = pywt.Wavelet(wavelet)
    coeffs = pywt.wavedec(signal, wavelet=base_wave, mode='symmetric', level=de_level)

    len_coef = len(coeffs)
    new_coeffs = []
    new_coeffs.append(coeffs[0])  # savae approximate coefficient

    lambda_c = 4.505

    for i in range(1, len_coef, 1):
        n = len(coeffs[i])
        L0 = int(np.log(n) / 2)
        L1 = max(1, int(L0 / 2))
        L = L0 + 2 * L1

        num_block = int(np.ceil(n / L0))
        new_cD = coeffs[i]
        for j in range(num_block):
            left_index = max(0, j * L0 - L1)
            right_index = min(n, (j + 1) * L0 + L1)
            S = np.sum(np.square(coeffs[i][left_index:right_index]))
            threshold = max(0, (S - lambda_c * L) / S)
            rb_ind = min(n, (j + 1) * L0)
            new_cD[j * L0:rb_ind] = coeffs[i][j * L0:rb_ind] * threshold

        new_coeffs.append(new_cD)
    rec_waves = pywt.waverec(new_coeffs, base_wave)

    return rec_waves


def wavelet_denoise(waves, wavelet='db4', de_level=2):
    """
    waves: 输入的原始信号，信号形状（time_steps, 12）

    """
    assert waves.shape[1] == 12
    rec_waves = []
    for i in range(waves.shape[1]):
        rec_wave = neighblock_denoising(waves[:, i], wavelet=wavelet, de_level=de_level)
        rec_wave = np.reshape(rec_wave, (-1,))
        rec_waves.append(rec_wave)
    rec_waves = np.vstack(np.array(rec_waves))
    return rec_waves.T


from scipy.signal import medfilt


def median_remove_bw(signal):
    assert signal.shape[1] == 12
    denoise_signal = np.zeros(signal.shape)
    for i in range(signal.shape[1]):
        baseline = medfilt(signal[:, i], 99)
        baseline = medfilt(baseline, 299)
        denoise_signal[:, i] = signal[:, i] - baseline
    return denoise_signal

# def preprocess_outlier_ecg(ecg, threshold=10):
#     """
#     :param ecg: raw ecg signal
#     :param threshold: outlier threshold, any signal value greater than the threshold would be regarded as outlier
#     :return: remove outlier ecg, the segment of which not include outlier
#     """
#     ecg_mask = np.where(np.abs(ecg) >= threshold, 1, 0)
#     ecg_mask2 = np.sum(ecg_mask, axis=1)
#     ecg_mask2 = np.where(ecg_mask2 > 0)[0]
#     if len(ecg_mask2) <= 10:
#         process_ecg = np.where(np.abs(ecg) >= threshold, 0, ecg)
#     else:
#         ecg_mask3 = np.diff(ecg_mask2)
#         max_diff = np.max(ecg_mask3)
#         max_diff_ind = np.argmax(ecg_mask3)
#
#         interval1 = ecg_mask2[0]
#         interval2 = len(ecg) - ecg_mask2[-1]
#         max_interval = np.max([interval1, interval2, max_diff])
#         if (max_interval == interval2):
#             process_ecg = ecg[ecg_mask2[-1] + 1:, :]
#         elif (max_interval == interval1):
#             process_ecg = ecg[:ecg_mask2[0], :]
#         else:
#             process_ecg = ecg[ecg_mask2[max_diff_ind] + 1:ecg_mask2[max_diff_ind + 1], :]
#
#     return process_ecg
