# -*- coding: utf-8 -*-
"""
Created on Fri Jun 18 11:51:11 2021

@author: Maurice
"""


#DIAG_CLASSES = {'270492004':0,'164889003':1,'164890007':2,'426627000':3,'713426002':4,'445118002':5,
#                '39732003':6,'164909002':7,'251146004':8,'698252002':9,'10370003':10,'164947007':11,'111975006':12,'164917005':13,
#                '47665007':14,'59118001':15,'427393009':16,'426177001':17,'426783006':18,'427084000':19,'63593006':20,'164934002':21,'59931005':22,'17338001':23}

EQUIV_CLASSES= {'164909002':'733534002','59118001':'713427006','63593006':'284470004','17338001':'427172004'}
DIAG_CLASSES = {'164889003':0,'164890007':1,'6374002':2,'426627000':3,'733534002':4,'713427006':5,'270492004':6,'713426002':7,
                '39732003':8,'445118002':9,'251146004':10,'698252002':11,'426783006':12,'284470004':13,'10370003':14,
                '365413008':15,'427172004':16,'164947007':17,'111975006':18,'164917005':19,'47665007':20,'427393009':21,
                '426177001':22,'427084000':23,'164934002':24,'59931005':25}
# Including dual classes (which are scored as equal):
# DIAG_CLASSES = {'164889003':0,'164890007':1,'6374002':2,'426627000':3,'733534002':4,'713427006':5,'270492004':6,'713426002':7,
#                 '39732003':8,'445118002':9,'164909002':10,'251146004':11,'698252002':12,'426783006':13,'284470004':14,'10370003':15,
#                 '365413008':16,'427172004':17,'164947007':18,'111975006':19,'164917005':20,'47665007':21,'59118001':22,'427393009':23,
#                 '426177001':24,'427084000':25,'63593006':26,'164934002':27,'59931005':28,'17338001':29}


# %% Data Extraction
from helper_code import *
import neurokit2 as nk
from tqdm import tqdm
import h5py
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from sklearn.model_selection import train_test_split
import torch


def extract_labels_OH(header):
    '''
    pseudo one hot encoding of labels (multiple labels can be true)

    Parameters
    ----------
    header : string

    Returns
    -------
    labels_OH : ndarray
        1 dim array with entries: 1 for each assigned label.

    '''
    labels_OH = np.zeros(len(DIAG_CLASSES),float)
    labels = get_labels(header)
    for label in labels:
        equi_label = EQUIV_CLASSES.get(label)
        if type(equi_label) == str:
            labels_OH[DIAG_CLASSES[equi_label]] = 1
        else :
            idx = DIAG_CLASSES.get(label)
            if type(idx) == int:
                labels_OH[idx] = 1
    return labels_OH

def lin_resample(signal,fs,fsnew):
    t_orig = np.arange((signal.T.shape[0]))/fs
    t_res = np.arange(t_orig[0],t_orig[-1],1/fsnew)
    resample = interp1d(t_orig,signal)
    return resample(t_res)     



from sklearn.model_selection import train_test_split

def normalize_signal2(signal):
    sig_mean = signal.mean()
    sig_sig = signal.std()
    return (signal-sig_mean)/sig_sig

#extract random sample with zero padding
#np.random.seed(42)
def extract_random_sample(sig,length):
    sample = np.zeros(length)
    sig_len = sig.shape[0]
    if sig_len <= length:
        sample[:sig_len]=sig
    else:
        sig_start = np.random.randint(0,sig_len-length)
        sample = sig[sig_start:sig_start+length]
    return sample
#extract random sample with zero padding per signal in sigs
def extract_random_nsamples(sigs,length):
    samples = np.zeros((sigs.shape[0],length))
    sig_len = sigs.shape[1]
    if sig_len <= length:
        samples[:,:sig_len]=sigs
    else:
        sig_start = np.random.randint(0,sig_len-length)
        samples = sigs[:,sig_start:sig_start+length]
    return samples

# returns list of segments
def extract_overlapping_nsamples(sigs,length,overlap=0.5):
    samples = list()
    sig_len = sigs.shape[1]
    last_sample = np.zeros((sigs.shape[0],length))
    if sig_len <= length:
        last_sample[:,:sig_len]=sigs
        samples.append(last_sample)
        return samples
    n_samples = int((sig_len/length-1)/(1-overlap)+1)
    for i in range(n_samples):
        samples.append(sigs[:,int(i*length*overlap):int(i*length*overlap)+length])
    return samples

