import os
import re
import pickle
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import (confusion_matrix, f1_score, roc_auc_score,
                             roc_curve)

from helper_code import *
import itertools
import neurokit2 as nk
import warnings
from neurokit2.misc import NeuroKitWarning

from src.setup import equivalent_classes, replace_equivalent_classes, dx_scored
################################################################################
#
# Other functions
#
################################################################################

prefix_to_dataset = {
    "A": "CPSC",
    "Q": "CPSC_Extra",
    "I": "StPetersburg",
    "S": "PTB",
    "HR": "PTB_XL",
    "E": "Georgia",
    "JS": "Chapman_Shaoxing",
    "JS": "Ningbo",
}


def get_target_classes_mask_by_sample_id(classes, sample_id):
    """
    return classes with boolean
    True if class is in target_labels else False


    Args:
        classes: classes to predict (all_targeted_classes)
        sample_id: ex) "E2909"

    Return:
        [True, False, False, True, ...], shape: (len(classes),)
    """
    target_labels = get_target_classes_by_sample_id(sample_id)
    return np.isin(classes, target_labels)


def get_target_classes_by_sample_id(sample_id):
    """
    get dataset from sample id using prefix_to_dataset
    if prefix is not in prefix_to_dataset.keys() return UNKNOWN
    if prefix is JS we use smaple ids pickle file
    (Ningbo_sample_ids, ChapmanShaoxing_sample_ids)

    Args:
        sample_id: 
            ex) "E0923.mat", "JS0124"
    
    Return:
        dataset: 
            ex) "Georgia", "Ningbo"
    """
    sample_id_prefix = get_sample_id_prefix(sample_id)

    JS_func = lambda x: "Chapman_Shaoxing" if (int(x[2:]) <= 10646 and int(x[2:]) != 3074) else "Ningbo"
    dataset = prefix_to_dataset.get(sample_id_prefix, "UNKNOWN") if sample_id_prefix != "JS" else JS_func(sample_id) 

    return get_target_classes_per_dataset(dataset)


def get_target_classes_per_dataset(dataset):
    """
    Args:
        dataset: one of ['CPSC', 'CPSC_Extra', 'StPetersburg', 'PTB', 'PTB_XL', 
                         'Georgia', 'Chapman_Shaoxing', 'Ningbo']
    Return:
        ["1649~~", "734~~", ~~]
    """

    # dx_scored = pd.read_csv("./evaluation-2021/dx_mapping_scored.csv", dtype="str")
    if dataset == "UNKNOWN":
        return sorted(dx_scored["SNOMEDCTCode"].values)
    non_zero_labels = dx_scored[dx_scored[dataset].astype('int32') != 0]["SNOMEDCTCode"].values
    target_labels = sorted(list(set(replace_equivalent_classes(non_zero_labels, equivalent_classes))))
    return target_labels


def get_sample_id_prefix(sample_id):
    """
    Args:
        sample_id: ex) E5050.mat or JS01234

    Return:
        prefix of sample id: ex) E, JS
    """
    p = re.compile('[A-Z]+')
    return p.match(sample_id).group() if p.match(sample_id) is not None else "UNKNOWN"


def identify_dataset(header, recording_mean_per_lead, thres=0.5):
    """
    To separate three dataset, "American", "CPSC", "Georgia"

    Args:
        header: header, to check default freq
        recording_mean_per_lead: recording_mean_per_lead
        thres: to separate CPSC, Georgia, threshold for mean

    Return:
        "UNKNOWN": American
        "A": CPSC
        "E": Georgia
    """
    freq = get_frequency(header)
    recording_mean_per_lead = np.array(recording_mean_per_lead)
    if freq == 500:    
        if all((recording_mean_per_lead <= abs(thres)) & (recording_mean_per_lead >= -1 * abs(thres))):
            return "A" # CPSC
        else:
            return "E" # Georgia
    else:
        return "UNKNOWN" # American

def get_recording_mean_per_lead(recording):
    """
    To get the average value of recording per lead

    Args:
        recording: [[lead1], [lead2], ...] before preprocess

    Returns:
        [mean value of lead1, mean value of lead2, ...]

    """
    return list(map(lambda x: recording[x].mean().item(), range(len(recording))))


