import scipy.io as sio
import scipy.signal as sig
import pandas as pd
import numpy as np
from skmultilearn.model_selection import iterative_train_test_split
from preprocessing import *
import re





class dataset():
    #ordered by weights.csv file
    CLASSES_ALL = ['270492004','164889003','164890007','426627000','713427006','713426002','445118002','39732003','164909002','251146004','698252002','10370003','284470004','427172004','164947007','111975006','164917005','47665007','59118001','427393009','426177001','426783006','427084000','63593006','164934002','59931005','17338001']
    #equivalent classes
    EQUIVALENT=[['713427006', '59118001'],
                 ['284470004', '63593006'],
                 ['427172004', '17338001']]

    CLASSES = sorted(list(set(CLASSES_ALL) - set([e[1] for e in EQUIVALENT])))
    stop=1


    #ordered by SNOMED string
    #CLASSES = ['10370003', '111975006', '164889003', '164890007', '164909002', '164917005', '164934002', '164947007', '17338001', '251146004', '270492004', '284470004', '39732003', '426177001', '426627000', '426783006', '427084000', '427172004', '427393009', '445118002', '47665007', '59118001', '59931005', '63593006', '698252002', '713426002', '713427006']
    @staticmethod
    def train_test_split(header_files,test_eval):
        headers = []
        fs = []
        labels = []
        for h in header_files:
            headers.append(h)
            _fs,_y,_ = dataset.read_header(h)
            fs.append(_fs)
            labels.append(_y)

        headers = np.array(headers).reshape(-1,1)
        fs = np.array(fs)
        labels = np.array(labels)
        corr = np.corrcoef(labels.T)
        cweights = np.sum(labels,axis=0)
        cweights = 1/cweights
        cweights = cweights/np.max(cweights)
        cweights = np.around(cweights,2)
        X_train, y_train, X_test, y_test = iterative_train_test_split(headers, labels, test_size=0.25)
        train = dataset(X_train,eval=False)
        test = dataset(X_test,eval=test_eval)
        # TODO valid split if necessary
        return train,test,cweights


    def __init__(self, headers,eval):
        self.eval = eval
        if isinstance(headers,np.ndarray):
            self.headers = headers[:,0]
        elif isinstance(headers,list):
            self.headers = np.array(headers)


    def __len__(self):
        return len(self.headers)

    def __getitem__(self, item):
        x = sio.loadmat(self.headers[item].replace(".hea",".mat"))
        x = np.asarray(x['val'])
        fs,t,srange = self.read_header(self.headers[item])



        X = preprocessing(x,fs,srange,eval=self.eval)

        return X, t

    @staticmethod
    def read_header(filename):
        y = np.zeros(len(dataset.CLASSES))
        with open(filename, 'r') as f:
            header = f.readlines()
        fs = int(header[0].split(" ")[2])
        srange = re.findall("\d+/mV",header[1])[0]
        srange = int(srange[:-3])
        stop = 1

        for l in header:
            if l.startswith('#Dx'):
                tmp = l.split(': ')[1].split(',')
                for c in tmp:
                    cc = c.strip()
                    if cc ==dataset.EQUIVALENT[0][1]:
                        cc =dataset.EQUIVALENT[0][0]
                    if cc ==dataset.EQUIVALENT[1][1]:
                        cc =dataset.EQUIVALENT[1][0]
                    if cc ==dataset.EQUIVALENT[2][1]:
                        cc =dataset.EQUIVALENT[2][0]
                    if cc in dataset.CLASSES:
                        y[dataset.CLASSES.index(cc)] = 1
        return fs,y,srange
