#%%
from src.engine.models.base import Classifier
import torch
import torch.nn as nn
import torch.nn.functional as F

#%% ResNet1d related
class MyConv1dPadSame(nn.Module):
    """
    extend nn.Conv1d to support SAME padding
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
        super(MyConv1dPadSame, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups
        self.conv = torch.nn.Conv1d(
            in_channels=self.in_channels, 
            out_channels=self.out_channels, 
            kernel_size=self.kernel_size, 
            stride=self.stride, 
            groups=self.groups)

    def forward(self, x):
        
        net = x
        
        # compute pad shape
        in_dim = net.shape[-1]
        out_dim = (in_dim + self.stride - 1) // self.stride
        p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
        pad_left = p // 2
        pad_right = p - pad_left
        net = F.pad(net, (pad_left, pad_right), "constant", 0)
        
        net = self.conv(net)

        return net
        
class MyMaxPool1dPadSame(nn.Module):
    """
    extend nn.MaxPool1d to support SAME padding
    """
    def __init__(self, kernel_size):
        super(MyMaxPool1dPadSame, self).__init__()
        self.kernel_size = kernel_size
        self.stride = 1
        self.max_pool = torch.nn.MaxPool1d(kernel_size=self.kernel_size)

    def forward(self, x):
        
        net = x
        
        # compute pad shape
        in_dim = net.shape[-1]
        out_dim = (in_dim + self.stride - 1) // self.stride
        p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
        pad_left = p // 2
        pad_right = p - pad_left
        net = F.pad(net, (pad_left, pad_right), "constant", 0)
        
        net = self.max_pool(net)
        
        return net

class SE_Block(nn.Module):
    "credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4"
    def __init__(self, c, r=16):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool1d(1)
        self.excitation = nn.Sequential(
            nn.Linear(c, c // r, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c // r, c, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        bs, c, _ = x.shape
        y = self.squeeze(x).view(bs, c)
        y = self.excitation(y).view(bs, c, 1)
        return x * y.expand_as(x)

class BasicBlock(nn.Module):
    """
    ResNet Basic Block
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, 
                groups, downsample, use_bn, use_do, is_first_block=False, is_se=False):
        super(BasicBlock, self).__init__()
        
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.stride = stride
        self.groups = groups
        self.downsample = downsample
        if self.downsample:
            self.stride = stride
        else:
            self.stride = 1
        self.is_first_block = is_first_block
        self.is_se = is_se
        self.use_bn = use_bn
        self.use_do = use_do

        # the first conv
        self.bn1 = nn.BatchNorm1d(in_channels)
        self.relu1 = nn.ReLU()
        self.do1 = nn.Dropout(p=0.5)
        self.conv1 = MyConv1dPadSame(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=self.stride,
            groups=self.groups)

        # the second conv
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu2 = nn.ReLU()
        self.do2 = nn.Dropout(p=0.5)
        self.conv2 = MyConv1dPadSame(
            in_channels=out_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=1,
            groups=self.groups)
                
        self.max_pool = MyMaxPool1dPadSame(kernel_size=self.stride)

        # Squeeze and excitation layer
        if self.is_se:  self.se = SE_Block(out_channels, 16)

    def forward(self, x):
        
        identity = x
        
        # the first conv
        out = x
        if not self.is_first_block:
            if self.use_bn:
                out = self.bn1(out)
            out = self.relu1(out)
            if self.use_do:
                out = self.do1(out)
        out = self.conv1(out)
        
        # the second conv
        if self.use_bn:
            out = self.bn2(out)
        out = self.relu2(out)
        if self.use_do:
            out = self.do2(out)
        out = self.conv2(out)
        
        # if downsample, also downsample identity
        if self.downsample:
            identity = self.max_pool(identity)
            
        # if expand channel, also pad zeros to identity
        if self.out_channels != self.in_channels:
            identity = identity.transpose(-1,-2)
            ch1 = (self.out_channels-self.in_channels)//2
            ch2 = self.out_channels-self.in_channels-ch1
            identity = F.pad(identity, (ch1, ch2), "constant", 0)
            identity = identity.transpose(-1,-2)
        
        # Squeeze and excitation layer
        if self.is_se: 
            out = self.se(out)
        # shortcut
        out += identity

        return out

