# -*- coding: utf-8 -*-
"""
Created on Fri Jun 18 14:59:33 2021

@author: Maurice
"""

from helper_code import *
import neurokit2 as nk
from tqdm import tqdm
import h5py
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import numpy as np
from torch.utils.data import DataLoader

import torch
from sklearn.model_selection import train_test_split
from prep import *
from DualLSTMClass import *
from evaluate_model import compute_accuracy,compute_f_measure,compute_auc,compute_challenge_metric

import argparse
#import json



def train_deep_model(data_dir,model_dir,parameters):
    np.random.seed(42)
    
    #data_dir = "../../Data/WFDB_CPSC2018"
    model_name = model_dir + '/'+ parameters['model_name']
    epochs=parameters['epochs']
    n_random_cuts = parameters['random_cuts']
    sample_len = 2048
    alpha = parameters['alpha']
    print('loading data...')
    augmentation_types = {'dropout_burst':False,'random_zeroing':False}
    data_train,data_test,t_headers = preprocess_recordings(data_dir,n_rand_cuts=n_random_cuts,validation=0.10,fs_res=250,sample_len=sample_len,augmentation_types=augmentation_types,thresh_clean=11)
    
    
    #data_train,data_test = train_test_split(patients_data,test_size=0.10,random_state=42) 
    
    
    from torch.optim.lr_scheduler import MultiStepLR
    device = 'cuda'
    gpu=True
    learning_rate = 1.5e-3
    batch_size = 128
    
    n_labels = data_train[0][1].shape[0]
    
    dataset_train = SmallDataset(data_train)
    dataset_test = SmallDataset(data_test)
    dataloader_train = DataLoader(dataset_train,batch_size=batch_size,shuffle=True)
    dataloader_val = DataLoader(dataset_test,batch_size=batch_size,shuffle=True)
    
    #set multiplier as median from max and min values per recording
    # model = DualAEClassificationV2(n_labels=n_labels,seq_len=sample_len,in_channels=2,lstm_h_size=256,y_multiplier=5)
    model = DualAEClassificationV2Deep(n_labels=n_labels,seq_len=sample_len,in_channels=2,lstm_h_size=512,y_multiplier=5)
    #model = DualAEClassificationV3Deep(n_labels=n_labels,seq_len=sample_len,in_channels=2,lstm_h_size=512,y_multiplier=5)
    optimizer =  torch.optim.Adam(model.parameters(),lr=learning_rate)
    scheduler =  MultiStepLR(optimizer, milestones=[20,40,80,160], gamma=0.25)
    #loss_fn = nn.BCEWithLogitsLoss() #reconstruciton error
    #loss_fn = nn.CrossEntropyLoss()
    #loss_fn = nn.NLLLoss()

#%%

    Weight_loss = torch.ones((1,len(DIAG_CLASSES.keys())))*0.5
    Weight_loss[0,DIAG_CLASSES['426783006']]=1.1 # weight sinus rhythm more
    loss2 = MHotLoss_w(Weight_loss)
    # loss2 = MHotLoss
    
    wrap = DualModelWrapper(model,alpha,optimizer,scheduler=scheduler,loss2=loss2)
    print('Starting training...')
    wrap.train(epochs,dataloader_train,dataloader_val,gpu=gpu)
    
