Waveform Database Software Package (WFDB) for Python 4.1.0

File: <base>/tests/test_plot.py (6,402 bytes)
import unittest

import matplotlib.pyplot as plt
import numpy as np

import wfdb
from wfdb.plot import plot


class TestPlotWfdb(unittest.TestCase):
    """
    Tests for the wfdb.plot_wfdb function
    """

    def assertAxesMatchSignal(self, axes, signal, t_divisor=1):
        """
        Check that axis limits are reasonable for plotting a signal array.

        Parameters
        ----------
        axes : matplotlib.axes.Axes
            An Axes object.
        signal : numpy.ndarray
            A one-dimensional array of sample values.
        t_divisor : float, optional
            The intended plotting resolution (number of samples of `signal`
            per unit of the X axis.)

        """
        xmin, xmax = axes.get_xlim()
        tmin = 0
        tmax = (len(signal) - 1) / t_divisor
        # The range from tmin to tmax should fit within the plot.
        self.assertLessEqual(
            xmin,
            tmin,
            msg=f"X range is [{xmin}, {xmax}]; expected [{tmin}, {tmax}]",
        )
        self.assertGreaterEqual(
            xmax,
            tmax,
            msg=f"X range is [{xmin}, {xmax}]; expected [{tmin}, {tmax}]",
        )
        # The padding on left and right sides should be approximately equal.
        self.assertAlmostEqual(
            xmin - tmin,
            tmax - xmax,
            delta=(tmax - tmin) / 10 + 1 / t_divisor,
            msg=f"X range is [{xmin}, {xmax}]; expected [{tmin}, {tmax}]",
        )

        ymin, ymax = axes.get_ylim()
        vmin = np.nanmin(signal)
        vmax = np.nanmax(signal)
        # The range from vmin to vmax should fit within the plot.
        self.assertLessEqual(
            ymin,
            vmin,
            msg=f"Y range is [{ymin}, {ymax}]; expected [{vmin}, {vmax}]",
        )
        self.assertGreaterEqual(
            ymax,
            vmax,
            msg=f"Y range is [{ymin}, {ymax}]; expected [{vmin}, {vmax}]",
        )
        # The padding on top and bottom should be approximately equal.
        self.assertAlmostEqual(
            ymin - vmin,
            vmax - ymax,
            delta=(vmax - vmin) / 10,
            msg=f"Y range is [{ymin}, {ymax}]; expected [{vmin}, {vmax}]",
        )

    def test_physical_smooth(self):
        """
        Plot a record with physical, single-frequency data
        """
        record = wfdb.rdrecord(
            "sample-data/100",
            sampto=1000,
            physical=True,
            smooth_frames=True,
        )
        self.assertIsNotNone(record.p_signal)

        annotation = wfdb.rdann("sample-data/100", "atr", sampto=1000)

        fig = wfdb.plot_wfdb(
            record,
            annotation,
            time_units="samples",
            ecg_grids="all",
            return_fig=True,
        )
        plt.close(fig)

        self.assertEqual(len(fig.axes), record.n_sig)
        for ch in range(record.n_sig):
            self.assertAxesMatchSignal(fig.axes[ch], record.p_signal[:, ch])

    def test_digital_smooth(self):
        """
        Plot a record with digital, single-frequency data
        """
        record = wfdb.rdrecord(
            "sample-data/drive02",
            sampto=1000,
            physical=False,
            smooth_frames=True,
        )
        self.assertIsNotNone(record.d_signal)

        fig = wfdb.plot_wfdb(record, time_units="seconds", return_fig=True)
        plt.close(fig)

        self.assertEqual(len(fig.axes), record.n_sig)
        for ch in range(record.n_sig):
            self.assertAxesMatchSignal(
                fig.axes[ch], record.d_signal[:, ch], record.fs
            )

    def test_physical_multifrequency(self):
        """
        Plot a record with physical, multi-frequency data
        """
        record = wfdb.rdrecord(
            "sample-data/wave_4",
            sampto=10,
            physical=True,
            smooth_frames=False,
        )
        self.assertIsNotNone(record.e_p_signal)

        fig = wfdb.plot_wfdb(record, time_units="seconds", return_fig=True)
        plt.close(fig)

        self.assertEqual(len(fig.axes), record.n_sig)
        for ch in range(record.n_sig):
            self.assertAxesMatchSignal(
                fig.axes[ch],
                record.e_p_signal[ch],
                record.fs * record.samps_per_frame[ch],
            )

    def test_digital_multifrequency(self):
        """
        Plot a record with digital, multi-frequency data
        """
        record = wfdb.rdrecord(
            "sample-data/multi-segment/041s/041s",
            sampto=1000,
            physical=False,
            smooth_frames=False,
        )
        self.assertIsNotNone(record.e_d_signal)

        fig = wfdb.plot_wfdb(record, time_units="seconds", return_fig=True)
        plt.close(fig)

        self.assertEqual(len(fig.axes), record.n_sig)
        for ch in range(record.n_sig):
            self.assertAxesMatchSignal(
                fig.axes[ch],
                record.e_d_signal[ch],
                record.fs * record.samps_per_frame[ch],
            )


class TestPlotInternal(unittest.TestCase):
    """
    Unit tests for internal wfdb.plot.plot functions
    """

    def test_get_plot_dims(self):
        sampfrom = 0
        sampto = 3000
        record = wfdb.rdrecord(
            "sample-data/100", physical=True, sampfrom=sampfrom, sampto=sampto
        )
        ann = wfdb.rdann(
            "sample-data/100", "atr", sampfrom=sampfrom, sampto=sampto
        )
        sig_len, n_sig, n_annot, n_subplots = plot._get_plot_dims(
            signal=record.p_signal, ann_samp=[ann.sample]
        )

        assert sig_len == sampto - sampfrom
        assert n_sig == record.n_sig
        assert n_annot == 1
        assert n_subplots == record.n_sig

    def test_create_figure_single_subplots(self):
        n_subplots = 1
        fig, axes = plot._create_figure(
            n_subplots, sharex=True, sharey=True, figsize=None
        )
        assert fig is not None
        assert axes is not None
        assert len(axes) == n_subplots

    def test_create_figure_multiple_subplots(self):
        n_subplots = 5
        fig, axes = plot._create_figure(
            n_subplots, sharex=True, sharey=True, figsize=None
        )
        assert fig is not None
        assert axes is not None
        assert len(axes) == n_subplots


if __name__ == "__main__":
    unittest.main()