#!/usr/bin/env python
# Copyright 2020, TATA Consultancy Services. All rights reserved.
from __future__ import print_function
import os
import time
import numpy
import random
import tensorflow
seed_value = 42
seed = 42
tensorflow.compat.v1.set_random_seed(seed_value)
tensorflow.random.set_seed(seed)
# import keras
# import keras.callbacks

import numpy as np
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from Parameter import Parameter_dictionary
from scipy.io import loadmat
import pickle
# from Validation_data_preparation import Validation_Data_Preparation
from tensorflow.keras.models import load_model
import joblib
from get_12ECG_features import get_12ECG_features
from CNNet import CNNet
# from Conv_LSTM_Net import Conv_LSTM_Net

# from LSTM_model import LSTM_model
from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import MinMaxScaler
# import matplotlib.pyplot as plt
# import keras.backend as K
# from evaluate_12ECG_score_loss_function import evaluate_12ECG_score_loss_function
# import keras.backend as K
# from sklearn.impute import SimpleImputer
# from sklearn.ensemble import RandomForestClassifier

seed_value = 20
# 1. Set `PYTHONHASHSEED` environment variable at a fixed value
os.environ['PYTHONHASHSEED']=str(seed_value)
# 2. Set `python` built-in pseudo-random generator at a fixed value
random.seed(seed_value)
# 3. Set `numpy` pseudo-random generator at a fixed value
np.random.seed(seed_value)
#Generate 5 random numbers between 10 and 30

# from skmultilearn.model_selection.measures import get_combination_wise_output_matrix

