import argparse
from sklearn.externals import joblib
import numpy as np
import scipy.io
from scipy import signal
from biosppy.signals import ecg
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from biosppy.signals.tools import filter_signal
from sklearn.decomposition import PCA
from scipy.signal import resample
from scipy import spatial
from numpy.linalg import norm
import scipy.fftpack
from scipy.stats import stats, entropy
from biosppy.signals.tools import find_extrema, get_heart_rate
from collections import Counter
import itertools
from keras.utils import np_utils
import pickle

STEPS_COUNT = 23
FEATURES_LENGTH = 46

def hrto(rri):
    TO = np.zeros((len(rri)))
    for i in range(len(rri)-2):
        z = i +1  
        TO[z] = ( (rri[z+1] + rri[z]) - (rri[z] + rri[z-1]) ) / (rri[z] + rri[z-1])
        
    return TO
from scipy.stats import stats

def peaksInQS(signal, q, s):
    
    #num_peaks = np.zeros((len(s)-1))
    stds = np.zeros((len(s)-1))
    fisher_qs = np.zeros((len(s)-1))
    
    for x in range(len(s)-1):
        signal_range = signal[s[x]:q[x+1]]
        signal_range = signal_range[55:]
        signal_range = signal_range[:len(signal_range)-55]
        if len(signal_range) > 4:
        
            stds[x] = np.std(np.diff(signal_range), ddof=1)
    
    return stds

def qs_signal(signal, q, s):
    
    samples = np.zeros((0,))
    sds = np.zeros((0,))
    num_peaks = np.zeros((0,))
    t = np.zeros((0,))
    p = np.zeros((0,))
    t_width = np.zeros((0,))
    p_width = np.zeros((0,))
    qses = []

    for x in range(len(s)-1):
        signal_range = signal[s[x]:q[x+1]]
        z = 0
        #signal_range[signal_range<-40] = 0
        signal_range = signal_range[55:]
        signal_range = signal_range[:len(signal_range)-55]
            
        #signal_range = null_outliers(signal_range)
        samples = np.hstack((samples, signal_range))
        qses.append(signal_range)
        sd_arange = np.zeros((len(signal_range),))
        sd_arange[sd_arange==0] = np.std(signal_range, ddof=1)
        sds = np.hstack((sds, sd_arange))
        
        mean2 = running_mean(signal_range, 15)
        mean1 = running_mean(signal_range, 9)
        
        idx = np.argwhere(np.diff(np.sign(mean1[:len(mean2)] - mean2)) != 0).reshape(-1) + 0
        if len(idx) < 2:
            mask = np.ones((len(signal_range)))*-1
        else:
            mask = mean_by_indices(signal_range, idx, 1)
        
        indices = np.where(np.diff(np.sign(mask)))[0]
        
        if len(mask) == 0:
            num_peaks = np.hstack((num_peaks, [len(indices)]))
            indices_new = np.hstack(([0], indices))
            indices_new = np.hstack((indices_new, [len(signal_range)]))
            widths = np.diff(indices_new)
            
            widths_filtered = widths
            widths_filtered[1::2] = 0
        elif mask[0] == 1 and mask[-1] == 1: 
            num_peaks = np.hstack((num_peaks, [len(indices)]))
            indices_new = np.hstack(([0], indices))
            indices_new = np.hstack((indices_new, [len(signal_range)]))
            widths = np.diff(indices_new)
            
            widths_filtered = widths
            widths_filtered[1::2] = 0

        elif mask[0] == -1 and mask[-1] == -1:
            num_peaks = np.hstack((num_peaks, [len(indices)-2]))
            indices_new = np.hstack(([0], indices))
            indices_new = np.hstack((indices_new, [len(signal_range)]))
            widths = np.diff(indices_new)
            
            widths_filtered = widths
            widths_filtered[::2] = 0
        elif mask[0] == 1 and mask[-1] == -1:
            num_peaks = np.hstack((num_peaks, [len(indices)-1]))
            indices_new = np.hstack(([0], indices))
            indices_new = np.hstack((indices_new, [len(signal_range)]))
            widths = np.diff(indices_new)
            
            widths_filtered = widths
            widths_filtered[1::2] = 0
        elif mask[0] == -1 and mask[-1] == 1:
            num_peaks = np.hstack((num_peaks, [len(indices)-1]))
            indices_new = np.hstack(([0], indices))
            indices_new = np.hstack((indices_new, [len(signal_range)]))
            widths = np.diff(indices_new)
            widths_filtered = widths
            widths_filtered[::2] = 0

        
        if len(widths_filtered) < 2:
            max_indexes = np.array([0, 0])
            max_widths = np.array([0, 0])
        else:
            max_indexes = np.sort(np.argpartition(widths_filtered, -2)[-2:])
            max_widths = np.take(widths_filtered, max_indexes)
        try:
            if max_indexes[0]+1 > len(indices_new)-1:
                ts = np.argmax(signal_range[indices_new[max_indexes[0]]:indices_new[max_indexes[0]+1]])
                ts += len(signal_range[:indices_new[max_indexes[0]]])
            else:

                ts = np.argmax(signal_range[indices_new[max_indexes[0]]:indices_new[max_indexes[0]+1]])
                ts += len(signal_range[:indices_new[max_indexes[0]]])
        except:
            if len(signal_range) > 0:
                ts = np.argmax(signal_range[:])
                ts += len(signal_range[:indices_new[max_indexes[0]]])
            else:
                ts = 0
                ts += len(signal_range[:indices_new[max_indexes[0]]])
        
        try:
            if max_indexes[1]+1 > len(indices_new)-1:

                ps = np.argmax(signal_range[indices_new[max_indexes[1]]:indices_new[max_indexes[1]+1]])
                ps += len(signal_range[:indices_new[max_indexes[1]]])
            else:
                ps = np.argmax(signal_range[indices_new[max_indexes[1]]:indices_new[max_indexes[1]+1]])
                ps += len(signal_range[:indices_new[max_indexes[1]]])
        except:
            if len(signal_range) > 0:
                ps = np.argmax(signal_range[:])
                ps += len(signal_range[:indices_new[max_indexes[1]]])
            else:
                ps = 0
                ps += len(signal_range[:indices_new[max_indexes[1]]])
        
        t = np.hstack((t, [s[x]+ts+55]))
        p = np.hstack((p, [s[x]+ps+55]))
        
        t_width = np.hstack((t_width, max_widths[0]))
        p_width = np.hstack((p_width, max_widths[1]))
        
    p = np.array(p, dtype='int')
    t = np.array(t, dtype='int')
    
    return num_peaks, t, p, t_width, p_width

