import collections
import itertools
import operator
import os.path
import pathlib
import re
import copy
import psutil
import scipy.io
import scipy.signal
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from scipy.signal import sosfiltfilt, iirdesign
from ecgdetectors import Detectors
# from pyecg.fiducial import RWaveDetector
from zipfile import BadZipFile
from tensorflow.keras.utils import Sequence
from .transformer_functions import create_padding_mask, zeropad
from .utils import rms, ClassEncoder
from .io import read_json, write_json

PTB_XL_LABELS = {
    '164865005', '445118002', '54329005', '63593006', '164947007', '446358003',
    '111975006', '425419005', '67741000119109', '428750005', '445211001',
    '698252002', '10370003', '164884008', '426434006', '195042002', '55930002',
    '67198005', '713426002', '251120003', '27885002', '39732003', '164861001',
    '426177001', '427084000', '47665007', '164951009', '427393009',
    '164917005', '59931005', '251180001', '425623009', '164931005',
    '251200008', '89792004', '164873001', '270492004', '426761007',
    '426783006', '11157007', '713427006', '266249003', '284470004',
    '164889003', '429622005', '164909002', '164890007', '164934002',
    '251146004', '74390002'
}

GEORGIA_LABELS = {
    '233917008', '59118001', '445118002', '164865005', '63593006', '713422000',
    '253339007', '251266004', '111975006', '81898007', '425419005',
    '164930006', '67741000119109', '428750005', '251139008', '445211001',
    '698252002', '164884008', '426434006', '195042002', '55930002',
    '713426002', '251120003', '27885002', '195060002', '413444003', '39732003',
    '426995002', '426177001', '195080001', '427084000', '47665007',
    '427393009', '426664006', '59931005', '164917005', '251180001',
    '425623009', '164931005', '253352002', '89792004', '164873001',
    '426761007', '270492004', '428417006', '164896001', '426783006', '6374002',
    '11157007', '713427006', '266249003', '426648003', '195126007',
    '284470004', '164889003', '429622005', '164921003', '164909002',
    '49578007', '164890007', '426627000', '164934002', '251268003',
    '195101003', '17338001', '251146004', '74390002'
}

FOLDER_NAME = "beated"
DETECTOR = "christov_detector"
SUPPORTED_NUM_LEADS = [1, 2, 3, 4, 6, 12]  # Supported number of leads
MAX_NUM_LEADS = 12


###############################################################################
#
# From cinc2020.data
#
###############################################################################


