#!/usr/bin/env python

import os
from scipy.io import loadmat
import numpy as np
import pandas as pd
import keras
from keras.models import Sequential, load_model
from keras.layers import LSTM, GRU, TimeDistributed, Bidirectional, LeakyReLU, BatchNormalization
from keras.layers import Dense, Dropout, Activation, Flatten,  Input, Reshape, GRU, CuDNNGRU, MaxPooling1D
from keras.layers import Convolution1D, MaxPool1D, GlobalAveragePooling1D,concatenate,AveragePooling1D,Conv1D
from keras.models import Model
from sklearn.preprocessing import MinMaxScaler
from keras.models import model_from_json
from scipy.signal import resample
from sklearn.model_selection import train_test_split
from keras.callbacks import EarlyStopping, ModelCheckpoint
from helper import *

def train_12ECG_classifier(input_directory, output_directory):
    print('Loading data...')


    header_files = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            # print (g)
            header_files.append(g)

    df = pd.read_csv('dx_mapping_scored.csv')
    codes = [str(code) for code in df['SNOMED CT Code'].values]
    classes = [str(abb) for abb in df['Abbreviation'].values]
    # print (classes)
    # print (codes)
    # codes = []
    # for code in df['SNOMED CT Code'].values:
    #     codes.append(code)
    # len(codes)

    input_files = []
    final_labels = []

    for filename in header_files:
        flag = 0
        with open(filename, 'r') as f:
            for l in f:
                if l.startswith('#Dx'):
                    tmp = l.split(': ')[1].split(',')
                    for c in tmp:
                        if c.strip() in codes:
                            flag = 1
                            code = c.strip()
                            break
                    if flag == 1:
                        input_files.append(filename)
                        final_labels.append(code)

    # print (len(input_files))
    # print (len(final_labels))

    input_labels = []
    for l in final_labels:
        input_labels.append(codes.index(l))
    # print (len(input_labels))
    input_labels = keras.utils.to_categorical(input_labels)
    # print (input_labels.shape)

    X_train, X_val, y_train, y_val = train_test_split(input_files, input_labels, test_size=0.1, random_state=1, shuffle=True)

    model = ResNet_model(2500,12,27,8)
    batch_size = 16

    train_generator = generator(X_train, y_train, batch_size)
    val_generator = generator(X_val, y_val, batch_size)

    train_samples=np.ceil(len(X_train) / batch_size)
    val_samples=np.ceil(len(X_val) / batch_size)

    callbacks = [
        ModelCheckpoint(filepath=output_directory+'/check_model.h5', monitor='val_loss', save_best_only=True),
        EarlyStopping(monitor='val_acc', verbose=1, patience=5)]

    history = model.fit_generator(train_generator, steps_per_epoch=train_samples, epochs=100, verbose=1, 
                        validation_data=val_generator, validation_steps=val_samples, callbacks=callbacks)

    save_the_model(output_directory, model)



    # header_files = []
    # for f in os.listdir(input_directory):
    #     g = os.path.join(input_directory, f)
    #     if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
    #         # print (g)
    #         header_files.append(g)


    # classes = get_classes(input_directory, header_files)
    # num_classes = len(classes)
    # num_files = len(header_files)
    # recordings = list()
    # headers = list()

    # print (classes)
    # print('Loading data...')

    # df = pd.read_csv('input_files.csv',header=None)
    # input_files = list(df.iloc[:,0])
    # print (len(input_files))

    # df = pd.read_csv('classes.csv',header=None)
    # classes = list(df.iloc[:,0])
    # print (classes)

    # final_labels = pd.read_csv('final_labels.csv',header=None)
    # final_labels = final_labels.values[:,:-1]
    # print (final_labels.shape)

    # X_train, X_val, y_train, y_val = train_test_split(input_files, final_labels, test_size=0.5, random_state=1, shuffle=True)


# write_history(output_directory, history.history)

# def load_challenge_data(filename):
#     x = loadmat(filename)
#     data = np.asarray(x['val'], dtype=np.float64)
#     new_file = filename.replace('.mat', '.hea')
#     input_header_file = os.path.join(new_file)
#     with open(input_header_file, 'r') as f:
#         header_data = f.readlines()
#     return data, header_data

# Load challenge data.
# def load_challenge_data(header_file):
#     with open(header_file, 'r') as f:
#         header = f.readlines()
#     mat_file = header_file.replace('.hea', '.mat')
#     x = loadmat(mat_file)
#     recording = np.asarray(x['val'], dtype=np.float64)
#     return recording, header

# Find unique classes.
# def get_classes(input_directory, filenames):
#     classes = set()
#     for filename in filenames:
#         with open(filename, 'r') as f:
#             for l in f:
#                 if l.startswith('#Dx'):
#                     tmp = l.split(': ')[1].split(',')
#                     for c in tmp:
#                         classes.add(c.strip())
#     return sorted(classes)