#==============================================================================    
    # train decision boundaries
    # _,data_opt,t_headers = preprocess_recordings(data_dir,n_rand_cuts=1,validation=1.0,fs_res=250,sample_len=sample_len,augmentation_types=dict(),thresh_clean=11)
    # val_size = len(data_opt)
    # X_t = torch.zeros((val_size,2,sample_len))
    # Y_t = np.zeros((val_size,26))
    # for i in range(val_size):
    #     X_t[i,:,:]=data_opt[i][0]
    #     Y_t[i,:]=data_opt[i][1]
    # _,Y_pred = wrap.predict(X_t)
    
    # Y_t = Y_t
    # Y_pred=Y_pred.numpy()
    # thres_vector = np.ones((1,Y_t.shape[1]))*0.5
    # best_f1 = 0
    # for l in range(Y_t.shape[1]):
    #     best_thres = 0.5
    #     for thres in np.arange(0.05, 1.0, 0.05):
    #         thres_vector[0,l]=thres
    #         decisions = Y_pred > thres_vector 
    #         f_macro,_ = compute_f_measure(Y_t, decisions)
    #         if f_macro > best_f1:
    #             best_thres = thres
    #             best_f1=f_macro
    #     thres_vector[0,l]=best_thres
    # print('Decision Thresholds:' , thres_vector)
    # with open('decision_thresholds.json','w') as fp:
    #     json.dump(list(thres_vector.reshape(-1)),fp)
#==============================================================================
    
    import pickle
    
    filehandle = open(model_name,'wb')
    pickle.dump(wrap,filehandle)
    filehandle.close()
    print('stored model as',model_name)




    
    # def compute_f_scores(conf_matrix):
    #     n_class = conf_matrix.shape[1]
    #     f_scores = np.zeros((n_class,1))
    #     for n in range
    #         r = np.arange(n_class)!=n
    #         tp = conf_matrix[n,n]
    #         fp = conf_matrix[r,n].sum() 
    #         fn = conf_matrix[n,r].sum()
    #         f_scores[n] = tp/(tp+0.5*(fp+fn))
    #     return f_scores
    
    # #copied from challenge evaluation
    # def compute_accuracy(labels, outputs):
    #     num_recordings, num_classes = np.shape(labels)
    
    #     num_correct_recordings = 0
    #     for i in range(num_recordings):
    #         if np.all(labels[i, :]==outputs[i, :]):
    #             num_correct_recordings += 1
    
    #     return float(num_correct_recordings) / float(num_recordings)

    #%% Evaluation on Validation Set
    print('evaluating model ...')
    #filehandle2 = open(model_name,'r')
    #wrap = pickle.load(filehandle2)
    val_size = len(data_test)
    X_t = np.zeros((val_size,2,sample_len))
    for i in range(val_size):
        X_t[i,:,:]=data_test[i][0].numpy()
    X_t = torch.Tensor(X_t)
    Y_t = np.zeros((val_size,26))
    
    for i in range(val_size):
        Y_t[i,:]=data_test[i][1]
    
    sig,lab = wrap.predict(X_t)
    lab = lab.numpy()
    decisions = lab>0.5
    #only true classes
    macro_f_measure, f_measure = compute_f_measure(Y_t, decisions)
    macro_auroc, macro_auprc, auroc, auprc = compute_auc(Y_t, lab)
    accuracy = compute_accuracy(Y_t,decisions)
    #challenge_metric=compute_challenge_metric(weights, labels, outputs, classes, sinus_rhythm)
    
    output_string = 'AUROC,AUPRC,Accuracy,F-measure\n{:.3f},{:.3f},{:.3f},{:.3f}'.format(macro_auroc, macro_auprc, accuracy, macro_f_measure)
    
    print('Validation Scores:',output_string)
    

if __name__=='__main__':
    
    parser = argparse.ArgumentParser(description='Train Deep Modell')
    parser.add_argument('data_dir', action='store',type=str)
    parser.add_argument('--model_name', action='store',type=str, default='model.pkl')
    parser.add_argument('--epochs', action='store',type=int,default=10)
    parser.add_argument('--alpha', help='Weight on Autoncoder Loss function 0-1',action='store',type=float,default=0.0)
    parser.add_argument('--random_cuts', help='number of random cuts of recording for training',action='store',type=int,default=2)

    model_dir='.'
    args = parser.parse_args()
    data_dir = args.data_dir
    parameters = {'model_name':args.model_name,'epochs':args.epochs,'alpha':args.alpha,'random_cuts':args.random_cuts}
    train_deep_model(data_dir,model_dir,parameters)
    

