# this repo is for those unused code

import numpy as np
import torch
from torch.utils.data import Dataset
from helper_code import load_header, get_labels, load_recording, get_age, get_sex, get_leads, get_adcgains,  get_baselines #, twelve_leads
class CustomDataset(Dataset):
    def __init__(self, header_files, recording_files, classes, leads, len_signal, eva_class_indices, multi_patch=False, len_overlap=256, transform=None, target_transform=None):
        self.header_files = header_files
        self.recording_files = recording_files
        self.transform = transform
        self.target_transform = target_transform
        self.classes = classes
        self.num_classes = len(classes)
        self.leads = leads
        self.multi_patch = multi_patch
        self.len_signal = len_signal
        self.len_overlap = len_overlap
        self.eva_class_indices = eva_class_indices
        # self.twelve_leads = twelve_leads

    def __len__(self):
        return len(self.header_files)

    def __getitem__(self, idx):
        header = load_header(self.header_files[idx])
        recording = load_recording(self.recording_files[idx])
        recording = torch.Tensor(recording)
        attachedpro = torch.zeros(5, dtype=torch.float32)
        # Extract age.
        age = get_age(header)
        if age is None:
            age = float('nan')

        if age >= 0 and age <= 120:
            attachedpro[0] = age
        else:
            attachedpro[1] = 1

        # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
        sex = get_sex(header)
        if sex in ('Female', 'female', 'F', 'f'):
            attachedpro[2] = 1
            # sex = 0
        elif sex in ('Male', 'male', 'M', 'm'):
            attachedpro[3] = 1
            # sex = 1
        else:
            attachedpro[4] = 1

        # Reorder/reselect leads in recordings.
        available_leads = get_leads(header)
        indices = list()
        for lead in self.leads:
            i = available_leads.index(lead)
            indices.append(i)
        recording = recording[indices, :]
   
        # Pre-process recordings.
        adc_gains = get_adcgains(header, self.leads)
        baselines = get_baselines(header, self.leads)
        len_recording = len(recording[0])
        recording = (recording - np.expand_dims(baselines, axis=1)) / np.expand_dims(adc_gains, axis=1) 

        num_leads = len(self.leads)

        if not self.multi_patch:
            data = torch.zeros((num_leads, self.len_signal))
            if len_recording > self.len_signal:
                data = recording[:,:self.len_signal]
            else:
                data[ :, 0 : len_recording] = recording
        else:
            num_patch = int( np.ceil((len_recording - self.len_signal) / (self.len_signal - self.len_overlap)) + 1)
            data = torch.zeros((num_patch, num_leads, self.len_signal))

            cnt = 0
            start = 0
            if len_recording > self.len_signal:
                while (len_recording - start)  > self.len_signal:
                    data[cnt] = recording[:, start : start + self.len_signal]
                    cnt += 1
                    start += self.len_signal - self.len_overlap
                data[cnt] = recording[:, len_recording - self.len_signal : len_recording]
            else:
                try:
                    data[0, :, 0:len_recording] = recording
                except:
                    pdb.set_trace()

        # only count target classes (there are some classes of illness we do not identify) 
        current_labels = get_labels(header)
        this_labels = torch.zeros(self.num_classes, dtype=torch.float) 
        for label in current_labels:
            if label in self.classes:
                j = self.classes.index(label)
                this_labels[j] = 1
        labels = this_labels[self.eva_class_indices]

        sample = tuple((data, attachedpro, labels))
        return sample



def train_model(leads, recording_files, header_files, classes, weights, eva_class_indices, eva_classes, normal_class, filename, num_epochs, device, writer, len_signal, multi_patch=False, len_overlap=256):
    
    num_leads = len(leads)
    batch_size = 64 # 64
    
    # feature_indices = [twelve_leads.index(lead) for lead in leads]
    # features = data[:, feature_indices ,:] # data [43101, 12, 4096]
    # print('feature_indices :',feature_indices)
    # labels = labels[:, eva_class_indices]
    # print('labels shape :',labels.shape)

    # ####################### Train ECG model. #######################
    # X_train  = features
    # Y_train  = labels
    # attachedpro_train = attachedpro
    # X_train = torch.Tensor(X_train)
    # attachedpro_train = torch.Tensor(attachedpro_train)
    # Y_train = torch.Tensor(Y_train)

    # train_data = torch.utils.data.TensorDataset(X_train, attachedpro_train, Y_train)
    train_data = CustomDataset(header_files, recording_files, classes, leads, len_signal, eva_class_indices, multi_patch, len_overlap)
    train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)

    # instantiate a classifier
    classifier = resnet18(in_channel=num_leads) # resnet18(out_channel=num_classes) # num_classes = 111?
    print('# classifier total parameters:', sum(param.numel() for param in classifier.parameters()))
    net = classifier.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)

    # @Todo loss function
    criterion = torch.nn.BCEWithLogitsLoss()
    # @Todo Adam optimizer? learning rate is too large?
    optimizer = optim.SGD(net.parameters(), lr=3e-2,
                        momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    
    for epoch in range(num_epochs):
        print('\nEpoch: %d' % epoch)
        net.train()
        train_loss = 0
        for batch_idx, (inputs, ags, targets) in enumerate(train_iter):
            inputs, ags, targets = inputs.to(device), ags.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs, ags)
            loss = compute_challenge_loss(weights, targets, outputs, eva_classes, eva_class_indices, normal_class)
            loss.backward()
            optimizer.step()
                
            train_loss += loss.item()
            writer.add_scalar("lead{}/train/alldata/challenge_loss".format(num_leads), loss, epoch*len(train_iter)+batch_idx)
            # print useful information about loss
            scalar_outputs = outputs.cpu().detach().numpy()
            ilabels = targets.cpu().detach().numpy()
            auroc, auprc, auroc_classes, auprc_classes = compute_auc(ilabels, scalar_outputs) #[111,]
            writer.add_scalar("lead{}/train/alldata/auroc".format(num_leads), auroc, epoch*len(train_iter)+batch_idx)
                
            binary_outputs = np.zeros(scalar_outputs.shape)
            binary_outputs[scalar_outputs > 0.5] = 1
            challenge_metric = compute_challenge_metric(weights, ilabels, binary_outputs, eva_classes, normal_class)
            writer.add_scalar("lead{}/train/alldata/challenge_score".format(num_leads), challenge_metric, epoch*len(train_iter)+batch_idx)
            # progress_bar(batch_idx, len(train_iter), 'Loss: %.3f ' % (train_loss/(batch_idx+1)))
            if batch_idx % 100 == 1:
                print(batch_idx, len(train_iter), 'Loss: %.3f ' % (train_loss/(batch_idx+1)))
        scheduler.step()
        
    save_model(filename, eva_classes, leads, classifier, imputer=None)


