from keras.models import model_from_json
import argparse
import numpy as np
from scipy import signal
from scipy.fftpack import fft
from scipy.io import wavfile

def rnn_from_wave_file(file_name):
    rate, wave_data = wavfile.read(file_name)
    sampling_rate = rate
    b, a = signal.butter(2, [25.0 / sampling_rate * 2, 200.0 / sampling_rate * 2], btype='bandpass')
    seconds = 4.8
    window_size = 64
    step_size = 32
    downsample_rate = 4
    n_timestep = int(seconds * sampling_rate / downsample_rate / step_size)
    data_set_step = 32
    max_n_sample = 512
    data_set = np.zeros((max_n_sample, n_timestep, window_size / 2))
    counter = 0
    x = signal.lfilter(b, a, wave_data)
    downsampled_x = np.zeros(len(wave_data) / downsample_rate)
    for i in range(len(wave_data) // downsample_rate):
        downsampled_x[i] = x[i * downsample_rate]
    n_frames = (len(downsampled_x) - window_size) // step_size
    y = np.zeros((n_frames, window_size // 2), dtype=np.float)
    for i in range(n_frames):
        y[i] = np.absolute(fft(downsampled_x[i * step_size:i * step_size + window_size]))[:window_size // 2]
    for i in range((y.shape[0]-1 - n_timestep) // data_set_step+1):
        data_set[counter] = y[i * data_set_step:i * data_set_step + n_timestep]
        counter += 1
    return data_set[:counter]

if __name__=='__main__':
    parser = argparse.ArgumentParser(description='Predict by RNN in freq domain with Physionet 2016 challenge data')
    parser.add_argument('-mf', dest='mf', help='model file', required=True)
    parser.add_argument('-wf', dest='wf', help='weight file', required=True)
    args = parser.parse_args()

    model = model_from_json(open(args.mf).read())
    model.load_weights(args.wf)
    model.compile(loss='binary_crossentropy', optimizer='RMSprop', metrics=['accuracy'])

    file_name = 'pre.wav'
    data = rnn_from_wave_file(file_name)
    result = model.predict(data)
