PADS - Parkinsons Disease Smartwatch dataset 1.0.0

File: <base>/scripts/create_csv_stratified_subset.py (5,684 bytes)
import warnings
from utils.constants import patient_dir
from utils.data_handling import load_all_files
import numpy as np
import pandas as pd

data_path = '../preprocessed/'

# Store file list for ml project
df = pd.concat(load_all_files(patient_dir))
df['label'] = df['condition']
df.replace({'label': {'Healthy': 0,
                      "Parkinson's": 1,
                      'Other Movement Disorders': 2,
                      'Essential Tremor': 2,
                      'Multiple Sclerosis': 2,
                      'Atypical Parkinsonism': 2}},
           inplace=True)

df.sort_values(by=['age'], inplace=True)

# The same number of samples from each class with each gender
ref_data1 = ['Female', 'Male',
             'Female', 'Male',
             'Female', 'Male']
ref_data2 = ['0', '0',
             '1', '1',
             '2', '2']
ref_data = np.array([ref_data1, ref_data2])

data = df.loc[:, ['gender', 'label']].values.transpose()


def _match_groups_encoding(data, ref_data):
    """
    Match group encoding: if e.g. one group has elements 'A' and 'C', and the other group 'A', 'B', 'C', the numerical
    encoding will be adjusted so that group B is counted with 0 values in the one group.
    """
    data_dict = set(data).union(set(ref_data))
    data_dict = list(data_dict)
    data_dict.sort()
    data = pd.Series(data).replace(data_dict, np.arange(len(data_dict))).values
    ref_data = pd.Series(ref_data).replace(data_dict, np.arange(len(data_dict))).values
    return data, ref_data


def _get_overlap_groups(data, ref_data):
    """
    Transform the elements of data to string. Concatenate the rows and join the string values to create new group
    labels.
    """
    data = np.array(data, dtype=str)
    ref_data = np.array(ref_data, dtype=str)
    data = list(map("_".join, zip(*data)))
    ref_data = list(map("_".join, zip(*ref_data)))
    data, ref_data = _match_groups_encoding(data, ref_data)
    return data, ref_data


def _get_groups_weights(data, ref_data):
    """
    Get the counts for each discrete group in the dataset and the resulting relative weights that should be applied to
    the samples in the original data to match the distribution in ref_data.
    """
    data = np.array(data, dtype=str)
    ref_data = np.array(ref_data, dtype=str)
    # data, ref_data = _match_groups_encoding(data, ref_data)
    data = np.array(data, dtype=int)
    if len(data.shape) > 1:
        raise ValueError(f'Wrong input: data must be 1d list or 1d ndarray. Passed: {data.shape}.')
    ref_data = np.array(ref_data, dtype=int)
    if len(ref_data.shape) > 1:
        raise ValueError(f'Wrong input: ref_data must be 1d list or 1d ndarray. Passed: {ref_data.shape}.')
    minlength = int(np.max([np.max(ref_data), np.max(data)]) + 1)
    counts = np.bincount(data, minlength=minlength)
    ref_counts = np.bincount(ref_data, minlength=minlength)
    weights = ref_counts / counts
    weights = np.nan_to_num(weights, nan=0, posinf=0, neginf=0)
    weights_per_sample = weights[data]
    return weights_per_sample, weights, counts, ref_counts


def _exact_stratified_sampling(data, weights, random_state=42):
    """
    Exactly match the objective distribution by randomly choosing the number of samples from each group that is coded
    as relative portions of the original length in weights.
    IMPORTANT NOTE: This method is only approximately giving the right result, as the values may become complex float
    numbers and slight numerical errors occur due to rounding etc.
    """
    num_to_sample = (weights / np.max(weights))
    num_to_sample = num_to_sample.astype(np.float32)
    rng = np.random.default_rng(random_state)
    idxs = []
    for n, frac in enumerate(num_to_sample):
        sub_idxs = np.argwhere(data == n)[:, 0]
        size = int(len(sub_idxs) * frac)
        random_idxs = rng.choice(sub_idxs, size=size, replace=False)
        idxs.extend(random_idxs)
    return idxs


def _iterative_stratified_sampling(data, ref_weights, frac=0.8, random_state=42):
    """
    Match the objective distribution by iteratively removing sample indices from the original data. The sampling is
    performed by calculating the group that has the highest difference when subtracting the relative proportion of the
    group in the original data set to the reference weights.
    """
    ref_weights = ref_weights / np.sum(ref_weights)
    rng = np.random.default_rng(random_state)
    orig_len = len(data)
    new_len = frac * orig_len
    idxs = np.arange(orig_len)
    while len(data) > new_len:
        weights = np.bincount(data, minlength=len(ref_weights))
        weights = weights / np.sum(weights)
        weights_diff = weights - ref_weights
        if not weights_diff.any():
            warnings.warn(f'Given frac {frac} is smaller than actual match {len(data) / orig_len}', UserWarning)
            break
        label_to_remove = np.argmax(weights_diff)
        idx_to_remove = np.where(data == label_to_remove)[0]
        idx_to_remove = rng.choice(idx_to_remove, size=1, replace=False)
        data = np.delete(data, idx_to_remove)
        idxs = np.delete(idxs, idx_to_remove)
    return idxs


data, ref_data = _get_overlap_groups(data, ref_data)

weights_per_sample, weights, counts, ref_counts = _get_groups_weights(data, ref_data)

# idxs = _exact_stratified_sampling(data, weights)

idxs = _iterative_stratified_sampling(data, ref_counts, frac=0.5)
idxs = idxs.tolist()

df = df.iloc[idxs, :]
df.to_csv(f'{data_path}stratified_subset_file_list.csv', index=False, sep=',')