#!/usr/bin/env python

b_train_rf=True
b_train_wave=True
b_train_beat=False
b_train_ensemble=True

class_names=['270492004', '164889003', '164890007', '426627000', '713427006', '713426002', '445118002', '39732003', '164909002', '251146004', '698252002', '10370003', '284470004', '427172004', '164947007', '111975006', '164917005', '47665007', '427393009', '426177001', '426783006', '427084000', '164934002', '59931005']
normal_class='426783006'
normal_col_idx=class_names.index(normal_class)
b_out=False

import numpy as np, os, sys
from scipy.io import loadmat
import pandas as pd
import glob
import numpy as np
from scipy.signal import butter, lfilter, filtfilt
import scipy.signal 
import pickle
from ecgdetectors import Detectors
from sklearn.ensemble import RandomForestClassifier 
from sklearn.metrics import roc_auc_score
import joblib

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Conv1D, BatchNormalization, MaxPooling1D, Dense, Permute, Reshape, Bidirectional, LSTM, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model

from numpy.random import seed
seed(1)
from tensorflow import set_random_seed
set_random_seed(2)
import random
random.seed(4)

detectors = Detectors(100)



class_names=['270492004', '164889003', '164890007', '426627000', '713427006', '713426002', '445118002', '39732003', '164909002', '251146004', '698252002', '10370003', '284470004', '427172004', '164947007', '111975006', '164917005', '47665007', '427393009', '426177001', '426783006', '427084000', '164934002', '59931005']
normal_class='426783006'
normal_col_idx=class_names.index(normal_class)
#class_thresh = [0.3]*24
class_thresh = [0.1]*24

def get_header_info(input_directory,files,statDataFrame,scored_classes_list):
    
    i=0

    for f in files:
        g = f.replace('.mat','.hea')
        input_file = os.path.join(input_directory,g)
        
        classes=set()
        sample_num=-1
        sample_freq=-1
        
        new_line=dict( (a,0) for a in statDataFrame.columns)
        
        with open(input_file,'r') as file:
            for lines in file:
                if sample_num<0:
                    tmp = lines.split(' ')
                    sample_num=int(tmp[3])
                    sample_freq=int(tmp[2])
                if lines.startswith('#Dx'):
                    tmp = lines.split(': ')[1].split(',')
                    for c in tmp:
                        c=c.strip()
                        if c=='59118001':
                            c='713427006'
                        if c=='63593006':
                            c='284470004'
                        if c=='17338001':
                            c='427172004'
                        classes.add(c)
        
        new_line['database']=input_directory
        new_line['file']=f
        new_line['sample num']=sample_num
        new_line['sample freq']=sample_freq
     
        for c in scored_classes_list:
            if c in classes:
                #print(c)
                new_line[c]=True
            else:
                new_line[c]=False
        statDataFrame=statDataFrame.append(new_line,ignore_index=True)
        
        if i%500==0:
            print(i)
            #break
        i=i+1
    return statDataFrame

# DISCLAIMER: This function is copied from https://github.com/nwhitehead/swmixer/blob/master/swmixer.py, 
#             which was released under LGPL. 
def resample_by_interpolation(signal, input_fs, output_fs):

    scale = output_fs / input_fs
    # calculate new length of sample
    n = round(len(signal) * scale)

    # use linear interpolation
    # endpoint keyword means than linspace doesn't go all the way to 1.0
    # If it did, there are some off-by-one errors
    # e.g. scale=2.0, [1,2,3] should go to [1,1.5,2,2.5,3,3]
    # but with endpoint=True, we get [1,1.4,1.8,2.2,2.6,3]
    # Both are OK, but since resampling will often involve
    # exact ratios (i.e. for 44100 to 22050 or vice versa)
    # using endpoint=False gets less noise in the resampled sound
    resampled_signal = np.interp(
        np.linspace(0.0, 1.0, n, endpoint=False),  # where to interpret
        np.linspace(0.0, 1.0, len(signal), endpoint=False),  # known positions
        signal,  # known data points
    )
    return resampled_signal



