# from keras.engine.topology import Layer
from tensorflow.keras.layers import Layer
import tensorflow as tf


class DistanceLayer(Layer):
    """
    Custom layer to compute shapelet-timeseries minimum distance matrix
    """
    def __init__(self, output_dim, shapelets, stride, **kwargs):
        self.output_dim = output_dim
        self.shapelets = shapelets
        self.stride = stride
        super(DistanceLayer, self).__init__(**kwargs)

    def get_shapelets(self, shape, dtype=None):
        return tf.convert_to_tensor(self.shapelets, dtype=dtype)

    def build(self, input_shape):
        # Create trainable shapelets variables for this layer.

        self.kernel = self.add_weight(name='kernel',
                                      shape=(self.output_dim, self.shapelets.shape[1], self.shapelets.shape[2]),
                                      initializer=self.get_shapelets,
                                      trainable=True)
        del self.shapelets
        super(DistanceLayer, self).build(input_shape)

    def call(self, data):
        print(self.kernel.shape)
        tl, nc = data.shape[1], data.shape[2]
        print(data.shape)
        ls = self.kernel.shape[1]
        ns = self.kernel.shape[0]
        self.stride = ls


        D = tf.reshape(data,(-1,1,data.shape[1],data.shape[2]))
        A = tf.image.extract_patches(D,[1,1,ls,1],[1,1,self.stride,1],[1,1,1,1],padding='VALID') #extract subsequences of
        # length ls from each timeseries
        print(A.shape)
        A = tf.reshape(A,(-1,(tl - ls) // self.stride + 1,ls,nc))
        print(A.shape)

        P = tf.math.squared_difference(A, tf.reshape(self.kernel, [ns, 1, 1,
                                                                   ls, nc]))

        print(P.shape)

        # o = tf.reduce_sum(P,
        #                   axis=(4, 3))
        # o = tf.reduce_sum(P,
        #                   axis=3)
        o = tf.reshape(P, [ns, tf.shape(data)[0], ((tl - ls) // self.stride + 1) * ls, nc])
        print(o.shape)
        # o = tf.reshape(tf.reduce_min(o, axis=2), [ns, -1])

        # o = tf.transpose(o, perm=[1, 2, 0])
        o = tf.reshape(tf.transpose(o, perm=[1, 2, 3, 0]), [tf.shape(data)[0], ((tl - ls) // self.stride + 1) * ls, nc*ns])
        # o = tf.squeeze(tf.transpose(o, perm=[1, 2, 3, 0]), [2, 3])

        print(o.shape)

        o = o / ls

        # o = tf.reshape(o, [4, 27, 1])

        # o = tf.expand_dims(o, -1)

        return o

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)
