import numpy as np
from scipy.signal import resample
import neurokit2 as nk
from helper_code import *
from windowing import *

twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)

def window_sig(recording, header, window_len, overlap, get_features, resampled_fs = 100):
    # Does pre-processing: Resamples to common frequency, Perfomrs windowing
    # Returns the pre-processes signals, and the number of windows created for label creation
    windowed = []
    
    # Read the frequency and number of samples from header file
    fs = int(get_frequency(header))
    t = recording.shape[1]/fs
    
    # FILTERING AND GETTING FEATURES
    # Only do this if you are loading the 12 lead dataset first
    if get_features == True:
        for i in range(recording.shape[0]):
             # Filter the signal
            recording[i,:] = nk.signal_detrend(recording[i,:])
            recording[i,:] = nk.signal_filter(recording[i,:], lowcut=5, highcut=100)

            if i == 1: # Get the HRV features from lead II
                try:
                    peaks, info = nk.ecg_peaks(recording[i,:], sampling_rate=fs)
                    # Time domain features
                    hrv_time = np.asarray(nk.hrv_time(peaks, sampling_rate=fs, show=False))
                    # Non-linear domain features
                    hrv_non_linear = np.asarray(nk.hrv_nonlinear(peaks, sampling_rate=fs, show=False))

                    # Remove sample entropy and append Time domain and non-linear features
                    hrv_non_linear = hrv_non_linear[:,:-1]
                    hrv_features = np.append(hrv_time, hrv_non_linear)

                    # Turn NaN to zeros (change this if needed, just a placeholder to make code run)
                    if np.any(np.isinf(hrv_features)):
                        inf_idx = np.array(np.where(np.isinf(hrv_features)))[0]
                        for idx in inf_idx:
                            hrv_features[inf_idx] = 0
                except:
                    # If any errors come up in creating the features, just return a list of zeros
                    # This is rare, and there may be a better way of handling this
                    hrv_features = np.zeros(42) # 43 if sample entropy is included

    recording.transpose()
    
    # determine number of resampled samples
    n_resample = int(t*resampled_fs)
    
    # Resample the signal - change from using Scipy 
    sig_resample = resample(recording, n_resample, axis = 1)
    sig_resample = sig_resample.T
    
    # Window the signal and get the number of windows created
    win, count = windowing2(sig_resample, window_len, overlap, resampled_fs)
    
    # transpose the signal to put it into a readable dimension for Tensorflow
    windowed = win.transpose(0,2,1) # (number of windows, singlal length, number of leads)
    
    if get_features == True:
        return windowed, count, hrv_features 
    else:
        return windowed, count

def process_recording(header, recording, window_len, overlap, classes, train, leads):
    # This function loads a signal and preforms the preprocessing on the recording
    # If processing to preform training, this function also creates the label from header data
    if train == True:
        if leads == twelve_leads: 
            signals = []
            labels = []
            features = []
            num_classes = len(classes)

            # CREATE THE MULTICLASS LABEL
            current_labels = get_labels(header) # get all labels from the header.
            a = np.zeros((1,num_classes)) # create a 1D array of zeros to be populated with labels
            for label in current_labels:
                if label != '':
                    # check to see if the label is in the list of scored classes
                    if int(label) in classes:
                        j = classes.index(int(label)) # index that the label exists in the list of scored labels
                        a[:,j] = 1 # turn a zero to a one in on the multilabel, at the index found previously

            if np.count_nonzero(a) != 0: # check to see if any of the labels are scored 
                windowed_signal, counts, hrv_features = window_sig(recording, header, window_len,overlap, get_features = True) # preprocess the signal 
                for i in range(counts):
                    # append the label and features to the list of labels and list of features. this is done as many times as there are windows.
                    labels.append(a) 
                    features.append(hrv_features)

                # Turn list of labels and features to Numpy array
                labels = np.concatenate(labels) 
                features = np.stack(features)
                return windowed_signal, labels, features

            else:
                # Error handling
                return('None','None', 'None') # return a 'none' if signal did not have any scored labels
            
            
        # Run through this else statement if you are not wanting to load the features, since you have already loaded them.
        else:
            signals = []
            labels = []
            num_classes = len(classes)

            # CREATE THE MULTICLASS LABEL
            current_labels = get_labels(header) # get all labels from the header.
            a = np.zeros((1,num_classes)) # create a 1D array of zeros to be populated with labels
            for label in current_labels:
                if label != '':
                    # check to see if the label is in the list of scored classes
                    if int(label) in classes:
                        j = classes.index(int(label)) # index that the label exists in the list of scored labels
                        a[:,j] = 1 # turn a zero to a one in on the multilabel, at the index found previously

            if np.count_nonzero(a) != 0: # check to see if any of the labels are scored 
                windowed_signal, counts = window_sig(recording, header, window_len,overlap, get_features = False) # preprocess the signal 
                for i in range(counts):
                    # append the label and features to the list of labels and list of features. this is done as many times as there are windows.
                    labels.append(a) 

                # Turn list of labels and features to Numpy array
                labels = np.concatenate(labels) 
                return windowed_signal, labels

            else:
                # Error handling
                return('None','None') # return a 'none' if signal did not have any scored labels
            
        
    if train == False: # For running the model, rather than training

        signals = []
        labels = []
        features = []
        num_classes = len(classes) 
        windowed_signal, counts, hrv_features = window_sig(recording, header, window_len,overlap, get_features = True) # perform the signal processing 
        signals.append(windowed_signal) # append the current signal to a list of all signals
        features.append(hrv_features) # append the current features to a list of all features
        

        if len(signals) != 0:
            # turn the list into an an array 
            signals = np.concatenate(signals) 
            features = np.stack(features)
            return (signals, features)
    else:
        return('None', 'None')
