# cinc2020.utils 19/03/21
import copy
import os
import random
import numpy as np

EQUIVALENT_CLASSES = [
    ['713427006', '59118001'],
    ['284470004', '63593006'],
    ['427172004', '17338001'],
    ['733534002', '164909002']
]

EQUIVALENT_CLASSES_FIXED = [
    ['284470004', '63593006'],
    ['427172004', '17338001']
]

SCORED_CLASSES = [
    '164889003', '164890007', '6374002', '426627000', '733534002', '713427006',
    '270492004', '713426002', '39732003', '445118002', '164909002', '251146004',
    '698252002', '426783006', '284470004', '10370003', '365413008', '427172004',
    '164947007', '111975006', '164917005', '47665007', '59118001', '427393009',
    '426177001', '427084000', '63593006', '164934002', '59931005', '17338001'
]

UNSCORED_CLASSES = [
    '233892002', '164951009', '251187003', '61277005', '426664006', '251139008',
    '57054005', '413444003', '426434006', '54329005', '251173003', '195080001',
    '195126007', '251268003', '106068003', '713422000', '233917008', '50799005',
    '29320008', '251166008', '233897008', '251170000', '418818005', '74615001',
    '426749004', '251199005', '61721007', '698247007', '27885002', '204384007',
    '53741008', '413844008', '251198002', '82226007', '428417006', '13640000',
    '164942001', '84114007', '368009', '251259000', '251200008', '195042002',
    '426183003', '425419005', '251120003', '704997005', '49260003', '426995002',
    '251164006', '426648003', '253352002', '67741000119109', '446813000',
    '425623009', '445211001', '164873001', '55827005', '370365005', '164865005',
    '164861001', '54016002', '428750005', '164867002', '282825002', '251205003',
    '67198005', '425856008', '164912004', '253339007', '164921003', '446358003',
    '67751000119106', '314208002', '89792004', '17366009', '65778007', '5609005',
    '60423000', '49578007', '77867006', '55930002', '429622005', '164931005',
    '164930006', '251168009', '426761007', '266257000', '251223006', '164937009',
    '11157007', '164884008', '75532003', '81898007', '164896001', '111288001',
    '266249003', '195060002', '251266004', '251182009', '164895002', '251180001',
    '195101003', '74390002'
]

ALL_CLASSES = UNSCORED_CLASSES + SCORED_CLASSES


class ClassEncoder:
    """Class encoder.

    Class encoder for a multi-label classification task.

    Attributes:
        classes: List of unique classes that can be encoded.
    """

    def __init__(
        self, classes=ALL_CLASSES,
        scored_classes=SCORED_CLASSES,
        equivalent_classes=EQUIVALENT_CLASSES_FIXED,
        negative=True
    ):
        self.negative = negative
        if classes is None:
            classes = ALL_CLASSES
        unscored_classes = self._build_class_list(
            classes, scored_classes, equivalent_classes
        )
        self._build_index_mapper(unscored_classes, equivalent_classes)

    @property
    def n_classes(self):
        """Number of classes in the encoder."""
        return len(self.classes)

    def encode(self, labels):
        """Encode labels.

        Encode labels by building an array of zeros and ones where ones
        indicate that a label is present.

        Args:
            labels: List of labels to encode where each element is a list of
                classes.

        Returns:
            A two-dimensional array of zeros and ones encoding labels. The
            number of rows is equal to the length of the input list and the
            number of columns is equal to the number of classes that can be
            encoded. There can be multiple ones by rows in case of multi-label
            samples.
        """
        p = np.zeros((len(labels), self.n_classes), dtype='float32')
        for i, l in enumerate(labels):
            indices = [
                self.mapper[name] for name in l
                if name in self.classes and self.mapper[name] is not None
            ]
            p[i, indices] = 1.0
        return p

    def decode(self, probabilities, threshold=0.5):
        """Decode labels.

        Decode labels from an array of probabilities.

        Args:
            probabilities: Two-dimensional array of estimated probabilities.
                The number of rows corresponds to the number of samples and the
                number of columns to the number of classes that can be decoded.
            threshold: Probability threshold to determine if a class if
                present.

        Returns:
            A list of labels composed of list of classes.
        """
        probabilities = np.asarray(probabilities)
        assert probabilities.ndim == 2
        assert probabilities.shape[1] == self.n_classes
        labels = []
        for p in probabilities:
            indices = np.flatnonzero(p > threshold)
            labels.append([self.classes[i] for i in indices])
        return labels

    def _build_index_mapper(self, unscored_classes, equivalent_classes):
        """Build a mapper from class names to indices.
        Args:
            unscored_classes: A list of classes that will not be scored in the
            cinc2020 challenge.
            equivalent_classes: A list of lists of classes that are considered
            equivalent in the cinc2020 challenge.
        """
        self.mapper = {cls: i for i, cls in enumerate(self.classes)}
        if unscored_classes is not None:
            for cl in unscored_classes:
                if self.negative:
                    self.mapper[cl] = 0
                else:
                    self.mapper[cl] = None

        if equivalent_classes is not None:
            for class_list in equivalent_classes:
                reference_class = class_list[0]
                for cl in class_list[1:]:
                    self.mapper[cl] = self.mapper[reference_class]

    def _build_class_list(self, classes, scored_classes, equivalent_classes):
        """Build a list of class names to map from indexes to class names.
        Args:
            classes: A list of classes in the dataset.
            scored_classes: A list of classes that will be scored in the
            cinc2020 challenge.
            equivalent_classes: A list of lists of classes that are considered
            equivalent in the cinc2020 challenge.
        """
        if scored_classes is not None:
            unscored_classes = [
                cl for cl in classes if cl not in scored_classes
            ]
            if self.negative:
                self.classes = [unscored_classes[0]] + scored_classes
            else:
                self.classes = scored_classes
        else:
            unscored_classes = None
            self.classes = classes

        if equivalent_classes is not None:
            self.classes = get_classes(
                scored_classes=self.classes,
                equivalent_classes=equivalent_classes
            )

        return unscored_classes


