#!/usr/bin/env python

import numpy as np, os, sys, joblib
from scipy.io import loadmat
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from get_12ECG_features import get_random_forest_features, get_ensemble_features
import pandas as pd
import datetime
from DataGenerator import DataGenerator
from Classifiers.CNN1D_Nature import NatureResid1D, get_forked_model
from keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
from sklearn.preprocessing import StandardScaler
#from xgboost import XGBClassifier

dataLength = 4096
batch_size = 64
epochs = 1000
num_leads = 12
trainPerc = 0.90
scalerDataLength = 4096 #int(dataLength/4)
makeScaler = False

doTrainCNN = True
doTrainRF = False
doTrainEnsembleClf = False

logdir = "logs\\train\\"
log_dir = logdir + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# create a map from SNOMED CT codes to the diagnoses
def createClassList(classesDF):
    classList = list(classesDF['SNOMED CT Code'].astype(str))
    return classList

# Create the class map
classesDF = pd.read_csv('dx_mapping_scored.csv')
classList = createClassList(classesDF)

def train_12ECG_classifier(input_directory, output_directory):

    # Determine the classes
    print('Determining classes...')
    header_files = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            header_files.append(g)

    classes = get_classes(input_directory, header_files, classList) # Redundant since we already know the classes of interest
    num_classes = len(classes)
    num_files = len(header_files)
    recordings = np.empty((num_leads, scalerDataLength * num_files))

    if doTrainCNN:
        if makeScaler:
            # Read the data in preparation for training the scaler
            print('Loading data...')
            lastInd = 0
            for i in range(num_files):
                recording, header = load_challenge_data(header_files[i])
                recordingLength = int(np.min([np.size(recording,1), scalerDataLength]))
                recording = recording[:,:recordingLength]
                recordings[:, lastInd:lastInd+recordingLength] = recording
                lastInd += recordingLength

            # JCC - Train and save the scaler
            print('Training the scaler...')
            scaler = StandardScaler()
            scaler.fit(np.transpose(recordings))
            joblib.dump(scaler, os.path.join(output_directory, 'scaler.sav'))
        else:
            print('Loading scaler...')
            scaler = joblib.load(os.path.join(os.getcwd(), 'scaler.sav'))

        # JCC - Separate training and validation data
        allIDs = [os.path.basename(x)[:-4] for x in header_files]
        randomIDs = list(np.random.permutation(len(allIDs)))
        trainIDs = [allIDs[x] for x in randomIDs[:int(trainPerc * len(header_files))]]
        validIDs = [allIDs[x] for x in randomIDs[int(trainPerc * len(header_files)):]]

        # JCC - Print out the validIDs
        joblib.dump(validIDs, os.path.join(output_directory, 'validIDs.sav'))

        # JCC - Load the Random Forest scaler
        rfScaler = joblib.load(os.path.join(os.getcwd(), 'rfScaler.sav'))

        # JCC - Load the imputer
        imputer = joblib.load(os.path.join(os.getcwd(), 'imputer.sav'))

        # JCC - Create data generators
        training_generator = DataGenerator(trainIDs, input_directory, scaler, load_challenge_data, map_class,
                                           classes, batch_size=batch_size, rfScaler=rfScaler, imputer=imputer)
        validation_generator = DataGenerator(validIDs, input_directory, scaler, load_challenge_data, map_class,
                                             classes, batch_size=batch_size, rfScaler=rfScaler, imputer=imputer)

        # Train model.
        print('Training model...')

        # Train the classifier
        # model = NatureResid1D(nClasses=num_classes)
        # model.compile(optimizer="adam", loss="binary_crossentropy",
        #               metrics = ['accuracy'])

        model = get_forked_model(nSamples=dataLength, nClasses=num_classes, nChannels=num_leads)

        modelName = os.path.join(output_directory, 'finalized_model.wts')
        callbacks = [ModelCheckpoint(filepath=modelName, monitor='val_loss',
                                     save_weights_only=True, save_best_only=True),
                     EarlyStopping(monitor='val_loss', patience=3),
                     TensorBoard(log_dir=log_dir, histogram_freq=1)]

        # Train model on dataset
        model.fit_generator(generator=training_generator,
                            validation_data=validation_generator,
                            epochs=epochs,
                            callbacks=callbacks)

        # Save model.
        print('Saving model...')
        joblib.dump(classes, os.path.join(output_directory, 'classes.sav'), protocol=0)

    if doTrainRF:
        # JCC - Train the Random Forest Classifier
        print('Training the Random Forest...')
        train_random_forest_classifier(input_directory, output_directory)

    if doTrainEnsembleClf:
        # JCC - Train the ensemble model
        print('Training the ensemble classifier...')
        train_ensemble_model(input_directory)

    return

