#!/usr/bin/env python

# Edit this script to add your team's code. Some functions are *required*, but you can edit most parts of the required functions,
# change or remove non-required functions, and add your own functions.

################################################################################
#
# Import libraries and functions. You can change or remove them.
#
################################################################################

from helper_code import *
import numpy as np, scipy as sp, scipy.stats, os, joblib
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from scipy.integrate import simps
from sklearn.neural_network import MLPClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE
from imblearn.over_sampling import SMOTENC
from imblearn.over_sampling import BorderlineSMOTE
from imblearn.over_sampling import SVMSMOTE
from imblearn.over_sampling import ADASYN
from imblearn.over_sampling import KMeansSMOTE
from imblearn.combine import SMOTEENN
from imblearn.combine import SMOTETomek
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from sklearn.feature_selection import SelectKBest, SelectPercentile
from sklearn.feature_selection import chi2, f_classif
from sklearn.svm import SVC
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import recall_score
from sklearn.model_selection import cross_val_score
from sklearn.metrics import confusion_matrix
from scipy.io.wavfile import write
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import classification_report
from  scipy.signal import filtfilt
from scipy import signal
from sklearn.feature_selection import SequentialFeatureSelector
from mlxtend.feature_selection import SequentialFeatureSelector as sfsx
# import matplotlib.pyplot as plt
import pywt

################################################################################
#
# Required functions. Edit these functions to add your code, but do not change the arguments.
#
################################################################################