# Extract features from the header and recording.
def get_features(header, recording, leads, preprocess_configs=None):
    features = np.zeros(20, dtype=np.float32)
    age_sex = get_age_sex(header)
    rr_mean_std = get_rr_mean_std(recording, preprocess_configs=preprocess_configs)

    features[:4] = age_sex
    features[4:] = rr_mean_std
    return features

def get_age_sex(header):
    """
    Return output = np.array([age, man, woman, sex_mask])
    age_default is set to 0.6.
    """
    output = np.zeros(4, dtype=np.float32)

    # Make age and age_mask.
    age_default = 0.6
    age = get_age(header)
    if age is None or np.isnan(age):
        age = age_default
    else:
        age = age / 100

    # Make man, woman, and sex_mask.
    sex = get_sex(header)
    if sex in ("Female", "female", "F", "f"):
        man = 0
        woman = 1
        sex_mask = 0
    elif sex in ("Male", "male", "M", "m"):
        man = 1
        woman = 0
        sex_mask = 0
    else:
        man = 0
        woman = 0
        sex_mask = 1

    output[0] = age
    output[1] = man
    output[2] = woman
    output[3] = sex_mask
    return output

def get_rr_mean_std(recording, preprocess_configs=None):
    """
    Extract mean, std, RMSSD of RR interval length and mean, RMSSD of R peak value, HR mean, HR min, and HR max from lead 2. 
    If lead 2 does not work, try with lead 1.
    If lead 1 & lead 2 do not work, put default value.
    Output : (16,)-ndarry. 
             output = np.array([(mean of RR interval length), (mask-mean of RR interval length), (std of RR interval length),
                                (mask-std of RR interval length), (RMSSD of RR interval length), (mask-RMSSD of RR interval length),
                                (mean of R peak value), (mask-mean of R peak value), (RMSSD of R peak value), (mask-RMSSD of R peak value),
                                (HR_mean), (mask-HR_mean), (HR_min), (mask-HR_min), (HR_max), (mask-HR_max)])
    """
    resample_freq = preprocess_configs["resample_freq"]
    output = np.zeros(16, dtype=np.float32)
    default_value = -1
    try_lead_1 = False
    is_error = False

    mask_mean_RR_interval = 0
    mask_std_RR_interval = 0
    mask_RMSSD_RR_interval = 0
    mask_mean_RR_peaks = 0
    mask_RMSSD_RR_peaks = 0
    mask_HR_mean = 0
    mask_HR_min = 0
    mask_HR_max = 0

    # Try with lead 2 (lead idx 1)
    try:
        filtered_data0 = (recording[1].clone()).numpy()
        _, R_peaks = nk.ecg_peaks(filtered_data0, sampling_rate=resample_freq)   # This may occur an error. Thus, I used [try & except]
        R_peaks = R_peaks['ECG_R_Peaks']
        if R_peaks.shape[0] < 4:
            try_lead_1 = True
        else: 
            diff_RR = (R_peaks[1:] - R_peaks[:-1]) / resample_freq                   # freq로 나눔으로서, 단위가 second가 됨. This may occur an error.
            mean_RR_interval = diff_RR.mean()
            std_RR_interval = diff_RR.std()
            RMSSD_RR_interval = get_RMSSD(diff_RR)
            mean_RR_peaks = filtered_data0[R_peaks].mean()
            RMSSD_RR_peaks = get_RMSSD(filtered_data0[R_peaks])
            rate = nk.signal.signal_rate(R_peaks, sampling_rate=resample_freq, desired_length=len(filtered_data0))
            HR_mean = rate.mean() / 100
            HR_min = rate.min() / 100
            HR_max = rate.max() / 100
            if (diff_RR < 0.2).sum() != 0 or mean_RR_interval > 2.5 or std_RR_interval > 0.215 :
                try_lead_1 = True
    except:
        try_lead_1 = True
    
    # Try with lead 1 (lead idx 0)
    if try_lead_1:
        try:
            filtered_data0 = (recording[0].clone()).numpy()
            _, R_peaks = nk.ecg_peaks(filtered_data0, sampling_rate=resample_freq)   # This may occur an error. Thus, I used [try & except]
            R_peaks = R_peaks['ECG_R_Peaks']
            if R_peaks.shape[0] < 4:
                is_error = True
            else: 
                diff_RR = (R_peaks[1:] - R_peaks[:-1]) / resample_freq                   # freq로 나눔으로서, 단위가 second가 됨. This may occur an error.
                mean_RR_interval = diff_RR.mean()
                std_RR_interval = diff_RR.std()
                RMSSD_RR_interval = get_RMSSD(diff_RR)
                mean_RR_peaks = filtered_data0[R_peaks].mean()
                RMSSD_RR_peaks = get_RMSSD(filtered_data0[R_peaks])
                rate = nk.signal.signal_rate(R_peaks, sampling_rate=resample_freq, desired_length=len(filtered_data0))
                HR_mean = rate.mean() / 100
                HR_min = rate.min() / 100
                HR_max = rate.max() / 100
                if (diff_RR < 0.2).sum() != 0 or mean_RR_interval > 2.5 or std_RR_interval > 0.215:
                    is_error = True
        except:
            is_error = True
    
    # Cannot extract exact R pick locations. Put default values.
    if is_error:
        mean_RR_interval = default_value
        std_RR_interval = default_value
        RMSSD_RR_interval = default_value
        mean_RR_peaks = default_value
        RMSSD_RR_peaks = default_value
        HR_mean = default_value
        HR_min = default_value
        HR_max = default_value
        mask_mean_RR_interval = 1
        mask_std_RR_interval = 1
        mask_RMSSD_RR_interval = 1
        mask_mean_RR_peaks = 1
        mask_RMSSD_RR_peaks = 1
        mask_HR_mean = 1
        mask_HR_min = 1
        mask_HR_max = 1

    output[0] = mean_RR_interval
    output[1] = mask_mean_RR_interval
    output[2] = std_RR_interval
    output[3] = mask_std_RR_interval
    output[4] = RMSSD_RR_interval
    output[5] = mask_RMSSD_RR_interval
    output[6] = mean_RR_peaks
    output[7] = mask_mean_RR_peaks
    output[8] = RMSSD_RR_peaks
    output[9] = mask_RMSSD_RR_peaks
    output[10] = HR_mean
    output[11] = mask_HR_mean
    output[12] = HR_min
    output[13] = mask_HR_min
    output[14] = HR_max
    output[15] = mask_HR_max
    return output

