# -*- coding: utf-8 -*-
"""
Created on Wed Mar 17 15:02:24 2021

@author: Maurice
"""

from helper_code import *
import numpy as np, os, sys, joblib
from feature_extraction import compute_statistical_features
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')

normal_class = '426783006'
equivalent_classes = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]
EQUIV_CLASSES = {'713427006': '59118001', '284470004': '63593006', '427172004': '17338001'}
DIAG_CLASSES = {'270492004':0,'164889003':1,'164890007':2,'426627000':3,'713426002':4,'445118002':5,
                '39732003':6,'164909002':7,'251146004':8,'698252002':9,'10370003':10,'164947007':11,'111975006':12,'164917005':13,
                '47665007':14,'59118001':15,'427393009':16,'426177001':17,'426783006':18,'427084000':19,'63593006':20,'164934002':21,'59931005':22,'17338001':23}

def extract_header(header_file):
    header = load_header(header_file)
    leads = get_leads(header)
    sampling_rate = get_frequency(header)
    baselines = get_baselines(header, leads)
    adcgains = get_adc_gains(header, leads)
    # Extract age.
    age = get_age(header)
    if age is None or age<0:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')
        
    return leads,sampling_rate,baselines,adcgains,age,sex

def extract_labels_OH(header_file):
    '''
    pseudo one hot encoding of labels (multiple labels can be true)

    Parameters
    ----------
    header_file : string
        header file path.

    Returns
    -------
    labels_OH : ndarray
        1 dim array with entries: 1 for each assigned label.

    '''
    header = load_header(header_file)
    labels_OH = np.zeros(len(DIAG_CLASSES),float)
    labels = get_labels(header)
    for label in labels:
        equi_label = EQUIV_CLASSES.get(label)
        if type(equi_label) == str:
            labels_OH[DIAG_CLASSES[equi_label]] = 1
        else :
            idx = DIAG_CLASSES.get(label)
            if type(idx) == int:
                labels_OH[idx] = 1
    return labels_OH
        
        
    
        
        
class ECGRecording:
    def __init__(self,header_file,record_file,load_labels=False):
        '''
        representing one ecg recording from one person

        Parameters
        ----------
        header_file : string
            path location of header file.
        record_file : string
            path location of recording file.
        load_labels : bool, optional
            Should diagnostic labels be loaded (for training). The default is False.

        Returns
        -------
        ecg_recording.

        '''
        self.header_file=header_file
        self.record_file=record_file
        
        self.leads,self.sampling_rate,self.baselines,self.adcgains,self.age,self.sex = extract_header(header_file)
        # cut recording for now (for timing)
        t_recording = load_recording(record_file)
        
        # Reorder/reselect leads in recordings.
        if len(self.leads) ==12:
            available_leads = self.leads
            indices = list()
            for lead in twelve_leads:
                i = available_leads.index(lead)
                indices.append(i)
            t_recording = t_recording[indices, :]
            self.leads = twelve_leads
        else:
            print('RecordingError: Recording ',header_file, ' has less than 12 leads!')
            
        if t_recording.shape[1]>22000:
            self.recording = t_recording[:,1999:21999]
        else:
            self.recording = t_recording
        #preprocess recordings with header info
        self.recording = self.recording.astype(float) #convert to float after update of May
        for i in range(len(self.leads)): 
            self.recording[i, :] = (self.recording[i, :] - self.baselines[i]) / self.adcgains[i]
            
        self.ecg_features = np.array((self.age,self.sex))
        self.labels_OH = None
        self.lead_indices_stat = np.zeros(len(self.leads))
        if load_labels:
            self.load_labels()
        
    def append_statis_features(self):
        statis_features,lead_indices = compute_statistical_features(self.recording,self.sampling_rate,self.leads)
        self.ecg_features = np.concatenate((self.ecg_features, statis_features))
        self.lead_indices_stat = lead_indices+2
        
    def get_lead_indices(self):
        return self.lead_indices_stat
        
    def load_labels(self):
        self.labels_OH = extract_labels_OH(self.header_file)
        
def create_ecg_recording(header_file,record_file,load_labels=False,compute_statis_features=True):
    recording = ECGRecording(header_file, record_file,load_labels)
    if compute_statis_features:
        recording.append_statis_features()
    return recording