import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
import os
import argparse

# crop size 3600, 4000 -> last_dim 192
# crop size 4500 -> last_dim 256

class Block(torch.nn.Module):
    def __init__(self, index, conv_in_channels, conv_out_channels, ibn_usage):
        super(Block, self).__init__()
        
        self.pool_kernel_size = 5
        self.pool_kernel_stride = 2
        self.pool_kernel_padding = 1

        self.conv_kernel_size = 7
        self.conv_kernel_stride = 1
        self.conv_kernel_padding = 1

        self.conv_in_channels = conv_in_channels
        self.conv_out_channels = conv_out_channels
        self.ibn_usage = ibn_usage

        self.dropout_ratio = 0.2

        self.stride = self.conv_kernel_padding
        self.index = index

        if index == 8:
            self.stride = 2
        
        # initialize the layers
        self.conv = torch.nn.Conv1d(in_channels=self.conv_in_channels,          \
                                out_channels=self.conv_out_channels,            \
                                    kernel_size=self.conv_kernel_size,          \
                                        stride=self.stride,                     \
                                            padding=self.conv_kernel_padding)
        self.bn = torch.nn.BatchNorm1d(self.conv_out_channels)
        self.relu = torch.nn.ReLU()
        self.avg_pool = torch.nn.AvgPool1d(kernel_size=self.pool_kernel_size,   \
                                        stride=self.pool_kernel_stride,         \
                                            padding=self.pool_kernel_padding)
        self.drop_out = torch.nn.Dropout(p=self.dropout_ratio)

        # for ibn
        self.bn_ibn = torch.nn.BatchNorm1d(int(self.conv_out_channels/2))
        self.ibn_layer = torch.nn.InstanceNorm1d(int(self.conv_out_channels/2))

    def forward(self, x):
        x = self.conv(x)
        if self.ibn_usage is True:
            out1, out2 = torch.split(x, split_size_or_sections=int(self.conv_out_channels/2), dim=1)
            out1 = self.bn_ibn(out1)
            out2 = self.ibn_layer(out2)
            x = torch.cat((out1, out2), dim=1)
        else:
            x = self.bn(x)
        
        x = self.relu(x)
        
        if self.index >= 3 and self.index != 8:
            x = self.drop_out(x)
        
        if self.index != 8:
            x = self.avg_pool(x)

        x = self.drop_out(x)
        return x

class CNN_jp(torch.nn.Module):
    def __init__(self, num_leads, num_classes, configs, info_for_test=None):
        super(CNN_jp, self).__init__()
        self.num_leads = num_leads
        self.num_classes = num_classes
        self.configs = configs
        self.info_for_test = info_for_test

        self.conv_in_channels = [num_leads, 128, 64, 128, 64, 128, 64, 128, 64]
        self.conv_out_channels = [128, 64, 128, 64, 128, 64, 128, 64, 64]
        self.ibn_usage = [False, False, False, False, False, False, False, False, False]
        # self.ibn_usage = [True, True, True, True, False, False, False, False, False]

        self.num_features = 20

        self.last_dim = 192
        self.proj_dim = 128

        with open("configs/configs.yaml") as f:
            global_config = yaml.load(f, Loader=yaml.FullLoader)

        if global_config["TRANS_BASE_PARAMS"]["random_crop_size"] == 4500:
            self.last_dim = 256

        self.fc1 = torch.nn.Linear(self.last_dim, 64)
        self.fc2 = torch.nn.Linear(64, self.num_classes)
        self.fc3 = torch.nn.Linear(64+self.num_features, self.num_classes)

        self.projection_head = torch.nn.Sequential(
                    torch.nn.Linear(self.last_dim, self.proj_dim),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.proj_dim, self.proj_dim)
                )
        self.ln1 = torch.nn.Linear(self.num_features, 32)
        self.ln2 = torch.nn.Linear(32, self.num_features)

        self.layer1 = self.make_layer(0, self.conv_in_channels[0], self.conv_out_channels[0], self.ibn_usage[0])
        self.layer2 = self.make_layer(1, self.conv_in_channels[1], self.conv_out_channels[1], self.ibn_usage[1])
        self.layer3 = self.make_layer(2, self.conv_in_channels[2], self.conv_out_channels[2], self.ibn_usage[2])
        self.layer4 = self.make_layer(3, self.conv_in_channels[3], self.conv_out_channels[3], self.ibn_usage[3])
        self.layer5 = self.make_layer(4, self.conv_in_channels[4], self.conv_out_channels[4], self.ibn_usage[4])
        self.layer6 = self.make_layer(5, self.conv_in_channels[5], self.conv_out_channels[5], self.ibn_usage[5])
        self.layer7 = self.make_layer(6, self.conv_in_channels[6], self.conv_out_channels[6], self.ibn_usage[6])
        self.layer8 = self.make_layer(7, self.conv_in_channels[7], self.conv_out_channels[7], self.ibn_usage[7])
        self.layer9 = self.make_layer(8, self.conv_in_channels[8], self.conv_out_channels[8], self.ibn_usage[8])

    def make_layer(self, index, conv_in_channels, conv_out_channels, ibn_usage):
        layers = []
        layers.append(Block(index, conv_in_channels, conv_out_channels, ibn_usage))
        return torch.nn.Sequential(*layers)

    def forward(self, input_data, features=None):
        x = input_data

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        x = self.layer8(x)
        x = self.layer9(x)

        x = torch.flatten(x,1)
        x_proj = self.projection_head(x)

        x = self.fc1(x)
        x = torch.nn.ReLU()(x)

        if features != None:
            features = self.ln1(features)
            features = torch.nn.ReLU()(features)
            features = self.ln2(features)
            x = self.fc3(torch.cat((x, features), dim = 1))
        else:
            x = self.fc2(x)

        return x, x_proj
