"""Sequence transformations for augmentation.
- Reference:
    https://github.com/uchidalab/time_series_augmentation
"""

import numpy as np
import torch
from scipy.signal import butter, sosfilt, sosfiltfilt, sosfreqz


class RandomCrop(object):
    """
    Randomly crops record data with fixed window size
    """

    def __init__(self, window_size):
        self.window_size = window_size

    def __call__(self, sample):
        if sample.shape[1] == self.window_size:
            # When random cropping is not necessary
            cropped_data = sample
        elif sample.shape[1] < self.window_size:
            # When the record is shorter than the given window size
            # Use zero-padding
            cropped_data = torch.zeros(sample.shape[0], self.window_size)
            cropped_data[:, : sample.shape[1]] = sample
        else:
            cropped_last_index = np.random.randint(
                self.window_size, sample.shape[1] + 1
            )
            cropped_data = sample[
                :, (cropped_last_index - self.window_size):cropped_last_index
            ]
        return cropped_data.float()


class AddChannelDimension(object):
    """
    Adds one channel dimension to regard ECG as single channel image
    """
    def __init__(self):
        pass

    def __call__(self, sample):
        return sample.unsqueeze(0).float()


class Jittering:
    """
    jittering the sequence
    """
    def __init__(self, sigma=0.03):
        """
        Args:
            sigma (float): the mignitude of jittering, small sigma implies small jittering
        """
        self.sigma = sigma

    def __call__(self, seq):
        noise = np.random.normal(loc=0, scale=self.sigma, size=seq.shape)
        seq = seq + noise
        return seq.float()


class Scaling:
    """
    scaling the sequence
    """
    def __init__(self, sigma=0.1):
        """
        Args:
            sigma (float): the mignitude of saling, small sigma implies small scaling
        """
        self.sigma = sigma

    def __call__(self, seq):
        """
        seq.shape: (num leads, len_seq)
        seq.type: torch.Tensor
        factor.shape: seq.shape[0]: num leads
        different scaling for different lead
        """
        factor = np.random.normal(loc=1., scale=self.sigma, size=(seq.shape[0], 1))
        seq_mean = seq.mean(dim=1).reshape(-1, 1)
        mean_zero_seq = seq - seq_mean
        scaled_mean_zero_seq = np.multiply(mean_zero_seq, factor)
        scaled_seq = scaled_mean_zero_seq + seq_mean
        return scaled_seq.float()


class HorzontalFlip:
    """
    flip sequence in the horizontal direction
    """
    def __init__(self):
        pass

    def __call__(self, seq):
        return torch.flip(seq, [1]).float()

class Cutout:
    """
    Cutout signals on some time interval
    """
    def __init__(self, ratio = 0.05):
        self.ratio = ratio    # The ratio of cutout to the whole signal

    def __call__(self, seq):
        seq_len = seq.shape[1]
        new_seq = torch.clone(seq)
        cutout_len = int(self.ratio * seq_len)
        for i in range(seq.shape[0]):
            start = np.random.randint(0, seq_len - cutout_len -1)
            end = start + cutout_len
            start = max(0, start)
            end = min(end, seq_len)
            new_seq[i, start:end] = 0
        return new_seq.float()

class Baseline_Shift:
    """
    Adds an abrupt baseline shift to the sample
    """
    def __init__(self, ratio = 0.1):
        self.ratio = ratio   # The ratio of shifted time interval to the whole time interval
    
    def __call__(self, seq):
        seq_len = seq.shape[1]
        new_seq = torch.clone(seq)
        shift_len = int(self.ratio * seq_len)
        for i in range(seq.shape[0]):
            start = np.random.randint(0, seq_len - shift_len - 1)
            end = start + shift_len
            start = max(0, start)
            end = min(end, seq_len)
            amplitude = np.random.uniform(-0.25, 0.25)
            new_seq[i, start:end] += amplitude
        return new_seq.float()

