#!/usr/bin/env python

# Edit this script to add your team's code. Some functions are *required*, but you can edit most parts of the required functions,
# change or remove non-required functions, and add your own functions.

################################################################################
#
# Import libraries and functions. You can change or remove them.
#
################################################################################

import chunk
import glob
import pdb
import subprocess
from helper_code import *
import numpy as np, scipy as sp, scipy.stats, os, sys, joblib
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt

from dataset.prepare import prepare, MODES
from dataset.dataset import Murmur2022Dataset,Murmur2022ValidDataset
from dataset.dataloader import Murmur2022DataLoader,ChunkSplitter

import torch
import torch.nn as nn
import torch.optim as opt
import torch.nn.functional as F
import torchvision.transforms as transforms
import random
from sklearn.metrics import confusion_matrix, f1_score, fbeta_score
from models.crnn_lstm import model_segment_lstm
from trainer import Trainer

from transforms import *
from models.model import *
from metrics.challenge_loss import ChallengeLoss
from utils import get_device

from evaluation_2022.evaluate_model import compute_cost
from train_outcome_model import train_outcome_model_features,NN,get_features,prepare_features
# torch.manual_seed(0)
# random.seed(0)
# np.random.seed(0)
################################################################################
#
# Required functions. Edit these functions to add your code, but do not change the arguments.
#
################################################################################


def show_spectrogram(x,s,bidx):
    """

    :param x: spectrogram [batch,Faxis,Taxis]
    :param s: segmentace
    :param bidx: idx in batch
    :return:
    """
    fig, ax = plt.subplots(2, 1, dpi=300)
    ax[0].imshow(x.data[bidx, :, :].cpu().numpy(), aspect='auto')
    ax[1].plot(s[bidx,:].data.cpu().numpy())
    ax[1].set_xlim(0, s.shape[1])
    plt.show()



def transform_segmentation_to_STFT_len(seg,window,offset,Sxx):
    seg = seg.data.cpu().numpy()
    out = np.zeros((Sxx.shape[0],Sxx.shape[-1]))
    for k in range(Sxx.shape[0]):
        for i in range(Sxx.shape[-1]):
            tmp = seg[k,i*offset:i*offset+window]

            #Most frequent value in the above array
            tmp = np.bincount(tmp).argmax()

            out[k,i] = tmp
    out = torch.from_numpy(out)
    return out



# Train your model.
def train_challenge_model(data_folder, model_folder, verbose):
    MODELS = 15
    MODELS_OUTCOME = 5
    features, murmurs, outcomes = prepare_features(data_folder, model_folder, verbose)
    for i in range(MODELS_OUTCOME):
        train_outcome_model_features(features, murmurs, outcomes, model_folder,ensemble_id=i)
    #return
    for i in range(MODELS):
        _train_challenge_model(data_folder, model_folder, verbose,id=i,num_epochs=10)








def _train_challenge_model(data_folder, model_folder, verbose,id,num_epochs):
    os.makedirs(model_folder, exist_ok=True)
    # Find data files.
    if verbose >= 1:
        print('Finding data files...')
    train, valid = prepare(data_folder, mode=MODES["All"], split_ratio=0.2, verbose=verbose)

    transform_list = []
    transforms_dataset = transforms.Compose(transform_list)
    seg_transform_list = []
    seg_transforms_dataset = transforms.Compose(seg_transform_list)
    # datasets
    train_dataset = Murmur2022Dataset(train, transforms=transforms_dataset, segTransforms=seg_transforms_dataset)
    valid_dataset = Murmur2022Dataset(valid, transforms=transforms_dataset, segTransforms=seg_transforms_dataset)
    valid_inference_dataset = Murmur2022ValidDataset(valid)

    # dataloader chunk transforms
    chunk_transform_list = [torchaudio.transforms.Spectrogram(n_fft = 256, hop_length=128,normalized=False),
                            transform_clip_Faxis(min=0,max=64),
                            transform_normalize_log10(),
                            transform_normalize_zscore()]
    chunk_transforms = transforms.Compose(chunk_transform_list)
    
    chunk_seg_transform_list = []
    chunk_seg_transforms = transforms.Compose(chunk_seg_transform_list)

    # dataloaders
    train_dataloader = Murmur2022DataLoader(dataset=train_dataset,
                                            num_workers=0,
                                            chunk_size=20000,
                                            overlap=10000,
                                            batch_size=64,
                                            train=True,
                                            transforms=chunk_transforms,
                                            segTransforms=chunk_seg_transforms)
    valid_dataloader = Murmur2022DataLoader(dataset=valid_dataset,
                                            num_workers=0,
                                            chunk_size=20000,
                                            overlap=10000,
                                            batch_size=64,
                                            train=False,
                                            transforms=chunk_transforms,
                                            segTransforms=chunk_seg_transforms)

    trainer = Trainer(model=model_segment_gru(), 
                      model_folder=model_folder, 
                      chunk_transforms=chunk_transforms, 
                      learning_rate=1e-4,
                      num_epochs=num_epochs)

    trainer.run(train_dataloader, valid_dataloader, valid_inference_dataset,ensemble_id=id)