def bandpass_filter_and_resample_12_channel(data_12,lowcut, highcut, signal_freq, filter_order, resample_freq):
    new_data_12 = np.full_like(data_12, 0)
    for i in range(12):
        data=data_12[i,:]
        data=data-np.min(data)
        mx=np.max(data)
        if mx==0:
            max=1
        data=data/mx
        data=data-np.mean(data)
        nyquist_freq = 0.5 * signal_freq
        low = lowcut / nyquist_freq
        high = highcut / nyquist_freq
        b, a = butter(filter_order, [low,high], btype="bandpass")
        y = filtfilt(b, a, data)
        y=resample_by_interpolation(y, signal_freq, resample_freq)
        new_data_12[i,:len(y)]=y
    new_data_12=new_data_12[:,:len(y)]
    return new_data_12

def getBeats(wave_12_data_list):
    peaksList=[]
    ekg_1_beat=[]
    ekg_2_beat=[]
    ekg_3_beat=[]
    ekg_4_beat=[]
    for i in range(len(wave_12_data_list)):
        ekg=wave_12_data_list[i][11]
        r_peaks = detectors.swt_detector(ekg)

        r_peaks_rev=[]
        prev_peak=0

        for j in r_peaks:
            if j>=len(ekg):
                continue
            while j>0 and ekg[j-1]>ekg[j]:
                j=j-1
            while j>0 and ekg[j-1]<ekg[j]:
                j=j-1
            if j-prev_peak>30:
                r_peaks_rev.append(j)
            prev_peak=j
        r_peaks=r_peaks_rev

        peaksList.append(r_peaks)

        try:
            ekg_part1=np.asarray(wave_12_data_list[i][:])[:,r_peaks[1]:r_peaks[2]]
            ekg_1_beat.append(ekg_part1)
        except:
            ekg_1_beat.append([])
        try:
            ekg_part2=np.asarray(wave_12_data_list[i][:])[:,r_peaks[2]:r_peaks[3]]
            ekg_2_beat.append(ekg_part2)
        except:
            ekg_2_beat.append([])
        try:
            ekg_part3=np.asarray(wave_12_data_list[i][:])[:,r_peaks[3]:r_peaks[4]]
            ekg_3_beat.append(ekg_part3)
        except:
            ekg_3_beat.append([])
        try:
            ekg_part4=np.asarray(wave_12_data_list[i][:])[:,r_peaks[4]:r_peaks[5]]
            ekg_4_beat.append(ekg_part4)
        except:
            ekg_4_beat.append([])

        if i%1000==0:
            print(i)
            
    return peaksList,ekg_1_beat,ekg_2_beat,ekg_3_beat,ekg_4_beat

def extendWithZeros(fea,length):
    N=min(len(fea), length)
    retval=np.zeros(length)
    retval[:N]=fea[:N]
    return retval

