#!/usr/bin/env python

import numpy as np, os, sys, joblib
from scipy.io import loadmat
from scipy.signal import butter, lfilter,resample_poly
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
import gc
from pathlib import Path
import math
from sklearn.svm import SVC
import tensorflow as tf
from tensorflow.keras.layers import Conv1D, MaxPool1D, Flatten, Dense, Input, BatchNormalization,GlobalMaxPool1D
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.callbacks import ModelCheckpoint
from wfdb import processing
from functools import partial
import pickle


def train_12ECG_classifier(input_directory, output_directory):
    
    
    
    # Define Model Parameters
    output_directory='model'
    bs = 12 #Batch size
    ep = 50 #epochs
    threshold=0.5 #Threshold
    frame_len = 15000 #Frame Length
    max_cases=2000 # Max cases for each class for CNN
    max_cases_svc=2000 # Max cases for each class for SVC
    #__________________________________________________________________
    # All scored Classes 
    #__________________________________________________________________
    scored_classes={270492004:"IAVB",
                    164889003:"AF",
                    164890007:"AFL",
                    426627000:"Brady",
                    713427006:"CRBBB",
                    713426002:"IRBBB",
                    445118002:"LAnFB",
                    39732003:"LAD",
                    164909002:"LBBB",
                    251146004:"LQRSV",
                    698252002:"NSIVCB",
                    10370003:"PR",
                    284470004:"PAC",
                    427172004:"PVC",
                    164947007:"LPR",
                    111975006:"LQT",
                    164917005:"QAb",
                    47665007:"RAD",
                    59118001:"RBBB",
                    427393009:"SA",
                    426177001:"SB",
                    426783006:"SNR",
                    427084000:"STach",
                    63593006:"SVPB",
                    164934002:"TAb",
                    59931005:"TInv",
                    17338001:"VPB"
                    }
    classes=sorted(scored_classes)
    # Define how many maximum cases for each class you want
    classes_cases= [max_cases] * len(classes)
    
    
    # Read All files with labels according to need as defined by classes_cases
    input_files,labels=readFilesWithLabels(input_directory=input_directory,classes=classes,classes_cases=classes_cases)
    
    #Read data from files
    X,y=readData(input_files,labels,classes,scored_classes)

    
    # Creating One Hot encoding scheme for given classes (27)
    n_features=1
    n_labels=len(classes)
    categories=[range(n_labels)]*n_features
    onehot_encoder = OneHotEncoder(categories=categories,sparse=False)

    # Encoding labels
    y = onehot_encoder.fit_transform(y.reshape(-1, 1))

    # Convert Equivalent Classes Labels
    scored_classes,classes,y=equvialentClassesConversion(scored_classes,classes,y)

    # Creating One Hot encoding scheme for new classes (24)
    n_features=1
    n_labels=len(classes)
    categories=[range(n_labels)]*n_features
    onehot_encoder = OneHotEncoder(categories=categories,sparse=False)

    #Split data into training and testing
    X_train, X_test, y_train, y_test,  = train_test_split(X, y,stratify=y, test_size=0.2, random_state=42)
    
    #gc.collect()
    del X,y,input_files
    
    # Define CNN model architecture
    model_cnn=create_model(frame_len,len(classes))
    model_path_cnn = 'cnn_model.h5'
    model_path_cnn=os.path.join(output_directory,model_path_cnn)
    checkpoint_cnn = ModelCheckpoint(model_path_cnn, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
    
    
    print('------------------------------------------------------------------------')
    print('Training CNN Model...')
    try:
        model_cnn.load_weights(model_path_cnn)
    except:
        pass

    history=model_cnn.fit(X_train, y_train,batch_size=bs, epochs=ep,
                        validation_data=(X_test, y_test),
                        callbacks=[checkpoint_cnn])  # starts training


    del X_train,y_train

     # Load best epoch weights. 
    model_cnn.load_weights(model_path_cnn)

    print('------------------------------------------------------------------------')
    print('Define Feature Extractor and create Binary SVCs')
    model_feat = Model(inputs=model_cnn.input,outputs=model_cnn.get_layer('global_max_pooling1d').output)
    createBinarySVC(classes,max_cases_svc,model_feat,input_directory,bs,output_directory)
    print('-------------------DONE---------------------------------')
    
    
    



# Find unique classes.
def get_classes(input_directory, filenames):
    classes = set()
    for filename in filenames:
        with open(filename, 'r') as f:
            for l in f:
                if l.startswith('#Dx'):
                    tmp = l.split(': ')[1].split(',')
                    for c in tmp:
                        classes.add(c.strip())
    return sorted(classes)

def bandpass_filter(data, lowcut = 0.001, highcut = 15.0, signal_freq = 500, filter_order = 1):
        """
        Method responsible for creating and applying Butterworth filter.
        :param deque data: raw data
        :param float lowcut: filter lowcut frequency value
        :param float highcut: filter highcut frequency value
        :param int signal_freq: signal frequency in samples per second (Hz)
        :param int filter_order: filter order
        :return array: filtered data
        """
        nyquist_freq = 0.5 * signal_freq
        low = lowcut / nyquist_freq
        high = highcut / nyquist_freq
        b, a = butter(filter_order, [low, high], btype="band")
        y = lfilter(b, a, data)
        return y
    
    
def load_challenge_data(filename):

    x = loadmat(filename)
    data = np.asarray(x['val'], dtype=np.float64)

    new_file = filename.replace('.mat','.hea')
    input_header_file = os.path.join(new_file)

    with open(input_header_file,'r') as f:
        header_data=f.readlines()

    return data, header_data

# Find unique true labels
def get_true_labels(input_file,classes,classes_cases):

    classes_label = classes
    single_recording_labels=np.zeros(len(classes),dtype=int)
    scored_classes_flag=False
    with open(input_file,'r') as f:
        first_line = f.readline()
        recording_label=first_line.split(' ')[0]
        #print(recording_label)
        for lines in f:
            if lines.startswith('#Dx'):
                tmp = lines.split(': ')[1].split(',')
                for c in tmp:
                    current_class=int(c.strip())
                    
                    if current_class in classes_label:
                        scored_classes_flag=True
                        idx = classes.index(current_class)
                        if classes_cases[idx]>0:
                            classes_cases[idx]-=1
                            single_recording_labels[idx]=1

    return scored_classes_flag,recording_label,classes_label,single_recording_labels




def extend_ts(ts, length):
    extended = np.zeros(length)
    siglength = np.min([length, ts.shape[0]])
    extended[:siglength] = ts[:siglength]
    return extended 




def readData(input_files,labels,classes,scored_classes):
    fs=500
    num_leads = 12
    frame_len = 15000
    num_classes = len(classes)
    
    
    #__________________________________________________________________
    # Data from all mat files.
    #__________________________________________________________________
    num_files = len(input_files)
    X = []
    y = []
    
    normalize = partial(processing.normalize_bound, lb=-1, ub=1)
    # Iterate over files.
    for i, f in enumerate(input_files):
        
        print('    {}/{}...'.format(i+1, num_files))
        # Creating temporary variables for current signal and label
        temp_x = np.zeros((1, frame_len, num_leads), dtype = np.float32)
        temp_y = np.zeros((1),dtype=int)
        # Mat files. (ECG data)
        tmp_input_file = f
        data,header_data = load_challenge_data(tmp_input_file)
        
        
        #___________________________________________________________________________
        # Reading Header data and processing it
        #
        
        # Header files. (ECG Labels)
        g = f.replace('.mat','.hea')
        tmp_input_file = g
        
        # Read sampled frequency
        with open(tmp_input_file,'r') as f:
            first_line = f.readline()
            sampled_fs=int(first_line.split(' ')[2])
        # If sample frequency is not 500. Resample data
        if sampled_fs!=fs:
            data=resample_poly(data, fs, sampled_fs,axis=1)
            
        
        #___________________________________________________________________________
        # Reading Signal data and processing it
        #
        
        # If length of ecg signal is greater than the frame length just truncate it. 
        if data.shape[1] > frame_len:
            data = data[:,:frame_len]

        extended_data = np.zeros((num_leads,frame_len))

        for j in range(num_leads):

            # If all values in a lead are not zero. 
            if data[j,:].any():
                # Frame Normalization
                data[j,:] = np.squeeze(np.apply_along_axis(normalize, 0, data[j,:]))

            # padding zeros and bandpass filtering. 
            extended_data[j,:] = bandpass_filter(extend_ts(data[j,:], length = frame_len))

        temp_x = extended_data.T
        
        #___________________________________________________________________________
        # Finalizing Labels and Signals into X and y
        #
        
        # Creating multiple Xs and ys for multi labelled files
        temp_y=labels[i]
        y.append(temp_y)
        X.append(temp_x)

            
            
    # Collect unused Variables
    gc.collect()
    del extended_data,input_files,data
    
    
    
    # Converting Python lists to final Training Array
    X = np.asarray(X, dtype=np.float32)
    y = np.asarray(y, dtype=np.int)
  
    
    
    
    return X,y



def readFilesWithLabels(input_directory,classes,classes_cases):

    
    #__________________________________________________________________
    # Find all mat files in data directory. 
    #__________________________________________________________________

    input_files = []
    labels=[]
    
    for f in os.listdir(input_directory):
        if os.path.isfile(os.path.join(input_directory, f)) and not f.lower().startswith('.') and f.lower().endswith('mat'):
            
            
            current_file = os.path.join(input_directory, f)
            #___________________________________________________________________________
            # Reading Header data and processing it
            #

            # Header files. (ECG Labels)
            g = current_file.replace('.mat','.hea')
            tmp_input_file = g

            # Check if the current class is in scored classes. Otherwise Skip the file
            scored_classes_flag,recording_label,classes_label,multi_labels_temp=get_true_labels(tmp_input_file,classes,classes_cases)

            # Skipping the files where no scored class label is found
            if not scored_classes_flag:
                # print("No Scored Class Found in this file")
                continue

            # Taking All indexes where class label is 1
            idx = np.where(multi_labels_temp == 1)
            # Creating multiple filename entries and labels for multi labelled files
            for i in range(len(idx[0])):
                temp_y=idx[0][i]
                input_files.append(current_file)
                labels.append(temp_y)
                
                
                
    return input_files,labels
    






def create_model(frame_len,num_classes):
    
        model = Sequential([   
        Input(shape=(frame_len, 12)),
        Conv1D(64, 15, activation='relu'),
        MaxPool1D(2),
        Conv1D(64, 15, activation='relu'),
        MaxPool1D(2),
        Conv1D(64, 15, activation='relu'),
        MaxPool1D(2),
        Conv1D(64, 9, activation='relu'),
        MaxPool1D(3),
        Conv1D(64, 9, activation='relu'),
        MaxPool1D(3),
        Conv1D(32, 9, activation='relu'),
        MaxPool1D(3),
        Conv1D(32, 3, activation='relu'),
        MaxPool1D(4),
        GlobalMaxPool1D(),
        Dense(64, kernel_initializer='normal', activation='relu'),
        Dense(num_classes,activation='sigmoid', kernel_initializer='normal')])
        model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])
        
        return model