def get_randomly_zeroed(sigs,percent):
    amount = int(percent/100*sigs.shape[1])
    sigs_out = sigs.copy()
    for i in range(sigs.shape[0]):
        zero_ind = np.random.randint(0,2048,size=amount)
        sigs_out[i,zero_ind]=0
    return sigs_out

def get_dropout_bursted(sigs,dropout_len):
    zero_ind = np.random.randint(0,sigs.shape[1]-dropout_len)
    sigs_out=sigs.copy()
    sigs_out[:,zero_ind:zero_ind+dropout_len]=0
    return sigs_out


#data_dir = "../../Data/WFDB_CPSC2018"

#my_labels= {'426783006':0,'164889003':1,'59118001':2,'284470004':3,'270492004':4}
#my_label_names = ['sinus rhythm','atrial fibrillation','right bundle branch block','premature atrial contraction','1st degree av block']
# def oh_5(label):
#     y = np.zeros(5)
#     for idx,l in enumerate(my_labels.keys()):
#         if l==label:
#             y[idx]=1
#     return y

def preprocess_recordings(data_dir,n_rand_cuts=1,fs_res=250,sample_len=2048,thresh_clean=11,augmentation_types=dict(),extract_labels=True,shuffle_leads=False,validation=0.0):
    
    do_burst_augmentation = augmentation_types.get('dropout_burst')
    if do_burst_augmentation:
        print('Adding dropout burst augmentation')
    do_random_zeroing_augmentation = augmentation_types.get('random_zeroing')
    if do_random_zeroing_augmentation:
        print('Adding random_zeroing augmentation')
    
    t_headers_train,t_headers_test,t_recordings_train,t_recordings_test = [],[],[],[]    
    t_headers,t_recordings=find_challenge_files(data_dir)
    if validation>0:
        if validation ==1.0:
            t_headers_test,t_recordings_test=t_headers,t_recordings
        else:
            t_headers_train,t_headers_test,t_recordings_train,t_recordings_test = train_test_split(t_headers,t_recordings,test_size=validation,random_state=42) 
    else:
        t_headers_train,t_recordings_train=t_headers,t_recordings
        
    patients_data_train = []
    patients_data_test = []
    for idx in tqdm(range(len(t_headers_train)),miniters=100):
        header = load_header(t_headers_train[idx])
        recording = load_recording(t_recordings_train[idx])
        leads = ['I','II']
        recording = choose_leads(recording, header, leads)
        sampling_rate = get_frequency(header)
        #lead_idx = leads.index('I')
        if(extract_labels):
            labels = extract_labels_OH(header)
        else:
            labels = None
            
        #cut off end and beginning because often bad    
        clean_ecg_signal0 = nk.ecg_clean(recording[0,100:-100],sampling_rate,method='neurokit') #butterworth highpass+powerline (50Hz) alt:'pantompkins1985'
        clean_ecg_signal1 = nk.ecg_clean(recording[1,100:-100],sampling_rate,method='neurokit')
        #clean_ecg_signal0 = nk.signal_filter(recording[0,100:-100],sampling_rate,lowcut=0.1,highcut=50,method='butterworth',order=5)
        #clean_ecg_signal1 = nk.signal_filter(recording[1,100:-100],sampling_rate,lowcut=0.1,highcut=50,method='butterworth',order=5)
        clean_ecg_signal0 = lin_resample(clean_ecg_signal0, sampling_rate, fs_res)
        clean_ecg_signal1 = lin_resample(clean_ecg_signal1, sampling_rate, fs_res)
        # find R-peaks / QRS-complexes
        sig_len = clean_ecg_signal0.shape[0]                
        
        clean_ecg_signal0 = normalize_signal2(clean_ecg_signal0)
        clean_ecg_signal1 = normalize_signal2(clean_ecg_signal1)
        
        #cut out bad parts
        clean_ecg_signal0[np.abs(clean_ecg_signal0)>thresh_clean] = 0
        clean_ecg_signal1[np.abs(clean_ecg_signal1)>thresh_clean] = 0
        
        
        for k in range(n_rand_cuts):
            if shuffle_leads:
                segment_person0 = extract_random_sample(clean_ecg_signal0,sample_len).reshape(1,-1)  
                segment_person1 = extract_random_sample(clean_ecg_signal1,sample_len).reshape(1,-1)
                segments_person = np.concatenate((segment_person0,segment_person1),axis=0)
            else:
                clean_sigs = np.concatenate((clean_ecg_signal0.reshape(1,-1),clean_ecg_signal1.reshape(1,-1)),axis=0)
                segments_person = extract_random_nsamples(clean_sigs, sample_len)
            
            patients_data_train.append((torch.Tensor(segments_person).reshape(2,-1),labels))
            
        if do_burst_augmentation or do_random_zeroing_augmentation:
            lean_sigs = np.concatenate((clean_ecg_signal0.reshape(1,-1),clean_ecg_signal1.reshape(1,-1)),axis=0)
            segments_person = extract_random_nsamples(clean_sigs, sample_len)
            if do_burst_augmentation:
                segments_bursted = get_dropout_bursted(segments_person, 12) #at 250Hz that is roughly 50ms
                patients_data_train.append((torch.Tensor(segments_bursted).reshape(2,-1),labels))
            if do_random_zeroing_augmentation:
                segments_zeroing = get_randomly_zeroed(segments_person, 1) # 1 percent of 2048 is 20 
                patients_data_train.append((torch.Tensor(segments_zeroing).reshape(2,-1),labels))
        
            
    for idx in tqdm(range(len(t_headers_test)),miniters=100):
        header = load_header(t_headers_test[idx])
        recording = load_recording(t_recordings_test[idx])
        leads = ['I','II']
        recording = choose_leads(recording, header, leads)
        sampling_rate = get_frequency(header)
        #lead_idx = leads.index('I')
        if(extract_labels):
            labels = extract_labels_OH(header)
        else:
            labels = None
            
        #cut off end and beginning because often bad    
        clean_ecg_signal0 = nk.ecg_clean(recording[0,100:-100],sampling_rate,method='neurokit') #butterworth highpass+powerline (50Hz) alt:'pantompkins1985'
        clean_ecg_signal1 = nk.ecg_clean(recording[1,100:-100],sampling_rate,method='neurokit')
        clean_ecg_signal0 = lin_resample(clean_ecg_signal0, sampling_rate, fs_res)
        clean_ecg_signal1 = lin_resample(clean_ecg_signal1, sampling_rate, fs_res)
        # find R-peaks / QRS-complexes
        sig_len = clean_ecg_signal0.shape[0]                
        
        clean_ecg_signal0 = normalize_signal2(clean_ecg_signal0)
        clean_ecg_signal1 = normalize_signal2(clean_ecg_signal1)
        
        #cut out bad parts
        clean_ecg_signal0[np.abs(clean_ecg_signal0)>thresh_clean] = 0
        clean_ecg_signal1[np.abs(clean_ecg_signal1)>thresh_clean] = 0
        
        clean_sigs = np.concatenate((clean_ecg_signal0.reshape(1,-1),clean_ecg_signal1.reshape(1,-1)),axis=0)
        segments_person = extract_random_nsamples(clean_sigs, sample_len)
        patients_data_test.append((torch.Tensor(segments_person).reshape(2,-1),labels))
        
            
    return patients_data_train,patients_data_test,t_headers
        
