# this files contains all functions to preprocess the data
import random
import numpy as np
from scipy.ndimage.filters import uniform_filter1d
from scipy import signal
from math import ceil


def cutout(sig, cutout_length, cutout_probability, cutout_value=None):
    """
    This function realizes cutout augmentation.

    :param sig: An array that has the size ChannelxSamples where the cutout is applied independent to every channel.
    :type sig: np.ndarray
    :param cutout_length: The maximum length of the cutout value in samples
    :type cutout_length: int
    :param cutout_probability: The probability whether there will be cutout for each channel independently (0-1
    :type cutout_probability: int
    :param cutout_value: Which value will be substituted in the cutout window (None for signal mean)
    :type cutout_value: float
    """

    # make some checks
    if cutout_probability < 0 or cutout_probability > 1:
        raise ValueError("The input values aren't right.")

    # copy the signal
    sig = sig.copy()

    # iterate over all channels
    for i in range(sig.shape[0]):
        # randomly decide if we want to cutout
        if random.random() < cutout_probability:

            # decide where to cut
            start_sample = int(random.random()*sig.shape[1])

            # decide the length of the cut
            length = int(cutout_length*random.random())

            # calculate the signal mean if cutout is not given
            if cutout_value is None:
                cutout_value = np.mean(sig[i, :])

            # put in the cutout
            sig[i, start_sample:start_sample+length] = cutout_value

    return sig


def cropping(sig, minimum_length=5000, maximum_length=7500):
    """
    This function realizes cropping augmentation. It crops a random window from a given signal.

    :param sig: An array that has the size ChannelxSamples where the cutout is applied independent to every channel.
    :type sig: np.ndarray
    :param minimum_length: The minimal length of the cropping window in samples
    :type minimum_length: int
    :param maximum_length: The maximum length of the cropping window in samples
    :type maximum_length: int
    """
    # copy the signal
    sig = sig.copy()

    # calculate how much playroom we do have from minimum length to maximum length
    max_additional_samples = int(maximum_length-minimum_length)
    if max_additional_samples < 0:
        raise ValueError('"minimum_length" can not be longer than "maximum_length"!')

    # get a random window from the signal
    if sig.shape[1] > minimum_length:
        signal_length = int(min((sig.shape[1] - minimum_length), max_additional_samples) * random.random())\
                        + minimum_length
        onset = int((sig.shape[1] - signal_length) * random.random())
        sig = sig[:, onset:onset + signal_length]

    return sig


def noise(sig, max_amplitude):
    """
    This function adds additive white gaussion noise with randowm amplitude independently to every channel.

    :param sig: An array that has the size ChannelxSamples where the cutout is applied independent to every channel.
    :type sig: np.ndarray
    :param max_amplitude: The maximum amplitude of the noise
    :type max_amplitude: float
    """

    # copy the signal
    sig = sig.copy()

    # iterate over all signals an generate independent noise in the fiven range
    for i in range(sig.shape[0]):
        # make the noise
        noised = np.random.normal(0, max_amplitude*random.random(), sig.shape[1])

        # add the noise to the channel
        sig[i, :] += noised

    return sig


def shifting(sig, max_samples_rotation=500):
    """
    This functions rotates the signal for a given amount of samples.
    :param sig: An array that has the size ChannelxSamples where the cutout is applied independent to every channel.
    :type sig: np.ndarray
    :param max_samples_rotation: The maximum amount of samples the signal should be rotated
    :type max_samples_rotation: int
    """
    # copy the signal
    sig = sig.copy()

    # roll the array!
    sig = np.roll(sig, int(max_samples_rotation*random.random()), axis=1)

    return sig


def wiggle_channels(sig, max_samples_rotation=5):
    """
    This function rotates evey channel for a few samples so the samples are switches in relative postion.

    :param sig: An array that has the size ChannelxSamples where the cutout is applied independent to every channel.
    :type sig: np.ndarray
    :param max_samples_rotation: The maximum amount of samples each channel will be moved
    :type max_samples_rotation: int
    """
    # copy the signal
    sig = sig.copy()

    # check if the signal has more channels
    if not len(sig.shape) == 2:
        raise ValueError('The signal needs to have exactly two dimensions.')

    # roll every channel of the signal independently
    for i in range(sig.shape[0]):
        sig[i, :] = shifting(sig[i:i+1, :], max_samples_rotation=int(max_samples_rotation*random.random()))

    return sig


