#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import datetime
import os

import keras
from sklearn.utils import compute_class_weight

from ResSeNet import ResSENet
# from new_resNet import ResSENet
# from compute_test_feature import Metrics
from data_util import *
from parameters import convlstm1_params
from read_data import ReadData, MAX_LEN

MAX_EPOCHS = 30


def make_save_dir(dirname, experiment_name):
    # start_time = str(datetime.datetime.now().strftime('%m%d-%H/%M/%S')) + '-' + str(random.randrange(100))
    start_time = datetime.datetime.now().strftime('%m%d-%H_%M_%S')
    save_dir = os.path.join(experiment_name, start_time)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    return save_dir


def get_filename_for_saving(save_dir):
    return os.path.join(save_dir, '{epoch:03d}-{val_loss:.3f}-{val_acc:.3f}-{loss:.3f}-{acc:.3f}.hdf5')


def train_12ECG_classifier(input_directory, output_directory):
    params = convlstm1_params
    val_fold = 1
    print('Loading data set...')
    read_data = ReadData(raw_label=True)
    dataset, labelset = read_data.load_dataset(path=input_directory)
    print(np.shape(dataset[0]))

    print('normalization...')
    if params.get('normalization', False):
        dataset = [meanstd_scale(ecg) for ecg in dataset]

    label_concatenate = labelset[:, 0]
    for i in range(1, labelset.shape[1]):
        label_concatenate = np.hstack((label_concatenate, labelset[:, i]))
    label_concatenate = label_concatenate[label_concatenate != 0]
    class_weight = compute_class_weight('balanced', np.unique(label_concatenate), label_concatenate)
    class_weight = dict(enumerate(class_weight))
    train_data, train_label, dev_data, dev_label = data_split(dataset, labelset, folds=5, valid_fold=val_fold)

    print("Building preprocessor...")
    preproc = Preprocess_data(train_data, train_label, channels=12)
    print('Training size: ' + str(len(train_data)) + ' examples.')
    print('Dev size: ' + str(len(dev_data)) + ' examples.')

    # save_dir = make_save_dir(params['save_dir'], output_directory)
    # print(save_dir)

    params.update({"input_shape": [MAX_LEN, 12],
                   "num_categories": 111,
                   'class_weight': class_weight})

    network = ResSENet()
    model = network.build_network(**params)

    stopping = keras.callbacks.EarlyStopping(patience=5)
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        factor=0.1,
        patience=2,
        min_lr=params["learning_rate"] * 0.001)

    # checkpointer = keras.callbacks.ModelCheckpoint(
    #     filepath=get_filename_for_saving(save_dir),
    #     save_best_only=False)

    batch_size = params.get("batch_size", 32)

    train_gen = data_generator(batch_size, preproc, train_data, train_label)
    dev_gen = data_generator(batch_size, preproc, dev_data, dev_label)
    dev_x, _ = preproc.process(dev_data, dev_label)
    # metrics = Metrics(dev_x, dev_label)
    model.summary()
    model.fit_generator(train_gen,
                        steps_per_epoch=int(len(train_data) / batch_size),
                        epochs=MAX_EPOCHS,
                        verbose=2,
                        validation_data=dev_gen,
                        validation_steps=int(len(dev_data) / batch_size),
                        class_weight=params['class_weight'],
                        callbacks=[
                            # checkpointer,
                            reduce_lr,
                            stopping]
                        )
    model_name = 'model.hdf5'
    path = os.path.join(output_directory, model_name)
    model.save(path)

    # f1_file_name = os.path.join(save_dir, 'dev_set.txt')
    # np.savetxt(f1_file_name, np.asarray(dev_f1), fmt='%3f')
    # param_file = os.path.join(save_dir, 'params.json')
    # util.make_json(param_file, params)
    # util.plot_model_history(history, save_dir)



# import joblib
# import numpy as np
# import os
# from scipy.io import loadmat
# from sklearn.ensemble import RandomForestClassifier
# from sklearn.impute import SimpleImputer
#
# from get_12ECG_features import get_12ECG_features
#
#
# def train_12ECG_classifier(input_directory, output_directory):
#     # Load data.
#     print('Loading data...')
#
#     header_files = []
#     for f in os.listdir(input_directory):
#         g = os.path.join(input_directory, f)
#         if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
#             header_files.append(g)
#
#     classes = get_classes(input_directory, header_files)
#     num_classes = len(classes)
#     num_files = len(header_files)
#     recordings = list()
#     headers = list()
#
#     for i in range(num_files):
#         recording, header = load_challenge_data(header_files[i])
#         recordings.append(recording)
#         headers.append(header)
#
#     # Train model.
#     print('Training model...')
#
#     features = list()
#     labels = list()
#
#     for i in range(num_files):
#         recording = recordings[i]
#         header = headers[i]
#
#         tmp = get_12ECG_features(recording, header)
#         features.append(tmp)
#
#         for l in header:
#             if l.startswith('#Dx:'):
#                 labels_act = np.zeros(num_classes)
#                 arrs = l.strip().split(' ')
#                 for arr in arrs[1].split(','):
#                     class_index = classes.index(arr.rstrip())  # Only use first positive index
#                     labels_act[class_index] = 1
#         labels.append(labels_act)
#
#     features = np.array(features)
#     labels = np.array(labels)
#
#     # Replace NaN values with mean values
#     imputer = SimpleImputer().fit(features)
#     features = imputer.transform(features)
#
#     # Train the classifier
#     model = RandomForestClassifier().fit(features, labels)
#
#     # Save model.
#     print('Saving model...')
#
#     final_model = {'model': model, 'imputer': imputer, 'classes': classes}
#
#     filename = os.path.join(output_directory, 'finalized_model.sav')
#     joblib.dump(final_model, filename, protocol=0)
#
#
# # Load challenge data.
# def load_challenge_data(header_file):
#     with open(header_file, 'r') as f:
#         header = f.readlines()
#     mat_file = header_file.replace('.hea', '.mat')
#     x = loadmat(mat_file)
#     recording = np.asarray(x['val'], dtype=np.float64)
#     return recording, header
#
#
# # Find unique classes.
# def get_classes(input_directory, filenames):
#     classes = set()
#     for filename in filenames:
#         with open(filename, 'r') as f:
#             for l in f:
#                 if l.startswith('#Dx'):
#                     tmp = l.split(': ')[1].split(',')
#                     for c in tmp:
#                         classes.add(c.strip())
#     return sorted(classes)
