# -*- coding: utf-8 -*-
"""
Created on Fri Apr 29 13:02:00 2016

@author: nicho_000
"""
import json
import numpy as np
import sys
import os
import wave
from scipy import signal
from sklearn import preprocessing
from sklearn import ensemble
from sklearn import decomposition
from sklearn import cluster
from scipy.optimize import curve_fit

#from matplotlib import use as useMat
#useMat('Agg')
import pylab

def graph_spectrogram(xmin,xmax,window,sound_info,frame_rate):
    freqs, t, spectrum = signal.spectrogram(sound_info[xmin:xmax], fs=frame_rate, 
                                            nperseg=2000, noverlap=1980, 
                                            window=signal.get_window(("gaussian",window),2000),
                                            detrend=False)
    return spectrum, freqs, t


def get_wav_info(wav_file):
    wav = wave.open(wav_file, 'r')
    frames = wav.readframes(-1)
    sound_info = pylab.fromstring(frames, 'Int16')
    frame_rate = wav.getframerate()
    wav.close()
    return sound_info, frame_rate
    
def clusterit(data,dim=2,clusters=2) :
    #pca    
    pca_model = decomposition.PCA(n_components=dim)
    pca_model.fit(data)
    transformed_data = pca_model.transform(data)
    #transformed_data = data
    #clustering
    clustering = cluster.KMeans(n_clusters=clusters,random_state=2013,n_init=10,init='random')
    clustering.fit(transformed_data)
    clusters = clustering.predict(transformed_data)
    return clusters

def insertLabels(labelChain, newLabels, roundNum) :
    keyLabel = 1
    if np.mean(newLabels) > 0.5 :
        keyLabel = 0
    index = 0
    for i in range(len(labelChain)) :
        if (labelChain[i]==0) :
            if (newLabels[index] == keyLabel) :
                labelChain[i] = roundNum
            index = index + 1
    return labelChain
   
def numberOfSpans(l,index) :
    count = 0
    midspan = False
    for i in l :
        if midspan :
            if i!=index :
                midspan = False
        else :
            if i==index :
                count += 1
                midspan = True
    return count

def sigmoid(x, x0, k, L):
     y = L / (1 + np.exp(-k*(x-x0)))
     return y

def fitsigmoid(Y) :
    Ysort = np.sort(Y)
    X = 1/len(Y)*np.arange(len(Y))
    popt, pcov = curve_fit(sigmoid, X, Ysort,p0=[1.1,1.1,1.1],maxfev=1000000)
    return popt

def identify_hs(data):

    levels = [2,5,10,20,30,40,50,75]
    features = []
    for l in levels :
        opts = fitsigmoid(data[:,l])
        features = features + opts.tolist()
        
    return features