def train_12ECG_classifier(input_directory, output_directory):
    Param = Parameter_dictionary()
    UniformTimelength = Param.get("TimeLength")
    Sampling_Rate_uniform = Param.get("Uniform_Sampling_Rate")
    SliceLength = Param.get("SliceLength")

    
    ColumnCount = UniformTimelength*Sampling_Rate_uniform
    nrw = Param.get("nrw")
    ReshapeInputData = Param.get("ReshapeInputData")
    MinMaxNor = Param.get("MinMaxNormalization")
    CNN_LSTM = Param.get("CNN_LSTM")
    CNNet_val = Param.get("CNNet")
    OnlyScoringClass = Param.get("TrainOnlyScoringClasses")

    MultiLabel_2_multiClass = Param.get("MultiLabel_2_multiClass")
    loss_function = 'binary_crossentropy'
    if MultiLabel_2_multiClass==1:
        loss_function = 'categorical_crossentropy'
        

    
    # Load data.
    print('Loading training and validation data...')
    header_files = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            header_files.append(g)

    
    classes = get_classes(input_directory, header_files)
    num_classes = len(classes)
    num_files = len(header_files)
    recordings = list()
    headers = list()
    
    labels = list()
    print('Number of files read:',num_files)
    print('Number of classes present:',num_classes)
    
    for i in range(num_files):
        #print(i)
        recording, header = load_challenge_data(header_files[i])
        recording = np.nan_to_num(recording)

        recordings.append(recording)
        headers.append(header)
        
        for l in header:
            if l.startswith('#Dx:'):
                labels_act = np.zeros(num_classes)
                arrs = l.strip().split(' ')
                for arr in arrs[1].split(','):
                    class_index = classes.index(arr.rstrip()) # Only use first positive index
                    labels_act[class_index] = 1
        labels.append(labels_act)
    labels = np.array(labels)
    
        # Train model.
    if OnlyScoringClass == 1:
        recordings,labels,classes = OnlyScoringClassInstances(recordings,labels,classes)
        print('Number of files remaining  after retaining labels containing scoring classes:',len(labels))
        print('Number of classes remaining  after retaining labels containing scoring classes:',len(classes))
    
    print('Training and validation data preprocessing and feature generation...')

    features = list()
    DNN_ready_data = list()
    # HCF_repeated = list()
    Labels_repeated_master = list()
    
    # DNN_ready_label = list()

    for i in range(len(labels)):
        #print(i)
        #print(f)
        #if i == 4053:
        recording = recordings[i]
        #print(recording.shape)
        header = headers[i]
        CurrentLabel = labels[i, :]
        tmp, generated_features = get_12ECG_features(recording, header, Param)
        DataSliced, LabelRepeat, generated_features_repeated = Data_slicing(tmp, Param, CurrentLabel,
                                                                            generated_features)
        # print(i)
        # DNN_ready_data.append(tmp)
        DNN_ready_data.extend(DataSliced)

        # DNN_ready_label.append(tmp_label)
        features.extend(generated_features_repeated)
        Labels_repeated_master.extend(LabelRepeat)


        
    features = np.array(features)
    features = np.nan_to_num(features)
    DNN_ready_data = np.array(DNN_ready_data)
    Labels_repeated_master = np.array(Labels_repeated_master)
    labels = Labels_repeated_master
    ColumnCount = DNN_ready_data.shape[2]
    
    # DNN_ready_label = np.array(DNN_ready_label)

    
    if MultiLabel_2_multiClass ==1:
        print('MultiLabel_2_multiClass')
        MultiClassData = list()
        MultiClassLabel = list()

        for i in range(len(labels)):
            #print(i)
            MultiLabelCount = np.sum(labels[i,:])
            LabelLocations = np.where(labels[i,:]==1)
            LabelLocations = np.array(LabelLocations)
            # print('MultiLabelCount:',MultiLabelCount)
            for j in range(int(MultiLabelCount)):
                MultiClassData.append(DNN_ready_data[j,:])
                Temp = np.zeros((1,len(classes)),dtype=int)
                Temp[0,LabelLocations[0,j]] = 1
                MultiClassLabel.append(Temp)
                
        DNN_ready_data = np.array(MultiClassData)
        labels = np.array(MultiClassLabel)
        print('Number of files  after multilabel to multiclass:',len(labels))

    # if InstanceCloning ==1:
    #     print('InstanceCloning')
    #     ClonedData = list()
    #     ClonedLabel = list()

    #     for i in range(len(labels)):
    #         print(i)
    #         MultiLabelCount = np.sum(labels[i,:])
    #         LabelLocations = np.where(labels[i,:]==1)
    #         LabelLocations = np.array(LabelLocations)
    #         # print('MultiLabelCount:',MultiLabelCount)
    #         for j in range(int(MultiLabelCount)):
    #             MultiClassData.append(DNN_ready_data[j,:])
    #             Temp = np.zeros((1,len(classes)),dtype=int)
    #             Temp[0,LabelLocations[0,j]] = 1
    #             MultiClassLabel.append(Temp)
                
    #     DNN_ready_data = np.array(MultiClassData)
    #     labels = np.array(MultiClassLabel)
    #     print('Number of files  after multilabel to multiclass:',len(labels))

    indices = []
    for i in range(len(DNN_ready_data)):
        indices.append(i)
    indices = np.array(indices)
    X_train, X_valid, y_train, y_valid,train_indices,valid_indices = train_test_split(DNN_ready_data, labels,indices,test_size=0.2, random_state=42)
    print('Number of training files:',len(y_train))
    print('Number of validating files:',len(y_valid))
    X_train_HCF = features[train_indices,:]
    X_valid_HCF = features[valid_indices,:]
    
    
    #Normalized HCF
    Min_X_train_HCF = np.min(X_train_HCF,axis= 0)
    Max_X_train_HCF = np.std(X_train_HCF,axis= 0)
    Range = Max_X_train_HCF-Min_X_train_HCF
    zeroInd = np.where(Range == 0)[0]
    if len(zeroInd) != 0:
        Range[zeroInd] = 1
    
    X_train_HCF = (X_train_HCF-Min_X_train_HCF)/Range
    X_valid_HCF = (X_valid_HCF-Min_X_train_HCF)/Range
