Waveform Database Software Package (WFDB) for Python 4.1.0

File: <base>/tests/test_processing.py (4,705 bytes)
import numpy as np

import wfdb
from wfdb import processing


class test_processing:
    """
    Test processing functions
    """

    def test_resample_single(self):
        sig, fields = wfdb.rdsamp("sample-data/100")
        ann = wfdb.rdann("sample-data/100", "atr")

        fs = fields["fs"]
        fs_target = 50

        new_sig, new_ann = processing.resample_singlechan(
            sig[:, 0], ann, fs, fs_target
        )

        expected_length = int(sig.shape[0] * fs_target / fs)

        assert new_sig.shape[0] == expected_length

    def test_resample_multi(self):
        sig, fields = wfdb.rdsamp("sample-data/100")
        ann = wfdb.rdann("sample-data/100", "atr")

        fs = fields["fs"]
        fs_target = 50

        new_sig, new_ann = processing.resample_multichan(
            sig, ann, fs, fs_target
        )

        expected_length = int(sig.shape[0] * fs_target / fs)

        assert new_sig.shape[0] == expected_length
        assert new_sig.shape[1] == sig.shape[1]

    def test_normalize_bound(self):
        sig, _ = wfdb.rdsamp("sample-data/100")
        lb = -5
        ub = 15

        x = processing.normalize_bound(sig[:, 0], lb, ub)
        assert x.shape[0] == sig.shape[0]
        assert np.min(x) >= lb
        assert np.max(x) <= ub

    def test_find_peaks(self):
        x = [0, 2, 1, 0, -10, -15, -15, -15, 9, 8, 0, 0, 1, 2, 10]
        hp, sp = processing.find_peaks(x)
        assert np.array_equal(hp, [1, 8])
        assert np.array_equal(sp, [6, 10])

    def test_find_peaks_empty(self):
        x = []
        hp, sp = processing.find_peaks(x)
        assert hp.shape == (0,)
        assert sp.shape == (0,)

    def test_gqrs(self):

        record = wfdb.rdrecord(
            "sample-data/100",
            channels=[0],
            sampfrom=9998,
            sampto=19998,
            physical=False,
        )

        expected_peaks = [
            271,
            580,
            884,
            1181,
            1469,
            1770,
            2055,
            2339,
            2634,
            2939,
            3255,
            3551,
            3831,
            4120,
            4412,
            4700,
            5000,
            5299,
            5596,
            5889,
            6172,
            6454,
            6744,
            7047,
            7347,
            7646,
            7936,
            8216,
            8503,
            8785,
            9070,
            9377,
            9682,
        ]

        peaks = processing.gqrs_detect(
            d_sig=record.d_signal[:, 0],
            fs=record.fs,
            adc_gain=record.adc_gain[0],
            adc_zero=record.adc_zero[0],
            threshold=1.0,
        )

        assert np.array_equal(peaks, expected_peaks)

    def test_correct_peaks(self):
        sig, fields = wfdb.rdsamp("sample-data/100")
        ann = wfdb.rdann("sample-data/100", "atr")
        fs = fields["fs"]
        min_bpm = 10
        max_bpm = 350
        min_gap = fs * 60 / min_bpm
        max_gap = fs * 60 / max_bpm

        y_idxs = processing.correct_peaks(
            sig=sig[:, 0],
            peak_inds=ann.sample,
            search_radius=int(max_gap),
            smooth_window_size=150,
        )

        yz = np.zeros(sig.shape[0])
        yz[y_idxs] = 1
        yz = np.where(yz[:10000] == 1)[0]

        expected_peaks = [
            77,
            370,
            663,
            947,
            1231,
            1515,
            1809,
            2045,
            2403,
            2706,
            2998,
            3283,
            3560,
            3863,
            4171,
            4466,
            4765,
            5061,
            5347,
            5634,
            5919,
            6215,
            6527,
            6824,
            7106,
            7393,
            7670,
            7953,
            8246,
            8539,
            8837,
            9142,
            9432,
            9710,
            9998,
        ]

        assert np.array_equal(yz, expected_peaks)


class test_qrs:
    """
    Testing QRS detectors
    """

    def test_xqrs(self):
        """
        Run XQRS detector on record 100 and compare to reference annotations
        """
        sig, fields = wfdb.rdsamp("sample-data/100", channels=[0])
        ann_ref = wfdb.rdann("sample-data/100", "atr")

        xqrs = processing.XQRS(sig=sig[:, 0], fs=fields["fs"])
        xqrs.detect()

        comparitor = processing.compare_annotations(
            ann_ref.sample[1:], xqrs.qrs_inds, int(0.1 * fields["fs"])
        )

        assert comparitor.sensitivity > 0.99
        assert comparitor.positive_predictivity > 0.99