#!/usr/bin/env python

import numpy as np
import pandas as pd
import scipy.signal as scipysi
# Added by Project Team
from scipy import signal
from scipy.signal import butter, lfilter, find_peaks, filtfilt, iirnotch

import constants
from joblib import Parallel, delayed
from sklearn.preprocessing import MinMaxScaler
import biosppy

#
# def detect_peaks(ecg_measurements,signal_frequency,gain):
#
#         """
#         Method responsible for extracting peaks from loaded ECG measurements data through measurements processing.
#
#         This implementation of a QRS Complex Detector is by no means a certified medical tool and should not be used in health monitoring.
#         It was created and used for experimental purposes in psychophysiology and psychology.
#         You can find more information in module documentation:
#         https://github.com/c-labpl/qrs_detector
#         If you use these modules in a research project, please consider citing it:
#         https://zenodo.org/record/583770
#         If you use these modules in any other project, please refer to MIT open-source license.
#
#         If you have any question on the implementation, please refer to:
#
#         Michal Sznajder (Jagiellonian University) - technical contact (msznajder@gmail.com)
#         Marta lukowska (Jagiellonian University)
#         Janko Slavic peak detection algorithm and implementation.
#         https://github.com/c-labpl/qrs_detector
#         https://github.com/jankoslavic/py-tools/tree/master/findpeaks
#
#         MIT License
#         Copyright (c) 2017 Michal Sznajder, Marta Lukowska
#
#         Permission is hereby granted, free of charge, to any person obtaining a copy
#         of this software and associated documentation files (the "Software"), to deal
#         in the Software without restriction, including without limitation the rights
#         to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#         copies of the Software, and to permit persons to whom the Software is
#         furnished to do so, subject to the following conditions:
#         The above copyright notice and this permission notice shall be included in all
#         copies or substantial portions of the Software.
#         THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#         IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#         FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#         AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#         LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#         OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#         SOFTWARE.
#
#         """
#
#
#         filter_lowcut = 0.001
#         filter_highcut = 15.0
#         filter_order = 1
#         integration_window = 30  # Change proportionally when adjusting frequency (in samples).
#         findpeaks_limit = 0.35
#         findpeaks_spacing = 100  # Change proportionally when adjusting frequency (in samples).
#         refractory_period = 240  # Change proportionally when adjusting frequency (in samples).
#         qrs_peak_filtering_factor = 0.125
#         noise_peak_filtering_factor = 0.125
#         qrs_noise_diff_weight = 0.25
#
#
#         # Detection results.
#         qrs_peaks_indices = np.array([], dtype=int)
#         noise_peaks_indices = np.array([], dtype=int)
#
#
#         # Measurements filtering - 0-15 Hz band pass filter.
#         filtered_ecg_measurements = bandpass_filter(ecg_measurements, lowcut=filter_lowcut, highcut=filter_highcut, signal_freq=signal_frequency, filter_order=filter_order)
#
#         filtered_ecg_measurements[:5] = filtered_ecg_measurements[5]
#
#         # Derivative - provides QRS slope information.
#         differentiated_ecg_measurements = np.ediff1d(filtered_ecg_measurements)
#
#         # Squaring - intensifies values received in derivative.
#         squared_ecg_measurements = differentiated_ecg_measurements ** 2
#
#         # Moving-window integration.
#         integrated_ecg_measurements = np.convolve(squared_ecg_measurements, np.ones(integration_window)/integration_window)
#
#         # Fiducial mark - peak detection on integrated measurements.
#         detected_peaks_indices = findpeaks(data=integrated_ecg_measurements,
#                                                      limit=findpeaks_limit,
#                                                      spacing=findpeaks_spacing)
#
#         detected_peaks_values = integrated_ecg_measurements[detected_peaks_indices]
#
#         return detected_peaks_values,detected_peaks_indices

#
# def detect_peaks_1(data,Fs, cutoff=0.05):
#     # filtered_lead = filter_lead(data,Fs)
#     b, a = iirnotch(cutoff, Q = 0.005, fs = Fs)
#     filter_wandering = filtfilt(b, a, data)
#     peaks, _ = find_peaks(filter_wandering, distance=Fs/2)
#     return peaks