#
    
    ### Weight computation 
    train_class_distribution = np.sum(y_train,axis = 0)
    train_class_distribution_scaled = train_class_distribution/np.sum(train_class_distribution)
    for i in range(len(train_class_distribution_scaled)):
        train_class_distribution_scaled[i] = np.round(1/train_class_distribution_scaled[i],2)
    print('Data dimension:',X_train.shape[1],X_train.shape[2])
    
    # X_train,y_train,classes = OnlyScoringClassInstances(X_train,y_train,classes)
    # X_valid,y_valid,classes = OnlyScoringClassInstances(X_valid,y_valid,classes)

    # X_train,y_train = Limit_unique_classes(X_train,y_train)
    



   
    MaxVal_Per_Channel = [0] * nrw 
    MaxVal_Per_Channel = numpy.array(MaxVal_Per_Channel) 
    MinVal_Per_Channel = [0] * nrw 
    MinVal_Per_Channel = numpy.array(MinVal_Per_Channel) 
    if MinMaxNor == 1:
        #Minmax normalization
        print('Min-Max Normalization applied')

        for i in range(nrw):
            ith_Channel_train_data = X_train[:,i,:]
            ith_Channel_train_data_max = ith_Channel_train_data.max(axis=1)
            ith_Channel_train_data_min = ith_Channel_train_data.min(axis=1)
            MaxVal_Per_Channel[i] = np.percentile(ith_Channel_train_data_max,90)
            MinVal_Per_Channel[i] = np.percentile(ith_Channel_train_data_min,10)
        MaxVal_Per_Channel = MaxVal_Per_Channel.reshape(len(MaxVal_Per_Channel),1)
        MinVal_Per_Channel = MinVal_Per_Channel.reshape(len(MinVal_Per_Channel),1)
        Range_per_Channel = MaxVal_Per_Channel - MinVal_Per_Channel
        if 0 in Range_per_Channel:
            Zero_Range_loc = np.where(Range_per_Channel == 0)
            Range_per_Channel[Zero_Range_loc] = 1
        for i in range(nrw):
            ith_Channel_train_data = X_train[:,i,:]      
            ith_Channel_train_data = (ith_Channel_train_data-MinVal_Per_Channel[i])/Range_per_Channel[i]
            X_train[:,i,:] = ith_Channel_train_data 
        for i in range(nrw):
            ith_Channel_valid_data = X_valid[:,i,:]      
            ith_Channel_valid_data = (ith_Channel_valid_data-MinVal_Per_Channel[i])/Range_per_Channel[i]
            X_valid[:,i,:] = ith_Channel_valid_data   

    if ReshapeInputData == 1:
            X_train = X_train.reshape(len(X_train),int(nrw*8),int((Sampling_Rate_uniform*UniformTimelength)/8))
            X_valid = X_valid.reshape(len(X_valid),int(nrw*8),int((Sampling_Rate_uniform*UniformTimelength)/8))
            ColumnCount = int((Sampling_Rate_uniform*UniformTimelength)/8)
            nrw = int(nrw*8)
    # print('Reading validation data and labels complete')

    # Replace NaN values with mean values
    #1435818 edited
    # imputer=SimpleImputer().fit(features)
    # features=imputer.transform(features)
    #1435818 edited : generate random number for train test classification
    # randomlist = random.sample(range(1, num_files), round(num_files*.2))
    # Train the classifier
    # Imbalance_handling for the train data: using X_train_HCF X_train y_train
    Indices_resampled = list()
    Indices_classwise = list()

    for i in range(len(classes)):
        Indices_class_x = np.where(y_train[:, i] == 1)
        Indices_classwise.append(Indices_class_x)
    X_train_distribution = np.sum(y_train, axis=0)
    Fulcrum_point = int(np.round(np.sum(X_train_distribution) / len(classes)))
    Resample_count = Fulcrum_point - X_train_distribution

    for i in range(len(Resample_count)):
        class_indices_x = Indices_classwise[i][0]
        if Resample_count[i] >= 1:
            resample_count_remain = Resample_count[i]
            temp = list()
            while resample_count_remain != 0:
                random_resampled_indices = random.sample(range(len(class_indices_x)),
                                                         int(np.min([resample_count_remain, len(class_indices_x)])))
                random_resampled_indices = class_indices_x[random_resampled_indices]
                temp.extend(random_resampled_indices)
                resample_count_remain = resample_count_remain - int(np.min([resample_count_remain, len(class_indices_x)]))
                if resample_count_remain <= -1:
                    resample_count_remain = 0
            temp = np.array(temp)
            Indices_resampled.extend(temp)
            Indices_resampled.extend(class_indices_x)

            print('class index, count,Fulcrum_point,resamples', i, len(class_indices_x), Fulcrum_point, len(temp))
        else:
            Indices_resampled.extend(class_indices_x)
            print('class index, count,Fulcrum_point', i, len(class_indices_x), Fulcrum_point)

    Indices_resampled = np.array(Indices_resampled)
    X_train = X_train[Indices_resampled]
    y_train = y_train[Indices_resampled]
    X_train_HCF = X_train_HCF[Indices_resampled]
    X_train_resampled_distribution = np.sum(y_train, axis=0)

    #===Train=======#
    print('Training model...')

    t2 = time.clock()
    train(X_train, y_train, X_valid, y_valid,ColumnCount,nrw,classes,Param,output_directory,MaxVal_Per_Channel, MinVal_Per_Channel,CNNet_val,CNN_LSTM,loss_function,X_train_HCF,X_valid_HCF,Min_X_train_HCF,Max_X_train_HCF)
    t3 = time.clock()
    #print("Time elapsed for Training in Hours: ", round((t3 - t2) / 3600,2))
    # model = RandomForestClassifier().fit(features,labels)

    # Save model.
    # print('Saving model...')

    # final_model={'model':model, 'imputer':imputer,'classes':classes}

    # filename = os.path.join(output_directory, 'finalized_model.sav')
    # joblib.dump(model, filename, protocol=0)


    # joblib.dump(model, filename, protocol=0)

