from abc import abstractclassmethod
import torch, torchaudio
import numpy as np, pickle, os, pdb
import wfdb
from wfdb import processing
from utils.helper_code import load_header, get_labels, load_recording, get_age, get_sex, get_leads, get_adcgains,  get_baselines, get_frequency, find_challenge_files, is_integer, load_weights, twelve_leads

def process_data(args, data_directory):
    # Find header and recording files.
    print('Finding header and recording files...')
    header_files, recording_files = find_challenge_files(data_directory)
    num_recordings = len(recording_files)

    if not num_recordings:
        raise Exception('No data was provided.')

    # Extract classes from dataset.
    print('Extracting classes...')

    classes = set()
    for header_file in header_files:
        header = load_header(header_file)
        classes |= set(get_labels(header))
    if all(is_integer(x) for x in classes):
        classes = sorted(classes, key=lambda x: int(x)) # Sort classes numerically if numbers.
    else:
        classes = sorted(classes) # Sort classes alphanumerically otherwise.
    # 107 classes
    classes = ['368009', '6374002', '10370003', '11157007', '13640000', '17338001', '27885002', '29320008', '39732003', '47665007', '49578007', '53741008', '54016002', '54329005', '55930002', '57054005', '59118001', '59931005', '60423000', '63593006', '65778007', '67198005', '74390002', '74615001', '75532003', '77867006', '81898007', '82226007', '84114007', '89792004', '111288001', '111975006', '164861001', '164865005', '164867002', '164873001', '164884008', '164889003', '164890007', '164895002', '164896001', '164909002', '164917005', '164921003', '164930006', '164931005', '164934002', '164937009', '164947007', '164951009', '195042002', '195060002', '195080001', '195101003', '195126007', '204384007', '233917008', '251120003', '251139008',
    '251146004', '251164006', '251170000', '251180001', '251182009', '251200008', '251259000', '251266004', '251268003', '253339007', '253352002', '266249003', '266257000', '270492004', '282825002', '284470004', '314208002', '370365005', '413444003', '413844008', '425419005', '425623009', '425856008', '426177001', '426434006', '426627000', '426648003', '426664006', '426749004', '426761007', '426783006', '426995002', '427084000', '427172004', '427393009', '428417006', '428750005', '429622005', '445118002', '445211001', '446358003', '446813000', '698247007', '698252002', '704997005', '713422000', '713426002', '713427006', '67741000119109'] 
    weights_file = 'weights.csv'
    normal_class = '426783006'
    equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
    args['normal_class'] = normal_class

    # Load the scored classes and the weights for the Challenge metric.
    print('Loading weights...')
    eva_classes, weights = load_weights(weights_file, equivalent_classes)
    eva_class_indices = []
    print('eva_classes len :',len(eva_classes))
    for ieva_class in eva_classes:
        # print('ieva_class :',ieva_class,'index: ',classes.index(ieva_class))
        eva_class_indices.append(classes.index(ieva_class))
    args['eva_class_indices'] = eva_class_indices
    args['eva_classes'] = eva_classes

    preprocess_path = "preprocess/preprocess_{}.p".format(data_directory)
    os.makedirs(os.path.dirname(preprocess_path), exist_ok=True)
    if args["preprocess"]:
        print("Preprocessing")
        # @Todo use 4096 or 5000?
        data, labels, attachedpro = get_training_data(header_files, recording_files, classes, twelve_leads)  
        # data, labels, attachedpro = get_training_data(header_files, recording_files, classes, twelve_leads, len_signal=5000, multi_patch=args['multi_patch'], len_overlap=256)  

        if args['save_preprocessed_data']:
            try:
                pickle.dump( (data, labels, attachedpro), open( preprocess_path, "wb" ) )
            except OverflowError: 
                print("OverflowError: cannot serialize a bytes object larger than 4 GiB")
                pass
    else:
        print("Loading preprocessed data")
        data, labels, attachedpro = pickle.load( open( preprocess_path, "rb" ) )
    new_labels = []
    for ilabel in labels:
        new_labels.append(ilabel[eva_class_indices])
    # labels = labels[:, eva_class_indices]
    
    # deal with samples with all weighted class shown zero including normal class -> we set it to be normal class
    normal_index = classes.index(normal_class)
    if args['Othertypes']:
        rowsum = torch.sum(labels, dim=1)
        print('row sum:',rowsum)
        nolabelsum = torch.sum(rowsum == 0)
        print('nolabelsum :',nolabelsum)
        normal_index = eva_classes.index(normal_class)
        labels[rowsum == 0][normal_index] = 1
        print('new row sum:',torch.sum(labels, dim=1))
    xqrssig = []
    preprocess_sig_path = "preprocess/preprocess_xqrssig_{}.p".format(data_directory)
    if args['add_domain_knowledge']:
        if args["preprocess"]:
            for i in range(len(data)):
                sig = data[i][1].numpy()
                xqrs = processing.XQRS(sig=sig, fs=257)
                xqrs.detect()
                sig_len = len(sig)
                addup = torch.zeros(sig_len)
                # print('sig_len :',sig_len)
                # print('xqrs.qrs_inds :',xqrs.qrs_inds)
                inds = xqrs.qrs_inds
                addup[inds] = -1
                if len(inds) > 0:
                    start = xqrs.qrs_inds[0]
                else:
                    start = 0
                cnt = 0
                for j in range(start,sig_len):
                    if addup[j] == -1:
                        cnt = 0
                    addup[j] = cnt
                    cnt += 1
                # print('data[i].shape :',data[i].shape)
                # print('addup.shape :',addup.shape)
                xqrssig.append(addup.unsqueeze(0))
            if args['save_preprocessed_data']:
                try:
                    pickle.dump(xqrssig, open( preprocess_sig_path, "wb" ) )
                except OverflowError: 
                    print("OverflowError: cannot serialize a bytes object larger than 4 GiB")
                    pass
        else:
            print("Loading xqrs preprocessed data")
            xqrssig = pickle.load( open( preprocess_sig_path, "rb" ) )
    return data, new_labels, attachedpro, weights, xqrssig


