#!/usr/bin/env python

# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of required functions, remove non-required functions, and add your own function.

from abc import abstractclassmethod
import numpy as np, os, sys, joblib, sklearn
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold
import pdb

from utils.helper_code import format_time, twelve_leads, six_leads, four_leads, three_leads, two_leads, train_val_dataset, find_challenge_files, load_header, load_recording, parse_labels
from utils.sampler import MultilabelBalancedRandomSampler
from utils.data_process import get_features, process_data
from utils.metrics import compute_confusion_matrices, compute_challenge_loss, compute_challenge_metric, compute_challenge_metric_fast, compute_auc
from model_architecture.res_models import ResNet, resnet18, resnet34, resnet50, resnet101, resnet152
from model_architecture import res_models
from model_architecture.efficientnet_1D import select_model

twelve_lead_model_filename = '12_lead_model.sav'
six_lead_model_filename = '6_lead_model.sav'
four_lead_model_filename = '4_lead_model.sav'
three_lead_model_filename = '3_lead_model.sav'
two_lead_model_filename = '2_lead_model.sav'

import sys
################################################################################
#
# Training function
#
################################################################################
try:
    import nni
    from nni.utils import merge_parameter
except:
    pass
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import math
import torch
import torch.optim as optim
import torch.nn.functional as F
import sys
import time
import argparse, os
from torch.utils.data import Dataset
from tensorboardX import SummaryWriter
import wfdb
from wfdb import processing
import scipy
# eliminate randomness
import random
seed = 2021
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

time_stamp = time.strftime("%m%d-%H%M%S", time.localtime())
# -----------------------input size>=32---------------------------------
# from efficientnet_pytorch import EfficientNet
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']


# _, term_width = os.popen('stty size', 'r').read().split()
# term_width = int(term_width)

# TOTAL_BAR_LENGTH = 65.
# last_time = time.time()
# begin_time = last_time



"""checking arguments"""
def set_args(print_details=False):
    # --logdir
    logdir = "logdir"
    if not os.path.exists(logdir):
        os.makedirs(logdir)


    device = torch.device("cuda")

    if print_details:
        writer = SummaryWriter(log_dir=os.path.join(logdir,time_stamp))
    else:
        writer = None

    params = {"writer": writer, "device": device, "batch_size": 64, "test_batch_size": 1000,
             "mixup_alpha": 0.15, "random_seed": None,

             "len_signal": 4096, # when changing this value, the one in run_model() should be changed as well

            # if want to set bce_cof, must set bce_penalty=True
            "bce_penalty": True, "bce_cof":1.0, "challenge_cof":0.1, "threshold":0.1,
              
            "pre_train_12leads": False,


            "multi_patch": False,  # will affect test data; won't affect train data since the multi-patch implementation ahs been commented
            "Othertypes": False, 
            "multi_val": True, 
            # "perturbations": False, 
            "num_epochs":50, "lr":3e-3, "weight_decay":5e-4, "momentum":0.9, "is_nni":False,
            "optimizer": "Adam", # choice: ["Adam", "SGD"]

            "preprocess": True, # to use preprocessed file or preprocess from scratch
            "save_preprocessed_data": False,

            #  "split_train_val" and "is_test" is mutual exclusive
            "split_train_val": True, # whether to split train_dataset into trainset and validation set
            "is_test":False, # whether to evaluate in new_testing data
            "test_file": "new_testing_data", # "new_testing_data",
            
            "sampler": None, # "MultilabelBalancedRandomSampler", # whether to use a sampler for imbalance data in Dataloader

            "gpu_ids": "0",
            "add_domain_knowledge": False,

            "clf_nn": "resnet18", # "resnet18","efficientnet-b4"
            "kernel_size":11, "alpha":1, "beta":1, "phi":1, 

            "is_sigmoid_in_resnet": True,
             
              
            "train_mode": "train_whole_model", # have several options: ["train_whole_model", "train_test_model", "train_test_model_mixup", "train_model_mixup"]
    }


    return params

# choose a single threshold for all classes
def searchthres_global(concat_outputs, concat_targets, weights, eva_classes, normal_class):
    thres_cand = np.array(range(10))/10
    thres_challenge = []
    for thres in thres_cand:
        binary_outputs1 = np.zeros(concat_outputs.shape)
        binary_outputs1[concat_outputs > thres] = 1
        challenge_metric = compute_challenge_metric_fast(weights, concat_targets, binary_outputs1, eva_classes, normal_class)
        thres_challenge.append(challenge_metric.cpu().numpy())
    best_idx = np.argmax(thres_challenge)
    print("Best threshold is {:.2f} with score {:.5f}".format(thres_cand[best_idx], thres_challenge[best_idx]))
    return thres_cand[best_idx], thres_challenge[best_idx]


def searchthres(tot_scalar_outputs, tot_labels, weights, eva_classes, normal_class, global_threshold):
    thresholds = []
    for i in range(24):
        binary_outputs = np.zeros(tot_scalar_outputs.shape)
        iclassout = tot_scalar_outputs[:,i]
        iclasslabel = tot_labels[:,i]
        max_challenge_metric = 0.
        maxthres = 0.
        for j in range(99):
            jthreshold = 0.01 + j* 0.01
            binary_outputs = tot_scalar_outputs > global_threshold
            binary_outputs[:, i] = tot_scalar_outputs[:,i] > jthreshold
            for k in range(len(thresholds )):
                binary_outputs[:, k] = tot_scalar_outputs[:,k] > thresholds[k]
            jchallenge_metric = compute_challenge_metric_fast(weights, tot_labels, binary_outputs, eva_classes, normal_class)
            if jchallenge_metric > max_challenge_metric:
                maxthres = jthreshold
                max_challenge_metric = jchallenge_metric
        thresholds.append(maxthres)
        # print('{}th class best thres={} challenge={}'.format(i, maxthres, max_challenge_metric))
    binary_outputs = np.zeros(tot_scalar_outputs.shape)
    for i in range(24):
        binary_outputs[:, i] = tot_scalar_outputs[:,i] > thresholds[i]
    challenge_metric = compute_challenge_metric_fast(weights, tot_labels, binary_outputs, eva_classes, normal_class)
    print('searched challenge_metric :',challenge_metric)
    return thresholds, challenge_metric


class mydataset(Dataset):
    def __init__(self, data, labels, attachedpro, num_leads, len_signal, feature_indices, multi_patch=False, perturbations = False):
        self.data = data
        self.attachpro  = attachedpro
        self.labels     = labels
        self.len = len(data)
        self.num_leads = num_leads
        self.len_signal = len_signal
        self.feature_indices = feature_indices
        self.perturbations = perturbations
        self.perturbations_num = 3
        # multi_patch default as False, set to be True for validation set
        self.multi_patch = multi_patch
        self.len_overlap = 256          
        # np.random.seed(233)

    def __len__(self):
        return self.len
    def __getitem__(self,index):
        idata = self.data[index]
        # print('idata shape:',idata.shape)
        idata = idata[self.feature_indices,:]
        iattachpro = self.attachpro[index]
        ilabel = self.labels[index]
        len_recording = len(idata[0])
        
        if self.multi_patch:
            num_patch = int( np.ceil((len_recording - self.len_signal) / (self.len_signal - self.len_overlap)) + 1)
            out = torch.empty((num_patch, self.num_leads, self.len_signal))
        else:
            out = torch.empty((self.num_leads, self.len_signal))
                  
        adice = np.random.random()
        if self.perturbations and adice < 0.1:
            perturb_ticks = np.random.choice(len_recording, self.perturbations_num) + 1
            new_len_recording = len_recording
            for tick in perturb_ticks:
                tick = min(tick, len_recording -  1)
                dice = np.random.random()
                if dice < 0.3:
                    idata[:, tick] = 0.
                elif dice < 0.7:
                    # print('len_recording :',len_recording)
                    # print('idata shape:',idata.shape)
                    insertv = (idata[:,tick-1] + idata[:, tick]) / 2
                    idata = torch.cat((idata[:,:tick], insertv.unsqueeze(1), idata[:,tick:]),dim=1)
                    # print('new idata shape:',idata.shape)
                    len_recording += 1
        if self.perturbations and adice > 0.9:
            for i in range(self.num_leads):
                # print('i:',i)
                var = torch.var(idata[i])
                # print('torch.var(idata[i]) :',var)
                if abs(var) < 0.001:
                    # print('idata[i] :',idata[i])
                    continue
                idata[i] = torch.from_numpy(scipy.ndimage.gaussian_filter1d(idata[i].numpy(), torch.var(idata[i]).numpy()))
        if not self.multi_patch:                      
            if len_recording <= self.len_signal:
                if len_recording == self.len_signal:
                    end  = len_recording
                else:
                    end = np.random.randint(len_recording, self.len_signal)
                out[:,(end - len_recording) : end] = idata
                out[:, : (end - len_recording)] = 0
                out[:, end : ] = 0
            else:
                end = np.random.randint(self.len_signal, len_recording)
                out[:,:self.len_signal] = idata[:, end - self.len_signal: end]
        else:
            # @Todo, if len_recording less than len_signal, then some empty part are left with random values
            if len_recording <= self.len_signal:
                if len_recording == self.len_signal:
                    end  = len_recording
                else:
                    end = np.random.randint(len_recording, self.len_signal)
                out[0, :,(end - len_recording) : end] = idata
                out[0, :, : (end - len_recording)] = 0
                out[0, :, end : ] = 0
            else:
                cnt = 0
                start = 0
                while (len_recording - start)  > self.len_signal:
                    out[cnt] = idata[:, start : start + self.len_signal]
                    cnt += 1
                    start += self.len_signal - self.len_overlap
                out[cnt] = idata[:, len_recording - self.len_signal : len_recording]
            iattachpro = iattachpro.expand(num_patch, -1)                     
        return out, iattachpro, ilabel

