#!/usr/bin/env python

import numpy as np, os, sys
import joblib
from get_12ECG_features import get_12ECG_features, get_random_forest_features
from Classifiers.CNN1D_Nature import NatureResid1D, get_forked_model

def run_12ECG_classifier(data,header_data,loaded_model):

    # Classes to ignore
    ignored = ['59118001', '63593006', '17338001']

    # Use your classifier here to obtain a label and score for each class.
    model = loaded_model['model']
    classes = loaded_model['classes']
    scaler = loaded_model['scaler']
    imputer = loaded_model['imputer']
    rfScaler = loaded_model['rfScaler']
    thresholds = loaded_model['thresholds']

    inData1 = get_12ECG_features(data, header_data, scaler)
    inData1_reshape = inData1.reshape((1, inData1.shape[0], inData1.shape[1]))
    inData2 = get_random_forest_features(data, header_data, rfScaler=rfScaler, imputer=imputer)
    inData2_reshape = inData2.reshape(1, inData2.shape[1])
    features = [np.array(inData1_reshape), np.array(inData2_reshape)]
    current_score = model.predict(features)[0]
    current_score = np.asarray(current_score)
    # current_label = []
    # for i in range(len(classes)):
    #     score = current_score[i]
    #     if classes[i] in ignored:
    #         lab = 0
    #     else:
    #         lab = 0 if score < thresholds[classes[i]] else 1
    #     current_label.append(lab)
    # current_label = np.asarray(current_label)
    current_label = np.where(current_score > 0.5, 1, 0).astype(int)

    return current_label, current_score,classes

def load_12ECG_model(input_directory):
    # Load scaler
    scaler = joblib.load(os.path.join(os.getcwd(), 'scaler.sav'))

    # Load imputer
    imputer = joblib.load(os.path.join(os.getcwd(), 'imputer.sav'))

    # Load rfScaler
    rfScaler = joblib.load(os.path.join(os.getcwd(), 'rfScaler.sav'))

    # Get the optimal thresholds
    thresholds = joblib.load('optimalThresholds.sav')

    # load the model from disk
    filename = os.path.join(input_directory, 'finalized_model.wts')
    classes = joblib.load(os.path.join(input_directory, 'classes.sav'))
    # model = NatureResid1D(nClasses=len(classes))
    model = get_forked_model()
    model.load_weights(filename)

    loaded_model = {'model':model, 'classes':classes, 'scaler':scaler, 'imputer':imputer, 'rfScaler':rfScaler, 'thresholds':thresholds}

    return loaded_model