def train_random_forest_classifier(input_directory, output_directory):
    # Load data.
    print('Loading data...')

    header_files = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            header_files.append(g)

    classes = get_classes(input_directory, header_files, classList)
    num_classes = len(classes)
    num_files = len(header_files)
    # recordings = list()
    headers = list()

    for i in range(num_files):
        recording, header = load_challenge_data(header_files[i])
        # recordings.append(recording)
        headers.append(header)

    # Train model.
    print('Preparing the features and labels...')

    features = list()
    labels = list()

    for i in range(num_files):
        recording = load_challenge_data(header_files[i]) #recordings[i]
        header = headers[i]

        tmp = get_random_forest_features(recording, header)
        features.append(tmp)

        for l in header:
            if l.startswith('#Dx:'):
                labels_act = np.zeros(num_classes)
                arrs = l.strip().split(' ')
                for arr in arrs[1].split(','):
                    if map_class(arr, classList):
                        class_index = classes.index(arr.rstrip()) # Only use first positive index
                        labels_act[class_index] = 1
        labels.append(labels_act)

    features = np.array(features)
    labels = np.array(labels)

    # Replace NaN values with mean values
    imputer=SimpleImputer().fit(features)
    features=imputer.transform(features)
    joblib.dump(imputer, os.path.join(output_directory, 'imputer.sav'))

    # Scale the features
    scaler = StandardScaler()
    scaler.fit(features)
    joblib.dump(scaler, os.path.join(output_directory, 'rfScaler.sav'))

    # Train the classifier
    print('Training the Random Forest...')
    model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(features, labels)

    # Save model.
    print('Saving model...')

    final_model={'model':model, 'imputer':imputer}

    filename = os.path.join(output_directory, 'randomForest.sav')
    joblib.dump(final_model, filename, protocol=0)
    return

def train_ensemble_model(input_directory):
    # Load data.
    print('Loading data...')

    header_files = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g) and f[:-4] in validIDs:
            header_files.append(g)

    classes = get_classes(input_directory, header_files, classList)
    num_classes = len(classes)
    num_files = len(header_files)
    recordings = list()
    headers = list()

    for i in range(num_files):
        recording, header = load_challenge_data(header_files[i])
        recordings.append(recording)
        headers.append(header)

    # Train model.
    print('Training model...')

    features = list()
    labels = list()

    for i in range(num_files):
        recording = recordings[i]
        header = headers[i]

        tmp = get_ensemble_features(recording, header)
        features.append(tmp)

        for l in header:
            if l.startswith('#Dx:'):
                labels_act = np.zeros(num_classes)
                arrs = l.strip().split(' ')
                for arr in arrs[1].split(','):
                    if map_class(arr, classList):
                        class_index = classes.index(arr.rstrip())  # Only use first positive index
                        labels_act[class_index] = 1
        labels.append(labels_act)


    features = np.array(features)
    labels = np.array(labels)

    # Train the classifier
    ensembleModel = None #XGBClassifier().fit(features, labels)

    # Save model.
    print('Saving model...')

    filename = 'ensembleModel.sav'
    joblib.dump(ensembleModel, filename, protocol=0)
    return


## %%%%%%%%%%%%   %%%%%%%%%%%%%%%%%%   %%%%%%%%%%%%%%%%%%%%%%
# Load challenge data.
def load_challenge_data(header_file):
    with open(header_file, 'r') as f:
        header = f.readlines()
    mat_file = header_file.replace('.hea', '.mat')
    x = loadmat(mat_file)
    recording = np.asarray(x['val'], dtype=np.float64)
    return recording, header


# Find unique classes.
def get_classes(input_directory, filenames, classList):
    classes = set()
    for filename in filenames:
        with open(filename, 'r') as f:
            for l in f:
                if l.startswith('#Dx'):
                    tmp = l.split(': ')[1].split(',')
                    for c in tmp:
                        tmpClass = map_class(c.strip(), classList)  # JCC
                        if tmpClass:
                            classes.add(tmpClass)
    return sorted(classes)

# JCC - Map the SNOMED CT code to the diagnoses
def map_class(code, classList):
    if code in classList:
        return code
    if code == '59118001':
        return '713427006'
    if code == '63593006':
        return '284470004'
    if code == '17338001':
        return '427172004'
    if code == '59118001':
        return '713427006'
    if code == '63593006':
        return '284470004'
    if code == '17338001':
        return '427172004'
    return None

##############################################################
if __name__ == '__main__':
    print(classMap)