


import numpy as np
import scipy.signal as sig



def preprocessing(x,fs,srange,eval):

    # resample to ~250hz
    if fs==250:
        fs=250
    elif fs == 257:
        fs=257
    elif fs == 1000:
        x = sig.resample_poly(x, up=1, down=4, axis=-1)
        fs=250
    elif fs == 500:
        x = sig.resample_poly(x, up=1, down=2, axis=-1)
        fs=250
    else:
        x = sig.resample_poly(x, up=1, down=int(np.floor(fs%250)),axis=-1)
        fs = 250

    if eval==False:
        # i.e. training
        # crop
        T = 250 * 60 * 2  # i.e. 2 minutes
        if x.shape[-1] > T:
            idx = np.random.randint(x.shape[-1] - T - 1)
            x = x[:, idx:idx + T]
    else:
        #use full recording
        pass

    if x.shape[-1]%2 != 0:
        x = x[:,:-1]

    # global zscore
    x = np.nan_to_num(x)
    # mu = np.mean(x,axis=1,keepdims=True)
    # std = x.std()
    # if std != 0.0:
    #     x = (x - mu) / std
    x = x/srange
    x = x - np.mean(x, axis=1, keepdims=True)
    lf = np.concatenate([filterFFT_bandPass(x[i,:],fs,0,8,True,'tukey') for i in range(12)],axis=0)
    lf = np.mean(lf,axis=0,keepdims=True)
    mf = np.concatenate([filterFFT_bandPass(x[i,:],fs,0,24,True,'tukey') for i in range(12)],axis=0)
    mf = np.mean(mf,axis=0,keepdims=True)

    x = np.concatenate([x,lf,mf],axis=0)
    return x


from scipy.signal import hilbert
import scipy.signal as signal

def filterFFT_bandPass(sgnl,fs,fromHz,toHz, doEnvelope, winType):
    """FFT signal filter. sgnl: signal, fs: sampling frequency, fromHZ:toHz: frequency range, doEnvelope: set to True if hilbert transform should follow, winType: type of window to smooth frequency range
    winTypes:boxcar (=rectangular), traing, blackman, hamming, hann, nuttal ... + any else from scipy

    """
    sp=np.fft.rfft(sgnl)

    N=len(sgnl);

    sh=N/fs;

    fs2=round(N/2);

    na=round(fromHz*sh);
    nb=round(toHz*sh);

    #windowing function
    nw=nb-na
    if winType=='tukey':
        wind = signal.tukey(nw)
    else:
        wind=signal.get_window(winType,nw)

    sp[na:nb]=sp[na:nb]*wind


    n2a=round(N-na);
    n2b=round(N-nb);


    sp[0:na]=0;
    sp[nb:n2b]=0;
    sp[n2a:N-1]=0;

    rek=np.fft.irfft(sp)

    if doEnvelope:
        hlb=hilbert(rek)
        rek=np.sqrt(hlb.real*hlb.real+hlb.imag*hlb.imag)

    return rek.reshape(1,-1)


def detrend(ecg):
    """ Detrend ensuring that signal is periodic (needed for FFT): y[0]=y[end]
    """
    signalLength = len(ecg)-1
    y0 = ecg[0]
    yEnd = ecg[signalLength]
    totalDy = yEnd-y0

    dy = totalDy/signalLength

    for i in range(0,signalLength+1):
        sub = dy*i
        ecg[i] -= sub

    return ecg