import yaml
import pandas as pd
import numpy as np
import glob
import os
import shutil
import sys


def get_is_scored(header_file_name, scored_set):
    fid = open(header_file_name, 'r')
    readall = fid.readlines()
    fid.close()
    for each in readall:
        if each[:5]=='#Dx: ':
            dx_set = set(each[5:].replace('\n','').split(','))
    if len(dx_set.intersection(scored_set)) > 0:
        is_scored = True
    else:
        is_scored = False
    return is_scored

def merge_sets_split_train_test(yaml_file):
    config = yaml.load(open(yaml_file,'r'),Loader=yaml.FullLoader)

    dx_mapping_scored = config['dx_mapping_scored']
    split = config['split']
    datapath = config['datapath']
    dataset_list = config['dataset_list']
    trainpath = config['trainpath']
    testpath = config['testpath']

    dx_mapping_scored = pd.read_csv(dx_mapping_scored)
    scored_set = set([str(each) for each in dx_mapping_scored['SNOMEDCTCode'].tolist()])


    df_list = []

    for each_set_name in dataset_list:
        head_list = glob.glob('{0}/{1}/*.hea'.format(datapath, each_set_name))
        head_list.sort()

        score_list = []
        for each_header_file in head_list:
            if get_is_scored(each_header_file, scored_set):
                score_list.append(True)
            else:
                score_list.append(False)

        df = pd.DataFrame()
        df['header'] = head_list
        df['is_scored'] = score_list
        df['set_name'] = each_set_name
        df_list.append(df)

    df = pd.concat(df_list)
    df.index = range(df.shape[0])
    df['is_test'] = np.all([df['is_scored'], (np.random.random(df.shape[0]) < split)],0)


    train_list = df.query('not is_test')['header'].tolist()
    test_list = df.query('is_test')['header'].tolist()

    os.makedirs(trainpath, exist_ok = True)
    os.makedirs(testpath, exist_ok = True)


    for each_header_file in train_list:
        each_base_name = os.path.basename(each_header_file).replace('.hea', '')
        each_dir_name = os.path.dirname(each_header_file)
        shutil.copy('{0}/{1}.hea'.format(each_dir_name, each_base_name), \
            '{0}/{1}.hea'.format(trainpath, each_base_name))
        shutil.copy('{0}/{1}.mat'.format(each_dir_name, each_base_name), \
            '{0}/{1}.mat'.format(trainpath, each_base_name))


    for each_header_file in test_list:
        each_base_name = os.path.basename(each_header_file).replace('.hea', '')
        each_dir_name = os.path.dirname(each_header_file)
        shutil.copy('{0}/{1}.hea'.format(each_dir_name, each_base_name), \
            '{0}/{1}.hea'.format(testpath, each_base_name))
        shutil.copy('{0}/{1}.mat'.format(each_dir_name, each_base_name), \
            '{0}/{1}.mat'.format(testpath, each_base_name))


if __name__ == "__main__":
    # Parse arguments.
    if len(sys.argv) != 2:
        raise Exception(
            "Include the yaml file as arguments, e.g., python merge_sets.split_train_test.py merge_sets.split_train_test.yaml"
        )

    merge_sets_split_train_test(sys.argv[1])