#
# def bandpass_filter(data, lowcut, highcut, signal_freq, filter_order):
#         """
#         Method responsible for creating and applying Butterworth filter.
#         :param deque data: raw data
#         :param float lowcut: filter lowcut frequency value
#         :param float highcut: filter highcut frequency value
#         :param int signal_freq: signal frequency in samples per second (Hz)
#         :param int filter_order: filter order
#         :return array: filtered data
#         """
#         nyquist_freq = 0.5 * signal_freq
#         low = lowcut / nyquist_freq
#         high = highcut / nyquist_freq
#         b, a = butter(filter_order, [low, high], btype="band")
#         y = lfilter(b, a, data)
#         return y

# def filter_lead(data, Fs):
#     # bp_filter = bandpass_filter(np.array(data),
#     #                             lowcut=constants.FILTER_LOWCUT,
#     #                             highcut=constants.FILTER_HIGHCUT,
#     #                             signal_freq=Fs,
#     #                             filter_order=constants.FILTER_ORDER)
#     bp_filter = data
#     t = np.arange(start=0, stop=len(bp_filter), step=1)
#     yy2 = lowess(t, bp_filter, 0.1).transpose()[0]
#     return (bp_filter - yy2)


#
# def detect_peaks_1(data,Fs, cutoff=0.05):
#     # filtered_lead = filter_lead(data,Fs)
#     b, a = iirnotch(cutoff, Q = 0.005, fs = Fs)
#     filter_wandering = filtfilt(b, a, data)
#     peaks, _ = find_peaks(filter_wandering, distance=Fs/2)
#     return peaks

#
# def bandpass_filter(data, lowcut, highcut, signal_freq, filter_order):
#         """
#         Method responsible for creating and applying Butterworth filter.
#         :param deque data: raw data
#         :param float lowcut: filter lowcut frequency value
#         :param float highcut: filter highcut frequency value
#         :param int signal_freq: signal frequency in samples per second (Hz)
#         :param int filter_order: filter order
#         :return array: filtered data
#         """
#         nyquist_freq = 0.5 * signal_freq
#         low = lowcut / nyquist_freq
#         high = highcut / nyquist_freq
#         b, a = butter(filter_order, [low, high], btype="band")
#         y = lfilter(b, a, data)
#         return y


# def filter_lead(data, Fs):
#     bp_filter = bandpass_filter(np.array(data),
#                                 lowcut=constants.FILTER_LOWCUT,
#                                 highcut=constants.FILTER_HIGHCUT,
#                                 signal_freq=Fs,
#                                 filter_order=constants.FILTER_ORDER)
#     t = np.arange(start=0, stop=len(bp_filter), step=1)
#     yy2 = lowess(t, bp_filter, 0.1).transpose()[0]
#     return (bp_filter - yy2)
#
#
# def detect_peaks_1(data,Fs, cutoff=0.05):
#     # filtered_lead = filter_lead(data,Fs)
#     b, a = iirnotch(cutoff, Q = 0.005, fs = Fs)
#     filter_wandering = filtfilt(b, a, data)
#     peaks, _ = find_peaks(filter_wandering, distance=Fs/2)
#     return peaks



#
# def findpeaks(data, spacing=1, limit=None):
#         """
#         Janko Slavic peak detection algorithm and implementation.
#         https://github.com/jankoslavic/py-tools/tree/master/findpeaks
#         Finds peaks in `data` which are of `spacing` width and >=`limit`.
#         :param ndarray data: data
#         :param float spacing: minimum spacing to the next peak (should be 1 or more)
#         :param float limit: peaks should have value greater or equal
#         :return array: detected peaks indexes array
#         """
#         len = data.size
#         x = np.zeros(len + 2 * spacing)
#         x[:spacing] = data[0] - 1.e-6
#         x[-spacing:] = data[-1] - 1.e-6
#         x[spacing:spacing + len] = data
#         peak_candidate = np.zeros(len)
#         peak_candidate[:] = True
#         for s in range(spacing):
#             start = spacing - s - 1
#             h_b = x[start: start + len]  # before
#             start = spacing
#             h_c = x[start: start + len]  # central
#             start = spacing + s + 1
#             h_a = x[start: start + len]  # after
#             peak_candidate = np.logical_and(peak_candidate, np.logical_and(h_c > h_b, h_c > h_a))
#
#         ind = np.argwhere(peak_candidate)
#         ind = ind.reshape(ind.size)
#         if limit is not None:
#             ind = ind[data[ind] > limit]
#         return ind


