GRABMyoFlow - Dataset extension 1.0.0

File: <base>/grabmyoflow_wfdb_to_mat_static.py (8,640 bytes)
import sys
import os
import shutil
import wfdb
import numpy as np
from scipy.io import savemat
import tkinter as tk
from tkinter import filedialog

# =============================================================================
# README: WFDB to MAT Conversion Script (GRABMyoFlow static Dataset)
# =============================================================================
# This script processes raw WFDB data files (.dat, .hea) from the GRABMyo
# dataset (43 subjects) and/or the extension from the GRABMyoFlow dataset.
#
# UPDATED LOGIC:
# - Uses Global IDs (1-63) for BOTH folder names and filenames.
# - Extension files (44-63) are processed to match the original format.
# - Output .mat files contain DATA_FOREARM (16ch) and DATA_WRIST (12ch) matrices,
#  where the DATA_FOREARM matricies for extension subjects are empty.
# =============================================================================

# =============================================================================
# STEP 0: CONSTANTS AND INITIAL SETUP
# =============================================================================

root = tk.Tk()
root.withdraw()

# --- Dataset Constants ---
NSESSION = 3
NGEESTURE = 16
NTRIALS = 7
NCH_ORIGINAL = 32 

NSUB_ORIGINAL = 43
NSUB_EXTENSION = 20
NSUB_TOTAL = NSUB_ORIGINAL + NSUB_EXTENSION

# --- Original Dataset Masks (32 Channels) ---
# Keeps 16 Forearm channels
FOREARM_INDICES_ORIGINAL = list(range(0, 16))
# Keeps 12 Wrist channels (Skipping indices 16, 23, 24, 31)
WRIST_INDICES_ORIGINAL = list(range(17, 23)) + list(range(25, 31))

forearm_mask_orig = np.zeros(NCH_ORIGINAL, dtype=bool)
forearm_mask_orig[FOREARM_INDICES_ORIGINAL] = True
wrist_mask_orig = np.zeros(NCH_ORIGINAL, dtype=bool)
wrist_mask_orig[WRIST_INDICES_ORIGINAL] = True

# --- Extension Dataset Mask (16 Channels) ---
# Keeps 12 Wrist channels (Skipping indices 1, 7, 8, 15)
EXTENSION_INDICES_KEEP = list(range(1, 7)) + list(range(9, 15))


# =============================================================================
# STEP 1: USER CHOICE AND PATH SELECTION
# =============================================================================
WHOLE_DATASET = input("Convert whole dataset (W) or only reconstructed extension (E)? ").upper()
if WHOLE_DATASET not in ('W', 'E'):
    print("Invalid input. Exiting Script!")
    sys.exit()
WHOLE_DATASET_FLAG = (WHOLE_DATASET == 'W')

DATA_PATHS = []

if WHOLE_DATASET_FLAG:
    print("Please select the ROOT folder for the original 43 participants ('.../grabmyo-1.1.0/')...")
    gm_BASE_INPUT = filedialog.askdirectory(title="Select Original Dataset ROOT Folder")
    if not gm_BASE_INPUT:
        print("No folder selected. Exiting Script!")
        sys.exit()
    DATA_PATHS.append(gm_BASE_INPUT) 
    
    OUTPUT_FOLDER_NAME = "static_whole_MATLAB"
    print("\n--- Next: Select Extension Dataset Path ---")
else:
    OUTPUT_FOLDER_NAME = "static_extension_MATLAB"
    
print("Please select the folder containing the extension data ('...GRABMyoFlow_1.0\static_extension_WFDB')...")
ext_BASE_INPUT = filedialog.askdirectory(title="Select Extension Data (static_extension_WFDB) Folder")

if not ext_BASE_INPUT:
    print("No folder selected. Exiting Script!")
    sys.exit()
DATA_PATHS.append(ext_BASE_INPUT)


# =============================================================================
# STEP 2: OUTPUT CONFIGURATION AND RANGE DEFINITION
# =============================================================================

OUTPUT_ROOT = os.path.join(os.path.dirname(ext_BASE_INPUT), OUTPUT_FOLDER_NAME)

if WHOLE_DATASET_FLAG:
    # 1 to 63
    sub_range = range(1, NSUB_TOTAL + 1)
else:
    # 44 to 63
    sub_range = range(NSUB_ORIGINAL + 1, NSUB_TOTAL + 1)


