#!/usr/bin/env python
import torch
import numpy as np, os, sys
from model import *
import pickle
import scipy.signal as sig
from preprocessing import *
import re
def run_12ECG_classifier(data,header_data,loaded_model):


    # Use your classifier here to obtain a label and score for each class.
    model = loaded_model['model']
    classes = loaded_model['classes']
    threshold = loaded_model['thresholds']
    fs = int(header_data[0].split(" ")[2])
    srange = re.findall("\d+/mV", header_data[1])[0]
    srange = int(srange[:-3])
    data = _preprocessing(data,fs,srange)
    y = model(data)
    y = torch.sigmoid(y)
    p = y.data.cpu().numpy()[0,:]
    p = np.round(p,3)
    current_label = np.array(p>=threshold,dtype=int)
    current_score = p
    return current_label, current_score,classes

def _preprocessing(data,fs,scale):
    # resample to ~250hz
    data = preprocessing(data,fs,srange=scale,eval=False)

    T = 250*60*2
    X = np.zeros((14, T))
    X[:, -data.shape[-1]:] = data

    X = torch.from_numpy(X).unsqueeze(0).unsqueeze(2)
    X = X.to("cuda:0").float()
    return X


def load_12ECG_model(input_directory):
    # load the model from disk 
    modelname='model'
    filename = os.path.join(input_directory,modelname)
    device = 'cuda:0'
    loaded_model = dict()
    loaded_model["model"] = NN(nOUT=24,device="cpu")
    state = torch.load(filename)
    loaded_model["model"].load_state_dict(state["state_dict"])
    loaded_model["model"].to(device)
    loaded_model["model"].eval()
    loaded_model["classes"] = state["classes"]
    loaded_model["thresholds"] = state["thresholds"]

    return loaded_model