def getBeatFeatures(wave_12_data,rf,encoder_model):
    ch_to_listen=[0,1,3,5]
    ekg=wave_12_data[11]
    ekg_1=wave_12_data[1]
    r_peaks = detectors.swt_detector(ekg)

    r_peaks_rev=[]
    prev_peak=0

    for j in r_peaks:
        if j>=len(ekg):
            continue
        while j>0 and ekg[j-1]>ekg[j]:
            j=j-1
        while j>0 and ekg[j-1]<ekg[j]:
            j=j-1
        if j-prev_peak>30:
            r_peaks_rev.append(j)
        prev_peak=j
    r_peaks=r_peaks_rev
    
    interesting_part_begin_idx=0

    if len(r_peaks)>4:
        preds=[]
        for peak_idx in range(len(r_peaks)-1):
            peak=ekg_1[r_peaks[peak_idx]:r_peaks[peak_idx+1]]

            resampled_signal = np.interp(
                np.linspace(0.0, 1.0, 80, endpoint=False),  # where to interpret
                np.linspace(0.0, 1.0, len(peak), endpoint=False),  # known positions
                peak)
            test_factor=80.0/len(peak)
            
            encoder_out = encoder_model.predict(np.expand_dims(resampled_signal,axis=0))
            pred=rf.predict_proba(np.expand_dims(np.append(encoder_out,test_factor),0))
            preds.append(pred[0][1]) # the probability of normal

        sorted_idxs=np.argsort(preds).tolist() #the less probability of normal
        beat_fea_1=ekg_1[r_peaks[sorted_idxs[0]]:r_peaks[1+sorted_idxs[0]]]
        beat_fea_1=extendWithZeros(beat_fea_1,200)
        beat_fea_2=ekg_1[r_peaks[sorted_idxs[1]]:r_peaks[1+sorted_idxs[1]]]
        beat_fea_2=extendWithZeros(beat_fea_2,200)
        beat_fea_3=ekg_1[r_peaks[sorted_idxs[2]]:r_peaks[1+sorted_idxs[2]]]
        beat_fea_3=extendWithZeros(beat_fea_3,200)
        beat_fea_4=ekg_1[r_peaks[sorted_idxs[3]]:r_peaks[1+sorted_idxs[3]]]
        beat_fea_4=extendWithZeros(beat_fea_4,200) 

        interesting_part_begin_idx=max(0,r_peaks[1+sorted_idxs[0]]-1000)
        
        RR_fea_1=np.mean(np.diff(r_peaks))
        RR_fea_2=np.std(np.diff(r_peaks))
        
    else:
        beat_fea_1=None
        beat_fea_2=None
        beat_fea_3=None
        beat_fea_4=None
        encoder_out=None

        RR_fea_1=None
        RR_fea_2=None

    #4 channel waveform features

    ekg_part=wave_12_data[ch_to_listen]
    interesting_part_begin_idx=max(0,min(interesting_part_begin_idx,ekg_part.shape[1]-1000))
    wave_fea=np.zeros((4,1000))
    l=min(1000,ekg_part.shape[1]-interesting_part_begin_idx)
    wave_fea[:,:l]=ekg_part[:,interesting_part_begin_idx:interesting_part_begin_idx+l]

    return RR_fea_1, RR_fea_2,encoder_out, beat_fea_1, beat_fea_2, beat_fea_3, beat_fea_4, wave_fea

class CRNN:

    def __init__(self, num_classes=1):
        self.num_classes = num_classes
        self.model = Sequential()

    def build_model(self, input_shape, weight_decay=0.001, convolution_activation='relu', padding='same',
                    pool_size=2, strides=2, 
                    output_layer_activation='sigmoid'):
        
        kernel_regularizer = l2(weight_decay)

        # Layer 1
        self.model.add(Conv1D(16, kernel_size=7, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer, input_shape=input_shape[1:]))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))

        # Layer 2
        self.model.add(Conv1D(32, kernel_size=5, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))

        # Layer 3
        self.model.add(Conv1D(64, kernel_size=3, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))

        self.model.add(BatchNormalization())

        self.model.add(Bidirectional(LSTM(100, return_sequences=False), merge_mode="concat"))
        self.model.add(BatchNormalization())

        self.model.add(Dense(self.num_classes, activation=output_layer_activation))

        print(self.model.summary())
        return self.model

class CRNN_beat:

    def __init__(self, num_classes=1):
        self.num_classes = num_classes
        self.model = Sequential()

    # TODO: tune model hyper-parameters
    def build_model(self, input_shape, weight_decay=0.001, convolution_activation='relu', padding='same',
                    pool_size=2, strides=1, 
                    output_layer_activation='sigmoid'):
        
        kernel_regularizer = l2(weight_decay)

        self.model.add(Conv1D(8, kernel_size=5, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer, input_shape=input_shape[1:]))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))
        self.model.add(BatchNormalization())

        # Layer 2
        self.model.add(Conv1D(16, kernel_size=3, padding=padding, activation=convolution_activation,
                              kernel_regularizer=kernel_regularizer))
        self.model.add(MaxPooling1D(pool_size=pool_size, strides=strides))
        self.model.add(BatchNormalization())

        self.model.add(Bidirectional(LSTM(50, return_sequences=True) ))
        self.model.add(Bidirectional(LSTM(50))) 

        self.model.add(Dense(self.num_classes, activation=output_layer_activation))

        print(self.model.summary())
        return self.model