def readDataWithoutLabels(input_files):
    fs=500
    num_leads = 12
    frame_len = 15000
    
    
    #__________________________________________________________________
    # Data from all mat files.
    #__________________________________________________________________
    num_files = len(input_files)
    X = []
    
    normalize = partial(processing.normalize_bound, lb=-1, ub=1)
    # Iterate over files.
    for i, f in enumerate(input_files):
        
       # Creating temporary variables for current signal and label
        temp_x = np.zeros((1, frame_len, num_leads), dtype = np.float32)
        # Mat files. (ECG data)
        tmp_input_file = f
        data,header_data = load_challenge_data(tmp_input_file)
        
        
        #___________________________________________________________________________
        # Reading Header data and processing it
        #
        
        # Header files. (ECG Labels)
        g = f.replace('.mat','.hea')
        tmp_input_file = g
        
        # Read sampled frequency
        with open(tmp_input_file,'r') as f:
            first_line = f.readline()
            sampled_fs=int(first_line.split(' ')[2])
        # If sample frequency is not 500. Resample data
        if sampled_fs!=fs:
            data=resample_poly(data, fs, sampled_fs,axis=1)
            
        
        #___________________________________________________________________________
        # Reading Signal data and processing it
        #
        
        # If length of ecg signal is greater than the frame length just truncate it. 
        if data.shape[1] > frame_len:
            data = data[:,:frame_len]

        extended_data = np.zeros((num_leads,frame_len))

        for j in range(num_leads):

            # If all values in a lead are not zero. 
            if data[j,:].any():
                # Frame Normalization
                data[j,:] = np.squeeze(np.apply_along_axis(normalize, 0, data[j,:]))

            # padding zeros and bandpass filtering. 
            extended_data[j,:] = bandpass_filter(extend_ts(data[j,:], length = frame_len))

        temp_x = extended_data.T
        
        #___________________________________________________________________________
        # Finalizing Labels and Signals into X and y
        #
        
        # Creating multiple Xs and ys for multi labelled files
        X.append(temp_x)
    
    return X



