#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of required functions, remove non-required functions, and add your own function.

from helper_code import *
import numpy as np, os, sys, joblib
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from scipy import signal, stats
twelve_lead_model_filename = '12_lead_model.sav'
six_lead_model_filename = '6_lead_model.sav'
three_lead_model_filename = '3_lead_model.sav'
two_lead_model_filename = '2_lead_model.sav'
import torch 
import team_classifier_code
from ecg_recording import DIAG_CLASSES, create_ecg_recording, extract_header
from ecg_processing import process_records
from sklearn.preprocessing import MinMaxScaler
from feature_extraction import compute_statistical_features
from sklearn.multiclass import OneVsRestClassifier

from train_model_Deep import train_deep_model
from test_model_deep import load_model_deep, test_single_example

USE_DEEP_MODEL=True

# Define the Challenge lead sets. These variables are not required. You can change or remove them.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)
fold = 1
################################################################################
#
# Training function
#
################################################################################

# Train your model. This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    
    if USE_DEEP_MODEL:
        print('training deep model...')
        parameters = {'model_name':'deep_model.pkl','epochs':40,'alpha':0.0,'random_cuts':6}
        train_deep_model(data_directory,model_directory,parameters)
    else:
        # Find header and recording files.
        load_saved_features = True
        print('Finding header and recording files...')
    
        header_files, recording_files = find_challenge_files(data_directory)
        num_recordings = len(recording_files)
    
        if not num_recordings:
            raise Exception('No data was provided.')
    
        # Create a folder for the model if it does not already exist.
        if not os.path.isdir(model_directory):
            os.mkdir(model_directory)
    
        # Extract classes from dataset.
        print('Extracting classes...')
    
        classes = DIAG_CLASSES
    
        # Extract features and labels from dataset.
        print('Extracting features and labels...')
        
        
    
        if load_saved_features:
            try:
                data, labels, indices = load_features_from_file("feature_extraction_result/Train/features_fold"+str(fold)+"_train.npz")
                scaler = MinMaxScaler()
                data = scaler.fit_transform(data)
                num_recordings = data.shape[0]
            except:
                print("WARNING: Features could not be loaded and are extracted newly")
                data, labels, scaler, indices = process_recordings(data_directory)
        else:
            data, labels, scaler, indices = process_recordings(data_directory)
        # TODO
        #------------------use this? -----------------------
        # six_leads_indices=list()
        # for t_lead in six_leads:
        #     six_leads_indices.append(twelve_leads.index(t_lead))
        # three_leads_indices=list()
        # for t_lead in three_leads:
        #     three_leads_indices.append(twelve_leads.index(t_lead))
        # two_leads_indices=list()
        # for t_lead in two_leads:
        #     two_leads_indices.append(twelve_leads.index(t_lead))
        # ---------------------------------------------------    
        scaler_twelve_lead = scaler
        scaler_six_lead = MinMaxScaler()
        six_f_indices = np.r_[0:indices[6]]
        scaler_six_lead.min_ = scaler.min_[six_f_indices]
        scaler_six_lead.scale_ = scaler.scale_[six_f_indices]
        scaler_four_lead = MinMaxScaler()
        four_f_indices = np.r_[0:indices[3],indices[7]:indices[8]] #TODO test
        scaler_four_lead.min_ = scaler.min_[four_f_indices]
        scaler_four_lead.scale_ = scaler.scale_[four_f_indices]
        scaler_three_lead = MinMaxScaler()
        three_f_indices = np.r_[0:indices[2],indices[7]:indices[8]]
        scaler_three_lead.min_ = scaler.min_[three_f_indices]
        scaler_three_lead.scale_ = scaler.scale_[three_f_indices]
        scaler_two_lead = MinMaxScaler()
        two_f_indices = np.r_[0:indices[2]] # updated to I,II
        scaler_two_lead.min_=scaler.min_[two_f_indices]
        scaler_two_lead.scale_=scaler.scale_[two_f_indices]
    

    # Train a model for each lead set.
        scale_count = 0
        for leads in lead_sets:
            print('Training model for {}-lead set: {}...'.format(len(leads), ', '.join(leads)))
    
            # Define parameters for random forest classifier.
            n_estimators = 100     # Number of trees in the forest.
            max_leaf_nodes = 2000 # Maximum number of leaf nodes in each tree.
            random_state = 123   # Random state; set for reproducibility.
    
            # Extract the features for the model.
            lead_indices =  [twelve_leads.index(lead) for lead in leads]
            feature_indices = [0,1]
            for idx in lead_indices:
                feature_indices.extend( list(range(indices[idx], indices[idx+1])))
            features = data[:, feature_indices]
    
            # Train the model.
            imputer = SimpleImputer().fit(features)
            features = imputer.transform(features)
            
            
            #Neural network classifier
            #classifier = team_classifier_code.One_Vs_All_Net(features.shape[1], labels.shape[1])
            #classifier = classifier.fit(features, labels.astype(bool))
            
            #Random Forrest classifier
            classifier = OneVsRestClassifier(RandomForestClassifier(n_estimators=n_estimators,max_leaf_nodes = max_leaf_nodes, random_state = random_state)).fit(features, labels.astype(bool))
            #classifier = RandomForestClassifier(n_estimators=3, random_state=0).fit(features, labels)
            
            #Chose correct scaler for saving model
            if scale_count == 1:
                scaler = scaler_six_lead
            elif scale_count == 2:
                scaler = scaler_four_lead
            elif scale_count == 3:
                scaler = scaler_three_lead
            elif scale_count == 4:
                scaler = scaler_two_lead 
            else:
                print("Warning: Unknown scaler")
            #Save model and scaler
            save_model(model_directory, leads, classes, imputer, classifier,scaler)
            scale_count += 1