def extract_qs(signal, rpeaks, before, after):
    R = np.sort(rpeaks)
    length = len(signal)
    q = []
    s = []
    
    for r in R:
        a = int(r - before)
        if a < 0:
            continue
        
        b = int(r + after)
        if b > length:
            break
            
        if signal[r] > 0:
            minBefore = np.min(signal[a:r])
        else:
            minBefore = np.max(signal[a:r])
            
        idx = np.where(signal[a:r] == minBefore)
        q.append(a+idx[0][0])
        
        if signal[r] > 0:
            minAfter = np.min(signal[r:b])
        else:
            minAfter = np.max(signal[r:b])
        
        idx = np.where(signal[r:b] == minAfter)
        s.append(r+idx[0][0])
    

    q = np.array(q, dtype='int')
    s = np.array(s, dtype='int')

    
    return q, s

def running_mean(x, N):
    cumsum = np.cumsum(np.insert(x, 0, 0)) 
    return (cumsum[N:] - cumsum[:-N]) / N

def mean_by_indices(qses, indices, th=15):
    
    indices_diff = np.hstack(([0], indices))
    diff = np.diff(indices_diff)
    
    filt = np.where(diff<th)
    
    #filt = diff[filt[0]]
    filt = filt[0]
    
    indices = np.delete(indices, filt)
    
    qses = np.array(qses)
    
    qses_mean = np.mean(qses)
    
    means = np.zeros((0,))
    
    if np.mean(qses[:indices[0]]) > qses_mean:
        mean_zero = 1
    else:
        mean_zero = -1 #np.mean(qses[:indices[0]])
    
    means_zeros = np.zeros((len(qses[:indices[0]])))
    means_zeros[means_zeros==0] = mean_zero
    means = np.hstack((means, means_zeros))
    
    for x in range(len(indices)-1):
        if np.mean(qses[indices[x]:indices[x+1]]) > qses_mean:
            mean_zero = 1
        else:
            mean_zero = -1 #np.mean(qses[indices[x]:indices[x+1]])
        means_zeros = np.zeros((len(qses[indices[x]:indices[x+1]])))
        means_zeros[means_zeros==0] = mean_zero
        means = np.hstack((means, means_zeros))
        
    if np.mean(qses[indices[len(indices)-1]:]) > qses_mean:
        mean_zero = 1
    else:
        mean_zero = -1 #np.max(qses[indices[len(indices)-1]:])
    
    means_zeros = np.zeros((len(qses[indices[len(indices)-1]:])))
    means_zeros[means_zeros==0] = mean_zero
    means = np.hstack((means, means_zeros))
    
    return means

