# cinc2020.metrics 19/03/21
import tensorflow as tf
import numpy as np

class ClassWeightedAccuracy(tf.keras.metrics.Metric):
    """Class-weighted accuracy for Keras.

    Metric to compute the class-weighted accuracy as defined in the challenge
    of Computing in Cardiology 2020. This metric should be used when training a
    Keras model.

    Args:
        n_classes: Number of classes.
        class_weights: Array of class weights.
        threshold: Threshold to converted predicted probabilities to class
            labels.
        name: Metric name.
    """

    def __init__(self,
                 n_classes,
                 class_weights=None,
                 threshold=0.5,
                 name='class_weighted_accuracy',
                 **kwargs):
        super().__init__(name=name, **kwargs)
        self.n_classes = n_classes
        if class_weights is None:
            self.class_weights = tf.ones((self.n_classes,), 'float32')
        else:
            self.class_weights = tf.convert_to_tensor(class_weights, 'float32')
        self.threshold = threshold
        self.tp = self.add_weight(
            name='tp', shape=(n_classes,), initializer='zeros')
        self.tn = self.add_weight(
            name='tn', shape=(n_classes,), initializer='zeros')
        self.fp = self.add_weight(
            name='fp', shape=(n_classes,), initializer='zeros')
        self.fn = self.add_weight(
            name='fn', shape=(n_classes,), initializer='zeros')

    def update_state(self, references, predictions, sample_weight=None):
        predictions = tf.cast(predictions > self.threshold, 'float32')
        n_labels = tf.reduce_sum(references, axis=1, keepdims=True)
        tp = tf.divide(tf.cast(tf.logical_and(
            references == 1, predictions == 1), 'float32'), n_labels)
        tn = tf.divide(tf.cast(tf.logical_and(
            references == 0, predictions == 0), 'float32'), n_labels)
        fp = tf.divide(tf.cast(tf.logical_and(
            references == 0, predictions == 1), 'float32'), n_labels)
        fn = tf.divide(tf.cast(tf.logical_and(
            references == 1, predictions == 0), 'float32'), n_labels)
        self.tp.assign_add(tf.reduce_sum(tp, axis=0))
        self.tn.assign_add(tf.reduce_sum(tn, axis=0))
        self.fp.assign_add(tf.reduce_sum(fp, axis=0))
        self.fn.assign_add(tf.reduce_sum(fn, axis=0))

    def reset_states(self):
        self.tp.assign(tf.zeros((self.n_classes,), 'float32'))
        self.tn.assign(tf.zeros((self.n_classes,), 'float32'))
        self.fp.assign(tf.zeros((self.n_classes,), 'float32'))
        self.fn.assign(tf.zeros((self.n_classes,), 'float32'))

    def result(self):
        accuracy = ((self.tp + self.tn)
                    / (self.tp + self.tn + self.fp + self.fn))
        return tf.reduce_sum(self.class_weights * accuracy) / self.n_classes


class ClassWeightedFScore(tf.keras.metrics.Metric):
    """Class-weighted F-score for Keras.

    Metric to compute the class-weighted F-score as defined in the challenge
    of Computing in Cardiology 2020. This metric should be used when training a
    Keras model.

    Args:
        n_classes: Number of classes.
        class_weights: Array of class weights.
        beta: Parameter of the F-score.
        threshold: Threshold to converted predicted probabilities to class
            labels.
        name: Metric name.
    """

    def __init__(self,
                 n_classes,
                 class_weights=None,
                 beta=2,
                 threshold=0.5,
                 name='class_weighted_f_score',
                 **kwargs):
        super().__init__(name=name, **kwargs)
        self.n_classes = n_classes
        if class_weights is None:
            self.class_weights = tf.ones((self.n_classes,), 'float32')
        else:
            self.class_weights = tf.convert_to_tensor(class_weights, 'float32')
        self.beta = beta
        self.threshold = threshold
        self.tp = self.add_weight(
            name='tp', shape=(n_classes,), initializer='zeros')
        self.fp = self.add_weight(
            name='fp', shape=(n_classes,), initializer='zeros')
        self.fn = self.add_weight(
            name='fn', shape=(n_classes,), initializer='zeros')

    def update_state(self, references, predictions, sample_weight=None):
        predictions = tf.cast(predictions > self.threshold, 'float32')
        n_labels = tf.reduce_sum(references, axis=1, keepdims=True)
        tp = tf.divide(tf.cast(tf.logical_and(
            references == 1, predictions == 1), 'float32'), n_labels)
        fp = tf.divide(tf.cast(tf.logical_and(
            references == 0, predictions == 1), 'float32'), n_labels)
        fn = tf.divide(tf.cast(tf.logical_and(
            references == 1, predictions == 0), 'float32'), n_labels)
        self.tp.assign_add(tf.reduce_sum(tp, axis=0))
        self.fp.assign_add(tf.reduce_sum(fp, axis=0))
        self.fn.assign_add(tf.reduce_sum(fn, axis=0))

    def reset_states(self):
        self.tp.assign(tf.zeros((self.n_classes,), 'float32'))
        self.fp.assign(tf.zeros((self.n_classes,), 'float32'))
        self.fn.assign(tf.zeros((self.n_classes,), 'float32'))

    def result(self):
        a = 1 + self.beta ** 2
        b = self.beta ** 2
        f = a * self.tp / (a * self.tp + self.fp + b * self.fn)
        return tf.reduce_sum(self.class_weights * f) / self.n_classes


