#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Apr  7 19:52:07 2021

@author: chadyang
"""
#%%
from helper_code import *
from sklearn.impute import SimpleImputer
from pycaret.classification import *
import pandas as pd

#%% params
# SITES = ['WFDB_CPSC2018', 'WFDB_CPSC2018_2', 'WFDB_Ga', 'WFDB_PTB', 'WFDB_PTBXL', 'WFDB_StPetersburg']
SITES = ['WFDB_CPSC2018_2']
# DATAROOT = '/HDD/HDD2/Projects/Challenges/Physionet/cinc/scripts/python-classifier-2021'
DATAROOT = '../../data'

FOLD = 5


#%%
data_directory = [f'{DATAROOT}/{site}' for site in SITES]
model_directory = './models'

twelve_lead_model_filename = '12_lead_model.sav'
six_lead_model_filename = '6_lead_model.sav'
three_lead_model_filename = '3_lead_model.sav'
two_lead_model_filename = '2_lead_model.sav'


#%% get data paths
header_files, recording_files, num_recordings = [], [], 0
for d in data_directory:
    h, r = find_challenge_files(d)
    num_recordings += len(r)
    header_files.extend(h)
    recording_files.extend(r)
print(f'Num of recordings {num_recordings}')


#%% Extract classes from dataset.
print('Extracting classes...')
classes = set()
for header_file in header_files:
    header = load_header(header_file)
    classes |= set(get_labels(header))
if all(is_integer(x) for x in classes):
    classes = sorted(classes, key=lambda x: int(x)) # Sort classes numerically if numbers.
else:
    classes = sorted(classes) # Sort classes alphanumerically otherwise.
num_classes = len(classes)
print(f'Number of classes: {num_classes}')


#%%
def get_features(header, recording, leads):
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Pre-process recordings.
    adc_gains = get_adcgains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]

    # Compute the root mean square of each ECG lead signal.
    rms = []
    for i in range(num_leads):
        x = recording[i, :]
        rms.append([np.sqrt(np.sum(x**2) / np.size(x))])
    return age, sex, rms


#%% Extract features and labels from dataset.
print('Extracting features and labels...')
# data = np.zeros((num_recordings, 14), dtype=np.float32) # 14 features: one feature for each lead, one feature for age, and one feature for sex
feas = []
labels = np.zeros((num_recordings, num_classes), dtype=np.bool) # One-hot encoding of classes

for idx in range(num_recordings):
    print(f'\r\t{idx+1}/{num_recordings}...', end="\r")

    # Load header and recording.
    header = load_header(header_files[idx])
    recording = load_recording(recording_files[idx])

    # Get age, sex and root mean square of the leads.
    age, sex, fea_hrv = get_features(header, recording, twelve_leads)
    fea =  fea_hrv + [[age, sex]] # fea: list with length = num_leads + 1([age, sex])
    feas.append(fea)
    
    current_labels = get_labels(header)
    for label in current_labels:
        if label in classes:
            j = classes.index(label)
            labels[idx, j] = 1


#%%
def parse_features(feas, leads):
    feature_indices = [twelve_leads.index(lead) for lead in leads] + [len(feas[0])-1]
    features = []
    for fea in feas:
        fea_concat = []
        for idx in feature_indices:
            fea_concat += fea[idx]
        features.append(fea_concat)
    features = np.asarray(features)
    return features    
    

#%% train model
#%% Train 12-lead ECG model.
print('Training 12-lead ECG model...')
leads = twelve_leads
filename = os.path.join(model_directory, twelve_lead_model_filename)
features = parse_features(feas, leads)

imputer = SimpleImputer().fit(features)
features = imputer.transform(features)


#%%
# for class_idx in range(labels.shape[1]):
class_idx = 3
label_class = labels[:,class_idx].astype('int')
data = np.hstack([features, label_class.reshape(-1,1)])
data = pd.DataFrame(data, columns=(list(range(features.shape[1]))+['target']))
print(f'label: {class_idx}: {sum(label_class==0)} | 1: {sum(label_class==1)}')

#%%
exp_clf = setup(data=data, target='target', session_id=class_idx, silent=True)
# exp_clf = setup(data=data, target='target', session_id=class_idx)
lgbm = create_model('lightgbm')
# tuned_lgbm = tune_model(lgbm, fold=FOLD)
tuned_lgbm = tune_model(lgbm)

#%% plot model
plot_model(tuned_lgbm, plot = 'auc')



#%% Train 2-lead ECG model.
print('Training 2-lead ECG model...')
leads = two_leads
filename = os.path.join(model_directory, two_lead_model_filename)
features = parse_features(feas, leads)