################################################################################
#
# File I/O functions
#
################################################################################

# Save a trained model. This function is not required. You can change or remove it.
def save_model(model_directory, leads, classes, imputer, classifier, scaler):
    d = {'leads': leads, 'classes': classes, 'imputer': imputer, 'classifier': classifier, 'scaler':scaler}
    filename = os.path.join(model_directory, get_model_filename(leads))
    joblib.dump(d, filename, protocol=0)

# Load a trained model. This function is *required*. Do *not* change the arguments of this function.
def load_model(model_directory, leads):
    if USE_DEEP_MODEL:
        return load_model_deep(model_directory, leads)
    else:
        filename = os.path.join(model_directory, get_model_filename(leads))
        return joblib.load(filename)

# Define the filename(s) for the trained models. This function is not required. You can change or remove it.
def get_model_filename(leads):
    return 'model_' + '-'.join(sort_leads(leads)) + '.sav'


################################################################################
#
# Running trained model functions
#
################################################################################

# Generic function for running a trained model.
def run_model(model, header, recording):
    if USE_DEEP_MODEL:
        return test_single_example(model, header, recording)
    else:
        classes = model['classes']
        leads = model['leads']
        imputer = model['imputer']
        classifier = model['classifier']
        scaler = model['scaler']
        #scaler = classifier.get_scaler()
        
       
        #Extract information from header
        #==========================================================================
        lead_names = get_leads(header)
        sampling_rate = get_frequency(header)
        baselines = get_baselines(header, leads)
        adcgains = get_adc_gains(header, leads)
        # Extract age.
        age = get_age(header)
        if age is None or age<0:
            age = float('nan')
    
        # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
        sex = get_sex(header)
        if sex in ('Female', 'female', 'F', 'f'):
            sex = 0
        elif sex in ('Male', 'male', 'M', 'm'):
            sex = 1
        else:
            sex = float('nan')
        #==========================================================================
        
        # Load features.
        data_pre, indices = compute_statistical_features(recording,sampling_rate,lead_names,resampling_freq=250)
        data = np.concatenate((np.array((age,sex)), data_pre))
    
    
        # Impute missing data.
        features = data.reshape(1, -1)
        features = scaler.transform(features)
        features = imputer.transform(features)
    
        # Predict labels and probabilities.
        #From neural network
        #labels = classifier.predict_labels(features) #neural network function
        #probabilities = classifier.predict_probs(features) #neural network function
        
        #From RandomForrest
        labels = classifier.predict(features)
        labels = np.asarray(labels, dtype=np.int)[0]
        probabilities = classifier.predict_proba(features)
        probabilities = np.asarray(probabilities, dtype=np.float32).flatten()
    
    
        return classes, labels, probabilities

################################################################################
#
# Other functions
#
################################################################################

