import torch
from torch.utils.data import Dataset, DataLoader
import torch_rbf as rbf
#############################################################################################################################
######                      Definition of neural network classifier                                                ##########
#############################################################################################################################

class MyDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        #Returns number of recodings, not number of features
        return self.features.shape[0]

    def __getitem__(self, index):
        feature = self.features[index]
        label = self.labels[index]
        return (feature, label)

class One_Vs_All_Net(torch.nn.Module):
    def __init__(self,num_features,num_labels):
        super(One_Vs_All_Net, self).__init__()
        self.num_features = num_features
        self.num_labels = num_labels
        self.Base_Model = SingleModel(self.num_features)
        self.classifiers = torch.nn.ModuleList([self.Base_Model for i in range(self.num_labels)]) 
    def forward(self,indx,feature):
        prediction = self.classifiers[indx].forward(feature)
        return prediction
    
    def fit(self,features, labels):
        loss_func = torch.nn.BCELoss()
        batch_size = 512
        lr = 0.1
        epochs = 20
        trainingset = MyDataset(torch.from_numpy(features).float(),torch.from_numpy(labels).float())
        trainloader = DataLoader(trainingset,batch_size = batch_size,shuffle=True)
        optimiser = torch.optim.Adam(self.parameters(), lr=lr,weight_decay=0.1)
        epoch = 0
        for epoch in range(epochs):
            epoch += 1
            loss_epoch = 0
            for feature, label in trainloader:
                for i in range(self.num_labels):
                    loss = 0
                    loss = loss_func(self.forward(i, feature),label[:,i].reshape([-1,1]))
                    loss_epoch+=loss/128
                    loss.backward()
                    optimiser.step()
                    optimiser.zero_grad()
            print(epoch, loss_epoch)
            
            
    def predict(self,features):
        self.eval()
        with torch.no_grad():
            probabilities = torch.zeros(self.num_labels)
            for i in range(self.num_labels):
                probabilities[i] = self.forward(i,torch.from_numpy(features).float()) #Calc probabilites
            labels = torch.round(probabilities) 
            #labels[:,torch.argmax(probabilities,1)] = 1     #Get label by setting max prob. to one and rest to zero could also be done by threshold
            return labels.numpy()       
    
    
    def predict_proba(self,features):
        self.eval()
        with torch.no_grad():
            probabilities = torch.zeros(self.num_labels)
            for i in range(self.num_labels):
                probabilities[i] = self.forward(i,torch.from_numpy(features).float())
        return probabilities.numpy()[0]  




class SingleModel(torch.nn.Module):
    def __init__(self,num_features):
        super(SingleModel,self).__init__()
        self.num_features = num_features
        self.Dense1 =  torch.nn.Sequential(torch.nn.Linear(self.num_features,self.num_features,bias=False),torch.nn.Tanh())
        self.Dropout1 = torch.nn.Dropout(p=0.1)
        self.Dense2 =   torch.nn.Sequential(torch.nn.Linear(self.num_features,512,bias=False),torch.nn.Tanh())
        self.Dropout2 = torch.nn.Dropout(p=0.1)
        self.Dense3 =   torch.nn.Sequential(torch.nn.Linear(512,64,bias=False),torch.nn.Tanh())
        self.Dropout3 = torch.nn.Dropout(p=0.1)
        self.Dense4 =   torch.nn.Sequential(torch.nn.Linear(64,1,bias=False))
        self.OutputFCT = torch.nn.Sigmoid()
        
    def forward(self,x):
        DenseOut1 = self.Dense1(x)
        DropOut1 = self.Dropout1(DenseOut1)
        DenseOut2 = self.Dense2(DropOut1)
        DropOut2 = self.Dropout2(DenseOut2)
        DenseOut3 = self.Dense3(DropOut2)
        DropOut3 = self.Dropout3(DenseOut3)
        DenseOut4 = self.Dense4(DropOut3)
        out = self.OutputFCT(DenseOut4)
        return out
    
    

            
    



class NeuralNetworkClassifier(torch.nn.Module):
    def __init__(self,num_features,num_labels,scaler):
        super(NeuralNetworkClassifier,self).__init__()
        self.scaler = scaler
        #self.num_labels = num_labels
        self.Dense1 = torch.nn.Sequential(torch.nn.Linear(num_features,512,bias=False),torch.nn.Sigmoid())
       # self.RBF = rbf.RBF(num_features,512,rbf.gaussian)
        #self.Dense2 = torch.nn.Sequential(torch.nn.Linear(100,100,bias=False), torch.nn.Sigmoid())
        self.Dense3 = torch.nn.Sequential(torch.nn.Linear(512,64,bias=False))
        #self.DropOut1 = torch.nn.Dropout(p=0.2)
        #self.Dense4 = torch.nn.Sequential(torch.nn.Linear(64,32,bias=False))
        #self.DropOut2 = torch.nn.Dropout(p=0.2)
        self.Dense5 = torch.nn.Sequential(torch.nn.Linear(64,num_labels,bias=False),torch.nn.Sigmoid())
        #self.Ouput = torch.nn.Sigmoid()


    def forward(self,x):
        # Might need a from numpy and to numpy?!
        # MLP 
        out1 = self.Dense1(x)
        #print(out1)
        # RBF
        #out2 = self.RBF(x)
        # MLP
        #out3 = self.Dense2(out2)
        # MLP
        out4 = self.Dense3(out1)
        # Dropout
        #out5 = self.DropOut1(out4)
        # MLP
        #out6 = self.Dense4(out5)
        # Dropout
        #out7 = self.DropOut2(out6)
        # MLP
        output = self.Dense5(out4)

        #output = self.Ouput(out8)
        #print(out8)
        return output


    def fit(self, features, labels, epochs=1, batch_size=1, lr=0.001, loss_func=torch.nn.BCELoss(reduction='sum')): #Function for training of NN
        self.train()
        trainingset = MyDataset(torch.from_numpy(features).float(),torch.from_numpy(labels).float())
        trainloader = DataLoader(trainingset,batch_size = batch_size,shuffle=True)
        optimiser = torch.optim.Adam(self.parameters(), lr=lr,weight_decay=0.1)
        epoch = 0
        for epoch in range(epochs):
            epoch += 1
            loss = 0
            for feature, label in trainloader:
                loss = loss_func(self.forward(feature),label)
                loss.backward()
                optimiser.step()
                optimiser.zero_grad()
            print(epoch, loss)


    def predict_probs(self,features): #Function for evaluating the NN
        self.eval()
        with torch.no_grad():
            probabilities = self.forward(torch.from_numpy(features).float())
        return probabilities.numpy()[0]      


    def predict_labels(self,features):
        self.eval()
        with torch.no_grad():
            probabilities = self.forward(torch.from_numpy(features).float()) #Calc probabilites
            labels = torch.round(probabilities) 
            #labels[:,torch.argmax(probabilities,1)] = 1     #Get label by setting max prob. to one and rest to zero could also be done by threshold
        return labels.numpy()   

    def get_scaler(self):
        return self.scaler