class Record:
    """CinC 2020 record.

    Class to store data from a record of the challenge of Computing in
    Cardiology 2020.

    Attributes:
        subject: Subject identifier.
        age: Subject age.
        gender: Subject gender.
        labels: Rhythm labels.
        fs: Sampling frequency in Hertz.
        ecg: ECG signals.
    """
    _SCALE_REGEX = re.compile(r'(?P<scale>\d+(\.\d*)?)/m[vV]')

    def __init__(
            self, subject, age, gender, labels, fs, ecg, dataset=None,
            r_peaks=None, beats=None
    ):
        self.subject = subject
        self.age = age
        self.gender = gender
        self.labels = labels
        self.fs = fs
        self.ecg = ecg
        self.dataset = dataset
        if self.dataset is not None:
            self.dataset = str(dataset)
        self.time = np.arange(self.ecg.shape[0]) / self.fs
        self.r_peaks = r_peaks

    def __repr__(self):
        return f'<{self.__class__.__name__} {self.subject}>'

    def to_npz(self, file):
        """Write the record to an NPZ file.

        Write the record to an NPZ file at the specified path.

        Args:
            file: Path where to write the record.
        """
        np.savez(
            file=file,
            subject=self.subject,
            age=self.age,
            gender=self.gender,
            labels=self.labels,
            fs=self.fs,
            ecg=self.ecg,
            dataset=self.dataset,
            r_peaks=self.r_peaks,
        )

    @classmethod
    def from_npz(cls, file):
        """Read a record from an NPZ file.

        Read a record of the CinC challenge from an NPZ file.

        Args:
            file: Path to an NPZ file.

        Returns:
            A record.
        """
        # Add allow_pickle=True to tolerate when .item() returns None
        with np.load(file, allow_pickle=True) as data:
            try:
                dataset = data['dataset']
            except KeyError:
                dataset = "Information Unavailable"
            return cls(
                subject=data['subject'].item(),
                age=data['age'].item(),
                gender=data['gender'].item(),
                labels=data['labels'].tolist(),
                fs=data['fs'].item(),
                ecg=data['ecg'],
                dataset=dataset,
                r_peaks=data['r_peaks']
            )

    @classmethod
    def from_header(cls, header, ecg, dataset='unknown'):
        """Read a header

            Args:
                header: list(list('str'))
                ecg: ecg data (num_samples, num_leads)
                dataset: name of dataset

            Returns:
                A record.
        """
        n_leads = int(header[0][1])
        assert n_leads == ecg.shape[1]
        fs = float(header[0][2])
        # fs = int(fs)
        n_samples = int(header[0][3])
        assert n_samples == ecg.shape[0]

        scale = np.ones(ecg.shape[1])
        for i in range(n_leads):
            # match = cls._SCALE_REGEX.fullmatch(header[i + 1][2])
            # scale[i] = float(match.group('scale'))
            scale[i] = float(re.split('/|[(]|[)]', header[i + 1][2])[0])
        ecg /= scale

        subject = header[0][0]
        if "." in subject:  # Remove .mat if it's there
            subject = subject.split(".")[0]
        age = None
        gender = None
        labels = None
        for line in header:
            if not len(line):  # skip empty lines
                continue
            if line[0] == '#Age:':
                age = float(line[1]) if line[1] != '-' else None
            elif line[0] == '#Sex:':
                gender = line[1].lower()
                if gender == 'f':
                    print("Fixing female gender")
                    gender = 'female'
                if gender == 'm':
                    print("Fixing male gender")
                    gender = 'male'
            elif line[0] == '#Dx:':
                labels = line[1].lower().split(',')

        return cls(subject, age, gender, labels, fs, ecg, dataset)

    @classmethod
    def from_mat(cls, file):
        """Read a record from a MAT file.

        Read a record of the CinC 2020 challenge from a MAT file and the
        corresponding HEA file.

        Args:
            file: Path to a MAT file.

        Returns:
            A record.
        """

        def find_dataset(file, version="cinc2021"):
            if version == 'cinc2021':
                dataset = file.parts[-2]
                name = file.stem
                if name[0] == "A":  # dataset == "WFDB_CPSC2018":
                    return "CPSC_1"
                if name[0] == "Q":  # dataset == "WFDB_CPSC2018_2":
                    return "CPSC_2"
                if name[0] == "E":  # dataset == "WFDB_Ga":
                    return "Georgia"
                if name[0] == "S":  # dataset == "WFDB_PTB":
                    return "PTB"
                if name[:2] == "HR":  # dataset == "WFDB_PTBXL":
                    return "PTB_XL"
                if name[0] == "I":  # dataset == "WFDB_StPetersburg":
                    return "StPetersburg"
                if name[:2] == "JS":  # dataset == "WFDB_ChapmanShaoxing":
                    return "Chapman"
                if name[:2] == "JS":  # dataset == "WFDB_Ningbo":
                    return "Ningbo"
                if dataset == "2000mitdb":
                    return "2000mitdb"
                if dataset == "2000ltafdb":
                    return "2000ltafdb"
                if dataset == "2000nstdb":
                    return "2000nstdb"
                if dataset == "2000cudb":
                    return "2000cudb"
                if dataset == "1977AHADB":
                    return "1977AHADB"
                if dataset == "2009edb":
                    return "2009edb"
                return "unknown"

            if version == 'cinc2020':
                dataset = file.parts[-2]
                if dataset == "Training_2":
                    dataset = "CPSC_2"
                if dataset == "Training_PTB":
                    dataset = "PTB"
                if dataset == "Training_StPetersburg":
                    dataset = "StPetersburg"
                if dataset == "Training_WFDB":
                    dataset = "CPSC_1"
                if dataset == "WFDB":
                    _SUBJECT_REGEX = re.compile(r'E\d{5}')
                    if _SUBJECT_REGEX.fullmatch(file.stem):
                        dataset = "Georgia"
                    else:
                        dataset = "PTB_XL"
                return dataset
            raise ValueError("Unknown dataset finding version", version)

        file = pathlib.Path(file).with_suffix('.mat')
        dataset = find_dataset(file)
        # print(file)
        assert dataset != "unknown"
        data = scipy.io.loadmat(file)
        ecg = data['val'].astype('float32').T
        assert ecg.shape[1] in SUPPORTED_NUM_LEADS
        with open(file.with_suffix('.hea'), mode='r') as f:
            header = list(map(str.split, map(str.strip, f)))
        record = Record.from_header(header, ecg, dataset)
        return record

    def resample(self, fs):
        """Resample the record.

        Resample the ECG signals of the record to the specified sampling
        frequency.

        Args:
            fs: New sampling frequency in Hertz.
        """
        if fs == self.fs:
            return
        self.ecg = scipy.signal.resample_poly(self.ecg, fs, self.fs, axis=0)
        self.ecg = self.ecg.astype('float32')
        # WARNING: resampling messes with r-peak detector!??
        self.fs = fs

    def filter(self, wp=0.5, ws=0.3, gpass=0.5, gstop=20.0):
        """Filter the record.

        High pass filter the ECG signals of the record.
        """
        sos = iirdesign(
            wp, ws, gpass, gstop,
            ftype='butter', output='sos', fs=self.fs
        )
        for i in range(self.ecg.shape[1]):
            if any(np.isnan(self.ecg.T[i])):
                self.ecg.T[i] = pd.Series(
                    self.ecg.T[i]).interpolate().to_numpy()
            # Repeat for NaN values at the beginning or end of ecg recording
            # Replace the NaN values by the first and the last registered
            # values of the ecg recording: straight lines
            if any(np.isnan(self.ecg.T[i])):
                start = True
                nan_indices = np.isnan(self.ecg.T[i])
                nan_idx_old = np.arange(len(self.ecg.T[i]))[nan_indices][0]
                self.ecg.T[i][nan_idx_old] = self.ecg.T[i][~nan_indices][0]
                for nan_idx in np.arange(len(self.ecg.T[i]))[nan_indices][1:]:
                    if (nan_idx - nan_idx_old == 1) and start:
                        self.ecg.T[i][nan_idx] = self.ecg.T[i][~nan_indices][0]
                    else:
                        start = False
                        self.ecg.T[i][nan_idx] = \
                            self.ecg.T[i][~nan_indices][-1]
                    nan_idx_old = nan_idx
        self.ecg = sosfiltfilt(sos, self.ecg, axis=0)

    def r_peak_detector(self):
        """ Find R-peaks in ECG signal"""

        detectors = Detectors(self.fs)
        detector = getattr(detectors, DETECTOR)

        return detector(rms(self.ecg))