# =============================================================================
# STEP 3: OUTPUT FOLDER HANDLING (Overwrite check)
# =============================================================================
if not os.path.exists(OUTPUT_ROOT):
    os.makedirs(OUTPUT_ROOT)
    print(f"Created output folder: {OUTPUT_ROOT}")
else:
    while True:
        print(f"Found existing folder in: {OUTPUT_ROOT}")
        cont = input("Overwrite it (Y/N)? ").upper()
        if cont in ('Y', 'N'):
            if cont == 'Y':
                print("Overwriting")
                shutil.rmtree(OUTPUT_ROOT)
                os.makedirs(OUTPUT_ROOT)
                break
            else:
                print("Exiting Script!")
                sys.exit()


# =============================================================================
# STEP 4: MAIN CONVERSION LOOP
# =============================================================================
count = 0

def create_empty_matrix(rows, cols):
    arr = np.empty((rows, cols), dtype=object)
    for i in range(rows):
        for j in range(cols):
            arr[i, j] = np.array([]) 
    return arr

for isession in range(1, NSESSION + 1):
    output_session_folder = f"Session{isession}"
    os.makedirs(os.path.join(OUTPUT_ROOT, output_session_folder), exist_ok=True)
    
    for isub_global in sub_range:
        
        is_original_subject = (isub_global <= NSUB_ORIGINAL)
        
        if is_original_subject:
            # --- ORIGINAL DATASET (1-43) ---
            base_dir = DATA_PATHS[0]
            session_dir_name = f"Session{isession}"
        else:
            # --- EXTENSION DATASET (44-63) ---
            base_dir = DATA_PATHS[-1]
            session_dir_name = f"session{isession}"
        
        foldername = f"session{isession}_participant{isub_global}"
        participant_dir = os.path.join(base_dir, session_dir_name, foldername)
        
        if not os.path.exists(participant_dir):
            print(f"[SKIP] Folder not found: {participant_dir}")
            continue
            
        matrices_forearm = create_empty_matrix(NTRIALS, NGEESTURE + 1)
        matrices_wrist   = create_empty_matrix(NTRIALS, NGEESTURE + 1)

        for igesture in range(1, NGEESTURE + 2):
            for itrial in range(1, NTRIALS + 1):
                
                filename = f"session{isession}_participant{isub_global}_gesture{igesture}_trial{itrial}"
                wfdb_record_path = os.path.join(participant_dir, filename)

                if not (os.path.exists(wfdb_record_path + ".dat") or os.path.exists(wfdb_record_path + ".hea")):
                    continue

                try:
                    record = wfdb.rdrecord(wfdb_record_path)
                    data_emg = record.p_signal  

                    if is_original_subject:
                        # Original Logic: Split 32ch into Forearm(16) and Wrist(12)
                        data_forearm = data_emg[:, forearm_mask_orig]
                        data_wrist   = data_emg[:, wrist_mask_orig]
                    else:
                        # Extension Logic: Split 16ch (Wrist only) to 12ch.                        
                        # Forearm is empty for extension subjects
                        data_forearm = np.empty((data_emg.shape[0], 0))
                        
                        if data_emg.shape[1] >= 16:
                            data_wrist = data_emg[:, EXTENSION_INDICES_KEEP]
                        else:
                            print(f"session{isession}_participant{isub_global}_gesture{igesture}_trial{itrial}: Unexpected channel count ({data_emg.shape[1]}ch). Using all available channels for wrist.")
                            data_wrist = data_emg

                    matrices_forearm[itrial - 1, igesture - 1] = data_forearm
                    matrices_wrist[itrial - 1, igesture - 1]   = data_wrist

                except Exception as e:
                    print(f"[ERROR] Reading {filename}: {e}")

        count += 1
        print(f"[OK] Converted participant {isub_global} of session {isession}")

        # =============================================================================
        # STEP 5: SAVE OUTPUT
        # =============================================================================
        mat_filename = f"session{isession}_participant{isub_global}.mat"
        try:
            savemat(os.path.join(OUTPUT_ROOT, output_session_folder, mat_filename),
                    {"DATA_FOREARM": matrices_forearm,
                     "DATA_WRIST":   matrices_wrist})
        except Exception as e:
            print(f"[FATAL ERROR] Could not save {mat_filename}: {e}")

print("\nEnd of Script.")