# Jan Pavlus
# https://github.com/funcwj/conv-tasnet/blob/master/nnet/libs/dataset.py

from ast import Tuple
import random
from typing import Any, List
import torch as th
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
import torchvision

from .dataset import Murmur2022Dataset


class ChunkSplitter(object):
    """
    Split utterance into small chunks
    """

    def __init__(self,
                 chunk_size: int,
                 train: bool = True,
                 least: int = 16000,
                 transforms: torchvision.transforms.transforms.Compose = None,
                 segTransforms: torchvision.transforms.transforms.Compose = None):
        """Chunk splitter class.

        Args:
            chunk_size (int): Size of the chunk.
            train (bool, optional): Random chunk start if the train is set to True. Defaults to True.
            least (int, optional): Solves overlaping. Least size tell to chunker after how many samples
                                   the next chunk will starts. So least = chunk_size - overlap. Defaults to 16000.
            transforms (torchvision.transforms.transforms.Compose, optional): Pytorch transforms that will be applied
                                                                              on chunked signal. Defaults to None.
            segTransforms (torchvision.transforms.transforms.Compose, optional): Pytorch transforms that will be applied
                                                                              on chunked segmentation. Defaults to None.
        """

        self.chunk_size = chunk_size
        self.least = least
        self.train = train
        self.transforms = transforms
        self.segTransforms = segTransforms

    def _make_chunk(self, data: Tuple, s: int) -> Tuple:
        """Make a chunk tuple instance, which contains: signal, class number, segmentation data

        Args:
            data (Tuple): Data from which chunk will be created.
            s (int): Start of the created chunk.

        Returns:
            Tuple: Chunk instance.
        """
        chunk_data = data[0][s:s + self.chunk_size]
        seg_data = data[2][s:s + self.chunk_size]
        if self.transforms:
            chunk_data = self.transforms(chunk_data)
        if self.segTransforms:
            seg_data = self.segTransforms(seg_data)
        chunk = (chunk_data, data[1], seg_data,data[3])
        return chunk

    def split(self, data: Tuple) -> List[Tuple]:
        """Split data to chunks.

        Args:
            data (Tuple): Data to be splitted to chunks.

        Returns:
            List[Tuple]: List of created chunk instances.
        """
        N = data[0].size(dim=0)
        # too short, throw away
        if N < self.least:
            return []
        chunks = []
        # padding zeros
        if N < self.chunk_size:
            P = self.chunk_size - N
            chunk_data = th.nn.functional.pad(data[0], (0, P), "constant")
            seg_data = th.nn.functional.pad(data[2], (0, P), "constant")
            if self.transforms:
                chunk_data = self.transforms(chunk_data)
            chunk = (chunk_data, data[1], seg_data,data[3])
            chunks.append(chunk)
        else:
            # random select start point for training
            s = random.randint(0, N % self.least) if self.train else 0
            while True:
                if s + self.chunk_size > N:
                    break
                chunk = self._make_chunk(data, s)
                chunks.append(chunk)
                s += self.least
        return chunks


class Murmur2022DataLoader(object):
    def __init__(self, dataset: Murmur2022Dataset,
                 num_workers: int,
                 chunk_size: int,
                 overlap: int,
                 batch_size: int,
                 transforms: torchvision.transforms.transforms.Compose = None,
                 segTransforms: torchvision.transforms.transforms.Compose = None,
                 train: bool = True):
        """Murmur 2022 dataset dataloader supports signal, segmentation and class label data.

        Args:
            dataset (Murmur2022Dataset): Datset class.
            num_workers (int): Number of threads to be used to load from dataset.
            chunk_size (int): Size of chunk (window).
            overlap (int): Size of overlap of the new chunk over the previous one.
            batch_size (int): Size of the batch.
            transforms (torchvision.transforms.transforms.Compose, optional): Pytorch transforms that will be applied
                                                                              on chunked signal. Defaults to None.
            segTransforms (torchvision.transforms.transforms.Compose, optional): Pytorch transforms that will be applied 
                                                                                 on chunked segmentation. Defaults to None.
            train (bool, optional): If true, then shuffle and random chunking start is turned` on. Defaults to True.
        """
        assert chunk_size > overlap
        self.batch_size = batch_size
        self.train = train
        self.splitter = ChunkSplitter(chunk_size,
                                      train=train,
                                      least=chunk_size - overlap,
                                      transforms=transforms,
                                      segTransforms=segTransforms)
        # just return batch of egs, support multiple workers
        # num_workers=num_workers,
        self.data_loader = DataLoader(dataset,
                                      batch_size=batch_size // 2,
                                      shuffle=train,
                                      num_workers=num_workers,
                                      collate_fn=self._collate)

    def _collate(self, batch: Any) -> List:
        """Online split utterances

        Args:
            batch (Any): _description_

        Returns:
            List: _description_
        """
        chunk = []
        for data in batch:
            chunk += self.splitter.split(data)
        return chunk

    def _merge(self, chunk_list: List) -> Any:
        """Merge chunk list into mini-batch

        Args:
            chunk_list (List): _description_

        Returns:
            Any: _description_
        """
        N = len(chunk_list)
        if self.train:
            random.shuffle(chunk_list)
        blist = []
        for s in range(0, N - self.batch_size + 1, self.batch_size):
            batch = default_collate(chunk_list[s:s + self.batch_size])
            blist.append(batch)
        rn = N % self.batch_size
        return blist, chunk_list[-rn:] if rn else []

    def __iter__(self):
        """Iterate through dataloader.

        Yields:
            Tuple[th.tensor, int, th.tensor]: Tuple with signal, class number and segmentation data.
        """
        chunk_list = []
        for chunks in self.data_loader:
            chunk_list += chunks
            batch, chunk_list = self._merge(chunk_list)
            for obj in batch:
                yield obj
