
import numpy as np, os, sys
from scipy.io import loadmat
import pandas as pd
import glob
import numpy as np
from scipy.signal import butter, lfilter, filtfilt
import scipy.signal 
import pickle
from ecgdetectors import Detectors


# In[3]:


import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.layers import Conv1D, BatchNormalization, MaxPooling1D, Dense, Permute, Reshape, Bidirectional, LSTM, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model

from numpy.random import seed
from tensorflow import set_random_seed
import random

from sklearn.ensemble import RandomForestClassifier 
from sklearn.metrics import roc_auc_score
import joblib

BATCH_SIZE=128
detectors = Detectors(100)


def get_header_info(input_directory,files,statDataFrame,scored_classes_list):
    
    i=0

    for f in files:
        g = f.replace('.mat','.hea')
        input_file = os.path.join(input_directory,g)
        
        classes=set()
        sample_num=-1
        sample_freq=-1
        
        new_line=dict( (a,0) for a in statDataFrame.columns)
        
        with open(input_file,'r') as file:
            for lines in file:
                if sample_num<0:
                    tmp = lines.split(' ')
                    sample_num=int(tmp[3])
                    sample_freq=int(tmp[2])
                if lines.startswith('#Dx'):
                    tmp = lines.split(': ')[1].split(',')
                    for c in tmp:
                        c=c.strip()
                        if c=='59118001':
                            c='713427006'
                        if c=='63593006':
                            c='284470004'
                        if c=='17338001':
                            c='427172004'
                        classes.add(c)
        
        new_line['database']=input_directory
        new_line['file']=f
        new_line['sample num']=sample_num
        new_line['sample freq']=sample_freq
     
        for c in scored_classes_list:
            if c in classes:
                #print(c)
                new_line[c]=True
            else:
                new_line[c]=False
        statDataFrame=statDataFrame.append(new_line,ignore_index=True)
        
        if i%500==0:
            print(i)
            #break
        i=i+1
    return statDataFrame


# In[6]:



# DISCLAIMER: This function is copied from https://github.com/nwhitehead/swmixer/blob/master/swmixer.py, 
#             which was released under LGPL. 
def resample_by_interpolation(signal, input_fs, output_fs):

    scale = output_fs / input_fs
    # calculate new length of sample
    n = round(len(signal) * scale)

    # use linear interpolation
    # endpoint keyword means than linspace doesn't go all the way to 1.0
    # If it did, there are some off-by-one errors
    # e.g. scale=2.0, [1,2,3] should go to [1,1.5,2,2.5,3,3]
    # but with endpoint=True, we get [1,1.4,1.8,2.2,2.6,3]
    # Both are OK, but since resampling will often involve
    # exact ratios (i.e. for 44100 to 22050 or vice versa)
    # using endpoint=False gets less noise in the resampled sound
    resampled_signal = np.interp(
        np.linspace(0.0, 1.0, n, endpoint=False),  # where to interpret
        np.linspace(0.0, 1.0, len(signal), endpoint=False),  # known positions
        signal,  # known data points
    )
    return resampled_signal



def bandpass_filter_and_resample_12_channel(data_12,lowcut, highcut, signal_freq, filter_order, resample_freq):
    new_data_12 = np.full_like(data_12, 0)
    for i in range(12):
        data=data_12[i,:]
        data=data-np.min(data)
        mx=np.max(data)
        if mx==0:
            mx=1
        data=data/mx
        data=data-np.mean(data)
        nyquist_freq = 0.5 * signal_freq
        low = lowcut / nyquist_freq
        high = highcut / nyquist_freq
        b, a = butter(filter_order, [low,high], btype="bandpass")
        y = filtfilt(b, a, data)
        y=resample_by_interpolation(y, signal_freq, resample_freq)
        new_data_12[i,:len(y)]=y
    new_data_12=new_data_12[:,:len(y)]
    return new_data_12

def get12channelData(input_files,input_directory,statDataFrame):
    class_list=[]
    wave_12_data_list=[]
    inp_files=[]

    for i, f in enumerate(input_files):

        hea_filename=f[:-3]+'hea'
        data_description=statDataFrame[statDataFrame['file']==hea_filename].iloc[0]

        classes=data_description.iloc[4:]
        classes=classes.astype(int).values
        sf=data_description['sample freq']
        #print(sf)
        #print(classes)
        #break

        tmp_input_file = os.path.join(input_directory,f)
        x = loadmat(tmp_input_file)
        data = np.asarray(x['val'], dtype=np.float64)


        if i%100==0:
            print(i)
        new_data_12=bandpass_filter_and_resample_12_channel(data, 0.5, 40, sf, 3, 100)
        
        ecg_new = np.array(new_data_12)
        #ecg_plot.plot_12(ecg_new, sample_rate=100, title = 'ECG ')
        #break

        wave_12_data_list.append(ecg_new)
        class_list.append(classes)
        inp_files.append(f)
        
    return wave_12_data_list,class_list,inp_files


# In[7]:


def getBeats(wave_12_data_list):
    peaksList=[]
    ekg_1_beat=[]
    ekg_2_beat=[]
    ekg_3_beat=[]
    ekg_4_beat=[]
    for i in range(len(wave_12_data_list)):
        ekg=wave_12_data_list[i][11]
        r_peaks = detectors.swt_detector(ekg)

        r_peaks_rev=[]
        prev_peak=0

        for j in r_peaks:
            if j>=len(ekg):
                continue
            while j>0 and ekg[j-1]>ekg[j]:
                j=j-1
            while j>0 and ekg[j-1]<ekg[j]:
                j=j-1
            if j-prev_peak>30:
                r_peaks_rev.append(j)
            prev_peak=j
        r_peaks=r_peaks_rev

        peaksList.append(r_peaks)

        try:
            ekg_part1=np.asarray(wave_12_data_list[i][:])[:,r_peaks[1]:r_peaks[2]]
            ekg_1_beat.append(ekg_part1)
        except:
            ekg_1_beat.append([])
        try:
            ekg_part2=np.asarray(wave_12_data_list[i][:])[:,r_peaks[2]:r_peaks[3]]
            ekg_2_beat.append(ekg_part2)
        except:
            ekg_2_beat.append([])
        try:
            ekg_part3=np.asarray(wave_12_data_list[i][:])[:,r_peaks[3]:r_peaks[4]]
            ekg_3_beat.append(ekg_part3)
        except:
            ekg_3_beat.append([])
        try:
            ekg_part4=np.asarray(wave_12_data_list[i][:])[:,r_peaks[4]:r_peaks[5]]
            ekg_4_beat.append(ekg_part4)
        except:
            ekg_4_beat.append([])

        if i%1000==0:
            print(i)
            
    return peaksList,ekg_1_beat,ekg_2_beat,ekg_3_beat,ekg_4_beat


# In[8]:


def extendWithZeros(fea,length):
    N=min(len(fea), length)
    retval=np.zeros(length)
    retval[:N]=fea[:N]
    return retval
        


# In[9]:


def getBeatFeatures(wave_12_data,rf,encoder_model):
    ch_to_listen=[0,1,3,5]
    ekg=wave_12_data[11]
    ekg_1=wave_12_data[1]
    r_peaks = detectors.swt_detector(ekg)

    r_peaks_rev=[]
    prev_peak=0

    for j in r_peaks:
        if j>=len(ekg):
            continue
        while j>0 and ekg[j-1]>ekg[j]:
            j=j-1
        while j>0 and ekg[j-1]<ekg[j]:
            j=j-1
        if j-prev_peak>30:
            r_peaks_rev.append(j)
        prev_peak=j
    r_peaks=r_peaks_rev
    
    interesting_part_begin_idx=0

    if len(r_peaks)>4:
        preds=[]
        for peak_idx in range(len(r_peaks)-1):
            peak=ekg_1[r_peaks[peak_idx]:r_peaks[peak_idx+1]]

            resampled_signal = np.interp(
                np.linspace(0.0, 1.0, 80, endpoint=False),  # where to interpret
                np.linspace(0.0, 1.0, len(peak), endpoint=False),  # known positions
                peak)
            test_factor=80.0/len(peak)
            
            encoder_out = encoder_model.predict(np.expand_dims(resampled_signal,axis=0))
            pred=rf.predict_proba(np.expand_dims(np.append(encoder_out,test_factor),0))
            preds.append(pred[0][1]) # the probability of normal

        sorted_idxs=np.argsort(preds).tolist() #the less probability of normal
        beat_fea_1=ekg_1[r_peaks[sorted_idxs[0]]:r_peaks[1+sorted_idxs[0]]]
        beat_fea_1=extendWithZeros(beat_fea_1,200)
        beat_fea_2=ekg_1[r_peaks[sorted_idxs[1]]:r_peaks[1+sorted_idxs[1]]]
        beat_fea_2=extendWithZeros(beat_fea_2,200)
        beat_fea_3=ekg_1[r_peaks[sorted_idxs[2]]:r_peaks[1+sorted_idxs[2]]]
        beat_fea_3=extendWithZeros(beat_fea_3,200)
        beat_fea_4=ekg_1[r_peaks[sorted_idxs[3]]:r_peaks[1+sorted_idxs[3]]]
        beat_fea_4=extendWithZeros(beat_fea_4,200) 

        interesting_part_begin_idx=max(0,r_peaks[1+sorted_idxs[0]]-1000)
        
        RR_fea_1=np.mean(np.diff(r_peaks))
        RR_fea_2=np.std(np.diff(r_peaks))
        
    else:
        beat_fea_1=None
        beat_fea_2=None
        beat_fea_3=None
        beat_fea_4=None
        encoder_out=None

        RR_fea_1=None
        RR_fea_2=None

    #4 channel waveform features

    ekg_part=wave_12_data[ch_to_listen]
    interesting_part_begin_idx=max(0,min(interesting_part_begin_idx,ekg_part.shape[1]-1000))
    wave_fea=np.zeros((4,1000))
    l=min(1000,ekg_part.shape[1]-interesting_part_begin_idx)
    wave_fea[:,:l]=ekg_part[:,interesting_part_begin_idx:interesting_part_begin_idx+l]

    return RR_fea_1, RR_fea_2,encoder_out, beat_fea_1, beat_fea_2, beat_fea_3, beat_fea_4, wave_fea
    
            