class Gaussian_Blur:
    """
    Returns a blurred version of the signal
    As kernel_size or sigma get bigger, a degree of burrliness gets stronger
    """
    def __init__(self, kernel_size=11, sigma=3):
        k = kernel_size
        r = np.arange(-int(k/2), int(k/2)+1)
        self.conv = torch.nn.Conv1d(1,1,kernel_size=k,stride=1,padding=int(k/2), bias=False)
        self.conv.weight.data = torch.nn.Parameter(
            torch.tensor([[[1/(sigma*np.sqrt(2*np.pi)) * np.exp(-x**2/(2*sigma**2)) for x in r]]], dtype=torch.float))
        self.conv.weight.requires_grad = False
    
    def __call__(self, seq):
        new_seq = torch.clone(seq)
        new_seq = torch.unsqueeze(new_seq, 0)
        for i in range(seq.shape[0]):
            tmp = new_seq[:,i,:]
            tmp = torch.unsqueeze(tmp, 1)
            new_seq[:,i,:] = self.conv(tmp)
        return torch.squeeze(new_seq).float()

class VerticalFlip:
    """
    flip sequence in the vertical direction
    """
    def __init__(self):
        pass

    def __call__(self, seq):
        return (-seq).float()

class Powerline_Noise:
    """
    fs: sampling frequency (Hz)
    fn: base frequency of powerline noise (Hz)
    """
    def __init__(self, fs=300, fn=50., K=3):
        self.fs=fs
        self.fn=fn
        self.K=K
 
    def __call__(self, seq):
        C = np.random.uniform(0, 1/4)
        N = seq.shape[1]
        channels=seq.shape[0]
        t = np.arange(0, N/self.fs, 1./self.fs)
        phi1 = np.random.uniform(0, 2*np.pi)
        ak = np.random.uniform(0, 1)
        noise = ak*np.cos(2*np.pi*self.fn*t+phi1)
        noise = C*noise[:, None]
        if(channels > 1):
            channel_gains = np.array([np.random.uniform(-1, 1)
                                  for _ in range(channels)])
        noise = noise*channel_gains[None]
        noise = np.transpose(noise)
        return (seq + noise).float()

class Baseline_Wander:
   '''
   fs: sampling frequency (Hz)
   C: relative scaling factor 
   fc: cutoff frequency for the baseline wander (Hz)
   fdelta: lowest resolvable frequency
   '''
   """This Noise is a linear combination of 50 cosine functions 
    of 0.01Hz, 0.02Hz, ..., 0.5Hz, respectively."""
   def __init__(self, fs=300, C=0.02, fc=0.5, fdelta=0.01):
       self.fs=fs
       self.C=C
       self.fc=fc
       self.fdelta=fdelta
 
   def __call__(self, seq):
       N = seq.shape[1]
       channels=seq.shape[0]
       t = np.arange(0, N/self.fs, 1./self.fs)
       K = int(np.round(self.fc/self.fdelta))
       noise = np.zeros((N, channels))
       for k in range(1, K+1):
           phik = np.random.uniform(0, 2*np.pi)
           ak = np.random.uniform(0, 1)
           for c in range(channels):
               if(c > 0):  
                   ak = np.random.uniform(0, 1)*(2*np.random.randint(0, 1+1)-1)
               noise[:, c] += self.C*ak*np.cos(2*np.pi*k*self.fdelta*t+phik)
 
       if(channels > 1):  
           channel_gains = np.array(
               [(2*np.random.randint(0, 1+1)-1)*np.random.normal(1, 1) for _ in range(channels)])
           noise = noise*channel_gains[None]
       noise = np.transpose(noise)
       return (seq + noise).float()

# Application of Butterworth Filter
def butter_filter(lowcut=1, highcut=45, fs=300, order=20, btype='band'):
    '''returns butterworth filter with given specifications'''
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq

    sos = butter(order, [low, high] if btype=="band" else \
        (low if btype=="highpass" else high), analog=False, btype=btype, output='sos')
    return sos

class ButterFilter(object):
    """
    Apply filter 
    [btype] 'highpass' for high-pass, 'lowpass' for low-pass, 'band' for band filter
    Default : high pass with cutoff frequency of 1Hz
              If you put btype="band", it works as a bandpass filter of 1Hz~45Hz
    """
    def __init__(self, lowcut=1, highcut=45, fs=300, order=20, btype='highpass', forwardbackward=True):
        self.filter = butter_filter(lowcut,highcut,fs,order,btype)
        self.forwardbackward = forwardbackward

    def __call__(self, seq):
        if(self.forwardbackward):
            new_seq = sosfiltfilt(self.filter, seq, axis=1)
        else:
            new_seq = sosfilt(self.filter, seq, axis=1)

        return torch.from_numpy(new_seq.copy()).float()

class NoAug:
    """
    No Augmentation
    """
    def __init__(self):
        pass

    def __call__(self, seq):
        return (seq).float()