# get RMSSD
def get_RMSSD(intervals):
    return (sum((intervals) ** 2) / (len(intervals) - 1)) ** 0.5

# Extract id from the header
def get_id(header):
    sample_id = None
    for i, l in enumerate(header.split("\n")):
        if i == 0:
            try:
                sample_id = l.split(" ")[0]
            except:
                pass
        else:
            break
    return sample_id


def print_and_write(target_str, file_to_write):
    print(target_str)
    file_to_write.write(target_str)
    file_to_write.write("\n")


def evaluate_single_threshold(preds, class_idx, threshold):
    """
    evaluates prediction score of a specified class, given a candidate threshold value

    Args:
        preds (list of dicts): predictions ({'id': ..., 'labels': ..., 'prediction': ...})
        class_idx (int): index of currently processing class
        threshold (float): candidate threshold value
    Return:
        f1 score, precision, recall (tuple of floats)
    """
    num_preds = len(preds)
    _predictions = []
    _true_labels = []
    for x in preds:
        _predictions.append(x["prediction"][class_idx])
        _true_labels.append(x["labels"][class_idx])
    binary_predictions = np.array(_predictions) > threshold  # list of True/False
    true_labels = np.array(_true_labels).astype(bool)  # list of True/False

    if len(set(binary_predictions)) == 1 and len(set(true_labels)) == 1:
        if not binary_predictions[0] and not true_labels[0]:
            # both all False
            return np.float64("nan"), np.float64("nan"), np.float64("nan")
        elif binary_predictions[0] and true_labels[0]:
            # both all True
            return np.float64(1), np.float64(1), np.float64(1)

    tn, fp, fn, tp = confusion_matrix(true_labels, binary_predictions).ravel()
    result_precision = tp / (tp + fp)
    result_recall = tp / (tp + fn)
    result_f1_score = f1_score(true_labels, binary_predictions)
    # If there's no tp, fp, or fn, some score results in nan (division by zero) because numpy.float handles the exception.
    # but any comparision operation including nan returns False, so you don't have to worry!
    return result_f1_score, result_precision, result_recall