def shorten(x, length):
    if len(x) > length:
        x = x[:length]
    else:
        newx = np.zeros((length))
        newx[:len(x)] = x
        x = newx
        
    return x

def get_features(path, filename, single_recording=True):
    mat = scipy.io.loadmat(path+filename)
    samples = mat['val']
    samplerate = 300.
    samples = samples.reshape((samples.shape[1]))

    order = int(0.3 * samplerate)

    samples, _, _ = filter_signal(samples, ftype='FIR', band='bandpass', order=order, frequency=[3, 45], sampling_rate=300.)

    top = np.sort(samples, axis=None)
    bottom = top[-10:]

    top = top[:10]

    positive = 0
    negative = 0

    for r in range(len(top)):
        if np.abs(top[r]) > np.abs(bottom[r]):
            negative += 1
        else:
            positive += 1
        
        
    if positive < negative:
        samples = -1*samples


    rpeaks_h, = ecg.hamilton_segmenter(signal=samples, sampling_rate=300.)
    rpeaks_h = ecg.correct_rpeaks(signal=samples, rpeaks=rpeaks_h, sampling_rate=300., tol=0.05)
    rpeaks_h = np.array(rpeaks_h)
    rpeaks_h = rpeaks_h.reshape((rpeaks_h.shape[0]*rpeaks_h.shape[1]))
    
    if len(rpeaks_h) > 3:
        rpeaks = rpeaks_h

        templates, _ = ecg.extract_heartbeats(signal=samples, rpeaks=rpeaks, sampling_rate=300., before=0.2, after=0.4)

        before = 0.02 * samplerate
        after = 0.05 * samplerate
        q, s = extract_qs(signal=samples, rpeaks=rpeaks, before=before, after=after)

        num_peaks, t, p, t_width, p_width = qs_signal(samples, q, s)
        
        p = np.hstack(([0], p))
        p_width = np.hstack(([0], p_width))

        stds = peaksInQS(samples, q, s)

        symetry = rpeaks[:len(rpeaks)-1]-rpeaks[1:]

        ar_peaks = np.take(samples, rpeaks)
        s_peaks = np.take(samples, s)
        q_peaks = np.take(samples, q)
        t_peaks = np.take(samples, t)
        p_peaks = np.take(samples, p)
        
        to = hrto(np.diff(rpeaks))

        length = len(rpeaks)

        num_peaks = shorten(num_peaks, length)
        q = shorten(q, length)
        s = shorten(s, length)
        p = shorten(p, length)
        t = shorten(t, length)
        p_width = shorten(p_width, length)
        t_width = shorten(t_width, length)
        symetry = shorten(symetry, length)
        ar_peaks = shorten(ar_peaks, length)
        s_peaks = shorten(s_peaks, length)
        q_peaks = shorten(q_peaks, length)
        t_peaks = shorten(t_peaks, length)
        p_peaks = shorten(p_peaks, length)
        to = shorten(to, length)


        rminuss = ar_peaks - s_peaks
        rminusq = ar_peaks - q_peaks
        
        p_width = p_width #[:-1]/np.sqrt(np.diff(rpeaks))*100
        t_width = t_width #[:-1]/np.sqrt(np.diff(rpeaks))*100

        sq = (q - s) #[:-1]/np.sqrt(np.diff(rpeaks))*100
        pr = (rpeaks - p) #[:-1]/np.sqrt(np.diff(rpeaks))*100
        qt = (t - q)[:-1]/np.sqrt(np.diff(rpeaks))*100
        st = (t - s) #[:-1]/np.sqrt(np.diff(rpeaks))*100
        pq = (q - p)


        ar_peaks_diff = np.diff(ar_peaks)
        ddif_ar = shorten(np.diff(ar_peaks_diff), length)
        symetry = np.diff(symetry)
        rpeaks = np.diff(rpeaks)
        ddif_r = shorten(np.diff(rpeaks), length)
        q = np.diff(q)
        s = np.diff(s)
        p = np.diff(p)
        t = np.diff(t)

        steps = STEPS_COUNT
        num_features = FEATURES_LENGTH

        if single_recording:
            num_samples = 1
        else:
            num_samples = int(len(rpeaks)/(num_features+1))

        if num_samples == 0:
            num_samples = 1

        features = np.zeros((num_samples, steps, num_features))

        for x in range(num_samples):
            nrpeaks = rpeaks[x*(num_features+1):(x+1)*(num_features+1)]
            nq = q[x*(num_features+1):(x+1)*(num_features+1)]
            ns = s[x*(num_features+1):(x+1)*(num_features+1)]
            npp = p[x*(num_features+1):(x+1)*(num_features+1)]
            nt = t[x*(num_features+1):(x+1)*(num_features+1)]
            nsq = sq[x*(num_features+1):(x+1)*(num_features+1)]
            npr = pr[x*(num_features+1):(x+1)*(num_features+1)]
            nqt = qt[x*(num_features+1):(x+1)*(num_features+1)]
            nst = st[x*(num_features+1):(x+1)*(num_features+1)]
            nstds = stds[x*(num_features):(x+1)*(num_features)]
            nnpeaks = num_peaks[x*(num_features):(x+1)*(num_features)]
            nsym = symetry[x*(num_features):(x+1)*(num_features)]
            nard = ar_peaks_diff[x*(num_features):(x+1)*(num_features)]
            nas = s_peaks[x*(num_features):(x+1)*(num_features)]
            nrmins = rminuss[x*(num_features):(x+1)*(num_features)]
            nrminq = rminusq[x*(num_features):(x+1)*(num_features)]
            #nffti = ffti[x*(num_features):(x+1)*(num_features)]
            nppeak = p_peaks[x*(num_features):(x+1)*(num_features)]
            ntpeak = t_peaks[x*(num_features):(x+1)*(num_features)]
            #nddif_ar = hr[x*(num_features):(x+1)*(num_features)]
            nddif_r = ddif_r[x*(num_features):(x+1)*(num_features)]
            #nfisher = fisher_qs[x*(num_features):(x+1)*(num_features)]
            npq = pq[x*(num_features):(x+1)*(num_features)]
            npwidth = p_width[x*(num_features):(x+1)*(num_features)]
            ntwidth = t_width[x*(num_features):(x+1)*(num_features)]
            nto = to[x*(num_features):(x+1)*(num_features)]
            #nentropy = entropy_t[x*(num_features):(x+1)*(num_features)]
            #nentropy_qs = entropy_qs[x*(num_features):(x+1)*(num_features)]

            if len(nrpeaks) <= num_features+1:
                features[x:x+1,0:1,0:len(nrpeaks)-1] = nrpeaks[:len(nrpeaks)-1]
                features[x:x+1,1:2,0:len(nrpeaks)-1] = nsym[:len(nrpeaks)-1]
                features[x:x+1,2:3,0:len(nq)-1] = nq[:len(nq)-1]
                features[x:x+1,3:4,0:len(ns)-1] = ns[:len(ns)-1]
                features[x:x+1,4:5,0:len(npp)-1] = npp[:len(npp)-1]
                features[x:x+1,5:6,0:len(nt)-1] = nt[:len(nt)-1]
                features[x:x+1,6:7,0:len(rpeaks)-1] = nstds[:len(rpeaks)-1]
                features[x:x+1,7:8,0:len(nsq)-1] = nsq[:len(nsq)-1]
                features[x:x+1,8:9,0:len(npr)-1] = npr[:len(npr)-1]
                features[x:x+1,9:10,0:len(nqt)-1] = nqt[:len(nqt)-1]
                features[x:x+1,10:11,0:len(nst)-1] = nst[:len(nst)-1]
                features[x:x+1,11:12,0:len(rpeaks)-1] = nnpeaks[:len(rpeaks)-1]
                features[x:x+1,12:13,0:len(rpeaks)-1] = nard[:len(rpeaks)-1]
                features[x:x+1,13:14,0:len(rpeaks)-1] = nas[:len(rpeaks)-1]
                #features[x:x+1,14:15,0:len(rpeaks)-1] = nffti[:len(rpeaks)-1]
                features[x:x+1,14:15,0:len(rpeaks)-1] = nppeak[:len(rpeaks)-1]
                features[x:x+1,15:16,0:len(rpeaks)-1] = ntpeak[:len(rpeaks)-1]
                #features[x:x+1,16:17,0:len(rpeaks)-1] = nddif_ar[:len(rpeaks)-1]
                features[x:x+1,16:17,0:len(rpeaks)-1] = nddif_r[:len(rpeaks)-1]
                features[x:x+1,17:18,0:len(rpeaks)-1] = nrmins[:len(rpeaks)-1]
                features[x:x+1,18:19,0:len(rpeaks)-1] = nrminq[:len(rpeaks)-1]
                #features[x:x+1,19:20,0:len(rpeaks)-1] = nfisher[:len(rpeaks)-1]
                features[x:x+1,19:20,0:len(rpeaks)-1] = npq[:len(rpeaks)-1]
                features[x:x+1,20:21,0:len(rpeaks)-1] = npwidth[:len(rpeaks)-1]
                features[x:x+1,21:22,0:len(rpeaks)-1] = ntwidth[:len(rpeaks)-1]
                features[x:x+1,22:23,0:len(rpeaks)-1] = nto[:len(rpeaks)-1]

            else:
                features[x:x+1,0:1,0:num_features] = nrpeaks[:num_features]
                features[x:x+1,1:2,0:num_features] = nsym[:num_features]
                features[x:x+1,2:3,0:num_features] = nq[:num_features]
                features[x:x+1,3:4,0:num_features] = ns[:num_features]
                features[x:x+1,4:5,0:num_features] = npp[:num_features]
                features[x:x+1,5:6,0:num_features] = nt[:num_features]
                features[x:x+1,6:7,0:num_features] = nstds[:num_features]
                features[x:x+1,7:8,0:num_features] = nsq[:num_features]
                features[x:x+1,8:9,0:num_features] = npr[:num_features]
                features[x:x+1,9:10,0:num_features] = nqt[:num_features]
                features[x:x+1,10:11,0:num_features] = nst[:num_features]
                features[x:x+1,11:12,0:num_features] = nnpeaks[:num_features]
                features[x:x+1,12:13,0:num_features] = nard[:num_features]
                features[x:x+1,13:14,0:num_features] = nas[:num_features]
                #features[x:x+1,14:15,0:num_features] = nffti[:num_features]
                features[x:x+1,14:15,0:len(rpeaks)-1] = nppeak[:num_features]
                features[x:x+1,15:16,0:num_features] = ntpeak[:num_features]
                #features[x:x+1,16:17,0:num_features] = nddif_ar[:num_features]
                features[x:x+1,16:17,0:num_features] = nddif_r[:num_features]
                features[x:x+1,17:18,0:num_features] = nrmins[:num_features]
                features[x:x+1,18:19,0:num_features] = nrminq[:num_features]
                #features[x:x+1,19:20,0:num_features] = nfisher[:num_features]
                features[x:x+1,19:20,0:len(rpeaks)-1] = npq[:num_features]
                features[x:x+1,20:21,0:num_features] = npwidth[:num_features]
                features[x:x+1,21:22,0:num_features] = ntwidth[:num_features]
                features[x:x+1,22:23,0:num_features] = nto[:num_features]

            return features
    else:
        return False

