import glob
from scipy.io import loadmat
import numpy as np
from os import remove
import pandas as pd
from keras.preprocessing import sequence
from keras.models import Sequential,Model
from keras.layers import Dense, LSTM, Merge, Masking, Dropout, Conv1D, MaxPooling1D
import warnings
warnings.filterwarnings("ignore")
import time
import pickle
###################################################################################

                               #Util Function#

###################################################################################

start_time = time.time()

def train_valid_set():
    valid_name = []
    valid_data = []
    for d in glob.glob('*_data.mat'):
        data = loadmat(d,squeeze_me=True)
        remove(d)
        id_n = d.split('_')[0]
        valid_name.append(id_n)
        valid_data.append(data)
    return valid_name,valid_data

def get_column_val(df_val,column):
    return df_val[column].values

def get_column_pad_value(df_val,column,val=0.):
    max_valid = 175
    X_valid = df_val[column].values
    X_valid = sequence.pad_sequences(X_valid, maxlen=max_valid,value=val,padding='post',dtype='float32')
    return X_valid

def mid_rr_tm(r_t):
    middle_tm_rr = []
    for i in range(len(r_t) - 1):
        middle_tm_rr.append((r_t[i] + r_t[i + 1]) / 2.0)
    return middle_tm_rr

def rr_window(ecg,tm,qrs,mid_rr):
    mid_rr_ind = []
    windows = []
    for i in range(1,len(qrs)-1) :
        start_r = tm[qrs[i]]
        for n in range(qrs[i], 0, -1):
            if(start_r - tm[n] >= start_r - mid_rr[i-1] or start_r - tm[n] >= 0.47):
                mid_rr_ind.append(n)
                break
        for n in range(qrs[i],len(ecg)):
            if (tm[n] - start_r >= mid_rr[i] - start_r or tm[n] - start_r >= 0.47):
                mid_rr_ind.append(n)
                break
    if(len(mid_rr_ind)>1):
        for i in range(0,len(mid_rr_ind),2):
            windows.append(ecg[mid_rr_ind[i]:mid_rr_ind[i+1]])
    else :
        windows.append([])
    return np.array(windows)

##################################################################################

valid_name,valid_data = train_valid_set()
df_valid = pd.DataFrame(valid_data,index=valid_name)
df_valid = df_valid.select(lambda x: "__" not in x, axis=1)

max_windows = 284
max_len_window = 175 * 284
pad_val = 0.

###################################################################################

                               #DATA EXTRACTION#

###################################################################################

ecg_data_valid = get_column_val(df_valid,column='ecg')
tm_data_valid = get_column_val(df_valid,column='tm')
qrs_data_valid = get_column_val(df_valid,column='qrs')
ramp_valid = get_column_val(df_valid,column="r_amp")
rt_valid = get_column_val(df_valid,column="r_t")
hrv_valid = get_column_pad_value(df_valid,column="hrv")

ecg_windows_valid = []

for i in range(ecg_data_valid.shape[0]):
    qrs_ind = [int(x - 1) for x in qrs_data_valid[i]]
    windows = rr_window(ecg_data_valid[i], tm_data_valid[i], qrs_ind, mid_rr_tm(rt_valid[i]))
    windows = sequence.pad_sequences(windows, maxlen=max_windows, value=pad_val, padding='post', dtype='float32')
    windows = windows.reshape(windows.shape[0]*windows.shape[1])
    ecg_windows_valid.append(windows)

ecg_windows_valid = np.array(ecg_windows_valid)

ecg_windows_valid = sequence.pad_sequences(ecg_windows_valid, maxlen=max_len_window, value=pad_val, padding='post',truncating='post',dtype='float32')