# Load challenge data.
def load_challenge_data(header_file):
    with open(header_file, 'r') as f:
        header = f.readlines()
    mat_file = header_file.replace('.hea', '.mat')
    x = loadmat(mat_file)
    recording = np.asarray(x['val'], dtype=np.float64)
    return recording, header

# Find unique classes.
def get_classes(input_directory, filenames):
    classes = set()
    for filename in filenames:
        with open(filename, 'r') as f:
            for l in f:
                if l.startswith('#Dx'):
                    tmp = l.split(': ')[1].split(',')
                    for c in tmp:
                        classes.add(c.strip())
    return sorted(classes)

def OnlyScoringClassInstances(data,labels,classes):
    data_screened = list()
    labels_screened = list()
    Scored_Labels = np.loadtxt("scored_labels.csv", dtype=str) # 27 Labels
    scoring_labels = Scored_Labels.tolist()
    scoring_labels = set(scoring_labels).intersection(classes)
    scoring_labels = list(scoring_labels)
    ScoringClass_mapping = np.zeros(len(scoring_labels))
    for i in range(len(scoring_labels)):
        ScoringClass_mapping[i] = classes.index(scoring_labels[i])
    ScoringClass_mapping_int = ScoringClass_mapping.astype(int)
    for i in range(len(labels)):
        Current_label = labels[i,:]
        Current_data = data[i]
        ScoringClass_loc_current_label = Current_label[ScoringClass_mapping_int]
        if (sum(ScoringClass_loc_current_label)>=1):
            data_screened.append(Current_data)
            labels_screened.append(Current_label)
    labels = np.array(labels_screened)
    labels = labels[:,ScoringClass_mapping_int]
    # data_screened = np.array(data_screened)
    return data_screened, labels,scoring_labels