# In[10]:


def batch_generator(X, Y, batch_size = BATCH_SIZE):
    indices = np.arange(len(X)) 
    batch=[]
    while True:
        # it might be a good idea to shuffle your data before each epoch
        np.random.shuffle(indices) 
        for i in indices:
            batch.append(i)
            if len(batch)==batch_size:
                yield X[batch], Y[batch]
                batch=[]


# In[11]:


class CRNN:

    def __init__(self, num_classes=1):
        self.num_classes = num_classes
        self.model = Sequential()

    def build_model(self, input_shape, weight_decay=0.001, convolution_activation='relu', padding='same',
                    pool_size=2, strides=2, 
                    output_layer_activation='sigmoid'):
        
        kernel_regularizer = l2(weight_decay)

        # Layer 1
        self.model.add(Conv1D(16, kernel_size=7, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer, input_shape=input_shape[1:]))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))

        # Layer 2
        self.model.add(Conv1D(32, kernel_size=5, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))

        # Layer 3
        self.model.add(Conv1D(64, kernel_size=3, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))

        self.model.add(BatchNormalization())

        self.model.add(Bidirectional(LSTM(100, return_sequences=False), merge_mode="concat"))
        self.model.add(BatchNormalization())

        self.model.add(Dense(self.num_classes, activation=output_layer_activation))

        print(self.model.summary())
        return self.model

def start_training_wave_model(x_train,y_train,x_val,y_val, log_dir, start_model, num_epochs, optimizer=tf.keras.optimizers.Adam(lr=0.001, decay=1e-6), loss="binary_crossentropy",
                   metrics=None):
 
    prec_metric=tf.keras.metrics.Precision(name="precision")
    rec_metric=tf.keras.metrics.Recall(name="recall")

    if metrics is None:
        metrics=['accuracy', prec_metric, rec_metric, tf.keras.metrics.AUC(name="auc")]

    # Training Callbacks
    checkpoint_filename = os.path.join(log_dir, "weights.{epoch:02d}.model")
    model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1, monitor='val_auc', mode="max")
    early_stopping_callback = EarlyStopping(monitor='val_auc', min_delta=0, patience=8, verbose=1, mode="max")
    
    csv_logger_callback = CSVLogger(os.path.join(log_dir, "log.csv"))

    if start_model is None:
        crnn = CRNN()
        model = crnn.build_model((None,1000,4))
    else:
        model=tf.keras.models.clone_model(start_model)

    model.compile(optimizer, loss, metrics)

    train_generator = batch_generator(np.array(x_train), y_train, BATCH_SIZE)
    val_generator = batch_generator(np.array(x_val), y_val, BATCH_SIZE)

    history = model.fit_generator(
        train_generator,
        steps_per_epoch=len(y_train)// BATCH_SIZE,
        epochs=num_epochs,
        callbacks=[model_checkpoint_callback, csv_logger_callback, early_stopping_callback],
        verbose=1,
        validation_data=val_generator,
        validation_steps=len(y_val)// BATCH_SIZE
    )

    return model, history.history['val_precision'][-1], history.history['val_recall'][-1],history.history['val_auc'][-1]


# In[12]:


class CRNN_beat:

    def __init__(self, num_classes=1):
        self.num_classes = num_classes
        self.model = Sequential()

    # TODO: tune model hyper-parameters
    def build_model(self, input_shape, weight_decay=0.001, convolution_activation='relu', padding='same',
                    pool_size=2, strides=1, 
                    output_layer_activation='sigmoid'):
        
        kernel_regularizer = l2(weight_decay)

        self.model.add(Conv1D(8, kernel_size=5, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer, input_shape=input_shape[1:]))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))
        self.model.add(BatchNormalization())

        # Layer 2
        self.model.add(Conv1D(16, kernel_size=3, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))
        self.model.add(BatchNormalization())

        self.model.add(Bidirectional(LSTM(50, return_sequences=True) ))
        self.model.add(Bidirectional(LSTM(50))) 

        self.model.add(Dense(self.num_classes, activation=output_layer_activation))

        print(self.model.summary())
        return self.model

