#!/usr/bin/env python


import numpy as np, os, sys
from helper_code import *
from prep import *
from DualLSTMClass import *
import pickle
import torch

ENABLE_LOGGING=False
# Computation of measures
# attributes wrong positive predictions to all true positives euqally
def compute_confusion(y_true,y_pred):     
    n_class = y_true.shape[1]
    conf = np.zeros((n_class,n_class))
    for n in range(y_true.shape[0]):
        trues = np.argwhere(y_true[n,:]) 
        if trues.shape[0]:
            for p in range(n_class):
                if p in trues:
                    conf[p,p]+=y_pred[n,p]
                else:
                    conf[trues,p]+=y_pred[n,p]/trues.shape[0]
                        
    return conf

def save_confusion(confusion,output_dir,model_name,data_name):
    output_file = os.path.join(output_dir, model_name +'_'+ data_name + '.csv')
    classes = list(DIAG_CLASSES.keys())
    output_string= ' ,' + ','.join(c for c in classes) +'\n'
    
    for i in range(confusion.shape[0]):
        output_string += classes[i] + ','
        output_string += ','.join(f"{confusion[i,j]:.4f}" for j in range(confusion.shape[1])) + '\n'
    # Save the model outputs.
    with open(output_file, 'w') as f:
        f.write(output_string)

def compute_confusion_rel(y_true,y_pred):
    y_appear=np.sum(y_true,axis=0)
    print('occurence of classes:',y_appear)
    y_appear[y_appear==0]=1
    return compute_confusion(y_true,y_pred)/(y_appear.reshape(-1,1))


# Test model.
def test_deep_model(model_directory, data_directory, output_directory):
    # Find header and recording files.
    print('Finding header and recording files...')

    #header_files, recording_files = find_challenge_files(data_directory)
    
    sample_len = 2048
    
    _,patients_data,header_files = preprocess_recordings(data_directory,n_rand_cuts=1,fs_res=250,sample_len=sample_len,thresh_clean=11,validation=1.0,extract_labels=True)
    num_recordings = len(patients_data)
    
    model_name = model_directory
    filehandle = open(model_name,'rb')
    wrap = pickle.load(filehandle)
    filehandle.close()
    print('loaded model',model_name)
    
    
    if not num_recordings:
        raise Exception('No data was provided.')

    # Create a folder for the outputs if it does not already exist.
    if not os.path.isdir(output_directory):
        os.mkdir(output_directory)
        
        
        
    print('evaluating model ...')
    val_size = len(patients_data)
    num_classes = len(DIAG_CLASSES)
    X_t = np.zeros((val_size,2,sample_len))
    Y_t = np.zeros((val_size,num_classes))
    for i in range(val_size):
        X_t[i,:,:]=patients_data[i][0].numpy()
        Y_t[i,:]=patients_data[i][1]
    X_t = torch.Tensor(X_t)
    
    sig,lab = wrap.predict(X_t)
    lab = lab.numpy()
    decisions = lab>0.5    
        
    classes = list(DIAG_CLASSES.keys())    

    data_name = data_directory.split('/')
    print(compute_confusion(Y_t, decisions))
    save_confusion(compute_confusion_rel(Y_t,decisions), output_directory,'conf_'+ model_name[:-4], data_name[-1])

    for i in range(num_recordings):
        # print('    {}/{}...'.format(i+1, num_recordings))

        # Load header and recording.
        header = load_header(header_files[i])

        # Save model outputs.
        recording_id = get_recording_id(header)
        head, tail = os.path.split(header_files[i])
        root, extension = os.path.splitext(tail)
        output_file = os.path.join(output_directory, root + '.csv')
        save_outputs(output_file, recording_id, classes, decisions[i,:], lab[i,:])

    print('Done.')
    
def combine_labels(multi_labels,classes):
    normal = '426783006'
    assert len(classes) == multi_labels.shape[1]
    n_ind = classes.get(normal)
    single_lables=torch.zeros((multi_labels.shape[1]))
    for i in  range(multi_labels.shape[1]):
        if i == n_ind:
            single_lables[i] = torch.min(multi_labels[:,i])
        else:
            single_lables[i] = torch.max(multi_labels[:,i])
    return single_lables
    
def test_single_example(model,header,recording):
    segments_person = prep_single(header, recording)
    sig,lab = model.predict(segments_person)
    classes = list(DIAG_CLASSES.keys())  
    if ENABLE_LOGGING:
        print(str(header),str(lab.numpy()))
    lab_sing = combine_labels(lab, DIAG_CLASSES)
    lab_sing = lab_sing.numpy()
    labels = lab_sing>0.5
    probabilities = lab_sing
    return classes,labels.reshape(-1),probabilities.reshape(-1)
    
def load_model_deep(model_dir,leads):
    model_name = model_dir + '/deep_model.pkl'
    #model_name = model_dir + '/seventh_model20210630.pkl'
    #model_name = model_dir + '/fourth_model20210624.pkl'
    filehandle = open(model_name,'rb')
    wrap = pickle.load(filehandle)
    filehandle.close()
    print('loaded model',model_name)   
    return wrap

if __name__ == '__main__':
    # Parse arguments.
    if len(sys.argv) != 4:
        raise Exception('Include the model, data, and output folders as arguments, e.g., python test_model.py model data outputs.')

    model_directory = sys.argv[1]
    data_directory = sys.argv[2]
    output_directory = sys.argv[3]

    test_deep_model(model_directory, data_directory, output_directory)