def get_features(infile):
    
        sound_info, frame_rate = get_wav_info(infile)    
        xmin = 0
        xmax = len(sound_info)
        
        spectrum15,freqs15,t15 = graph_spectrogram(xmin,xmax,15,sound_info,frame_rate)
        spectrum75,freqs75,t75 = graph_spectrogram(xmin,xmax,75,sound_info,frame_rate)

        datat15 = np.transpose(spectrum15[0:1000,:])
        datat75 = np.transpose(spectrum75[0:1000,:])
        datac15 = datat15.reshape(-1,10).mean(1).reshape(len(t15),-1)        
        datac75 = datat75.reshape(-1,10).mean(1).reshape(len(t75),-1)        
        datap15 = preprocessing.normalize(datac15,norm='max',axis=0)
        datap75 = preprocessing.normalize(datac75,norm='max',axis=0)        
        
        #clusters over the whole thing
        clusters = np.zeros(len(datap75))
        for i in range(1,5) :
            datatemp = datap75[clusters[:]==0]
            clusterstemp = clusterit(datatemp)
            clusters = insertLabels(clusters, clusterstemp, i)
        keep = np.ones(len(datap75))

        #clusters over the whole thing
        clusters15 = np.zeros(len(datap15))
        for i in range(1,5) :
            datatemp = datap15[clusters15[:]==0]
            clusterstemp = clusterit(datatemp)
            clusters15 = insertLabels(clusters15, clusterstemp, i)

        num75 = numberOfSpans(clusters,1)
        num15 = numberOfSpans(clusters15,1)

        #eliminate
        firstclean = 0 
        for i in range(1,5) :
            spans = numberOfSpans(clusters,i)
            if(spans) <=3 :
                for j in range(len(clusters)) :
                    if clusters[j]==i :
                        keep[j]=0
            else :
                if firstclean == 0 :
                    firstclean = i
        
        dataclean75 = datap75[keep[:]==1]
        dataclean15 = datap15[keep[:]==1]
        datacleanc75 = datac75[keep[:]==1]        
        

        datap75t = preprocessing.normalize(datac75,norm='max',axis=1)
        datacleanp1 = datap75t[keep[:]==1]

        features = []
        features = features + [np.mean(dataclean75[:,0:25]), np.mean(dataclean75[25:50]), np.mean(dataclean75[50:100])]
        features = features + [np.var(dataclean75[:,0:25]),np.var(dataclean75[:,25:50]),np.var(dataclean75[:,50:100])]
        features = features + [np.mean(dataclean15[:,0:25]), np.mean(dataclean15[25:50]), np.mean(dataclean15[50:100])]
        features = features + [np.var(dataclean15[:,0:25]),np.var(dataclean15[:,25:50]),np.var(dataclean15[:,50:100])]
        features = features + [np.mean(datacleanp1[:,0:25])]
        features = features + [np.mean(datacleanc75[:,0:25]), np.mean(datacleanc75[25:50]), np.mean(datacleanc75[50:100])]
        features = features + [np.var(datacleanc75[:,0:25]),np.var(datacleanc75[:,25:50]),np.var(datacleanc75[:,50:100])]

        features = features + identify_hs(dataclean75)
        features = features + identify_hs(dataclean15)
        features = features + [num75/len(clusters)]
        features = features + [num15/len(clusters15)]

        #flatten array
        features =  np.array(features).flatten()

        filterpick = [ 5, 51, 23,  1,  7, 46, 10, 15,  4,  1,  1, 30,  1,  1, 24, 11,  1,
                      16,  9, 35,  3, 47, 49, 21, 36, 53, 50, 57, 25, 29, 48, 52, 62, 41,
                      27, 38, 56, 33, 40,  8, 31, 14, 60, 28, 18, 44, 43, 13, 54, 26, 42,
                      59, 22,  6, 39, 55, 32, 61,  2, 19, 58, 12, 20, 45,  1,  1, 34, 37,
                      17]
        
        return [features[np.array(filterpick)[:]<25]]


if __name__ == '__main__':

    with open('X_features_filtered.txt', 'r') as infile:
            X = json.load(infile)
            infile.close()
    with open('Y_features.txt', 'r') as infile:
            y = json.load(infile)
            infile.close()
    clf = ensemble.RandomForestRegressor(n_estimators=500,max_features=5,random_state=20082013)
    clf.fit(X,y)    
        
    validate = False

    if validate == False:    
        the_file = sys.argv[1]+".wav"
        features = get_features(the_file)
        clf_pred = clf.predict(features)    
        print("{:s},{:d}".format(sys.argv[1],int(np.sign(clf_pred[0]+0.4))))
    
    if validate == True: 
        answers = open('C:\\Users\\nicho_000\\Desktop\\hs\\answers.txt','w')

        for i in os.listdir('C:\\Users\\nicho_000\\Desktop\\Heart_Sounds\\validation'):
            if i.endswith(".wav"):
                the_file = 'C:\\Users\\nicho_000\\Desktop\\Heart_Sounds\\validation\\'+i
                features = get_features(the_file)
                clf_pred = clf.predict(features)
                answers.write("{:s},{:d}\n".format(i,int(np.sign(clf_pred[0]+0.4))))  
                answers.flush()
        answers.close()
    
    