
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thur June 17 17:35:05 2021

@author: chadyang
"""
import random
import numpy as np
import torch
from torchvision import transforms
from torchvision.transforms import functional as TF
from heartpy.filtering import filter_signal, remove_baseline_wander
# import librosa

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

class Zscore(object):
    def __call__(self, array):
        array = (array-array.mean(axis=0))/array.std(axis=0)
        array[np.isnan(array)] = 0
        return array
    def __repr__(self):
        return self.__class__.__name__

class RemoveBaselineWander(object):
    def __init__(self, sr=500, cutoff=0.05):
        self.sr = sr
        self.cutoff = cutoff
    def __call__(self, data:list):
        return remove_baseline_wander(data, sample_rate=self.sr, cutoff=self.cutoff).copy()
    def __repr__(self):
        return self.__class__.__name__

class BandPass(object):
    def __init__(self, sr=500, cutoff=0.05, filtertype="bandpass"):
        self.sr = sr
        self.cutoff = cutoff
        self.filtertype = filtertype
    def __call__(self, data:list):
        return filter_signal(data, sample_rate=self.sr, cutoff=self.cutoff, filtertype=self.filtertype).copy()
    def __repr__(self):
        return self.__class__.__name__

class MinMaxScaler(object):
    def __call__(self, array):
        array = array.T
        array = ((array-array.min(axis=0))/(array.max(axis=0)-array.min(axis=0))).T
        array[~np.isfinite(array)] = 0
        return array
    def __repr__(self):
        return self.__class__.__name__

class RandomLeadMask(object):
    def __init__(self, p=0.5):
        self.p = p 
    def __call__(self, tensor):
        rand_number = np.random.uniform(low=0, high=1, size=tensor.shape[1])
        tmp = torch.zeros_like(tensor)
        tmp[:,rand_number>self.p,:] = tensor[:,rand_number>self.p,:]
        return tmp
    def __repr__(self):
        return self.__class__.__name__

class RandomShuflleLead(object):
    def __init__(self, p=0.5):
        self.p = p
    def __call__(self, tensor):
        rand_number = np.random.uniform(low=0, high=1, size=1)[0]
        if rand_number<self.p:
            lead_idx = list(range(tensor.shape[1]))
            random.shuffle(lead_idx)
            tensor = tensor[:,lead_idx,:]
        return tensor
    def __repr__(self):
        return self.__class__.__name__


# class GetSpecTralGram(object):
#     def __init__(self, nfft, hop_len):
#         self.nfft = nfft
#         self.hop_len = hop_len
#     def __call__(self, array):
#         # array: lead*len
#         array_stft = []
#         for lead in array:
#             specgm = sa.stft(lead, n_fft=self.nfft, hop_length=self.hop_len)
#             specgm = np.abs(specgm) 
#             array_stft.append(specgm)
#         array_stft = np.asarray(array_stft) # Lead*D*T
#         # array_stft = np.transpose(array_stft, (2,1,0)) # T*D*Lead
#         return array_stft.astype("float")
#     def __repr__(self):
#         return self.__class__.__name__

class AsTensor(object):
    def __call__(self, array):
        return torch.tensor(array)
    def __repr__(self):
        return self.__class__.__name__