def start_training_beat(x_train,y_train,x_val,y_val, log_dir, start_model, num_epochs, optimizer=tf.keras.optimizers.Adam(lr=0.001, decay=1e-6), loss="binary_crossentropy",
                   metrics=None):
 
    prec_metric=tf.keras.metrics.Precision(name="precision")
    rec_metric=tf.keras.metrics.Recall(name="recall")

    if metrics is None:
        metrics=['accuracy', prec_metric, rec_metric, tf.keras.metrics.AUC(name="auc")]

    # Training Callbacks
    checkpoint_filename = os.path.join(log_dir, "weights.{epoch:02d}.model")
    #model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1, monitor="val_accuracy")
    model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1, monitor='val_auc', mode="max")
    early_stopping_callback = EarlyStopping(monitor='val_auc', min_delta=0, patience=8, verbose=1, mode="max")
    
    csv_logger_callback = CSVLogger(os.path.join(log_dir, "log.csv"))

    if start_model is None:
        crnn = CRNN_beat()
        model = crnn.build_model((None,200,4))
    else:
        model=tf.keras.models.clone_model(start_model)
    
    model.compile(optimizer, loss, metrics)
    #train_generator = batch_generator(np.expand_dims(np.array(x_train), axis=2), y_train, BATCH_SIZE)
    #val_generator = batch_generator(np.expand_dims(np.array(x_val), axis=2), y_val, BATCH_SIZE)
    train_generator = batch_generator(np.array(x_train), y_train, BATCH_SIZE)
    val_generator = batch_generator(np.array(x_val), y_val, BATCH_SIZE)

    history = model.fit_generator(
        train_generator,
        steps_per_epoch=len(y_train)// BATCH_SIZE,
        epochs=num_epochs,
        callbacks=[model_checkpoint_callback, csv_logger_callback, early_stopping_callback],
        verbose=1,
        validation_data=val_generator,
        validation_steps=len(y_val)// BATCH_SIZE
    )

    metrics_ = model.evaluate(
        val_generator,
        steps=len(y_val)// BATCH_SIZE,
        #batch_size=BATCH_SIZE,
        verbose=1,
        sample_weight=None
        )
    return model, metrics_[2], metrics_[3],metrics_[4]

"""
def start_training_output_map(x_train,y_train,x_val,y_val, log_dir, num_epochs, optimizer=tf.keras.optimizers.Adam(lr=0.001, decay=1e-6), loss="binary_crossentropy",
                   metrics=None):
 
    prec_metric=tf.keras.metrics.Precision(name="precision")
    rec_metric=tf.keras.metrics.Recall(name="recall")

    if metrics is None:
        metrics=['accuracy', prec_metric, rec_metric, tf.keras.metrics.AUC(name="auc")]

    # Training Callbacks
    checkpoint_filename = os.path.join(log_dir, "weights.{epoch:02d}.model")
    #model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1, monitor="val_accuracy")
    model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1, monitor='val_auc', mode="max")
    #early_stopping_callback = EarlyStopping(monitor='val_auc', min_delta=0, patience=10, verbose=1, mode="max")
    
    csv_logger_callback = CSVLogger(os.path.join(log_dir, "log.csv"))

    o_map_net = output_mapper()
    model = o_map_net.build_model()
    
    model.compile(optimizer, loss, metrics)
    #train_generator = batch_generator(np.expand_dims(np.array(x_train), axis=2), y_train, BATCH_SIZE)
    #val_generator = batch_generator(np.expand_dims(np.array(x_val), axis=2), y_val, BATCH_SIZE)
    train_generator = batch_generator(np.array(x_train), y_train, BATCH_SIZE)
    val_generator = batch_generator(np.array(x_val), y_val, BATCH_SIZE)

    history = model.fit_generator(
        train_generator,
        steps_per_epoch=len(y_train)// BATCH_SIZE,
        epochs=num_epochs,
        #callbacks=[model_checkpoint_callback, csv_logger_callback, early_stopping_callback],
        callbacks=[model_checkpoint_callback, csv_logger_callback],
        verbose=1,
        validation_data=val_generator,
        validation_steps=len(y_val)// BATCH_SIZE
    )

    return model, history.history['val_precision'][-1], history.history['val_recall'][-1],history.history['val_auc'][-1]
"""
def start_training_output_map2(x_train,y_train,x_val,y_val, log_dir, num_epochs, optimizer=tf.keras.optimizers.Adam(lr=0.001, decay=1e-6), loss="binary_crossentropy",
                   metrics=None):
 
    prec_metric=tf.keras.metrics.Precision(name="precision")
    rec_metric=tf.keras.metrics.Recall(name="recall")

    if metrics is None:
        metrics=['accuracy', prec_metric, rec_metric, tf.keras.metrics.AUC(name="auc")]

    # Training Callbacks
    checkpoint_filename = os.path.join(log_dir, "weights.{epoch:02d}.model")
    #model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1, monitor="val_accuracy")
    model_checkpoint_callback = ModelCheckpoint(checkpoint_filename, save_best_only=True, verbose=1, monitor='val_auc', mode="max")
    early_stopping_callback = EarlyStopping(monitor='val_auc', min_delta=0, patience=10, verbose=1, mode="max")
    
    csv_logger_callback = CSVLogger(os.path.join(log_dir, "log.csv"))

    o_map_net = output_mapper(num_classes=24, input_size=48)
    model = o_map_net.build_model()
    
    model.compile(optimizer, loss, metrics)
    #train_generator = batch_generator(np.expand_dims(np.array(x_train), axis=2), y_train, BATCH_SIZE)
    #val_generator = batch_generator(np.expand_dims(np.array(x_val), axis=2), y_val, BATCH_SIZE)
    train_generator = batch_generator(np.array(x_train), y_train, BATCH_SIZE)
    val_generator = batch_generator(np.array(x_val), y_val, BATCH_SIZE)

    history = model.fit_generator(
        train_generator,
        steps_per_epoch=len(y_train)// BATCH_SIZE,
        epochs=num_epochs,
        #callbacks=[model_checkpoint_callback, csv_logger_callback, early_stopping_callback],
        callbacks=[model_checkpoint_callback, csv_logger_callback],
        verbose=1,
        validation_data=val_generator,
        validation_steps=len(y_val)// BATCH_SIZE
    )

    return model, history.history['val_precision'][-1], history.history['val_recall'][-1],history.history['val_auc'][-1]