# Train your model.
myflag = 1
def train_challenge_model(data_folder, model_folder, verbose):
    # Find data files.
    if verbose >= 1:
        print('Finding data files...')

    # Find the patient data files.
    patient_files = find_patient_files(data_folder)
    num_patient_files = len(patient_files)

    if num_patient_files==0:
        raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    os.makedirs(model_folder, exist_ok=True)

    # Extract the features and labels.
    if verbose >= 1:
        print('Extracting features and labels from the Challenge data...')

    murmur_classes = ['Present', 'Unknown', 'Absent']
    num_murmur_classes = len(murmur_classes)
    outcome_classes = ['Abnormal', 'Normal']
    num_outcome_classes = len(outcome_classes)

    features = list()
    murmurs = list()
    outcomes = list()

    for i in range(num_patient_files):
        if verbose >= 2:
            print('    {}/{}...'.format(i+1, num_patient_files))

        # Load the current patient data and recordings.
        current_patient_data = load_patient_data(patient_files[i])
        current_recordings = load_recordings(data_folder, current_patient_data)

        # Extract features.
        current_features = get_features(current_patient_data, current_recordings)
        features.append(current_features)

        # Extract labels and use one-hot encoding.
        current_murmur = np.zeros(num_murmur_classes, dtype=int)
        murmur = get_murmur(current_patient_data)
        if murmur in murmur_classes:
            j = murmur_classes.index(murmur)
            current_murmur[j] = 1
        murmurs.append(current_murmur)

        current_outcome = np.zeros(num_outcome_classes, dtype=int)
        outcome = get_outcome(current_patient_data)
        if outcome in outcome_classes:
            j = outcome_classes.index(outcome)
            current_outcome[j] = 1
        outcomes.append(current_outcome)

    features = np.vstack(features)
    murmurs = np.vstack(murmurs)
    outcomes = np.vstack(outcomes)

    # Train the model.
    if verbose >= 1:
        print('Training model...')

    # Define parameters for MLP Classifier
    random_state = 4  # Random state; set for reproducibility.

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)

    sc = StandardScaler().fit(features)
    features = sc.transform(features)

    labels = murmurs.argmax(axis=1)
    outcomes = outcomes.argmax(axis=1)

    features2 = features
    print(features.shape)
    features = features[:,[48, 66, 72, 75, 77, 79, 83, 91, 93, 95, 122, 124, 126, 127, 130, 131, 132, 133, 135, 137, 140, 142, 143, 147, 148]]

    # features = np.delete(features, [0, 1, 2, 3, 5, 6, 7, 8, 11, 12, 13, 15, 17,
    #  18, 20, 21, 22, 23, 24, 25, 26, 33, 34, 38, 42, 52,
    #  53, 54, 55, 56, 58, 59, 60, 62, 63, 64, 66, 67, 70,
    #  71, 72, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 85,
    #  87, 88, 89, 90], 1)
    print(features.shape)



    k = 10
    kf = StratifiedKFold(n_splits=k, random_state=random_state, shuffle=True)
    model1 = MLPClassifier(hidden_layer_sizes=(256, 128, 64, 32), max_iter=10000, activation='relu', solver='adam',
                          random_state=random_state)
    model2 = MLPClassifier(hidden_layer_sizes=(256, 128, 64, 32), max_iter=10000, activation='relu', solver='adam',
                            random_state=random_state)


    # features = SelectKBest(f_classif, k = 20).fit_transform(features, labels)
    # print(features.shape)

    # rus = RandomUnderSampler(random_state=random_state)
    # features, labels = rus.fit_resample(features, labels)


    # sm = SMOTEENN(random_state=random_state, sampling_strategy='all')
    # features, labels = sm.fit_resample(features, labels)


    # sfs = sfsx(model1,
    #        k_features="best",
    #        forward=True,
    #        floating=True,
    #        verbose=2,
    #        scoring='balanced_accuracy',
    #        n_jobs=-1,
    #        cv=5)
    # features = sfs.fit_transform(features, labels)
    # print(features.shape)
    # print(sfs.k_feature_idx_)

    acc_score = np.zeros((4, 3))
    score_total = 0
    hscore = 0
    lscore = 1
    bas = 0

    for train_index, test_index in kf.split(features, labels):
        X_train, X_test = features[train_index, :], features[test_index, :]
        y_train, y_test = labels[train_index], labels[test_index]

        # sm = SMOTENC(random_state=random_state, categorical_features=[1, 2, 5, 6, 13, 20, 27, 34])
        # X_train, y_train = sm.fit_resample(X_train, y_train)

        # sm = SMOTE(random_state=random_state)
        # X_train, y_train = sm.fit_resample(X_train, y_train)

        sm = SMOTEENN(random_state=random_state , sampling_strategy='all')
        X_train, y_train = sm.fit_resample(X_train, y_train)

        # ros = RandomOverSampler(random_state=random_state)
        # X_train, y_train = ros.fit_resample(X_train, y_train)

        # rus = RandomUnderSampler(random_state=random_state)
        # X_train, y_train = rus.fit_resample(X_train, y_train)

        temp_classifier = model1.fit(X_train, y_train)
        pred_values = temp_classifier.predict(X_test)

        cm = confusion_matrix(y_test, pred_values)
        score = (5 * cm[0][0] + 3 * cm[1][1] + cm[2][2]) / (
                    5 * (cm[0][0] + cm[1][0] + cm[2][0]) + 3 * (cm[0][1] + cm[1][1] + cm[2][1]) + (
                        cm[0][2] + cm[1][2] + cm[2][2]))
        score_total = score_total + score

        bas = bas + balanced_accuracy_score(y_test, pred_values)

        if score > hscore:
            murmur_classifier = temp_classifier
            hscore = score
            print(cm)

        if score < lscore:
            lscore = score


        acc = precision_recall_fscore_support(y_test, pred_values, average=None)
        acc = np.asarray(acc)
        acc_score = np.add(acc_score, acc)

    print(trunc(acc_score[0:3, :] * 100 / k, decs=2))
    print(acc_score[3, :] / k)
    print(score_total / k)
    # print(hscore)
    # print(lscore)
    print(bas / k)

    bas = 0
    acc_score = np.zeros((4, 2))
    hscore = 0
    lscore = 1

    features = features2

    features = features[:,[5, 15, 16, 33, 35, 38, 43, 64, 93, 104, 114, 122, 123, 124, 125, 127, 129, 130, 131, 133, 135, 136, 137, 138, 140, 141, 142, 143, 144, 145, 146, 148, 150]]

    # sfs = sfsx(model2,
    #            k_features="best",
    #            forward=True,
    #            floating=True,
    #            verbose=2,
    #            scoring='balanced_accuracy',
    #            n_jobs=-1,
    #            cv=5)
    # features = sfs.fit_transform(features, outcomes)
    # print(features.shape)
    # print(sfs.k_feature_idx_)

    for train_index, test_index in kf.split(features, outcomes):
        X_train, X_test = features[train_index, :], features[test_index, :]
        y_train, y_test = outcomes[train_index], outcomes[test_index]

        # sm = SMOTEENN(random_state=random_state, sampling_strategy='all')
        # X_train, y_train = sm.fit_resample(X_train, y_train)

        rus = RandomUnderSampler(random_state=random_state)
        X_train, y_train = rus.fit_resample(X_train, y_train)

        temp2_classifier = model2.fit(X_train, y_train)
        pred_outcomes = temp2_classifier.predict(X_test)

        score = balanced_accuracy_score(y_test, pred_outcomes)

        if score > hscore:
            outcome_classifier = temp2_classifier
            hscore = score


        bas = bas + balanced_accuracy_score(y_test, pred_outcomes)

        # acc = precision_recall_fscore_support(y_test, pred_outcomes, average=None)
        # acc = np.asarray(acc)
        # acc_score = np.add(acc_score, acc)

    # print(trunc(acc_score[0:3, :] * 100 / k, decs=2))
    # print(acc_score[3, :] / k)
    print(bas / k)
    print(hscore)




    # outcome_classifier = model2.fit(features, outcomes)


    # Save the model.
    save_challenge_model(model_folder, imputer, sc, murmur_classes, murmur_classifier, outcome_classes, outcome_classifier)

    if verbose >= 1:
        print('Done.')

