#!/usr/bin/python3

"""
Classify an ECG record for CinC2017 challenge
"""
import sys
import os
import pickle
import scipy.io as sio
from scipy.stats import mode
import numpy as np
from sklearn.preprocessing import minmax_scale, robust_scale
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
# Keras model
from keras.models import model_from_json
# EEMD
#from pyeemd import eemd

def read_ecg_record(record,folder='./'):
    """
    Reads an ECG record from CINC 2017 challenge, record format is
    Matlab 'mat' file and ASCII 'hea'.
    
    Input:
    
    record --- record name, basename of a file without extension, e.g. 'A00010' (string)
    
    folder (optional) --- the name of a sub-folder of the current directory, 
    where to search for the record. Folder can be specified without trailing '/'. 
    Default: './'.

    Output:

    ecg_data --- 1D numpy array with the ECG record, type = float.

    """
    if folder.endswith('/'):
        matname = folder+record+'.mat'
        heaname = folder+record+'.hea'
    else:
        matname = folder+'/'+record+'.mat'
        heaname = folder+'/'+record+'.hea'

    ### Read ECG itself
    ecg_data = sio.loadmat(matname)['val'][0].astype(np.float)
    if len(ecg_data.shape) != 1:
        raise ValueError('ECG record must be a 1D array.')

    return ecg_data

def segment_ecg(ecg_data,width=2600):
    """
    Segment an ECG record to extract a subset of the original data. The subset is centered
    within the ECG record.
    
    Input:
    
    ecg_data --- 1D numpy array (ECG recording)
    
    width (optional) --- the width of the output segment. Default: 2600.

    Output: processed ECG signal, 1D numpy array.
    """

    n = ecg_data.size
    if n < (width-1):
        raise ValueError('ECG record must be at least {} samples long'.format(str(width)))
    if ecg_data.ndim != 1:
        raise ValueError('ECG record must be 1D array.')

    # start = (n - width) // 2
    # segment = ecg_data[start:(start+width)]

    segments = []
    for (start, end) in windows(ecg_data, width):
        segments.append(ecg_data[start:end])
    
    return segments

def windows(data, size):
    start = 0
    while (start+size) <= len(data):
        yield start, start + size
        start += size

def process_ecg(ecg_data,transform=None,reshape='image',iwidth=2600):
    """
    Process an ECG record to make it ready for classification.
    Namely, segment (see segment_ecg) and transform.
    
    Input:
    
    ecg_data --- 1D numpy array of the ECG record, type float.

    transform --- type of a transform: None (leave original ECG, default), 
    'af' (Auto-correlation), 'eemd' (EEMD).
    """
    segments = segment_ecg(ecg_data,width=iwidth)
    
    out_segments = []
    for i in range(len(segments)):
        out = segments[i]
        # if transform == 'eemd':
        #     out = eemd(out)
        if reshape == 'image':
            out = out.reshape((1,out.size))## for normalization
            robust_scale(out,axis=1,copy=False)##axis=1 --> normalize samples
            out = out.reshape((1,1,out.size,1))
        elif reshape == 'table':
            out = out.reshape((1,out.size))
            robust_scale(out,axis=1,copy=False)##axis=1 --> normalize samples
        elif reshape == '1d':
            out = out.reshape((out.size,))
        out_segments.append(out)

    return out_segments;

def nn_model_clf(signal,groups,model):
    groups = np.asarray(groups)
    p = model.predict(signal,batch_size=1)
    print("Prediction:",p)
    idx = np.argmax(p[0])
    return groups[idx]

def load_keras_model(model_load_fname="model.json",weights_load_fname="weights.hdf5"):
    json_file = open(model_load_fname,"r")
    loaded_model_json = json_file.read()
    json_file.close()
    model = model_from_json(loaded_model_json)
    model.load_weights(weights_load_fname)
    print("Keras model is loaded")
    return model;

