# Petr Nejedly

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class ChallengeLoss(nn.Module):
    def __init__(self, c_algorithm=1, c_GP=250, c_specialist=500, c_treatment=1000, c_error=10000, alpha=0.5):
        super(ChallengeLoss, self).__init__()

        self.P = 0
        self.U = 1
        self.N = 2

        self.W = np.zeros((3, 3))
        self.W[self.P, self.P] = c_algorithm + c_GP + c_treatment
        self.W[self.P, self.U] = c_algorithm + c_GP + c_specialist + alpha * c_treatment
        self.W[self.P, self.N] = c_algorithm + c_GP
        self.W[self.U, self.P] = c_algorithm + c_specialist + c_treatment
        self.W[self.U, self.U] = c_algorithm + c_specialist + alpha * c_treatment
        self.W[self.U, self.N] = c_algorithm + c_specialist
        self.W[self.N, self.P] = c_algorithm + c_error
        self.W[self.N, self.U] = c_algorithm + alpha * c_error
        self.W[self.N, self.N] = c_algorithm
        self.Wtorch = torch.from_numpy(self.W).float().requires_grad_(False)

    def __call__(self, t, y, disp=False):
        n_total = t.shape[0]
        t = F.one_hot(t).float()
        cm = t.T @ y
        cm = cm.T
        if disp:
            print(np.round(cm.data.cpu().numpy(), 2))
        c1w = torch.sum(cm * self.Wtorch.to(t.device)) / n_total
        return c1w

    def compute_from_CM(self, CM):
        n_total = CM.sum()
        c1w = np.sum(CM * self.W) / n_total
        return c1w