# -*- coding: utf-8 -*-
"""
Created on Tue Apr 13 10:46:48 2021

@author: Linschmann
"""
import numpy as np
from team_code import fold,load_model,save_model,load_model,load_features_from_file,training_code, process_recordings, lead_sets, twelve_leads, six_leads, four_leads, three_leads, two_leads
from helper_code import *
from ecg_recording import DIAG_CLASSES
from evaluate_model import evaluate_model


output_path = "outputs/fold" + str(fold) +"_outputs/"

fold_header_path = "E:/scibo/Arbeit/Diss/CINC/2021/Test_Code_locally/test_data/fold" + str(fold) +"_header/"

def fold_extraction():
    
    root = 'E:/scibo/Arbeit/Diss/CINC/2021/'
    data_sets = ['WFDB_Ningbo','WFDB_ChapmanShaoxing','WFDB_CPSC2018', 'WFDB_CPSC2018_2', 'WFDB_Ga', 'WFDB_PTB', 'WFDB_PTBXL', 'WFDB_StPetersburg']
    
    features = []
    labels = []
    for dataset in data_sets:
        count = 0
        data = np.load(root +'Datasets/'+ dataset + 'features.npz')
        labels_temp = data['labels']
        features_temp = data['data']
        indices = data['indices']
        indx = []
        #remove recordings with no label in scored labels
        for label in labels_temp:
            num_labels = sum(label)
            if num_labels < 1:
                indx.append(count)
            count = count + 1
            
        labels_del = np.delete(labels_temp, indx, 0)
        features_del = np.delete(features_temp, indx, 0)
        labels.extend(labels_del)
        features.extend(features_del)
    
    #randomly shuffle feautres/labels
    labels = np.asarray(labels)
    features = np.asarray(features)  
    shuffle_indx = np.arange(labels.shape[0])
    rng = np.random.default_rng()
    rng.shuffle(shuffle_indx)
    labels = labels[shuffle_indx]
    data = features[shuffle_indx]

    num_recordings = data.shape[0]
    fold_len = np.floor(num_recordings/5)
    for i in range(5):
        indx = np.arange(start = int(i*fold_len), stop = int(((i+1)*fold_len)), step = 1)
        data_val = data[indx,:]
        data_train = np.delete(data, indx,axis = 0)
        labels_val = labels[indx,:]
        labels_train = np.delete(labels,indx,axis = 0)
        np.savez("feature_extraction_result/Train/features_fold"+str(i+1)+"_train.npz", data = data_train, labels = labels_train, indices = indices)
        np.savez("feature_extraction_result/Test/features_fold"+str(i+1)+"_test.npz", data = data_val, labels = labels_val, indices =indices)
    
    

def fold_test():
    
    #load saved features for test fold
    loaded = np.load("feature_extraction_result/Test/features_fold" + str(fold) +"_test.npz")
    data = loaded['data']
    ref_labels = loaded['labels']
    indices = loaded['indices']
    num_recordings = data.shape[0]
    model_directory = "model/fold" + str(fold)

    write_header = True
    for leads in lead_sets:
        
        #Load model
        model =  load_model(model_directory, leads)
        
        # Get features for lead
        lead_indices =  [twelve_leads.index(lead) for lead in leads]
        feature_indices = [0,1]
        for idx in lead_indices:
            feature_indices.extend( list(range(indices[idx], indices[idx+1])))
        features = data[:, feature_indices]
            
        
       
        for i in range(num_recordings):
            #Convert Classes dict into numpy array for look up in reverse order
            print(str(i+1)+ "of" + str(num_recordings))
            class_array = np.array(list(DIAG_CLASSES.items()))[:,0]
            
            # Write header the first time (for the first lead, always the same)
            if write_header:
                #write header (important for evalutation)
                file = open(fold_header_path + "recording" + str(i) +".hea", "w+")
                file.write("#Dx: ")
                first_class = True
                for j in range(ref_labels.shape[1]):
                    if ref_labels[i,j] >0:  #labels with one
                        if first_class == True:
                            file.write(str(class_array[j]))
                            first_class = False
                        else:
                            file.write("," + str(class_array[j]))
                file.close()
            #================================================================================================
            # run model

            classes, labels, probabilities = run_model_val(model, features[i,:])
            
            # Save model outputs.
            save_outputs(output_path + '-'.join(sort_leads(leads))  + "/recording" + str(i) + ".csv", i , classes, labels, probabilities)
            #===================================================================================================
        write_header=False
        
def run_model_val(model,features):
    classes = model['classes']
    leads = model['leads']
    imputer = model['imputer']
    classifier = model['classifier']
    scaler = model['scaler']
    #scaler = classifier.get_scaler()   


    # Impute missing data.
    features = features.reshape(1, -1)
    features = scaler.transform(features)
    features = imputer.transform(features)


    # Predict labels and probabilities.
    #From neural network
    #labels = classifier.predict_labels(features) #neural network function
    #probabilities = classifier.predict_probs(features) #neural network function
    
    #From RandomForrest
    labels = classifier.predict(features)
    labels = np.asarray(labels, dtype=np.int)[0]
    probabilities = classifier.predict_proba(features).flatten()
    #probabilities = np.asarray(probabilities, dtype=np.float32)[:, 0, 1]
    
    return classes, labels, probabilities
     
     
     
def generate_features():
    datasets = ['WFDB_Ningbo','WFDB_ChapmanShaoxing','WFDB_CPSC2018', 'WFDB_CPSC2018_2', 'WFDB_Ga', 'WFDB_PTB', 'WFDB_PTBXL', 'WFDB_StPetersburg']
    path = '../Datasets/'
    
    for data in datasets:
        print('processing dataset' + data)
        process_recordings(path + data)
         
     
     
     
if __name__ == '__main__':
    generate_features()
    fold_extraction()
    #model_directory = "model/fold" + str(fold) 
    #data_directory = "../Datasets/Train_Full"
    
    #model_directory = "model/test"
    #training_code(data_directory, model_directory)
    #fold_test()
    
    # Evaluation of outputs
    #for leads in lead_sets:
    #    lead = '-'.join(sort_leads(leads)) 
    #    classes, auroc, auprc, auroc_classes, auprc_classes, accuracy, f_measure, f_measure_classes, challenge_metric = evaluate_model('test_data/fold' + str(fold) +'_header', 'outputs/fold' + str(fold) +'_outputs/' + lead)
    #    output_string = 'AUROC,AUPRC,Accuracy,F-measure,Challenge metric\n{:.3f},{:.3f},{:.3f},{:.3f},{:.3f}'.format(auroc, auprc, accuracy, f_measure, challenge_metric)
    #    class_output_string = 'Classes,{}\nAUROC,{}\nAUPRC,{}\nF-measure,{}'.format(
    #        ','.join('|'.join(sorted(x)) for x in classes),
    #        ','.join('{:.3f}'.format(x) for x in auroc_classes),
    #        ','.join('{:.3f}'.format(x) for x in auprc_classes),
    #        ','.join('{:.3f}'.format(x) for x in f_measure_classes))
    #    print("==================================================")
    #    print('Evaluate: ' + lead)
    #    print(output_string)
    #    print("==================================================")

        
        
    