def prep_single(header,recording,fs_res=250,sample_len=2048,thresh_clean=11):
    leads = ['I','II']
    recording = choose_leads(recording, header, leads)
    sampling_rate = get_frequency(header)
    clean_ecg_signal0 = nk.ecg_clean(recording[0,100:-100],sampling_rate,method='neurokit') #butterworth highpass+powerline (50Hz) alt:'pantompkins1985'
    clean_ecg_signal1 = nk.ecg_clean(recording[1,100:-100],sampling_rate,method='neurokit')
    clean_ecg_signal0 = lin_resample(clean_ecg_signal0, sampling_rate, fs_res)
    clean_ecg_signal1 = lin_resample(clean_ecg_signal1, sampling_rate, fs_res)
    # find R-peaks / QRS-complexes
    sig_len = clean_ecg_signal0.shape[0]                
    
    clean_ecg_signal0 = normalize_signal2(clean_ecg_signal0)
    clean_ecg_signal1 = normalize_signal2(clean_ecg_signal1)
    
    #cut out bad parts
    clean_ecg_signal0[np.abs(clean_ecg_signal0)>thresh_clean] = 0
    clean_ecg_signal1[np.abs(clean_ecg_signal1)>thresh_clean] = 0
    clean_sigs = np.concatenate((clean_ecg_signal0.reshape(1,-1),clean_ecg_signal1.reshape(1,-1)),axis=0)
    #segments_person = extract_random_nsamples(clean_sigs, sample_len)
    segments_person = extract_overlapping_nsamples(clean_sigs, sample_len)
    segments_person = torch.cat([torch.Tensor(samp).reshape(1,2,-1) for samp in segments_person],dim=0)
    return segments_person