def classify_ecg_nn(record_name, model, **kwargs):
    """
    Classify an ECG signal.
    """
    ##### 1. Get the ECG signal from the record
    if "folder" in kwargs:
        folname = kwargs["folder"]
    else:
        folname = "./"
        
    signal = read_ecg_record(record_name,folder=folname)

    ##### 2. Process the signal
    if "transform" in kwargs:
        trans = kwargs["transform"]
    else:
        trans = None
    if "reshape" in kwargs:
        rsh = kwargs["reshape"]
    else:
        rsh = 'image'
    if "iwidth" in kwargs:
        iw = kwargs["iwidth"]
        
    signal2 = process_ecg(signal,reshape=rsh,iwidth=iw,transform=trans)

    ##### 3. Call a classificator for each segment
    cls = np.empty((len(signal2),),dtype=str)
    for j in range(len(signal2)):
        cls[j] = nn_model_clf(signal2[j],['A','N','O','~'],model)

    ## Take the mode as the final classification decision
    pred = mode(cls)[0][0]
    print(record_name,":",cls,"=>",pred)
    return pred;

def classify_ecg_sk(record_name,clf_fname='clf.pickle',**kwargs):
    if "folder" in kwargs:
        folname = kwargs["folder"]
    else:
        folname = "./"
    signal = read_ecg_record(record_name,folder=folname)
    if "transform" in kwargs:
        trans = kwargs["transform"]
    else:
        trans = None
    if "reshape" in kwargs:
        rsh = kwargs["reshape"]
    else:
        rsh = 'image'
    if "iwidth" in kwargs:
        iw = kwargs["iwidth"]
    signal2 = process_ecg(signal,reshape=rsh,iwidth=iw,transform=trans)
    with open(clf_fname,'rb') as f:
        clf = pickle.load(f)
    cls = np.empty((len(signal2),),dtype=str)
    for j in range(len(signal2)):
        cls[j] = clf.predict(signal2[j])[0]
    pred = mode(cls)[0][0]
    print(record_name,":",cls,"=>",pred)
    return pred;

def report_class(record,cls,fname='answers.txt'):
    with open(fname,'a') as f:
        f.write(record+','+cls+'\n')

def classify_ecg_records(clf='nn',records_file='validation/RECORDS',out_file='answers.txt',**kwargs):
    if os.path.exists(out_file):
        os.remove(out_file)
    if clf == 'nn':
        model = load_keras_model()
    with open(records_file,'r') as inf:
        for line in inf:
            record = line.strip("\n ")
            if clf == 'nn':
                C = classify_ecg_nn(record,model,**kwargs)
            elif clf == 'sk':
                C = classify_ecg_sk(record,**kwargs)
            else:
                raise ValueError('Unknown classifier.')
            report_class(record,C,fname=out_file)


if __name__ == "__main__":
    IW = 300
    FOL = 'validation/'
    CLF = 'nn'
    TRNS = None
    if '-w' in sys.argv:
        IW = int(sys.argv[sys.argv.index('-w')+1])
    if '-f' in sys.argv:
        FOL = sys.argv[sys.argv.index('-f')+1]
    if '-c' in sys.argv:
        CLF = sys.argv[sys.argv.index('-c')+1]
    if '-t' in sys.argv:
        TRNS = sys.argv[sys.argv.index('-t')+1]
    if '-r' in sys.argv:
        ## Classify single record '-r' argument is given
        record = sys.argv[sys.argv.index('-r')+1]
        print("Classify a record:",record)
        if CLF == 'nn':
            cls = classify_ecg_nn(record,load_keras_model(),folder=FOL,
                                  reshape='image',iwidth=IW,transform=TRNS)
        elif CLF == 'sk':
            cls = classify_ecg_sk(record,folder=FOL,reshape='table',iwidth=IW,transform=TRNS)
        else:
            raise ValueError('Unknown classifier.')
        report_class(record,cls,fname='answers.txt')
    else:
        ## Classify array of records from a file
        if CLF == 'nn':
            classify_ecg_records(folder=FOL,iwidth=IW,transform=TRNS,reshape='image')
        elif CLF == 'sk':
            classify_ecg_records(clf=CLF,folder=FOL,iwidth=IW,reshape='table',transform=TRNS)
        else:
            raise ValueError('Unknown classifier.')
