import pathlib
import os,sys
from helper_code import *
import shutil
import random

#%%

EQUIV_CLASSES_old = {'713427006': '59118001', '284470004': '63593006', '427172004': '17338001'}
DIAG_CLASSES_old = {'270492004':0,'164889003':1,'164890007':2,'426627000':3,'713426002':4,'445118002':5,
                '39732003':6,'164909002':7,'251146004':8,'698252002':9,'10370003':10,'164947007':11,'111975006':12,'164917005':13,
                '47665007':14,'59118001':15,'427393009':16,'426177001':17,'426783006':18,'427084000':19,'63593006':20,'164934002':21,'59931005':22,'17338001':23}
SCORED_CLASSES_old=['713427006', '284470004', '427172004','270492004','164889003','164890007','426627000','713426002','445118002',
                '39732003','164909002','251146004','698252002','10370003','164947007','111975006','164917005',
                '47665007','59118001','427393009','426177001','426783006','427084000','63593006','164934002','59931005','17338001']

EQUIV_CLASSES= {'164909002':'733534002','59118001':'713427006','63593006':'284470004','17338001':'427172004'}
DIAG_CLASSES = {'164889003':0,'164890007':1,'6374002':2,'426627000':3,'733534002':4,'713427006':5,'270492004':6,'713426002':7,
                '39732003':8,'445118002':9,'251146004':10,'698252002':11,'426783006':12,'284470004':13,'10370003':14,
                '365413008':15,'427172004':16,'164947007':17,'111975006':18,'164917005':19,'47665007':20,'427393009':21,
                '426177001':22,'427084000':23,'164934002':24,'59931005':25}
SCORED_CLASSES=['164889003','164890007','6374002','426627000','733534002','164909002','59118001','713427006','270492004','713426002',
 '39732003','445118002','164947007','251146004','111975006','698252002','426783006','63593006','284470004','10370003','365413008',
 '17338001','427172004','164917005','47665007','427393009','426177001','427084000','164934002','59931005']

database_names_old = ('WFDB_CPSC2018','WFDB_CPSC2018_2','WFDB_Ga','WFDB_PTB','WFDB_PTBXL','WFDB_StPetersburg')
database_names = ('WFDB_CPSC2018','WFDB_CPSC2018_2','WFDB_Ga','WFDB_PTB','WFDB_PTBXL','WFDB_StPetersburg','WFDB_Ningbo','WFDB_ChapmanShaoxing')


if __name__ == '__main__':
    """

    run with: python create_test_training_dir.py data_dir=../Data/ nr_samples=all test_partition=0.1 training_out=Train_Full/ testing_out=Test_Full/
    data_dir: where DataSets named 'WFDB_CPSC2018','WFDB_CPSC2018_2','WFDB_Ga','WFDB_PTB','WFDB_PTBXL','WFDB_StPetersburg' are located
    nr_samples: total number of extracted samples
    test_partition: percentage number of samples for Training Set
    testing_out:  subfolder of data_dir
    training_out: subfolder of data_dir

    """

    test_set_partition = 0.1
    working_dir = pathlib.Path().absolute() # should be #/Code
    data_dir = os.path.join(working_dir, "../Data/")
    training_out = 'Train_Full/'
    testing_out = 'Test_Full/'
    nr_samples=None

    input_params = list()
    for param in sys.argv[1:]:
        _t = param.split("=")
        if len(_t)==2:
            if _t[0] == 'data_dir':
                data_dir=_t[1]
            elif _t[0]=='nr_samples':
                if not _t[1] == 'all':
                    nr_samples = int(_t[1])
            elif _t[0]=='test_partition':
                _tf = float(_t[1])
                if _tf <= 1.0 and _tf>=0:
                    test_set_partition = _tf
                else:
                    raise Exception('test_partition not between 0 and 1')

            elif _t[0]=='training_out':
                training_out = _t[1]
            elif _t[0]=='testing_out':
                testing_out = _t[1]        


        


    

    
    database_names = ('WFDB_CPSC2018','WFDB_CPSC2018_2','WFDB_Ga','WFDB_PTB','WFDB_PTBXL','WFDB_StPetersburg')
    out_dir_test_full = os.path.join(data_dir,testing_out)
    out_dir_train_full = os.path.join(data_dir,training_out)
    if not os.path.isdir(out_dir_test_full):
        os.mkdir(out_dir_test_full)
    if not os.path.isdir(out_dir_train_full):
        os.mkdir(out_dir_train_full)

    all_header_files = list()
    
    for name in database_names:
        header_files, _ = find_challenge_files(data_dir+name) 
        all_header_files.extend(header_files)

    n_all_header_files = len(all_header_files)
    all_header_files = set(all_header_files)    

    cleaned_header_files = all_header_files.copy()    
    for header_f in all_header_files:
        header = load_header(header_f)
        labels = get_labels(header)
        has_scored_label = any(item in SCORED_CLASSES for item in labels)
        if not has_scored_label:
            cleaned_header_files.discard(header_f)

    cleaned_header_files = list(cleaned_header_files)
    random.shuffle(cleaned_header_files)

    n_recordings = len(cleaned_header_files)
    print("Total number of Recording: ", n_recordings)
    if nr_samples:
        if nr_samples<n_recordings:
            n_recordings=nr_samples

    header_files_train = cleaned_header_files[0:round(n_recordings*(1-test_set_partition))] 
    header_files_test = cleaned_header_files[round(n_recordings*(1-test_set_partition)):n_recordings]


    n_recordings_test = len(header_files_test)
    print('Copying Test Files to ',out_dir_test_full,' ...')
    for idx,header_f in enumerate(header_files_test):
        rec = header_f[:-3] + 'mat'
        shutil.copy(header_f,out_dir_test_full)
        shutil.copy(rec,out_dir_test_full)
        print(idx,'/',n_recordings_test)
    print('copied ', n_recordings_test, 'recordings')    


    n_recordings_train = len(header_files_train)
    print('Copying Training Files ',out_dir_train_full,' ...')
    for idx,header_f in enumerate(header_files_train):
        rec = header_f[:-3] + 'mat'
        shutil.copy(header_f,out_dir_train_full)
        shutil.copy(rec,out_dir_train_full)
        print(idx,'/',n_recordings_train)
    print('copied ', n_recordings_train, 'recordings')

    print("Total number of Recording: ", n_recordings)  
    print("skipped", n_all_header_files-n_recordings, " recordings")       


    #shutil.copy(header_name,out_dir_test_full)

# %%

    