import yaml
import pandas as pd
import numpy as np
import glob
import os
import shutil
import sys
import random
from pathlib import Path
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
from src.utils.preprocess import Recording_resampling
from helper_code import load_recording
from scipy.io import savemat

pd.options.mode.chained_assignment = None  # default='warn'

def set_random_seed(SEED):
    random.seed(SEED)
    np.random.seed(SEED)

def get_is_scored(header_file_name, scored_set):
    fid = open(header_file_name, 'r')
    readall = fid.readlines()
    fid.close()
    for each in readall:
        if each[:5]=='#Dx: ':
            dx_set = set(each[5:].replace('\n','').split(','))
    if len(dx_set.intersection(scored_set)) > 0:
        is_scored = True
    else:
        is_scored = False
    return is_scored

# Get labels from header file.
def get_labels(header_file_name):
    labels = list()
    with open(header_file_name, 'r') as f:
        header = f.read()
        for l in header.split('\n'):
            if l.startswith('#Dx'):
                try:
                    entries = l.split(': ')[1].split(',')
                    for entry in entries:
                        labels.append(entry.strip())
                except:
                    pass
    return labels

def merge_sets_split_train_test(yaml_file):
    config = yaml.load(open(yaml_file,'r'),Loader=yaml.FullLoader)

    dx_mapping_scored = config['dx_mapping_scored']
    datapath = config['datapath']
    dataset_list = config['dataset_list']
    trainpath = config['trainpath']
    testpath = config['testpath']
    num_test_samples = config['num_test_samples']
    test_resample_to_300 = config['test_resample_to_300']

    dx_mapping_scored = pd.read_csv(dx_mapping_scored, dtype=str)
    scored_set = set([str(each) for each in dx_mapping_scored['SNOMEDCTCode'].tolist()])

    df_list = []

    scored_labels = [str(each) for each in dx_mapping_scored['SNOMEDCTCode'].tolist()]
    train_label_df = dx_mapping_scored.iloc[:, 0:3]
    test_label_df = dx_mapping_scored.iloc[:, 0:3]

    for each_set_name in dataset_list:
        head_list = glob.glob('{0}/{1}/*.hea'.format(datapath, each_set_name))
        head_list.sort()
        num_samples_here = len(head_list)

        score_list = []
        for each_header_file in head_list:
            if get_is_scored(each_header_file, scored_set):
                score_list.append(True)
            else:
                score_list.append(False)

        df = pd.DataFrame()
        df['header'] = head_list
        df['is_scored'] = score_list
        df['set_name'] = each_set_name
        num_test_samples_here = num_test_samples.get(each_set_name, 0)
        num_train_samples_here = num_samples_here-num_test_samples_here

        index_is_scored_true = list(df[df.is_scored == True].index)
        shuffled_index_is_test_true = random.sample(index_is_scored_true, len(index_is_scored_true))
        is_test_list = shuffled_index_is_test_true[:num_test_samples_here]

        df['is_test'] = [True if t in is_test_list else False for t in range(num_samples_here)]
        assert df['is_test'].sum() == num_test_samples_here
        df_list.append(df)
        print('Dataset: {}\t\t# of training samples: {}\t\t# of test samples: {}'.format(each_set_name, str(num_train_samples_here), str(num_test_samples_here)))

        train_label_df[each_set_name] = [0] * len(train_label_df)
        test_label_df[each_set_name] = [0] * len(test_label_df)

        # how many labels included in this local train/test set, for each database?
        this_train_df = df[df.is_test == False]
        for idx, row in this_train_df.iterrows():
            this_labels = get_labels(row['header'])
            for this_label in this_labels:
                if this_label in scored_labels:
                    train_label_df.loc[train_label_df.SNOMEDCTCode == this_label, each_set_name]+= 1

        this_test_df = df[df.is_test == True]
        for idx, row in this_test_df.iterrows():
            this_labels = get_labels(row['header'])
            for this_label in this_labels:
                if this_label in scored_labels:
                    test_label_df.loc[test_label_df.SNOMEDCTCode == this_label, each_set_name]+= 1

    os.makedirs(trainpath, exist_ok = False) # overwriting is not recommended; set to False
    os.makedirs(testpath, exist_ok = False)

    dist_path = Path(trainpath).parent.absolute()
    train_label_df.to_csv(os.path.join(dist_path, 'local_train_label_dist.csv'))
    test_label_df.to_csv(os.path.join(dist_path, 'local_test_label_dist.csv'))

    df = pd.concat(df_list)
    df.index = range(df.shape[0])

    train_list = df.query('not is_test')['header'].tolist()
    test_list = df.query('is_test')['header'].tolist()

    print('='*30)
    print('Total # of training samples: {}\t\tTotal # of test samples: {}'.format(str(len(train_list)), str(len(test_list))))

    for each_header_file in train_list:
        each_base_name = os.path.basename(each_header_file).replace('.hea', '')
        each_dir_name = os.path.dirname(each_header_file)
        shutil.copy('{0}/{1}.hea'.format(each_dir_name, each_base_name), \
            '{0}/{1}.hea'.format(trainpath, each_base_name))
        shutil.copy('{0}/{1}.mat'.format(each_dir_name, each_base_name), \
            '{0}/{1}.mat'.format(trainpath, each_base_name))

    RRS = Recording_resampling(target_frequency=300)

    for each_header_file in test_list:
        each_base_name = os.path.basename(each_header_file).replace('.hea', '')
        each_dir_name = os.path.dirname(each_header_file)
        
        ### Apply resampling for test samples in 'test_resample_to_300'
        if os.path.split(each_dir_name)[-1] in test_resample_to_300:
            each_data_file = os.path.join(each_dir_name, each_base_name) + '.mat'
            with open(each_header_file, 'r') as h_file:
                header = h_file.readlines()
                header_copy = []
                for line_idx, this_line in enumerate(header):
                    if line_idx == 0:
                        this_line_split = this_line.split(' ')
                        assert this_line_split[2] == '500' # 500Hz data?
                        this_line_split[2] = '300' # change to 300Hz
                        assert this_line_split[3] == '5000' # 10sec data?
                        this_line_split[3] = '3000' # keep 10sec
                        this_line_modified = ' '.join(this_line_split)
                        header_copy.append(this_line_modified)
                    else:
                        header_copy.append(this_line)
            new_header_path = os.path.join(testpath, each_base_name) + '.hea'
            with open(new_header_path, 'w') as new_h_file:
                new_h_file.writelines(header_copy)
            
            recording = load_recording(each_data_file)
            recording_resampled = RRS(''.join(header), recording).astype(np.int16)
            new_data_path = os.path.join(testpath, each_base_name) + '.mat'
            savemat(new_data_path, mdict={'val': recording_resampled})

        else:
            shutil.copy('{0}/{1}.hea'.format(each_dir_name, each_base_name), \
                '{0}/{1}.hea'.format(testpath, each_base_name))
            shutil.copy('{0}/{1}.mat'.format(each_dir_name, each_base_name), \
                '{0}/{1}.mat'.format(testpath, each_base_name))

if __name__ == "__main__":
    # Parse arguments.
    if len(sys.argv) != 2:
        raise Exception(
            "Include the yaml file as arguments, e.g., python merge_sets.split_train_test.py merge_sets.split_train_test.yaml"
        )
    set_random_seed(42)
    merge_sets_split_train_test(sys.argv[1])


