import glob
import pandas as pd
import re
import numpy as np
from skmultilearn.model_selection import iterative_train_test_split
import os
from shutil import copyfile

class get_Dx():
    def __init__(self,verbose=True):
        self.i = 0
        self.verbose = verbose

    def __call__(self, x):
        self.i += 1
        if self.verbose:
            print(self.i)
        with open(x['header']) as f:
            lines = f.readlines()

        x['fs'] = int(lines[0].split(" ")[2])
        for l in lines:
            if "#Dx: " in l:
                temp = re.findall(r'\d+', l)
                res = list(map(int, temp))
                x['Dx'] = res
                return x

class get_Targets():
    def __init__(self, verbose=True):
        self.i = 0
        self.verbose = verbose
        self.mapping = pd.read_csv("data/dx_mapping_scored.csv")['SNOMED CT Code -A'].tolist()


    def __call__(self, x):
        self.i += 1
        if self.verbose:
            print(self.i)

        x["Target"] = np.zeros((27,))
        for d in x["Dx"]:
            try:
                x["Target"][self.mapping.index(d)] = 1
            except Exception as exc:
                pass
        return x


files = glob.glob("./*/*.hea")
#files = files[0:1000]
files = pd.DataFrame(files,columns=['header'])
files['data'] = files.apply(lambda x: x['header'].replace(".hea",".mat"),axis=1)
files = files.apply(get_Dx(verbose=True),axis=1)
files["numDx"] = files.apply(lambda x: len(x["Dx"]),axis=1)
files = files.sort_values(by="numDx")
files = files.apply(get_Targets(verbose=True),axis=1)
files["id"] = files.apply(lambda x: (x['header'].split("/")[-1]).split(".")[0],axis=1)
recordings = files[['data','header','fs','id']].to_numpy()
labels = files['Target'].to_list()
labels = np.array(labels)
X_train, y_train, X_test, y_test = iterative_train_test_split(recordings, labels, test_size=0.25)

if not os.path.exists('./_train'):
    os.makedirs('./_train')

if not os.path.exists('./_test'):
    os.makedirs('./_test')


def copy(arr,dir):
    for i in range(arr.shape[0]):
        print(i)
        header = "./_"+ dir + "/"+arr[i,3]+".hea"
        data = "./_"+ dir + "/"+arr[i,3]+".mat"

        copyfile(arr[i,1],header)
        copyfile(arr[i,0],data)

copy(X_train,dir='train')
copy(X_test,dir='test')