def parse_hea_file(header_data):
    tmp_hea = header_data[0].split(' ')
    ptID = tmp_hea[0]
    num_leads = int(tmp_hea[1])
    sample_Fs = int(tmp_hea[2])
    gender = None
    age = 0
    classes = []

    concats = []
    for ii in range(num_leads):
        tmp_hea = header_data[ii + 1].split(' ')
        lead_result = {}
        lead_result['idx'] = ii
        lead_result['gain'] = int(tmp_hea[2].split('/')[0])
        lead_result['lead'] = tmp_hea[8].strip()
        concats.append(pd.Series(lead_result, index=lead_result.keys()))

    for iline in header_data:
        if iline.startswith('#Age: '):
            tmp_age = iline.split(': ')[1].strip()
            age = int(tmp_age if tmp_age != 'NaN' else 0)
        elif iline.startswith('#Sex: '):
            if (iline.split(': ')[1].strip() == 'Female'):
                gender = constants.FEMALE
            else:
                gender = constants.MALE
        elif iline.startswith('#Dx: '):
            classes = [cl.strip() for cl in iline.split(': ')[1].split(',')]
    lead_info = pd.concat(concats, axis=1).T
    return num_leads, ptID, gender, age, sample_Fs, lead_info, classes

def filter_lead(data, Fs):
    nyquist_freq = 0.5 * Fs
    low1 = 35.0 / nyquist_freq
    b1, a1 = butter(3, low1, btype = 'lowpass', analog=False)
    y1 = signal.filtfilt(b1, a1, data)
    b, a = signal.iirnotch(.05, Q = 0.005, fs = 1000)
    return signal.filtfilt(b, a, y1)

#######Added by Arnold to extract basic features from peaks and valleys
def get_target_classes(targets):
    ll = np.zeros(len(constants.LABELS))
    for tg in targets:
        if tg in constants.REMAP.keys():
            ll[constants.LABELS.index(constants.REMAP[tg])] = 1
        elif tg in constants.LABELS:
            ll[constants.LABELS.index(tg)] = 1
    return ll

# def get_target_classes_level1(targets):
#     ll = np.zeros(len(constants.LABELS_LEVLE_CONDITION))
#     for tg in targets:
#         if tg in constants.REMAP.keys():
#             ll[constants.LABELS.index(constants.REMAP[tg])] = 1
#         elif tg in constants.LABELS_LEVLE_CONDITION:
#             ll[constants.LABELS_LEVLE_CONDITION.index(tg)] = 1
#     return ll
#
# def get_target_classes_level2(targets):
#     ll = np.zeros(len(constants.LABELS_LEVEL_RHYTHM))
#     for tg in targets:
#         #if tg in constants.REMAP.keys():
#         #    ll[constants.LABELS.index(constants.REMAP[tg])] = 1
#         if tg in constants.LABELS_LEVEL_RHYTHM:
#             ll[constants.LABELS_LEVEL_RHYTHM.index(tg)] = 1
#     return ll
#
# def get_target_classes_NA(targets):
#     ll = 1
#     #if (targets not in constants.LABELS_LEVEL_RHYTHM) or ('426783006' in targets):
#     if ('426783006' == targets[0]) and (len(targets)==1):
#         ll =0
#     return ll