class output_mapper:

    def __init__(self, num_classes=24, input_size=24):
        self.num_classes = num_classes
        self.input_size = input_size
        self.model = Sequential()

    # TODO: tune model hyper-parameters
    def build_model(self, weight_decay=0.001):
        
        kernel_regularizer = l2(weight_decay)
        #self.model.add(Dense(7, activation='sigmoid', input_shape=(self.input_size,)))
        self.model.add(Dense(self.num_classes, activation='sigmoid', input_shape=(self.input_size,)))
        #self.model.add(BatchNormalization())
        #self.model.add(Dense(self.num_classes, activation='sigmoid'))
        print(self.model.summary())
        return self.model



def train_12ECG_classifier(input_directory, output_directory):


    input_training_directory=input_directory #'d:/DATABASES/CHALLENGE2020/divided/0/train/'
    output_training_directory=output_directory #'d:/DATABASES/CHALLENGE2020/output_tr_dir/'

    TrainRatio=0.8
    encoder_epoch_num=200
    WaveCNN_first_epoch_num=30
    WaveCNN_next_epoch_num=15
    BeatCNN_first_epoch_num=30
    BeatCNN_next_epoch_num=15

    b_train_rf=True
    b_train_wave=True
    b_train_beat=False
    b_train_ensemble=True

    class_names=['270492004', '164889003', '164890007', '426627000', '713427006', '713426002', '445118002', '39732003', '164909002', '251146004', '698252002', '10370003', '284470004', '427172004', '164947007', '111975006', '164917005', '47665007', '427393009', '426177001', '426783006', '427084000', '164934002', '59931005']
    normal_class='426783006'
    normal_col_idx=class_names.index(normal_class)
    #BATCH_SIZE=128
    b_out=False

    seed(1)
    set_random_seed(2)
    random.seed(4)

    if not os.path.exists(output_training_directory):
        os.makedirs(output_training_directory)


    statDataFrame=pd.DataFrame(columns=["database","file","sample num",'sample freq']+class_names)

    files = [os.path.basename(x) for x in glob.glob(input_training_directory+ "/*.hea", recursive=False)]
    #files=files[:2000]

    statDataFrame=get_header_info(input_training_directory,files,statDataFrame,class_names)

    classPart=statDataFrame.iloc[:,4:]
    classPart=classPart.astype(int)
    print(classPart.head())
    if b_out:
        statDataFrame.to_pickle("Allinfo_v1.pkl")


    input_files_train = []
    for f in files: #os.listdir(input_training_directory):
        f=f[:-3]+"mat"
        if os.path.isfile(os.path.join(input_training_directory, f)) and not f.lower().startswith('.') and f.lower().endswith('mat'):
            input_files_train.append(f)
    random.shuffle(input_files_train)

    wave_12_data_list,class_list,inp_files=get12channelData(input_files_train,input_training_directory,statDataFrame)
    N_train=int(len(files)*TrainRatio)

    wave_12_data_list_train=wave_12_data_list[:N_train]
    class_list_train=class_list[:N_train]
    inp_files_train=inp_files[:N_train]

    wave_12_data_list_test=wave_12_data_list[N_train:]
    class_list_test=class_list[N_train:]
    inp_files_test=inp_files[N_train:]


    if b_out:
        with open('train_base_.pkl', 'wb') as f:  # Python 3: open(..., 'wb')
            pickle.dump([wave_12_data_list_train, class_list_train, inp_files_train], f)
        with open('test_base_.pkl', 'wb') as f:  # Python 3: open(..., 'wb')
            pickle.dump([wave_12_data_list_test, class_list_test, inp_files_test], f)


    print(len(wave_12_data_list_train), len(class_list_train), len(inp_files_train))

    #detectors = Detectors(100)

    peaksList_train,ekg_1_beat_train,ekg_2_beat_train,ekg_3_beat_train,ekg_4_beat_train=getBeats(wave_12_data_list_train)
    peaksList_test,ekg_1_beat_test,ekg_2_beat_test,ekg_3_beat_test,ekg_4_beat_test=getBeats(wave_12_data_list_test)


    # # Norml osztlyba tartoz szvverst tantunk

    train_labels=np.asarray(class_list_train)[:,class_names.index(normal_class)]
    test_labels=np.asarray(class_list_test)[:,class_names.index(normal_class)]

    train_labels_1=[]
    train_samples_1=[]
    train_factors=[]
    for i, sample in enumerate(ekg_1_beat_train):
        if len(sample)>0 and (len(sample[1])>40):
            train_labels_1.append(train_labels[i])
            
            resampled_signal = np.interp(
                np.linspace(0.0, 1.0, 80, endpoint=False),  # where to interpret
                np.linspace(0.0, 1.0, len(sample[1]), endpoint=False),  # known positions
                sample[1],  # known data points
            )
            train_factors.append(80.0/len(sample[1]))
            train_samples_1.append(resampled_signal)
        if i%1000==0:
            print(i)
            

    test_labels_1=[]
    test_samples_1=[]
    test_factors=[]
    for i, sample in enumerate(ekg_1_beat_test):
        if len(sample)>0 and (len(sample[1])>40):
            test_labels_1.append(test_labels[i])
            
            resampled_signal = np.interp(
                np.linspace(0.0, 1.0, 80, endpoint=False),  # where to interpret
                np.linspace(0.0, 1.0, len(sample[1]), endpoint=False),  # known positions
                sample[1],  # known data points
            )
            test_factors.append(80.0/len(sample[1]))
            test_samples_1.append(resampled_signal)
            
        if i%1000==0:
            print(i)


    if b_out:
        pickle.dump((train_samples_1,train_labels_1,test_samples_1,test_labels_1),
               open("normal_data.pkl","wb"))


    # Tantsuk...

    # # autoencoder

    input_ecg = Input(shape=(80,))
    encoded = Dense(40, activation='relu')(input_ecg)
    encoder_output = Dense(20, activation='relu')(encoded)
    decoded = Dense(40, activation='relu')(encoder_output)
    decoded = Dense(80, activation='tanh')(decoded)
    autoencoder = Model(input_ecg, decoded)

    autoencoder.compile(optimizer='adam', loss='mse')

    autoencoder.fit(np.asarray(train_samples_1), np.asarray(train_samples_1),
                    epochs=encoder_epoch_num,
                    batch_size=256,
                    shuffle=True,
                    validation_data=(np.asarray(test_samples_1), np.asarray(test_samples_1)))

    #el kne menteni az autoencoder modellt
    autoencoder.save(os.path.join(output_training_directory,'one_beat_model'))

    encoder_model= Model(input_ecg, encoder_output)
    encoder_model.save(os.path.join(output_training_directory,'one_beat_encoder'))


    # Az enkoder outputjra be kne tantani egy norml szinusz modellt

    encoded_train = encoder_model.predict(np.asarray(train_samples_1))
    encoded_test = encoder_model.predict(np.asarray(test_samples_1))

    X_train=np.hstack((encoded_train,np.expand_dims(np.asarray(train_factors),1)))
    X_test=np.hstack((encoded_test,np.expand_dims(np.asarray(test_factors),1)))
    y_train=train_labels_1
    y_test=test_labels_1
    rf = RandomForestClassifier(n_estimators = 1000, random_state = 42)
    rf.fit(X_train, y_train)

    # save
    joblib.dump(rf, os.path.join(output_training_directory,"normal_beat_rf.joblib"))

    pred=rf.predict_proba(X_test)
    try:
        a=roc_auc_score(y_test, pred[:,1])
    except:
        a=-1
    print(a)


    # # jelemzk meghatrozsa

    train_features=[]
    for i in range(len(wave_12_data_list_train)):

        features=getBeatFeatures(wave_12_data_list_train[i],rf,encoder_model)
        #RR_fea_1, RR_fea_2, encoder_out, beat_fea_1, beat_fea_2, beat_fea_3, beat_fea_4, wave_fea
        train_features.append(features)
        if i%100==0:
            print(i)
    if b_out:
        with open('train_features.pkl', 'wb') as f:
            pickle.dump(train_features, f)



    test_features=[]
    for i in range(len(wave_12_data_list_test)):

        features=getBeatFeatures(wave_12_data_list_test[i],rf,encoder_model)
        #RR_fea_1, RR_fea_2, encoder_out, beat_fea_1, beat_fea_2, beat_fea_3, beat_fea_4, wave_fea
        test_features.append(features)
        if i%100==0:
            print(i)

    if b_out:        
        with open('test_features.pkl', 'wb') as f:  # Python 3: open(..., 'wb')
            pickle.dump(test_features, f)


    # # hullmforma alap tantsa

    if b_train_wave:
        train_class_label_df=pd.DataFrame(class_list_train, columns=class_names)
        train_features_df=pd.DataFrame(train_features, columns=["RR_fea_1", "RR_fea_2","encoder_out", 
                                       "beat_fea_1", "beat_fea_2", "beat_fea_3", "beat_fea_4", "wave_fea"])
        test_class_label_df=pd.DataFrame(class_list_test, columns=class_names)
        test_features_df=pd.DataFrame(test_features, columns=["RR_fea_1", "RR_fea_2","encoder_out", 
                                       "beat_fea_1", "beat_fea_2", "beat_fea_3", "beat_fea_4", "wave_fea"])

        #x_train=[np.asarray(train_features_df['wave_fea'].iloc[i].T) for i in range(0,N_train)]
        #x_val=[np.asarray(train_features_df['wave_fea'].iloc[i].T) for i in range(N_train,len(idx))]
        N_train=len(class_list_train)
        N_test=len(class_list_test)
        x_train=[np.asarray(train_features_df['wave_fea'].iloc[i].T) for i in range(N_train)]
        x_val=[np.asarray(test_features_df['wave_fea'].iloc[i].T) for i in range(N_test)]

        y_train=train_class_label_df[normal_class].iloc[:].values
        y_val=train_class_label_df[normal_class].iloc[:].values

        log_dir = "class_"+str(normal_col_idx)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        print("Logging to {}".format(log_dir))

        model, precision, recall, auc = start_training_wave_model(np.asarray(x_train),np.asarray(y_train)
                                    ,np.asarray(x_val),np.asarray(y_val),log_dir,None,WaveCNN_first_epoch_num)

        model.save(os.path.join(output_training_directory,str(normal_col_idx)+"_class_model.h5"), include_optimizer=False)

        print('class {}, precision {}, recall {}, auc {}'.format(normal_col_idx, precision, recall, auc))
        with open("log_train.txt", "a") as myfile:
            myfile.write('class {}, precision {}, recall {}, auc {}\n'.format(normal_col_idx, precision, recall, auc))

        model_normal=tf.keras.models.clone_model(model)

        wave_models=[]

        for y_col_idx in range(24):

            if y_col_idx==normal_col_idx:
                wave_models.append(model_normal)
                continue


            y_train=train_class_label_df[class_names[y_col_idx]].iloc[:].values
            y_val=test_class_label_df[class_names[y_col_idx]].iloc[:].values

            log_dir = "class_"+str(y_col_idx)
            if not os.path.exists(log_dir):
                os.makedirs(log_dir)
            print("Logging to {}".format(log_dir))


            model, precision, recall, auc = start_training_wave_model(np.asarray(x_train),np.asarray(y_train)
                                        ,np.asarray(x_val),np.asarray(y_val),log_dir,model_normal,WaveCNN_next_epoch_num)

            model.save(os.path.join(output_training_directory,str(y_col_idx)+"_class_model.h5"), include_optimizer=False)
            wave_models.append(model)
            print('class {}, precision {}, recall {}, auc {}'.format(y_col_idx, precision, recall, auc))
            with open("log_train.txt", "a") as myfile:
                myfile.write('class {}, precision {}, recall {}, auc {}\n'.format(y_col_idx, precision, recall, auc))


    # # random forest tantsa

    if b_train_rf:
        train_class_label_df=pd.DataFrame(class_list_train, columns=class_names)
        train_features_df=pd.DataFrame(train_features, columns=["RR_fea_1", "RR_fea_2","encoder_out", 
                                       "beat_fea_1", "beat_fea_2", "beat_fea_3", "beat_fea_4", "wave_fea"])
        train_class_label_df = train_class_label_df[train_features_df['beat_fea_4'].notna()]
        train_features_df = train_features_df[train_features_df['beat_fea_4'].notna()]

        test_class_label_df=pd.DataFrame(class_list_test, columns=class_names)
        test_features_df=pd.DataFrame(test_features, columns=["RR_fea_1", "RR_fea_2","encoder_out", 
                                       "beat_fea_1", "beat_fea_2", "beat_fea_3", "beat_fea_4", "wave_fea"])
        test_class_label_df = test_class_label_df[test_features_df['beat_fea_4'].notna()]
        test_features_df = test_features_df[test_features_df['beat_fea_4'].notna()]    

        N_train=len(train_features_df)
        N_test=len(test_features_df)
        x_train=[[train_features_df["RR_fea_1"].iloc[i]]+ [train_features_df["RR_fea_2"].iloc[i]]+train_features_df["encoder_out"].iloc[i][0].tolist() for i in range(N_train)]
        x_val=[[test_features_df["RR_fea_1"].iloc[i]]+ [test_features_df["RR_fea_2"].iloc[i]]+test_features_df["encoder_out"].iloc[i][0].tolist() for i in range(N_test)]

        aucs=[]
        rf_models=[]
        for y_col_idx in range(24):

            y_train=train_class_label_df[class_names[y_col_idx]].iloc[:].values
            y_val=test_class_label_df[class_names[y_col_idx]].iloc[:].values

            log_dir = "class_enseble_"+str(y_col_idx)
            if not os.path.exists(log_dir):
                os.makedirs(log_dir)
            print("Logging to {}".format(log_dir))

            rf = RandomForestClassifier(n_estimators = 1000, random_state = 42)
            rf.fit(x_train, y_train)

            pred=rf.predict_proba(x_val)
            if pred.shape[1] == 1:
                auc=0
                aucs.append(auc)
            else:
                try:
                    auc=roc_auc_score(y_val, pred[:,1])
                    aucs.append(auc)
                except:
                    aucs.append(-1)
            
            # save
            joblib.dump(rf, os.path.join(output_training_directory,str(y_col_idx)+"_beat_class_model.joblib"))

            rf_models.append(rf)

            # load
            #loaded_rf = joblib.load(str(y_col_idx)+"_beat_class_model.joblib")

            print('class {}, auc {}'.format(y_col_idx, auc))
            with open("log_ensemble_train.txt", "a") as myfile:
                myfile.write('class {}, auc {}\n'.format(y_col_idx, auc))



    # # szvvers alap tants

    if b_train_beat:
        train_class_label_df=pd.DataFrame(class_list_train, columns=class_names)
        train_features_df=pd.DataFrame(train_features, columns=["RR_fea_1", "RR_fea_2","encoder_out", 
                                       "beat_fea_1", "beat_fea_2", "beat_fea_3", "beat_fea_4", "wave_fea"])
        train_class_label_df = train_class_label_df[train_features_df['beat_fea_4'].notna()]
        train_features_df = train_features_df[train_features_df['beat_fea_4'].notna()]

        test_class_label_df=pd.DataFrame(class_list_test, columns=class_names)
        test_features_df=pd.DataFrame(test_features, columns=["RR_fea_1", "RR_fea_2","encoder_out", 
                                       "beat_fea_1", "beat_fea_2", "beat_fea_3", "beat_fea_4", "wave_fea"])
        test_class_label_df = test_class_label_df[test_features_df['beat_fea_4'].notna()]
        test_features_df = test_features_df[test_features_df['beat_fea_4'].notna()]    

        N_train=len(train_class_label_df)
        N_test=len(test_class_label_df)
        x_train=[np.stack(train_features_df[["beat_fea_1", "beat_fea_2", "beat_fea_3", "beat_fea_4"]].iloc[i]).T for i in range(N_train)]
        x_val=[np.stack(train_features_df[["beat_fea_1", "beat_fea_2", "beat_fea_3", "beat_fea_4"]].iloc[i]).T for i in range(N_test)]

        y_train=train_class_label_df[normal_class].iloc[:].values
        y_val=test_class_label_df[normal_class].iloc[:].values

        log_dir = "class_beat_"+str(normal_col_idx)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        print("Logging to {}".format(log_dir))


        model, precision, recall, auc = start_training_beat(np.asarray(x_train),np.asarray(y_train)
                                    ,np.asarray(x_val),np.asarray(y_val),log_dir,None,BeatCNN_first_epoch_num)


        model.save(os.path.join(output_training_directory,str(normal_col_idx)+"_beat_class_model.h5"), include_optimizer=False)
        print('class {}, precision {}, recall {}, auc {}'.format(normal_col_idx, precision, recall, auc))
        with open("log_beat_train.txt", "a") as myfile:
            myfile.write('class {}, precision {}, recall {}, auc {}\n'.format(normal_col_idx, precision, recall, auc))

        model_normal=tf.keras.models.clone_model(model)

        for y_col_idx in range(24):

            if y_col_idx==normal_col_idx:
                continue

            y_train=train_class_label_df[class_names[y_col_idx]].iloc[:].values
            y_val=test_class_label_df[class_names[y_col_idx]].iloc[:].values

            log_dir = "class_beat_"+str(y_col_idx)
            if not os.path.exists(log_dir):
                os.makedirs(log_dir)
            print("Logging to {}".format(log_dir))


            model, precision, recall, auc = start_training_beat(np.asarray(x_train),np.asarray(y_train)
                                        ,np.asarray(x_val),np.asarray(y_val),log_dir,model_normal,BeatCNN_next_epoch_num)

            model.save(os.path.join(output_training_directory,str(y_col_idx)+"_beat_class_model.h5"), include_optimizer=False)
            print('class {}, precision {}, recall {}, auc {}'.format(y_col_idx, precision, recall, auc))
            with open("log_beat_train.txt", "a") as myfile:
                myfile.write('class {}, precision {}, recall {}, auc {}\n'.format(y_col_idx, precision, recall, auc))


    if b_train_ensemble:
        #wave_models=[]
        #for i in range(24):
        #    model = tf.keras.models.load_model(os.path.join(output_training_directory,str(i)+"_class_model.h5"))
        #    wave_models.append(model)
        #    print(i)

        log_dir = "wave_rf_class_map"
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        probs_vector_list_wave_rf=[]
        for i, features in enumerate(test_features):
            
            wave_fea=np.expand_dims(features[7].T,0)
            probs_wave=[]
            for clidx in range(24):
                p=wave_models[clidx].predict_proba(wave_fea)
                probs_wave.append(p[0][0])
                #print(probs_wave)
            #print(probs_wave)

            if features[0] is None:
                probs_RF=[0]*24
            else:
                RF_fea=[features[0]]+[features[1]]+features[2][0].tolist()
                RF_fea=np.expand_dims(np.asarray(RF_fea),0)
                probs_RF=[]
                
                for cli in range(24):
                    rf=rf_models[cli]
                    pred=rf.predict_proba(RF_fea)
                    
                    if pred.shape[1] == 1:
                        probs_RF.append(0)
                    else:
                        probs_RF.append(pred[:,1][0])

            probs_vector_list_wave_rf.append(probs_wave+probs_RF)

        N=len(test_features)
        x_tr=np.asarray(probs_vector_list_wave_rf)
        y_tr=np.asarray(class_list_test)
        x_te=np.asarray(probs_vector_list_wave_rf)
        y_te=np.asarray(class_list_test)

        model, precision, recall, auc = start_training_output_map2(x_tr,y_tr,x_te,y_te,log_dir,1500)

        model.save(os.path.join(output_training_directory,"wave_rf_class_map_model.h5"), include_optimizer=False)
        print('class {}, precision {}, recall {}, auc {}'.format(normal_col_idx, precision, recall, auc))