def test_new_testing_data_faster(args, net, weights, test_file="new_testing_data"):
    net.eval()
    # @Todo, process_data is problematic
    features, labels, attachedpro, _ =  process_data(args, test_file) 
    X_val = torch.Tensor(features)
    attachedpro_val = torch.Tensor(attachedpro)
    Y_val = torch.Tensor(labels)
    val_dataset = torch.utils.data.TensorDataset(X_val, attachedpro_val, Y_val)
    val_iter = torch.utils.data.DataLoader(val_dataset, args['test_batch_size'], shuffle=False)    
    with torch.no_grad():
        val_loss = 0
        tot_labels = []
        tot_binary_outputs = []
        tot_ranktop3_binary_outputs = []
        tot_scalar_outputs = []
        for batch_idx, (inputs, ags, targets) in enumerate(val_iter):
            inputs, ags, targets = inputs.to(args['device']), ags.to(args['device']), targets.to(args['device'])
            outputs = net(inputs, ags)

            scalar_outputs = outputs.cpu().detach().numpy()
            ilabels = targets.cpu().detach().numpy()
            binary_outputs = np.zeros(scalar_outputs.shape)
            binary_outputs[scalar_outputs > 0.5] = 1

            # tot_scalar_outputs.append(scalar_outputs)
            tot_labels.append(ilabels)
            tot_binary_outputs.append(binary_outputs)
        
        tot_labels = np.concatenate(tot_labels, axis=0)
        tot_binary_outputs = np.concatenate(tot_binary_outputs, axis=0)
        challenge_metric = compute_challenge_metric_fast(weights, tot_labels, tot_binary_outputs, args['eva_classes'], args['normal_class'])
    return challenge_metric



def test_new_testing_data(args, net, weights, leads, searched_thresholds, test_file="new_testing_data"):
    net.eval()
    # load data file
    header_files, recording_files = find_challenge_files(test_file)
    num_recordings = len(recording_files)
    classes = args['eva_classes']
    num_classes = len(classes)
    model ={
        "classes": args['eva_classes'],
        "leads": leads,
        "imputer": None,
        "classifier": net,
        "searched_thresholds": searched_thresholds
    }

    with torch.no_grad():
        val_loss = 0
        tot_labels = []
        tot_binary_outputs = []
        tot_ranktop3_binary_outputs = []
        tot_scalar_outputs = []
        for i in range(num_recordings):
            # 12-lead is ok, need more consideration if for other leads
            header = load_header(header_files[i])
            labels = parse_labels(header, classes)
            recording = load_recording(recording_files[i])
            _, predictions, probabilities = run_model(model, header, recording, len_signal=args['len_signal'], len_overlap=256, multi_patch=args['multi_patch'])
            tot_labels.append(labels)
            tot_binary_outputs.append(predictions)
        
        tot_labels = np.stack(tot_labels, axis=0)
        tot_binary_outputs = np.stack(tot_binary_outputs, axis=0)
        challenge_metric = compute_challenge_metric_fast(weights, tot_labels, tot_binary_outputs, args['eva_classes'], args['normal_class'])
    return challenge_metric

def test_from_train_split(args, net, weights, val_iter, criterion, epoch, best_metric, local_search):
    net.eval()
    with torch.no_grad():
        all_targets = []
        all_outputs1 = [] # output from 1st model component (for BCE loss)
        all_outputs2 = [] # output from 2nd model component (for challenge loss)
        val_bce = 0
        val_chal = 0

        for batch_idx, (inputs, ags, targets) in enumerate(val_iter):
            if args['multi_patch']: 
                inputs, ags = inputs.squeeze(0),  ags.squeeze(0)          
            inputs, ags, targets = inputs.to(args['device']), ags.to(args['device']), targets.to(args['device'])
            out1, out2 = net(inputs, ags)

            if args['multi_patch']:
                out1 = torch.mean(out1, axis=0).unsqueeze(0)
                out2 = torch.mean(out2, axis=0).unsqueeze(0)
                       
            loss_bce = criterion(out1, targets)
            loss_chal = compute_challenge_loss(weights, targets, out2, args['eva_classes'], args['eva_class_indices'], args['normal_class'], challenge_norm=False)
            
            val_bce += loss_bce.item()
            val_chal += loss_chal.item()

            scalar_outputs1 = out1.cpu().detach().numpy()
            scalar_outputs2 = out2.cpu().detach().numpy()
            ilabels = targets.cpu().detach().numpy()

            all_targets.append(ilabels)
            all_outputs1.append(scalar_outputs1)
            all_outputs2.append(scalar_outputs2)
        
        concat_targets = np.concatenate(all_targets, axis=0)
        concat_outputs1 = np.concatenate(all_outputs1, axis=0)
        concat_outputs2 = np.concatenate(all_outputs2, axis=0)


        thres1, challenge1 = searchthres_global(concat_outputs1, concat_targets, weights, args['eva_classes'], args['normal_class'])
        global_thres2, challenge2 = searchthres_global(concat_outputs2, concat_targets, weights, args['eva_classes'], args['normal_class'])
        thres2 = args['threshold']

        binary_outputs1 = np.zeros(concat_outputs1.shape)
        binary_outputs1[concat_outputs1 > thres1] = 1 
        binary_outputs2 = np.zeros(concat_outputs2.shape)
        binary_outputs2[concat_outputs2 > thres2] = 1
        A = compute_confusion_matrices(concat_targets, binary_outputs2)

        # challenge2 = compute_challenge_metric_fast(weights, concat_targets, binary_outputs2, args['eva_classes'], args['normal_class'])

        best_searched = 0.
        searched_thresholds = [global_thres2 for i in range(24)]
        if epoch > 10 and local_search:
            searched_thresholds, best_searched = searchthres(concat_outputs2, concat_targets, weights, args['eva_classes'], args['normal_class'], global_thres2)
            best_searched = best_searched.cpu().numpy()
            print('best_searched challenge score using local search in challenge loss:',best_searched)
        Asum = A.sum(axis=0)
        sens = Asum[1, 1] / (Asum[1, 1] + Asum[0, 1])
        spec = Asum[0, 0] / (Asum[0, 0] + Asum[1, 0])

        if args['writer'] is not None:
            args["writer"].add_scalar("val_loss", val_bce/len(val_iter), epoch)
            args["writer"].add_scalar("sens", sens, epoch)
            args["writer"].add_scalar("spec", spec, epoch)
            args["writer"].add_scalar("challenge", challenge2, epoch)

        best_so_far = max(challenge2, best_metric) # best challenge score trained using challenge loss over epochs
        best2 = max(challenge2, best_searched)
        print("Epoch {}: val: {:.5f} valC: {:.5f} sens: {:.5f} spec: {:.5f} chal1: {:.5f} chal2: {:.5f} best_so_far {:.5f} best_for_current_epoch {:.5f}".format(\
            epoch,  val_bce / len(val_iter), val_chal / len(val_iter), sens, spec, challenge1, challenge2, best_so_far, best2))
    return searched_thresholds, best2



