#!/usr/bin/env python

import numpy as np
#import joblib
import pickle as pkl
import torch
import numpy as np
from lib.models import attention_model
from lib.data.process import label_lib
from lib.data.process import input_lib

def run_12ECG_classifier(data,header_data,loaded_model):

    classes, e_classes, weights, normal_class_idx = (
        label_lib.get_classes_and_weights()
    )
    x = input_lib.get_all_feats(np.copy(data), np.copy(header_data))
    
    try: 
        for feat in x:
            x[feat] = torch.tensor(x[feat], dtype=torch.float)

        # Use your classifier here to obtain a label and score for each class.
        model = loaded_model['model']
        model.eval()

        y_pred = model.forward(x)
        y_pred_binary = y_pred > loaded_model['thresholds']
        y_pred = y_pred.detach().cpu().numpy().reshape(-1)
        y_pred_binary = y_pred_binary.detach().cpu().numpy().reshape(-1)

        assert y_pred.shape[0] == 24
        assert y_pred_binary.shape[0] == 24
        assert np.sum(y_pred) < 24
        assert np.sum(y_pred_binary) < 24

    except:
        # Use normal class.
        y_pred = np.zeros(24)
        y_pred[normal_class_idx] = 1
        y_pred_binary = y_pred

    return y_pred_binary, y_pred, classes

def load_12ECG_model(input_directory):
    # load the model from disk 
    hparams = {
        'dropout': 0.3, 'grad_clip_norm': 1.1756550102559349,
        'lambda_1': 0.3360538188978585,
        'lambda_2': 5.093445459314259,
        'lr': 0.0001513743424918295,
        'lstm_input_dim': 512,
        'lstm_hidden_dim': 256,
        'lstm_num_layers': 1,
        'num_classes': 24,
        'use_layer_norm': True}
    device = torch.device("cpu")

    model = attention_model.Attention(
        dim_header=3,
        dim_peak=5,
        dim_peak_delta=1,
        dim_fft=100 * 12,
        hparams=hparams,
        device='cpu').to(device)
    
    loaded_model = {}
    model.load_state_dict(torch.load(f'model.pt', map_location=device))
    loaded_model['model'] = model

    # Thresholds computed using baysian optimization.
    thresholds = np.array([0.18695211, 0.07317427, 0.37836097, 0.83147034, 
      0.81668541, 0.62888138, 0.06455703, 0.52245425, 0.22561041, 0.17002092,
      0.22300197, 0.62799837, 0.63907804, 0.17370408, 0.30089595, 0.18410993,
      0.55902293, 0.65451802, 0.12220304, 0.36191285, 0.17743886, 0.10119567,
      0.39077215, 0.19125648])
    thresholds = torch.tensor(thresholds, dtype=torch.float)
    
    loaded_model['thresholds'] = thresholds
    return loaded_model