def outlier_removal(sig, window_size):
    """
    This function removes outliers from the signal by comparing the original signal with a mean of the signal.

    :param sig: An array that has the size ChannelxSamples where the cutout is applied independent to every channel.
    :type sig: np.ndarray
    :param window_size: The size of the mean filter to extract the comparison mean signal
    :type window_size: int
    """

    # copy the signal
    sig = sig.copy()

    # meanfilter the signal
    mean_sig = uniform_filter1d(np.abs(sig), window_size, axis=1)

    sig[mean_sig > 2.5] = np.mean(sig)

    return sig


def filter_sig(sig, sos):
    """
    This function frequency filters the signal and also extracts outlier by applying a medianfilter and the
    outlier removal function

    :param sig: An array that has the size ChannelxSamples where the cutout is applied independent to every channel.
    :type sig: np.ndarray
    :param sos: The sos-data of a given filter
    """

    # copy the signal
    sig = sig.copy()

    sig = outlier_removal(sig, 50)
    sig = np.stack([signal.sosfilt(sos, signal.medfilt(sig[i, :], 5)) for i in range(sig.shape[0])], axis=0)
    return sig


def amplitudes_perturbation_noise(sig, max_amplitude):
    """
    This function realizes additive frequency noise.

    :param sig: An array that has the size ChannelxSamples where the cutout is applied independent to every channel.
    :type sig: np.ndarray
    :param max_amplitude: The maximum amplitude of the noise per frequency
    :type max_amplitude: float
    """
    # copy the signal
    sig = sig.copy()

    # change the frequencies
    for i in range(sig.shape[0]):

        # transform the signal to frequencies
        signal_f = np.fft.fft(sig[i])

        # generate complex noise (needs to be symmetrical because of
        # https://stackoverflow.com/questions/48532509/matlab-for-even-real-functions-fft-complex-result-ifft-real-result
        # we also need to scale the maximum noise due to
        # https://dsp.stackexchange.com/questions/35951/relationship-between-fft-amplitude-and-sample-size
        if (signal_f.shape[0]-1) % 2 == 0:  # check if there are an even number of conjugated frequencies
            complex_noise = (np.random.uniform(low=0.0, high=max_amplitude*signal_f.shape[0], size=(int((signal_f.shape[0]-1)/2), 2))
                             .view(np.complex128))[:, 0]
            complex_noise = np.concatenate((complex_noise, np.flip(np.conj(complex_noise))))
        else:
            complex_noise = (np.random.uniform(low=0.0, high=max_amplitude*signal_f.shape[0], size=(ceil((signal_f.shape[0]-1)/2), 2))
                             .view(np.complex128))[:, 0]
            complex_noise = np.concatenate((complex_noise, np.flip(np.conj(complex_noise[:-1]))))

        # add noise to frequencies
        signal_f[1:] += complex_noise

        # transform the frequencies back into the signal
        sig[i] = np.fft.ifft(signal_f).real

    return sig


def amplitudes_perturbation_zero(sig, probability):
    """
    This function realizes random frequency canceling augmentation.

    :param sig: An array that has the size ChannelxSamples where the cutout is applied independent to every channel.
    :type sig: np.ndarray
    :param probability: The probability whether a frequency will be set to zero (0-1)
    :type probability: float
    """
    # copy the signal
    sig = sig.copy()

    # change the frequencies
    for i in range(sig.shape[0]):

        # transform the signal to frequencies
        signal_f = np.fft.fft(sig[i])

        # generate complex noise (needs to be symmetrical because of
        # https://stackoverflow.com/questions/48532509/matlab-for-even-real-functions-fft-complex-result-ifft-real-result
        if (signal_f.shape[0] - 1) % 2 == 0:  # check if there are an even number of conjugated frequencies
            rand = np.random.binomial(1, 1 - probability, int((signal_f.shape[0] - 1) / 2))
            rand = np.concatenate((rand, np.flip(rand)))
        else:
            rand = np.random.binomial(1, 1 - probability, ceil((signal_f.shape[0] - 1) / 2))
            rand = np.concatenate((rand, np.flip(rand[:-1])))

        # set frequencies to zero
        signal_f[1:] += rand

        # transform the frequencies back into the signal
        sig[i] = np.fft.ifft(signal_f).real

    return sig


def resample_signal(sig, old_frequency, new_frequency):

    if old_frequency != new_frequency:

        # calculate the new number of samples and make a new array
        number_of_samples = round(sig.shape[1] * float(new_frequency) / old_frequency)
        res_sig = np.zeros((sig.shape[0], number_of_samples))

        for i in range(sig.shape[0]):
            res_sig[i] = signal.resample(sig[i], number_of_samples)
    else:
        res_sig = sig.copy()

    return res_sig