def Limit_unique_classes(data,labels):
    unique_labels = np.unique(labels,axis=0)
    Count_Classes_retained = int(np.log(len(labels)) + np.sqrt(len(labels)))
    Trimmed_Labels = list()
    Trimmed_data = list()

    Unique_labels_count  = list()
    Loc_iden_label_instances = [0]
    Loc_iden_label_instances_store = list()

    for unique_label in unique_labels:
        Loc_iden_label_instances = np.where(np.all(labels==np.transpose(unique_label),axis=1))
        Loc_iden_label_instances_store.append(Loc_iden_label_instances)
        Unique_labels_count.append(len(Loc_iden_label_instances[0]))
    
    SortedIndices = np.argsort(np.array(Unique_labels_count))
    SortedIndices = SortedIndices[::-1]   
    SortedIndices = SortedIndices[0:Count_Classes_retained]
    for i in SortedIndices:
        locations = Loc_iden_label_instances_store[i]
        
        temp_label = np.array(labels[locations])
        Trimmed_Labels.extend(temp_label)
        for l in locations[0]:
            
            Trimmed_data.append(data[l])
    Trimmed_Labels = np.array(Trimmed_Labels)
    Trimmed_data = np.array(Trimmed_data)
    return Trimmed_data,Trimmed_Labels

def Data_slicing(data,Param,CurrentLabel,generated_features):
    DataSliced = list()
    LabelRepeat = list()
    generated_features_repeated = list()
    SlicedDataLength = Param.get("SliceLength")
    Target_sampling_rate = Param.get("Uniform_Sampling_Rate")
    SliceWidth = SlicedDataLength*Target_sampling_rate
    dataLen = data.shape[1]
    SliceCount = int(np.floor(dataLen/SliceWidth))
    data_extention = 0
    while (dataLen < SliceWidth):
        data = np.repeat(data,2,axis=1)
        dataLen = data.shape[1]
        data_extention = 1
    if data_extention == 1:
        data = data[:, 0:SliceWidth]

    SliceCount = int(np.floor(dataLen / SliceWidth))
    # print('SliceWidth,dataLen,SliceCount:',SliceWidth,dataLen,SliceCount)
    for i in range(SliceCount):
        DataSliced.append(data[:,0+i*SliceWidth:((i+1)*SliceWidth)])
        LabelRepeat.append(CurrentLabel)
        generated_features_repeated.append(generated_features)
    DataSliced = np.array(DataSliced)
    LabelRepeat = np.array(LabelRepeat)
    generated_features_repeated = np.array(generated_features_repeated)
    # if SliceCount >= 2:              
    #     for i in range(SliceCount-1):
    #         DataSliced.append(data[:,0+i*SliceWidth:((i+1)*SliceWidth)])
    #         LabelRepeat.append(CurrentLabel)   
    #     DataSliced.append(data[:,dataLen-SliceWidth:dataLen])
    #     LabelRepeat.append(CurrentLabel)
    # elif SliceCount == 1:
    #     DataSliced.append(data)
    #     LabelRepeat.append(CurrentLabel)
    # DataSliced = np.array(DataSliced)
    # LabelRepeat = np.array(LabelRepeat)
            
    return DataSliced,LabelRepeat,generated_features_repeated

