# Jan Pavlus
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as opt
from torch.cuda import amp

from sklearn.metrics import confusion_matrix, f1_score, fbeta_score
from tqdm import tqdm
from sklearn.preprocessing import LabelBinarizer
from metrics.score_murmur import score_murmur
from utils import get_device

from evaluation_2022.evaluate_model import compute_cost,compute_weighted_accuracy

class Trainer():
    def __init__(self, model:nn.Module,chunk_transforms, model_folder:str, learning_rate:float,num_epochs:int, loss_S_weigths=torch.Tensor([0, 1, 1,1,1]), loss_Z_weigths=torch.Tensor([5, 3, 1])):
        self._device = torch.device('cuda:0')#get_device()
        print(self._device)
        self._loss_S = nn.CrossEntropyLoss(weight=loss_S_weigths.to(self._device))
        self._loss_C = nn.CrossEntropyLoss(weight=loss_Z_weigths.to(self._device))
        self._loss_O = nn.CrossEntropyLoss(weight=torch.Tensor([2,1]).to(self._device))

        self._num_epochs = num_epochs

        self._mdl = model
        self._mdl.to(self._device)

        self._optim = opt.Adam(self._mdl.parameters(),lr=learning_rate)
        self._epoch = 0
        self.chunk_transforms = chunk_transforms
        self._model_folder = model_folder
        self._best=None

    @staticmethod
    def transform_segmentation_to_STFT_len(seg,window,offset,Sxx):
        #seg = seg.data
        seg = seg.data.cpu().numpy()
        #out = torch.zeros((Sxx.shape[0],Sxx.shape[-1]))
        out = np.zeros((Sxx.shape[0],Sxx.shape[-1]))
        for k in range(Sxx.shape[0]):
            for i in range(Sxx.shape[-1]):
                tmp = seg[k,i*offset:i*offset+window]

                #Most frequent value in the above array
                #tmp = torch.bincount(tmp).argmax()
                tmp = np.bincount(tmp).argmax()

                out[k,i] = tmp
        out = torch.from_numpy(out)
        return out

    @staticmethod
    def split(x,window,step,transform):
        out = []
        i = 0
        while i < x.shape[0]:
            tmp = torch.zeros((window,))
            if x.shape[0] - i >= window:
                tmp[:] = x[i:i+window]
            else:
                # zeropad last segment
                tmp[ :x.shape[0] - i] = x[i:x.shape[0]]
            if transform:
                tmp = transform(tmp)
            out.append(tmp)
            i += step
        out = torch.stack(out,dim=0)
        return out



    def train_valid_loop(self,dataloader,train=False):
        if train:
            self._mdl.train()
        else:
            self._mdl.eval()

        collector = dict()
        collector['target_murmur'] = []
        collector['predicted_murmur'] = []
        collector['target_segmentation'] = []
        collector['predicted_segmentation'] = []
        collector['target_outcome'] = []
        collector['predicted_outcome'] = []



        for i, (input_data, target_murmur, target_segmentation, target_output) in tqdm(enumerate(dataloader)):
            with torch.set_grad_enabled(train):
                with amp.autocast(enabled=self._device.type == 'cuda'):
                    target_segmentation = self.transform_segmentation_to_STFT_len(target_segmentation,256,128,input_data)
                    target_segmentation = target_segmentation.to(self._device).long()

                    input_data = input_data.to(self._device).float().unsqueeze(1)
                    target_murmur = target_murmur.to(self._device).long()
                    target_outcome = target_output.to(self._device).long()


                    predicted_segmentation, predicted_murmur, predicted_outcome = self._mdl(input_data)

                    J = self._loss_S(predicted_segmentation, target_segmentation) \
                      + self._loss_C(predicted_murmur,target_murmur)# \
                      #+ self._loss_O(predicted_outcome,target_outcome)

                    if train:
                        self._optim.zero_grad()
                        J.backward()
                        self._optim.step()

                    if i % 10 == 0:
                        print(self._epoch, J)

                    collector['target_murmur'] += target_murmur.data.cpu()
                    collector['predicted_murmur'] += predicted_murmur.data.cpu()
                    collector['target_segmentation'] += target_segmentation.data.cpu()
                    collector['predicted_segmentation'] += predicted_segmentation.data.cpu()
                    collector['target_outcome'] += target_outcome.data.cpu()
                    collector['predicted_outcome'] += predicted_outcome.data.cpu()

        collector['target_murmur'] = torch.stack(collector['target_murmur'], dim=0)
        collector['predicted_murmur'] = torch.stack(collector['predicted_murmur'], dim=0)
        collector['target_segmentation'] = torch.stack(collector['target_segmentation'], dim=0)
        collector['predicted_segmentation'] = torch.stack(collector['predicted_segmentation'], dim=0)
        collector['target_outcome'] = torch.stack(collector['target_outcome'], dim=0)
        collector['predicted_outcome'] = torch.stack(collector['predicted_outcome'], dim=0)

        collector['predicted_murmur'] = torch.argmax(collector['predicted_murmur'],dim=1)
        collector['predicted_segmentation'] = torch.argmax(collector['predicted_segmentation'],dim=1)
        collector['predicted_outcome'] = torch.argmax(collector['predicted_outcome'],dim=1)


        # confusion matrix for segmentation
        CM = confusion_matrix(y_true=collector['target_segmentation'].ravel(), y_pred=collector['predicted_segmentation'].ravel())
        print(CM)

        # confusion matrix for murmur
        CM = confusion_matrix(y_true=collector['target_murmur'], y_pred=collector['predicted_murmur'])
        print(CM)
        print(f"Murmur score: {score_murmur(CM.T)}")

        # confusion matrix for outcome
        CM = confusion_matrix(y_true=collector['target_outcome'], y_pred=collector['predicted_outcome'])
        print(CM)

        stop = 1


    def train(self, train_dataloader):
        self.train_valid_loop(train_dataloader,train=True)



    def validate(self, valid_dataloader, valid_inference_dataset, ensemble_id):
        self.train_valid_loop(valid_dataloader,train=False)

        # VALIDATION
        chunk_size = 20000

        results_murmur = []
        targets_murmur = []
        results_outcome = []
        targets_outcome = []
        # dataloader chunk transforms
        for recordings, target,outcome in tqdm(valid_inference_dataset):
            tmp_murmur = []
            tmp_outcome = []
            for recording in recordings:
                x = self.split(recording, window=chunk_size, step=10000, transform=self.chunk_transforms)
                x = x.to(self._device).float().unsqueeze(1)
                predicted_segmentation, predicted_murmur, predicted_outcome = self._mdl(x)
                predicted_murmur = torch.softmax(predicted_murmur,dim=-1)
                predicted_murmur = predicted_murmur.data.cpu().numpy()

                predicted_outcome = torch.sigmoid(predicted_outcome)
                predicted_outcome = predicted_outcome.data.cpu().numpy()

                tmp_murmur.append(predicted_murmur)
                tmp_outcome.append(predicted_outcome)

            tmp_murmur = np.concatenate(tmp_murmur, axis=0)
            tmp_murmur = tmp_murmur.mean(axis=0)
            tmp_murmur = tmp_murmur / tmp_murmur.sum()
            results_murmur.append(tmp_murmur)
            targets_murmur.append(target)

            tmp_outcome = np.concatenate(tmp_outcome, axis=0)
            tmp_outcome = tmp_outcome.mean(axis=0)
            results_outcome.append(tmp_outcome)
            targets_outcome.append(outcome)

        targets_outcome = np.array(targets_outcome)
        results_outcome = np.stack(results_outcome,axis=0)

        results_murmur = np.argmax(results_murmur, axis=1)
        targets_murmur = np.array(targets_murmur)

        targets_murmur = self.one_hot(targets_murmur,3)
        results_murmur = self.one_hot(results_murmur,3)

        score = compute_weighted_accuracy(targets_murmur, results_murmur, ['Present', 'Unknown', 'Absent'])  # This is the murmur scoring metric.
        print(score)
        # CM = confusion_matrix(y_true=targets_murmur, y_pred=results_murmur)
        # score = score_murmur(CM.T)
        # print("Patient validate...")
        # print(CM)
        # print(f"Murmur score:{score}")



        if (self._best == None) or (score > self._best):
            threshold = self.optimize_outcome_threshold(targets_outcome, results_outcome)
            print('SAVE')
            self._best = score
            torch.save({'model':self._mdl.state_dict(),
                        'epoch':self._epoch,
                        'score':score,
                        'threshold':threshold}, f'{self._model_folder}/model_murmur_{ensemble_id}')

    def one_hot(self,a, num_classes):
        y = np.zeros((a.shape[0],num_classes))
        y[np.arange(a.shape[0]),a] = 1
        return y

    def optimize_outcome_threshold(self,y_true,y_pred):
        outcome_classes = outcome_classes = ['Abnormal', 'Normal']
        y_true = self.one_hot(y_true,2)

        scores = np.zeros((100,))
        for i in range(1,99):
            tmp = y_pred[:,0] > i/100
            tmp = np.stack([tmp,~tmp],axis=1)
            score = compute_cost(y_true, tmp, outcome_classes,
                                        outcome_classes)  # This is the clinical outcomes scoring metric.
            scores[i] = score
        stop = 0

        threshold = np.argmin(scores[20:80])+20

        if False:
            plt.figure(dpi=300)
            plt.plot(scores)
            plt.plot([threshold],[scores[threshold]],'r*')
            plt.show()

        return threshold



    def run(self, train_dataloader, valid_dataloader, valid_inference_dataset,ensemble_id):
        for self._epoch in range(self._num_epochs):
            print(f"Epoch {self._epoch}")
            print("Train...")
            self.train(train_dataloader)
            print("Validate...")
            self.validate(valid_dataloader=valid_dataloader, valid_inference_dataset=valid_inference_dataset,ensemble_id=ensemble_id)
        print("Complete")