def get_training_data(header_files, recording_files, classes, leads):

    num_recordings = len(recording_files)
    num_classes = len(classes)

    data = [] # np.zeros((num_recordings, 12, len_signal), dtype=np.float32) 
    labels = [] # np.zeros((num_recordings, num_classes), dtype=np.bool) # [43101, 111]
    attachedpro = [] # np.zeros((num_recordings, 5), dtype=np.float32)
      
    normal_index = classes.index('426783006')
    for i in range(num_recordings): # access one record a time
        # Load header and recording.
        header = load_header(header_files[i])
        # recording = load_recording(recording_files[i]) # recording: (12, 5000)
        
        this_attachedpro, this_data = _get_features(header_files[i], recording_files[i], leads)
        data.append(this_data)

        # only count target classes (there are some classes of illness we do not identify) 
        current_labels = get_labels(header)
        this_labels = torch.zeros(num_classes) 
        for label in current_labels:
            if label in classes:
                j = classes.index(label)
                this_labels[j] = 1

        patch = this_data.shape[0]   

        # this_attachedpro = this_attachedpro.expand(patch, -1)    
        # this_labels = this_labels.expand(patch, -1)
        labels.append(this_labels)
        attachedpro.append(this_attachedpro)
    print('data len:',len(data))
    print('labels len:',len(labels))
    print('attachedpro len:',len(attachedpro))
    # data = torch.cat(data, dim=0)
    # labels = torch.cat(labels, dim=0)
    # attachedpro = torch.cat(attachedpro, dim=0)
    return data, labels, attachedpro
    

def _get_features(header_file, recording_file, leads):

    header = load_header(header_file)
    recording = load_recording(recording_file) # recording: (12, 5000)
    recording = torch.Tensor(recording)
    

    attachedpro = torch.zeros(5, dtype=torch.float32)
    # Extract age.
    age = get_age(header)
    fre = get_frequency(header)
    if age is None:
        age = float('nan')

    if age >= 0 and age <= 120:
        attachedpro[0] = age / 120
    else:
        attachedpro[1] = 1

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        attachedpro[2] = 1
        # sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        attachedpro[3] = 1
        # sex = 1
    else:
        attachedpro[4] = 1


    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Pre-process recordings.
    adc_gains = get_adcgains(header, leads)
    baselines = get_baselines(header, leads)
    module = torchaudio.transforms.Resample(fre, 257)
    num_leads = len(leads)
    len_recording = len(recording[0])
    # print('len_recording :',len_recording)
    recording = (recording - torch.unsqueeze(torch.Tensor(baselines), dim=1)) / torch.unsqueeze(torch.Tensor(adc_gains), dim=1) 
    # print('recording shape:',recording.shape)
    recording = module(recording)
    # print('recording shape:',recording.shape)
    len_recording = len(recording[0])

    return attachedpro, recording
 
# Extract features from the header and recording.
# currently this 'get_features' function is designed for test process
def get_features(header, recording, leads, len_signal, len_overlap, multi_patch):

    attachedpro = np.zeros(5, dtype=np.float32)
    # Extract age.
    age = get_age(header)
    fre = get_frequency(header)
    recording = torch.Tensor(recording)
    if age is None:
        age = float('nan')

    if age >= 0 and age <= 120:
        attachedpro[0] = age / 120
    else:
        attachedpro[1] = 1

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        attachedpro[2] = 1
        # sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        attachedpro[3] = 1
        # sex = 1
    else:
        attachedpro[4] = 1


    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Pre-process recordings.
    adc_gains = get_adcgains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    # recording = (recording - np.expand_dims(baselines, axis=1)) / np.expand_dims(adc_gains, axis=1) 
    recording = (recording - torch.unsqueeze(torch.Tensor(baselines), dim=1)) / torch.unsqueeze(torch.Tensor(adc_gains), dim=1) 

    module = torchaudio.transforms.Resample(fre, 257)
    recording = module(recording)
    len_recording = len(recording[0])

    if not multi_patch:
        data = np.zeros((1, num_leads, len_signal))
        if len_recording > len_signal:
            data[0] = recording[:,:len_signal]
        else:
            data[0, :, 0 : len_recording] = recording
    else:
        num_patch = int( np.ceil((len_recording - len_signal) / (len_signal - len_overlap)) + 1)
        data = np.zeros((num_patch, num_leads, len_signal))

        cnt = 0
        start = 0
        if len_recording > len_signal:
            while (len_recording - start)  > len_signal:
                data[cnt] = recording[:, start : start + len_signal]
                cnt += 1
                start += len_signal - len_overlap
            data[cnt] = recording[:, len_recording - len_signal : len_recording]
        else:
            # try:
            data[0, :, 0:len_recording] = recording
            # except:
                # pdb.set_trace()

    return attachedpro, data
