# -*- coding:utf-8 -*-
# Time:2019/3/27 17:01
# author: Tengfei Shen, Du Changping
# File:compute_test_feature.py

"""
Reference: https://medium.com/@thongonary/how-to-compute-f1-score-for-each-epoch-in-keras-a1acd17715a2
"""

import numpy as np
from keras.callbacks import Callback
from keras import backend as K
import warnings
warnings.filterwarnings("ignore")

class Metrics(Callback):
    def __init__(self, dev_x, dev_y):
        self.dev_data = dev_x
        self.dev_label = dev_y

    def on_train_begin(self, logs={}):
        self.val_f1 = []

    def on_epoch_end(self, epoch, logs={}):
        val_predict = np.array(self.model.predict(self.dev_data))
        preds = []
        for pred in val_predict:
            new_pred = np.zeros(111)
            tmp = np.argwhere(pred > 0.25).squeeze()
            if np.size(tmp) == 0:
                tmp = np.argmax(pred)
            new_pred[tmp] = 1
            # for i in tmp:
            #     new_pred[i] = 1
            preds.append(new_pred)

        labels = []
        for label in self.dev_label:
            new_label = np.zeros(111)
            for i in np.array(label):
                new_label[int(i)-1] = 1
            labels.append(new_label)

        _val_f1 = self.score(preds, labels)
        self.val_f1.append(_val_f1)
        return

    # def on_epoch_end(self, epoch, logs={}):
    #     val_predict = np.array(self.model.predict(self.dev_data))
    #     pred = np.argmax(val_predict, axis=1)
    #     _val_f1 = self.score(pred, self.dev_label)
    #     self.val_f1.append(_val_f1)
    #     return

    # def score(self, preds, y_true):
    #     A = np.zeros((9, 9), dtype=np.float)
    #     for i in range(len(preds)):
    #         predict = preds[i]+1
    #         label = y_true[i]
    #         if predict in label:
    #             A[predict-1, predict-1] += 1
    #         else:
    #             A[int(label[0])-1, predict-1] += 1
    #
    #     F11 = 2 * A[0][0] / (np.sum(A[0, :]) + np.sum(A[:, 0]))
    #     F12 = 2 * A[1][1] / (np.sum(A[1, :]) + np.sum(A[:, 1]))
    #     F13 = 2 * A[2][2] / (np.sum(A[2, :]) + np.sum(A[:, 2]))
    #     F14 = 2 * A[3][3] / (np.sum(A[3, :]) + np.sum(A[:, 3]))
    #     F15 = 2 * A[4][4] / (np.sum(A[4, :]) + np.sum(A[:, 4]))
    #     F16 = 2 * A[5][5] / (np.sum(A[5, :]) + np.sum(A[:, 5]))
    #     F17 = 2 * A[6][6] / (np.sum(A[6, :]) + np.sum(A[:, 6]))
    #     F18 = 2 * A[7][7] / (np.sum(A[7, :]) + np.sum(A[:, 7]))
    #     F19 = 2 * A[8][8] / (np.sum(A[8, :]) + np.sum(A[:, 8]))
    #
    #     F1 = (F11 + F12 + F13 + F14 + F15 + F16 + F17 + F18 + F19) / 9
    #
    #     ## following is calculating scores for 4 types: AF, Block, Premature contraction, ST-segment change.
    #
    #     Faf = 2 * A[1][1] / (np.sum(A[1, :]) + np.sum(A[:, 1]))
    #     Fblock = 2 * (A[2][2] + A[3][3] + A[4][4]) / (np.sum(A[2:5, :]) + np.sum(A[:, 2:5]))
    #     Fpc = 2 * (A[5][5] + A[6][6]) / (np.sum(A[5:7, :]) + np.sum(A[:, 5:7]))
    #     Fst = 2 * (A[7][7] + A[8][8]) / (np.sum(A[7:9, :]) + np.sum(A[:, 7:9]))
    #
    #     print("F11:{:.3f}".format(F11),
    #           "F12:{:.3f}".format(F12),
    #           "F13:{:.3f}".format(F13),
    #           "F14:{:.3f}".format(F14),
    #           "F15:{:.3f}".format(F15),
    #           "F16:{:.3f}".format(F16),
    #           "F17:{:.3f}".format(F17),
    #           "F18:{:.3f}".format(F18),
    #           "F19:{:.3f}".format(F19),
    #           "F1:{:.3f}".format(F1),
    #           "Faf:{:.3f}".format(Faf),
    #           "Fblock:{:.3f}".format(Fblock),
    #           "Fpc:{:.3f}".format(Fpc),
    #           "Fst:{:.3f}".format(Fst))
    #
    #     return (F11, F12, F13, F14, F15, F16, F17, F18, F19, F1, Faf, Fblock, Fpc, Fst)

    def score(self, output, labels):
        beta = 2
        num_classes = 111
        num_recordings = len(labels)

        fbeta_l = np.zeros(num_classes)
        gbeta_l = np.zeros(num_classes)
        fmeasure_l = np.zeros(num_classes)
        accuracy_l = np.zeros(num_classes)

        f_beta = 0
        g_beta = 0
        f_measure = 0
        accuracy = 0
        # TODO 这里下一步考察如何记录错误记录的输出，
        #  思考输出的概率结果和真实的类标之间的联系，看如何改进

        # Weight function
        C_l = np.ones(num_classes)


        for j in range(num_classes):
            tp = 0
            fp = 0
            fn = 0
            tn = 0

            for i in range(num_recordings):

                num_labels = np.sum(labels[i])

                if labels[i][j] and output[i][j]:
                    tp += 1 / num_labels
                elif not labels[i][j] and output[i][j]:
                    fp += 1 / num_labels
                elif labels[i][j] and not output[i][j]:
                    fn += 1 / num_labels
                elif not labels[i][j] and not output[i][j]:
                    tn += 1 / num_labels

            # Summarize contingency table.
            if ((1 + beta ** 2) * tp + (fn * beta ** 2) + fp):
                fbeta_l[j] = float((1 + beta ** 2) * tp) / float(((1 + beta ** 2) * tp) + (fn * beta ** 2) + fp)
            else:
                fbeta_l[j] = 1.0

            if (tp + fp + beta * fn):
                gbeta_l[j] = float(tp) / float(tp + fp + beta * fn)
            else:
                gbeta_l[j] = 1.0

            if tp + fp + fn + tn:
                accuracy_l[j] = float(tp + tn) / float(tp + fp + fn + tn)
            else:
                accuracy_l[j] = 1.0

            if 2 * tp + fp + fn:
                fmeasure_l[j] = float(2 * tp) / float(2 * tp + fp + fn)
            else:
                fmeasure_l[j] = 1.0

        for i in range(num_classes):
            f_beta += fbeta_l[i] * C_l[i]
            g_beta += gbeta_l[i] * C_l[i]
            f_measure += fmeasure_l[i] * C_l[i]
            accuracy += accuracy_l[i] * C_l[i]

        f_beta = float(f_beta) / float(num_classes)
        g_beta = float(g_beta) / float(num_classes)
        f_measure = float(f_measure) / float(num_classes)
        accuracy = float(accuracy) / float(num_classes)

        print('accuracy:{:.3f}'.format(accuracy),
              'f_measure:{:.3f}'.format(f_measure),
              'f_beta:{:.3f}'.format(f_beta),
              'g_beta:{:.3f}'.format(g_beta))

        return accuracy, f_measure, f_beta, g_beta

    def call_f1_score(self):
        return self.val_f1



