# this is designed to pre-define the shape of all training data, in order to avoid out-of-memory error.

import numpy as np
import torch
from helper_code import load_header, get_labels, load_recording, get_age, get_sex, get_leads, get_adcgains,  get_baselines
import pdb

def get_training_data_pre_define(header_files, recording_files, classes, leads, len_signal, multi_patch=False, len_overlap=256):

    num_recordings = len(recording_files)
    num_classes = len(classes)

    num_samples = 72590
    data = torch.zeros((num_samples, 12, len_signal), dtype=torch.float) 
    labels = torch.zeros((num_samples, num_classes), dtype=torch.bool) # [43101, 111]
    attachedpro = torch.zeros((num_samples, 5), dtype=torch.float)

    flag = 0
    for i in range(num_recordings): # access one record a time
        # Load header and recording.
        # header = load_header(header_files[i])
        # recording = load_recording(recording_files[i]) # recording: (12, 5000)
        
        data, labels, attachedpro, flag = get_features_train(flag, data, labels, attachedpro, header_files[i], recording_files[i], leads, len_signal, len_overlap, multi_patch)

    return data, labels, attachedpro
    

def get_features_train(flag, data, labels, attachedpro, header_file, recording_file, leads, len_signal, len_overlap, multi_patch):

    header = load_header(header_file)
    recording = load_recording(recording_file) # recording: (12, 5000)
    recording = torch.Tensor(recording)
    

    this_attachedpro = attachedpro[flag]
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    if age >= 0 and age <= 120:
        this_attachedpro[0] = age
    else:
        this_attachedpro[1] = 1

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        this_attachedpro[2] = 1
        # sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        this_attachedpro[3] = 1
        # sex = 1
    else:
        this_attachedpro[4] = 1

    # only count target classes (there are some classes of illness we do not identify) 
    current_labels = get_labels(header)
    this_labels = labels[flag]
    for label in current_labels:
        if label in classes:
            j = classes.index(label)
            this_labels[j] = 1
    

    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Pre-process recordings.
    adc_gains = get_adcgains(header, leads) #[12,]
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    len_recording = len(recording[0])
    recording = (recording - torch.unsqueeze(torch.Tensor(baselines), dim=1)) / torch.unsqueeze(torch.Tensor(adc_gains), dim=1)  # unsqueeze from [12,] to [12,1]


    num_patch = 1
    if not multi_patch:
        if len_recording > len_signal:
            data[flag] = recording[:,:len_signal]
        else:
            data[flag, :, 0 : len_recording] = recording
    else:
        num_patch = int( np.ceil((len_recording - len_signal) / (len_signal - len_overlap)) + 1)

        cnt = 0
        start = 0
        if len_recording > len_signal:
            while (len_recording - start)  > len_signal:
                data[flag+cnt] = recording[:, start : start + len_signal]
                cnt += 1
                start += len_signal - len_overlap
            data[flag+cnt] = recording[:, len_recording - len_signal : len_recording]
            assert cnt == num_patch - 1
        else:
            data[flag, :, 0:len_recording] = recording
    
    labels[flag:flag+num_patch] = this_labels.expand(num_patch, -1) 
    attachedpro[flag:flag+num_patch] = this_attachedpro.expand(num_patch, -1)
    flag += num_patch

    return data, labels, attachedpro, flag
 