#!/usr/bin/env python

import numpy as np
import os
import modelDefinition
import torch
import dataReading
from scipy import signal
import dataPreprocessing
import math


def run_12ECG_classifier(data, header_data, loaded_model):
    with torch.no_grad():

        # load the mapping
        code_to_int, int_to_code, text = dataReading.load_classes_mapping('classes.txt', combine_equal_classes=True)

        # make the class name list
        filler = '------'
        cls_lst = [filler for _ in int_to_code.keys()]

        for key in int_to_code:
            if isinstance(int_to_code[key], list):
                cls_lst[key] = int_to_code[key][0]
            else:
                cls_lst[key] = int_to_code[key]

        if filler in cls_lst:
            raise ValueError('Something is wrong.')

        # get the header information
        _, _, _, age, sex, identifier, spl_rate, gain = dataReading.header2info(header_data, code_to_int, int_to_code,
                                                                                text)
        assert gain < 5000, f'Gain is {gain} (too big).'
        data = data/gain
        unfiltered_sig = dataPreprocessing.resample_signal(data, spl_rate, 500)

        # filter the signal
        filtered_sig = dataPreprocessing.filter_sig(unfiltered_sig, loaded_model.sos)

        # get the device
        device = loaded_model.temp_device

        # get the model input
        inp = torch.from_numpy(filtered_sig).float()
        inp = inp.to(device)
        spl = torch.tensor([[age, sex]]).to(device).float()

        # check whether the signal is long enough to get multiple opinions (longer than 20s)
        if inp.shape[1] > 10000:
            length = 5000
            repetitions = math.floor(inp.shape[1] / length)
            inp = torch.stack([inp[:, sl * length:length * (sl + 1)] for sl in range(repetitions)], dim=0)
            spl = spl.repeat((repetitions, 1))
        else:
            inp = inp.unsqueeze(0)

        # feed through the network
        current_score = loaded_model(inp, spl)
        current_score = loaded_model.temp_activation(current_score)
        current_score = current_score.detach().cpu().numpy()
        current_score = np.mean(current_score, axis=0, keepdims=True)
        current_label = current_score > 0.071
        assert current_label.shape[0] == 1

        return current_label[0, :], current_score[0, :], cls_lst


def load_12ECG_model(input_directory):
    # load the model from disk 
    f_out = 'best_model'
    filename = os.path.join(input_directory, f_out)

    # get the cuda device
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(device)

    netw = modelDefinition.GenericResNet(features=16, classes=24, kernel_size=5, padding=2,
                                         down_kernel_size=5, down_padding=2)

    netw.to(device)
    netw.load_state_dict(torch.load(filename))
    netw.eval()

    # save the device
    netw.temp_device = device

    # initialize the activation
    netw.temp_activation = torch.nn.Sigmoid()

    # make the filter for the signals
    netw.sos = signal.butter(2, [0.5, 45], btype='band', analog=False, output='sos', fs=500)

    return netw
