# -*- coding: utf-8 -*-
"""
Created on Tue May 18 17:04:57 2021

@author: Maurice
"""
# %% 
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import copy


class PooConvNorm(nn.Module):
    
    def __init__(self,in_channels,out_channels,parameters=None):
        super(PooConvNorm, self).__init__()
        self.net=nn.Sequential(nn.MaxPool1d(kernel_size=3,stride=2,padding=1),
                               nn.Conv1d(in_channels, out_channels, kernel_size=5,stride=2,padding=2),
                               nn.ReLU(),
                               nn.BatchNorm1d(out_channels))
        
    def forward(self,x):
        return self.net(x)
    
class InvPooConvNorm(nn.Module):
    
    def __init__(self,in_channels,out_channels,parameters=None):
        super(InvPooConvNorm,self).__init__()
        self.net=nn.Sequential(nn.ConvTranspose1d(in_channels, out_channels, kernel_size=5,stride=2,padding=2,output_padding=1),
                              nn.Upsample(scale_factor=2))
        
    def forward(self,x):
        return self.net(x)
    
class ResidualBlock(nn.Module):
    def __init__(self,in_channels,mid_channels):
        super(ResidualBlock,self).__init__()
        self.res = nn.Sequential(nn.Conv1d(in_channels, mid_channels, kernel_size=5,padding=2),
                                 nn.BatchNorm1d(mid_channels),
                                 nn.ReLU(),
                                 nn.Conv1d(mid_channels, in_channels, kernel_size=5,padding=2),
                                 nn.BatchNorm1d(in_channels))
        
    def forward(self,x):
        return x+self.res(x)

