# -*- coding: utf-8 -*-
"""
Created on Tue Jul 28 14:51:40 2020

@author: BaumgartnerM
"""

import numpy as np
import os
import cinc_aux as aux
import pandas as pd
from keras.optimizers import Adam
from sklearn.model_selection import StratifiedKFold
from evaluate_12ECG_score import compute_challenge_metric, load_weights

data_directory = "C:/Data/cinc-2020"
datasets = os.listdir(data_directory)
input_directory = data_directory + '/' + datasets[0]


print("Preparing data...")
x, y, dataset_id, classes = aux.prepare_data(input_directory)
x = x.swapaxes(2,1)

print("Organizing labels...")

weights_file = 'weights.csv'
normal_class = '426783006'

classes_meta = np.loadtxt('dx_mapping_scored.csv', delimiter = ',', dtype = object)
full_classes = list(classes_meta[1:, 1]) 
weights = load_weights(weights_file, full_classes)

decoded_y = aux.decode_labels(y)


epochs = 500
batch_size = 32
info_interval = 5
n_splits = 5

# Set up randomized, stratified k-fold split for k-fold crossvalidation
skf = StratifiedKFold(n_splits=n_splits, shuffle = True, random_state = 42)

for train_index, validation_index in skf.split(x, decoded_y):
    x_train, x_validation = x[train_index], x[validation_index]
    y_train, y_validation = y[train_index], y[validation_index]
    
    print("Building model...")
    model = aux.dnn(x.shape[1], x.shape[2], y.shape[1])
    opt = Adam(0.0001, 0.9)
    model.compile(loss='binary_crossentropy', metrics = ['acc'], optimizer = opt)
    
    for e in range(epochs):
        model.fit(x_train, y_train, batch_size = batch_size, epochs = 1, verbose = 1)
        if e % info_interval == 0:
            scores = model.predict(x_validation)
            labels = aux.scores_to_label(scores)
            scores, labels = aux.label_mapping(classes, scores, labels)
            _, y_full = aux.label_mapping(classes, scores, y_validation)
            metric = compute_challenge_metric(weights, y_full, labels, full_classes, normal_class)
            print(metric)
     
        