def LocatePeaks(SingleLeadDataSet, list_QRSlocations,num_samplingRate):
    peaksloc, _ = scipysi.find_peaks(SingleLeadDataSet)
    peaksprom = scipysi.peak_prominences(SingleLeadDataSet, peaksloc)
    half_peakswidth = scipysi.peak_widths(SingleLeadDataSet, peaksloc, rel_height=0.5)
    full_peakswidth = scipysi.peak_widths(SingleLeadDataSet, peaksloc,  rel_height=1)
    peaksheight = SingleLeadDataSet[peaksloc]

    valleyloc, _ = scipysi.find_peaks(-SingleLeadDataSet)
    valleysprom = scipysi.peak_prominences(-SingleLeadDataSet, valleyloc)
    half_valleyswidth = scipysi.peak_widths(-SingleLeadDataSet, valleyloc, rel_height=0.5)
    full_valleyswidth = scipysi.peak_widths(-SingleLeadDataSet, valleyloc, rel_height=1)
    valleysheight = SingleLeadDataSet[valleyloc]
    # IsPeak: 0= QRS complex, 1= non-QRS peaks, -1=valleys
    PeaksDF = pd.DataFrame(
        {'Location': peaksloc,
         'Prom': peaksprom[0], 'PromLB': peaksprom[1]-peaksloc, 'PromRB': peaksprom[2]-peaksloc,
         'HalfWidth': half_peakswidth[0], 'ContourHeight': peaksprom[0]-peaksheight, 'FullWidth': full_peakswidth[0],
          'Height': peaksheight})
    PeaksDF['IsPeak'] = 1
    ValleysDF = pd.DataFrame(
        {'Location': valleyloc,
         'Prom': valleysprom[0], 'PromLB': valleysprom[1]-valleyloc, 'PromRB': valleysprom[2]-valleyloc,
         'HalfWidth': half_valleyswidth[0], 'ContourHeight': valleysprom[0]-valleysheight, 'FullWidth': full_valleyswidth[0],
         'Height': valleysheight})
    ValleysDF['IsPeak'] = -1

    PeaksValleysDF = pd.concat([PeaksDF, ValleysDF], axis=0, ignore_index=True)
    PeaksValleysDF = PeaksValleysDF.sort_values(by=['Location'], ascending=True)
    PeaksValleysDF = PeaksValleysDF.reset_index(drop=True)

    for iRpeak in list_QRSlocations:
        #find dominant wave before and after 0.1 second of R peak
        PeaksDFwithinRange = PeaksValleysDF.loc[(PeaksValleysDF['Location'] < iRpeak+ num_samplingRate*0.1) &
                                                (PeaksValleysDF['Location'] > iRpeak- num_samplingRate*0.1)]
        if PeaksDFwithinRange.shape[0]>0:
            PeaksDFwithinRange.iloc[:,1] = PeaksDFwithinRange.iloc[:,1].abs()
            PeaksValleysDF.iloc[PeaksDFwithinRange['Prom'].idxmax(), -1] =0

    return PeaksValleysDF

