'''
    Simply splits dataset into training and test set
    Usage:
        python split_dataset.py --source [dataset directory] --split [testset ratio(0~1)]
'''

import argparse
import os
import sys
from shutil import copyfile

import numpy as np

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Splitting dataset into training and test set")
    parser.add_argument(
        "--source", required=True, help="dataset folder you want to split"
    )
    parser.add_argument(
        "--split", type=float, default=0.2, help="split ratio for test data"
    )
    args = parser.parse_args()

    data_directory = args.source
    split_ratio = args.split

    files = []
    for f in os.listdir(data_directory):
        file_id = f.split(".")[0]
        if file_id not in files:
            files.append(file_id)

    dataset_size = len(files)
    split_index = int(np.floor(split_ratio * dataset_size))
    np.random.shuffle(files)
    train_files, test_files = files[split_index:], files[:split_index]

    print("Total dataset size: {}".format(dataset_size))

    path = os.path.split(data_directory)[0]
    dataset_name = data_directory.split("_")[-1]
    train_data_directory = os.path.join(path, "Training_" + dataset_name)
    test_data_directory = os.path.join(path, "Test_" + dataset_name)

    if not os.path.isdir(train_data_directory):
        os.mkdir(train_data_directory)

    if not os.path.isdir(test_data_directory):
        os.mkdir(test_data_directory)

    for file_id in train_files:
        copyfile(
            os.path.join(data_directory, file_id + ".hea"),
            os.path.join(train_data_directory, file_id + ".hea"),
        )
        copyfile(
            os.path.join(data_directory, file_id + ".mat"),
            os.path.join(train_data_directory, file_id + ".mat"),
        )
    print("Training dataset size: {}".format(len(train_files)))

    for file_id in test_files:
        copyfile(
            os.path.join(data_directory, file_id + ".hea"),
            os.path.join(test_data_directory, file_id + ".hea"),
        )
        copyfile(
            os.path.join(data_directory, file_id + ".mat"),
            os.path.join(test_data_directory, file_id + ".mat"),
        )
    print("Test dataset size: {}".format(len(test_files)))