ecg_windows_valid = ecg_windows_valid.reshape(ecg_windows_valid.shape[0],ecg_windows_valid.shape[1]//max_windows,max_windows)

##############################################
#             END OF DATA PROCESSING         #
##############################################

##############################################
#          LOADING THE TRAIN MODEL ...       #
##############################################

batch_size = 1

k1 = 3


modelHrvAFvsAll = Sequential()
modelHrvAFvsAll.add(Conv1D(filters=32, kernel_size=k1, activation='relu', padding='valid', input_shape=(175,1),kernel_initializer='he_normal'))
modelHrvAFvsAll.add(MaxPooling1D())
modelHrvAFvsAll.add(Dropout(0.05))
modelHrvAFvsAll.add(Conv1D(filters=64, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelHrvAFvsAll.add(MaxPooling1D())
modelHrvAFvsAll.add(Dropout(0.1))
modelHrvAFvsAll.add(Conv1D(filters=128, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelHrvAFvsAll.add(MaxPooling1D())
modelHrvAFvsAll.add(Dropout(0.15))

modelEcgAFvsAll = Sequential()
modelEcgAFvsAll.add(Conv1D(filters=32, kernel_size=k1, activation='relu', padding='valid', input_shape=ecg_windows_valid.shape[1:],kernel_initializer='he_normal'))
modelEcgAFvsAll.add(MaxPooling1D())
modelEcgAFvsAll.add(Dropout(0.05))
modelEcgAFvsAll.add(Conv1D(filters=64, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelEcgAFvsAll.add(MaxPooling1D())
modelEcgAFvsAll.add(Dropout(0.1))
modelEcgAFvsAll.add(Conv1D(filters=128, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelEcgAFvsAll.add(MaxPooling1D())
modelEcgAFvsAll.add(Dropout(0.15))

modelAFvsAll = Sequential()
modelAFvsAll.add(Merge([modelEcgAFvsAll,modelHrvAFvsAll], mode='concat'))
modelAFvsAll.add(Masking(mask_value=0.))
modelAFvsAll.add(LSTM(64, return_sequences=True,kernel_initializer='he_normal'))
modelAFvsAll.add(LSTM(64,name='Output',kernel_initializer='he_normal'))
modelAFvsAll.add(Dense(2, activation='softmax', kernel_initializer='he_normal'))

modelAFvsAll.load_weights("Final_CRNN_AF_VS_ALL.hdf5")

##############################################

modelHrvNvsAll = Sequential()
modelHrvNvsAll.add(Conv1D(filters=32, kernel_size=k1, activation='relu', padding='valid', input_shape=(175,1),kernel_initializer='he_normal'))
modelHrvNvsAll.add(MaxPooling1D())
modelHrvNvsAll.add(Dropout(0.05))
modelHrvNvsAll.add(Conv1D(filters=64, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelHrvNvsAll.add(MaxPooling1D())
modelHrvNvsAll.add(Dropout(0.1))
modelHrvNvsAll.add(Conv1D(filters=128, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelHrvNvsAll.add(MaxPooling1D())
modelHrvNvsAll.add(Dropout(0.15))

modelEcgNvsAll = Sequential()
modelEcgNvsAll.add(Conv1D(filters=32, kernel_size=k1, activation='relu', padding='valid', input_shape=ecg_windows_valid.shape[1:],kernel_initializer='he_normal'))
modelEcgNvsAll.add(MaxPooling1D())
modelEcgNvsAll.add(Dropout(0.05))
modelEcgNvsAll.add(Conv1D(filters=64, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelEcgNvsAll.add(MaxPooling1D())
modelEcgNvsAll.add(Dropout(0.1))
modelEcgNvsAll.add(Conv1D(filters=128, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelEcgNvsAll.add(MaxPooling1D())
modelEcgNvsAll.add(Dropout(0.15))

modelNvsAll = Sequential()
modelNvsAll.add(Merge([modelEcgNvsAll,modelHrvNvsAll], mode='concat'))
modelNvsAll.add(Masking(mask_value=0.))
modelNvsAll.add(LSTM(64, return_sequences=True,kernel_initializer='he_normal'))
modelNvsAll.add(LSTM(64,name='Output',kernel_initializer='he_normal'))
modelNvsAll.add(Dense(2, activation='softmax', kernel_initializer='he_normal'))

modelNvsAll.load_weights("Final_CRNN_N_VS_ALL.hdf5")

##############################################

modelHrvOvsNoisy = Sequential()
modelHrvOvsNoisy.add(Conv1D(filters=32, kernel_size=k1, activation='relu', padding='valid', input_shape=(175,1),kernel_initializer='he_normal'))
modelHrvOvsNoisy.add(MaxPooling1D())
modelHrvOvsNoisy.add(Dropout(0.05))
modelHrvOvsNoisy.add(Conv1D(filters=64, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelHrvOvsNoisy.add(MaxPooling1D())
modelHrvOvsNoisy.add(Dropout(0.1))
modelHrvOvsNoisy.add(Conv1D(filters=128, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelHrvOvsNoisy.add(MaxPooling1D())
modelHrvOvsNoisy.add(Dropout(0.15))

modelEcgOvsNoisy = Sequential()
modelEcgOvsNoisy.add(Conv1D(filters=32, kernel_size=k1, activation='relu', padding='valid', input_shape=ecg_windows_valid.shape[1:],kernel_initializer='he_normal'))
modelEcgOvsNoisy.add(MaxPooling1D())
modelEcgOvsNoisy.add(Dropout(0.05))
modelEcgOvsNoisy.add(Conv1D(filters=64, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelEcgOvsNoisy.add(MaxPooling1D())
modelEcgOvsNoisy.add(Dropout(0.1))
modelEcgOvsNoisy.add(Conv1D(filters=128, kernel_size=k1, activation='relu', padding='valid',kernel_initializer='he_normal'))
modelEcgOvsNoisy.add(MaxPooling1D())
modelEcgOvsNoisy.add(Dropout(0.15))

modelOvsNoisy = Sequential()
modelOvsNoisy.add(Merge([modelEcgOvsNoisy,modelHrvOvsNoisy], mode='concat'))
modelOvsNoisy.add(Masking(mask_value=0.))
modelOvsNoisy.add(LSTM(64, return_sequences=True,kernel_initializer='he_normal'))
modelOvsNoisy.add(LSTM(64,name='Output',kernel_initializer='he_normal'))
modelOvsNoisy.add(Dense(2, activation='softmax', kernel_initializer='he_normal'))

modelOvsNoisy.load_weights("Final_CRNN_O_VS_Noisy.hdf5")

##############################################

intermediate_layer_model_AFvsALL = Model(inputs=modelAFvsAll.input,
                                outputs=modelAFvsAll.get_layer('Output').output)


intermediate_layer_model_NvsALL = Model(inputs=modelNvsAll.input,
                                outputs=modelNvsAll.get_layer('Output').output)

intermediate_layer_model_OvsNoisy = Model(inputs=modelOvsNoisy.input,
                                outputs=modelOvsNoisy.get_layer('Output').output)

##############################################

class_pred = ''

minMaxScaler1 = pickle.load(open('MinMax1.pkl', "rb"))
minMaxScaler2 = pickle.load(open('MinMax2.pkl', "rb"))
minMaxScaler3 = pickle.load(open('MinMax3.pkl', "rb"))

svcAFvsALL = pickle.load(open('SvmAFvsALL.pkl', "rb"))
svcNvsALL = pickle.load(open('SvmNvsALL.pkl', "rb"))
svcOvsNoisy = pickle.load(open('SvmOvsNoisy.pkl', "rb"))

for i in range(len(ecg_windows_valid)):
    hrv_valid_minMax = minMaxScaler1.transform(np.array([hrv_valid[i]]))
    hrv_valid_minMax = hrv_valid_minMax.reshape(hrv_valid_minMax.shape[1], 1)
    interd_AF = intermediate_layer_model_AFvsALL.predict([np.array([ecg_windows_valid[i]]),np.array([hrv_valid_minMax])])
    AF = svcAFvsALL.predict(interd_AF)
    if(AF == 1):
        hrv_valid_minMax = minMaxScaler2.transform(np.array([hrv_valid[i]]))
        hrv_valid_minMax = hrv_valid_minMax.reshape(hrv_valid_minMax.shape[1], 1)
        interd_N = intermediate_layer_model_NvsALL.predict([np.array([ecg_windows_valid[i]]), np.array([hrv_valid_minMax])])
        N = svcNvsALL.predict(interd_N)
        if(N == 1):
            hrv_valid_minMax = minMaxScaler3.transform(np.array([hrv_valid[i]]))
            hrv_valid_minMax = hrv_valid_minMax.reshape(hrv_valid_minMax.shape[1], 1)
            interd_OvsNoisy = intermediate_layer_model_OvsNoisy.predict([np.array([ecg_windows_valid[i]]), np.array([hrv_valid_minMax])])
            OvsNoisy = svcOvsNoisy.predict(interd_OvsNoisy)
            if (OvsNoisy == 0):
                class_pred = 'O'
            else :
                class_pred = '~'
        else :
            class_pred = 'N'
    else:
        class_pred = 'A'

y_lab = df_valid.index.tolist()

file = open("answers.txt","a")
file.write(y_lab[0]+","+class_pred+"\n")
file.close()

#print("Time Script : ", time.time() - start_time, 'seconds')

#print()
#print()