#!/usr/bin/env python

import numpy as np, os, sys, joblib
from scipy.io import loadmat
import torch
from torch.utils.data import DataLoader
import csv

from constants import encodeur, list_pathologies, dropout_cnn_final, hidden_size_final, learning_rate, epochs_number, \
    test_pathologies
from dataset import Dataset
from models import ecg_classifier
from training import train_epoch


def create_dataset(input_directory):
    input_files = []
    with open('train_filenames.csv', newline='') as f:
        reader = csv.reader(f)
        training = list(reader)
    for i, f in enumerate(training):
        f = f[0]
        if f[-1] != 'a':
            f = f + '.hea'

        tmp_input_file = os.path.join(input_directory, f)
        if not (os.path.isfile(tmp_input_file)):
            continue

        data_ecg, header_data = load_challenge_data(tmp_input_file)
        for iline in header_data:
            if iline.startswith('#Dx'):
                label = iline.split(': ')[1].split(',')
                for index in range(len(label)):
                    label[index] = label[index].rstrip("\n")
                    if label[index] in test_pathologies:
                        input_files.append(f)
                        break
                    if label[index] == '426783006':
                        input_files.append(f)
                        break
    return input_files


def train_12ECG_classifier(input_directory, output_directory):
    encodeur.fit(list_pathologies)
    input_files = create_dataset(input_directory)
    max_length = 5000
    batch_size = 1024
    training_set = Dataset(max_length, len(input_files), input_files, input_directory, training=True)

    trainset_dl = DataLoader(dataset=training_set,
                             batch_size=batch_size,
                             shuffle=True,
                             pin_memory=False)

    model = ecg_classifier(dropout_cnn=dropout_cnn_final, hidden_size=hidden_size_final, dropout_gru=dropout_cnn_final,
                           batch_size=batch_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True)

    for i in range(epochs_number):
        model.train()
        train_epoch(trainset_dl, model, lr_scheduler, optimizer, i, train=True, batch_size=batch_size)

    torch.save(model.state_dict(), output_directory + "/finalized_model.pt")


# Load challenge data.
def load_challenge_data(header_file):
    with open(header_file, 'r') as f:
        header = f.readlines()
    mat_file = header_file.replace('.hea', '.mat')
    x = loadmat(mat_file)
    recording = np.asarray(x['val'], dtype=np.float64)
    return recording, header


# Find unique classes.
def get_classes(input_directory, filenames):
    classes = set()
    for filename in filenames:
        with open(filename, 'r') as f:
            for l in f:
                if l.startswith('#Dx'):
                    tmp = l.split(': ')[1].split(',')
                    for c in tmp:
                        classes.add(c.strip())
    return sorted(classes)