"""the only difference between [train_whole_model] and [train_model_mixup] is the use of mixup"""
def train_whole_model(args, leads, weights, filename, train_dataset, val_dataset, local_search = False, trained_classifier=None):
    num_leads = len(leads)

    if args['split_train_val']:
        if args['multi_val']:
            # val_iter1 = torch.utils.data.DataLoader(val_dataset[0], args['test_batch_size'], shuffle=False, num_workers=4)
            # val_iter2 = torch.utils.data.DataLoader(val_dataset[1], args['test_batch_size'], shuffle=False, num_workers=4)
            # val_iter3 = torch.utils.data.DataLoader(val_dataset[2], args['test_batch_size'], shuffle=False, num_workers=4)
            val_iter = torch.utils.data.DataLoader(val_dataset, args['test_batch_size'], shuffle=False, num_workers=4)
        else:
            val_iter = torch.utils.data.DataLoader(val_dataset, args['test_batch_size'], shuffle=False, num_workers=4)

    if args['sampler'] == "MultilabelBalancedRandomSampler":
        sampler = MultilabelBalancedRandomSampler(labels=torch.stack(train_dataset.labels, dim=0))
        train_iter = torch.utils.data.DataLoader(train_dataset, args['batch_size'], shuffle=False, sampler=sampler, num_workers=4)
    else: 
        sampler = None
        train_iter = torch.utils.data.DataLoader(train_dataset, args['batch_size'], shuffle=True, sampler=sampler, num_workers=4)
    
    # instantiate a classifier
    if args['add_domain_knowledge']:
        in_channel = num_leads+1
    else: in_channel = num_leads
    if "resnet" in args['clf_nn']:
        model = getattr(res_models, args['clf_nn'])
        classifier = model(in_channel=in_channel, is_sigmoid=args['is_sigmoid_in_resnet'])
    elif "efficientnet" in args['clf_nn']:
        num_classes = 24
        classifier = select_model('cnn1d_adaptive', in_channel, args['kernel_size'], num_classes, args['alpha'], args['beta'], args['phi'])
        # classifier = EfficientNet.from_pretrained(args['clf_nn'], in_channels=num_leads, num_classes=24)
    # print('classifier :',classifier)
    # classifier = resnet18(in_channel=num_leads)
    if trained_classifier is not None:
        trained_dict = trained_classifier.state_dict()
        model_dict =  classifier.state_dict()
        trained_dict['conv1.weight'] = model_dict['conv1.weight']
        classifier.load_state_dict(trained_dict)
    print('# classifier total parameters:', sum(param.numel() for param in classifier.parameters()))
    net = classifier.to(args['device'])
    if args['device'] == 'cuda':
        net = torch.nn.DataParallel(net)

    if not args['is_sigmoid_in_resnet']:
        labels = torch.stack(train_dataset.labels, dim=0)
        weight = labels.sum(dim=0)
        num_samples = labels.shape[0]
        pos_weight = [ (num_samples - we)/(we+1e-5) for we in weight] 
        pos_weight = torch.as_tensor(pos_weight, dtype=torch.float)
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        criterion = criterion.to(args['device'])
    else: 
        criterion = torch.nn.BCELoss()
      
      
    if args['optimizer'] == "Adam":
        optimizer = optim.Adam(net.parameters(), lr=args['lr'])
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40], gamma=0.1)
    elif args['optimizer'] == "SGD":
        optimizer = optim.SGD(net.parameters(), lr=args['lr'], momentum=args['momentum'], weight_decay=args['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    
    best_metric = 0.
    best_threshold = 0.
    for epoch in range(args['num_epochs']):
        print('\nEpoch: %d' % epoch)

        # train
        net.train()
        train_loss = 0
        train_bceloss = 0
        train_challenge_loss = 0
        for batch_idx, (inputs, ags, targets) in enumerate(train_iter):
            inputs, ags, targets = inputs.to(args['device']), ags.to(args['device']), targets.to(args['device'])
            optimizer.zero_grad()
            out1, out2 = net(inputs, ags)              


            bceloss = criterion(out1,targets)
            challenge_loss = compute_challenge_loss(weights, targets, out2, args['eva_classes'], args['eva_class_indices'], args['normal_class'], challenge_norm=False)

            loss = args['challenge_cof'] * challenge_loss + args['bce_cof'] * bceloss

            # if args['bce_penalty']:
            #     bceloss = criterion(outputs,targets)
            #     challenge_loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])
            #     loss = args['challenge_cof'] * challenge_loss + args['bce_cof'] * bceloss
            # else:
            #     loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])
            loss.backward()
            optimizer.step()
            
            # add log about loss    
            train_loss += loss.item()
            train_bceloss += bceloss.item()
            train_challenge_loss += challenge_loss.item()
            if args['writer'] is not None:
                args['writer'].add_scalar("lead{}/train/split_{}/challenge_loss".format(num_leads, 0), loss, epoch*len(train_iter)+batch_idx)
                args['writer'].add_scalar("lead{}/train/split_{}/average_challenge_loss".format(num_leads, 0), train_loss/(batch_idx+1), epoch*len(train_iter)+batch_idx)
            
            if batch_idx % 100 == 1:
                print(batch_idx, len(train_iter), 'Loss: %.3f BCE: %.3f, Chal: %.3f' % (train_loss/(batch_idx+1), train_bceloss/(batch_idx+1), train_challenge_loss/(batch_idx+1)))
        scheduler.step()

        # if not evaluate during training time, record training loss for nni
        if not args['split_train_val'] and args['is_nni']:
            nni.report_intermediate_result(train_loss/(batch_idx+1))

    
        # evaluation during training period 
        if args['split_train_val'] and (epoch+1) % 1 == 0:
                # threshold is searched based on challenge loss
                if args['multi_val']:
                    # searched_thresholds1, challenge_metric1 = test_from_train_split(args, net, weights, val_iter1, criterion, epoch, best_metric, local_search)
                    # print('GA challenge_metric1:',challenge_metric1)
                    # searched_thresholds2, challenge_metric2 = test_from_train_split(args, net, weights, val_iter2, criterion, epoch, best_metric, local_search)
                    # print('CPSC challenge_metric2:',challenge_metric2)

                    # searched_thresholds3, challenge_metric3 = test_from_train_split(args, net, weights, val_iter3, criterion, epoch, best_metric, local_search)
                    # print('all challenge_metric3:',challenge_metric3)
                    # challenge_metric = (3 * challenge_metric1 + challenge_metric2 + challenge_metric3) / 5
                    # searched_thresholds = searched_thresholds3
                    searched_thresholds, challenge_metric = test_from_train_split(args, net, weights, val_iter, criterion, epoch, best_metric, local_search)
                else:
                    searched_thresholds, challenge_metric = test_from_train_split(args, net, weights, val_iter, criterion, epoch, best_metric, local_search)
                
                if challenge_metric > best_metric:
                    best_metric = challenge_metric
                    best_threshold = searched_thresholds
                    save_model(filename, args['eva_classes'], leads, classifier, searched_thresholds, imputer=None)
                if args['is_nni']:
                    nni.report_intermediate_result(challenge_metric)

        # save intermediate model 
        if epoch > 20 and epoch < 30:
            filename_epoch = filename + str(epoch)
            save_model(filename_epoch, args['eva_classes'], leads, classifier, searched_thresholds, imputer=None)
    ####################### Test ECG model. #######################
    # evaluation
    if args['split_train_val']:
        challenge_metric = best_metric
        print("split_train_val challenge metric: ", challenge_metric)
    
    # use new_testing_data for generalization test
    if args['is_test']:
        challenge_metric = test_new_testing_data(args, net, weights, leads, best_threshold, test_file=args['test_file'])
        print("new testing data's challenge metric: ", challenge_metric)
        
    if args['is_nni']:
        nni.report_final_result(challenge_metric)

    # save_model(filename, args['eva_classes'], leads, classifier, searched_thresholds, imputer=None)
    return classifier, challenge_metric


def train_test_model(args, leads, data, labels, attachedpro, weights, filename, trained_classifier = None): 
    num_leads = len(leads)
    feature_indices = [twelve_leads.index(lead) for lead in leads]
    features = data[:, feature_indices ,:] # data [43101, 12, 5000] -> data [43101, num_lead, 5000] 

    ####################### Train ECG model. #######################
    kf = KFold(n_splits=2)
    for split_idx, (train_index, test_index) in enumerate(kf.split(features, labels)):
        X_train , X_test = features[train_index], features[test_index]
        Y_train , Y_test = labels[train_index], labels[test_index]
        attachedpro_train, attachedpro_test = attachedpro[train_index], attachedpro[test_index]

        X_train = torch.Tensor(X_train)
        attachedpro_train = torch.Tensor(attachedpro_train)
        Y_train = torch.Tensor(Y_train)

        X_test = torch.Tensor(X_test)
        attachedpro_test = torch.Tensor(attachedpro_test)
        Y_test = torch.Tensor(Y_test)

        train_data = torch.utils.data.TensorDataset(X_train, attachedpro_train, Y_train)
        if args['sampler'] == "MultilabelBalancedRandomSampler":
            sampler = MultilabelBalancedRandomSampler(labels=Y_train)
            train_iter = torch.utils.data.DataLoader(train_data, args['batch_size'], shuffle=False, sampler=sampler)
        else: 
            sampler = None
            train_iter = torch.utils.data.DataLoader(train_data, args['batch_size'], shuffle=True, sampler=sampler)


        test_data = torch.utils.data.TensorDataset(X_test, attachedpro_test, Y_test)
        test_iter = torch.utils.data.DataLoader(test_data, 1000, shuffle=True)
        
        

        # instantiate a classifier
        classifier = resnet18(in_channel=num_leads, is_sigmoid=args['is_sigmoid_in_resnet']) 
        if trained_classifier is not None:
            trained_dict = trained_classifier.state_dict()
            model_dict =  classifier.state_dict()
            trained_dict['conv1.weight'] = model_dict['conv1.weight']
            # state_dict = {k:v for k,v in trained_classifier.state_dict().items() if (k in model_dict.keys() and k is not 'conv1.weight')}
            # print(model_dict.keys()) 
            # print(trained_dict.keys())  
            # model_dict.update(state_dict)
            classifier.load_state_dict(trained_dict)

        print('# classifier total parameters:', sum(param.numel() for param in classifier.parameters()))
        net = classifier.to(args['device'])
        if args['device'] == 'cuda':
            net = torch.nn.DataParallel(net)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = optim.SGD(net.parameters(), lr=args['lr'],
                            momentum=args['momentum'], weight_decay=args['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
       
        for epoch in range(args['num_epochs']):
            print('\nEpoch: %d' % epoch)
            net.train()
            train_loss = 0
            for batch_idx, (inputs, ags, targets) in enumerate(train_iter):
                inputs, ags, targets = inputs.to(args['device']), ags.to(args['device']), targets.to(args['device'])
                optimizer.zero_grad()
                outputs = net(inputs, ags)
                if args['bce_penalty']:
                    challenge_loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])
                    bceloss = criterion(outputs,targets)
                    loss = args['challenge_cof'] * challenge_loss + args['bce_cof'] * bceloss
                else:
                    loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])
                loss.backward()
                optimizer.step()
                
                # add log about loss
                train_loss += loss.item()
                if args['writer'] is not None:
                    args['writer'].add_scalar("lead{}/train/split_{}/challenge_loss".format(num_leads, split_idx), challenge_loss, epoch*len(train_iter)+batch_idx)
                    args['writer'].add_scalar("lead{}/train/split_{}/total_loss".format(num_leads, split_idx), loss, epoch*len(train_iter)+batch_idx)
                    args['writer'].add_scalar("lead{}/train/split_{}/bce_loss".format(num_leads, split_idx), bceloss, epoch*len(train_iter)+batch_idx)
                
                # scalar_outputs = outputs.cpu().detach().numpy()
                # ilabels = targets.cpu().detach().numpy()
                # auroc, auprc, auroc_classes, auprc_classes = compute_auc(ilabels, scalar_outputs) #[111,]
                # args['writer'].add_scalar("lead{}/train/split_{}/auroc".format(num_leads, split_idx), auroc, epoch*len(train_iter)+batch_idx)
                  
                # binary_outputs = np.zeros(scalar_outputs.shape)
                # binary_outputs[scalar_outputs > 0.5] = 1
                # challenge_metric = compute_challenge_metric_fast(weights, ilabels, binary_outputs, args['eva_classes'], args['normal_class'])
                # args['writer'].add_scalar("lead{}/train/split_{}/challenge_score".format(num_leads, split_idx), challenge_metric, epoch*len(train_iter)+batch_idx)
                
                if batch_idx % 100 == 1:
                    print(batch_idx, len(train_iter), 'Avg Loss: %.3f ' % (train_loss/(batch_idx+1)))
                    print(batch_idx, len(train_iter), 'BCE Loss: %.3f ' % (bceloss))
            
            if epoch % 20 == 0:
                scheduler.step()
                filename_epoch = filename + str(epoch)
                save_model(filename_epoch, args['eva_classes'], leads, classifier, imputer=None)

        ####################### Test ECG model. #######################
        print('Test start:')
        net.eval()
        test_loss = 0.
        tot_labels = []
        tot_binary_outputs = []
        tot_ranktop3_binary_outputs = []
        tot_scalar_outputs = []
        with torch.no_grad():
            for batch_idx, (inputs, ags, targets) in enumerate(test_iter):
                inputs, ags, targets = inputs.to(args['device']), ags.to(args['device']), targets.to(args['device'])
                outputs = net(inputs, ags)
                # loss = criterion(outputs, targets)
                loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])

                test_loss += loss.item()
                if args['writer'] is not None:
                    args['writer'].add_scalar("lead{}/split_{}/test/challege_loss".format(num_leads, split_idx), test_loss/(batch_idx+1), batch_idx)


                scalar_outputs = outputs.cpu().detach().numpy()
                ilabels = targets.cpu().detach().numpy()
                auroc, auprc, auroc_classes, auprc_classes = compute_auc(ilabels, scalar_outputs)

                binary_outputs = np.zeros(scalar_outputs.shape)
                binary_outputs[scalar_outputs > 0.5] = 1
    
                # accuracy = compute_accuracy(ilabels, binary_outputs)
                # args['writer'].add_scalar("lead{}/split_{}/test/acc".format(num_leads, split_idx), accuracy, batch_idx)

                # print('- F-measure...')
                # f_measure, f_measure_classes = compute_f_measure(ilabels, binary_outputs)
                # args['writer'].add_scalar("lead{}/split_{}/test/f_measure".format(num_leads, split_idx), f_measure, batch_idx)

                # print('- Challenge metric...')
                # binary_outputs_for_cha = binary_outputs[:,args['eva_class_indices']]

                # rank_binary_outputs_for_cha = np.argsort(scalar_outputs[:,args['eva_class_indices']])
                # ranked_scalar_output = scalar_outputs[:,args['eva_class_indices']].copy()
                tot_scalar_outputs.append(scalar_outputs)
                # ranked_scalar_output[rank_binary_outputs_for_cha > 1] = 0
                # binary_rank3_outputs = np.zeros(ranked_scalar_output.shape)
                # binary_rank3_outputs[ranked_scalar_output > 0.1] = 1

                # print('binary_outputs_for_cha shape:',binary_outputs_for_cha.shape)
                tot_labels.append(ilabels)
                # tot_scalar_outputs.append(scalar_outputs[:,args['eva_class_indices']])
                # tot_ranktop3_binary_outputs.append(binary_rank3_outputs)
                tot_binary_outputs.append(binary_outputs)
                challenge_metric = compute_challenge_metric_fast(weights, ilabels, binary_outputs, args['eva_classes'], args['normal_class'])
                if args['writer'] is not None:
                    args['writer'].add_scalar("lead{}/split_{}/test/challenge_score".format(num_leads, split_idx), challenge_metric, batch_idx)


                        
        tot_labels = np.concatenate(tot_labels, axis=0)
        tot_binary_outputs = np.concatenate(tot_binary_outputs, axis=0)
        # tot_ranktop3_binary_outputs = np.concatenate(tot_ranktop3_binary_outputs, axis=0)
        tot_scalar_outputs = np.concatenate(tot_scalar_outputs, axis=0)
        print('tot_binary_outputs shape',tot_binary_outputs.shape)
        print('tot_labels shape',tot_labels.shape)
        challenge_metric = compute_challenge_metric_fast(weights, tot_labels, tot_binary_outputs, args['eva_classes'], args['normal_class'])
        # challenge_ranked_metric = compute_challenge_metric_fast(weights, tot_labels, tot_ranktop3_binary_outputs, args['eva_classes'], args['normal_class'])
        auroc, auprc, auroc_classes, auprc_classes = compute_auc(tot_labels, tot_scalar_outputs)
        print('Test Results')
        print('challenge metric :',challenge_metric)
        # print('rank3 threshold challenge metric :',challenge_ranked_metric)
        print('auroc: ',auroc)
        ichallenge_metric_max = 0.
        ithreshold_best = 0
        for i in range(99):
            ithreshold = 0.01 + i* 0.01
            ioutputs = np.zeros(tot_scalar_outputs.shape)
            ioutputs[tot_scalar_outputs >= ithreshold ] = 1
            ichallenge_metric = compute_challenge_metric_fast(weights, tot_labels, ioutputs, args['eva_classes'], args['normal_class'])
            # print('ichallenge_metric :',ichallenge_metric)
            ithreshold_best = ithreshold if ichallenge_metric_max > ichallenge_metric else ithreshold_best
            ichallenge_metric_max = max(ichallenge_metric_max, ichallenge_metric)
        print('max challenge matrix:',ichallenge_metric_max)
        print('best threshold: ', ithreshold_best)
        if args['writer'] is not None:
            args['writer'].add_scalar('max_challenge_matrix', ichallenge_metric_max, num_leads)
            args['writer'].add_scalar('best_threshold', ithreshold_best, num_leads)
        break

    save_model(filename, args['eva_classes'], leads, classifier, imputer=None)
    return classifier

