import torch
from torch.utils.data import Dataset
#from matplotlib import pyplot as plt

class PCG_Dataset(Dataset):
    def __init__(self, features, labels):
        self.x_data = features
        self.y_data = labels
        self.len = len(features)

    def __getitem__(self, index):
        feature = self.x_data[index]
        label = self.y_data[index]
        x = torch.from_numpy(feature).float()
        y = torch.from_numpy(label).float()
        return x, y

    def __len__(self):
        return self.len
