import torch
import numpy as np

def check_datapoint_ok(x):
    if not x:
        return False
    for key in x:
        if x[key] is None:
            return False
    return True


def batchify(X_batch_dict_list, Y_batch_list, device='cpu'):
    X_batch = {}
    Y_batch = np.vstack(Y_batch_list)
    feat_names = ['header_feats', 'peak_feats',
                  'peak_delta_feats', 'fft_feats']
    for feat_name in feat_names:
        feat_list = [x[feat_name] for x in X_batch_dict_list]
        X_batch[feat_name] = np.vstack(feat_list)
        X_batch[feat_name] = torch.tensor(X_batch[feat_name],
                                          dtype=torch.float).to(device)
    
    Y_batch = torch.tensor(Y_batch, dtype=torch.float).to(device)
        
    return X_batch, Y_batch
    
    
def batcher(X, Y, batch_size=64, device=None, max_seq_len=10, ids=None):
    """
    ids: which elements to take, used for train/test split.
    """
    X_batch_by_size = [[] for i in range(max_seq_len)]
    Y_batch_by_size = [[] for i in range(max_seq_len)]
    
    for idx, (x, y) in enumerate(zip(X, Y)):
        if ids is not None and idx not in ids:
            continue
        if not check_datapoint_ok(x):
            continue
        n_beats = x['peak_delta_feats'].shape[1]
        X_batch_by_size[n_beats - 1].append(x)
        Y_batch_by_size[n_beats - 1].append(y)

        if len(X_batch_by_size[n_beats - 1]) == batch_size:
            X_batch, Y_batch = batchify(X_batch_by_size[n_beats - 1],
                                        Y_batch_by_size[n_beats - 1],
                                        device=device)
            
            yield X_batch, Y_batch
            # Reset batch bucket.
            X_batch_by_size[n_beats - 1] = []
            Y_batch_by_size[n_beats - 1] = []  
    
    for X_batch_remain, Y_batch_remain in zip(X_batch_by_size, Y_batch_by_size):
        if X_batch_remain:
            X_batch, Y_batch = batchify(X_batch_remain, Y_batch_remain,
                                        device=device)
            yield X_batch, Y_batch
            
    del X_batch_by_size, Y_batch_by_size
        
    