# Jan Pavlus
import scipy.io as sio
import glob
from typing import Dict, List, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import scipy.io.wavfile as wf
import torch as th
from torch.utils.data import Dataset
import torchvision


from helper_code import *
MAX_INT16 = np.iinfo(np.int16).max


class Segmentation2016Dataset(Dataset):
    def __init__(self, pth_data, pth_seg, subdirs=['training-a', 'training-b', 'training-c', 'training-d', 'training-e', 'training-f']):
        self.pth_data = pth_data
        self.pth_seg = pth_seg

        self.pth_data = [f for k in subdirs for f in glob.glob(
            f"{self.pth_data}/{k}/*.hea")]
        self.pth_seg = [f for k in subdirs for f in glob.glob(
            f"{self.pth_seg}/{k}/*.mat")]

        self.pth_data = pd.DataFrame(self.pth_data, columns=['header'])
        self.pth_data['wav'] = self.pth_data.apply(
            lambda x: x['header'].replace('hea', 'wav'), axis=1)

        self.pth_data['id'] = self.pth_data.apply(lambda x: (
            x['header'].split('/')[-1]).split('.')[0], axis=1)

        self.pth_seg = pd.DataFrame(self.pth_seg, columns=['segmentation'])
        self.pth_seg['id'] = self.pth_seg.apply(lambda x: (
            x['segmentation'].split('/')[-1]).split('_')[0], axis=1)

        self.data = pd.merge(left=self.pth_data, right=self.pth_seg, on=['id'])
        self.data = self.data.sample(frac=1)
        stop = 1

    def __getitem__(self, item):
        wav, fs = load_wav_file(self.data.iloc[item]['wav'])
        patient_data = load_patient_data(self.data.iloc[item]['header'])
        seg = sio.loadmat(self.data.iloc[item]['segmentation'])['state_ans']
        segarray = np.zeros_like(wav)
        for i in range(seg.shape[0]-1):
            start = seg[i][0][0][0]
            stop = seg[i+1][0][0][0]
            val = 0
            if seg[i][1] == 'diastole':
                val = 4
            if seg[i][1] == 'S1':
                val = 1
            if seg[i][1] == 'systole':
                val = 2
            if seg[i][1] == 'S2':
                val = 3
            segarray[start:stop] = val

        plt.plot(scipy.stats.zscore(wav)[:10000])
        plt.plot(segarray[:10000])
        plt.show()
        return 0

    def __len__(self):
        return len(self.data)


class Murmur2022Dataset(Dataset):
    def __init__(self,
                 data: List[Dict],
                 transforms: torchvision.transforms.transforms.Compose = None,
                 segTransforms: torchvision.transforms.transforms.Compose = None):
        """Dataset class

        Args:
            data (List[Dict]): Loaded data.
            transforms (torchvision.transforms.transforms.Compose, optional): Pytorch transforms that will be applied
                                                                              on signal. Defaults to None.
            segTransforms (torchvision.transforms.transforms.Compose, optional): Pytorch transforms that will be applied 
                                                                                 on segmentation. Defaults to None.
        """
        self._data = data
        self._transforms = transforms
        self._segTransforms = segTransforms

    def _load_wav(self, path: str, return_rate: bool = False) -> Union[np.array, Tuple[int, np.array]]:
        """Read wave files using scipy.io.wavfile(support multi-channel)

        Args:
            path (str): _description_
            return_rate (bool, optional): _description_. Defaults to False.

        Returns:
            Union[np.array,Tuple[int, np.array]]: _description_
        """
        # samps_int16: N x C or N
        #   N: number of samples
        #   C: number of channels
        samp_rate, samps_int16 = wf.read(path)
        # N x C => C x N
        if samps_int16.dtype == np.dtype('int16'):
            normalize = True
        samps = samps_int16.astype(np.float)
        # tranpose because I used to put channel axis first
        if samps.ndim != 1:
            samps = np.transpose(samps)
        # normalize like MATLAB and librosa
        if normalize:
            samps = samps / MAX_INT16
        if return_rate:
            return samp_rate, samps
        return samps

    def __getitem__(self, id: int) -> Tuple:
        wav = th.tensor(self._load_wav(self._data[id]["wav"]))
        if self._transforms:
            wav = self._transforms(wav)
        tsv = self._data[id]["tsv"]
        if self._segTransforms:
            tsv = self._segTransforms(tsv)
        return wav, self._data[id]["classnum"], tsv, self._data[id]["outcomenum"],

    def __len__(self):
        return len(self._data)


class Murmur2022ValidDataset(Dataset):
    def __init__(self, files):
        self.data = files
        self.data = pd.DataFrame(self.data)
        self.data.drop(columns=['tsv'], inplace=True)
        self.data['patient'] = self.data.apply(lambda x: (
            x['wav'].split('/')[-1]).split("_")[0], axis=1)
        self.data = self.data.groupby(by=['patient'])
        self.data = [group for _, group in self.data]

    def __len__(self):
        return len(self.data)

    def _load_wav(self, path: str, return_rate: bool = False) -> Union[np.array, Tuple[int, np.array]]:
        """Read wave files using scipy.io.wavfile(support multi-channel)

        Args:
            path (str): _description_
            return_rate (bool, optional): _description_. Defaults to False.

        Returns:
            Union[np.array,Tuple[int, np.array]]: _description_
        """
        # samps_int16: N x C or N
        #   N: number of samples
        #   C: number of channels
        samp_rate, samps_int16 = wf.read(path)
        # N x C => C x N
        if samps_int16.dtype == np.dtype('int16'):
            normalize = True
        samps = samps_int16.astype(np.float)
        # tranpose because I used to put channel axis first
        if samps.ndim != 1:
            samps = np.transpose(samps)
        # normalize like MATLAB and librosa
        if normalize:
            samps = samps / MAX_INT16
        if return_rate:
            return samp_rate, samps
        return samps

    def __getitem__(self, item):
        x = self.data[item]
        target = x['classnum'].to_numpy().min()
        outcome_target = x['outcomenum'].to_numpy().min()
        u = np.unique(x['classnum'].to_numpy())
        if len(u) >= 2:
            stop = 1

        wav = []
        for file in x['wav'].tolist():
            wav.append(th.tensor(self._load_wav(file)))

        return wav, target, outcome_target
