import torch
import torch.nn as nn
import torch.nn.functional as F


class MyResidualBlock(nn.Module):
    def __init__(self,Fin,Fout,kernel_size,padding,stride,bias,dropout):
        super(MyResidualBlock,self).__init__()
        self.dropout = dropout

        self.conv_0 = nn.Conv2d(Fin, Fout, kernel_size=kernel_size, padding=padding,stride=stride, bias=bias)
        self.bn_0 = nn.BatchNorm2d(Fout)

        self.conv_1 = nn.Conv2d(Fout, Fout, kernel_size=kernel_size, padding=padding, bias=bias)
        self.bn_1 = nn.BatchNorm2d(Fout)

        self.conv_2 = nn.Conv2d(Fout, Fout, kernel_size=kernel_size, padding=padding, bias=bias)
        self.bn_2 = nn.BatchNorm2d(Fout)

    def forward(self, x):
        x = F.leaky_relu(self.bn_0(self.conv_0(x)))
        y = F.leaky_relu(self.bn_1(self.conv_1(x)))
        y = self.bn_2(self.conv_2(y)) + x
        x = F.leaky_relu(y)
        x = F.dropout(x,self.dropout,self.training)
        return x



class NN(nn.Module):
    def __init__(self,device,nOUT):
        super(NN,self).__init__()
        self.Fin = 14
        self.rb_0 = MyResidualBlock(Fin=self.Fin,Fout=32,kernel_size=(1,3),padding=(0,1),stride=(1,2),bias=False,dropout=0.5)
        self.rb_1 = MyResidualBlock(Fin=32,Fout=64,kernel_size=(1,3),padding=(0,1),stride=(1,2),bias=False,dropout=0.5)
        self.rb_2 = MyResidualBlock(Fin=64,Fout=128,kernel_size=(1,3),padding=(0,1),stride=(1,2),bias=False,dropout=0.5)
        self.rb_3 = MyResidualBlock(Fin=128,Fout=128,kernel_size=(1,3),padding=(0,1),stride=(1,2),bias=False,dropout=0.5)
        self.rb_4 = MyResidualBlock(Fin=128,Fout=256,kernel_size=(1,3),padding=(0,1),stride=(1,2),bias=False,dropout=0.5)
        # self.pool = nn.AdaptiveMaxPool2d(1)



        self.gru = nn.GRU(256,128,
                          num_layers=2,
                          batch_first=True,
                          bidirectional=True,
                          dropout=0.5)

        # attention layers
        self.attn0 = nn.Linear(256,128)
        self.attn1 = nn.Linear(128,1)
        #
        #
        self.fc_1 = nn.Linear(256,nOUT)

        self.to(device)




    def forward(self, x):
        bs = x.size(0)

        x = self.rb_0(x)
        x = self.rb_1(x)
        x = self.rb_2(x)
        x = self.rb_3(x)
        x = self.rb_4(x)
        # x = self.pool(x)
        # x = x.squeeze()
        # stop = 1

        x = x.squeeze(2).permute(0,2,1)

        x,_ = self.gru(x)

        #attention mechanism
        a = torch.tanh(self.attn0(x))
        a = F.softmax(self.attn1(a),dim=1)
        x = torch.sum(x*a,dim=1)

        x = self.fc_1(x)


        return x