import sys
import os
import shutil
import wfdb
import numpy as np
from scipy.io import savemat
import tkinter as tk
from tkinter import filedialog
# =============================================================================
# README: WFDB to MAT Conversion Script (GRABMyoFlow static Dataset)
# =============================================================================
# This script processes raw WFDB data files (.dat, .hea) from the GRABMyo
# dataset (43 subjects) and/or the extension from the GRABMyoFlow dataset.
#
# UPDATED LOGIC:
# - Uses Global IDs (1-63) for BOTH folder names and filenames.
# - Extension files (44-63) are processed to match the original format.
# - Output .mat files contain DATA_FOREARM (16ch) and DATA_WRIST (12ch) matrices,
# where the DATA_FOREARM matricies for extension subjects are empty.
# =============================================================================
# =============================================================================
# STEP 0: CONSTANTS AND INITIAL SETUP
# =============================================================================
root = tk.Tk()
root.withdraw()
# --- Dataset Constants ---
NSESSION = 3
NGEESTURE = 16
NTRIALS = 7
NCH_ORIGINAL = 32
NSUB_ORIGINAL = 43
NSUB_EXTENSION = 20
NSUB_TOTAL = NSUB_ORIGINAL + NSUB_EXTENSION
# --- Original Dataset Masks (32 Channels) ---
# Keeps 16 Forearm channels
FOREARM_INDICES_ORIGINAL = list(range(0, 16))
# Keeps 12 Wrist channels (Skipping indices 16, 23, 24, 31)
WRIST_INDICES_ORIGINAL = list(range(17, 23)) + list(range(25, 31))
forearm_mask_orig = np.zeros(NCH_ORIGINAL, dtype=bool)
forearm_mask_orig[FOREARM_INDICES_ORIGINAL] = True
wrist_mask_orig = np.zeros(NCH_ORIGINAL, dtype=bool)
wrist_mask_orig[WRIST_INDICES_ORIGINAL] = True
# --- Extension Dataset Mask (16 Channels) ---
# Keeps 12 Wrist channels (Skipping indices 1, 7, 8, 15)
EXTENSION_INDICES_KEEP = list(range(1, 7)) + list(range(9, 15))
# =============================================================================
# STEP 1: USER CHOICE AND PATH SELECTION
# =============================================================================
WHOLE_DATASET = input("Convert whole dataset (W) or only reconstructed extension (E)? ").upper()
if WHOLE_DATASET not in ('W', 'E'):
print("Invalid input. Exiting Script!")
sys.exit()
WHOLE_DATASET_FLAG = (WHOLE_DATASET == 'W')
DATA_PATHS = []
if WHOLE_DATASET_FLAG:
print("Please select the ROOT folder for the original 43 participants ('.../grabmyo-1.1.0/')...")
gm_BASE_INPUT = filedialog.askdirectory(title="Select Original Dataset ROOT Folder")
if not gm_BASE_INPUT:
print("No folder selected. Exiting Script!")
sys.exit()
DATA_PATHS.append(gm_BASE_INPUT)
OUTPUT_FOLDER_NAME = "static_whole_MATLAB"
print("\n--- Next: Select Extension Dataset Path ---")
else:
OUTPUT_FOLDER_NAME = "static_extension_MATLAB"
print("Please select the folder containing the extension data ('...GRABMyoFlow_1.0\static_extension_WFDB')...")
ext_BASE_INPUT = filedialog.askdirectory(title="Select Extension Data (static_extension_WFDB) Folder")
if not ext_BASE_INPUT:
print("No folder selected. Exiting Script!")
sys.exit()
DATA_PATHS.append(ext_BASE_INPUT)
# =============================================================================
# STEP 2: OUTPUT CONFIGURATION AND RANGE DEFINITION
# =============================================================================
OUTPUT_ROOT = os.path.join(os.path.dirname(ext_BASE_INPUT), OUTPUT_FOLDER_NAME)
if WHOLE_DATASET_FLAG:
# 1 to 63
sub_range = range(1, NSUB_TOTAL + 1)
else:
# 44 to 63
sub_range = range(NSUB_ORIGINAL + 1, NSUB_TOTAL + 1)
# =============================================================================
# STEP 3: OUTPUT FOLDER HANDLING (Overwrite check)
# =============================================================================
if not os.path.exists(OUTPUT_ROOT):
os.makedirs(OUTPUT_ROOT)
print(f"Created output folder: {OUTPUT_ROOT}")
else:
while True:
print(f"Found existing folder in: {OUTPUT_ROOT}")
cont = input("Overwrite it (Y/N)? ").upper()
if cont in ('Y', 'N'):
if cont == 'Y':
print("Overwriting")
shutil.rmtree(OUTPUT_ROOT)
os.makedirs(OUTPUT_ROOT)
break
else:
print("Exiting Script!")
sys.exit()
# =============================================================================
# STEP 4: MAIN CONVERSION LOOP
# =============================================================================
count = 0
def create_empty_matrix(rows, cols):
arr = np.empty((rows, cols), dtype=object)
for i in range(rows):
for j in range(cols):
arr[i, j] = np.array([])
return arr
for isession in range(1, NSESSION + 1):
output_session_folder = f"Session{isession}"
os.makedirs(os.path.join(OUTPUT_ROOT, output_session_folder), exist_ok=True)
for isub_global in sub_range:
is_original_subject = (isub_global <= NSUB_ORIGINAL)
if is_original_subject:
# --- ORIGINAL DATASET (1-43) ---
base_dir = DATA_PATHS[0]
session_dir_name = f"Session{isession}"
else:
# --- EXTENSION DATASET (44-63) ---
base_dir = DATA_PATHS[-1]
session_dir_name = f"session{isession}"
foldername = f"session{isession}_participant{isub_global}"
participant_dir = os.path.join(base_dir, session_dir_name, foldername)
if not os.path.exists(participant_dir):
print(f"[SKIP] Folder not found: {participant_dir}")
continue
matrices_forearm = create_empty_matrix(NTRIALS, NGEESTURE + 1)
matrices_wrist = create_empty_matrix(NTRIALS, NGEESTURE + 1)
for igesture in range(1, NGEESTURE + 2):
for itrial in range(1, NTRIALS + 1):
filename = f"session{isession}_participant{isub_global}_gesture{igesture}_trial{itrial}"
wfdb_record_path = os.path.join(participant_dir, filename)
if not (os.path.exists(wfdb_record_path + ".dat") or os.path.exists(wfdb_record_path + ".hea")):
continue
try:
record = wfdb.rdrecord(wfdb_record_path)
data_emg = record.p_signal
if is_original_subject:
# Original Logic: Split 32ch into Forearm(16) and Wrist(12)
data_forearm = data_emg[:, forearm_mask_orig]
data_wrist = data_emg[:, wrist_mask_orig]
else:
# Extension Logic: Split 16ch (Wrist only) to 12ch.
# Forearm is empty for extension subjects
data_forearm = np.empty((data_emg.shape[0], 0))
if data_emg.shape[1] >= 16:
data_wrist = data_emg[:, EXTENSION_INDICES_KEEP]
else:
print(f"session{isession}_participant{isub_global}_gesture{igesture}_trial{itrial}: Unexpected channel count ({data_emg.shape[1]}ch). Using all available channels for wrist.")
data_wrist = data_emg
matrices_forearm[itrial - 1, igesture - 1] = data_forearm
matrices_wrist[itrial - 1, igesture - 1] = data_wrist
except Exception as e:
print(f"[ERROR] Reading {filename}: {e}")
count += 1
print(f"[OK] Converted participant {isub_global} of session {isession}")
# =============================================================================
# STEP 5: SAVE OUTPUT
# =============================================================================
mat_filename = f"session{isession}_participant{isub_global}.mat"
try:
savemat(os.path.join(OUTPUT_ROOT, output_session_folder, mat_filename),
{"DATA_FOREARM": matrices_forearm,
"DATA_WRIST": matrices_wrist})
except Exception as e:
print(f"[FATAL ERROR] Could not save {mat_filename}: {e}")
print("\nEnd of Script.")