class ClassWeightedJaccardMeasure(tf.keras.metrics.Metric):
    """Class-weighted F-score for Keras.

    Metric to compute the class-weighted Jaccard measure as defined in the
    challenge of Computing in Cardiology 2020. This metric should be used when
    training a Keras model.

    Args:
        n_classes: Number of classes.
        class_weights: Array of class weights.
        beta: Parameter of the Jaccard measure.
        threshold: Threshold to converted predicted probabilities to class
            labels.
        name: Metric name.
    """

    def __init__(self,
                 n_classes,
                 class_weights=None,
                 beta=2,
                 threshold=0.5,
                 name='class_weighted_jaccard_measure',
                 **kwargs):
        super().__init__(name=name, **kwargs)
        self.n_classes = n_classes
        if class_weights is None:
            self.class_weights = tf.ones((self.n_classes,), 'float32')
        else:
            self.class_weights = tf.convert_to_tensor(class_weights, 'float32')
        self.beta = beta
        self.threshold = threshold
        self.tp = self.add_weight(
            name='tp', shape=(n_classes,), initializer='zeros')
        self.fp = self.add_weight(
            name='fp', shape=(n_classes,), initializer='zeros')
        self.fn = self.add_weight(
            name='fn', shape=(n_classes,), initializer='zeros')

    def update_state(self, references, predictions, sample_weight=None):
        predictions = tf.cast(predictions > self.threshold, 'float32')
        n_labels = tf.reduce_sum(references, axis=1, keepdims=True)
        tp = tf.divide(tf.cast(tf.logical_and(
            references == 1, predictions == 1), 'float32'), n_labels)
        fp = tf.divide(tf.cast(tf.logical_and(
            references == 0, predictions == 1), 'float32'), n_labels)
        fn = tf.divide(tf.cast(tf.logical_and(
            references == 1, predictions == 0), 'float32'), n_labels)
        self.tp.assign_add(tf.reduce_sum(tp, axis=0))
        self.fp.assign_add(tf.reduce_sum(fp, axis=0))
        self.fn.assign_add(tf.reduce_sum(fn, axis=0))

    def reset_states(self):
        self.tp.assign(tf.zeros((self.n_classes,), 'float32'))
        self.fp.assign(tf.zeros((self.n_classes,), 'float32'))
        self.fn.assign(tf.zeros((self.n_classes,), 'float32'))

    def result(self):
        g = self.tp / (self.tp + self.fp + self.beta * self.fn)
        return tf.reduce_sum(self.class_weights * g) / self.n_classes


class CincScore(tf.keras.metrics.Metric):
    """CinC score for Keras.

    Metric to compute the score for evaluating entries in the challenge of
    Computing in Cardiology 2020. This metric is defined as the geometric mean
    of the class-weighted F-score and the class-weighted Jaccard measure and
    should be used when training a Keras model.

    Args:
        n_classes: Number of classes.
        class_weights: Array of class weights.
        beta: Parameter of the F-score and the Jaccard measure.
        threshold: Threshold to converted predicted probabilities to class
            labels.
        name: Metric name.
    """

    def __init__(self,
                 n_classes,
                 class_weights=None,
                 beta=2,
                 threshold=0.5,
                 name='cinc_score',
                 **kwargs):
        super().__init__(name=name, **kwargs)
        self.n_classes = n_classes
        if class_weights is None:
            self.class_weights = tf.ones((self.n_classes,), 'float32')
        else:
            self.class_weights = tf.convert_to_tensor(class_weights, 'float32')
        self.beta = beta
        self.threshold = threshold
        self.tp = self.add_weight(
            name='tp', shape=(n_classes,), initializer='zeros')
        self.fp = self.add_weight(
            name='fp', shape=(n_classes,), initializer='zeros')
        self.fn = self.add_weight(
            name='fn', shape=(n_classes,), initializer='zeros')

    def update_state(self, references, predictions, sample_weight=None):
        print("Update state")
        predictions = tf.cast(predictions > self.threshold, 'float32')
        n_labels = tf.reduce_sum(references, axis=1, keepdims=True)
        tp = tf.divide(tf.cast(tf.logical_and(
            references == 1, predictions == 1), 'float32'), n_labels)
        fp = tf.divide(tf.cast(tf.logical_and(
            references == 0, predictions == 1), 'float32'), n_labels)
        fn = tf.divide(tf.cast(tf.logical_and(
            references == 1, predictions == 0), 'float32'), n_labels)
        self.tp.assign_add(tf.reduce_sum(tp, axis=0))
        self.fp.assign_add(tf.reduce_sum(fp, axis=0))
        self.fn.assign_add(tf.reduce_sum(fn, axis=0))

    def reset_states(self):
        print("Reset")
        self.tp.assign(tf.zeros((self.n_classes,), 'float32'))
        self.fp.assign(tf.zeros((self.n_classes,), 'float32'))
        self.fn.assign(tf.zeros((self.n_classes,), 'float32'))

    def result(self):
        print("Result")
        a = 1 + self.beta ** 2
        b = self.beta ** 2
        f = a * self.tp / (a * self.tp + self.fp + b * self.fn)
        f = tf.reduce_sum(self.class_weights * f) / self.n_classes
        g = self.tp / (self.tp + self.fp + self.beta * self.fn)
        g = tf.reduce_sum(self.class_weights * g) / self.n_classes
        print(tf.sqrt(f * g))
        return tf.sqrt(f * g)


