#!/usr/bin/env python

from sklearn.ensemble import ExtraTreesClassifier
import numpy as np, os, sys, joblib
from scipy.io import loadmat
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb
from sklearn.preprocessing import OneHotEncoder, MultiLabelBinarizer
from imblearn.over_sampling import SMOTE
from sklearn.multioutput import MultiOutputClassifier
import constants
import get_12ECG_features
from sklearn import metrics
from collections import Counter
import pandas as pd

def train_12ECG_classifier(input_directory, output_directory):
    # Load data.
    print('Loading data...')

    header_files = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            header_files.append(g)

    # classes = get_classes(input_directory, header_files)
    # num_classes = len(classes)
    num_files = len(header_files)
    recordings = list()
    headers = list()

    for i in range(num_files):
        recording, header = load_challenge_data(header_files[i])
        num_leads, ptID, gender, age, sample_Fs, lead_info, classes = get_12ECG_features.parse_hea_file(header)
        #  we want to include lable with zeros
        # if(np.sum(get_12ECG_features.get_target_classes(classes)) == 0):
        #     continue
        if(ptID in constants.OMIT):
            continue
        if(len(lead_info) != 12):
            continue
        recordings.append(recording[[1,6],:])
        headers.append(header)

    features, labels = get_12ECG_features.GenerateEngFeature(recordings, headers, 0, constants.NUM_CORES)

    recordings = None
    headers = None

    IntLabels = np.array([labels[i, :].dot(2 ** np.arange(labels[i, :].size)[::-1]) for i in range(labels.shape[0])], dtype=int)
    CounterArray = pd.Series(Counter(IntLabels))
    LabelsOver100 = CounterArray.loc[CounterArray >=50].index.to_list()
    idxOver100 = np.argwhere(np.isin(IntLabels, LabelsOver100))


    Over100Labels = IntLabels[idxOver100][:,0]
    Over100Features = features[idxOver100[:,0],:]
    Over100BinLabels = labels[idxOver100][:,0]


    classifer = ExtraTreesClassifier(n_estimators=400, criterion='entropy', bootstrap=True, max_features='sqrt', n_jobs=-1)


    # IntLabels = np.array([], dtype=int)
    # for i in range(labels.shape[0]):
    #     IntLabels = np.append(IntLabels, labels[i, :].dot(2 ** np.arange(labels[i, :].size)[::-1]))
    # CounterDF = pd.Series(Counter(IntLabels))
    # RetainedNameList = CounterDF.loc[CounterDF > 50].index
    # features = features[np.isin(IntLabels, RetainedNameList),:]
    # IntLabels = IntLabels[np.isin(IntLabels, RetainedNameList)]

    # for ind, cl in enumerate(np.array(labels).transpose()):
    # classifer = xgb.XGBClassifier(max_depth=100, objective='multi:softmax', tree_method='approx', scale_pos_weight=1,
    #                   grow_policy='depthwise', learning_rate=0.01, n_estimators=200, n_jobs=constants.NUM_CORES)

    #fits = []
    #sm = SMOTE(random_state=10, sampling_strategy='auto', n_jobs=-1)
    #for cl in labels.transpose():
    #Train_X, Train_Y= sm.fit_resample(features, IntLabels)
    Train_X, Train_Y = (Over100Features, Over100BinLabels)
    #for label, model in zip(constants.LABELS, fits):
    filename = os.path.join(output_directory, 'xgb_classifier.sav')
    joblib.dump({'model' : classifer.fit(Train_X, Train_Y)}, filename, protocol=0)

    # for index, ff in enumerate(fits):
    #     ff.save_model('xgb_classifier_' + index + '.sav')

    # for trainIndex, testIndex in skf.split(FeatureArrayList[ifeatureset], IntLabelListForAllPVCs[ilabelist]):
    #     print('Training Size and Testingt Size:', trainIndex.shape, testIndex.shape, num_validation)
    #     for ismote in range(2):
    #         if ismote == 0:
    #             Train_x, Train_y = FeatureArrayList[ifeatureset][trainIndex], IntLabelListForAllPVCs[ilabelist][
    #                 trainIndex]
    #         else:
    #             Train_x, Train_y = smote.fit_resample(FeatureArrayList[ifeatureset][trainIndex],
    #                                                   IntLabelListForAllPVCs[ilabelist][trainIndex])
    #         # ModelFitList =[] Parallel(n_jobs=5)(delayed(FitModel)(clf, Train_x, Train_y) for clf in Classifiers)
    #         Test_x, Test_y = FeatureArrayList[ifeatureset][testIndex], IntLabelListForAllPVCs[ilabelist][testIndex]
    #         for imodel in range(len(Classifiers)):
    #             clfFit = Classifiers[imodel].fit(Train_x, Train_y)
    # #Train_X is feature, this np.array dim=2 row is observation, col is variable (126006)
    # #Train_y is label, [1,2,3,4,5] #
    # skf = StratifiedShuffleSplit(n_splits=10, test_size=0.1)
    # num_validation = 0
    # # TotalF1SMOTE =0
    # # TotalF1NonSMOTE=0
    # # TotalAccuraySMOT =0
    # # TotalAccurayNonSMOT=0

    # #One model for level 1 labels
    # #The other model for level 2 labels
    #
    # print(metrics.classification_report(Test_y, clfFit.predict(Test_x)))
    # print(metrics.confusion_matrix(Test_y, clfFit.predict(Test_x)))
    # print(metrics.f1_score(Test_y, clfFit.predict(Test_x), average='weighted'))
    # print(metrics.accuracy_score(Test_y, clfFit.predict(Test_x)))
    #
    #
    # dtrain = xgb.DMatrix(features, label=labels)
    # bst = xgb.train(param, dtrain, 5, evallist)
    #
    #
    # # Train model.
    # print('Training model...')
    #
    #
    # features = list()
    # labels = list()
    #
    # for i in range(num_files):
    #     recording = recordings[i]
    #     header = headers[i]
    #
    #
    #     tmp = get_12ECG_features(recording, header)
    #     features.append(tmp)
    #
    #     for l in header:
    #         if l.startswith('#Dx:'):
    #             labels_act = np.zeros(num_classes)
    #             arrs = l.strip().split(' ')
    #             for arr in arrs[1].split(','):
    #                 class_index = classes.index(arr.rstrip()) # Only use first positive index
    #                 labels_act[class_index] = 1
    #     labels.append(labels_act)
    #
    # features = np.array(features)
    # labels = np.array(labels)
    #
    # # Replace NaN values with mean values
    # imputer=SimpleImputer().fit(features)
    # features=imputer.transform(features)
    #
    # # Train the classifier
    # model = RandomForestClassifier().fit(features,labels)
    #
    # # Save model.
    # print('Saving model...')
    #
    # final_model={'model':model, 'imputer':imputer,'classes':classes}
    #
    # filename = os.path.join(output_directory, 'finalized_model.sav')
    # joblib.dump(final_model, filename, protocol=0)

# Load challenge data.
def load_challenge_data(header_file):
    with open(header_file, 'r') as f:
        header = f.readlines()
    mat_file = header_file.replace('.hea', '.mat')
    x = loadmat(mat_file)
    recording = np.asarray(x['val'], dtype=np.float64)
    return recording, header

# Find unique classes.
def get_classes(input_directory, filenames):
    classes = set()
    for filename in filenames:
        with open(filename, 'r') as f:
            for l in f:
                if l.startswith('#Dx'):
                    tmp = l.split(': ')[1].split(',')
                    for c in tmp:
                        classes.add(c.strip())
    return sorted(classes)