from keras import backend as K
from keras.layers import Input, Dense, Permute, Reshape
from keras.models import Model
from keras.layers.recurrent import LSTM
import keras
from keras.optimizers import Adam

def build_model(full=True):
    inputs = Input(shape=(STEPS_COUNT, FEATURES_LENGTH,))
    shape = Permute((2, 1))(inputs)
    lstm1 = LSTM(8, return_sequences=True, dropout_W=0.2, dropout_U=0.2)(shape)

    lstm2 = LSTM(8, return_sequences=True, dropout_W=0.2, dropout_U=0.2)(lstm1)
    
    lstm3 = LSTM(4, dropout_W=0.2, dropout_U=0.2)(lstm2)

    
    if full:
        out = Dense(3, activation='softmax')(lstm3)
        model = Model(input=inputs, output=out)
    else:
        model = Model(input=inputs, output=lstm3)
    
    return model

def write_answer(filename, result, resultfile="answers.txt"):
        fo = open(resultfile, 'a')
        fo.write(str(filename) + "," + str(result) + "\n")
        fo.close()

        return True

# Parse arguments
parser = argparse.ArgumentParser(description='This is a script to train and test PhysioNet 2016 challenge data.')
parser.add_argument('-i','--inputfile', help='Input file name (full path) without .wav ending',required=False)
args = parser.parse_args()

