import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import Classifier, Dropout

class MLPVanillaArch(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, n_layers, dropout_prob):
        super(MLPVanillaArch, self).__init__()
        # Feature extractor        

        layers = []
        layers.append(nn.Linear(input_size, hidden_size))
        layers.append(nn.LeakyReLU(0.02, inplace=True))
        layers.append(Dropout(dropout_prob))

        for i in range(n_layers):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.LeakyReLU(0.02, inplace=True))
            layers.append(Dropout(dropout_prob))

        self.main = nn.Sequential(*layers)

        # Regressor
        self.clf = nn.Linear(hidden_size, output_size)

    def forward(self, x_data):
        # Use only the tabular data
        x = x_data

        # Rank for X should be 2: (N, F)
        if len(x.shape) !=2:
            raise RuntimeError("Tensor rank ({}) is not supported!".format(len(x.shape)))
            
        h = self.main(x)
        return self.clf(h) # output logits

class MLPVanilla(Classifier):

    def __init__(self, param_model, random_state, class_weight=None):
        super(MLPVanilla, self).__init__(param_model, random_state, class_weight=None)
        self.class_weight = class_weight
        self._engine = MLPVanillaArch(param_model.input_size, param_model.output_size, param_model.hidden_size, 
                                        param_model.n_layers, param_model.dropout_prob)
        self._optimizer = torch.optim.AdamW(self._engine.parameters(), lr=param_model.lr)  # Optimizer

        if torch.cuda.is_available():
            self._engine.cuda()