#!/usr/bin/env python

import numpy as np, os, sys, joblib
from scipy.io import loadmat
from scipy.signal import butter, lfilter,resample_poly
import gc
from pathlib import Path
import math
from sklearn.svm import SVC
import tensorflow as tf
from tensorflow.keras.layers import Conv1D, MaxPool1D, Flatten, Dense, Input, BatchNormalization,GlobalMaxPool1D
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.callbacks import ModelCheckpoint
from wfdb import processing
from functools import partial
import pickle

def run_12ECG_classifier(data,header_data,loaded_model):
    model_directory='model'
    X = preprocess_data(data,header_data)
    scored_classes={270492004:"IAVB",
                164889003:"AF",
                164890007:"AFL",
                426627000:"Brady",
                713427006:"CRBBB",
                713426002:"IRBBB",
                445118002:"LAnFB",
                39732003:"LAD",
                164909002:"LBBB",
                251146004:"LQRSV",
                698252002:"NSIVCB",
                10370003:"PR",
                284470004:"PAC",
                427172004:"PVC",
                164947007:"LPR",
                111975006:"LQT",
                164917005:"QAb",
                47665007:"RAD",
                59118001:"RBBB",
                427393009:"SA",
                426177001:"SB",
                426783006:"SNR",
                427084000:"STach",
                63593006:"SVPB",
                164934002:"TAb",
                59931005:"TInv",
                17338001:"VPB"
                }
    classes=sorted(scored_classes)
    scored_classes,classes=equvialentClassesConversion(scored_classes,classes)
    #Get Features from feature Extractor
    feat_test = loaded_model.predict(X)
    y_test_pred=[]
    y_test_pred_score=[]
    for idx,current_class in enumerate(classes):
        
        # load the model from disk
        filename='svc_'+str(current_class)+'.sav'
        filename=os.path.join(model_directory,filename)
        loaded_model = pickle.load(open(filename, 'rb'))
        # Make predictions
        y_test_pred.append(loaded_model.predict(feat_test))
        y_test_pred_score.append(loaded_model.predict_proba(feat_test)[:,-1])
    
    y_test_pred = np.squeeze(np.asarray(y_test_pred, dtype=np.int),axis=1)
    y_test_pred_score = np.squeeze(np.asarray(y_test_pred_score, dtype=np.float32),axis=1) 
    str_classes=list(map(str, classes))

    return y_test_pred, y_test_pred_score,str_classes


def equvialentClassesConversion(scored_classes,classes):
    equivalent_classes_collection = [[713427006, 59118001], [284470004, 63593006], [427172004, 17338001]]
    # For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
    # The label for the representative class is positive if any of the labels in the set is positive.
    remove_classes = list()
    remove_indices = list()
    for equivalent_classes in equivalent_classes_collection:
        equivalent_classes = [x for x in equivalent_classes if x in classes]
        if len(equivalent_classes)>1:
            representative_class = equivalent_classes[0]
            other_classes = equivalent_classes[1:]
            equivalent_indices = [classes.index(x) for x in equivalent_classes]
            representative_index = equivalent_indices[0]
            other_indices = equivalent_indices[1:]

            remove_classes += other_classes
            remove_indices += other_indices

    for x in remove_classes:
        classes.remove(x)
        del scored_classes[x]
    
    return scored_classes,classes


def bandpass_filter(data, lowcut = 0.001, highcut = 15.0, signal_freq = 500, filter_order = 1):
        """
        Method responsible for creating and applying Butterworth filter.
        :param deque data: raw data
        :param float lowcut: filter lowcut frequency value
        :param float highcut: filter highcut frequency value
        :param int signal_freq: signal frequency in samples per second (Hz)
        :param int filter_order: filter order
        :return array: filtered data
        """
        nyquist_freq = 0.5 * signal_freq
        low = lowcut / nyquist_freq
        high = highcut / nyquist_freq
        b, a = butter(filter_order, [low, high], btype="band")
        y = lfilter(b, a, data)
        return y
    
def extend_ts(ts, length):
    extended = np.zeros(length)
    siglength = np.min([length, ts.shape[0]])
    extended[:siglength] = ts[:siglength]
    return extended 

def preprocess_data(data,header_data):
    fs=500
    num_leads = 12
    frame_len = 15000
    normalize = partial(processing.normalize_bound, lb=-1, ub=1)

    
    # Creating temporary variables for current signal and label
    temp_x = np.zeros((1, frame_len, num_leads), dtype = np.float32)
    
    
    #___________________________________________________________________________
    # Reading Header data and processing it
    #
    # Read sampled frequency
    first_line = header_data[0].split('\n', 1)[0]
    sampled_fs=int(first_line.split(' ')[2])
    
    # If sample frequency is not 500. Resample data
    if sampled_fs!=fs:
        data=resample_poly(data, fs, sampled_fs,axis=1)
        
    
    #___________________________________________________________________________
    # Reading Signal data and processing it
    #
    
    # If length of ecg signal is greater than the frame length just truncate it. 
    if data.shape[1] > frame_len:
        data = data[:,:frame_len]

    extended_data = np.zeros((num_leads,frame_len))

    for j in range(num_leads):

        # If all values in a lead are not zero. 
        if data[j,:].any():
            # Frame Normalization
            data[j,:] = np.squeeze(np.apply_along_axis(normalize, 0, data[j,:]))

        # padding zeros and bandpass filtering. 
        extended_data[j,:] = bandpass_filter(extend_ts(data[j,:], length = frame_len))

    X = extended_data.T
    
    X=np.expand_dims(X, axis=0)
    
    return X


def load_12ECG_model(input_directory):
    # load the model from disk 
    input_directory='model'
    frame_len=15000
    num_classes=24
    model_cnn=create_model(frame_len,num_classes)
    model_path_cnn='cnn_model.h5'
    model_path_cnn=os.path.join(input_directory,model_path_cnn)
    model_cnn.load_weights(model_path_cnn)
    model_feat = Model(inputs=model_cnn.input,outputs=model_cnn.get_layer('global_max_pooling1d').output)
    return model_feat

def create_model(frame_len,num_classes):
    
        model = Sequential([   
        Input(shape=(frame_len, 12)),
        Conv1D(64, 15, activation='relu'),
        MaxPool1D(2),
        Conv1D(64, 15, activation='relu'),
        MaxPool1D(2),
        Conv1D(64, 15, activation='relu'),
        MaxPool1D(2),
        Conv1D(64, 9, activation='relu'),
        MaxPool1D(3),
        Conv1D(64, 9, activation='relu'),
        MaxPool1D(3),
        Conv1D(32, 9, activation='relu'),
        MaxPool1D(3),
        Conv1D(32, 3, activation='relu'),
        MaxPool1D(4),
        GlobalMaxPool1D(),
        Dense(64, kernel_initializer='normal', activation='relu'),
        Dense(num_classes,activation='sigmoid', kernel_initializer='normal')])
        model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])
        
        return model