class Dataset:
    """CinC 2020 dataset.

    Class to store a dataset of records from the challenge of Computing in
    Cardiology 2020 divided into folds.

    Attributes:
        folds: List of folds where each fold is a list of records.
    """
    _SUBJECT_REGEX = re.compile(r'A\d{4}')
    _FOLD_REGEX = re.compile(r'fold(?P<fold>\d+)')

    def __init__(self, folds, datasets):
        self.folds = folds
        self.datasets = datasets

    @property
    def n_folds(self):
        """Number of folds in the dataset."""
        return len(self.folds)

    @property
    def n_records(self):
        """Total number of records in the dataset."""
        return sum(map(len, self.folds))

    def estimate_scaling_parameters(self, folds=None):
        """Estimate scaling parameters for the dataset.

        Compute the mean and standard deviation of the ECG signals in the
        dataset for scaling. Optionally, the indices of the folds to use can be
        specified.

        Args:
            folds: Indices of the folds to use for computing the mean and
                standard deviation.

        Returns:
            The mean and standard deviation of the ECG signals.
        """
        if folds is None:
            it = itertools.chain.from_iterable(self.folds)
        else:
            it = itertools.chain.from_iterable(self.folds[i] for i in folds)
        x = np.vstack(list(map(operator.attrgetter('ecg'), it)))
        mean = np.mean(x, axis=0)
        std = np.std(x, axis=0, ddof=1)
        return mean, std

    def iterate_records(self, folds=None, fold_index=False):
        """Iterate over records of the dataset.

        Args:
            folds: Indices of the folds to use when iterating over records.
            fold_index: Boolean to indicate that the fold index of each record
                should be included in the outputs.

        Yields:
            A record if fold_index is False and a fold index and a record
            otherwise.
        """
        if folds is None:
            folds = range(self.n_folds)
        if fold_index:
            for i in folds:
                for record in self.folds[i]:
                    yield i, record
        else:
            yield from itertools.chain.from_iterable(
                self.folds[i] for i in folds)

    def list_classes(self):
        """List classes in the dataset.

        Return a list of unique classes present in the dataset.

        Returns:
            A list of classes.
        """
        classes = set()
        for record in self.iterate_records(fold_index=False):
            classes.update(record.labels)
        return sorted(classes)

    def to_npz(self, directory):
        """Write the dataset to NPZ files.

        Write the records of the dataset to NPZ files in the specified
        directory. There is one subdirectory for each fold.

        Args:
             directory: Path where to write the dataset.
        """
        directory = pathlib.Path(directory)
        width = len(str(self.n_folds - 1))
        for i, fold in enumerate(self.folds):
            subdirectory = directory / f'fold{i:0{width}d}'
            if not subdirectory.is_dir():
                subdirectory.mkdir(parents=True)
            for record in fold:
                record.to_npz(subdirectory / f'{record.subject}.npz')

    @classmethod
    def from_npz(cls, directory):
        """Read a dataset from NPZ files.

        Read the records of a dataset from NPZ files

        Args:
            directory: Path to a directory with NPZ files.

        Returns:
            A dataset.
        """
        directory = pathlib.Path(directory)

        def is_fold_directory(path):
            return path.is_dir() and cls._FOLD_REGEX.fullmatch(path.name)

        def is_record_file(path):
            return path.is_file() and path.suffix == '.npz'
            # and cls._SUBJECT_REGEX.fullmatch(path.stem))

        subdirectories = sorted(filter(is_fold_directory, directory.iterdir()))
        folds = []
        datasets = set()
        count = 0
        for subdirectory in subdirectories:
            records = []
            for file in sorted(filter(is_record_file, subdirectory.iterdir())):
                if count % 500 == 0:
                    print("Recording:", count, "  ||  ", file.stem, end="\r")
                count += 1
                record = Record.from_npz(file)
                datasets.add(record.dataset)
                records.append(record)
            folds.append(records)

        return cls(folds, datasets)

    @classmethod
    def from_mat(cls, directory, n_folds=None, fs=None, hp_filter=False):
        """Create a dataset from MAT files.

        Create a dataset of records of the CinC 2020 challenge from MAT files
        and the corresponding HEA files. The records are split into folds
        stratified by labels to simplify cross-validation. Optionally, the ECG
        signals in the records can be resampled.

        Args:
            directory: Path to a directory with MAT files.
            n_folds: Number of folds to create.
            fs: Sampling frequency in Hertz for resampling.
            hp_filter: High pass filter the ECG signals of the record

        Returns:
            A dataset of records.
        """
        process = psutil.Process(os.getpid())
        directory = pathlib.Path(directory)

        def is_record_file(path):
            return path.is_file() and path.suffix == '.mat'
            # and cls._SUBJECT_REGEX.fullmatch(path.stem))

        def iterate_directory(directory):
            for path in directory.iterdir():
                if is_record_file(path):
                    return directory.iterdir()
            return directory.glob("*/*.mat")

        records = []
        datasets = set()
        count = 0
        for file in sorted(
                filter(is_record_file, iterate_directory(directory))):
            record = Record.from_mat(file)
            datasets.add(record.dataset)
            if hp_filter:
                record.filter()
            try:
                idxs = record.r_peak_detector()
                if len(idxs) == 0:
                    print("WARNING: no beats found in recording",
                          record.subject)
                    print("Label", record.labels)
                    print("Excluding recording.")
                    continue
                else:
                    record.r_peaks = idxs
                    # print('R peaks detected: {0}, fs = {1}, time = {2}'
                    #       .format(len(idxs), record.fs, record.time[-1]))
            except IndexError:
                print("WARNING: no beats found in recording", record.subject)
                print("Label", record.labels)
                print("Excluding recording.")
                continue
            # Discard record if more than 4 beats per seconds detected which is
            # equivalent to 240bpm, likely the record is corrupted
            # if len(record.r_beats) > 4 * int(record.time[-1]):
            #    continue
            if fs is not None:
                record.r_peaks = (
                        fs / record.fs * np.array(record.r_peaks)
                ).astype(int)
                record.resample(fs)
            records.append(record)
            if count % 500 == 0:
                print("Recording:", count, "  ||  ", file.stem, end="\r")
                print(
                    'memory usesd: ' +
                    str(process.memory_info().rss // 1024 ** 2) + 'MB'
                )
            count += 1

        if n_folds is None or n_folds == 1:
            folds = [records]
        else:
            # The least common label of multi-labeled records is used to build
            # stratified folds.
            counter = collections.Counter()
            for record in records:
                counter.update(record.labels)

            labels = []
            for record in records:
                i = np.argmin([counter[label] for label in record.labels])
                labels.append(
                    record.labels[i] + record.dataset)  # Stratify by dataset.
            splitter = StratifiedKFold(n_splits=n_folds, shuffle=False)
            folds = []
            for _, indices in splitter.split(np.zeros(len(labels)), labels):
                folds.append([records[i] for i in indices])

        return cls(folds, datasets)


###############################################################################
#
# From cinc2020.scripts.transformer_test.data 19/03/21
#
###############################################################################

class DataGeneratorAutoregressive(Sequence):
    """
    Data generator for ECG data that has been split into beats.
    The generator assumes an autoregressive task of next beat regression.
    """

    def __init__(
            self, batch_size, root_path, folds=None, shuffle=True,
            datasets=None
    ):
        if folds is None:
            folds = [0, 1, 2, 3]
        if datasets is None:
            datasets = ['Georgia', "PTB_XL"]
        folder_path = pathlib.Path(root_path) / pathlib.Path(FOLDER_NAME)
        pattern = create_pattern(datasets, folds)
        self.paths = get_paths(folder_path, pattern)
        n_beats_per_recording = read_json(
            folder_path / pathlib.Path("n_beats_per_recording.json")
        )
        self.path_target_idx = []

        print("Preparing paths and beat lengths")
        for path in self.paths:
            n_beats = n_beats_per_recording[path.stem]
            for j in range(1, n_beats):
                self.path_target_idx.append((path, j))
        if shuffle:
            print("Shuffling data")
            np.random.shuffle(self.path_target_idx)

        self.batch_size = batch_size
        self.n_batches = int(len(self.path_target_idx) / self.batch_size)
        self.n_recordings = len(self.paths)
        self.n_training_points = len(self.path_target_idx)

    def __len__(self):
        """Denotes the number of batches per epoch"""
        return self.n_batches

    def __getitem__(self, idx):
        """Returns the data and targets for one batch as numpy arrays"""
        if not isinstance(idx, int):
            raise ValueError("idx must be an integer.")
        if not 0 <= idx < self.n_batches:
            raise ValueError("idx is out of bounds.")
        paths = self.path_target_idx[
                idx * self.batch_size:(idx + 1) * self.batch_size
                ]
        beats = []
        for filename, j_beats in paths:
            try:
                record = Record.from_npz(filename)
                # Concatenate (clone) the ecg data 12 leads // n_leads times
                if record.ecg.shape[-1] != MAX_NUM_LEADS:
                    record.ecg = np.concatenate(
                        [record.ecg] * (MAX_NUM_LEADS // record.ecg.shape[-1]),
                        axis=-1
                    )
                beats.append((record.ecg, j_beats))
                # beats.append((load_npz(filename), i))
            except BadZipFile:
                print()
                print("Excluding file", filename, "because of error.")
        x_data = [beat[:j_beats] for beat, j_beats in beats]
        y_data = [beat[j_beats] for beat, j_beats in beats]
        return x_data, y_data


class DataGeneratorSupervised(Sequence):
    """
    Data generator for ECG data that has been split into beats for supervised
    learning task.
    """

    def __init__(
            self, batch_size, root_path, classes=None, folds=None,
            shuffle=True, datasets=None, negative=False
    ):
        if folds is None:
            folds = [0, 1, 2, 3]
        if datasets is None:
            datasets = ['Georgia', "PTB_XL"]
        folder_path = pathlib.Path(root_path) / pathlib.Path(FOLDER_NAME)
        print(folder_path)
        pattern = create_pattern(datasets, folds)
        self.paths = get_paths(folder_path, pattern)
        assert len(self.paths) > 0

        if shuffle:
            print("Shuffling data")
            np.random.shuffle(self.paths)
        self.encoder = ClassEncoder(classes=classes, negative=negative)
        self.batch_size = batch_size
        self.n_batches = int(len(self.paths) / self.batch_size)
        self.n_recordings = len(self.paths)

    def __len__(self):
        """Denotes the number of batches per epoch"""
        return self.n_batches

    def __getitem__(self, idx):
        """Returns the data and targets for one batch as numpy arrays"""
        if not isinstance(idx, int):
            raise ValueError("idx must be an integer.")
        if not 0 <= idx < self.n_batches:
            raise ValueError("idx is out of bounds.")
        paths = self.paths[idx * self.batch_size:(idx + 1) * self.batch_size]
        x_data = []
        y_data = []
        for filename in paths:
            try:
                record = Record.from_npz(filename)
                x_data.append(record.ecg)
                y_data.append(record.labels)
                # beats.append((load_npz(filename), i))
            except BadZipFile:
                print()
                print("Excluding file", filename, "because of error.")
        y_data = self.encoder.encode(y_data)
        assert len(x_data) == self.batch_size
        assert len(y_data) == self.batch_size
        return x_data, y_data


def create_pattern(datasets, folds):
    dataset_list = []
    if "Georgia" in datasets:
        dataset_list.append("E")
    if "CPSC_1" in datasets:
        dataset_list.append("A")
    if "CPSC_2" in datasets:
        dataset_list.append("Q")
    if "PTB" in datasets:
        dataset_list.append("S")
    if "PTB_XL" in datasets:
        dataset_list.append("HR")
    if "StPetersburg" in datasets:
        dataset_list.append("I")
    if "Chapman" in datasets:
        dataset_list.append("JS")
    if "Ningbo" in datasets:
        dataset_list.append("JS")
    if "2000mitdb" in datasets:
        dataset_list.append("M")
    if "2000ltafdb" in datasets:
        dataset_list.append("L")
    if "2000nstdb" in datasets:
        dataset_list.append("N")
    if "2000cudb" in datasets:
        dataset_list.append("cu")
    if "1977AHADB" in datasets:
        dataset_list.append("AH")
    if "2009edb" in datasets:
        dataset_list.append("e")
    return "fold" + str(folds) + "/" + str(dataset_list) + "*.npz"


def get_paths(folder, pattern):
    folder = pathlib.Path(folder)
    return [path for path in folder.glob(pattern)]


def preprocess_recording(
        recording, header, n_leads=12, fs=500, hp_filter=False
):
    """ Preprocess recording data

    Args:
        recording: ecg data
        header: string from helper_code.load_header()
        n_leads: number of leads
        fs: sampling frequency
        hp_filter: High pass filter the ECG signals of the record

    Returns:
        A record
    """
    header = list(line.split(' ') for line in list(header.strip().split('\n')))
    if len(recording) == n_leads:
        recording = recording.T
    recording = recording.astype('float32')
    record = Record.from_header(header, recording)
    if hp_filter:
        record.filter()
    try:
        idxs = record.r_peak_detector()
        if len(idxs) == 0:
            print("WARNING: no beats found in recording", record.subject)
        else:
            record.r_peaks = idxs
            # print('R peaks detected: {0}, fs = {1}, time = {2}'
            #       .format(len(idxs), record.fs, record.time[-1]))
    except IndexError:
        print("WARNING: no beats found in recording", record.subject)
    # Discard record if more than 4 beats per seconds detected which is
    # equivalent to 240bpm, likely the record is corrupted
    # if len(record.r_beats) > 4 * int(record.time[-1]):
    #    continue
    if fs is not None:
        record.r_peaks = (
                fs / record.fs * np.array(record.r_peaks)
        ).astype(int)
        record.resample(fs)
    return record


###############################################################################
#
# From cinc2020.scripts.transformer_test.visualization_tool_functions 11/03/21
#
###############################################################################


def transformer_preprocess(
        recording, header, max_before=333, max_after=667, n_leads=12, fs=500,
        maximum_position_encoding=50, hp_filter=False
):
    """Preprocess data for input to transformer.

    Find R-peaks, split into heartbeats, align beats at R-peak and zeropad on
    both sides.

    Args:
        recording: an input recording
        header: string from helper_code.load_header()
        max_before: the maximum number of samples before the R-peak
        max_after: the maximum number of samples after the R-peak
        n_leads: the number of leads in the data
        fs: sampling frequency
        maximum_position_encoding: the maximum number of beats per recording
        hp_filter: High pass filter the ECG signals of the record

    Returns:
        inp: input ECG to transformer model
        padding_mask: padding mask input to transformer model
        EXAMPLE transformer use: transformer(inp, False, padding_mask)
    """
    record = preprocess_recording(recording, header, n_leads, fs, hp_filter)
    if record.r_peaks is None:
        return "flat", "flat"
    beats = zeropad_beats(
        record.ecg, record.r_peaks, max_before=max_before,
        max_after=max_after, n_leads=n_leads
    )
    padding_mask = create_padding_mask([beats], maximum_position_encoding)
    inp = zeropad([beats], maximum_position_encoding)
    return inp, padding_mask


###############################################################################
#
# From transform_ECG_classification.ipynb ("Better splitting") 22/03/21
#
###############################################################################

def zeropad_beats(record, rr_idxs, max_before, max_after, n_leads):
    try:
        ecg = record.ecg
    except AttributeError:
        ecg = record
    idxs = rr_idxs
    beats = []
    for i in range(len(idxs)):
        if i == 0:
            rr_interval_before = idxs[i]
        else:
            rr_interval_before = idxs[i] - idxs[i - 1]

        if i == len(idxs) - 1:
            len_after = len(ecg) - idxs[i]
            rr_interval_after = len_after
        else:
            rr_interval_after = idxs[i + 1] - idxs[i]

        len_rri_before = int(np.round(rr_interval_before / 3))
        len_rri_after = int(np.round(2 * rr_interval_after / 3))

        if len_rri_before > max_before:
            len_rri_before = max_before
        if len_rri_after > max_after:
            len_rri_after = max_after

        idx_0 = idxs[i] - len_rri_before
        idx_1 = idxs[i] + len_rri_after

        n_zeros_before = max_before - len_rri_before
        n_zeros_after = max_after - len_rri_after
        beat = np.concatenate([
            np.zeros((n_zeros_before, n_leads)),
            ecg[idx_0:idx_1, :],
            np.zeros((n_zeros_after, n_leads))
        ])
        assert beat.shape[-2] == max_before + max_after
        beats.append(beat)
    beats = np.array(beats)
    assert len(beats.shape) == 3, f'{beats.shape}, {record.subject}, ' \
                                  f'{len(idxs)}, {idxs}, {record.ecg.shape}'
    return beats


def split_zeropad_save(
        data_iterator, data_folder, max_before=333, max_after=667,
        maximum_position_encoding=50, fs=500, return_beats=True, datasets=None,
        supervised=False
):
    """Split ECG recording into zeropadded heartbeats.

    Find R-peaks, split into heartbeats, align beats at R-peak and zeropad on
    both sides.

    Args:
        data_iterator: an iterator of the input ECG data
        data_folder: path to datasets
        max_before: the maximum number of samples before the R-peak
        max_after: the maximum number of samples after the R-peak
        maximum_position_encoding: maximum number of beats per recording
        fs: sampling frequency
        return_beats: return constructed beats
        datasets: list of datasets to be separated into beats
        supervised: flag for supervised vs unsupervised training

    Returns:
        all_beats: a list of recordings on the form of np.arrays of shape
                   (number_of_beats, fixed_heartbeat_length, n_leads)
    """
    if datasets is None:
        datasets = ["CPSC_1", "CPSC_2", "Georgia", "PTB", "PTB_XL",
                    "StPetersburg"]
    n_beats_per_recording = {}
    count = 0
    if return_beats:
        all_beats = []
    fold_before = -1
    for fold, record in data_iterator:
        if fold - fold_before != 0:
            if return_beats:
                all_beats.append([])
            fold_before = fold
        if record.dataset not in datasets:
            continue
        if count % 500 == 0:
            print("Record", count, record.dataset)
        # time = np.arange(record.ecg.shape[0]) / record.fs
        idxs = record.r_peaks
        beats = zeropad_beats(
            record, idxs, max_before, max_after, record.ecg.shape[1]
        )
        fold_path = pathlib.Path(data_folder) / pathlib.Path(
            "fold" + str(fold))
        fold_path.mkdir(parents=True, exist_ok=True)
        record_temp = copy.deepcopy(record)
        start = 0
        for i in range((len(beats) - 1) // maximum_position_encoding + 1):
            suf = '' if not i else '_' + str(i)
            n_beats = min(maximum_position_encoding,
                          len(beats) - maximum_position_encoding * i)
            n_beats_per_recording[record.subject + suf] = n_beats
            record_temp.ecg = beats[start:start + n_beats]
            record_temp.to_npz(
                fold_path / (record_temp.subject + suf + ".npz"))
            if return_beats:
                all_beats[fold].append(beats[start:start + n_beats])
            start += n_beats
            # For supervised training with classification, only consider first
            # 50 beats
            if supervised:
                break
        count += 1
    if os.path.isfile(
            pathlib.Path(data_folder) / 'n_beats_per_recording.json'):
        prev_dict = read_json(
            pathlib.Path(data_folder) / 'n_beats_per_recording.json'
        )
        n_beats_per_recording = dict(
            itertools.chain(prev_dict.items(), n_beats_per_recording.items())
        )
    write_json(
        pathlib.Path(data_folder) / 'n_beats_per_recording.json',
        n_beats_per_recording
    )
    print("Success")
    if return_beats:
        return all_beats


###############################################################################
#
# From analyze_transformer.ipynb 22/03/21
#
###############################################################################

def get_classes(datasets):
    classes = set()
    if "PTB_XL" in datasets:
        classes = classes.union(PTB_XL_LABELS)
    if "Georgia" in datasets:
        classes = classes.union(GEORGIA_LABELS)
    classes = list(classes)
    return classes


###############################################################################
#
# From cinc2020-challenge-9.train_12ECG_classifier
#
###############################################################################


def load_data(input_directory, output_directory, fs):
    folder_name = "train_from_scratch_" + str(fs) + "hz"
    # CSEM machine
    # dataset_directory = pathlib.Path(input_directory).parent / folder_name
    # Submission
    dataset_directory = pathlib.Path(output_directory) / folder_name
    if not dataset_directory.is_dir():
        print("Preprocessing data...")
        ds = Dataset.from_mat(
            input_directory, n_folds=5, fs=fs, hp_filter=True
        )
        # print("Saving data...")
        # ds.to_npz(dataset_directory)
        return ds
    return Dataset.from_npz(dataset_directory)