def train_test_model_mixup(args, leads, data, labels, attachedpro, weights, filename, trained_classifier = None):
    alpha = 0.4
    num_leads = len(leads)
    feature_indices = [twelve_leads.index(lead) for lead in leads]
    features = data[:, feature_indices ,:] # data [43101, 12, 4096]
    print('feature_indices :',feature_indices)
    # labels = labels[:, args['eva_class_indices']]
    # labels = labels
    print('labels shape :',labels.shape)

    ####################### Train ECG model. #######################
    kf = KFold(n_splits=2)
    for split_idx, (train_index, test_index) in enumerate(kf.split(features, labels)):
        X_train , X_test = features[train_index], features[test_index]
        Y_train , Y_test = labels[train_index], labels[test_index]
        attachedpro_train, attachedpro_test = attachedpro[train_index], attachedpro[test_index]

        X_train = torch.Tensor(X_train)
        attachedpro_train = torch.Tensor(attachedpro_train)
        Y_train = torch.Tensor(Y_train)

        X_test = torch.Tensor(X_test)
        attachedpro_test = torch.Tensor(attachedpro_test)
        Y_test = torch.Tensor(Y_test)

        train_data = torch.utils.data.TensorDataset(X_train, attachedpro_train, Y_train)
        if args['sampler'] == "MultilabelBalancedRandomSampler":
            sampler = MultilabelBalancedRandomSampler(labels=Y_train)
            train_iter1 = torch.utils.data.DataLoader(train_data, args['batch_size'], shuffle=False, sampler=sampler)
            train_iter2 = torch.utils.data.DataLoader(train_data, args['batch_size'], shuffle=False, sampler=sampler)
        else: 
            train_iter1 = torch.utils.data.DataLoader(train_data, args['batch_size'], shuffle=True)
            train_iter2 = torch.utils.data.DataLoader(train_data, args['batch_size'], shuffle=True)

        test_data = torch.utils.data.TensorDataset(X_test, attachedpro_test, Y_test)
        test_iter = torch.utils.data.DataLoader(test_data, args['batch_size'], shuffle=True)

        # instantiate a classifier
        classifier = resnet18(in_channel=num_leads, is_sigmoid=args['is_sigmoid_in_resnet']) 
        if trained_classifier is not None:
            trained_dict = trained_classifier.state_dict()
            model_dict =  classifier.state_dict()
            trained_dict['conv1.weight'] = model_dict['conv1.weight']
            # state_dict = {k:v for k,v in trained_classifier.state_dict().items() if (k in model_dict.keys() and k is not 'conv1.weight')}
            # print(model_dict.keys()) 
            # print(trained_dict.keys())  
            # model_dict.update(state_dict)
            classifier.load_state_dict(trained_dict)

        print('# classifier total parameters:', sum(param.numel() for param in classifier.parameters()))
        net = classifier.to(args['device'])
        if args['device'] == 'cuda':
            net = torch.nn.DataParallel(net)


        criterion = torch.nn.BCELoss()
        optimizer = optim.SGD(net.parameters(), lr=args['lr'],
                            momentum=args['momentum'], weight_decay=args['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
       
        for epoch in range(args['num_epochs']):
            print('\nEpoch: %d' % epoch)
            net.train()
            train_loss = 0
            for batch_idx, ((inputs1, ags1, targets1), (inputs2, ags2, targets2))in enumerate(zip(train_iter1, train_iter2)):
                lam = np.random.beta(alpha, alpha)
                inputs = inputs1 * lam + inputs2 * (1. - lam)
                ags = ags1 * lam + ags2 * (1. - lam)
                targets = targets1 * lam + targets2 * (1. - lam)
                inputs, ags, targets = inputs.to(args['device']), ags.to(args['device']), targets.to(args['device'])
                optimizer.zero_grad()
                outputs = net(inputs, ags)
                loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])
            # for batch_idx, (inputs, ags, targets) in enumerate(train_iter):
            #     inputs, ags, targets = inputs.to(args['device']), ags.to(args['device']), targets.to(args['device'])
            #     optimizer.zero_grad()
            #     outputs = net(inputs, ags)
            #     loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])
                loss.backward()
                optimizer.step()
                
                # add log about loss  
                train_loss += loss.item()
                if args['writer'] is not None:
                    args['writer'].add_scalar("lead{}/train/split_{}/challenge_loss".format(num_leads, split_idx), loss, epoch*len(train_iter1)+batch_idx)
                    args['writer'].add_scalar("lead{}/train/split_{}/average_challenge_loss".format(num_leads, split_idx), train_loss/(batch_idx+1), epoch*len(train_iter1)+batch_idx)
                
                # scalar_outputs = outputs.cpu().detach().numpy()
                # ilabels = targets.cpu().detach().numpy()
                # auroc, auprc, auroc_classes, auprc_classes = compute_auc(ilabels, scalar_outputs) #[111,]
                # args['writer'].add_scalar("lead{}/train/split_{}/auroc".format(num_leads, split_idx), auroc, epoch*len(train_iter1)+batch_idx)
                  
                # binary_outputs = np.zeros(scalar_outputs.shape)
                # binary_outputs[scalar_outputs > 0.5] = 1
                # challenge_metric = compute_challenge_metric_fast(weights, ilabels, binary_outputs, args['eva_classes'], args['normal_class'])
                # args['writer'].add_scalar("lead{}/train/split_{}/challenge_score".format(num_leads, split_idx), challenge_metric, epoch*len(train_iter1)+batch_idx)
                # progress_bar(batch_idx, len(train_iter), 'Loss: %.3f ' % (train_loss/(batch_idx+1)))
                if batch_idx % 100 == 1:
                    print(batch_idx, len(train_iter1), 'Loss: %.3f ' % (train_loss/(batch_idx+1)))
            scheduler.step()
            if epoch % 20 == 0:
                filename_epoch = filename + str(epoch)
                save_model(filename_epoch, args['eva_classes'], leads, classifier, imputer=None)

        ####################### Test ECG model. #######################
        print('Test start:')
        net.eval()
        test_loss = 0.
        tot_labels = []
        tot_binary_outputs = []
        tot_ranktop3_binary_outputs = []
        tot_scalar_outputs = []
        for batch_idx, (inputs, ags, targets) in enumerate(test_iter):
            inputs, ags, targets = inputs.to(args['device']), ags.to(args['device']), targets.to(args['device'])
            optimizer.zero_grad()
            outputs = net(inputs, ags)
            # loss = criterion(outputs, targets)
            loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])

            test_loss += loss.item()
            if args['writer'] is not None:
                args['writer'].add_scalar("lead{}/split_{}/test/challege_loss".format(num_leads, split_idx), test_loss/(batch_idx+1), batch_idx)


            scalar_outputs = outputs.cpu().detach().numpy()
            ilabels = targets.cpu().detach().numpy()
            auroc, auprc, auroc_classes, auprc_classes = compute_auc(ilabels, scalar_outputs)

            binary_outputs = np.zeros(scalar_outputs.shape)
            binary_outputs[scalar_outputs > 0.5] = 1
 
            # accuracy = compute_accuracy(ilabels, binary_outputs)
            # args['writer'].add_scalar("lead{}/split_{}/test/acc".format(num_leads, split_idx), accuracy, batch_idx)

            # print('- F-measure...')
            # f_measure, f_measure_classes = compute_f_measure(ilabels, binary_outputs)
            # args['writer'].add_scalar("lead{}/split_{}/test/f_measure".format(num_leads, split_idx), f_measure, batch_idx)

            # print('- Challenge metric...')
            # binary_outputs_for_cha = binary_outputs[:,args['eva_class_indices']]

            # rank_binary_outputs_for_cha = np.argsort(scalar_outputs[:,args['eva_class_indices']])
            # ranked_scalar_output = scalar_outputs[:,args['eva_class_indices']].copy()
            tot_scalar_outputs.append(scalar_outputs)
            # ranked_scalar_output[rank_binary_outputs_for_cha > 1] = 0
            # binary_rank3_outputs = np.zeros(ranked_scalar_output.shape)
            # binary_rank3_outputs[ranked_scalar_output > 0.1] = 1

            # print('binary_outputs_for_cha shape:',binary_outputs_for_cha.shape)
            tot_labels.append(ilabels)
            # tot_scalar_outputs.append(scalar_outputs[:,args['eva_class_indices']])
            # tot_ranktop3_binary_outputs.append(binary_rank3_outputs)
            tot_binary_outputs.append(binary_outputs)
            challenge_metric = compute_challenge_metric_fast(weights, ilabels, binary_outputs, args['eva_classes'], args['normal_class'])
            if args['writer'] is not None:
                args['writer'].add_scalar("lead{}/split_{}/test/challenge_score".format(num_leads, split_idx), challenge_metric, batch_idx)

                        
        tot_labels = np.concatenate(tot_labels, axis=0)
        tot_binary_outputs = np.concatenate(tot_binary_outputs, axis=0)
        # tot_ranktop3_binary_outputs = np.concatenate(tot_ranktop3_binary_outputs, axis=0)
        tot_scalar_outputs = np.concatenate(tot_scalar_outputs, axis=0)
        print('tot_binary_outputs shape',tot_binary_outputs.shape)
        print('tot_labels shape',tot_labels.shape)
        challenge_metric = compute_challenge_metric_fast(weights, tot_labels, tot_binary_outputs, args['eva_classes'], args['normal_class'])
        # challenge_ranked_metric = compute_challenge_metric_fast(weights, tot_labels, tot_ranktop3_binary_outputs, args['eva_classes'], args['normal_class'])
        auroc, auprc, auroc_classes, auprc_classes = compute_auc(tot_labels, tot_scalar_outputs)
        print('Test Results')
        print('challenge metric :',challenge_metric)
        # print('rank3 threshold challenge metric :',challenge_ranked_metric)
        print('auroc: ',auroc)
        ichallenge_metric_max = 0.
        ithreshold_best = 0
        for i in range(99):
            ithreshold = 0.01 + i* 0.01
            ioutputs = np.zeros(tot_scalar_outputs.shape)
            ioutputs[tot_scalar_outputs >= ithreshold ] = 1
            ichallenge_metric = compute_challenge_metric_fast(weights, tot_labels, ioutputs, args['eva_classes'], args['normal_class'])
            # print('ichallenge_metric :',ichallenge_metric)
            ithreshold_best = ithreshold if ichallenge_metric_max > ichallenge_metric else ithreshold_best
            ichallenge_metric_max = max(ichallenge_metric_max, ichallenge_metric)
        print('max challenge matrix:',ichallenge_metric_max)
        print('best threshold: ', ithreshold_best)
        if args['writer'] is not None:
            args['writer'].add_scalar('max_challenge_matrix', ichallenge_metric_max, num_leads)
            args['writer'].add_scalar('best_threshold', ithreshold_best, num_leads)
        
    save_model(filename, args['eva_classes'], leads, classifier, imputer=None)
    return classifier
      