filename = str(args.inputfile)
path = ''

# Load features
features = get_features(path, filename, single_recording=True)
X = np.zeros((1, STEPS_COUNT, FEATURES_LENGTH))
X[0] = features

X = np.nan_to_num(X)

# Scale features
params = [[[ 184.47900276], [ 117.63358294]], [[-0.23659151], [ 49.31848708]], [[ 184.4784653], [ 117.63355812]], [[ 184.44842686], [ 117.61816366]], [[ 185.24659883], [ 120.67547697]], [[ 184.11243809], [ 130.65376032]], [[ 4.27486786], [ 7.54813831]], [[-12.06194646], [ 6.88674396]], [[ 79.90607298], [ 55.42125151]], [[ np.inf], [ np.inf]], [[ 57.5812324], [ 65.10223948]], [[ 1.61740054], [ 2.46969341]], [[ 2.31930575], [ 114.73319636]], [[-178.683285], [ 176.95879311]], [[ 21.84712646], [ 67.33395767]], [[ 55.99142731], [ 81.12838205]], [[ 0.23659151], [ 49.31848708]], [[ 637.31867845], [ 505.55613724]], [[ 467.94051406], [ 381.23756542]], [[ 73.55607726], [ 53.86135113]], [[ 13.74461056], [ 17.37796325]], [[ 17.36302666], [ 15.13310563]], [[-0.34581473], [ 0.47948587]]]

for x in range(STEPS_COUNT):
    scaler = StandardScaler()
    scaler.mean_ = params[x][0]
    scaler.scale_ = params[x][1]
    scaler.std_ = params[x][1]
    X[0:1,x:x+1] = scaler.transform(X[0:1,x:x+1].reshape((FEATURES_LENGTH, 1))).reshape((1, FEATURES_LENGTH))


X = np.nan_to_num(X)

# Build model
model = build_model()
optimizer = Adam(lr=0.0009)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['f1score'])
model.load_weights('weights.hdf5')

# Predict and write it down
result = model.predict(X)
result2 = np_utils.probas_to_classes(result)

if result2[0] == 0:
    result = 'A'
elif result2[0] == 1:
    result = 'N'
elif result2[0] == 2:
    result = 'O'
elif result2[0] == 3:
    result = '~'

write_answer(filename, result)