def train(X_train, y_train, X_valid, y_valid,ncl,nrw,considered_classes,Param,output_directory,MaxVal_Per_Channel, MinVal_Per_Channel,CNNet_val,CNN_LSTM,loss_function,X_train_HCF,X_valid_HCF,Min_X_train_HCF,Max_X_train_HCF):
   
    num_classes = len(considered_classes)
    # model = LSTM_model(nrw,ncl,num_classes)
    if CNNet_val == 1:
        X_train = X_train.reshape(X_train.shape[0], nrw , ncl,1)# (Data, nrw, ncl) # Remove 4th dimension for Chinese model
        X_valid = X_valid.reshape(X_valid.shape[0],nrw , ncl,1)# (Data, nrw, ncl) # Remove 4th dimension for Chinese model
        model = CNNet(nrw,ncl,num_classes)
    if CNN_LSTM == 1:
        X_train = X_train.transpose(0, 2, 1)
        X_valid = X_valid.transpose(0, 2, 1)
        n_steps, n_length = 4, 150
        n_features = nrw
        X_train = X_train.reshape((X_train.shape[0], n_steps, n_length, n_features))
        X_valid = X_valid.reshape((X_valid.shape[0], n_steps, n_length, n_features))
        # Load Model
        model = Conv_LSTM_Net(n_length,n_features,num_classes)
        

    # model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    # # model.summary()
    # filename = output_directory+"/Final_model.hdf5"
    # early = EarlyStopping(monitor='val_accuracy', min_delta=0, patience=1, verbose=1, mode='max')

    # checkpoint = ModelCheckpoint(filename,monitor='val_accuracy', verbose=0, save_best_only=True,mode='max')
    # model.fit(X_train, y_train, validation_data=(X_valid, y_valid), batch_size=256, epochs=50,callbacks=[early,checkpoint])
    # def custom_loss(y_true, y_pred):
    #     # weights = [32.81,100.79,235.17,36.18,39.19,64.14,52.26,70.55,1411,15.85,61.35,705.5,38.14,22.4,2.54,56.44,58.79,67.19,176.38,7.13,25.65,23.13,88.19,352.75,94.07,201.57]
    #     # weights = np.array(weights)
    #     # weighted_values = y_true * K.abs(1-y_pred) + (1-y_true) * K.abs(y_pred)
    #     # loss = K.mean(weighted_values)
            
    #    # def compute_challenge_metric_customized(weights, labels, outputs, classes, normal_class):
    #     # num_recordings, num_classes = np.shape(labels)
    #     # normal_index = classes.index(normal_class)
    #     labels = y_true
    #     outputs = y_pred
    #     weights = np.ones((24,24))
    #     # Compute the observed score.
    #     A = compute_modified_confusion_matrix(labels, outputs)
    #     observed_score = np.nansum(weights * A)
    
    #     # Compute the score for the model that always chooses the correct label(s).
    #     correct_outputs = labels
    #     A = compute_modified_confusion_matrix(labels, correct_outputs)
    #     correct_score = np.nansum(weights * A)
    
    #     # Compute the score for the model that always chooses the normal class.
    #     inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool)
    #     inactive_outputs[:, normal_index] = 1
    #     A = compute_modified_confusion_matrix(labels, inactive_outputs)
    #     inactive_score = np.nansum(weights * A)
    
    #     if correct_score != inactive_score:
    #         normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score)
    #     else:
    #         normalized_score = float('nan')
    
    #     return normalized_score
        # return loss
    # model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])#binary_crossentropy

    # model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])#binary_crossentropy

    # model.compile(optimizer='adam', loss=evaluate_12ECG_score_loss_function, metrics=['accuracy'])#binary_crossentropy

    # model.compile(optimizer='adam', loss=loss_function, metrics=['accuracy'])#binary_crossentropy
    from tensorflow.keras.optimizers import Adam
    import json
    params = json.load(open("config_resnet.json"))
    optimizer = Adam(
        lr=params["learning_rate"],
        clipnorm=params.get("clipnorm", 1))

    model.compile(loss=loss_function,
                  optimizer=optimizer,
                  metrics=['accuracy'])

    model.summary()
    filename = output_directory+"/Final_model.hdf5"
    checkpoint = ModelCheckpoint(filename,
                                 monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False,
                                 mode='auto', period=1)

    early = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode='auto')
    # train the model
    # hist = model.fit([X_train,X_train_HCF], y_train, validation_data=([X_valid,X_valid_HCF], y_valid), batch_size=32, epochs=40, callbacks=[checkpoint, early])
    # history = model.fit([X_train,X_train_HCF], y_train, validation_data=([X_valid,X_valid_HCF], y_valid), batch_size=32, epochs=80, callbacks=[checkpoint, early])
    X_trainR = X_train[:, 1,:,:]
    # X_trainR = X_trainR.reshape(140,1,500)
    X_validR = X_valid[:, 1,:,:]
    # X_validR = X_validR.reshape(36, 1, 500)
    #print("X_train:", X_train.shape)
    #print("X_trainR:", X_trainR.shape)
    #print("X_valid:",X_valid.shape)
    #print("X_validR:", X_validR.shape)
    # exit()
    reduce_lr = tensorflow.keras.callbacks.ReduceLROnPlateau(
        factor=0.1,
        patience=2,
        min_lr=params["learning_rate"] * 0.001)

    history = model.fit([X_train, X_trainR, X_train_HCF], y_train, validation_data=([X_valid, X_validR, X_valid_HCF], y_valid),
                        batch_size=32, epochs=200, callbacks=[checkpoint, reduce_lr, early])

    # import matplotlib.pyplot as pyplot
    # fig1 = pyplot.gcf()
    # # pyplot.yscale('log')
    # pyplot.title('Learning Curves')
    # pyplot.xlabel('Epoch')
    # pyplot.ylabel('Loss')
    # pyplot.plot(history.history['loss'], label='train')
    # pyplot.plot(history.history['val_loss'], label='val')
    # pyplot.legend()
    # fig1.savefig('train_vs_eval_loss.png', dpi=300)
    # pyplot.show()
    # fig2 = pyplot.gcf()
    # # pyplot.yscale('log')
    # pyplot.title('Learning Curves')
    # pyplot.xlabel('Epoch')
    # pyplot.ylabel('Accuracy')
    # pyplot.plot(history.history['accuracy'], label='train')
    # pyplot.plot(history.history['val_accuracy'], label='val')
    # pyplot.legend()
    # fig2.savefig('train_vs_eval_accuracy.png', dpi=300)
    # pyplot.show()
    # from contextlib import redirect_stdout
    # with open('modelsummary.txt', 'w') as f:
    #     with redirect_stdout(f):
    #         model.summary()

    # Save History
    # pd.DataFrame(hist.history).to_csv(path_or_buf='History.csv')
        # Load the best model
    # load_best_model = load_model(filename)
    # Save model.
    print('Saving model parameters...')
    # filename = output_directory+"/finalized_model.h5"
    # model.save(filename)
    # final_model = {'model': load_best_model, 'classes': considered_classes, 'Parameters':Param,'MinMax_nor_max':MaxVal_Per_Channel,'MinMax_nor_min': MinVal_Per_Channel}
    model_parameters = {'classes': considered_classes, 'Parameters':Param,'MinMax_nor_max':MaxVal_Per_Channel,'MinMax_nor_min': MinVal_Per_Channel,'Min_train_HCF':Min_X_train_HCF,'Max_train_HCF':Max_X_train_HCF}

    filename = output_directory+"/final_model_parameters.pickle"
    
    
    # mydict = {'a': 1, 'b': 2, 'c': 3}
    output = open(filename, 'wb')
    pickle.dump(model_parameters, output)
    output.close()

    # joblib.dump(model_parameters, filename, protocol=0)
    