def train_model_mixup(args, leads, weights, filename, train_dataset, val_dataset,local_search = False, trained_classifier=None):
    num_leads = len(leads)

    if args['split_train_val']:
        if args['multi_val']:
            val_iter = torch.utils.data.DataLoader(val_dataset, args['test_batch_size'], shuffle=False, num_workers=4)
            # val_iter1 = torch.utils.data.DataLoader(val_dataset[0], args['test_batch_size'], shuffle=False, num_workers=4)
            # val_iter2 = torch.utils.data.DataLoader(val_dataset[1], args['test_batch_size'], shuffle=False, num_workers=4)
            # val_iter3 = torch.utils.data.DataLoader(val_dataset[2], args['test_batch_size'], shuffle=False, num_workers=4)
        else:
            val_iter = torch.utils.data.DataLoader(val_dataset, args['test_batch_size'], shuffle=False, num_workers=4)

    if args['sampler'] == "MultilabelBalancedRandomSampler":
        sampler = MultilabelBalancedRandomSampler(labels=train_dataset.labels)
        train_iter1 = torch.utils.data.DataLoader(train_dataset, args['batch_size'], shuffle=False, sampler=sampler, num_workers=4)
        train_iter2 = torch.utils.data.DataLoader(train_dataset, args['batch_size'], shuffle=False, sampler=sampler, num_workers=4)
    else: 
        sampler = None
        train_iter1 = torch.utils.data.DataLoader(train_dataset, args['batch_size'], shuffle=True, sampler=sampler, num_workers=4)
        train_iter2 = torch.utils.data.DataLoader(train_dataset, args['batch_size'], shuffle=True, sampler=sampler, num_workers=4)   

    # instantiate a classifier
    if args['add_domain_knowledge']:
        in_channel = num_leads+1
    else: in_channel = num_leads
    if "resnet" in args['clf_nn']:
        model = getattr(res_models, args['clf_nn'])
        classifier = model(in_channel=in_channel, is_sigmoid=args['is_sigmoid_in_resnet'])
    elif "efficientnet" in args['clf_nn']:
        num_classes = 24
        classifier = select_model('cnn1d_adaptive', in_channel, args['kernel_size'], num_classes, args['alpha'], args['beta'], args['phi'])
        # classifier = EfficientNet.from_pretrained(args['clf_nn'], in_channels=num_leads, num_classes=24)
    print('classifier :',classifier)
      
    
    if trained_classifier is not None:
        trained_dict = trained_classifier.state_dict()
        model_dict =  classifier.state_dict()
        trained_dict['conv1.weight'] = model_dict['conv1.weight']
        classifier.load_state_dict(trained_dict)
    print('# classifier total parameters:', sum(param.numel() for param in classifier.parameters()))
    net = classifier.to(args['device'])
    # if args['device'] == 'cuda':
    #     net = torch.nn.DataParallel(net)

    criterion = torch.nn.BCELoss()
    if args['optimizer'] == "Adam":
        optimizer = optim.Adam(net.parameters(), lr=args['lr'])
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40], gamma=0.1)
    elif args['optimizer'] == "SGD":
        optimizer = optim.SGD(net.parameters(), lr=args['lr'], momentum=args['momentum'], weight_decay=args['weight_decay'])
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40], gamma=0.1)  
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    best_metric = 0.
    best_threshold = 0.
    for epoch in range(args['num_epochs']):
      
        print('\nEpoch: %d' % epoch)
        net.train()
        train_loss = 0
        train_bceloss = 0
        train_challenge_loss = 0         
        for batch_idx, ((inputs1, ags1, targets1), (inputs2, ags2, targets2))in enumerate(zip(train_iter1, train_iter2)):
            lam = np.random.beta(args['mixup_alpha'], args['mixup_alpha'])
            inputs = inputs1 * lam + inputs2 * (1. - lam)
            ags = ags1 * lam + ags2 * (1. - lam)
            targets = targets1 * lam + targets2 * (1. - lam)
            inputs, ags, targets = inputs.to(args['device']), ags.to(args['device']), targets.to(args['device'])
            optimizer.zero_grad()
            out1, out2 = net(inputs, ags)
            
            bceloss = criterion(out1,targets)
            challenge_loss = compute_challenge_loss(weights, targets, out2, args['eva_classes'], args['eva_class_indices'], args['normal_class'], challenge_norm=False)   
            loss = args['challenge_cof'] * challenge_loss + args['bce_cof'] * bceloss
            
            # if args['bce_penalty']:
            #     bceloss = criterion(outputs,targets)
            #     challenge_loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])
            #     loss = args['challenge_cof'] * challenge_loss + args['bce_cof'] * bceloss
            # else:
            #     loss = compute_challenge_loss(weights, targets, outputs, args['eva_classes'], args['eva_class_indices'], args['normal_class'])
            loss.backward()
            optimizer.step()
            
            # add log about loss 
            train_loss += loss.item()
            train_bceloss += bceloss.item()
            train_challenge_loss += challenge_loss.item()            
            if args['writer'] is not None:
                args['writer'].add_scalar("lead{}/train/split_{}/challenge_loss".format(num_leads, 0), loss, epoch*len(train_iter1)+batch_idx)
                args['writer'].add_scalar("lead{}/train/split_{}/average_challenge_loss".format(num_leads, 0), train_loss/(batch_idx+1), epoch*len(train_iter1)+batch_idx)
            
            if batch_idx % 100 == 1:
                print(batch_idx, len(train_iter1), 'Loss: %.3f BCE: %.3f, Chal: %.3f' % (train_loss/(batch_idx+1), train_bceloss/(batch_idx+1), train_challenge_loss/(batch_idx+1)))
        scheduler.step()
        
        # if not evaluate during training time, record training loss for nni
        # if not args['split_train_val'] and args['is_nni']:
        #     nni.report_intermediate_result(train_loss/(batch_idx+1))
        
        
        # evaluation during training period 
        if args['split_train_val'] and (epoch+1) % 1 == 0:
            searched_thresholds, challenge_metric = test_from_train_split(args, net, weights, val_iter, criterion, epoch, best_metric, local_search)
            # if challenge_metric > best, then save
            if challenge_metric > best_metric:
                best_metric = challenge_metric
                best_threshold = searched_thresholds
                save_model(filename, args['eva_classes'], leads, classifier, searched_thresholds, imputer=None)
            if args['is_nni']:
                nni.report_intermediate_result(challenge_metric)

        # save intermediate model 
        # if epoch % 20 == 0:
        #     scheduler.step()
        #     filename_epoch = filename + str(epoch)
        #     save_model(filename_epoch, args['eva_classes'], leads, classifier, searched_thresholds, imputer=None)
        # evaluation during training period 


    ####################### Test ECG model. #######################
    # evaluation
    if args['split_train_val']:
        challenge_metric = best_metric
        print("split_train_val challenge metric: ", challenge_metric)
    
    # use new_testing_data for generalization test
    if args['is_test']:
        challenge_metric = test_new_testing_data(args, net, weights, leads, best_threshold, test_file=args['test_file'])
        print("new testing data's challenge metric: ", challenge_metric)
        
    if args['is_nni']:
        nni.report_final_result(challenge_metric)

    # save_model(filename, args['eva_classes'], leads, classifier, searched_thresholds, imputer=None)
    return classifier, challenge_metric