# Extract features from the header and recording.
def get_features(header, recording, leads):
    
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Get sampling rate
    frequency = get_frequency(header)
    print(frequency)
    ### Pre-process recordings.
    #Get numerical values of ecg records
    adc_gains = get_adcgains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]
    
    
    #Filter and normalize records and resample if necessary
    resample_rate = 250 
    shape = recording.shape
    record_len = shape[1]
    print(record_len)
    temp = np.zeros((num_leads,round(record_len*resample_rate/frequency)))
    for i in range(num_leads):
        sos_low = signal.butter(3, 30, 'lowpass', False, output='sos',fs=frequency)  #get butterworth low-pass filter
        sos_high = signal.butter(3, 0.5, 'highpass', False, output='sos',fs=frequency)  #get butterworth high-pass filter
        recording[i, :] = signal.sosfilt(sos_low, recording[i, :]) #filter signal
        recording[i, :] = signal.sosfilt(sos_high, recording[i, :]) #filter signal
        #recording[i, :] = tools.normalize(recording[i, :]) #normalize signal to zero-mean and unitary standard deviation (see bioppy.signals.tools)
        recording[i, :] = stats.zscore(recording[i, :])  #normalize signal to zero-mean and unitary standard deviation
        if frequency != resample_rate:
            temp[i, :] = signal.resample(recording[i, :], round(record_len*resample_rate/frequency))
        else:
            temp[i, :] = recording[i, :]  
        #out = ecg.ecg(signal = temp[i,:], sampling_rate=resample_rate,show=True)
    
    #===================================================================================================================================#
    # Feature extraction (team)
    
    ##Sementation (RR beats, PQ interval, ST interval, )
    #r_peaks = np.zeros((num_leads,round(record_len*resample_rate/frequency)))
    # R-peak segmentation
    for i in range(num_leads):
        #TODO: extract peaks for other leads and do a consitency check
        if i ==1:
            r_peaks =np.asarray(ecg.christov_segmenter(signal=temp[i,:], sampling_rate=resample_rate))
            print(r_peaks)
        # TODO: get PQ, PT, ST 

    # HRV features    
    IBI = np.diff(r_peaks)
    MeanIBI = np.mean(r_peaks/resample_rate)
    SD =  np.sqrt(np.var(r_peaks/resample_rate))
    RMSSD = np.sqrt(np.mean(np.diff(IBI)**2/resample_rate)) 


    # waveform features
    
    # Compute the root mean square of each ECG lead signal.
    rms = np.zeros(num_leads, dtype=np.float32)
    for i in range(num_leads):
        x = recording[i, :]
        rms[i] = np.sqrt(np.sum(x**2) / np.size(x))


    return age, sex, rms



def process_recordings(path):
    

    records = process_records(path,multiproc=True)
    indices = records[0].get_lead_indices()

    feature_vectors = np.array([(record.ecg_features) for record in records])
    label_vectors = np.array([(record.labels_OH) for record in records])

    #delete entries with too many -1/nan values (rows): more than half==-1 -> delete
    feature_vectors_mask = np.count_nonzero(feature_vectors==-1,axis=1)<feature_vectors.shape[1]/2
    feature_vectors = feature_vectors[feature_vectors_mask]
    label_vectors = label_vectors[feature_vectors_mask]
    #--------------------------------------------------------------------
    #set -1/nan to median------------------------------------------------
    feature_median = np.median(feature_vectors,axis=0)
    #alternative
    #feature_mean = np.ma.array(feature_vectors, mask=feature_vectors==-1).mean(0)
    feature_cleaned = np.where(feature_vectors==-1,feature_median,feature_vectors)
    #-----------------------------------------------------------  
    features = feature_cleaned
    labels = label_vectors
    np.savez(path + 'features.npz', data=features, labels = labels, indices = indices)
    
    # scale data for training
    scaler = MinMaxScaler()
    
    X_train_scaled = scaler.fit_transform(feature_cleaned)
    Y_train = label_vectors
    #save_model("feature_extraction_result/features_full.sav", X_train_scaled, Y_train, scaler, indicies)
    
    return X_train_scaled, Y_train, scaler, indices
    
def load_features_from_file(path):
    data = np.load(path)
    features = data['data']
    labels = data['labels']
    indices = data['indices']
    return features, labels, indices 