import pandas as pd
import numpy as np
import joblib
from sklearn import svm
from xgboost import XGBClassifier
from imblearn.over_sampling import SMOTE

labelTypes = ['Normal', 'AF', 'I-AVB', 'LBBB', 'RBBB', 'PAC', 'PVC', 'STD', 'STE']
priorProbsDict = {'Normal':[0.867,0.133], 'AF':[0.822,0.178], 'I-AVB':[0.895,0.105],
                 'LBBB':[0.966,0.034], 'RBBB':[0.73,0.270], 'PAC':[0.91,0.090],
                 'PVC':[0.898,0.102], 'STD':[0.874,0.126], 'STE':[0.968,0.032]}

class BinaryTrainer:

    def __init__(self, clfType, X_train, y_trainDF):
        self.clfType = clfType
        self.X_train = X_train
        self.y_trainDF = y_trainDF
        self.priors0, self.priors1 = self.createPriors()
        return

    def createPriors(self):
        priors0 = np.empty((len(labelTypes), 1))
        priors1 = np.empty((len(labelTypes), 1))
        for i, cl in enumerate(labelTypes):
            priors0[i] = priorProbsDict[cl][0]
            priors1[i] = priorProbsDict[cl][1]
        return priors0, priors1

    def balanceDataset(self, X, y):
        smote = SMOTE('minority')
        X_bal, y_bal = smote.fit_sample(X, y)
        return X_bal, y_bal

    def getClassifier(self):
        if self.clfType == 'svm':
            return svm.SVC(probability=True)
        elif self.clfType == 'xgboost':
            return XGBClassifier()

    def trainModels(self):
        modelsDict = {}
        for cl in self.y_trainDF:
            print('      Training '+cl+':')
            y_train = self.y_trainDF[cl].to_numpy()
            X_bal, y_bal = self.balanceDataset(self.X_train, y_train)
            clf = self.getClassifier()
            clf.fit(X_bal, y_bal)
            joblib.dump(clf, cl+'_model.sav')
            modelsDict[cl] = clf
        return modelsDict

    def predict(self, modelsDict, X_test):
        y_pred = np.empty((len(X_test), len(labelTypes)))
        probs = np.empty((len(X_test), len(labelTypes)))
        for i, cl in enumerate(labelTypes):
            y_pred[:,i] = modelsDict[cl].predict(X_test)
            probMatrix = modelsDict[cl].predict_proba(X_test)
            for j in range(len(X_test)):
                probs[j,i] = probMatrix[j, int(y_pred[j,i])]
        return y_pred, probs

    def combinePreds(self, beat_preds, method='mean'):
        if method == 'mean':
            y_pred = np.round(np.mean(beat_preds, axis=0)).astype(int)
        elif method == 'median':
            y_pred = np.round(np.median(beat_preds, axis=0)).astype(int)
        return y_pred

    def predictPatient(self, modelsDict, X_patient):
        beat_preds, probs = self.predict(modelsDict, X_patient)
        y_pred = self.combinePreds(beat_preds).reshape((len(labelTypes),1))
        prob0Array = np.multiply(np.ones((len(labelTypes),1)), self.priors0)
        prob1Array = np.multiply(np.ones((len(labelTypes),1)), self.priors1)
        for beat in range(beat_preds.shape[0]):
            for i, cl in enumerate(labelTypes):
                if beat_preds[beat,i] == 0:
                    prob0Array[i] *= probs[beat,i] / self.priors0[i]
                else:
                    prob1Array[i] *= probs[beat,i] / self.priors1[i]
        denom = prob0Array + prob1Array
        probArray = np.divide(np.maximum(prob0Array, prob1Array), denom)
        if np.sum(y_pred) == 0:
            y_pred[probArray == np.min(probArray, axis=0)] = 1
        return y_pred.transpose(), probArray.transpose()

    def predictTest(self, modelsDict, X_test, idSeries):
        y_predAll = np.empty((len(idSeries.unique()), len(labelTypes)))
        probMatrix = np.empty((len(idSeries.unique()), len(labelTypes)))
        idList = []
        for i, patient in enumerate(idSeries.unique()):
            rowList = idSeries.index[idSeries == patient].tolist()
            X_patient = X_test[rowList,:]
            y_predAll[i,:], probMatrix[i,:] = self.predictPatient(modelsDict, X_patient)
            idList.append(patient)
        return y_predAll, probMatrix, idList