# Train your model. This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    
    args = set_args(print_details=False)
    # get parameters form tuner
    if args['is_nni']:
        tuner_params = nni.get_next_parameter()
        try:
            print(str(tuner_params))
        except:
            pass
        args = merge_parameter(args, tuner_params)
    else:
        args = args
    print(args)

    # Create a folder for the model if it does not already exist.
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)

    if args['gpu_ids']:
        os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu_ids']
        print("use GPU ", os.environ['CUDA_VISIBLE_DEVICES'])

    ####################### start of pre-process recordings ############################
    from sklearn.model_selection import train_test_split
    data, labels, attachedpro, weights, xqrssig = process_data(args, data_directory)
    header_files, recording_files = find_challenge_files(data_directory)
    num_recordings = len(recording_files)
    Ga_data = []
    Ga_labels = []
    Ga_attachedpro = []

    CPSC_data = []
    CPSC_labels = []
    CPSC_attachedpro = []

    other_data = []
    other_labels = []
    other_attachedpro = []
    for i in range(num_recordings):
        ifile =  recording_files[i]
        head, tail = os.path.split(ifile)
        if tail.startswith('E'):
            Ga_data.append(data[i])
            Ga_labels.append(labels[i])
            Ga_attachedpro.append(attachedpro[i])
        elif tail.startswith('A') or tail.startswith('Q'):
            CPSC_data.append(data[i])
            CPSC_labels.append(labels[i])
            CPSC_attachedpro.append(attachedpro[i])
        else:
            other_data.append(data[i])
            other_labels.append(labels[i])
            other_attachedpro.append(attachedpro[i])

    if args['split_train_val']:
        data, val_data, labels, val_labels, attachedpro, val_attachedpro = train_test_split(data, labels, attachedpro, test_size=0.2,train_size=0.8)

        


    if args['split_train_val']:
        Ga_train_data, Ga_val_data, Ga_train_labels, Ga_val_labels, Ga_train_attachedpro, Ga_val_attachedpro = train_test_split(Ga_data, Ga_labels, Ga_attachedpro, test_size=0.2,train_size=0.8)
        CPSC_train_data, CPSC_val_data, CPSC_train_labels, CPSC_val_labels, CPSC_train_attachedpro, CPSC_val_attachedpro = train_test_split(CPSC_data, CPSC_labels, CPSC_attachedpro, test_size=0.2,train_size=0.8)
        other_train_data, other_val_data, other_train_labels, other_val_labels, other_train_attachedpro, other_val_attachedpro = train_test_split(other_data, other_labels, other_attachedpro, test_size=0.2,train_size=0.8)
    ####################### end of pre-process recordings ############################
    score_logger = {}
      
        ###################### Train 12-lead ECG model. #######################
    local_search = True
    if args['add_domain_knowledge']:
        for i in range(len(data)):
            data[i] = torch.cat((data[i], xqrssig[i]), dim = 0)
    print('Training 12-lead ECG model...')
    leads = twelve_leads
    filename = os.path.join(model_directory, twelve_lead_model_filename)
    num_leads = len(leads)
    
    feature_indices = [twelve_leads.index(lead) for lead in leads]

    if args['multi_val']:
        train_dataset = mydataset(Ga_train_data + CPSC_train_data + other_train_data, Ga_train_labels + CPSC_train_labels + other_train_labels, Ga_train_attachedpro + CPSC_train_attachedpro + other_train_attachedpro, num_leads , args['len_signal'], feature_indices)
    else:
        train_dataset = mydataset(data, labels, attachedpro, num_leads , args['len_signal'], feature_indices)

    
    if args['split_train_val']:
        if args['multi_val']:
            # val_dataset1 = mydataset(Ga_val_data, Ga_val_labels, Ga_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset2 = mydataset(CPSC_val_data, CPSC_val_labels, CPSC_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            val_dataset3 = mydataset(21 * Ga_val_data + 7 * CPSC_val_data + other_val_data, 21 * Ga_val_labels + 7 * CPSC_val_labels + other_val_labels, 21 * Ga_val_attachedpro + 7 * CPSC_val_attachedpro + other_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset = [val_dataset1,val_dataset2,val_dataset3]
            val_dataset = val_dataset3
        else:
            val_dataset = mydataset(val_data, val_labels, val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])

    else:  
        val_dataset = None

    if args['train_mode'] == "train_test_model":
        trained_classifier = train_test_model(args, leads, weights, filename, train_dataset)
    elif args['train_mode'] == "train_model_mixup":
        trained_classifier, challenge_score = train_model_mixup(args, leads, weights, filename, train_dataset, val_dataset, local_search)
    else:
        trained_classifier, challenge_score = train_whole_model(args, leads, weights, filename, train_dataset, val_dataset, local_search)
    
    score_logger['12lead'] = challenge_score

   ####################### Train 6-lead ECG model. #######################.
    print('Training 6-lead ECG model...')
    if not args['pre_train_12leads']:
        trained_classifier = None
    leads = six_leads
    filename = os.path.join(model_directory, six_lead_model_filename)
    num_leads = len(leads)
    feature_indices = [twelve_leads.index(lead) for lead in leads]
    if args['multi_val']:
        train_dataset = mydataset(Ga_train_data + CPSC_train_data + other_train_data, Ga_train_labels + CPSC_train_labels + other_train_labels, Ga_train_attachedpro + CPSC_train_attachedpro + other_train_attachedpro, num_leads , args['len_signal'], feature_indices)
    else:
        train_dataset = mydataset(data, labels, attachedpro, num_leads , args['len_signal'], feature_indices)
    if args['split_train_val']:
        if args['multi_val']:
            # val_dataset1 = mydataset(Ga_val_data, Ga_val_labels, Ga_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset2 = mydataset(CPSC_val_data, CPSC_val_labels, CPSC_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            val_dataset3 = mydataset(21 * Ga_val_data + 7 * CPSC_val_data + other_val_data, 21 * Ga_val_labels + 7 * CPSC_val_labels + other_val_labels, 21 * Ga_val_attachedpro + 7 * CPSC_val_attachedpro + other_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset = [val_dataset1,val_dataset2,val_dataset3]
            val_dataset = val_dataset3
        else:
            val_dataset = mydataset(val_data, val_labels, val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])   

    if args['train_mode'] == "train_test_model":
        trained_classifier = train_test_model(args, leads, weights, filename, train_dataset, trained_classifier)
    elif args['train_mode'] == "train_model_mixup":
        trained_classifier, challenge_score = train_model_mixup(args, leads, weights, filename, train_dataset, val_dataset, local_search, trained_classifier)
    else:
        trained_classifier, challenge_score = train_whole_model(args, leads, weights, filename, train_dataset, val_dataset, local_search, trained_classifier)

    score_logger['6lead'] = challenge_score
   ####################### Train 4-lead ECG model. #######################.
    print('Training 4-lead ECG model...')
    if not args['pre_train_12leads']:
        trained_classifier = None
    leads = four_leads
    filename = os.path.join(model_directory, four_lead_model_filename)
    num_leads = len(leads)
    feature_indices = [twelve_leads.index(lead) for lead in leads]
    if args['multi_val']:
        train_dataset = mydataset(Ga_train_data + CPSC_train_data + other_train_data, Ga_train_labels + CPSC_train_labels + other_train_labels, Ga_train_attachedpro + CPSC_train_attachedpro + other_train_attachedpro, num_leads , args['len_signal'], feature_indices)
    else:
        train_dataset = mydataset(data, labels, attachedpro, num_leads , args['len_signal'], feature_indices)
    if args['split_train_val']:
        if args['multi_val']:
            # val_dataset1 = mydataset(Ga_val_data, Ga_val_labels, Ga_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset2 = mydataset(CPSC_val_data, CPSC_val_labels, CPSC_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            val_dataset3 = mydataset(21 * Ga_val_data + 7 * CPSC_val_data + other_val_data, 21 * Ga_val_labels + 7 * CPSC_val_labels + other_val_labels, 21 * Ga_val_attachedpro + 7 * CPSC_val_attachedpro + other_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset = [val_dataset1,val_dataset2,val_dataset3]
            val_dataset = val_dataset3
        else:
            val_dataset = mydataset(val_data, val_labels, val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])

    if args['train_mode'] == "train_test_model":
        trained_classifier = train_test_model(args, leads, weights, filename, train_dataset, trained_classifier)
    elif args['train_mode'] == "train_model_mixup":
        trained_classifier, challenge_score = train_model_mixup(args, leads, weights, filename, train_dataset, val_dataset, local_search, trained_classifier)
    else:
        trained_classifier, challenge_score = train_whole_model(args, leads, weights, filename, train_dataset, val_dataset, local_search, trained_classifier)

    score_logger['4lead'] = challenge_score

    ####################### Start of Train 3-lead ECG model. #######################   
    print('Training 3-lead ECG model...')
    if not args['pre_train_12leads']:
        trained_classifier = None
    leads = three_leads
    filename = os.path.join(model_directory, three_lead_model_filename)
    num_leads = len(leads)
    feature_indices = [twelve_leads.index(lead) for lead in leads]
    if args['multi_val']:
        train_dataset = mydataset(Ga_train_data + CPSC_train_data + other_train_data, Ga_train_labels + CPSC_train_labels + other_train_labels, Ga_train_attachedpro + CPSC_train_attachedpro + other_train_attachedpro, num_leads , args['len_signal'], feature_indices)
    else:
        train_dataset = mydataset(data, labels, attachedpro, num_leads , args['len_signal'], feature_indices)
    if args['split_train_val']:
        if args['multi_val']:
            # val_dataset1 = mydataset(Ga_val_data, Ga_val_labels, Ga_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset2 = mydataset(CPSC_val_data, CPSC_val_labels, CPSC_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            val_dataset3 = mydataset(21 * Ga_val_data + 7 * CPSC_val_data + other_val_data, 21 * Ga_val_labels + 7 * CPSC_val_labels + other_val_labels, 21 * Ga_val_attachedpro + 7 * CPSC_val_attachedpro + other_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset = [val_dataset1,val_dataset2,val_dataset3]
            val_dataset = val_dataset3
        else:
            val_dataset = mydataset(val_data, val_labels, val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])


    if args['train_mode'] == "train_test_model":
        trained_classifier = train_test_model(args, leads, weights, filename, train_dataset, trained_classifier)
    elif args['train_mode'] == "train_model_mixup":
        trained_classifier, challenge_score = train_model_mixup(args, leads, weights, filename, train_dataset, val_dataset, local_search, trained_classifier)
    else:
        trained_classifier, challenge_score = train_whole_model(args, leads, weights, filename, train_dataset, val_dataset, local_search, trained_classifier)

    score_logger['3lead'] = challenge_score

    ####################### Start of Train 2-lead ECG model. #######################  
    print('Training 2-lead ECG model...')
    if not args['pre_train_12leads']:
        trained_classifier = None
    leads = two_leads
    filename = os.path.join(model_directory, two_lead_model_filename)
    num_leads = len(leads)
    feature_indices = [twelve_leads.index(lead) for lead in leads]
    if args['multi_val']:
        train_dataset = mydataset(Ga_train_data + CPSC_train_data + other_train_data, Ga_train_labels + CPSC_train_labels + other_train_labels, Ga_train_attachedpro + CPSC_train_attachedpro + other_train_attachedpro, num_leads , args['len_signal'], feature_indices)
    else:
        train_dataset = mydataset(data, labels, attachedpro, num_leads , args['len_signal'], feature_indices)
    if args['split_train_val']:
        if args['multi_val']:
            # val_dataset1 = mydataset(Ga_val_data, Ga_val_labels, Ga_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset2 = mydataset(CPSC_val_data, CPSC_val_labels, CPSC_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            val_dataset3 = mydataset(21 * Ga_val_data + 7 * CPSC_val_data + other_val_data, 21 * Ga_val_labels + 7 * CPSC_val_labels + other_val_labels, 21 * Ga_val_attachedpro + 7 * CPSC_val_attachedpro + other_val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])
            # val_dataset = [val_dataset1,val_dataset2,val_dataset3]
            val_dataset = val_dataset3
        else:
            val_dataset = mydataset(val_data, val_labels, val_attachedpro, num_leads , args['len_signal'], feature_indices, multi_patch=args['multi_patch'])


    if args['train_mode'] == "train_test_model":
        trained_classifier = train_test_model(args, leads, weights, filename, train_dataset, trained_classifier)
    elif args['train_mode'] == "train_model_mixup":
        trained_classifier, challenge_score = train_model_mixup(args, leads, weights, filename, train_dataset, val_dataset, local_search, trained_classifier)
    else:
        trained_classifier, challenge_score = train_whole_model(args, leads, weights, filename, train_dataset, val_dataset, local_search, trained_classifier)

    score_logger['2lead'] = challenge_score


    # ####################### Experiments Done. Logging. ####################### 
    # os.makedirs("output", exist_ok=True)
    # with open("output/result.txt", "a") as f:
    #     f.write(str(args)+"\n")
    #     for lead, score in score_logger.items():
    #         f.write(lead + ": " + str(f"{score.cpu().numpy():5.4}") + '\n')
    #     f.write("\n")

