#!/usr/bin/env python

import numpy as np, os, sys
import joblib
import torch
from scipy import signal

from constants import treshold_23, treshold_snr, hidden_size_final, dropout_cnn_final
from dataset import resample, tensor_padding
from models import ecg_classifier
from preprocessing import preprocessing_ecg
from utils import convert_results


def run_12ECG_classifier(data, header_data, loaded_model):
    # Use your classifier here to obtain a label and score for each class.
    maxlen = 5000
    goal_freq = 500
    freq = get_freq(header_data)
    data_ecg = []
    seconds = int(len(data[0]) / freq)
    for i in range(12):
        data_ecg.append(signal.resample(data[i], int(seconds * goal_freq)))
    data_ecg = preprocessing_ecg(data_ecg, header_data)
    ecg = tensor_padding(data_ecg, maxlen)
    ecg = torch.FloatTensor(ecg)
    ecg = ecg.unsqueeze(0)

    loaded_model.eval()
    current_score = loaded_model(ecg)

    ones_tensor = torch.ones(current_score.shape)
    zeroes_tensor = torch.zeros(current_score.shape)
    tresh_23 = torch.ones((1, 23)) * treshold_23
    tresh_snr = torch.ones((1, 1)) * treshold_snr
    treshold_total = torch.cat((tresh_23, tresh_snr), dim=1)
    current_label = torch.where(current_score > treshold_total, ones_tensor, zeroes_tensor)

    classes = ['164889003',
               '164890007',
               '426627000',
               '270492004',
               '713426002',
               '39732003',
               '445118002',
               '164909002',
               '164947007',
               '251146004',
               '111975006',
               '698252002',
               '284470004',
               '63593006',
               '10370003',
               '427172004',
               '17338001',
               '164917005',
               '47665007',
               '713427006',
               '59118001',
               '427393009',
               '426177001',
               '426783006',
               '427084000',
               '164934002',
               '9931005']

    current_label = convert_results(current_label.detach().numpy().tolist()[0])
    current_score = convert_results(current_score.detach().numpy().tolist()[0])

    return current_label, current_score, classes


def load_12ECG_model(input_directory):
    # load the model from disk
    load_model_name = "finalized_model.pt"
    filename = os.path.join(input_directory, load_model_name)

    model_eval = ecg_classifier(dropout_cnn_final, hidden_size_final, dropout_cnn_final, batch_size=1)
    model_eval.load_state_dict(torch.load(filename))

    return model_eval


def get_freq(header_data):
    for lines in header_data:
        tmp = lines.split(' ')
        return int(tmp[2])