# Load your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the
# arguments of this function.
def load_challenge_model(model_folder, verbose):
    filename = os.path.join(model_folder, 'model.sav')
    return joblib.load(filename)

# Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the
# arguments of this function.
def run_challenge_model(model, data, recordings, verbose):
    imputer = model['imputer']
    sc = model['sc']
    murmur_classes = model['murmur_classes']
    murmur_classifier = model['murmur_classifier']
    outcome_classes = model['outcome_classes']
    outcome_classifier = model['outcome_classifier']

    # Load features.
    features = get_features(data, recordings)

    # Impute missing data.
    features = features.reshape(1, -1)
    features = imputer.transform(features)

    features = sc.transform(features)

    features2 = features
    features = features[:,
               [48, 66, 72, 75, 77, 79, 83, 91, 93, 95, 122, 124, 126, 127, 130, 131, 132, 133, 135, 137, 140, 142, 143, 147, 148]]

    # Get classifier probabilities.
    murmur_probabilities = murmur_classifier.predict_proba(features)
    murmur_probabilities = murmur_probabilities[0]

    features = features2
    features = features[:,
               [5, 15, 16, 33, 35, 38, 43, 64, 93, 104, 114, 122, 123, 124, 125, 127, 129, 130, 131, 133, 135, 136, 137, 138, 140, 141, 142, 143, 144, 145, 146, 148, 150]]

    outcome_probabilities = outcome_classifier.predict_proba(features)
    outcome_probabilities = outcome_probabilities[0]
    #print(outcome_probabilities)

    # Choose label with highest probability.
    murmur_labels = np.zeros(len(murmur_classes), dtype=np.int_)
    idx = np.argmax(murmur_probabilities)
    murmur_labels[idx] = 1
    outcome_labels = np.zeros(len(outcome_classes), dtype=np.int_)
    idx = np.argmax(outcome_probabilities)
    outcome_labels[idx] = 1

    # Concatenate classes, labels, and probabilities.
    classes = murmur_classes + outcome_classes
    labels = np.concatenate((murmur_labels, outcome_labels))
    probabilities = np.concatenate((murmur_probabilities, outcome_probabilities))

    return classes, labels, probabilities

################################################################################
#
# Optional functions. You can change or remove these functions and/or add new functions.
#
################################################################################

# Save your trained model.
def save_challenge_model(model_folder, imputer , sc, murmur_classes, murmur_classifier, outcome_classes, outcome_classifier):
    d = {'imputer': imputer, 'sc': sc, 'murmur_classes': murmur_classes, 'murmur_classifier': murmur_classifier, 'outcome_classes': outcome_classes, 'outcome_classifier': outcome_classifier}
    filename = os.path.join(model_folder, 'model.sav')
    joblib.dump(d, filename, protocol=0)

