import torch

class Attention(torch.nn.Module):
   
    def __init__(self,
                 dim_header,
                 dim_peak,
                 dim_peak_delta,
                 dim_fft,
                 device=None,
                 hparams={}):
        super(Attention, self).__init__()
        self.hparams = hparams
        self.device = device
        # self.d_a = hparams['d_a']
        self.pre_attn_merge = torch.nn.Linear(int(dim_fft / 12),
                                              int(dim_fft / 12))
        # self.beat_dim = int(dim_fft / 12) + dim_peak_delta + dim_peak + dim_header
        self.beat_dim = dim_fft + dim_peak_delta + dim_peak + dim_header

        
        self.beat_embed = torch.nn.Linear(
            self.beat_dim,
            hparams['lstm_input_dim'],
            bias=True)
        
        self.layer_norm_embed = torch.nn.LayerNorm(hparams['lstm_input_dim'])

        self.lstm = torch.nn.LSTM(hparams['lstm_input_dim'],
                                  hparams['lstm_hidden_dim'],
                                  hparams['lstm_num_layers'], 
                                  batch_first=True,
                                  bidirectional=True)
        
        self.attn_merge = torch.nn.Linear(2 * hparams['lstm_hidden_dim'],
                                          2 * hparams['lstm_hidden_dim'])
        
        self.post_lstm_linear = torch.nn.Linear(2 * hparams['lstm_hidden_dim'],
                                                2 * hparams['lstm_hidden_dim'])
        
        self.attn_linear = torch.nn.Linear(2 * hparams['lstm_hidden_dim'],
                                           2 * hparams['lstm_hidden_dim'])

        self.layer_norm_attn = torch.nn.LayerNorm(
            2 * hparams['lstm_hidden_dim'])

        self.linear_final = torch.nn.Linear(
            2 * hparams['lstm_hidden_dim'],
            hparams['num_classes'],
            bias=True)
        
        self.dropout = torch.nn.Dropout(hparams['dropout'])
        
        
    def forward(self, x):
        # shape = [batch_size, dim_header + dim_peaks]
        static_feats = torch.cat([x['header_feats'], x['peak_feats']],
                                 dim=-1)
        # print(static_feats.shape)
        # shape = [batch_size, n_beats, dim_fft]
        dynamic_feats = x['fft_feats']
        batch_size = dynamic_feats.shape[0]
        n_beats = dynamic_feats.shape[1]
        
#         dynamic_feats = dynamic_feats.reshape(batch_size, n_beats, 12, -1)
#         pre_alphas = self.pre_attn_merge(dynamic_feats)
#         dynamic_feats = dynamic_feats * pre_alphas
#         dynamic_feats = torch.mean(dynamic_feats, 2)
        # print(dynamic_feats.shape)

        # shape = [batch_size, n_beats, dim_fft + dim_peak_delta]
        dynamic_feats = torch.cat(
            [dynamic_feats, x['peak_delta_feats']],
            dim=-1
        )
        sum_energy = torch.sum(dynamic_feats[:, :, :-1], dim=-1,
                               keepdim=True)
        # dynamic_feats[:, :, :-1] /= sum_energy
        # print(dynamic_feats.shape)
        # print(dynamic_feats.get_device())

        static_feats = static_feats[:, None, :].expand(batch_size, n_beats, -1)
        # shape = [batch_size, n_beats, dim_header + dim_peaks + dim_fft 
        #          + dim_peak_delta]
        dynamic_feats = torch.cat([dynamic_feats, static_feats], dim=-1)
        # print(dynamic_feats.shape)
        # shape = (batch_size, n_beats, lstm_input_dim)
        embeds = self.beat_embed(dynamic_feats)
        embeds = self.dropout(embeds)
        if 'use_layer_norm' in self.hparams:
            embeds = self.layer_norm_embed(embeds)
        embeds = torch.tanh(embeds)
        # shape = (batch_size, n_beats, 2 * lstm_hidden_dim)
        outputs, _ = self.lstm(embeds)
        if 'use_layer_norm' in self.hparams:
            outputs = self.layer_norm_attn(outputs)
            
        lstm_outputs = self.post_lstm_linear(torch.tanh(outputs))
        lstm_outputs = self.dropout(lstm_outputs)
        outputs = self.dropout(outputs)
        outputs = torch.tanh(outputs)
        outputs = self.attn_linear(outputs)
        outputs = torch.tanh(outputs)
        # Use the mean for now.
        # shape = (batch_size, 2 * lstm_hidden_dim)
        alphas = self.attn_merge(outputs)
        alphas = torch.tanh(alphas)
        outputs = alphas * outputs #lstm_outputs
        outputs = torch.mean(outputs, dim=1)        
        outputs = self.linear_final(outputs)
        outputs = torch.sigmoid(outputs)
        return outputs