#!/usr/bin/env python
# Copyright 2020, TATA Consultancy Services. All rights reserved.

import numpy as np, os, sys
import joblib
from tensorflow.keras.models import load_model
from get_12ECG_features import get_12ECG_features
import pickle
from sklearn.cluster import KMeans
from sklearn.preprocessing import MinMaxScaler

def run_12ECG_classifier(data,header_data,loaded_model):
    f_out = 'address.pkl'
    filename = os.path.join(os.getcwd(),f_out)
    
    pkl_file = open(filename, 'rb')
    mydict2 = pickle.load(pkl_file)
    pkl_file.close()
    
    # f_out = 'final_model_parameters.pkl'
    filename = mydict2+"/final_model_parameters.pickle"
    # filename = os.path.join(mydict2,f_out)
    
    pkl_file = open(filename, 'rb')
    loaded_model_parameters = pickle.load(pkl_file)
    pkl_file.close()
    # loaded_model = joblib.load(filename)
    # loaded_model_parameters = load_model(filename)
    # Use your classifier here to obtain a label and score for each class.
    model = loaded_model
    # imputer = loaded_model['imputer']
    classes = loaded_model_parameters['classes']
    Param = loaded_model_parameters['Parameters']
    MaxVal_Per_Channel = loaded_model_parameters['MinMax_nor_max']
    MinVal_Per_Channel = loaded_model_parameters['MinMax_nor_min']
    Min_X_train_HCF = loaded_model_parameters['Min_train_HCF']
    Max_X_train_HCF = loaded_model_parameters['Max_train_HCF']
    ReshapeInputData = Param.get("ReshapeInputData")
    UniformTimelength = Param.get("TimeLength")
    Sampling_Rate_uniform = Param.get("Uniform_Sampling_Rate")
    SliceLength = Param.get("SliceLength")
    ColumnCount = SliceLength*Sampling_Rate_uniform
    nrw = Param.get("nrw")
    MinMaxNor = Param.get("MinMaxNormalization")
    CNN_LSTM = Param.get("CNN_LSTM")
    CNNet_val = Param.get("CNNet")
    
    
    DNN_ready_data,generated_features = get_12ECG_features(data, header_data,Param)
    generated_features = np.nan_to_num(generated_features)
    DNN_ready_data = np.nan_to_num(DNN_ready_data)

    DNN_ready_data,generated_features = Data_slicing(DNN_ready_data,Param,generated_features)

    #Minmax normalization 
    # generated_features = np. reshape(generated_features, (len(generated_features), 14))
    
    #Normalized HCF

    Range = Max_X_train_HCF-Min_X_train_HCF
    zeroInd = np.where(Range == 0)[0]
    if len(zeroInd) != 0:
        Range[zeroInd] = 1
    
    
    generated_features = (generated_features-Min_X_train_HCF)/Range
  
    
    
    if MinMaxNor == 1:
        Range_per_Channel = MaxVal_Per_Channel - MinVal_Per_Channel
        DNN_ready_data = (DNN_ready_data-MinVal_Per_Channel)/Range_per_Channel
    if ReshapeInputData == 1:
        DNN_ready_data = DNN_ready_data.reshape(int(nrw*8),int((Sampling_Rate_uniform*UniformTimelength)/8))
        ColumnCount = int((Sampling_Rate_uniform*UniformTimelength)/8)
        nrw = int(nrw*8)
    if CNNet_val == 1:        
        DNN_ready_data = DNN_ready_data.reshape(DNN_ready_data.shape[0], nrw, ColumnCount, 1) # (1, nrw, batchS) # Remove 4th dimension for Chinese model
    if CNN_LSTM == 1:
        DNN_ready_data = DNN_ready_data.transpose(1, 0)
        n_steps, n_length = 4, 150
        n_features = nrw
        DNN_ready_data = DNN_ready_data.reshape((1, n_steps, n_length, n_features))

    current_score_list = list()
    for i in range(int(DNN_ready_data.shape[0])):
        DNN_ready_dataR = DNN_ready_data[:, 1, :, :]
        current_score = model.predict([DNN_ready_data, DNN_ready_dataR, generated_features])
        current_score  = np.around(current_score, decimals = 3)
        current_score_list.extend(current_score)
    current_score = np.mean(current_score,axis=0)
    current_score = np.transpose(np.reshape(current_score,[-1,1]))
    current_label = Pred_prob2Pred_class(current_score)
    current_label=current_label.astype(int)
    current_score=np.asarray(current_score)
    current_score = current_score[0]

    return current_label, current_score,classes

def load_12ECG_model(input_directory):
    # load the model from disk 
    f_out = 'Final_model.hdf5'
    filename = os.path.join(input_directory,f_out)
    loaded_model = load_model(filename)
    
    Model_address = input_directory
    present_working_directory = os.getcwd()
    f_out = 'address.pkl'
    filename = os.path.join(present_working_directory,f_out)
    output = open(filename, 'wb')
    pickle.dump(Model_address, output)
    output.close()
    
    return loaded_model

# def load_model_parameters(input_directory):
#     # load the model from disk 
#     f_out = 'address.sav'
#     filename = os.path.join(os.getcwd(),f_out)
#     # loaded_model = joblib.load(filename)
#     loaded_model = load_model(filename)

#     return loaded_model


def Pred_prob2Pred_class(Predicted_scores):
    Sorted_Predicted_scores=np.transpose(Predicted_scores)
    kmeans = KMeans(n_clusters=2, random_state=45).fit(Sorted_Predicted_scores)
    Actlabel = kmeans.labels_
    centers = kmeans.cluster_centers_
    c1 = centers[0][0]
    c2 = centers[1][0]
    #print(c1, c2)
    #print(Actlabel)
    if c1 > c2:
        id1 = np.where(Actlabel == 1)[0]
        id2 = np.where(Actlabel == 0)[0]
        Actlabel[id1] = 0
        Actlabel[id2] = 1
    # Predicted_Class_indices = np.where(Predicted_scores>0.01)
    # Predicted_Class_indices = np.array(Predicted_Class_indices)
    # Predicted_Classes = np.zeros((len(Predicted_scores[0]),), dtype=int)
    # Predicted_Classes[Predicted_Class_indices] = 1
    return Actlabel

def Data_slicing(data,Param,generated_features):
    DataSliced = list()
    generated_features_repeated = list()
    SlicedDataLength = Param.get("SliceLength")
    Target_sampling_rate = Param.get("Uniform_Sampling_Rate")
    SliceWidth = SlicedDataLength*Target_sampling_rate
    dataLen = data.shape[1]
    SliceCount = int(np.floor(dataLen/SliceWidth))
    while (dataLen < SliceWidth):
        data = np.repeat(data,2,axis=1)  
        data =  data[:,0:SliceWidth-1]
        dataLen = data.shape[1]
    #print('SliceWidth,dataLen,SliceCount:',SliceWidth,dataLen,SliceCount)
    for i in range(SliceCount):
        #print(i)
        DataSliced.append(data[:,0+i*SliceWidth:((i+1)*SliceWidth)])
        generated_features_repeated.append(generated_features)
    DataSliced = np.array(DataSliced)
    generated_features_repeated = np.array(generated_features_repeated)
    return DataSliced,generated_features_repeated