# Extract features from the data.
def get_features(data, recordings):
    # Extract the age group and replace with the (approximate) number of months for the middle of the age group.
    age_group = get_age(data)

    if compare_strings(age_group, 'Neonate'):
        age = 0.5
    elif compare_strings(age_group, 'Infant'):
        age = 6
    elif compare_strings(age_group, 'Child'):
        age = 6 * 12
    elif compare_strings(age_group, 'Adolescent'):
        age = 15 * 12
    elif compare_strings(age_group, 'Young Adult'):
        age = 20 * 12
    else:
        age = float('nan')

    # Extract sex. Use one-hot encoding.
    sex = get_sex(data)

    sex_features = np.zeros(2, dtype=int)
    if compare_strings(sex, 'Female'):
        sex_features[0] = 1
    elif compare_strings(sex, 'Male'):
        sex_features[1] = 1

    # Extract height and weight.
    height = get_height(data)
    weight = get_weight(data)

    # Extract pregnancy status.
    is_pregnant = get_pregnancy_status(data)

    # Extract recording locations and data. Identify when a location is present, and compute the mean, variance, and skewness of
    # each recording. If there are multiple recordings for one location, then extract features from the last recording.
    locations = get_locations(data)

    recording_locations = ['AV', 'MV', 'PV', 'TV', 'PhC']
    num_recording_locations = len(recording_locations)
    recording_features = np.zeros((num_recording_locations, 29), dtype=float)
    num_locations = len(locations)
    num_recordings = len(recordings)
    if num_locations==num_recordings:
        for i in range(num_locations):
            for j in range(num_recording_locations):
                if compare_strings(locations[i], recording_locations[j]) and np.size(recordings[i])>0:
                    # global myflag
                    # if myflag == 1:
                    #    audio_fft(recordings[i], "data", 4000)
                    #    freq = bandPass(recordings[i], 20, 500, 5)
                    #    audio_fft(freq, "filtered data", 4000)
                    #    write('test.wav', 4000, recordings[i])
                    #    myflag = myflag + 1

                    recordings[i] = bandPass(recordings[i], 20, 500, 5)

                    recording_features[j, 0] = 1
                    recording_features[j, 1] = np.mean(recordings[i])
                    recording_features[j, 2] = np.var(recordings[i])
                    recording_features[j, 3] = sp.stats.skew(recordings[i])
                    recording_features[j, 4] = sp.stats.kurtosis(recordings[i])

                    fft_recordings = np.abs(np.fft.fft(recordings[i]))

                    recording_features[j, 5] = np.mean(fft_recordings)
                    recording_features[j, 6] = np.var(fft_recordings)
                    recording_features[j, 7] = sp.stats.skew(fft_recordings)
                    recording_features[j, 8] = sp.stats.kurtosis(fft_recordings)

                    freqs, Pxx_den = signal.welch(recordings[i], 4000, nperseg=4000)

                    recording_features[j, 9] = np.mean(Pxx_den)
                    recording_features[j, 10] = np.var(Pxx_den)
                    recording_features[j, 11] = sp.stats.skew(Pxx_den)
                    recording_features[j, 12] = sp.stats.kurtosis(Pxx_den)

                    rel_power = RP(freqs, Pxx_den, 20, 130)
                    recording_features[j, 13] = rel_power

                    rel_power = RP(freqs, Pxx_den, 130, 400)
                    recording_features[j, 14] = rel_power

                    rel_power = RP(freqs, Pxx_den, 400, 500)
                    recording_features[j, 15] = rel_power

                    db7 = pywt.Wavelet('db7')
                    cA5, cD5, cD4, cD3, cD2, cD1 = pywt.wavedec(recordings[i], db7, level=5)

                    recording_features[j, 16] = np.mean(cD4)
                    recording_features[j, 17] = np.var(cD4)
                    recording_features[j, 18] = np.mean(cD5)
                    recording_features[j, 19] = np.var(cD5)
                    recording_features[j, 20] = np.mean(cD1)
                    recording_features[j, 21] = np.var(cD1)
                    recording_features[j, 22] = np.mean(cD2)
                    recording_features[j, 23] = np.var(cD2)
                    recording_features[j, 24] = np.mean(cD3)
                    recording_features[j, 25] = np.var(cD4)
                    recording_features[j, 26] = np.mean(cA5)
                    recording_features[j, 27] = np.var(cA5)

                    recording_features[j, 28] = thd(fft_recordings[1:int(len(fft_recordings) / 2)])

                    # recording_features[j, 29] = np.mean(zero_crossing_rate(recordings[i])[0])


                    # tp = 0
                    # fla = 0
                    # for z in range(len(recordings[i])-1) :
                    #     tp = tp + (recordings[i][z] * recordings[i][z])
                    #     fla = fla + 1
                    # recording_features[j, 17] = tp/fla
                    # print(tp/fla)



    recording_features = recording_features.flatten()

    features = np.hstack(([age], sex_features, [height], [weight], [is_pregnant], recording_features))

    return np.asarray(features, dtype=np.float32)

def bandPass(signal, lowcut, highcut, order):
    fs = 4000.0

    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq


    b, a = scipy.signal.butter(order, [low, high], 'bandpass', analog=False)
    y = scipy.signal.filtfilt(b, a, signal, axis=0)

    return (y)

def trunc(values, decs=0):
    return np.trunc(values*10**decs)/(10**decs)

def RP(freqs, Pxx_den, low, high):
    Pxx_den = np.asarray(Pxx_den)
    idx_delta = np.logical_and(freqs >= low, freqs <= high)

    freq_res = freqs[1] - freqs[0]
    delta_power = simps(Pxx_den[idx_delta], dx=freq_res)

    total_power = simps(Pxx_den, dx=freq_res)
    delta_rel_power = delta_power / total_power
    return delta_rel_power

def thd(abs_data):
    sq_sum=0.0
    for r in range( len(abs_data)):
       sq_sum = sq_sum + (abs_data[r])**2

    sq_harmonics = sq_sum -(max(abs_data))**2.0
    thd = 100*sq_harmonics**0.5 / max(abs_data)

    return thd


# def audio_fft(signal, title, sr):
#     ft = np.fft.fft(signal)
#     amplitude_fft = np.abs(ft)
#
#     frequency = np.linspace(0, sr, len(amplitude_fft))
#     bins = int(len(frequency)/2)
#
#     plt.plot(frequency[:bins], amplitude_fft[:bins])
#     plt.xlabel("frequency (Hz)")
#     plt.title(title)
#
#     plt.show()