def equvialentClassesConversion(scored_classes,classes,labels):
    equivalent_classes_collection = [[713427006, 59118001], [284470004, 63593006], [427172004, 17338001]]
    # For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
    # The label for the representative class is positive if any of the labels in the set is positive.
    remove_classes = list()
    remove_indices = list()
    for equivalent_classes in equivalent_classes_collection:
        equivalent_classes = [x for x in equivalent_classes if x in classes]
        if len(equivalent_classes)>1:
            representative_class = equivalent_classes[0]
            other_classes = equivalent_classes[1:]
            equivalent_indices = [classes.index(x) for x in equivalent_classes]
            representative_index = equivalent_indices[0]
            other_indices = equivalent_indices[1:]

            labels[:, representative_index] = np.any(labels[:, equivalent_indices], axis=1)
            remove_classes += other_classes
            remove_indices += other_indices

    for x in remove_classes:
        classes.remove(x)
        del scored_classes[x]
    labels = np.delete(labels, remove_indices, axis=1)

    return scored_classes,classes, labels

def createBinarySVC(classes,max_cases_svc,model_feat,input_directory,bs,output_directory):
    output_directory='model'
    max_cases=max_cases_svc
    for idx,current_class in enumerate(classes):
        if current_class==713427006:
            current_classes=[713427006, 59118001]
        elif current_class==284470004:
            current_classes=[284470004, 63593006]
        elif current_class==427172004:
            current_classes=[427172004, 17338001]
        else:
            current_classes=[current_class]
        # Remainder Classes
        remainder_classes=classes.copy()
        remainder_classes.remove(current_class)
        # Read All files with labels of current clas
        current_class_cases=[max_cases]*len(current_classes)
        input_files_current , _ = readFilesWithLabels(input_directory=input_directory,classes=current_classes,classes_cases=current_class_cases)
        # Claculate how many cases of current class
        current_class_cases=len(input_files_current)
        # Read All files with labels of remainder classes
        remainder_cases=[math.ceil(current_class_cases/(len(classes)-1))]*(len(classes)-1)
        input_files_remainder, _ =readFilesWithLabels(input_directory=input_directory,classes=remainder_classes,classes_cases=remainder_cases)
        
        #Read data from files
        X_current=readDataWithoutLabels(input_files_current)
        y_current=[1]*len(X_current)
        X_remainder=readDataWithoutLabels(input_files_remainder)
        y_remainder=[0]*len(X_remainder)
        X = X_current + X_remainder
        y = y_current + y_remainder
        # Converting Python lists to final Training Array
        X = np.asarray(X, dtype=np.float32)
        y = np.asarray(y, dtype=np.int)
        
        
        #Split data into training and testing
        X_train, X_test, y_train, y_test = train_test_split(X, y,stratify=y, test_size=0.2, random_state=42)
        
        #Get Features from feature Extractor
        feat_train = model_feat.predict(X_train,batch_size=bs)
        feat_test = model_feat.predict(X_test,batch_size=bs)
        
        #Train SVM
        svm = SVC(kernel='linear')
        svm.probability=True
        svm.fit(feat_train,y_train)
        
        # save the model to disk
        filename = 'svc_'+str(current_class)+'.sav'
        filename=os.path.join(output_directory,filename)
        pickle.dump(svm, open(filename, 'wb'))
