'''
This program trains a binary to determine if the input data represents a normal or an abnormal cardiac condition.
'''

import os, glob
import shutil
from scipy import signal
from scipy.io import loadmat
from scipy.sparse import csr_matrix
import numpy as np
import pandas as pd
import joblib
from sklearn import svm
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Input, Dense
from keras.layers import LSTM
from keras.optimizers import Adam
from keras.models import load_model
from sklearn.model_selection import KFold
from keras.callbacks import ModelCheckpoint
from evaluate_12ECG_score import evaluate_12ECG_score
from CNN1D import Simple1DCNN, ResNet1D18, fbeta
from Resnet1D import ResNet

################
# Select method: Deep learning (DL) or Conventional
method = 'conventional'
################

colNames = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
labelTypes = ['Normal', 'AF', 'I-AVB', 'LBBB', 'RBBB', 'PAC', 'PVC', 'STD', 'STE']

sampleRate = 500
sampleFreq = 1.0/np.float(sampleRate)
n_timesteps = int(1.0 / sampleFreq) # seconds
n_features = len(colNames)
n_classes = 1

if method.lower() == 'dl':
    # Load the data
    print('Loading data...')
    dataDictList = joblib.load(r'D:\PhysioNet\Training_WFDB\dataDictList.sav')

    # Identify which data entries are normal and which are not
    for i in range(len(dataDictList)):
        dataDictList[i]['isNormal'] = 1 if dataDictList[i]['labels'][0] == 1 else 0

    def assembleData(nSamplesPerPatient=1):
        nPatients = len(dataDictList)
        windowLength = len(dataDictList[0]['data'][0,:,0])
        X = np.empty((nSamplesPerPatient*nPatients, windowLength, len(colNames)))
        y = []
        for p in range(nPatients):
            patientData = dataDictList[p]['data'][:nSamplesPerPatient,:,:]
            X[p*nSamplesPerPatient:(p+1)*nSamplesPerPatient, :, :] = patientData
            for i in range(nSamplesPerPatient):
                y.append(dataDictList[p]['isNormal'])
        return X, np.array(y)

    def getSimple1DCNN():
        cnn = Simple1DCNN()
        model = cnn.simpleCNN(n_timesteps, n_features, n_classes, finalActivation='softmax')
        model.compile("adam", "binary_crossentropy", ['accuracy'])
        return model, 'Simple1DCNN'

    def scaleData(X_train, X_test):
        trainShape = X_train.shape
        testShape = X_test.shape
        sc = StandardScaler()
        X_train2D = sc.fit_transform(np.reshape(X_train, (trainShape[0] * trainShape[1], trainShape[2])))
        X_test2D = sc.transform(np.reshape(X_test, (testShape[0] * testShape[1], testShape[2])))
        return np.reshape(X_train2D, trainShape), np.reshape(X_test2D, testShape)

    def trainModel(model, mType, X_train, X_test, y_train, y_test):
        chk = ModelCheckpoint('normal_or_not_model.pkl', monitor='val_loss', save_best_only=True, mode='max', verbose=1)
        model.fit(X_train, y_train, epochs=5,
                  batch_size=256, callbacks=[chk],
                  validation_data=(X_test, y_test))
        return model

    def evaluateModel(mType, X_test, y_test):
        model = load_model('normal_or_not_model.pkl', custom_objects={'fbeta':fbeta})
        score = model.evaluate(X_test, y_test, verbose=True)
        print('Test Accuracy: %.3f' % score[1])
        return

elif method.lower() == 'conventional':
    # Load the data
    print('Loading data...')
    dataDF = pd.read_csv(r'D:\PhysioNet\Training_WFDB\segBeatsDF.csv')

    dataLength = 250
    numPrincComp = 25

    def prepareTrainAndTestSets():
        # Adapted from  https://stackabuse.com/implementing-pca-in-python-with-scikit-learn/

        # Get X and y
        X = dataDF.iloc[:,:dataLength]
        X = pd.concat([X, dataDF[['age','sex']]], axis=1)
        y = dataDF['Normal']

        # Create training and testing datasets
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

        # Apply principal component analysis
        pca = PCA(n_components=numPrincComp)
        X_train_pca = pca.fit_transform(X_train.iloc[:,:dataLength])
        X_train = np.concatenate((X_train_pca, X_train[['age', 'sex']].to_numpy()), axis=1)
        joblib.dump(pca, 'normalOrNot_pca.sav')
        X_test_pca = pca.transform(X_test.iloc[:,:dataLength])
        X_test = np.concatenate((X_test_pca, X_test[['age', 'sex']].to_numpy()), axis=1)

        # Apply scaling
        sc = StandardScaler()
        X_train = sc.fit_transform(X_train)
        joblib.dump(sc, 'normalOrNot_scaler.sav')
        X_test = sc.transform(X_test)
        return X_train, X_test, y_train.to_numpy(), y_test.to_numpy()

    def trainSVM(X_train, y_train):
        clf = svm.SVC()
        print('Training the SVM...')
        clf.fit(X_train, y_train)
        return clf


##################################################################################################
if __name__ == '__main__':

    if method.lower() == 'dl':
        X, y = assembleData(nSamplesPerPatient=2)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
        X_train, X_test = scaleData(X_train, X_test)
        model, mType = getSimple1DCNN()
        trainModel(model, mType, X_train, X_test, y_train, y_test)
        evaluateModel(mType, X_test, y_test)

    elif method.lower() == 'conventional':
        X_train, X_test, y_train, y_test = prepareTrainAndTestSets()
        clf = trainSVM(X_train, y_train)
        y_pred = clf.predict(X_test)
        print("Accuracy:", accuracy_score(y_test, y_pred))
        joblib.dump(clf, 'normalOrNot_svm.sav')