#-----------------------------
def run_12ECG_classifier(data,header_data,models):

    num_classes = len(class_names)
    #print('class num.: {}'.format(num_classes))
    current_label = np.zeros(num_classes, dtype=int)
    current_score = np.zeros(num_classes)

    beat_encoder=models[0]
    beat_rf=models[1]
    wave_models=models[2]
    beat_models=models[3]
    rf_models=models[4]
    output_map_wave_rf=models[5]

    tmp_hea = header_data[0].split(' ')
    #ptID = tmp_hea[0]
    #num_leads = int(tmp_hea[1])
    sample_freq= int(tmp_hea[2])

    new_data_12=bandpass_filter_and_resample_12_channel(data, 0.5, 40, sample_freq, 3, 100)
    ecg = np.array(new_data_12)
    features=getBeatFeatures(ecg,beat_rf,beat_encoder)

    scores_wave=[]
    if b_train_wave:
        wave_fea=np.expand_dims(features[7].T,0)

        for i in range(24):
            #label = model.predict_classes(feats_reshape)
            p = wave_models[i].predict_proba(wave_fea)
            scores_wave.append(p[0][0])

    scores_beat=[]
    if b_train_beat:
        if features[1]==None:
            scores_beat=[0]*24
            scores_beat[20]=1
        else:  
            beat_fea=[features[3],features[4],features[5],features[6]]   
            beat_fea=np.expand_dims(np.asarray(beat_fea).T,0)
            for i in range(24):
                #label = model.predict_classes(feats_reshape)
                p = beat_models[i].predict_proba(beat_fea)
                scores_beat.append(p[0][0])
        scores=scores_beat

    scores_rf=[]
    if b_train_rf:
        if features[1]==None:
            scores_rf=[0]*24
            scores_rf[20]=1
        else:
            RF_fea=[features[0]]+[features[1]]+features[2][0].tolist()
            RF_fea=np.expand_dims(np.asarray(RF_fea),0)

            for cli in range(24):
                pred=rf_models[cli].predict_proba(RF_fea)
                if pred.shape[1] == 1:
                    scores_rf.append(0)
                else:
                    scores_rf.append(pred[:,1][0])


    if b_train_ensemble:
        scores=output_map_wave_rf.predict_proba( np.expand_dims(np.asarray(scores_wave+scores_rf),0) )
        scores=scores[0].tolist()

    #print('label.: {}'.format(label))
    print('probs.: {}'.format(scores))

    #current_label[label] = 1
    for i in range(len(scores)):
        if scores[i]>class_thresh[i]:
            current_label[i]=1

    #for i in range(num_classes):
    #    current_score[i] = np.array(scores[i])
    current_score = scores

    return current_label, current_score, class_names

def load_12ECG_model(input_directory):    # load the model from disk 

    #input_directory="./output_training_directory/"
    output_training_directory=input_directory

    print("loading beat encoder...")
    beat_encoder = tf.keras.models.load_model(os.path.join(output_training_directory,'one_beat_encoder'))
    beat_encoder.summary()

    print("loading beat score model...")
    beat_rf = joblib.load(os.path.join(output_training_directory,"normal_beat_rf.joblib"))

    print("loading 'wave' models...")
    wave_models=[]
    if  b_train_wave:
        for i in range(24):
            model = tf.keras.models.load_model(os.path.join(output_training_directory,str(i)+"_class_model.h5"))
            wave_models.append(model)
            print(i)

    print("loading 'beat' models...")
    beat_models=[]
    if b_train_beat:
        for i in range(24):
            model = tf.keras.models.load_model(os.path.join(output_training_directory,str(i)+"_beat_class_model.h5"))
            beat_models.append(model)
            print(i)

    print("loading rf classifier...")
    rf_models=[]
    if b_train_rf:
        for i in range(24):
            model= joblib.load(os.path.join(output_training_directory,str(i)+"_beat_class_model.joblib"))
            rf_models.append(model)
            print(i)

    output_map_wave=[]
    if b_train_ensemble:
        print("loading output model...")
        output_map_wave_rf = tf.keras.models.load_model(os.path.join(output_training_directory,"wave_rf_class_map_model.h5"))


    return [beat_encoder, beat_rf, wave_models, beat_models, rf_models, output_map_wave_rf]