class ResNetArch(nn.Module):
    """
    
    Input:
        X: (n_samples, n_channel, n_length)
        Y: (n_samples)
        
    Output:
        out: (n_samples, n_channel, fea_dim)
        
    Pararmetes:
        in_channels: dim of input, the same as n_channel
        base_filters: number of filters in the first several Conv layer, it will double at every 4 layers
        kernel_size: width of kernel
        stride: stride of kernel moving
        groups: set larget to 1 as ResNeXt
        n_block: number of blocks
        n_classes: number of classes
        
    """

    def __init__(self, in_channels, base_filters, first_kernel_size, kernel_size, stride, 
                        groups, n_block, is_se=False, downsample_gap=2, 
                        increasefilter_gap=2, use_bn=True, use_do=True, verbose=False):
        super(ResNetArch, self).__init__()
        
        self.verbose = verbose
        self.n_block = n_block
        self.first_kernel_size = first_kernel_size
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups
        self.use_bn = use_bn
        self.use_do = use_do
        self.is_se = is_se

        self.downsample_gap = downsample_gap # 2 for base model
        self.increasefilter_gap = increasefilter_gap # 4 for base model

        # first block
        self.first_block_conv = MyConv1dPadSame(in_channels=in_channels, out_channels=base_filters, kernel_size=self.first_kernel_size, stride=1)
        self.first_block_bn = nn.BatchNorm1d(base_filters)
        self.first_block_relu = nn.ReLU()
        self.first_block_maxpool = MyMaxPool1dPadSame(kernel_size=self.stride)
        self.out_channels = base_filters
                
        # residual blocks
        self.basicblock_list = nn.ModuleList()
        for i_block in range(self.n_block):
            # is_first_block
            if i_block == 0:
                is_first_block = True
            else:
                is_first_block = False
            # downsample at every self.downsample_gap blocks
            if i_block % self.downsample_gap == 1:
                downsample = True
            else:
                downsample = False
            # in_channels and out_channels
            if is_first_block:
                in_channels = base_filters
                self.out_channels = in_channels
            else:
                # increase filters at every self.increasefilter_gap blocks
                in_channels = int(base_filters*2**((i_block-1)//self.increasefilter_gap))
                if (i_block % self.increasefilter_gap == 0) and (i_block != 0):
                    self.out_channels = in_channels * 2
                else:
                    self.out_channels = in_channels
            
            tmp_block = BasicBlock(
                in_channels=in_channels, 
                out_channels=self.out_channels, 
                kernel_size=self.kernel_size, 
                stride = self.stride, 
                groups = self.groups, 
                downsample=downsample, 
                use_bn = self.use_bn, 
                use_do = self.use_do, 
                is_first_block=is_first_block,
                is_se=self.is_se)
            self.basicblock_list.append(tmp_block)

        # final prediction
        self.final_bn = nn.BatchNorm1d(self.out_channels)
        self.final_relu = nn.ReLU(inplace=True)

    # def forward(self, x):
    def forward(self, x):
        '''
        x["age"]: (n_batch, 1) # age, gender
        x["sex"]: (n_batch, 1) # gender
        x["signal"]: (n_batch, n_lead, x_dim)
        out: (n_batch, n_channel, feat_dim)
        '''
        assert len(x["signal"].shape) == 3

        # first conv
        if self.verbose:
            print('input shape', x["signal"].shape)
        out = self.first_block_conv(x["signal"])
        if self.verbose:
            print('after first conv', out.shape)
        if self.use_bn:
            out = self.first_block_bn(out)
        out = self.first_block_relu(out)
        out = self.first_block_maxpool(out)
        
        # residual blocks, every block has two conv
        for i_block in range(self.n_block):
            net = self.basicblock_list[i_block]
            if self.verbose:
                print('i_block: {0}, in_channels: {1}, out_channels: {2}, downsample: {3}'.format(i_block, net.in_channels, net.out_channels, net.downsample))
            out = net(out)
            if self.verbose:
                print(out.shape)

        # final prediction
        if self.use_bn:
            out = self.final_bn(out)
        h = self.final_relu(out)

        return h  

class RsnClf(nn.Module):
    def __init__(self, output_size, n_demo, **rsn_args):
        super().__init__()
        self.main_extract = ResNetArch(**rsn_args)
        out_channels = self.main_extract.out_channels
        self.n_demo = n_demo
        #------ Classifier
        if not self.n_demo is None:
            self.main_clf = nn.Linear(out_channels+self.n_demo, output_size)
        else:
            self.main_clf = nn.Linear(out_channels, output_size)
    
    def forward(self, x):
        '''
        x["age"]: (n_batch, 1) # age, gender
        x["sex"]: (n_batch, 1) # gender
        x["signal"]: (n_batch, n_lead, x_dim)
        out: (n_batch, n_channel, feat_dim)
        '''
        h = self.main_extract(x)
        h = h.mean(-1)
        if not self.n_demo is None:
            # Concat demo data:(nb_batch, nb_ppg_segment, 116)
            h = torch.cat((h, x["age"].unsqueeze(-1), x["sex"]), dim=-1)
        out = self.main_clf(h)
        return out 

#%%
class Resnet1d(Classifier):
    def __init__(self, config, random_state):
        super(Resnet1d, self).__init__(config, random_state)

        self.is_attention = self.is_attention=="True"
        self._engine = RsnClf(26, None, in_channels=12, base_filters=256, 
                              first_kernel_size=15, kernel_size=7, 
                              stride=3, groups=2, n_block=1, is_se=True)
        # self._engine = ResNet1D(self.input_channel, self.output_channel, self.kernel_size,
        #                         self.stride, self.groups, self.n_block, 
        #                         self.output_size, self.n_demo, is_se=self.is_se)
        self._optimizer = torch.optim.AdamW(self._engine.parameters(), lr=self.lr)  # Optimizer
        
        if torch.cuda.is_available():
            self._engine.cuda()

#%%
if __name__=='__main__':
#%% Run ResNetArch
    resnet = ResNetArch(12, 256, 15, 7, 3, 2, 1)
    x_demo = torch.randn(4, 2)
    x_sig = torch.randn(4, 12, 5000)
    x = {"signal": x_sig}
    output = resnet(x)
    print("output shape: ", output.shape)

#%% Run RsnClf
    resnet = RsnClf(26, None, in_channels=12, base_filters=256, first_kernel_size=15, 
            kernel_size=7, stride=3, groups=2, n_block=1, is_se=True)
    x_demo = torch.randn(4, 2)
    x_sig = torch.randn(4, 12, 5000)
    x = {"signal": x_sig}
    output = resnet(x)
    print("output shape: ", output.shape)

#%% Run SEResNet1D
    resnet = ResNet1D(12, 256, 15, 7, 3, 2, 1, 26, 2, is_se=True)
    x_demo = torch.randn(4, 2)
    x_sig = torch.randn(4, 12, 5000)
    output = resnet(x_demo, x_sig)