class DualAEClassification(nn.Module):
    # input only sequence length of 2^i
    def __init__(self,n_labels,seq_len,in_channels,lstm_h_size,dropout=0.1,parameters=None,y_multiplier=1):
        super(DualAEClassification, self).__init__()
        
        self.y_multiplier = y_multiplier # bad fix for signal normalization
        self.n_labels=n_labels
        self.dropout=dropout
        self.channels = [in_channels,8*in_channels,8*in_channels,16*in_channels,32*in_channels,64*in_channels,64*in_channels,64*in_channels]
        self.encoder = CEncoder(self.channels,self.dropout)
        self.decoder = CDecoder(self.channels)
        self.classifier = LSTMClass(n_labels, self.channels[-1], lstm_h_size, seq_len//512)
    
    def forward(self,x):
        z = self.encoder(x)
        y = self.decoder(z)
        labels = self.classifier(z)
        return y*self.y_multiplier,labels  
    
class DualAEClassificationV2(nn.Module):
    # input only sequence length of 2^i
    def __init__(self,n_labels,seq_len,in_channels,lstm_h_size,dropout=0.1,parameters=None,y_multiplier=1):
        super(DualAEClassificationV2, self).__init__()
        
        self.y_multiplier = y_multiplier # bad fix for signal normalization
        self.n_labels=n_labels
        self.dropout=dropout
        self.channels = [in_channels,8*in_channels,8*in_channels,16*in_channels,32*in_channels,64*in_channels,64*in_channels,64*in_channels]
        self.encoder = CEncoder(self.channels,self.dropout)
        self.decoder = CDecoder(self.channels)
        self.classifier = LSTMClassV2(n_labels, self.channels[-1], lstm_h_size, seq_len//512)
    
    def forward(self,x):
        z = self.encoder(x)
        y = self.decoder(z)
        labels = self.classifier(z)
        return y*self.y_multiplier,labels   
    
class DualAEClassificationV2Deep(nn.Module):
    # input only sequence length of 2^i
    def __init__(self,n_labels,seq_len,in_channels,lstm_h_size,dropout=0.1,parameters=None,y_multiplier=1):
        super(DualAEClassificationV2Deep, self).__init__()
        
        self.y_multiplier = y_multiplier # bad fix for signal normalization
        self.n_labels=n_labels
        self.dropout=dropout
        self.channels = [in_channels,8*in_channels,16*in_channels,32*in_channels,64*in_channels,128*in_channels,128*in_channels,64*in_channels,64*in_channels]
        self.encoder = CEncoder(self.channels,self.dropout)
        self.decoder = CDecoder(self.channels)
        self.classifier = LSTMClassV2Deep(n_labels, self.channels[-1], lstm_h_size, seq_len//512)
    
    def forward(self,x):
        z = self.encoder(x)
        y = x #self.decoder(z)
        labels = self.classifier(z)
        return y*self.y_multiplier,labels  
    
class DualAEClassificationV3Deep(nn.Module):
    # input only sequence length of 2^i
    def __init__(self,n_labels,seq_len,in_channels,lstm_h_size,dropout=0.1,parameters=None,y_multiplier=1):
        super(DualAEClassificationV3Deep, self).__init__()
        
        self.y_multiplier = y_multiplier # bad fix for signal normalization
        self.n_labels=n_labels
        self.dropout=dropout
        self.channels = [in_channels,8*in_channels,16*in_channels,32*in_channels,64*in_channels,128*in_channels,128*in_channels,64*in_channels]
        self.encoder = CEncoderDeepWide(self.channels,self.dropout)
        self.decoder = CDecoder(self.channels)
        self.classifier = LSTMClassV2Deep(n_labels, self.channels[-1], lstm_h_size, seq_len//512)
    
    def forward(self,x):
        z = self.encoder(x)
        y = x #self.decoder(z)
        labels = self.classifier(z)
        return y*self.y_multiplier,labels 
    
class CEncoder(nn.Module):
    def __init__(self,channels,dropout=0.1,parameters=None):
        super(CEncoder, self).__init__()
        self.channels=channels
        self.net=nn.Sequential(nn.Conv1d(channels[0], channels[1], kernel_size=5,stride=2,padding=2),
                               nn.ReLU(),
                               nn.Conv1d(channels[1], channels[2], kernel_size=5,stride=1,padding=2),
                               nn.ReLU(),
                               PooConvNorm(channels[2], channels[3]),
                               nn.Dropout(p=dropout),
                               PooConvNorm(channels[3], channels[4]),
                               PooConvNorm(channels[4], channels[5]),
                               nn.Dropout(p=dropout),
                               nn.Conv1d(channels[5], channels[6], kernel_size=5,stride=1,padding=2),
                               nn.ReLU(),
                               nn.BatchNorm1d(channels[6]),
                               nn.MaxPool1d(kernel_size=3,stride=2,padding=1),
                               nn.Conv1d(channels[6], channels[7], kernel_size=5, stride=2,padding=2),
                               nn.ReLU()
                               )
    
    def forward(self,x):
        assert x.shape[1]==self.channels[0]
        return self.net(x)
    
    
class CEncoderDeepWide(nn.Module):
    def __init__(self,channels,dropout=0.1,parameters=None):
        super(CEncoderDeepWide, self).__init__()
        self.channels=channels
        self.net=nn.Sequential(nn.Conv1d(channels[0], channels[1], kernel_size=5,stride=2,padding=2),
                                nn.ReLU(),
                                nn.Conv1d(channels[1], channels[2], kernel_size=5,stride=1,padding=2),
                                nn.ReLU(),
                                PooConvNorm(channels[2], channels[3]),
                                nn.Dropout(p=dropout),
                                PooConvNorm(channels[3], channels[4]),
                                ResidualBlock(channels[4], channels[3]),
                                ResidualBlock(channels[4], channels[3]),
                                PooConvNorm(channels[4], channels[5]),
                                nn.Dropout(p=dropout),
                                nn.Conv1d(channels[5], channels[6], kernel_size=5,stride=1,padding=2),
                                nn.ReLU(),
                                nn.BatchNorm1d(channels[6]),
                                nn.MaxPool1d(kernel_size=3,stride=2,padding=1),
                                nn.Conv1d(channels[6], channels[7], kernel_size=5, stride=2,padding=2),
                                nn.ReLU()
                                )
    
    def forward(self,x):
        assert x.shape[1]==self.channels[0]
        return self.net(x)    

class CDecoder(nn.Module):
    def __init__(self,channels,parameters=None):
        super(CDecoder, self).__init__()
        self.channels=channels
        self.net=nn.Sequential(nn.ConvTranspose1d(channels[-1],channels[-2],kernel_size=5,stride=2,padding=2,output_padding=1),
                               nn.ReLU(),
                               nn.Upsample(scale_factor=2),
                               nn.Conv1d(channels[-2], channels[-3], kernel_size=5,stride=1,padding=2),
                               nn.ReLU(),
                               InvPooConvNorm(channels[-3], channels[-4]),
                               InvPooConvNorm(channels[-4], channels[-5]),
                               InvPooConvNorm(channels[-5], channels[-6]),
                               nn.ConvTranspose1d(channels[-6], channels[-7], kernel_size=5,stride=1,padding=2),
                               nn.ReLU(),
                               nn.ConvTranspose1d(channels[-7], channels[-8], kernel_size=5,stride=2,padding=2,output_padding=1),
                               nn.Tanh()
                               )
    
    def forward(self,x):
        return self.net(x)
    
class LSTMClassV2(nn.Module):    
    def __init__(self,n_labels,input_size,hidden_size,seq_len,parameters=None):
        super(LSTMClassV2, self).__init__()
        self.lstm_hid_size = hidden_size
        self.lstm_in_size = input_size
        self.seq_len = seq_len
        self.n_labels = n_labels
        self.lstm = nn.LSTM(input_size,hidden_size=hidden_size)
        # add second path around LSTM
        self.fcConvFeatures = nn.Sequential(nn.Flatten(),
                                            nn.Linear(input_size*seq_len,1000),
                                            nn.ReLU(),
                                            nn.Linear(1000,256),
                                            nn.ReLU(),
                                            nn.Dropout(0.1))
        self.fcLSTMFeatures = nn.Sequential(nn.Linear(hidden_size*seq_len, 1000),
                                 nn.ReLU(),
                                 nn.Linear(1000,256),
                                 nn.ReLU(),
                                 nn.Dropout(0.1))
        self.output = nn.Sequential(nn.Linear(512, n_labels),
                                    nn.Sigmoid())
    
    def forward(self,x):
        
        device = x.device
        hidden = (torch.randn(1,x.shape[0],self.lstm_hid_size).to(device),torch.randn(1,x.shape[0],self.lstm_hid_size).to(device))
        x1,_ = self.lstm(x.permute(2,0,1),hidden)
        x1 = x1.permute(1,0,2).reshape(x1.shape[1],-1)
        x1 = self.fcLSTMFeatures(x1)
        x2 = self.fcConvFeatures(x)
        x = torch.cat((x1,x2),1)
        y = self.output(x)
        return y
    
class LSTMClassV2Deep(nn.Module):    
    def __init__(self,n_labels,input_size,hidden_size,seq_len,parameters=None):
        super(LSTMClassV2Deep, self).__init__()
        self.lstm_hid_size = hidden_size
        self.lstm_in_size = input_size
        self.seq_len = seq_len
        self.n_labels = n_labels
        self.num_layer=2
        self.lstm = nn.LSTM(input_size,hidden_size=hidden_size,num_layers=self.num_layer)
        # add second path around LSTM
        self.fcConvFeatures = nn.Sequential(nn.Flatten(),
                                            nn.Linear(input_size*seq_len,1000),
                                            nn.ReLU(),
                                            nn.Linear(1000,256),
                                            nn.ReLU(),
                                            nn.Dropout(0.1))
        self.fcLSTMFeatures = nn.Sequential(nn.Linear(hidden_size*seq_len, 1000),
                                 nn.ReLU(),
                                 nn.Linear(1000,256),
                                 nn.ReLU(),
                                 nn.Dropout(0.1))
        self.output = nn.Sequential(nn.Linear(512, n_labels),
                                    nn.Sigmoid())
    
    def forward(self,x):
        
        device = x.device
        hidden = (torch.randn(self.num_layer,x.shape[0],self.lstm_hid_size).to(device),torch.randn(self.num_layer,x.shape[0],self.lstm_hid_size).to(device))
        x1,_ = self.lstm(x.permute(2,0,1),hidden)
        x1 = x1.permute(1,0,2).reshape(x1.shape[1],-1)
        x1 = self.fcLSTMFeatures(x1)
        x2 = self.fcConvFeatures(x)
        x = torch.cat((x1,x2),1)
        y = self.output(x)
        return y

class LSTMClass(nn.Module):
    def __init__(self,n_labels,input_size,hidden_size,seq_len,parameters=None):
        super(LSTMClass, self).__init__()
        self.lstm_hid_size = hidden_size
        self.lstm_in_size = input_size
        self.seq_len = seq_len
        self.n_labels = n_labels
        self.lstm = nn.LSTM(input_size,hidden_size=hidden_size)
        self.fc1 = nn.Sequential(nn.Linear(hidden_size*seq_len, 1000),
                                 nn.ReLU(),
                                 nn.Linear(1000,256),
                                 nn.ReLU())
        self.output = nn.Sequential(nn.Linear(256, n_labels),
                                    nn.Sigmoid())
    
    def forward(self,x):
        
        device = x.device
        hidden = (torch.randn(1,x.shape[0],self.lstm_hid_size).to(device),torch.randn(1,x.shape[0],self.lstm_hid_size).to(device))
        x,_ = self.lstm(x.permute(2,0,1),hidden)
        x = x.permute(1,0,2).reshape(x.shape[1],-1)
        x = self.fc1(x)
        y = self.output(x)
        return y



class DualModelWrapper:
  def __init__(self, model, loss1_weight, optimizer,loss1 = nn.MSELoss(), loss2 = nn.MSELoss(), scheduler=None):
    self.model = model
    self.optimizer = optimizer
    self.scheduler = scheduler
    self.loss1 = loss1
    self.loss2 = loss2
    self.loss1_weight=loss1_weight
    self.history = list()

  def train(self,epochs,dataloader_train,dataloader_val,gpu=False):
      
    if(gpu):
        if torch.cuda.is_available():
            device=torch.device('cuda')
        else:
            device = torch.device('cpu')
    else:
        device = torch.device('cpu')
    self.model.to(device)        
        
    best_model_wts = copy.deepcopy(self.model.state_dict())
    best_loss = 100000.0
    train_size =len(dataloader_train.dataset)
    val_size = len(dataloader_val.dataset)
    train_losses = list()
    val_losses = list()
    for ep in range(epochs):
        train_loss = 0
        self.model.train() #set training mode (Dropout etc)
        for batch_nr,(X,Y) in enumerate(dataloader_train):
            # Prediction
            if len(Y.shape)==3:
                Y=Y.squeeze(1)
            X = X.float()
            Y = Y.float()
            X = X.to(device)
            Y = Y.to(device)
            signal,labels = self.model(X)
            loss = self.loss1_weight*self.loss1(signal,X)+(1-self.loss1_weight)*self.loss2(labels,Y)
            
            #Backprop
            self.optimizer.zero_grad()
            loss.backward()
            #clip gradients
            #nn.utils.clip_grad_norm(self.model.lstm.parameters(), 10)
            
            
            self.optimizer.step()
            
            train_loss += loss.item()
            #if batch % 100 == 0:
             #   loss,current = loss.item(), batch * len(X)
             #   print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
        train_loss /= train_size
        train_losses.append(train_loss)
        if self.scheduler:
                self.scheduler.step()
        
        with torch.no_grad():
            val_loss=0
            self.model.eval()
            for batch_nr,(X,Y) in enumerate(dataloader_val):
                X=X.float()
                Y=Y.float()
                X = X.to(device)
                Y = Y.to(device)
                signal,labels = self.model(X)
                loss = self.loss1_weight*self.loss1(signal,X)+(1-self.loss1_weight)*self.loss2(labels,Y)
                val_loss += loss.item()
            val_loss /= val_size
            val_losses.append(val_loss)
            
            if(val_loss<best_loss):
                best_model_wts = copy.deepcopy(self.model.state_dict())
        print('Epoch',ep,'Train_loss',train_loss,'Val_loss',val_loss)  
    self.model.load_state_dict(best_model_wts)
    self.history.append({'loss_train':train_losses,'loss_val':val_losses})

  def predict(self,X,gpu=False):
    self.model.eval()
    if(gpu):
        if torch.cuda.is_available():
            device=torch.device('cuda')
        else:
            device = torch.device('cpu')
    else:
        device = torch.device('cpu')
    self.model.to(device)
    if len(X.shape)< 3:
        X = X.reshape(1,1,-1)
    X = X.to(device)
    with torch.no_grad():
        signal,labels = self.model(X)
        #pred = pred.reshape(X.shape[0],-1)
        #pred = self.model.smax(pred)
    return signal,labels   

class SmallDataset(Dataset):
    def __init__(self,data):
        self.data = data # tuple(X,Y)
        self.datalen = len(data)
        
    def __len__(self):
        return self.datalen
    
    def __getitem__(self,idx):
        item = self.data[idx]
        return item

# add weighting according to class occurence
# y = [batch,classes]
def MHotLoss(y_tilde,y):
    s = y.sum(axis=1,keepdim=True)
    #y_hat = torch.exp(y_tilde)  # weird implication, if values 1 and 0 it makes them closer, opposite is desired
    y_hat = y_tilde/y_tilde.sum(axis=1,keepdim=True)*s
    loss = torch.mean((y-y_hat)**2)
    return loss

class MHotLoss_w:
    def __init__(self,W=None):
        self.W=W
    def __call__(self,y_tilde,y):   
        W=self.W
        if W!=None:
           assert W.shape[1]==y.shape[1] and W.shape[0] == 1
        else:
            W = torch.ones((1,y.shape[1]))
        W = W.to(y.device)    
        s = y.sum(axis=1,keepdim=True)
        #y_hat = torch.exp(y_tilde)  # weird implication, if values 1 and 0 it makes them closer, opposite is desired
        y_hat = y_tilde/y_tilde.sum(axis=1,keepdim=True)*s
        loss = torch.mean(W*(y-y_hat)**2)
        return loss
    
if __name__=='__main__':
    ins = torch.rand((16,2,2048))
    model = DualAEClassificationV2Deep(n_labels=2,seq_len=2048,in_channels=2,lstm_h_size=256,y_multiplier=5)    
    keck = model(ins)
    
    