class BinaryClassEncoder:
    """Class encoder.

    Class encoder for a binary classification task.

    Attributes:
        classes: List of unique classes that can be encoded.
    """

    def __init__(self, cls):
        if not isinstance(cls, list):
            cls = [cls]
        self.classes = cls

    @property
    def n_classes(self):
        """Number of classes in the encoder."""
        return 1

    def encode(self, labels):
        """Encode labels.

        Encode labels by building an array of zeros and ones where ones
        indicate that a label is present.

        Args:
            labels: List of labels to encode where each element is a list of
                classes.

        Returns:
            A two-dimensional array of zeros and ones encoding labels. The
            number of rows is equal to the length of the input list and the
            number of columns is equal to the number of classes that can be
            encoded. There can be multiple ones by rows in case of multi-label
            samples.
        """
        p = np.zeros(len(labels), dtype='float32')
        for i, l in enumerate(labels):
            for cla in self.classes:
                if cla in l:
                    p[i] = 1.0
                    continue
        return p

    def decode(self, probabilities, threshold=0.5):
        """Decode labels.

        Decode labels from an array of probabilities.

        Args:
            probabilities: Two-dimensional array of estimated probabilities.
                The number of rows corresponds to the number of samples and the
                number of columns to the number of classes that can be decoded.
            threshold: Probability threshold to determine if a class if
                present.

        Returns:
            A list of labels composed of list of classes.
        """
        probabilities = np.asarray(probabilities)
        labels = []
        for p in probabilities:
            if p > threshold:
                labels.append(self.classes[0])
        return labels


def indirect_sort(x):
    """Perform an indirect sort.

    Find the indices to sort a one-dimensional array and to inverse the sort
    operation (i.e. indices to go from the sorted array to the original array).

    Args:
        x: A one-dimensional array.

    Returns:
        Indices to sort the input array and indices to recover the input array
        from the sorted one.

    Raises:
        ValueError: The input array is not one-dimensional.

    Examples:
        >>> x = np.array([5, 2, 4, 1, 3])
        >>> i_fwd, i_bwd = indirect_sort(x)
        >>> y = x[i_fwd]
        >>> print(y)
        [1 2 3 4 5]
        >>> print(y[i_bwd])
        [5 2 4 1 3]
    """
    x = np.asarray(x)
    if x.ndim != 1:
        raise ValueError('array must be one-dimensional')
    i_fwd = np.argsort(x, kind='mergesort')
    i_bwd = np.zeros_like(i_fwd)
    i_bwd[i_fwd] = np.arange(i_fwd.size)
    return i_fwd, i_bwd


def build_sliding_windows(x, window_size, overlap_size=0):
    """Build sliding windows.

    Build sliding windows from an array with optional overlap. The windows are
    built along the first dimension for multi-dimensional arrays.

    Args:
        x: An array.
        window_size: Window size.
        overlap_size: Size of the overlap between successive windows.

    Returns:
        An array with sliding windows.
    """
    if x.shape[0] < window_size:
        return np.zeros_like(x, shape=(0, window_size) + x.shape[1:])
    windows = []
    for end in range(window_size, x.shape[0] + 1, window_size - overlap_size):
        start = end - window_size
        windows.append(x[start:end])
    return np.stack(windows)


def load_weights(weight_file, classes):
    # Load the weight matrix.
    rows, cols, values = load_table(weight_file)
    assert(rows == cols)

    # Assign the entries of the weight matrix with rows and columns
    # corresponding to the classes.
    num_classes = len(classes)
    weights = np.zeros((num_classes, num_classes), dtype=np.float64)
    for i, a in enumerate(rows):
        # Take into account new format of weights.csv for equivalent classes:
        # 'cl_eq_1|cl_eq_2'
        a = a.split('|')
        for a_ in a:
            if a_ in classes:
                k = classes.index(a_)
                for j, b in enumerate(rows):
                    b = b.split('|')
                    for b_ in b:
                        if b_ in classes:
                            l = classes.index(b_)
                            weights[k, l] = values[i, j]

    return weights


