#!/usr/bin/env python
# -*- coding: utf-8 -*-

import librosa
#from matplotlib import pyplot as plt
# import pywt

from scipy.signal import butter, lfilter
from helper_code import *

"""
信号预处理并提取数据特征和标签
"""

murmur_classes = ['Present', 'Unknown', 'Absent']
num_murmur_classes = len(murmur_classes)
outcome_classes = ['Abnormal', 'Normal']
num_outcome_classes = len(outcome_classes)

################################################################################
# data preprocess
################################################################################
def getDataFeatures(data_folder, patient_files, verbose):
    inner_all_features = []
    all_true_murmurs = []
    all_true_outcomes = []

    for i in range(len(patient_files)):
        if verbose >= 2:
            print('    {}/{}...'.format(i+1, len(patient_files)))

        # Load the current patient data and recordings.
        current_patient_data = load_patient_data(patient_files[i])
        current_frequencies = get_frequency(current_patient_data)
        current_recordings = load_recordings(data_folder, current_patient_data)

        # Extract features.
        single_features, demographic_feature = get_PCG_features(current_patient_data, current_recordings, current_frequencies)
        inner_all_features.append(single_features)#只添加了信号特征

        # Extract murmur labels and use one-hot encoding.
        current_murmur = np.zeros(num_murmur_classes, dtype=int)
        murmur = get_murmur(current_patient_data)
        if murmur in murmur_classes:
            j = murmur_classes.index(murmur)
            current_murmur[j] = 1
        all_true_murmurs.append(current_murmur)

        # Extract outcome labels and use one-hot encoding.
        current_outcome = np.zeros(num_outcome_classes, dtype=int)
        outcome = get_outcome(current_patient_data)
        if outcome in outcome_classes:
            j = outcome_classes.index(outcome)
            current_outcome[j] = 1
        all_true_outcomes.append(current_outcome)

    return inner_all_features, all_true_murmurs, all_true_outcomes


################################################################################
# Extract features from the data.
################################################################################
def get_PCG_features(data, recordings, sample_Fs):
    # Extract the age group and replace with the (approximate) number of months for the middle of the age group.
    age_group = get_age(data)

    if compare_strings(age_group, 'Neonate'):
        age = 0.5
    elif compare_strings(age_group, 'Infant'):
        age = 6
    elif compare_strings(age_group, 'Child'):
        age = 6 * 12
    elif compare_strings(age_group, 'Adolescent'):
        age = 15 * 12
    elif compare_strings(age_group, 'Young Adult'):
        age = 20 * 12
    else:
        age = float('nan')

    # Extract sex. Use one-hot encoding.
    sex = get_sex(data)

    sex_features = np.zeros(2, dtype=int)
    if compare_strings(sex, 'Female'):
        sex_features[0] = 1
    elif compare_strings(sex, 'Male'):
        sex_features[1] = 1

    # Extract height and weight.
    height = get_height(data)
    weight = get_weight(data)

    # Extract pregnancy status.
    is_pregnant = get_pregnancy_status(data)

    # Extract recording locations and data. Identify when a location is present, and compute the mean, variance, and skewness of
    # each recording. If there are multiple recordings for one location, then extract features from the last recording.
    targetFreq = 1000  # Down-sampling to 1000Hz
    locations = get_locations(data)

    recording_locations = ['AV', 'MV', 'PV', 'TV', 'PhC']
    recording_features = np.zeros((len(recording_locations), 10000), dtype=float)
    if len(locations) == len(recordings):
        for i in range(len(locations)):
            for j in range(len(recording_locations)):
                if compare_strings(locations[i], recording_locations[j]) and np.size(recordings[i]) > 0:
                    down_filtered_pcg_lead = myDownSample(recordings[i], sample_Fs,
                                                          targetFreq)

                    res = dataLenCheck(down_filtered_pcg_lead[:10000], 10000)  # 10s
                    recording_features[j] = res

    recording_features = recording_features.flatten()
    print(recording_features.shape)

    #recording_features滤波
    #recording_features = bandpass_filter(recording_features, 10, 400, targetFreq, 6)

    Mel_Spectrum = Mel_Time_Frequency_Spectrum(recording_features, targetFreq)
    return Mel_Spectrum, [age, sex_features, height, weight, is_pregnant]


################################################################################
# data uniform
################################################################################
# def Time_Frequency_Spectrum(signal, Fs):
#     Wavename = 'morl'
#     Totalscale = 256
#     fc = pywt.central_frequency(Wavename)
#     Cparam = 2 * fc * Totalscale
#     Scales = Cparam / np.arange(Totalscale, 1, -1)
#     [Cwtmatr, frequencies] = pywt.cwt(signal, Scales, Wavename, 1.0 / Fs)
#     # plt.figure(3)
#     # plt.contourf(abs(Cwtmatr))
#     # plt.show()
#
#     return abs(Cwtmatr)

################################################################################
# mel time_fre spectrum
################################################################################
def Mel_Time_Frequency_Spectrum(signal, Fs):

    EPS = 1E-6
    melspectrogram = librosa.feature.melspectrogram(y=signal, sr=Fs, n_mels=64,
                                                    hop_length=8, win_length=20)
    lms = np.log(melspectrogram + EPS)
    return lms
################################################################################
# data uniform
################################################################################
def data_uniform(totalFeatures, totalMurmurs, totalOutcomes):

    totalFeatures = np.array(totalFeatures, dtype='float32')  # 将64改成32，节省内存
    total_Murmurs = np.array(totalMurmurs)
    total_Outcomes = np.array(totalOutcomes)

    uniform_total_features = []
    for sigFeature in totalFeatures:
        theMean = np.mean(sigFeature)
        theStd = np.std(sigFeature)
        uniform_total_features.append((sigFeature - theMean) / theStd)
    uniform_total_features = np.array(uniform_total_features)

    uniform_total_features[np.isnan(uniform_total_features)] = 0.01
    uniform_total_features[np.isinf(uniform_total_features)] = 0.01

    total_features = uniform_total_features

    return total_features, total_Murmurs, total_Outcomes, theMean, theStd

################################################################################
# bandpass filter
################################################################################
def bandpass_filter(data, lowcut, highcut, signal_freq, filter_order):
    """
    Method responsible for creating and applying Butterworth filter.
    :param deque data: raw data
    :param float lowcut: filter lowcut frequency value
    :param float highcut: filter highcut frequency value
    :param int signal_freq: signal frequency in samples per second (Hz)
    :param int filter_order: filter order
    :return array: filtered data
    """
    nyquist_freq = 0.5 * signal_freq
    low = lowcut / nyquist_freq
    high = highcut / nyquist_freq
    b, a = butter(filter_order, [low, high], btype="band")
    y = lfilter(b, a, data)
    return y


################################################################################
# Down sampling
################################################################################
def myDownSample(data, sample_Fs, targetFreq):
    step = sample_Fs // targetFreq
    newData = [data[i] for i in range(len(data)) if i % step == 0]
    return newData


################################################################################
# data aligning
################################################################################
def dataLenCheck(data, window):
    if len(data) == window:
        return data
    elif len(data) < window:
        data += [0 for i in range(window - len(data))]  # Fill data by 0.
        return data
    else:
        raise Exception("Error in dataLenCheck")
    # 防止数据片段小于window
