import numpy as np

def get_fft(signal, fs=500, upper_freq=50, bins_per_freq=2):
    
    fft = np.fft.fft(signal)
    fft_freq = np.fft.fftfreq(signal.size, 1.0 / fs)
    fft_abs = np.absolute(fft) ** 2

    feat_vals = np.zeros(upper_freq * bins_per_freq)
    for i in range(fft_freq.size):
        freq = fft_freq[i]
        amp = fft_abs[i]
        if freq > 0 and freq < upper_freq:
            idx = int(freq * bins_per_freq)
            feat_vals[idx] += amp
    
    return feat_vals
    


def get_fft_feats(signals, peaks, fs, upper_freq=50, bins_per_freq=2):
    n_leads = signals.shape[0]
    n_beats = peaks.size - 1
    n_feats_per_beat = n_leads * upper_freq * bins_per_freq
    ffts = np.zeros([n_beats, n_feats_per_beat])
    # print(ffts.shape)
    for i in range(1, peaks.size):
        beat_12_feats = []
        beat_start = peaks[i - 1]
        beat_end = peaks[i]
        for signal_idx, signal in enumerate(signals):
            beat_feats = get_fft(signal[beat_start: beat_end],
                                 fs,
                                 upper_freq,
                                 bins_per_freq)
            beat_12_feats.append(beat_feats)           
        ffts[i - 1, :] = np.hstack(beat_12_feats)
            
    
    return ffts