def evaluate_scores(predictions, classes, thresholds):
    """
    evaluates prediction scores of all classes, given thresholds
    uses evaluate_threshold function defined in utils.py

    Args:
        predictions (list of dicts): predictions ({'id': ..., 'labels': ..., 'prediction': ...})
        classes (list of str): classes
        thresholds (list of float): thresholds for each class

    Return:
        f1 scores, precisions, recalls (-> np.array(num_classes,))
    """
    num_classes = len(classes)
    f1_scores = np.zeros(num_classes)
    precisions = np.zeros(num_classes)
    recalls = np.zeros(num_classes)

    for class_idx in range(num_classes):
        result_f1_score, result_precision, result_recall = evaluate_single_threshold(
            predictions, class_idx, thresholds[class_idx]
        )
        f1_scores[class_idx] = result_f1_score
        precisions[class_idx] = result_precision
        recalls[class_idx] = result_recall
    return f1_scores, precisions, recalls


def evaluate_roc_auc_scores(predictions, classes, classes_names, roc_dir):
    """
    evaluates ROC AUC scores from predictions, and plot roc curves.

    Args:
        predictions (list of dicts): predictions ({'id': ..., 'labels': ..., 'prediction': ...})
        classes (list of str): class codes
        classes_names (list of str): class names
        roc_dir (str): dir for saving roc curve plots

    Return:
        ROC AUC scores (-> np.array(num_classes,))
    """
    num_classes = len(classes)
    roc_auc_scores = np.zeros(num_classes)

    for class_idx in range(num_classes):
        result_roc_auc_score = evaluate_single_class_roc_auc(
            predictions, class_idx, classes_names[class_idx], roc_dir
        )
        roc_auc_scores[class_idx] = result_roc_auc_score
    return roc_auc_scores


def evaluate_single_class_roc_auc(preds, class_idx, class_name, roc_dir):
    """
    evaluates ROC AUC score & plotting a roc curve of a specified class

    Args:
        preds (list of dicts): predictions ({'id': ..., 'labels': ..., 'prediction': ...})
        class_idx (int): index of currently processing class
        class_name (str): class name of this index
        roc_dir (str): dir for saving roc curve plot
    Return:
        ROC AUC score (float)
    """

    num_preds = len(preds)
    _predictions = []
    _true_labels = []

    for x in preds:
        _predictions.append(x["prediction"][class_idx])
        _true_labels.append(x["labels"][class_idx])
    predictions = np.array(_predictions)
    true_labels = np.array(_true_labels)

    fpr, tpr, _ = roc_curve(true_labels, predictions)

    fig = plt.figure()
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(class_name)
    plot_path = os.path.join(roc_dir, class_name + ".png")
    plt.savefig(plot_path)
    plt.close(fig)

    roc_auc = roc_auc_score(true_labels, predictions)

    return roc_auc


def plot_confusion_matrix(cm, target_names=None, cmap=None, normalize=True, labels=True, title='Confusion matrix'):
        if cmap is None:
            cmap = plt.get_cmap('Blues')

        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            
        fig, ax = plt.subplots(figsize=(30, 30))
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()

        thresh = cm.max() / 1.5 if normalize else cm.max() / 2
        
        if target_names is not None:
            tick_marks = np.arange(len(target_names))
            plt.xticks(tick_marks, target_names, rotation=15)
            plt.yticks(tick_marks, target_names)
        
        if labels:
            for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
                if normalize:
                    plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                            horizontalalignment="center",
                            color="white" if cm[i, j] > thresh else "black")
                else:
                    plt.text(j, i, "{:.4f}".format(cm[i, j]),
                            horizontalalignment="center",
                            color="white" if cm[i, j] > thresh else "black")

        plt.tight_layout()
        plt.ylabel('True label')
        return fig

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=50, verbose=False, delta=0, checkpoint_pth='chechpoint.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.delta = delta
        self.checkpoint_pth = checkpoint_pth

    def __call__(self, score, model):
        # Here, score = accuracy, which means higher value is better

        if self.best_score is None:
            self.best_score = 0
            self.save_checkpoint(score, model)
        elif score <= self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.save_checkpoint(score, model)
            self.counter = 0

    def save_checkpoint(self, score, model):
        '''Saves model when validation loss decrease.'''
        if model is not None:
            if self.verbose:
                print(f'Accuracy increased ({self.best_score:.6f} --> {score:.6f}).  Saving model ...')
            self.best_score = score
            torch.save(model, self.checkpoint_pth)
