from __future__ import print_function

import sys
import os
import glob
import getopt
import scipy.io.wavfile

import keras.models as km
import keras.layers as kl
import keras.optimizers as ko

def compile_model(nhid=100):
	net = km.Sequential()
	net.add(kl.LSTM(input_dim=1, output_dim=nhid, return_sequences=False))
	net.add(kl.Dense(input_dim=nhid, output_dim=1))
	net.compile(loss='mse',optimizer='adam')
	net.load_weights("weights.h5py")
	return net

def compile_pcg2ecg_t():
	nhid_tLSTM = 32
	tLSTM = km.Sequential()
	tLSTM.add(kl.LSTM(input_dim=1, output_dim=nhid_tLSTM, return_sequences=True))
	tLSTM.add(kl.Activation('linear'))
	tLSTM.add(kl.TimeDistributedDense(input_dim=nhid_tLSTM, output_dim=1))
	tLSTM.add(kl.Activation('linear'))
	tLSTM.compile(loss="mse", optimizer="adam")
	tLSTM.load_weights("weights_tLSTM.h5py")
	return tLSTM

def compile_pcg2ecg_c():
	nhid_cLSTM = 100
	cLSTM = km.Sequential()
	cLSTM.add(kl.LSTM(input_dim=1, output_dim=nhid_cLSTM, return_sequences=False))
	cLSTM.add(kl.Dense(input_dim=nhid_cLSTM, output_dim=1))
	train = ko.RMSprop()
	cLSTM.compile(loss='mse',optimizer=train)
	cLSTM.load_weights("weights_cLSTM.h5py")
	return cLSTM

def predict_with_model(model, dataset):
	pred = model.predict(dataset.reshape((1,len(dataset),1)))
	pred_sign = -1 if pred < 0 else 1
	return pred_sign

def predict_pcg2ecg(tLSTM, cLSTM, data):
	a_ecg = tLSTM.predict(data.reshape((1,len(data),1)))
	pred = cLSTM.predict(a_ecg)
	pred_sign = -1 if pred < 0 else 1
	return pred_sign

def predict_and_write(model, infile, outfile):
	_, data = scipy.io.wavfile.read(infile+".wav")
	with open(outfile, 'a+') as f:
		pred = predict_with_model(model, data)
		f.write("%s,%s\n" % (os.path.basename(infile), str(pred)))

def predict_and_write_pcg2ecg(tLSTM, cLSTM, infile, outfile):
	_, data = scipy.io.wavfile.read(infile+".wav")
	with open(outfile, 'a+') as f:
		pred = predict_pcg2ecg(tLSTM, cLSTM, data)
		f.write("%s,%s\n" % (os.path.basename(infile), str(pred)))

def create_answers(dirname):
	ans = "answers.txt"
	if os.path.exists(ans):
		os.remove(ans)
	m = compile_model()
	for f in glob.glob(os.path.join(dirname,"*.wav")):
		predict_and_write(m, f[:-4], ans)

def create_answers_pcg2ecg(dirname):
	ans = "answers.txt"
	if os.path.exists(ans):
		os.remove(ans)
	t = compile_pcg2ecg_t()
	c = compile_pcg2ecg_c()
	for f in glob.glob(os.path.join(dirname,"*.wav")):
		predict_and_write_pcg2ecg(t, c, f[:-4], ans)

if __name__ == "__main__":
	usage = "usage:"
	usage += "\nrun_lstm.py predict <filename> (write prediction to answers.txt)"
	usage += "\nrun_lstm.py answers (delete and rebuild answers.txt)"
	argcount = {"predict": 1, "answers": 0, "predict-t": 1, "answers-t": 0}
	if len(sys.argv) == 1 or len(sys.argv) < argcount[sys.argv[1]] + 1:
		print(usage)
		exit()
	args = sys.argv[1:]
	mode = args[0]
	if mode == "predict":
		name = args[1]
		m = compile_model()
		predict_and_write(m, name, "answers.txt")
	elif mode == "answers":
		create_answers("../validation")
	elif mode == "predict-t":
		name = args[1]
		t = compile_pcg2ecg_t()
		c = compile_pcg2ecg_c()
		predict_and_write_pcg2ecg(t, c, name, "answers.txt")
	elif mode == "answers-t":
		create_answers_pcg2ecg("../validation")
