import numpy as np
import torch
from tqdm import tqdm

from src.setup import equivalent_classes


class Dataset_ECG(torch.utils.data.Dataset):
    def __init__(self, data_info_dict, classes, transform):
        self.headers = []
        self.sample_ids = []
        self.features = []
        self.recordings = []
        self.labels = []
        self.mean_per_lead = []

        data_headers = data_info_dict["headers"]
        data_sample_ids = data_info_dict["sample_ids"]
        data_features = data_info_dict["features"]
        data_recordings = data_info_dict["recordings"]
        data_labels = data_info_dict["labels"]
        data_classes = data_info_dict["classes"]
        data_mean_per_lead = data_info_dict["mean_per_lead"]

        self.substitute_dx_dict = dict((x[1], x[0]) for x in equivalent_classes)
        self.classes = classes if classes is not None else data_classes
        self.num_classes = len(self.classes)
        self.transform = transform
        
        if set(classes) == set(data_classes):
            self.headers = data_headers
            self.sample_ids = data_sample_ids
            self.features = data_features
            self.recordings = data_recordings
            self.labels = data_labels
            self.mean_per_lead = data_mean_per_lead
        else:
            for i in tqdm(range(len(data_recordings)), desc="preparing dataset"):
                sample_label = np.array(data_classes)[np.where(data_labels[i])].tolist()
                label, skip_this_sample = self._get_label(sample_label)
                
                if skip_this_sample:
                    continue
                
                self.headers.append(data_headers[i])
                self.sample_ids.append(data_sample_ids[i])
                self.features.append(data_features[i])
                self.recordings.append(data_recordings[i])
                self.labels.append(torch.Tensor(label))
                self.mean_per_lead.append(data_mean_per_lead[i])

        self.num_recordings = len(self.recordings)
        print(f"total data length: {self.num_recordings}")

    def __len__(self):
        return self.num_recordings

    def __getitem__(self, idx):
        """
        Returns:
            dictionary: {'id': ~, 'recording': ~, 'labels': ~}
            labels: multi-hot encoded target
        """
        return {
            "id": self.sample_ids[idx],
            "recording": self.transform(self.recordings[idx]),
            "labels": self.labels[idx],
            "features": self.features[idx],
        } # can return additional features here

    def _get_label(self, sample_label):
        label = np.zeros(self.num_classes)
        current_labels = list(
            map(
                lambda x: self.substitute_dx_dict[x]
                if x in self.substitute_dx_dict.keys() 
                else x, 
                sample_label
            )
        )

        skip_this_sample = True
        for l in set(current_labels) & set(self.classes):
            j = self.classes.index(l)
            label[j] = 1
            skip_this_sample = False

        return label, skip_this_sample

class Dataset_ECG_SimCLR(torch.utils.data.Dataset):
    def __init__(self, data_info_dict, transform):
        # self.headers = data_info_dict["headers"]
        # self.sample_ids = data_info_dict["sample_ids"]
        # self.features = data_info_dict["features"]
        self.recordings = data_info_dict["recordings"]
        # self.labels = data_info_dict["labels"]

        self.transform = transform # Compose(RandomCrop+RandAugmentation)
        self.num_recordings = len(self.recordings)

    def __len__(self):
        return self.num_recordings

    def __getitem__(self, idx):
        """
        Returns:
            list of transformed images (default: 2 views)
        (Not using a feature vector here!)
        """
        img = self.recordings[idx]
        img_list = list()
        if self.transform is not None:
            for _ in range(2):
                img_transformed = self.transform(img)
                img_list.append(img_transformed)
        return img_list