# Compute modified confusion matrix for multi-class, multi-label tasks.
# def compute_modified_confusion_matrix(labels, outputs):
#     # Compute a binary multi-class, multi-label confusion matrix, where the rows
#     # are the labels and the columns are the outputs.
#     num_recordings, num_classes = K.shape(labels)
#     # A = np.zeros((num_classes, num_classes))
#     A = K.zeros((num_classes, num_classes))


#     # Iterate over all of the recordings.
#     for i in range(len(labels)):
#         # Calculate the number of positive labels and/or outputs.
#         normalization = float(max(K.sum(K.any((labels[i, :], outputs[i, :]), axis=0)), 1))
#         # Iterate over all of the classes.
#         for j in range(num_classes):
#             # Assign full and/or partial credit for each positive class.
#             if labels[i, j]:
#                 for k in range(num_classes):
#                     if outputs[i, k]:
#                         A[j, k] += 1.0/normalization

#     return A
    # ==============================TRaining End ===================================================#
    
# def custom_loss(y_true, y_pred):
#     weighted_values = y_true * K.abs(1-y_pred) + (1-y_true) * K.abs(y_pred)
#     loss = K.mean(weighted_values)
#     return loss
        
    # return custom_loss
    
    # custom_loss_five = custom_loss_wrapper(fn_cost=5, fp_cost=1)

#model_five.compile(loss=custom_loss_five,
             #optimizer='sgd',
             #metrics=['accuracy'])
# model.compile(optimizer='adam', loss= custom_loss_five, metrics=['accuracy'])