def CalculateSingleFileFeatures(array_ECGDataSet, list_QRSlocations,num_samplingRate):
    ColNames = ['Location','Prom', 'PromLB', 'PromRB','HalfWidth', 'ContourHeight', 'FullWidth','Height',
                'IsPeak']
    #'DistanceToAnteriorR', 'DistanceToPosteriorR',
    #IsPeak: 0= QRS complex, 1= non-QRS peaks, -1=valleys
    #PeaksValleysDF = pd.DataFrame(columns=ColNames)
    FeatureRatioArray = np.empty((0, 100), float)

    for iLead in range(array_ECGDataSet.shape[1]):
        SinglePeaksValleysDF = LocatePeaks(array_ECGDataSet[:, iLead], list_QRSlocations, num_samplingRate)

        PeaksValleysHorizontalDF = SinglePeaksValleysDF[ColNames[0:-1]]
        RatioArray_LeadCols = np.empty((0, PeaksValleysHorizontalDF.shape[0]), float)
        for iCol in range(0, PeaksValleysHorizontalDF.shape[1] - 1):
            RatioArray = np.true_divide(PeaksValleysHorizontalDF.iloc[:, iCol + 1:].transpose().to_numpy(),
                                        PeaksValleysHorizontalDF.iloc[:, iCol].transpose().to_numpy(), casting='unsafe',
                                        out=np.zeros_like(
                                            PeaksValleysHorizontalDF.iloc[:, iCol + 1:].transpose().to_numpy()),
                                        where=PeaksValleysHorizontalDF.iloc[:, iCol].transpose().to_numpy() != 0)
            RatioArray_LeadCols = np.append(RatioArray_LeadCols, RatioArray, axis=0)
        RatioArray_LeadCols = RatioArray_LeadCols.transpose()
        SingleLeadFullFeatureDF = pd.concat([SinglePeaksValleysDF[['IsPeak','Prom', 'PromLB', 'PromRB','HalfWidth', 'ContourHeight', 'FullWidth','Height']],
                                             pd.DataFrame(RatioArray_LeadCols)], axis=1, ignore_index=True)
        # Create three DataFrames, R-waves, non-R peaks, and valleys
        RWaveDF = SingleLeadFullFeatureDF.loc[SingleLeadFullFeatureDF.iloc[:,0]==0,]
        NonRpeaksDF = SingleLeadFullFeatureDF.loc[SingleLeadFullFeatureDF.iloc[:, 0] == 1,]
        ValleysDF = SingleLeadFullFeatureDF.loc[SingleLeadFullFeatureDF.iloc[:, 0] == -1,]
        for icolumn in range(1, SingleLeadFullFeatureDF.shape[1]):
            FeatureRatioArray = np.append(FeatureRatioArray, np.histogram(RWaveDF.iloc[:,icolumn], bins=100)[0])
            FeatureRatioArray = np.append(FeatureRatioArray, np.histogram(NonRpeaksDF.iloc[:, icolumn], bins=100)[0])
            FeatureRatioArray = np.append(FeatureRatioArray, np.histogram(ValleysDF.iloc[:, icolumn], bins=100)[0])

        #PeaksValleysDF = PeaksValleysDF.append(SinglePeaksValleysDF, sort=False, ignore_index=True)
    returnArray = FeatureRatioArray.flatten()
    return returnArray

def GenerateSingleEngFeatureData(array_ECGData, array_HeaderData, num_leadIIindex=1):
    num_leads, ptID, gender, age, sample_Fs, lead_info, classes = parse_hea_file(array_HeaderData)
    binary_classes = get_target_classes(classes)
    # if num_LabelLevel==0:
    #     binary_classes = get_target_classes(classes)
    # elif num_LabelLevel==1:
    #     binary_classes = get_target_classes_level1(classes)
    # elif num_LabelLevel==2:
    #     binary_classes = get_target_classes_level2(classes)
    # else:
    #     binary_classes = get_target_classes_NA(classes)
#@Michael, the column of arry_ECGData represents lead as following sequence (I, II, III, avR, avL, avF, V1, V2 ...V6)
#@Michael, we should add noise reduction before rescale
    #This is testing code need to chage

    array_ECGData = np.array([filter_lead(i, sample_Fs) for i in array_ECGData])
    TARGET_SAMPLING_RATE = 125
    newsize = int((array_ECGData.shape[1] * TARGET_SAMPLING_RATE / sample_Fs) + 0.5)
    array_ECGData = signal.resample(array_ECGData, newsize, axis=1)
    array_ECGData = array_ECGData.transpose()
    #assert array_ECGData.shape[1] == 12

    scaler = MinMaxScaler(feature_range=(-1, 1))
    array_NormalECGData = (scaler.fit_transform(array_ECGData))
