#!/usr/bin/env python

import numpy as np
import pandas as pd
import os
import keras
from keras.layers import Input
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score
from keras.callbacks import ModelCheckpoint, EarlyStopping
import get_12ECG_features as d

diagnosis_id = {'270492004': 'IAVB',
                '164889003': 'AF',
                '164890007': 'AFL',
                '426627000': 'Brady',
                '713427006': 'CRBBB',
                '713426002': 'IRBBB',
                '445118002': 'LAnFB',
                '39732003': 'LAD',
                '164909002': 'LBBB',
                '251146004': 'LQRSV',
                '698252002': 'NSIVCB',
                '10370003': 'PR',
                '284470004': 'PAC',
                '427172004': 'PVC',
                '164947007': 'LPR',
                '111975006': 'LQT',
                '164917005': 'QAb',
                '47665007': 'RAD',
                '59118001': 'CRBBB',
                '427393009': 'SA',
                '426177001': 'SB',
                '426783006': 'SNR',
                '427084000': 'STach',
                '63593006': 'PAC',
                '164934002': 'TAb',
                '59931005': 'TInv',
                '17338001': 'PVC'}

classes = list(dict.fromkeys(sorted(diagnosis_id.values())))

def train_12ECG_classifier(input_directory, output_directory):
    # Load data.
    print('Loading data...')

    header_files = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            header_files.append(g)

    # Train model.
    print('Training model...')

    num_files = len(header_files)

    # id & label dictionary
    Id = []
    OneHotLabel = []
    for k in range(num_files):
        recording_id, recording_labels = get_true_labels(header_files[k], classes)
        Id.append(recording_id)
        OneHotLabel.append(recording_labels)

    y_label = pd.DataFrame(OneHotLabel, columns=classes)
    y_label['id'] = Id
    y_label = y_label.sample(frac=1, random_state=42).reset_index(drop=True)
    labels = dict(zip(y_label.id, y_label[classes].values.astype(int)))

    # split train & validation set
    X_list = np.zeros((len(y_label), 12), dtype=np.float32)
    y = np.argmax(y_label.values[:, :-1], axis=1)

    KF = KFold(n_splits=10, random_state= 34, shuffle=True)
    KF.get_n_splits(X_list, y)
    CinC_2020_FoldList = []
    for train_index, test_index in KF.split(X_list, y):
        CinC_2020_FoldList.append((train_index, test_index))
    CinC_2020_FoldList = np.asarray(CinC_2020_FoldList)

    partition_lsit = []
    for i in range(10):
        # Store the ids and labels in dictionaries
        partition = {'train': list(y_label.iloc[CinC_2020_FoldList[i][0]].id),
                     'validation': list(y_label.iloc[CinC_2020_FoldList[i][1]].id)
                     }
        partition_lsit.append(partition)

    # Train the classifier
    model_name = 'Final_model'

    train_generator = d.DataGenerator_resample(partition_lsit[0]['train'], labels)
    val_generator = d.DataGenerator_resample(partition_lsit[0]['validation'], labels)

    # model structure

    main_input = Input(shape=(30000, 12), dtype='float32', name='main_input')
    model = d.Inception_model(6)

    opt = keras.optimizers.Adam()
    model.compile(loss="binary_crossentropy",
                  optimizer=opt, metrics=["accuracy"])

    model_directory = os.path.join(output_directory, model_name)
    checkpointer = ModelCheckpoint(model_directory, verbose=1, save_best_only=True)
    es = EarlyStopping(monitor='val_loss', verbose=1, patience=20)

    model_history = model.fit_generator(generator=train_generator,
                                        epochs=100,
                                        validation_data=val_generator,
                                        callbacks=[checkpointer, es]
                                        )

    f_out = 'Final_model'
    model_name = os.path.join(output_directory, model_name)
    model.load_weights(model_name)

    # validation
    val_generator = d.DataGenerator_resample(partition_lsit[0]['validation'], labels, shuffle=False, batch_size=1)
    y_vali_predict = model.predict_generator(val_generator, steps=len(partition_lsit[0]['validation']))

    y_vali_label = y_label[classes].iloc[CinC_2020_FoldList[0][1]].values.astype(int)
    ID = y_label['id'][CinC_2020_FoldList[0][1]].values

    # output probability
    probability_table = pd.DataFrame(y_vali_predict)
    probability_table['id'] = ID

    # validation label
    y_vali = pd.DataFrame(y_vali_label, columns=classes)

    y_vali['id'] = ID

    def to_labels(pos_probs, threshold):
        return (pos_probs >= threshold).astype('int')

    threshold_list = []
    thresholds = np.arange(0, 1, 0.001)
    for k in y_vali.columns[:-1]:
        probability_table_noID = probability_table.iloc[:, :-1]
        probability_table_noID.columns = classes

        # evaluate each threshold
        scores = [f1_score(y_vali[k].values.astype(int), to_labels(probability_table_noID[k].values, t)) for t in
                  thresholds]
        # get best threshold
        ix = np.argmax(scores)
        threshold_list.append(thresholds[ix])
        print('type:', k)
        print('Threshold=%.3f, F-Score=%.5f' % (thresholds[ix], scores[ix]))
        print('----------------------')
        print('\n')

    best_threshold = np.array(threshold_list)
    threshold_directory = os.path.join(output_directory, 'best_threshold')
    np.save(threshold_directory, best_threshold)

def get_true_labels(label_files, classes):

    recording_labels = np.zeros(len(classes), dtype=int)

    with open(label_files, 'r') as f:
        first_line = f.readline()
        recording_id = first_line.split(' ')[0]
        for lines in f:
            if lines.startswith('#Dx'):
                tmp = lines.split(': ')[1].split(',')
                for c in tmp:
                    label = diagnosis_id.get(c.strip())
                    try:
                        idx = classes.index(label)
                        recording_labels[idx] = 1
                    except:
                        continue

    return recording_id, recording_labels