def load_table(table_file):
    # The table should have the following form:
    #
    # ,    a,   b,   c
    # a, 1.2, 2.3, 3.4
    # b, 4.5, 5.6, 6.7
    # c, 7.8, 8.9, 9.0
    #
    def is_number(x):
        try:
            float(x)
            return True
        except ValueError:
            return False

    table = list()
    with open(table_file, 'r') as f:
        for i, l in enumerate(f):
            arrs = [arr.strip() for arr in l.split(',')]
            table.append(arrs)

    # Define the numbers of rows and columns and check for errors.
    num_rows = len(table)-1
    if num_rows < 1:
        raise Exception('The table {} is empty.'.format(table_file))

    num_cols = set(len(table[i])-1 for i in range(num_rows))
    if len(num_cols) != 1:
        raise Exception('The table {} has rows with different lengths.'
                        .format(table_file))
    num_cols = min(num_cols)
    if num_cols<1:
        raise Exception('The table {} is empty.'.format(table_file))

    # Find the row and column labels.
    rows = [table[0][j+1] for j in range(num_rows)]
    cols = [table[i+1][0] for i in range(num_cols)]

    # Find the entries of the table.
    values = np.zeros((num_rows, num_cols))
    for i in range(num_rows):
        for j in range(num_cols):
            value = table[i+1][j+1]
            if is_number(value):
                values[i, j] = float(value)
            else:
                values[i, j] = float('nan')

    return rows, cols, values


def get_classes(
        scored_classes=SCORED_CLASSES, equivalent_classes=EQUIVALENT_CLASSES
):
    """Get the classes for training and testing."""
    classes = copy.deepcopy(scored_classes)
    for class_list in equivalent_classes:
        for cl in class_list[1:]:
            classes.remove(cl)
    return classes


def set_GPU(GPU_ID):
    """Set which GPU to use."""
    if GPU_ID is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
        os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'


def find_model_file(model_directory, metric=''):
    if metric == '':
        metric = 'last_epoch'
    else:
        metric = "best_val_" + metric
    model_file = []
    initial_epoch = []
    for file in model_directory.iterdir():
        stem = file.stem
        if stem[:10] == metric:
            model_file.append(file)
            initial_epoch.append(int(stem.split("_")[-1]) + 1)
    if len(model_file) == 0:
        raise EnvironmentError("Model file not found in directory:",
                               model_directory)
    model_file = model_file[np.argmax(initial_epoch)]
    initial_epoch = max(initial_epoch)
    return model_file, initial_epoch


def transform_class_weights(class_weights):
    """Transform class weights into a dictionary form"""
    if isinstance(class_weights, list):
        transformed_weights = {}
        for i in range(len(class_weights)):
            transformed_weights[i] = class_weights[i]
        return transformed_weights
    if isinstance(class_weights, dict) or class_weights is None:
        return class_weights
    raise ValueError("class_weights must be a dictionary or list")


def simple_oversample(X, y, seed=7, reshuffle=True):
    """Perform simple oversampling"""
    if not len(X) == len(y):
        print("WARNING: " + str(len(X)) + " signals and " + str(len(y))
              + " labels")
    y = np.array(y)
    diff = sum(y == 0) - sum(y == 1)
    print("Distribution before:", [sum(y == 0)/len(y), sum(y == 1)/len(y)])
    print("Total number of samples:", len(y))
    X = list(X)
    y = list(y)
    if diff > 0:
        over_sampled_class = 1
    else:
        over_sampled_class = 0
    over_sampled_idx = [i for i in range(len(y)) if y[i] == over_sampled_class]
    new_sample_idx = random.choices(over_sampled_idx, k=abs(diff))
    X_oversampled, y_oversampled = X, y
    for idx in new_sample_idx:
        X_oversampled.append(X[idx])
        y_oversampled.append(y[idx])
    if reshuffle:
        X_oversampled, y_oversampled = shuffle_in_unison(X, y)
    y_oversampled = np.array(y_oversampled)
    print("Distribution after:", [sum(y_oversampled == 0)/len(y_oversampled),
                                  sum(y_oversampled == 1)/len(y_oversampled)])
    print("Total number of samples:", len(y))
    return X_oversampled, y_oversampled


def shuffle_in_unison(a, b):
    """Shuffle two lists in unison"""
    c = list(zip(a, b))
    random.shuffle(c)
    a, b = zip(*c)
    return a, b


def rms(ecg):
    """Root mean square along the last axis."""
    return np.sqrt(np.mean(ecg**2, axis=-1))