class CincScore2(tf.keras.metrics.Metric):
    """CinC score, phase 2 for Keras.

    Metric to compute the score for evaluating entries in the second phase of
    the challenge of
    Computing in Cardiology 2020.

    Args:
        n_classes: Number of classes.
        class_weights: Array of class weights.
        beta: Parameter of the F-score and the Jaccard measure.
        threshold: Threshold to converted predicted probabilities to class
            labels.
        name: Metric name.
    """

    def __init__(self,
                 class_weights,
                 classes,
                 normal_class='426783006',
                 threshold=0.5,
                 name='cinc_score_2',
                 **kwargs):
        super().__init__(name=name, **kwargs)
        self.threshold = threshold
        self.class_weights = tf.convert_to_tensor(class_weights, 'float32')
        self.n_classes = len(classes)
        self.normal_class = normal_class
        self.observed_cm = self.add_weight(
            name='observed_cm', shape=(self.n_classes, self.n_classes),
            initializer='zeros'
        )
        self.perfect_cm = self.add_weight(
            name='observed_cm', shape=(self.n_classes, self.n_classes),
            initializer='zeros'
        )
        self.inactive_cm = self.add_weight(
            name='observed_cm', shape=(self.n_classes, self.n_classes),
            initializer='zeros'
        )
        ip = np.zeros(self.n_classes)
        if normal_class in classes:
            ip[classes.index(normal_class)] = 1
        self.inactive_prediction = tf.convert_to_tensor(ip, 'float32')

    def update_state(self, references, predictions, sample_weight=None):
        references = tf.convert_to_tensor(references)
        predictions = tf.cast(predictions > self.threshold, 'float32')
        observed_cm = compute_modified_confusion_matrix(
            references, predictions
        )
        perfect_cm = compute_modified_confusion_matrix(
            references, references
        )
        batch_size = tf.shape(references)[0]
        inactive_predictions = tf_outer_product(
            tf.ones(batch_size), self.inactive_prediction
        )
        inactive_cm = compute_modified_confusion_matrix(
            references, inactive_predictions
        )

        self.observed_cm.assign_add(observed_cm)
        self.perfect_cm.assign_add(perfect_cm)
        self.inactive_cm.assign_add(inactive_cm)

    def reset_states(self):
        self.observed_cm.assign(tf.zeros(
            (self.n_classes, self.n_classes), 'float32'
        ))
        self.perfect_cm.assign(tf.zeros(
            (self.n_classes, self.n_classes), 'float32'
        ))
        self.inactive_cm.assign(tf.zeros(
            (self.n_classes, self.n_classes), 'float32'
        ))

    def result(self):
        observed_score = tf.reduce_sum(tf.math.multiply(
            self.class_weights, self.observed_cm
        ))
        perfect_score = tf.reduce_sum(tf.math.multiply(
            self.class_weights, self.perfect_cm
        ))
        inactive_score = tf.reduce_sum(tf.math.multiply(
           self.class_weights, self.inactive_cm
        ))
        return (observed_score - inactive_score) / \
               (perfect_score - inactive_score)


def compute_modified_confusion_matrix(references, predictions):
    normalization = tf.cast(
        tf.math.reduce_any([references, predictions], 0), 'float32'
    )
    normalization = tf.reduce_sum(normalization)
    normalization = tf.cast(tf.maximum(normalization, 1), 'float32')
    cm = tf_outer_product(references, predictions)/normalization
    return tf.reduce_sum(cm, 0)


def tf_outer_product(v1, v2):
    """Take the outer product of two 1D tensors"""
    # Compute outer product of the references and predictions.
    return tf.matmul(
        tf.expand_dims(v1, -1),
        tf.expand_dims(v2, -1),
        transpose_b=True
    )
