from constants import *
import math
import torch.nn.functional as F


class CNN_cell(nn.Module):
    def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
        super(CNN_cell, self).__init__()
        self.cell = nn.Sequential(
            nn.Conv1d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm1d(channels_out),
            nn.LeakyReLU(),
            nn.Conv1d(channels_out, channels_out, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm1d(channels_out),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        out = self.cell(x)
        return out


class Shortcut(nn.Module):
    def __init__(self, channels, dropout):
        super(Shortcut, self).__init__()
        self.cell1 = CNN_cell(channels, channels, kernel_size=3, stride=1, padding=1)
        self.cell2 = CNN_cell(channels, channels, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Sequential(
            nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(channels),
        )
        self.after_shortcut = nn.Sequential(
            # nn.ReLU(inplace=True),
            nn.LeakyReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(p=dropout)
        )

    def forward(self, x):
        identity = x

        out = self.cell1(x)
        out = self.cell2(out)
        out = self.conv3(out)

        out += identity
        out = self.after_shortcut(out)

        return out


class AttentionWithContext(nn.Module):  # expects input in (S,B,H)
    # NLP:: S= word position, B= sentence, H= word details (which word)
    # ECG:: S= sequence posn., B= patient, H= lead (which lead, after BiGRU 2 reps/lead exist)
    # attention computed over sequence dimension (not H dimension which is leads in our case)

    def __init__(self, linear_dims, context_dims, use_bias):
        super(AttentionWithContext, self).__init__()

        # MLP layer
        self.mlp_layer = nn.Linear(in_features=linear_dims, out_features=linear_dims, bias=use_bias)
        # context vector
        self.context_vector = nn.Parameter(torch.empty(context_dims, 1))  # mark as trainable
        nn.init.xavier_uniform_(self.context_vector, gain=nn.init.calculate_gain('relu'))

    def forward(self, x):
        u_it = torch.tanh(self.mlp_layer(x))  # MLP
        a_it = torch.matmul(u_it, self.context_vector).squeeze(dim=2)  # compute u_it (dot) u_w
        a_it = F.softmax(a_it, dim=0)  # softmax along sequence dim

        att_out = x * a_it.unsqueeze(2)  # element-wise multiplication
        att_out = torch.sum(att_out, 0)

        return att_out


class ecg_classifier(nn.Module):
    def __init__(self, dropout_cnn, hidden_size, dropout_gru, batch_size):
        super(ecg_classifier, self).__init__()
        gru_shape = 5000
        self.hidden_size = hidden_size
        channel_out = 64
        snr_output = 40000
        channel_out_snr = 32
        self.batch_size = batch_size

        self.other_features = nn.Sequential(
            CNN_cell(12, 32, kernel_size=3, stride=1, padding=1),
            # nn.MaxPool1d(2),
            nn.Dropout(p=dropout_cnn),
            CNN_cell(32, channel_out, kernel_size=3, stride=1, padding=1),
            # nn.MaxPool1d(2),
            nn.Dropout(p=dropout_cnn),
            # CNN_cell(channel_out, channel_out, kernel_size=3, stride=1, padding=1),
            # nn.MaxPool1d(2),
            # nn.Dropout(p=dropout_cnn),
            # Shortcut(32, dropout_cnn),
            # CNN_cell(32, channel_out, kernel_size=3, stride=1, padding=1),
            # nn.MaxPool1d(2),
            # nn.Dropout(p=dropout_cnn),
            # Shortcut(channel_out, dropout_cnn),
            # Shortcut(channel_out, dropout_cnn),
            # Shortcut(channel_out, dropout_cnn),
            # Shortcut(channel_out, dropout_cnn),
            # CNN_cell(64, channel_out, kernel_size=3, stride=1, padding=1),
            # nn.Dropout(p=dropout_cnn),
            # CNN_cell(channel_out, channel_out, kernel_size=3, stride=1, padding=1),
        )

        self.snr_features = nn.Sequential(
            CNN_cell(12, 32, kernel_size=3, stride=1, padding=1),
            nn.MaxPool1d(2),
            nn.Dropout(p=dropout_cnn),
            Shortcut(channel_out_snr, dropout_cnn),
            # CNN_cell(64, channel_out_snr, kernel_size=3, stride=1, padding=1),
            # nn.MaxPool1d(2),
            # nn.Dropout(p=dropout_cnn),
            # CNN_cell(channel_out_snr, channel_out_snr, kernel_size=3, stride=1, padding=1),
            # nn.MaxPool1d(2),
            # nn.Dropout(p=dropout_cnn),
        )

        self.rnn = nn.GRU(gru_shape, hidden_size, num_layers=2, batch_first=True,
                          bidirectional=True, dropout=dropout_gru)
        # self.others_attention = AttentionWithContext(channel_out * 2 * hidden_size, 128, True)

        self.fc = nn.ModuleList()
        for i in range(n_pathologies):
            self.fc.append(nn.Sequential(
                nn.BatchNorm1d(channel_out * 2 * hidden_size),
                nn.Linear(channel_out * 2 * hidden_size, 1, bias=True),
                # nn.Softmax()
            ))

        self.others_acti = nn.Softmax(dim=1)
        self.snr_classifier = nn.Sequential(
            nn.BatchNorm1d(snr_output),
            nn.Linear(snr_output, 1, bias=True),
            nn.Sigmoid()
        )

        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # general_features = self.extraction_features(x)
        # return torch.squeeze(self.regular_linear(x))
        snr_features = self.snr_features(x)
        snr_features = snr_features.view(self.batch_size, -1)
        snr_output = self.snr_classifier(snr_features)

        other_features = self.other_features(x)
        # other_features = self.others_attention(other_features)
        hidden = self.initHidden()
        rnn_output, hidden = self.rnn(other_features, hidden)
        rnn_output = rnn_output.contiguous().view(self.batch_size, -1)

        # other_features = other_features.contiguous().view(self.batch_size, -1)
        multi_head = torch.empty(self.batch_size, 0, dtype=torch.float)
        for i in range(n_pathologies):
            head = self.fc[i](rnn_output)
            multi_head = torch.cat((multi_head, head), dim=1)
        multi_head = self.others_acti(multi_head)
        out = torch.cat((multi_head, snr_output), dim=1)
        # binary_output = torch.where(out > threshold, torch.ones(out.shape), torch.zeros(out.shape))
        return out

    def initHidden(self):
        return torch.zeros(2 * 2, self.batch_size, self.hidden_size)
