


import glob
import os
import shutil
from tqdm import tqdm
import numpy as np
#from skmultilearn.model_selection import iterative_train_test_split
from sklearn.model_selection import StratifiedKFold


def is_number(x):
    try:
        float(x)
        return True
    except (ValueError, TypeError):
        return False


# Check if a variable is an integer or represents an integer.
def is_integer(x):
    if is_number(x):
        return float(x).is_integer()
    else:
        return False


def get_murmur(data):
    murmur = None
    for l in data.split('\n'):
        if l.startswith('#Murmur:'):
            try:
                murmur = l.split(': ')[1]
            except:
                pass
    if murmur is None:
        raise ValueError('No murmur available. Is your code trying to load labels from the hidden data?')
    return murmur

# Get outcome from patient data.
def get_outcome(data):
    outcome = None
    for l in data.split('\n'):
        if l.startswith('#Outcome:'):
            try:
                outcome = l.split(': ')[1]
            except:
                pass
    if outcome is None:
        raise ValueError('No outcome available. Is your code trying to load labels from the hidden data?')
    return outcome

# Get labels from patient data.
def get_label(data):
    return get_murmur(data)+"-"+get_outcome(data)

# Get number of recording locations from patient data.
def get_num_locations(data):
    num_locations = None
    for i, l in enumerate(data.split('\n')):
        if i==0:
            try:
                num_locations = int(l.split(' ')[1])
            except:
                pass
        else:
            break
    return num_locations

# Find patient data files.
def find_patient_files(data_folder):
    # Find patient files.
    filenames = list()
    for f in sorted(os.listdir(data_folder)):
        root, extension = os.path.splitext(f)
        if not root.startswith('.') and extension=='.txt':
            filename = os.path.join(data_folder, f)
            filenames.append(filename)

    # To help with debugging, sort numerically if the filenames are integers.
    roots = [os.path.split(filename)[1][:-4] for filename in filenames]
    if all(is_integer(root) for root in roots):
        filenames = sorted(filenames, key=lambda filename: int(os.path.split(filename)[1][:-4]))

    return filenames


# Load patient data as a string.
def load_patient_data(filename):
    with open(filename, 'r') as f:
        data = f.read()
    return data


CLASS_MAP2 = {"Present-Abnormal": 0, "Unknown-Abnormal": 1, "Absent-Abnormal": 2, "Present-Normal": 3,"Unknown-Normal": 4, "Absent-Normal": 5}
VERBOSE = True
if __name__ == "__main__":


    path = "/home/nejedly/Data/CINC2022/physionet.org/files/circor-heart-sound/1.0.3/training_data/"

    files = find_patient_files(path)

    statuses = list()
    for file in tqdm(files) if VERBOSE >= 1 else files:
        data = load_patient_data(file)
        statuses.append(CLASS_MAP2[get_label(data)])


    X = np.array(files).reshape(-1, 1)
    y = np.array(statuses).reshape(-1, 1)

    skf = StratifiedKFold(n_splits=5)

    k = 0
    for train_index, test_index in skf.split(X, y):
        X_train = X[train_index]
        y_train = y[train_index]

        X_test = X[test_index]
        y_test = y[test_index]

        _,train_counts = np.unique(y_train,return_counts=True)
        _,valid_counts = np.unique(y_test,return_counts=True)
        print(train_counts)
        print(valid_counts)

        if not os.path.exists(f'./crossvalidation/{k}/train'):
            os.makedirs(f'./crossvalidation/{k}/train')
        for f in X_train.tolist():
            f = f[0]
            shutil.copy(f,f'./crossvalidation/{k}/train')
            for ff in glob.glob(f.replace('.txt','*')):
                shutil.copy(ff, f'./crossvalidation/{k}/train')

        if not os.path.exists(f'./crossvalidation/{k}/test'):
            os.makedirs(f'./crossvalidation/{k}/test')
        for f in X_test.tolist():
            f = f[0]
            shutil.copy(f,f'./crossvalidation/{k}/test')
            for ff in glob.glob(f.replace('.txt','*')):
                shutil.copy(ff, f'./crossvalidation/{k}/test')

        k += 1


    print("FINISH")
