import numpy as np
import tensorflow as tf


def rms_tf(ecg):
    return tf.math.sqrt(tf.math.reduce_mean(tf.math.square(ecg), axis=-1))


class SingleLead:
    def __init__(self, lead):
        self.lead = lead
        self.__name__ = "single_lead_" + str(self.lead)

    def __call__(self, x):
        return x[..., self.lead]


def scaled_dot_product_attention(q, k, v, mask):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead)
    but it must be broadcastable for addition.

    Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable
          to (..., seq_len_q, seq_len_k). Defaults to None.

    Returns:
    output, attention_weights
    """

    matmul_qk = tf.matmul(
        q, k, transpose_b=True
    )  # (..., seq_len_q, seq_len_k)

    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # add the mask to the scaled tensor.
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)

    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    attention_weights = tf.nn.softmax(
        scaled_attention_logits, axis=-1
    )  # (..., seq_len_q, seq_len_k)

    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights


def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(
            dff, activation='relu'
        ),  # (batch_size, seq_len, dff)
        tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
    ])


class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is
        (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(
            q, batch_size
        )  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(
            k, batch_size
        )  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(
            v, batch_size
        )  # (batch_size, num_heads, seq_len_v, depth)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask
        )

        scaled_attention = tf.transpose(
            scaled_attention, perm=[0, 2, 1, 3]
        )  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(
            scaled_attention,
            (batch_size, -1, self.d_model)
        )  # (batch_size, seq_len_q, d_model)

        output = self.dense(
            concat_attention
        )  # (batch_size, seq_len_q, d_model)

        return output, attention_weights


class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1, seed=7):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate, seed=seed)
        self.dropout2 = tf.keras.layers.Dropout(rate, seed=seed)

    def call(self, x, training, mask):
        attn_output, attention_weights = self.mha(
            x, x, x, mask
        )  # (batch_size, input_seq_len, d_model)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(
            x + attn_output
        )  # (batch_size, input_seq_len, d_model)

        ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(
            out1 + ffn_output
        )  # (batch_size, input_seq_len, d_model)

        return out2, attention_weights


def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates


def positional_encoding(position, d_model):
    angle_rads = get_angles(
        np.arange(position)[:, np.newaxis],
        np.arange(d_model)[np.newaxis, :],
        d_model
    )

    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]

    return tf.cast(pos_encoding, dtype=tf.float32)


def identity(x):
    return x


def create_padding_mask(seq, maximum_position_encoding=50, tf_version=False):
    if tf_version:
        mask = tf.cast(
            tf.norm(tf.norm(seq, axis=3), axis=2) > 0, dtype="int32"
        )
    else:
        batch_size = len(seq)
        mask = np.ones((batch_size, maximum_position_encoding))
        n_beats = [len(s) for s in seq]
        for i, n_beat in enumerate(n_beats):
            mask[i, :n_beat] = 0

        # add extra dimensions to add the padding to the attention logits.
    return mask[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)


def list_shape(l):
    """ Returns the shape of a list
        WARNING: this shape is not necessarily constant throughout the list
                 of lists.
    """
    shape = []
    while True:
        try:
            shape.append(len(l))
            l = l[0]
        except TypeError:
            return tuple(shape)


def zeropad(x, maximum_position_encoding, printing=False):
    for i, recording in enumerate(x):
        if len(recording) > maximum_position_encoding:
            if printing:
                print(
                    "WARNING: too many beats, taking first",
                    maximum_position_encoding, "/", len(recording)
                )
            recording = recording[:maximum_position_encoding]
            assert len(recording) == maximum_position_encoding
        n_zeros = maximum_position_encoding - len(recording)
        try:
            x[i] = np.pad(recording, [(0, n_zeros), (0, 0), (0, 0)])
        except ValueError:
            x[i] = np.pad(
                np.array([beat for beat in recording]),
                [(0, n_zeros), (0, 0), (0, 0)]
            )
    x = np.array(x)
    if len(x.shape) != 4:
        x = np.array([recording for recording in x])
    assert len(x.shape) == 4
    return np.array(x)


def submitted_CNN(
        maximum_position_encoding, d_model, n_channels,
        separable=0, layers_with_spatial_dropout=5, layers_to_skip=None,
        layers=None, kernel_size=None, strides=None, dropout=None,
        alpha=0.3, pooling=None, seed=7
):
    """Build a windowing sub-model."""
    if dropout is None:
        dropout = [0, 0, 0.1, 0, 0, 0.1, 0, 0, 0.1, 0, 0, 0.1, 0, 0, 0.1]
    if strides is None:
        strides = [1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2]
    if layers is None:
        layers = [16, 16, 16, 32, 32, 32, 32, 32, 32, 64, 64, 64, 64, 64, 64]
    if kernel_size is None:
        kernel_size = [3, 3, 24, 3, 3, 24, 3, 3, 24, 3, 3, 24, 3, 3, 48]
    from tensorflow import keras
    if separable:
        conv_layer = keras.layers.SeparableConv1D
    else:
        conv_layer = keras.layers.Conv1D
    print("CNN input window size is", d_model)
    x = keras.layers.Input(
        shape=(maximum_position_encoding, d_model, n_channels)
    )
    y = rms_tf(x)
    y = y[..., tf.newaxis]
    print("CNN input shape is", y.shape)
    for i, n_filters in enumerate(layers):
        if layers_to_skip is not None and i % layers_to_skip == 0 and i != 0:
            y_shortcut = y
            print("Adding start of skip")
        if layers_to_skip is not None and i % layers_to_skip == (
                layers_to_skip - 1):
            y = keras.layers.Add()([y, y_shortcut])
            print("Adding end of skip")
        # Extended batch shape doesn't work, use TimeDistributed layer instead:
        # https://github.com/keras-team/keras/issues/14146
        y = keras.layers.TimeDistributed(conv_layer(
            filters=n_filters,
            kernel_size=kernel_size[i],
            strides=strides[i],
            padding='same',
        ))(y)
        if layers_to_skip is not None and i % layers_to_skip == 0 and i == 0:
            y_shortcut = y
            print("Adding start of skip")
        y = keras.layers.LeakyReLU(alpha)(y)
        print("Activation output shape", y.shape)
        if dropout[i] > 0:
            if i < layers_with_spatial_dropout:
                y = keras.layers.TimeDistributed(
                    keras.layers.SpatialDropout1D(dropout[i], seed=seed)
                )(y)
            else:
                y = keras.layers.TimeDistributed(
                    keras.layers.Dropout(dropout[i], seed=seed)
                )(y)

    # Add pooling even if they don't to be compatible with the rest of the code
    if pooling is None:
        y = keras.layers.TimeDistributed(keras.layers.Flatten())(y)
    elif pooling == 'average':
        y = keras.layers.TimeDistributed(
            keras.layers.GlobalAveragePooling1D()
        )(y)
    elif pooling == 'max':
        y = keras.layers.TimeDistributed(keras.layers.GlobalMaxPooling1D())(y)
    else:
        raise ValueError(f'unknown global pooling type: {pooling!r}')
    # Dense layer for y shape to be compatible with input shape of transformers
    y = keras.layers.TimeDistributed(keras.layers.Dense(d_model))(y)

    return keras.Model(x, y, name='embedding_cnn')


class SubmittedCNN(tf.keras.Model):
    def __init__(self, maximum_position_encoding, d_model, n_channels,
                 separable=0, layers_with_spatial_dropout=5,
                 layers_to_skip=None, layers=None, kernel_size=None,
                 strides=None, dropout=None, alpha=0.3, pooling=None, seed=7
                 ):
        super(SubmittedCNN, self).__init__()
        self.maximum_position_encoding = maximum_position_encoding
        self.d_model = d_model
        self.n_channels = n_channels

        if dropout is None:
            self.dropout = [
                0, 0, 0.1, 0, 0, 0.1, 0, 0, 0.1, 0, 0, 0.1, 0, 0, 0.1
            ]
        if strides is None:
            self.strides = [1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2]
        if layers is None:
            self.filters = [
                16, 16, 16, 32, 32, 32, 32, 32, 32, 64, 64, 64, 64, 64, 64
            ]
        if kernel_size is None:
            self.kernel_size = [
                3, 3, 24, 3, 3, 24, 3, 3, 24, 3, 3, 24, 3, 3, 48
            ]
        self.layers_to_skip = layers_to_skip

        if separable:
            conv_layer = tf.keras.layers.SeparableConv1D
        else:
            conv_layer = tf.keras.layers.Conv1D
        self.conv_layers = [
            tf.keras.layers.TimeDistributed(conv_layer(
                filters=n_filters,
                kernel_size=self.kernel_size[i],
                strides=self.strides[i],
                padding='same',
            ))
            for i, n_filters in enumerate(self.filters)
        ]
        self.activation_layer = tf.keras.layers.LeakyReLU(alpha)

        self.dropout_layers = [
            tf.keras.layers.TimeDistributed(
                tf.keras.layers.SpatialDropout1D(rate, seed=seed)
            )
            if i < layers_with_spatial_dropout
            else
            tf.keras.layers.TimeDistributed(
                    tf.keras.layers.Dropout(rate, seed=seed)
            )
            for i, rate in enumerate(self.dropout)
        ]

        if pooling is None:
            self.pooling_layer = tf.keras.layers.TimeDistributed(
                tf.keras.layers.Flatten()
            )
        elif pooling == 'average':
            self.pooling_layer = tf.keras.layers.TimeDistributed(
                tf.keras.layers.GlobalAveragePooling1D()
            )
        elif pooling == 'max':
            self.pooling_layer = tf.keras.layers.TimeDistributed(
                tf.keras.layers.GlobalMaxPooling1D()
            )
        else:
            raise ValueError(f'unknown global pooling type: {pooling!r}')

        self.dense_layer = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(d_model)
        )

    def call(self, inp, training):
        inp = rms_tf(inp)
        y = inp[..., tf.newaxis]
        # print("CNN input shape is", y.shape)
        for i, conv_layer, dropout_layer in zip(range(len(self.conv_layers)),
                                                self.conv_layers,
                                                self.dropout_layers):

            if (
                    self.layers_to_skip is not None
                    and i % self.layers_to_skip == 0
                    and i != 0
            ):
                y_shortcut = y
                print("Adding start of skip")
            if (
                    self.layers_to_skip is not None
                    and i % self.layers_to_skip == (self.layers_to_skip - 1)
            ):
                y = tf.keras.layers.Add()([y, y_shortcut])
                print("Adding end of skip")

            y = conv_layer(y)

            if (
                    self.layers_to_skip is not None
                    and i % self.layers_to_skip == 0
                    and i == 0
            ):
                y_shortcut = y
                print("Adding start of skip")

            y = self.activation_layer(y)
            y = dropout_layer(y, training=training)
            # print("Activation output shape", y.shape)

        y = self.pooling_layer(y)
        return self.dense_layer(y)


class Embedding:
    def __init__(self, embedding_function):
        super(Embedding, self).__init__()
        self.embedding_function = embedding_function

    def embed(self, x, training):
        assert len(x.shape) == 4
        x = self.embedding_function(
            x)  # (batch_size, maximum_position_encoding, embedding_size)
        assert len(x.shape) == 3
        return tf.convert_to_tensor(x)


class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, embedding_function,
                 # input_vocab_size,
                 maximum_position_encoding, rate=0.1, seed=7, n_channels=12):
        # maximum_position_encoding = maximum_sentence_length
        super(Encoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers
        self.maximum_position_encoding = maximum_position_encoding
        if embedding_function.__name__ in ["submitted_CNN", "SubmittedCNN"]:
            self.embedding = embedding_function(
                maximum_position_encoding, d_model, n_channels=None, seed=seed
            )
        else:
            self.embedding = Embedding(
                embedding_function
            )  # tf.keras.layers.Embedding(input_vocab_size, d_model)
            # here, d_model=embedding_size
            self.embedding = self.embedding.embed
        self.pos_encoding = positional_encoding(maximum_position_encoding,
                                                self.d_model)

        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate, seed)
                           for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(rate, seed=seed)

    def call(self, x, training, mask=None):
        if mask is None:
            print("No padding mask input at encoder, creating padding mask...")
            mask = create_padding_mask(
                x, self.maximum_position_encoding, tf_version=True
            )

        # adding embedding and position encoding.
        # print("Before embedding", x.shape)
        assert len(x.shape) == 4
        x = self.embedding(
            x, training=training
        )  # (batch_size, input_seq_len, d_model)
        # print("After embedding:", x.shape)
        seq_len = tf.shape(x)[1]
        x *= tf.math.sqrt(
            tf.cast(self.d_model, tf.float32)
        )  # Why is this line here?? (from Tensorflow tutorial)
        x += self.pos_encoding[:, :seq_len, :]
        # print("After pos encoding", x.shape)
        x = self.dropout(x, training=training)
        # print("After dropout", x.shape)
        attention_weights = {}
        for i in range(self.num_layers):
            x, aw = self.enc_layers[i](x, training, mask)
            attention_weights['encoder_layer{}'.format(i + 1)] = aw
            # print("After encoder layer", i, x.shape)

        return x, attention_weights  # (batch_size, input_seq_len, d_model)


class TransformerEncoder(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, embedding_function,
                 maximum_position_encoding, rate=0.1, seed=7):
        super(TransformerEncoder, self).__init__()

        self.encoder = Encoder(
            num_layers, d_model, num_heads, dff, embedding_function,
            maximum_position_encoding, rate, seed
        )

        self.final_layer = tf.keras.layers.Dense(d_model)

    def call(self, inp, training, enc_padding_mask=None):
        if len(inp) == 2:
            inp, enc_padding_mask = inp
        assert len(inp.shape) == 4
        enc_output, attention_weights = self.encoder(
            inp, training, enc_padding_mask
        )
        enc_output = tf.reshape(
            enc_output,
            (enc_output.shape[0], enc_output.shape[1] * enc_output.shape[2])
        )
        final_output = self.final_layer(enc_output)

        return final_output, attention_weights


class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


class MSEAccuracy(tf.keras.metrics.Metric):
    """MSE-accuracy for Keras.
    TODO: fix this!!

    Accuracy metric based on a threshold on the MSE for each piece of time series.

    Args:
        thr: MSE threshold.
        name: Metric name.
    """

    def __init__(self,
                 thr,
                 name='mse_accuracy',
                 **kwargs):
        super().__init__(name="_".join([name, str(thr)]), **kwargs)
        self.threshold = thr
        self.mse = tf.keras.losses.MeanSquaredError()
        self.correct_points = 0
        self.total_points = 0

    def update_state(self, references, predictions, sample_weight=None):
        self.total_points += 1
        mse = self.mse(references, predictions)  # .numpy()
        # print("MSE", mse)
        # print("thr", self.threshold)
        # print("mse < self.threshold", mse < self.threshold)
        self.correct_points += tf.cast(mse < self.threshold, dtype="int32")
        # self.correct_points += (mse < self.threshold).astype("int32")
        # print(tf.cast(mse < self.threshold, dtype="int32"))
        # if mse < self.threshold:
        # self.correct_points += 1

    def reset_states(self):
        self.correct_points = 0
        self.total_points = 0

    def result(self):
        '''
        try:
            return self.correct_points / self.total_points
        except ZeroDivisionError:
            return 0
        '''
        return 0


def loss_function(real, pred, loss_object):
    loss_ = loss_object(real, pred)
    return tf.reduce_sum(loss_) / tf.cast(tf.shape(real)[0], dtype=loss_.dtype)


def split_targets(data, shuffle=True):
    expanded_data = []
    expanded_targets = []
    for recording in data:
        # we skip the last beat because the algorithm often fails and
        # puts two beats in the last beat
        for i in range(1, len(recording) - 1):
            expanded_data.append(recording[:i])
            expanded_targets.append(recording[i])
    if shuffle:
        shuffle_in_unison(expanded_data, expanded_targets)
    return expanded_data, expanded_targets


def shuffle_in_unison(a, b):
    rng_state = np.random.get_state()
    np.random.shuffle(a)
    np.random.set_state(rng_state)
    np.random.shuffle(b)


def split_batches(data, targets, batch_size):
    batches = []
    while len(data) > 0:
        batches.append(([], []))
        while len(batches[-1][0]) < batch_size and len(data) > 0:
            batches[-1][0].append(data.pop())
            batches[-1][1].append(targets.pop())
    return batches


def split_batches_2(data, targets, mask, batch_size):
    data_temp = []
    targets_temp = []
    mask_temp = []
    while len(data) > 0:
        data_temp.append([])
        targets_temp.append([])
        mask_temp.append([])
        while len(data_temp[-1]) < batch_size and len(data) > 0:
            data_temp[-1].append(data.pop())
            targets_temp[-1].append(targets.pop())
            mask_temp[-1].append(mask.pop())
    return [np.array(data_temp), np.array(targets_temp), np.array(mask_temp)]


def share_memory():
    from tensorflow.compat.v1 import ConfigProto
    from tensorflow.compat.v1 import InteractiveSession
    config = ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 2
    config.gpu_options.allow_growth = True
    session = InteractiveSession(config=config)


def prepare_transformer(
        transformer, layers_to_freeze, layers_to_remove, n_classes,
        embedding_function
):
    # Replace last linear regression layer with linear classification layer
    transformer.final_layer = tf.keras.layers.Dense(
        n_classes, activation="sigmoid"
    )
    print("Final layer replaced")

    if embedding_function.__name__ in ["submitted_CNN", "SubmittedCNN"]:
        transformer.encoder.embedding.trainable = True

    if layers_to_remove is not None:
        transformer.encoder.enc_layers = transformer.encoder.enc_layers[
                                         :(len(transformer.encoder.enc_layers)
                                           - layers_to_remove)
                                         ]
        transformer.encoder.num_layers = len(transformer.encoder.enc_layers)
        print("Number of encoder layers", len(transformer.encoder.enc_layers))

    if layers_to_freeze is not None:
        for i in range(layers_to_freeze):
            layer = transformer.encoder.enc_layers[i]
            layer.trainable = False
        for layer in transformer.encoder.enc_layers:
            print(layer.name, "trainable", layer.trainable)

    return transformer
