import copy

import torch
from scipy import signal
from scipy.interpolate import interp1d

from helper_code import *
import src.augmentation.transforms as T

# Pre-process recordings
def preprocess_recording(header, recording, leads, preprocess_configs, recording_resampling_check=[]):
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    if not recording_resampling_check: # empty list? == first time?
        recording_resampling_check.append(Recording_resampling(preprocess_configs["resample_freq"]))

    recording_resampling = recording_resampling_check[0]
    recording = recording_resampling(header, recording)

    # Pre-process recordings.
    adc_gains = get_adc_gains(header, leads).reshape(-1, 1)
    baselines = get_baselines(header, leads).reshape(-1, 1)
    recording = (recording - baselines) / adc_gains         # Original
    
    # Use Bandpass filter?
    if preprocess_configs["use_filter"]:
        bandpass = T.ButterFilter(btype="band")
        recording = bandpass(recording)
    # Use Standardization?
    if preprocess_configs["use_standardization"]:
        recording = recording - recording.mean(dim=1, keepdim=True)
        recording = recording / (recording.std(dim=1, keepdim=True) + 1e-6)
    # Use Max-min normalization?
    if preprocess_configs["use_max_min_normalization"]:
        normalizer = torch.abs(recording).quantile(q=0.997, dim=1, keepdim=True)    
        recording = recording / (normalizer + 1e-6)         # Normalize the data so that the value lies in [-1,1]
    return recording


class Recording_resampling(object):
    def __init__(self, target_frequency=300, interpolation_method="linear"):
        self.target_frequency = target_frequency
        self.interpolation_method = interpolation_method

    def __call__(self, header, recording):
        frequency = get_frequency(header)

        if frequency != self.target_frequency:
            num_samples = get_num_samples(header)
            data_time_sequence = np.arange(0, num_samples) / frequency
            end_time = int(data_time_sequence[-1] * self.target_frequency)
            target_data_time_sequence = (
                np.arange(0, end_time + 1) / self.target_frequency
            )
            resampled_recording = np.zeros((recording.shape[0], end_time + 1))

            for i in range(recording.shape[0]):
                interpolate_function = interp1d(
                    data_time_sequence, recording[i], kind=self.interpolation_method
                )

                resampled_recording[i] = interpolate_function(target_data_time_sequence)
        else:
            resampled_recording = recording

        return resampled_recording


def test_plot(recording, header):
    import matplotlib.pyplot as plt

    frequency = get_frequency(header)

    if frequency != 500:
        print(recording[4].shape)
        print("-----")
        plt.plot(recording[4])
        plt.show()


class Wassnorm(object):
    def __init__(self, num_leads):
        """
        :param num_leads: -
        """
        self.num_leads = num_leads

    def __call__(self, recording):
        """
        :param recording: -
        """
        shapes = recording.shape
        recording = recording.reshape([shapes[0] * self.num_leads, shapes[-1]])
        recording = torch.argsort(recording, dim=1) / (shapes[-1] - 1.0)
        recording = recording.reshape(recording.shape)

        return recording


# class Recording_filtering(object):
#     def __init__(self, filter_cut=81, sampling_freq=50, f_cut_low=0.2, f_cut_high=12):
#         self.filter_cut = filter_cut
#         self.sampling_freq = sampling_freq
#         self.f_cut_low = f_cut_low
#         self.f_cut_high = f_cut_high

#     def __call__(self, recording):
#         L_block = np.shape(recording)[1] * self.sampling_freq + 2 * self.filter_cut
#         bands = (
#             np.array(
#                 [
#                     0,
#                     0.1,
#                     self.f_cut_low,
#                     self.f_cut_high,
#                     self.f_cut_high * 1.1,
#                     self.sampling_freq / 2,
#                 ]
#             )
#             / self.sampling_freq
#             * 2
#         )
#         desired = (0, 0, 1, 1, 0, 0)
#         band_pass_filter = signal.firls(self.filter_cut, bands, desired, [100, 1, 100])
#         filtered_recording = []

#         for i in range(recording.shape[0]):
#             temp_pulse = recording[i]
#             temp_pulse = temp_pulse / np.abs(temp_pulse).max()  # Normalize
#             # TODO: "RuntimeWarning: invalid value encountered in true_divide" occurs for some samples -> loss NaN

#             h, w = signal.freqz(band_pass_filter, 1, 4096, self.sampling_freq)
#             temp = np.convolve(temp_pulse, band_pass_filter)
#             flt_out = temp[
#                 int((len(band_pass_filter) - 1) / 2) : -int(
#                     (len(band_pass_filter) - 1) / 2
#                 )
#             ]

#             flt_out = flt_out[self.filter_cut : len(flt_out) - self.filter_cut]
#             flt_out = flt_out / np.abs(flt_out).max()  # Normalize
#             filtered_recording.append(flt_out)

#         filtered_recording = np.array(filtered_recording)

#         return filtered_recording