# Load your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the
# arguments of this function.
def load_challenge_model(model_folder, verbose):
    ENSEMBLES_INFERENCE =5

    device = torch.device('cuda:0')
    models_murmur = glob.glob(f'{model_folder}/model_murmur*')
    models_murmur = [(k,torch.load(k)['score']) for k in models_murmur]
    models_murmur.sort(key=lambda tup: tup[1], reverse=True)

    models_outcome = glob.glob(f'{model_folder}/model_outcome*')
    models_outcome = [(k,torch.load(k)['score']) for k in models_outcome]

    print(models_murmur)
    models_murmur = models_murmur[:ENSEMBLES_INFERENCE]
    mdl = {}
    for i,pth in enumerate(models_murmur):
        mdl[i] = {}
        mdl[i]['model_murmur'] = model_segment_gru()
        tmp = torch.load(f'{model_folder}/model_murmur_{i}')
        mdl[i]['model_murmur'].load_state_dict(tmp['model'])
        mdl[i]['model_murmur'].to(device)
        mdl[i]['model_murmur'].eval()
        #mdl[i]['threshold_murmur'] = tmp['threshold']

        mdl[i]['model_outcome'] = NN()
        tmp = torch.load(f'{model_folder}/model_outcome_{i}')
        mdl[i]['model_outcome'].load_state_dict(tmp['model'])
        mdl[i]['model_outcome'].to(device)
        mdl[i]['model_outcome'].eval()
        mdl[i]['threshold_outcome'] = tmp['threshold']

    return mdl


def split(x,window,step,transform):
    out = []
    i = 0
    while i < x.shape[0]:
        tmp = np.zeros((window,))
        if x.shape[0] - i >= window:
            tmp[:] = x[i:i+window]
        else:
            # zeropad last segment
            tmp[ :x.shape[0] - i] = x[i:x.shape[0]]
        tmp = torch.from_numpy(tmp)
        if transform:
            tmp = transform(tmp)
        out.append(tmp)
        i += step
    out = torch.stack(out,dim=0)
    return out




# Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the
# arguments of this function.
def run_challenge_model(model, data, recordings, verbose):
    """
    model := torch model
    data := header file
    recordings := wav file data

    return:
    classes := ['Present', 'Unknown', 'Absent']
    labels := ndarray(3,) type int64
    probabilities := ndarray(3,) type float32
    """
    try:
        device = torch.device('cuda:0')

        features = get_features(data,recordings)
        features = features.reshape(1,-1)
        features = torch.from_numpy(features).to(device).float()

        p = [model[k]['model_outcome'](features) for k in range(5)]
        thrs = [model[k]['threshold_outcome']/100 for k in range(5)]
        p = [torch.softmax(k,dim=1) for k in p]
        p = [k.data.cpu().numpy() for k in p]
        p = np.concatenate(p,axis=0)
        pthr = p[:,1]
        pthr = pthr>thrs
        results_outcome = 1 if pthr.mean(axis=0)>0.5 else 0
        probs = p.mean(axis=0)
        # result_bin = np.argmax(p,axis=1)
        # result_bin = scipy.stats.mode(result_bin).mode[0]
        chunk_size = 20000

        # dataloader chunk transforms
        chunk_transform_list = [torchaudio.transforms.Spectrogram(n_fft = 256, hop_length=128,normalized=False),
                                transform_clip_Faxis(min=0,max=64),
                                transform_normalize_log10(),
                                transform_normalize_zscore()]
        chunk_transforms = transforms.Compose(chunk_transform_list)

        results_murmur = []
        for recording in recordings:
            x = split(recording,window=chunk_size,step=10000,transform=chunk_transforms)
            x = x.to(device).float().unsqueeze(1)

            for k in range(len(model)):
                segmenation,murmur,outcome = model[k]['model_murmur'](x)
                murmur = torch.softmax(murmur,dim=-1)
                results_murmur.append(murmur.data.cpu().numpy())


        results_murmur = np.concatenate(results_murmur,axis=0)
        results_murmur = results_murmur.mean(axis=0)
        results_murmur = results_murmur/results_murmur.sum()



        labels = [0,0,0,0,0]
        labels[np.argmax(results_murmur)] = 1
        labels[results_outcome+3] = 1

        classes = ['Present', 'Unknown', 'Absent','Abnormal', 'Normal']
        labels = np.array(labels,dtype=int)

        probabilities = np.zeros_like(labels,dtype=float)
        probabilities[:3] = results_murmur
        probabilities[3:] = probs

        return classes, labels, probabilities
    except Exception as exc:
        print("EXCEPTION")
        print(exc)
        classes = ['Present', 'Unknown', 'Absent','Abnormal', 'Normal']
        labels = np.array([1,0,0,1,0], dtype=int)
        probabilities = np.array([1,0,0,1,0],dtype=float)
        return classes, labels, probabilities