# @Michael, lead II index will change if we give data with different number of leads
    #num_leadIIindex = 1
    print("Start File: ", ptID)
    RpeaksInfo = biosppy.signals.ecg.ecg(array_NormalECGData[:, num_leadIIindex], TARGET_SAMPLING_RATE, show=False)

    RpeaksLocationList = RpeaksInfo['rpeaks']
    RatioFeatures = CalculateSingleFileFeatures(array_NormalECGData, RpeaksLocationList, TARGET_SAMPLING_RATE)
    array_FullFeatures = np.hstack([[gender, age, np.mean(np.diff(RpeaksLocationList)),
                      np.std(np.diff(RpeaksLocationList)), 60/(np.mean(np.diff(RpeaksLocationList))/TARGET_SAMPLING_RATE),
                      array_ECGData.shape[0]/TARGET_SAMPLING_RATE], RatioFeatures])
    print("Finished File: ", ptID)
    return array_FullFeatures, binary_classes

def GenerateEngFeature(list_ECGData, list_HeaderData, num_leadIIindex, num_cores=-1):
    print("Start Generating Eng Feature Array")
    results, targets = zip(*Parallel(n_jobs=num_cores)(delayed(GenerateSingleEngFeatureData)
                    (list_ECGData[eachSubject], list_HeaderData[eachSubject], num_leadIIindex) for eachSubject in range(len(list_ECGData))))
    #for eachSubject in range(len(list_ECGData)):
    #    GenerateSingleEngFeatureData(list_ECGData[eachSubject], list_HeaderData[eachSubject], num_LabelLevel, num_leadIIindex)

    EngFeatureArray = np.column_stack(results)
    EngFeatureArray = EngFeatureArray.transpose()
    # targets = targets.transpose()
    print("Finish Generating Eng Data Feature Array")
    #EngFeatureArray, row is patients, column is feature
    return np.array(EngFeatureArray), np.array(targets)
#
# # np.array(minmax_scale(data, feature_range=(-1,1)))
# def get_12ECG_features(data, header_data):
#      tmp_hea = header_data[0].split(' ')
#      ptID = tmp_hea[0]
#      num_leads = int(tmp_hea[1])
#      sample_Fs= int(tmp_hea[2])
#      gain_lead = np.zeros(num_leads)
#
#      for ii in range(num_leads):
#          tmp_hea = header_data[ii+1].split(' ')
#          gain_lead[ii] = int(tmp_hea[2].split('/')[0])
#
#      # for testing, we included the mean age of 57 if the age is a NaN
#      # This value will change as more data is being released
#      for iline in header_data:
#          if iline.startswith('#Age'):
#              tmp_age = iline.split(': ')[1].strip()
#              age = int(tmp_age if tmp_age != 'NaN' else 57)
#          elif iline.startswith('#Sex'):
#              tmp_sex = iline.split(': ')[1]
#              if tmp_sex.strip()=='Female':
#                  sex =1
#              else:
#                  sex=0
#          elif iline.startswith('#Dx'):
#              label = iline.split(': ')[1].split(',')[0]

# #   We are only using data from lead1
#     peaks,idx = detect_peaks(data[0],sample_Fs,gain_lead[0])
#
# #   mean
#     mean_RR = np.mean(idx/sample_Fs*1000)
#     mean_Peaks = np.mean(peaks*gain_lead[0])
#
# #   median
#     median_RR = np.median(idx/sample_Fs*1000)
#     median_Peaks = np.median(peaks*gain_lead[0])
#
# #   standard deviation
#     std_RR = np.std(idx/sample_Fs*1000)
#     std_Peaks = np.std(peaks*gain_lead[0])
#
# #   variance
#     var_RR = stats.tvar(idx/sample_Fs*1000)
#     var_Peaks = stats.tvar(peaks*gain_lead[0])
#
# #   Skewness
#     skew_RR = stats.skew(idx/sample_Fs*1000)
#     skew_Peaks = stats.skew(peaks*gain_lead[0])
#
# #   Kurtosis
#     kurt_RR = stats.kurtosis(idx/sample_Fs*1000)
#     kurt_Peaks = stats.kurtosis(peaks*gain_lead[0])
#
#     features = np.hstack([age,sex,mean_RR,mean_Peaks,median_RR,median_Peaks,std_RR,std_Peaks,var_RR,var_Peaks,skew_RR,skew_Peaks,kurt_RR,kurt_Peaks])
#
#
#     return features


#Testing code to explore generated features