################################################################################
#
# File I/O functions
#
################################################################################

# Save your trained models.
def save_model(filename, classes, leads, classifier, searched_thresholds, imputer=None):
    # Construct a data structure for the model and save it.
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    d = {'classes': classes, 'leads': leads, 'imputer': imputer, 'classifier': classifier, 'searched_thresholds' : searched_thresholds}
    joblib.dump(d, filename, protocol=0)

# Load your trained 12-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def load_twelve_lead_model(model_directory):
    filename = os.path.join(model_directory, twelve_lead_model_filename)
    return load_model(filename)

# Load your trained 6-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def load_six_lead_model(model_directory):
    filename = os.path.join(model_directory, six_lead_model_filename)
    return load_model(filename)

# Load your trained 3-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def load_three_lead_model(model_directory):
    filename = os.path.join(model_directory, three_lead_model_filename)
    return load_model(filename)

# Load your trained 2-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def load_two_lead_model(model_directory):
    filename = os.path.join(model_directory, two_lead_model_filename)
    return load_model(filename)

# Generic function for loading a model.
# def load_model(filename):
#     return joblib.load(filename)

################################################################################
#
# Running trained model functions
#
################################################################################

# Run your trained 12-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def run_twelve_lead_model(model, header, recording):
    return run_model(model, header, recording)

# Run your trained 6-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def run_six_lead_model(model, header, recording):
    return run_model(model, header, recording)

# Run your trained 3-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def run_three_lead_model(model, header, recording):
    return run_model(model, header, recording)

# Run your trained 2-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
def run_two_lead_model(model, header, recording):
    return run_model(model, header, recording)


def load_model(model_directory, leads):
    print('len leads :',len(leads))
    num_leads = len(leads)
    if num_leads == 12:
        models = []
        for f in os.listdir(model_directory):
            print('f :',f)
            if f.startswith('12'):
                filename = os.path.join(model_directory, f)
                models.append(joblib.load(filename))
        # print('models :',models)
    if num_leads == 6:
        models = []
        for f in os.listdir(model_directory):
            print('f :',f)
            if f.startswith('6'):
                filename = os.path.join(model_directory, f)
                models.append(joblib.load(filename))
        # print('models :',models)
    if num_leads == 4:
        models = []
        for f in os.listdir(model_directory):
            print('f :',f)
            if f.startswith('4'):
                filename = os.path.join(model_directory, f)
                models.append(joblib.load(filename))
        # print('models :',models)
    if num_leads == 3:
        models = []
        for f in os.listdir(model_directory):
            print('f :',f)
            if f.startswith('3'):
                filename = os.path.join(model_directory, f)
                models.append(joblib.load(filename))
    if num_leads == 2:
        models = []
        for f in os.listdir(model_directory):
            print('f :',f)
            if f.startswith('2'):
                filename = os.path.join(model_directory, f)
                models.append(joblib.load(filename))
        # print('models :',models)
    # filename = os.path.join(model_directory, get_model_filename(leads))
    # return joblib.load(filename)
    return models


def run_model(model, header, recording, len_signal=4096, len_overlap=256, multi_patch=True):
        
    torch.set_grad_enabled(False)
    model_probabilities = 0
    model_labels = 0
    # print('len :',len(model))
    for imodel in model:
        classes = imodel['classes']
        leads = imodel['leads']
        imputer = imodel['imputer']
        classifier = imodel['classifier']
        searched_thresholds = imodel['searched_thresholds']
        # Load features.
        attachedpro, data = get_features(header, recording, leads, len_signal, len_overlap, multi_patch)

        # Impute missing data.
        if imputer is not None:
            data = data.reshape(1, -1)
            data = imputer.transform(data)

        if multi_patch:
            # Predict labels and probabilities.
            data = torch.Tensor(data).cuda()
            attachedpro = torch.Tensor(attachedpro).cuda().expand(data.shape[0], -1)
            probabilities1, probabilities2 = classifier(data, attachedpro)
            probabilities = probabilities2.cpu().numpy()
            probabilities = np.mean(probabilities, axis=0)
            # probabilities = np.asarray(probabilities, dtype=np.float32)[:, 0, 1] # @Todo what is the [:, 0, 1]?
        else:
            data = torch.Tensor(data[0]).unsqueeze(0).cuda()
            attachedpro = torch.Tensor(attachedpro).unsqueeze(0).cuda()
            probabilities1, probabilities2 = classifier(data, attachedpro)
            probabilities = probabilities2.cpu().numpy()
            probabilities = probabilities[0]
        labels = np.zeros(probabilities.shape, dtype=np.int)

        if isinstance(searched_thresholds, float):
            for i in range(24):
                labels[i] = probabilities[i] > searched_thresholds
        else:
            for i in range(24):
                labels[i] = probabilities[i] > searched_thresholds[i]
        model_labels += labels
    # print('model_labels :',model_labels)
    model_labels = model_labels / len(model)
    # print('probabilities :',probabilities)

    final_labels = np.zeros(probabilities.shape, dtype=np.int)
    # print('model_labels :',model_labels)

    for i in range(24):
        final_labels[i] = model_labels[i] > 0.499999
    # print('final_labels :',final_labels)
    return classes, final_labels, probabilities


################################################################################
#
# Other functions
#
################################################################################

