#!/usr/bin/env python

# Copyright 2021 AMI inc.


# Edit this script to add your team's training code.
# Some functions are *required*, but you can edit most parts of required functions, remove non-required functions, and add your own function.

import os
os.environ["OMP_NUM_THREADS"] = "1"

import argparse
import configargparse
from collections import defaultdict, namedtuple
import datetime
from distutils.util import strtobool as dist_strtobool
from functools import partial
import re

import joblib
import json
import logging
import shutil
import subprocess
import sys
import time
import math
import random

from tabulate import tabulate
from colorama import Fore, Back, Style
from concurrent.futures import wait as confu_wait
from concurrent.futures import ProcessPoolExecutor

import numpy as np
import pandas as pd
import scipy

import biosppy
import hrv
import resampy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.optimizer import Optimizer
from torch.optim.swa_utils import AveragedModel

from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.mixture import GaussianMixture

import ray
from ray import tune
from ray.tune.suggest.ax import AxSearch
from ray.tune.suggest.bayesopt import BayesOptSearch
# ray tune
ray.init(address=None)

import helper_code
from helper_code import *

###############################################################################
# official evaluation code
if not os.path.exists('evaluation-2021'):
    res = subprocess.run(
        ["git", "clone", "https://github.com/physionetchallenges/evaluation-2021"],
        stdout=subprocess.PIPE
    )
    sys.stdout.buffer.write(res.stdout)

    shutil.copyfile('evaluation-2021/evaluate_model.py', 'evaluate_model.py')
    shutil.copyfile('evaluation-2021/weights.csv', 'weights.csv')
    
#sys.path.append('evaluation-2021')
from evaluate_model import (
    compute_f_measure,
    compute_accuracy,
    compute_challenge_metric,
    load_weights,
)
weights_file = 'weights.csv'
normal_class = '426783006'
equivalent_classes = [['733534002', '164909002'],
                      ['713427006', '59118001'],
                      ['284470004', '63593006'],
                      ['427172004', '17338001']]
_classes, _weights = load_weights(weights_file) #, equivalent_classes)
# print("classes: {}".format(_classes))
# print("weights: {}".format(_weights))
classes = None
weights = None

# Define the Challenge lead sets. These variables are not required. You can change or remove them.
twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
four_leads = ('I', 'II', 'III', 'V2')
three_leads = ('I', 'II', 'V2')
two_leads = ('I', 'II')
lead_sets = (twelve_leads, six_leads, four_leads, three_leads, two_leads)
###############################################################################

random.seed(0)
np.random.seed(0)

# (TODO) we need two?
scaler_1 = torch.cuda.amp.GradScaler()
scaler_2 = torch.cuda.amp.GradScaler()
max_seqlen = 7500

################################################################################
#
# Data Preparation
#
################################################################################

preproc_config = {
    "dt_rate": 0.0,
    "tt_rate": 0.1,
    "fs": 500,
    "trim_dur": -1, # 30
    "norm_opt": "range",
    "denoise": False,
    "use_tsfresh": False,
}

valid_syms = ["164889003", "164890007", "6374002", "426627000", "733534002",
              "713427006", "270492004", "713426002", "39732003", "445118002",
              "164909002", "251146004", "698252002", "426783006", "284470004",
              "10370003", "365413008", "427172004", "164947007", "111975006",
              "164917005", "47665007", "59118001",  "427393009", "426177001",
              "427084000", "63593006", "164934002", "59931005", "17338001"]
#valid_syms = [
#    "270492004", "164889003", "164890007", "426627000", "713427006",
#    "713426002", "445118002", "39732003", "164909002", "251146004",
#    "698252002", "10370003", "284470004", "427172004", "164947007",
#    "111975006", "164917005", "47665007", "59118001", "427393009",
#    "426177001", "426783006", "427084000", "63593006", "164934002",
#    "59931005", "17338001"]
sym_map = {
    "733534002": "164909002",
    "713427006": "59118001",
    "284470004": "63593006",
    "427172004": "17338001"
}
#corpus_dic = {
#    "A": 0,  # Training_CPSC
#    "E": 1,  # Training_E
#    "H": 2, # PTB_XL
#    "HR": 2, # PTB_XL
#    "I": 3,  # StPetersburg
#    "Q": 4,  # Training_2
#    "S": 5   # PTB
#}


################################################################################
#
# ESPNet Related
#
################################################################################
def load_trained_model(model_path):
    def get_model_conf(model_path, conf_path=None):
        """Get model config information by reading a model config file (model.json).
        Args:
            model_path (str): Model path.
            conf_path (str): Optional model config path.
        Returns:
            list[int, int, dict[str, Any]]: Config information loaded from json file.
        """
        if conf_path is None:
            model_conf = os.path.dirname(model_path) + "/model.json"
        else:
            model_conf = conf_path
        with open(model_conf, "rb") as f:
            logging.info("reading a config file from " + model_conf)
            confs = json.load(f)
        idim, odim, args = confs
        return idim, odim, argparse.Namespace(**args)

    def torch_load(path, model):
        """Load torch model states.
        Args:
            path (str): Model path or snapshot file path to be loaded.
            model (torch.nn.Module): Torch model.
        """
        if "snapshot" in os.path.basename(path):
            model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[
                "model"
            ]
        else:
            model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)
            
        if hasattr(model, "module"):
            model.module.load_state_dict(model_state_dict)
        else:
            model.load_state_dict(model_state_dict)
            
        del model_state_dict
        
    # load trained model and train_args
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), "model.json")
    )

    model = DivMixNet(idim, odim, train_args)
    torch_load(model_path, model)
    return model, train_args

def update_sym2int(sym2int):
    """map some ids to an unique one w.r.t. sym_map
    """
    for k, v in sym_map.items():
        if v in sym2int:
            sym2int[k] = sym2int[v]
    return sym2int
        
def is_nan_inf(x):
    """True if x is inf or nan, False otherwise
    """
    if np.isnan(x).any() or np.isinf(x).any():
        return True
    return False

################################################################################
#
# Feature Extraction Related
#
################################################################################
tsfresh_columns = ["lead_idx_" + str(i) for i in range(12)]
class TsFresh:
    def __init__(self):
        pass

    def get_tsfresh_setting(self):
        """Get custom feature set
        """
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        import tsfresh
        default_setting = tsfresh.feature_extraction.settings.EfficientFCParameters()
        setting = {}
        
        # 0.385583878
        setting['count_above'] = [{'t': 0}, {'t': 0.25}, {'t': 0.5}, {'t': 0.75}]
        # 0.387517929
        setting['mean_change'] = None
        # 0.389936447
        setting['mean_second_derivative_central'] = None
        # 0.39177227
        setting['count_below'] = [{'t': 0}, {'t': -0.25}, {'t': -0.5}, {'t': -0.75}]
        # 0.437315226
        setting['root_mean_square'] = None
        # 0.446575165
        setting['count_below_mean'] = None
        # 0.451524734
        setting['count_above_mean'] = None
        # 0.460935116
        setting['mean_abs_change'] = None
        # 0.471253157
        setting['standard_deviation'] = None
        # 0.495888948
        setting['variation_coefficient'] = None
        # 0.511912823
        #setting['number_crossing_m'] = [{'m': -0.75}, {'m': -0.5}, {'m': -0.25}, {'m': 0.0}, {'m': 0.25}, {'m': 0.5}, {'m': 0.75}]
        # 0.534977436
        setting['range_count'] = [{'min': -0.75, 'max': 0.75}, {'min': -0.5, 'max': 0.5}, {'min': -0.25, 'max': 0.25}]
        # 0.669710636
        setting['time_reversal_asymmetry_statistic'] = [{'lag': 1}, {'lag': 2}, {'lag': 3}]
        # 0.721371412
        setting['kurtosis'] = None
        # 0.696781635
        setting['skewness'] = None
        # 0.618685484
        setting['c3'] = [{'lag': 1}, {'lag': 2}, {'lag': 3}]
        # 0.774443865
        setting['binned_entropy'] = [{'max_bins': 10}]
        # 0.963179588
        setting['ratio_value_number_to_time_series_length'] = None
        
        # 1.065071821
        setting['percentage_of_reoccurring_values_to_all_values'] = None
        # 1.367287397
        setting['energy_ratio_by_chunks'] = default_setting['energy_ratio_by_chunks']
        # 1.418880939
        setting['fft_aggregated'] = [{'aggtype': 'centroid'}, {'aggtype': 'variance'}, {'aggtype': 'skew'}, {'aggtype': 'kurtosis'}]
        # 1.59143734
        setting['longest_strike_below_mean'] = None
        # 1.580827475
        setting['longest_strike_above_mean'] = None
        # 1.827394009
        setting['ratio_beyond_r_sigma'] = default_setting['ratio_beyond_r_sigma']
        # 1.187055588
        setting['spkt_welch_density'] = default_setting['spkt_welch_density']
        
        # 2.343521595
        setting['large_standard_deviation'] = [{'r': 0.05}, {'r': 0.1}, {'r': 0.15000000000000002}, {'r': 0.2}, {'r': 0.25}, {'r': 0.30000000000000004}]
        # 2.379972219
        #setting['number_peaks'] = [{'n': 1}, {'n': 3}, {'n': 5}, {'n': 10}, {'n': 50}]
        # 2.39845705
        setting['linear_trend'] = default_setting['linear_trend']
        # 2.450432301
        setting['quantile'] = default_setting['quantile']
        # 2.520286083
        setting['autocorrelation'] = default_setting['autocorrelation']
        # 2.563483953
        setting['agg_autocorrelation'] = default_setting['agg_autocorrelation']
        # 2.65562582
        setting['percentage_of_reoccurring_datapoints_to_all_datapoints'] = None
        # 3.19198823
        setting['cwt_coefficients'] = default_setting['cwt_coefficients']
        # 3.611354589
        setting['fft_coefficient'] = default_setting['fft_coefficient']
        # 5.246759653
        #setting['fourier_entropy'] = default_setting['fourier_entropy']
        return setting
    
    def extract_tsfresh_feat(self, recording, src_fs=500, dst_fs=200, max_sec=60):
        """
        Args:
            recording (numpy.ndarray): ECG recording (C, T)
            src_fs (int): sampling freq before tsfresh feature extraction
            dst_fs (int): down sample ECG data for faster computation
            max_sec (float): trim ECG data if it is too long
        """

        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        import tsfresh
        setting = self.get_tsfresh_setting()
        
        # 1. resample
        max_samples = 30000 #60 * 500
        if recording.shape[1] > max_samples:
            recording = recording[:, :max_samples]
        recording = resampy.resample(recording, src_fs, dst_fs).T  # (C, T) -> (T, C)
        dt = 0.005 # 1.0 / 200
            
        # 2. make pandas DataFrame
        df = pd.DataFrame(data=recording, columns=tsfresh_columns)
        df['sample_id'] = "0"
        #df['time'] = np.arange(recording.shape[0] * dt)
        
        extracted_features = tsfresh.extract_features(
            df, column_id="sample_id", default_fc_parameters=setting, n_jobs=0,
            disable_progressbar=True,
        )
        extracted_features = np.nan_to_num(
            extracted_features.values, nan=0.0, posinf=0.0, neginf=0.0,
        )[0]
            
        feat_num = int(extracted_features.shape[0] / 12)
        # print("shape of extracted features: {}".format(extracted_features.shape))
        return feat_num, extracted_features.reshape(12, feat_num).tolist()
    
        
class ECGData:
    """ECG data class for feature extraction

    Args:
        header_file (str): header file path
        recording_file (str): recording data file path
        sym2int (dict): symbol str to id mapper
        fs (int): sampling freq
        trim_dur (float): trim data if ECG data is too long
        norm_opt (str): normalization option [range|gauss||]
        denoise (bool): apply low-pass filter if True
    """
    def __init__(
            self,
            header_file,
            recording_file,
            sym2int,
            outdir,
            fs=500,
            trim_dur=-1,
            norm_opt="range",
            denoise=False,
            use_tsfresh=False,
            is_train_mode=True,
    ):
        if header_file.endswith('.hea'): # and os.path.exists(header_file):
            self.header_file = header_file
            self.header = helper_code.load_header(header_file)
        else:
            self.header = header_file
            
        if isinstance(recording_file, str):
            self.recording_file = recording_file
            self.recording = None
        else:
            self.recording_file = None
            self.recording = recording_file
            
        self.sym2int = sym2int
        self.outdir = outdir
        self.dst_fs = fs
        self.trim_dur = trim_dur
        self.norm_opt = norm_opt
        self.denoise = denoise
        self.use_tsfresh = use_tsfresh
        self.is_train_mode = is_train_mode

        self.sample_id = self.set_sample_id(header_file)

    def set_sample_id(self, header_file):
        header_file_arr = header_file.split('/')
        if len(header_file_arr) >= 2:
            sample_id = '_'.join(header_file.split('/')[-2:]).replace('.hea', '')
        else:
            sample_id = '_'.join(header_file.split('/')[-1]).replace('.hea', '')
        return sample_id
            
    @staticmethod
    def get_official_leads(n_leads):
        if n_leads == 12:
            return twelve_leads
        elif n_leads == 6:
            return six_leads
        elif n_leads == 4:
            return four_leads
        elif n_leads == 3:
            return three_leads
        elif n_leads == 2:
            return two_leads
        else:
            raise ValueError("got unknown label set")
    
    #@staticmethod
    #def select_reduced_leads(header, recording, leads):
    #    """Get reduced lead ECG data
    #
    #    Args:
    #        header (list): header
    #        recording (list): recording (n_leads, T)
    #        leads (list): leads to extract
    #    """
    #    available_leads = helper_code.get_leads(header)
    #    indices = list()
    #    for lead in leads:  # leads to extract
    #        # position of 'lead' on available_leads
    #        i = available_leads.index(lead)
    #        indices.append(i)
    #        
    #    recording = recording[indices, :]
    #    return recording
        
    def check_age_sex(self):
        """Return age and sex

        Returns:
            int: age
            int: sex (0: Female, 1: , 2: Male)
        """
        age = helper_code.get_age(self.header)
        if math.isnan(age):
            age = -1
            
        sex = helper_code.get_sex(self.header)
        if sex == "f" or sex == "F" or sex == "female" or sex == "Female":
            sex = 0 #"Female"
        elif sex == "m" or sex == "M" or sex == "male" or sex == "Male":
            sex = 2 #"Male"
        else:
            sex = 1
        return age, sex
    
    def check_header(self):
        """Load header and check statistics, age, gender, etc
        """
        dic = {}
        
        age, sex = self.check_age_sex()
        dic["age"] = age
        dic["sex"] = sex

        leads = helper_code.get_leads(self.header)
        num_leads = len(leads)

        # (TODO) check leads order, and refine it
        
        #assert(num_leads == 12)
        # unexpected n-lead order
        #assert(leads == list(helper_code.twelve_leads))
        dic["leads"] = leads
        dic["num_leads"] = num_leads
        
        freq = helper_code.get_frequency(self.header)
        dic["freq"] = freq
        
        n_samp = helper_code.get_num_samples(self.header)
        
        labels = helper_code.get_labels(self.header)
        #assert(len(labels) > 0)
        dic["labels"] = labels
        dic["n_labels"] = len(labels)
        return dic

    @staticmethod
    def normalize_signal_sigma(recording):
        mean = np.mean(recording, axis=1, keepdims=True)
        std = np.std(recording, axis=1, keepdims=True)
        
        for i in range(recording.shape[0]):
            if std[i] != 0.0:
                recording[i, :] = (recording[i, :] - mean[i]) / std[i]
            else:
                recording[i, :] = 0.0
        return recording

    @staticmethod
    def normalize_signal_range(recording, ret_min=-1, ret_max=1):
        vmin = np.min(recording, axis=1, keepdims=True)
        vmax = np.max(recording, axis=1, keepdims=True)
        
        # there's unrecorded lead  (if np.any(vmax - vmin == 0))
        for i in range(recording.shape[0]):
            if vmax[i] - vmin[i] != 0:
                recording[i] = (recording[i] - vmin[i]) / (vmax[i] - vmin[i])
                
        recording = recording * (ret_max - ret_min) + ret_min
        return recording
    
    @staticmethod
    def normalize_signal(recording, norm_opt):
        if norm_opt == "gauss_sample":
            recording = ECGData.normalize_signal_sigma(recording)
        elif norm_opt == "range":
            recording = ECGData.normalize_signal_range(recording)
        else:
            print("skip normalization.")
            recording = recording
        return recording

    @staticmethod
    def denoise_signal(
            recording,
            fs,
            btype='low',
            cutoff_low=3.0,
            cutoff_high=45.0,
            order=5
    ):
        nyquist = fs / 2.
        if btype == 'band':
            cut_off = (cutoff_low / nyquist, cutoff_high / nyquist)
        elif btype == 'high':
            cut_off = cutoff_low / nyquist
        elif btype == 'low':
            cut_off = cutoff_high / nyquist
        else:
            return recording
        
        b, a = butter(order, cut_off, analog=False, btype=btype)
        for i in range(recording.shape[0]):
            recording[i, :] = lfilter(b, a, recording[i, :])
        return recording

    @staticmethod
    def resample(recording, src_fs, dst_fs):
        """Resample recording

        Args:
            recording (numpy.ndarray): ECG data (C, T)
            src_fs (int): sampling freq before resample
            dst_fs (int): sampling freq after resample
        """
        if src_fs != dst_fs:
            _recording = []
            for i in range(recording.shape[0]):
                _recording.append(resampy.resample(recording[i, :], src_fs, dst_fs))
            recording = np.stack(_recording)
        return recording
            
    def check_data(self, src_fs, leads=twelve_leads):
        """Load, resample, and (pre)process ECG signal
        
        Args:
            src_fs (int): original sampling frequency
        """
        # 1. shape of recording: (#n-leads, #seq-len)
        if self.recording is None:
            recording = helper_code.load_recording(
                self.recording_file, header=self.header, leads=leads
            ).astype(np.float32)
        else:
            recording = self.recording.astype(np.float32)

        adc_gains = helper_code.get_adc_gains(self.header, leads).astype(np.float32)
        baselines = helper_code.get_baselines(self.header, leads).astype(np.float32)
        num_leads = len(leads)
        
        # 2. compute rms, check invalid leads (max value == min value)
        # rms = []
        invalid_ids = []
        vmin = np.min(recording, axis=1, keepdims=True)
        vmax = np.max(recording, axis=1, keepdims=True)
        
        for i in range(num_leads):
            if vmin[i] == vmax[i]:
                invalid_ids.append(i)
                recording[i, :] = 0.0 * recording[i, :]
            else:
                recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]
                
            #rms.append(np.sqrt(np.sum(recording[i, :]**2) / np.size(recording[i, :])))
        
        # 3. resample ECG data
        recording = ECGData.resample(recording, src_fs, self.dst_fs)
        
        # 4. apply low-pass filter
        #if self.denoise:
        #    recording = ECGData.denoise_signal(recording, self.dst_fs)
        
        # 5. normalize
        recording = ECGData.normalize_signal(recording, self.norm_opt)
        
        # 6. trim sequence data
        if self.trim_dur > 0:
            n_samples = int(self.dst_fs * self.trim_dur)
            pos = 0 #recording.shape[1] - n_samples
            if self.trim_dur > 0.0 and pos >= 0:
                pos = int(pos / 2)
                recording = recording[:, pos:pos+n_samples]
                
        if is_nan_inf(recording):
            #raise ValueError("detect nan or inf")
            return None, None, None
        
        return recording, invalid_ids

    def compute_rr_features(self, recording, max_dur=30000, stat_dim=20):
        """Extract RR related features

        Please refer to biosppy for more details
        
        Args:
            recording (numpy.ndarray): ECG recording (T,)
            max_samples (int): maximum length of recording_i (500*20)

        Returns:
            numpy.ndarray: extracted rr-related feature
        """
        fs = self.dst_fs
        feat = np.zeros(stat_dim, dtype=np.float32)
        if recording.shape[0] > max_dur:
            recording = recording[:max_dur]

        # 1. filter signal
        order = int(0.3 * fs)
        filtered, _, _ = biosppy.tools.filter_signal(
            signal=recording, ftype='FIR', band='bandpass', order=order,
            frequency=[3, 45], sampling_rate=fs,
        )
        
        # 2. segment
        rpeaks, = biosppy.signals.ecg.christov_segmenter(
            signal=filtered, sampling_rate=fs,
        )
        
        # 3. check R-peak locations
        rpeaks, = biosppy.signals.ecg.correct_rpeaks(
            signal=filtered, rpeaks=rpeaks, sampling_rate=fs, tol=0.05,
        )
        
        # 4. extract templates
        templates, rpeaks = biosppy.signals.ecg.extract_heartbeats(
            signal=filtered, rpeaks=rpeaks, sampling_rate=fs, before=0.2, after=0.4,
        )
        
        # 5. compute heart rate
        try:
            hr_idx, hr = biosppy.tools.get_heart_rate(
                beats=rpeaks, sampling_rate=fs, smooth=True, size=3
            )
        except ValueError:
            return stat_dim, np.zeros(stat_dim, dtype=np.float32)
        
        # rri
        length = len(recording)
        T = (length - 1) / fs
        ts = np.linspace(0, T, length, endpoint=True)
        ts_peaks = ts[rpeaks]
        
        try:
            rri = np.diff(ts_peaks) * 1000
            spline_func = scipy.interpolate.interp1d(ts_peaks[:-1], rri, kind='cubic')
            ts_1sec = np.arange(ts_peaks[0], ts_peaks[-2], 1)
            rri_1sec = spline_func(ts_1sec).round(6)
        except ValueError:
            return stat_dim, np.zeros(stat_dim, dtype=np.float32)
        
        # min/max/median
        rri_min = np.min(rri_1sec)
        rri_max = np.max(rri_1sec)
        rri_mean = np.mean(rri_1sec)
        rri_std = np.std(rri_1sec)
        
        # ['rri_min', 'rri_max', 'rri_mean', 'rri_std']
        
        # ['hf', 'hfnu', 'lf', 'lf_hf', 'lfnu',
        #  'mhr', 'mrri', 'nn50', 'pnn50', 'rmssd',
        #  'sd1', 'sd2', 'sdnn', 'sdsd', 'total_power',
        #  'vlf']

        # hrv
        try:
            hrv_rri = hrv.rri.RRi(rri_1sec)
            hrv_time = hrv.classical.time_domain(hrv_rri)
            hrv_freq = hrv.classical.frequency_domain(
                rri=hrv_rri,
                fs=4.0,
                method='welch',
                interp_method='cubic',
                detrend='linear'
            )
            hrv_nlin = hrv.classical.non_linear(hrv_rri)
        except ValueError:
            return stat_dim, np.zeros(stat_dim, dtype=np.float32)
        except TypeError:
            return stat_dim, np.zeros(stat_dim, dtype=np.float32)
        
        hrv_dic = {}
        hrv_keys = []
        hrv_vals = []
        hrv_dic.update(hrv_time)
        hrv_dic.update(hrv_freq)
        hrv_dic.update(hrv_nlin)
        for k, v in sorted(hrv_dic.items(), key=lambda x: x[0]):
            hrv_keys.append(k)
            hrv_vals.append(v)
            #print(str(k) + ": " + str(v))
            
        dim = 4 + len(hrv_keys)
        feat = np.array([rri_min, rri_max, rri_mean, rri_std])
        feat = np.hstack([feat, np.array(hrv_vals)])
        feat = np.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0)
        return dim, feat
        
    def process(self, leads=twelve_leads):
        # 1. check header/data
        sample_dic = self.check_header()
        recording, invalid_ids = self.check_data(sample_dic["freq"], leads=leads)
        if recording is None:
            return None, None, None

        # 2. update dic
        sample_dic["n_samp"] = recording.shape[1]
        #sample_dic["rms"] = rms
        sample_dic["freq"] = self.dst_fs
        sample_dic["invalid_ids"] = invalid_ids
        sample_dic["n_invalid"] = len(invalid_ids)
        if self.is_train_mode:
            syms = [l for l in sample_dic["labels"] if l in valid_syms]
            if self.sym2int is None:
                sample_dic["labels"] = [-1 for sym in syms]
            else:
                sample_dic["labels"] = [int(self.sym2int[sym]) for sym in syms]
            sample_dic["n_labels"] = len(sample_dic["labels"])

        # 3. save ECG data (local file or variable)
        if self.outdir is not None:
            signal_path = self.outdir + '/' + self.sample_id + '.npy'
            sample_dic["feat"] = signal_path
            np.save(signal_path, recording)
        else:
            sample_dic["feat"] = recording

        # 4. extract RR related feature
        # all sets has lead-II
        lead_idx = leads.index('II')
        stat_dim, stat_feat = self.compute_rr_features(recording[lead_idx, :])

        # 5. update extra_dic (tsfresh_feat)
        # extract tsfresh related feature
        if self.use_tsfresh:
            tsf = TsFresh()
            tsfresh_dim, tsfresh_feat = tsf.extract_tsfresh_feat(recording, src_fs=self.dst_fs)
            extra_sample_dic["tsfresh_dim"] = tsfresh_dim
            if self.outdir is not None:
                tsfresh_feat_path = self.outdir + '/' + self.sample_id + '_extra.npy'
                np.save(tsfresh_feat_path, tsfresh_feat)
                extra_sample_dic["tsfresh_feat"] = tsfresh_feat_path
            else:
                extra_sample_dic["tsfresh_feat"] = tsfresh_feat
                
        # 6. update extra_dic (stat_feat)
        extra_sample_dic = {}
        extra_sample_dic["stat_dim"] = stat_dim
        extra_sample_dic["stat_feat"] = stat_feat.tolist()
        
        return self.sample_id, sample_dic, extra_sample_dic
    
    @staticmethod
    def get_input_features(
            sample_dic,
            extra_sample_dic,
            train_args,
    ):
        """Get ECG and related stat features for inference

        Args:
            sample_dic (dict): sample_dic generated by process
                               (header info, ECG data)
            extra_sample_dic (dict): extra_sample_dic
                               (features extracted by biosppy, hrv, and tsfresh)
            train_args (argparse.Namespace): arguments used in the training stage
        """

        # ECG
        if sample_dic["feat"] is None:
            return None, None
        else:
            feat = sample_dic["feat"].astype(np.float32)  # (C, T)

        # extra feature
        if train_args.edim > 0:

            extra = np.zeros(train_args.edim, dtype=np.float32)
            extra[0] = sample_dic["age"]  # age
            extra[1] = sample_dic["sex"]  # gender
            pos = 2
            
            extra[pos:pos+extra_sample_dic["stat_dim"]] = \
                np.array(extra_sample_dic["stat_feat"], dtype=np.float32)
            pos = pos + extra_sample_dic["stat_dim"]
            
            if preproc_config["use_tsfresh"]:
                raise NotImplementedError

            # lead-dep 
            #extra[pos:pos+extra_sample_dic["tsfresh_dim"]] = \
                #np.array(extra_sample_dic["tsfresh_feat"], dtype=np.float32)
            ##pos = pos + extra_sample_dic["tsfresh_dim"]
                
        else:
            extra = np.array(1.0, dtype=np.float32)

        return feat, extra

    
# top function called by executor
def stat_feat_fn(header_file, feat_opt, fs, sym2int, outdir):
    """Extract feature

      Args:
          header_file (str): header file name
          feat_opt (Dict[str]): option for feature extraction
          fs (int): sampling frequency
          sym2int (Dict[str]): symbol to id mapping
          outdir (str): output directory name

      Returns:
          str: sample id
          Dict[str]: feature (ECG signal and related features)
    """
    print("processing: {}".format(header_file), flush=True)
    ecg_data = ECGData(
        header_file,
        header_file.replace('.hea', '.mat'),
        sym2int,
        outdir,
        fs=fs,
        trim_dur=feat_opt["trim_dur"],
        norm_opt=feat_opt["norm_opt"],
        denoise=feat_opt["denoise"],
        use_tsfresh=feat_opt["use_tsfresh"],
    )
    sample_id, sample_dic, extra_sample_dic = ecg_data.process(leads=twelve_leads)
    return sample_id, sample_dic, extra_sample_dic
    
def make_json(header_files,
              outdir,
              json_path,
              sym2int,
              fs=500,
              trim_dur=-1,
              norm_opt="",
              denoise=False,
              extra_prefix=False,
              verbose=True,
              use_tsfresh=False,
):
    """Dump ECG data path and related information to json file
    
    Args:
        header_files (list): list of header files to dump in a json file
        outdir (str): output directory
        json_path (str): json path
        sym2int (dict): mapping from label symbol to integer id
        fs (int): resampling frequency
        trim_dur (float): trim duration [sec]
        norm_opt (str): normalization method
        denose (bool): apply low-pass filter if True
        extra_prefix (bool): rename sample name by adding directory name
    """
    dic = {"data": {}}
    extra_dic = {"data": {}}

    feat_opt = {}
    feat_opt["trim_dur"] = trim_dur
    feat_opt["norm_opt"] = norm_opt
    feat_opt["denoise"] = denoise
    feat_opt["use_tsfresh"] = use_tsfresh

    batch_size = 32
    num_cpus = os.cpu_count()
    step = int(num_cpus * batch_size)

    for i in range(0, len(header_files), step):
        # 0.
        if verbose:
            print("processing: {}".format(header_files[i]))
            
        # 1. get batch of files
        if i + step <= len(header_files):
            batched_header_files = header_files[i:i+step]
        else:
            batched_header_files = header_files[i:]

        # 2. extract feat
        with ProcessPoolExecutor(max_workers=num_cpus) as executor:
            futures = [
                executor.submit(stat_feat_fn,
                                batched_header_files[j],
                                feat_opt,
                                fs,
                                sym2int,
                                outdir) for j in range(len(batched_header_files))
            ]
            (done, notdone) = confu_wait(futures)

            for j, future in enumerate(futures):
                sample_id, sample_dic, sample_extra_dic = future.result()
                if sample_id is not None and sample_dic is not None:
                    dic["data"][sample_id] = sample_dic
                    extra_dic["data"][sample_id] = sample_extra_dic
                    
    # 3. dump json files
    with open(json_path, 'w') as f:
        json.dump(dic, f, indent=4)
    with open(json_path.replace('.json', '_extra.json'), 'w') as f:
        json.dump(extra_dic, f, indent=4)

def update_labels(labels):
    """ Update labels (some labels will be merged at the eval stage)
    Args:
        labels (list): list of labels given to thesample
    """
    res = set()
    for label in labels:
        if label in sym_map:
            res.add(sym_map[label])
        else:
            res.add(label)
    return list(res)

def read_data_dirs(
        data_dirs=['train'],
        use_all_symbol=False
):
    """
    Args:
        data_dirs (list): training data directory (physionet 2021 format)
        use_all_symbol (bool): store all symbols if True
    """
    data_dic = defaultdict(list)  # key: label, val: sample-id
    
    def _update_labels(labels):
        res = set()
        for label in labels:
            if label in sym_map:
                res.add(sym_map[label])
            else:
                res.add(label)
        return list(res)
    
    for data_dir in data_dirs:
        data_info = data_dir.split('/')[-1]
        header_files, _ = helper_code.find_challenge_files(data_dir)
        
        for header_file in header_files:
            header = helper_code.load_header(header_file)
            labels = helper_code.get_labels(header)
            labels = _update_labels(labels)
            
            for label in labels:
                if use_all_symbol:
                    data_dic[label].append(header_file)
                else:
                    if label in valid_syms:
                        data_dic[label].append(header_file)
    return data_dic

def make_sym2int(data_dic, outdir, store=True):
    """Make sym2int (wrt label, gender, corpus), and dump it to storege
    
    Args:
        data_dic (dict)
        outdir (str): output directory to save sym2int files
        store (bool): write sym2int files to the outdir if True
    Returns:
        list: list of labels
        dict: mappling dict from symbol to integer id
    """
    
    label_set = sorted(list(data_dic.keys()), key=lambda x: int(x))
    gender_set = ["Female", "Male", "None"]
    
    def write_sym2int(syms, fname):
        """
        syms (list): list of symboles
        fname (str): output file name
        """
        sym2int = {}
        with open(fname, 'w') as f:
            for i, sym in enumerate(syms):
                sym = str(sym)
                f.writelines(sym + ' ' + str(i) + '\n')
                sym2int[sym] = i
        return sym2int
            
    sym2int = write_sym2int(label_set, outdir + '/sym2int_label')
    write_sym2int(gender_set, outdir + '/sym2int_gender')
    sym2int = update_sym2int(sym2int)
    return label_set, sym2int

def make_tr_dt_tt(data_dic, n_ignore=10, tt_rate=0.1, dt_rate=0.1, n_min=-1):
    """
    Args:
        n_ignore (int): ignore rare samples
        tt_rate (float): split data (#tot_samples * tt_rate -> test set)
        dt_rate (float): split data (#tot_samples * dt_rate -> dev set)
        n_min (int): restrict number of tt/dt samples by a fixed val
    """
    tr_ids = []  # training set
    dt_ids = []  # development set
    tt_ids = []  # evaluation set
    sorted_data = sorted(data_dic.items(), key=lambda x:len(x[1]))
    
    for label, sample_ids in sorted_data:
        if len(sample_ids) < n_ignore:
            continue  # ignore rare samples
        
        # exclusive
        tt_cands = list(set(sample_ids) - set(tt_ids) - set(tr_ids) - set(dt_ids))
        assert(tt_rate >= 0.0 and dt_rate >= 0.0)
        if n_min > 1:
            n_tt = min(int(tt_rate * len(tt_cands)), n_min)
            n_dt = min(int(dt_rate * len(tt_cands)), n_min)
        else:
            n_tt = int(tt_rate * len(tt_cands))
            n_dt = int(dt_rate * len(tt_cands))
            
        n_tr = len(tt_cands) - n_tt - n_dt
        random.shuffle(tt_cands)
        
        #assert(n_tt > 0 and n_dt > 0 and n_tt > 0)
        #if n_tt > 0:
        tt_ids = tt_ids + tt_cands[:n_tt]
        #if n_dt > 0 and n_tr > 0:
        rest = tt_cands[n_tt:]
        random.shuffle(rest)
        dt_ids = dt_ids + rest[:n_dt]
        tr_ids = tr_ids + rest[n_dt:]
        
    return tr_ids, dt_ids, tt_ids


def make_dirs(outdir):
    """ mkdir -p outdir/tr|dt|tt
    Args:
        outdir (str): data directory name
    """
    for task in ['tr', 'dt', 'tt', 'tr_2', 'dt_2', 'tt_2', 'tr_extra']:
        os.makedirs(outdir + '/' + task, exist_ok=True)
        
def make_json_files(
        data_dir, outdir, dt_rate=0.0, tt_rate=0.1, fs=500, trim_dur=30,
        norm_opt="range", denoise=False, use_tsfresh=False,
):
    # make dirs to save preprocessed files
    make_dirs(outdir)

    # load samples in data directory
    # key: label, val: list of sample-ids
    data_dic = read_data_dirs([data_dir], use_all_symbol=False)

    # make symbol to id mapper
    label_set, sym2int = make_sym2int(data_dic, outdir, store=True)
    
    # split data to tr/dt/tt set
    tr_ids, dt_ids, tt_ids = make_tr_dt_tt(
        data_dic, tt_rate=tt_rate, dt_rate=dt_rate,
    )
    
    # dump json
    make_json(tr_ids, outdir + '/tr', outdir + '/tr.json',
              sym2int, fs, trim_dur, norm_opt, denoise,
              use_tsfresh=use_tsfresh)
    make_json(dt_ids, outdir + '/dt', outdir + '/dt.json',
              sym2int, fs, trim_dur, norm_opt, denoise,
              use_tsfresh=use_tsfresh)
    make_json(tt_ids, outdir + '/tt', outdir + '/tt.json',
              sym2int, fs, trim_dur, norm_opt, denoise,
              use_tsfresh=use_tsfresh)
    
    # update header name
    #copy_file(tr_ids, outdir + '/tr_2')
    #copy_file(dt_ids, outdir + '/dt_2')
    #copy_file(tt_ids, outdir + '/tt_2')
    

################################################################################
#
# Data loader and Utils
#
################################################################################
def initialize():
    # for faster dilated conv
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    cuda_avail = torch.cuda.is_available()
    return cuda_avail

def get_idim(n_leads=12):
    idim = n_leads
    return idim

def get_edim(idim, extra_json_path, use_tsfresh=False):
    """Get dimension of extra stat feature
    
    Args:
        idim (int): input ECG dimension
        extra_json_path (str): path to extra json file
    Returns:
        int: dimension of extra feature
    """
    edim = 2  # age, gender
    
    with open(extra_json_path, 'r') as f:
        js = json.load(f)["data"]
        for k, v in js.items():
            edim = edim + v["stat_dim"]
            if use_tsfresh:
                edim = edim + idim * v["tsfresh_dim"]
            break
        
    return edim

def read_dic(fname):
    """Read file
    """
    dic = {}
    with open(fname, 'r') as f:
        for line in f:
            arr = line.rstrip().split()
            assert(len(arr) == 2)
            dic[arr[0]] = arr[1]
    return dic

class AddGaussianNoise:
    """Add Gaussian noise
    remind: decide by amplitude, but not snr
    """
    def __init__(
            self,
            min_noise_amplitude=0.001,
            max_noise_amplitude=0.01,
    ):
        self.min_noise_amplitude = min_noise_amplitude
        self.max_noise_amplitude = max_noise_amplitude
        
    def forward(self, feat):
        # 1. generate 0-mean unit-variance nosie
        noise = np.random.randn(*feat.shape).astype(np.float32)
        # 2. scale to a specified value U(vmin, vmax)
        noise_amplitude = np.random.uniform(
            self.min_noise_amplitude, self.max_noise_amplitude
        )
        feat = feat + noise_amplitude * noise
        return feat
    
    @staticmethod
    def apply(feat, vmin=0.001, vmax=0.01):
        # 1. generate 0-mean unit-variance nosie
        noise = np.random.randn(*feat.shape).astype(np.float32)
        # 2. scale to a specified value U(vmin, vmax)
        noise_amplitude = np.random.uniform(vmin, vmax)
        feat = feat + noise_amplitude * noise
        return feat
    

class physionet_dataset(Dataset):
    """
    Args:
        json_file (str): json file
        cache_dir (str): tmp directory to save temp files
        transform: perturb function
        mode (str): data generator mode
        pred (list):
        probability (list):
        idim (int): model input dimension
        odim (int): model output dimension
        n_leads (int): make reduced ECG data for physionet challenge
        max_seqlen (int): generate fixed length tensor
    """
    def __init__(
            self,
            json_file,
            cache_dir,
            transform,
            mode,
            pred=[],
            probability=[],
            idim=-1,
            odim=-1,
            n_leads=12,
            min_seqlen=-1,
            max_seqlen=7500,
            loader_version=1,
            edim=-1,
            efeat_drop=0.1,
    ):
        with open(json_file, 'r') as f:
            js = json.load(f)["data"]
            self.utts = list(js.keys())
            self._n = len(self.utts)
            
        extra_json_file = json_file.replace('.json', '_extra.json')
        if edim > 0:
            with open(extra_json_file, 'r') as f:
                extra_js = json.load(f)["data"]
                
        # list of numpy files (str)
        self.xs = []
        # list of target labels (numpy.ndarray), (B, odim)
        self.ys = np.zeros((self._n, odim), dtype=np.float32)
        # extra feats (age, gender, etc)
        # list of numpy files (str)
        # extra feats (age, gender, etc)
        self.edim = edim
        self.es = []
        self.efeat_drop = efeat_drop
        if edim > 0:
            lead_ind_edim = 2 + 20  # lead independent feature
            self.extra = np.zeros(
                (self._n, lead_ind_edim), dtype=np.float32
            )
        else:
            self.extra = None
            
        self.n_leads = n_leads
        self.lead_ids = self.get_lead_ids(n_leads)
        eps = 1e-10

        for i, (key, j) in enumerate(js.items()):
            self.xs.append(j["feat"])
            self.ys[i, list(map(int, j["labels"]))] = 1
            
            if edim > 0:
                # age
                self.extra[i, 0] = j["age"]  # -1.0: None
                # gender (0: female, 1: none, 2: male)
                self.extra[i, 1] = j["sex"]
                
                # set stat_feat
                start_pos = 2
                stat_dim = 20  # ith_extra_js['stat_dim']
                self.extra[i, start_pos:start_pos+stat_dim] = \
                    np.array(extra_js[key]['stat_feat'], dtype=np.float32)
                #start_pos = start_pos + stat_dim
                
                # set lead-dependent stat feature
                if preproc_config["use_tsfresh"]:
                    self.es.append(extra_js[key]['tsfresh_feat'])
                
        self.xs = np.array(self.xs)
        self.es = np.array(self.es)
        
        self.loader_version = loader_version
        self.transform = transform
        self.mode = mode
        
        if self.mode == 'all' or self.mode == 'test':
            pass
        else:
            if self.mode == "labeled":
                pred_idx = pred.nonzero()[0]
                self.probability = [probability[i] for i in pred_idx]
                
            elif self.mode == "unlabeled":
                pred_idx = (1-pred).nonzero()[0]
                
            pred_idx = pred_idx.astype(np.int32).tolist()
            self.xs = self.xs[pred_idx]
            if preproc_config["use_tsfresh"]:
                self.es = self.es[pred_idx]
            self.ys = [self.ys[i] for i in pred_idx]
            if self.extra is not None:
                self.extra = self.extra[pred_idx]
                
        self.n = len(self.xs)
        self.min_seqlen = min_seqlen
        self.max_seqlen = max_seqlen
        
        self.age_var = 1
        
    def get_lead_ids(self, n_leads):
        """Get lead ids given n_leads
                
        #twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
        #six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF'
        #four_leads = ('I', 'II', 'III', 'V2')
        #three_leads = ('I', 'II', 'V2')
        #two_leads = ('I', 'II')
        """
        if n_leads == 2:
            ids = np.array([0, 1])
        elif n_leads == 3:
            ids = np.array([0, 1, 7])
        elif n_leads == 4:
            ids = np.array([0, 1, 2, 7])
        elif n_leads == 6:
            ids = np.array([0, 1, 2, 3, 4, 5])
        else:
            ids = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
            
        return ids
    
    def apply_transform(self, data):
        """default transform (identity)
        Args:
            data (torch.Tensor): ECG data (n_leads, seqlen)
        """
        if self.transform is None:
            return data
        
        for fn in self.transform:
            data = fn(data)
        return data
    
    def _pad(self, _feat):
        ilen = _feat.shape[1]
        feat = np.zeros((_feat.shape[0], self.max_seqlen), dtype=np.float32)
        
        if ilen <= self.max_seqlen:
            # make min_seqlen < x < max_selqen feats
            if self.min_seqlen > 0 and ilen > self.min_seqlen:
                seqlen = random.randint(self.min_seqlen, ilen)
                feat[:, :seqlen] = _feat[:, :seqlen]
            else:
                feat[:, :ilen] = _feat
        else:
            pos = random.randint(0, int(_feat.shape[1] - self.max_seqlen - 1))
            
            # make min_seqlen < x < max_selqen feats
            if self.min_seqlen > 0:
                seqlen = random.randint(self.min_seqlen, self.max_seqlen)
            else:
                seqlen = self.max_seqlen
                
            feat[:, :seqlen] = _feat[:, pos:pos + seqlen]
            ilen = self.max_seqlen
        return feat
    
    def __getitem__(self, index):
        """Return index-th sample
        """
        feat = np.load(self.xs[index])[self.lead_ids, :]  # (n_leads, seqlen)
        feat = self._pad(feat).astype(np.float32)
        label = self.ys[index]
        
        if self.edim > 0:
            extra = self.extra[index]
            
            if random.random() < self.efeat_drop:
                extra[0] = -1.0  # age
            else:
                extra[0] = extra[0] + np.random.normal(0, self.age_var)
                
            if random.random() < self.efeat_drop:
                extra[1] = 1.0  # gender
                
            if random.random() < self.efeat_drop:
                extra[2:22] = extra[2:22] * 0.0

            if preproc_config["use_tsfresh"]:
                if random.random() < self.efeat_drop:
                    lead_dep_extra = np.load(self.es[index])[self.lead_ids].reshape(-1)
                    lead_dep_extra = lead_dep_extra * 0.0
                extra = np.hstack((extra, lead_dep_extra))
                
        else:
            extra = 1.0
        extra = np.array(extra, dtype=np.float32)
        
        if self.mode == 'labeled':
            prob = self.probability[index]
            if self.loader_version == 1:
                feat1 = feat
                feat2 = feat #AddGaussianNoise.apply(feat, rate=0.0)
            elif self.loader_version == 2:
                feat1 = self.apply_transform(feat)
                feat2 = self.apply_transform(feat)
            return (feat1, feat2, label, prob, extra)
        
        elif self.mode == 'unlabeled':
            if self.loader_version == 1:
                feat1 = feat
                feat2 = feat #AddGaussianNoise.apply(feat, rate=0.0)
            else:
                feat1 = self.apply_transform(feat)
                feat2 = self.apply_transform(feat)
            return (feat1, feat2, label, extra)
        
        elif self.mode == 'all':
            feat = self.apply_transform(feat)
            index = np.array(index)
            return (feat, label, index, extra)
        
        elif self.mode == 'test':
            feat = self.apply_transform(feat)
            return (feat, label, extra)
        
    def __len__(self):
        """return #samples
        """
        return self.n
    
    
class physionet_dataloader:
    """
    Args:
        json_file (str): json file
        batch_size (int): batch size
        num_workers (int): num workers
        cache_dir (str): tmp directory to save temp files
        idim (int): model input dimension
        odim (int): model output dimension
        n_leads (int): make reduced ECG data for physionet challenge
        max_seqlen (int): generate fixed length tensor
        loader_version (int): dataloader version control
    """
    def __init__(
            self,
            json_file,
            batch_size,
            num_workers,
            cache_dir,
            idim,
            odim,
            n_leads=12,
            min_seqlen=-1,
            max_seqlen=7500,
            loader_version=1,
            edim=-1,
            efeat_drop=0.1,
            drop_last=True,
    ):
        self.json_file = json_file
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.drop_last = drop_last
        
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        
        self.idim = idim
        self.odim = odim
        self.n_leads = n_leads
        self.min_seqlen = min_seqlen
        self.max_seqlen = max_seqlen
        
        self.edim = edim
        self.efeat_drop = efeat_drop
        
        self.loader_version = loader_version
        if loader_version == 1:
            self.transform_train = None #transforms.Compose([])
        elif loader_version == 2:
            self.transform_train = None #[
            #Denoise._denoise_signal,
            #AddGaussianNoise.apply,
            #]
        self.transform_test = None #transforms.Compose([])
        
    def run(self, mode, pred=[], prob=[], min_seqlen=None, max_seqlen=None):
        """Generate DataLoader
        
        mode (str): warmup | train | test | eval_train
                    warmup: use all samples as labeled clean data
                           Return one dataloader
                    train: divide data into labeled/unlabeled samples.
                           Return two dataloaders
                    test: use all samples
                           Return one dataloader
                    eval_train: use all samples
                           Return one dataloader
        pred: -
        prob: -
        """
        
        # update seqlen for gradual? training
        if min_seqlen is None:
            min_seqlen = self.min_seqlen
        if max_seqlen is None:
            max_seqlen = self.max_seqlen
            
        if mode == 'warmup':
            all_dataset = physionet_dataset(
                json_file=self.json_file,
                cache_dir=self.cache_dir,
                transform=self.transform_train,
                mode="all",
                idim=self.idim,
                odim=self.odim,
                n_leads=self.n_leads,
                min_seqlen=min_seqlen,
                max_seqlen=max_seqlen,
                loader_version=self.loader_version,
                edim=self.edim,
                efeat_drop=self.efeat_drop,
            )
            train_loader = DataLoader(
                dataset=all_dataset,
                batch_size=self.batch_size * 2,
                shuffle=True,
                num_workers=self.num_workers,
                drop_last=self.drop_last,
            )
            return train_loader
        
        elif mode == 'train':
            # labeled
            labeled_dataset = physionet_dataset(
                json_file=self.json_file,
                cache_dir=self.cache_dir,
                transform=self.transform_train,
                mode="labeled",
                pred=pred,
                probability=prob,
                idim=self.idim,
                odim=self.odim,
                n_leads=self.n_leads,
                min_seqlen=min_seqlen,
                max_seqlen=max_seqlen,
                loader_version=self.loader_version,
                edim=self.edim,
                efeat_drop=self.efeat_drop,
            )
            labeled_train_loader = DataLoader(
                dataset=labeled_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers,
                drop_last=self.drop_last,
            )
            # unlabeled
            unlabeled_dataset = physionet_dataset(
                json_file=self.json_file,
                cache_dir=self.cache_dir,
                transform=self.transform_train,
                mode="unlabeled",
                pred=pred,
                idim=self.idim,
                odim=self.odim,
                n_leads=self.n_leads,
                min_seqlen=min_seqlen,
                max_seqlen=max_seqlen,
                loader_version=self.loader_version,
                edim=self.edim,
                efeat_drop=self.efeat_drop,
            )
            unlabeled_train_loader = DataLoader(
                dataset=unlabeled_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers,
                drop_last=self.drop_last,
            )
            return labeled_train_loader, unlabeled_train_loader
        
        elif mode == 'test':
            test_dataset = physionet_dataset(
                json_file=self.json_file,
                cache_dir=self.cache_dir,
                transform=self.transform_test,
                mode='test',
                idim=self.idim,
                odim=self.odim,
                n_leads=self.n_leads,
                min_seqlen=min_seqlen,
                max_seqlen=max_seqlen,
                loader_version=self.loader_version,
                edim=self.edim,
                efeat_drop=self.efeat_drop,
            )
            test_loader = DataLoader(
                dataset=test_dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.num_workers,
                drop_last=False,
            )
            return test_loader

        elif mode == 'eval_train':
            eval_dataset = physionet_dataset(
                json_file=self.json_file,
                cache_dir=self.cache_dir,
                transform=self.transform_test,
                mode='all',
                idim=self.idim,
                odim=self.odim,
                n_leads=self.n_leads,
                min_seqlen=min_seqlen,
                max_seqlen=max_seqlen,
                loader_version=self.loader_version,
                edim=self.edim,
                efeat_drop=self.efeat_drop,
            )
            eval_loader = DataLoader(
                dataset=eval_dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.num_workers,
                drop_last=False,
            )
            return eval_loader
        
        
class AmiTestDataset(Dataset):
    def __init__(
            self,
            json_file,
            idim,
            odim,
            max_seqlen=10000,
            n_leads=12,
            use_extra=False,
            batch_size=1,
            device="cpu",
            edim=-1,
    ):
        with open(json_file, 'r') as f:
            self.js = json.load(f)["data"]
            self.utts = list(self.js.keys())
            self.n = len(self.utts)
            
        if edim > 0:
            with open(json_file.replace('.json', '_extra.json'), 'r') as f:
                self.extra_js = json.load(f)["data"]
            
        print("idim: {}, edim: {}, odim: {}".format(idim, edim, odim))
        self.idim = idim
        self.odim = odim
        self.edim = edim
        
        self.max_seqlen =  max_seqlen
        self.n_leads = n_leads
        self.lead_ids = self.get_lead_ids(n_leads)
        
        self.batch_size = batch_size
        self.device = device
        
        self.init()
        
        # random drop for short utts
        #self.transform = None
        self.use_extra = use_extra
        self.load_target = False
        
    def init(self):
        """Shuffle utterance ids for next epoch.
        """
        random.shuffle(self.utts)
        self.iterable = self.utts
        
    def get_lead_ids(self, n_leads):
        # official phase
        # two_leads = ('I', 'II')
        # three_leads = ('I', 'II', 'V2')
        # four_leads = ('I', 'II', 'III', 'V2')
        # six_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF')
        # twelve_leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')
        if n_leads == 2:
            ids = np.array([0, 1])
        elif n_leads == 3:
            ids = np.array([0, 1, 7])
        elif n_leads == 4:
            ids = np.array([0, 1, 2, 7])
        elif n_leads == 6:
            ids = np.array([0, 1, 2, 3, 4, 5])
        else:
            ids = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
            
        return ids
    
    def __iter__(self):
        """Generate a new batch
        
        Return:
            list: list of utterance ids. it's size is equal to self.batch_size,
                  (or less than batch size at the final iteration)
        """
        l = len(self.iterable)
        for ndx in range(0, l, self.batch_size):
            utts = self.iterable[ndx:min(ndx + self.batch_size, l)]
            yield utts
            
    def load_feats(self, utts):
        """Load batch of feats
        
        Args:
            List[utts]: list uf sample ids
        """
        feats = np.zeros(
            (len(utts), self.idim, self.max_seqlen),
            dtype=np.float32
        )
        ilens = []
        seqlen = -1
        ys = np.zeros((len(utts), self.odim))
        
        if self.edim > 0:
            extra = np.zeros((len(utts), self.edim), dtype=np.float32)
        else:
            extra = np.zeros((len(utts), 1), dtype=np.float32)
            
        for i, utt in enumerate(utts):
            # shape: (C, T)
            _feat = np.load(self.js[utt]["feat"])
            ilen = _feat.shape[1]
            _feat = _feat[self.lead_ids, :]
            
            if ilen <= self.max_seqlen:
                feats[i, :, :_feat.shape[1]] = _feat
            else:
                pos = int((_feat.shape[1] - self.max_seqlen) / 2)
                feats[i, :, :] = _feat[:, pos:pos+self.max_seqlen]
                ilen = self.max_seqlen
                
            ilens.append(ilen)
            
            if self.load_target:
                idx = list(map(int, self.js[utt]["labels"]))
                ys[i, idx] = 1
                
            # extra
            if self.edim > 0:
                extra[i, 0] = self.js[utt]["age"]
                extra[i, 1] = self.js[utt]["sex"]
                
                # set stat_feat
                start_pos = 2
                stat_dim = self.extra_js[utt]["stat_dim"]
                extra[i, start_pos:start_pos+stat_dim] = \
                    np.array(self.extra_js[utt]['stat_feat'], dtype=np.float32)
                
                if preproc_config["use_tsfresh"]:
                    # set tsfresh feat
                    start_pos = start_pos + stat_dim
                    tsfresh_dim = self.extra_js[utt]["tsfresh_dim"]
                    tsfresh_feat = np.load(self.extra_js[utt]["tsfresh_feat"]).astype(np.float32)
                    #extra[i, start_pos:start_pos+tsfresh_dim] =
                    extra[i, start_pos:] = \
                        tsfresh_feat[self.lead_ids].reshape(1, -1)
                    
        feats = torch.from_numpy(feats).to(device=self.device)
        ilens = torch.from_numpy(np.array(ilens)).to(device=feats.device)
        ys = torch.from_numpy(ys).to(device=feats.device)
        extra = torch.from_numpy(extra).to(device=feats.device)
        
        #feats = feats.permute(0, 2, 1)
        return feats, ilens, ys, extra
    
    def make_tensor(self, utts):
        feats, ilens, ys, extra = self.load_feats(utts)
        return feats, ilens, ys, extra
    
    def __len__(self):
        return self.n
    
    
################################################################################
#
# Optimizers
#
################################################################################

# https://github.com/juntang-zhuang/Adabelief-Optimizer
#from adabelief_pytorch import AdaBelief
version_higher = ( torch.__version__ >= "1.5.0" )


class AdaBelief(Optimizer):
    r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-16)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
            algorithm from the paper `On the Convergence of Adam and Beyond`_
            (default: False)
        weight_decouple (boolean, optional): ( default: True) If set as True, then
            the optimizer uses decoupled weight decay as in AdamW
        fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
            is set as True.
            When fixed_decay == True, the weight decay is performed as
            $W_{new} = W_{old} - W_{old} \times decay$.
            When fixed_decay == False, the weight decay is performed as
            $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the
            weight decay ratio decreases with learning rate (lr).
        rectify (boolean, optional): (default: True) If set as True, then perform the rectified
            update similar to RAdam
        degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
            when variance of gradient is high
        print_change_log (boolean, optional) (default: True) If set as True, print the modifcation to
            default hyper-parameters
    reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
    """
    
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
                 weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True,
                 degenerated_to_sgd=True, print_change_log = True):
        
        # ------------------------------------------------------------------------------
        # Print modifications to default arguments
        if print_change_log:
            print(Fore.RED + 'Please check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.')
            print(Fore.RED + 'Modifications to default arguments:')
            default_table = tabulate([
                ['adabelief-pytorch=0.0.5','1e-8','False','False'],
                ['>=0.1.0 (Current 0.2.0)','1e-16','True','True']],
                                     headers=['eps','weight_decouple','rectify'])
            print(Fore.RED + default_table)
            
            recommend_table = tabulate(
                [
                    ['Recommended eps = 1e-8', 'Recommended eps = 1e-16'],
                ],
                headers=['SGD better than Adam (e.g. CNN for Image Classification)','Adam better than SGD (e.g. Transformer, GAN)'])
            print(Fore.BLUE + recommend_table)
            
            print(Fore.BLUE +'For a complete table of recommended hyperparameters, see')
            print(Fore.BLUE + 'https://github.com/juntang-zhuang/Adabelief-Optimizer')
            
            print(Fore.GREEN + 'You can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.')
            
            print(Style.RESET_ALL)
        # ------------------------------------------------------------------------------
        
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        
        self.degenerated_to_sgd = degenerated_to_sgd
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
                    param['buffer'] = [[None, None, None] for _ in range(10)]
                    
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad, buffer=[[None, None, None] for _ in range(10)])
        super(AdaBelief, self).__init__(params, defaults)
        
        self.degenerated_to_sgd = degenerated_to_sgd
        self.weight_decouple = weight_decouple
        self.rectify = rectify
        self.fixed_decay = fixed_decay
        if self.weight_decouple:
            print('Weight decoupling enabled in AdaBelief')
            if self.fixed_decay:
                print('Weight decay fixed')
        if self.rectify:
            print('Rectification enabled in AdaBelief')
        if amsgrad:
            print('AMSGrad enabled in AdaBelief')
            
    def __setstate__(self, state):
        super(AdaBelief, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)
            
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                amsgrad = group['amsgrad']
                
                # State initialization
                state['step'] = 0
                # Exponential moving average of gradient values
                state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                    if version_higher else torch.zeros_like(p.data)
                
                # Exponential moving average of squared gradient values
                state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                    if version_higher else torch.zeros_like(p.data)
                
                if amsgrad:
                    # Maintains max of all exp. moving avg. of sq. grad. values
                    state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                        if version_higher else torch.zeros_like(p.data)
                    
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # cast data type
                half_precision = False
                if p.data.dtype == torch.float16:
                    half_precision = True
                    p.data = p.data.float()
                    p.grad = p.grad.float()
                    
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
                amsgrad = group['amsgrad']
                
                state = self.state[p]
                
                beta1, beta2 = group['betas']
                
                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                        if version_higher else torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                        if version_higher else torch.zeros_like(p.data)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                            if version_higher else torch.zeros_like(p.data)
                    
                # perform weight decay, check if decoupled weight decay
                if self.weight_decouple:
                    if not self.fixed_decay:
                        p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
                    else:
                        p.data.mul_(1.0 - group['weight_decay'])
                else:
                    if group['weight_decay'] != 0:
                        grad.add_(p.data, alpha=group['weight_decay'])
                        
                # get current state variable
                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
                
                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                # Update first and second moment running average
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                grad_residual = grad - exp_avg
                exp_avg_var.mul_(beta2).addcmul_( grad_residual, grad_residual, value=1 - beta2)
                
                if amsgrad:
                    max_exp_avg_var = state['max_exp_avg_var']
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var)
                    
                    # Use the max. for normalizing running avg. of gradient
                    denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                else:
                    denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                    
                # update
                if not self.rectify:
                    # Default update
                    step_size = group['lr'] / bias_correction1
                    p.data.addcdiv_( exp_avg, denom, value=-step_size)
                    
                else:  # Rectified update, forked from RAdam
                    buffered = group['buffer'][int(state['step'] % 10)]
                    if state['step'] == buffered[0]:
                        N_sma, step_size = buffered[1], buffered[2]
                    else:
                        buffered[0] = state['step']
                        beta2_t = beta2 ** state['step']
                        N_sma_max = 2 / (1 - beta2) - 1
                        N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                        buffered[1] = N_sma
                        
                        # more conservative since it's an approximated value
                        if N_sma >= 5:
                            step_size = math.sqrt(
                                (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
                                    N_sma_max - 2)) / (1 - beta1 ** state['step'])
                        elif self.degenerated_to_sgd:
                            step_size = 1.0 / (1 - beta1 ** state['step'])
                        else:
                            step_size = -1
                        buffered[2] = step_size
                        
                    if N_sma >= 5:
                        denom = exp_avg_var.sqrt().add_(group['eps'])
                        p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
                    elif step_size > 0:
                        p.data.add_( exp_avg, alpha=-step_size * group['lr'])
                        
                if half_precision:
                    p.data = p.data.half()
                    p.grad = p.grad.half()
                    
        return loss

    
class StepLR:
    def __init__(self, optimizer_1, optimizer_2, learning_rate: float, total_epochs: int):
        self.optimizer_1 = optimizer_1
        self.optimizer_2 = optimizer_2
        self.total_epochs = total_epochs
        self.base = learning_rate
        
    def __call__(self, epoch):
        print("StepLR@divide_mix")
        if epoch < self.total_epochs * 3/10:
            lr = self.base
        elif epoch < self.total_epochs * 6/10:
            lr = self.base * 0.2
        elif epoch < self.total_epochs * 8/10:
            lr = self.base * 0.2 ** 2
        else:
            lr = self.base * 0.2 ** 3
            
        for param_group in self.optimizer_1.param_groups:
            param_group["lr"] = lr
        for param_group in self.optimizer_2.param_groups:
            param_group["lr"] = lr
            
    def lr(self) -> float:
        return self.optimizer_1.param_groups[0]["lr"]

    
def get_optimizer(args, model_params_1, model_params_2):
    scheduler = None
    
    if args.opt == "adadelta":
        optimizer_1 = torch.optim.Adadelta(
            model_params_1, rho=0.95, eps=args.eps, weight_decay=args.weight_decay
        )
        optimizer_2 = torch.optim.Adadelta(
            model_params_2, rho=0.95, eps=args.eps, weight_decay=args.weight_decay
        )
    elif args.opt == "adam":
        optimizer_1 = torch.optim.Adam(model_params_1, weight_decay=args.weight_decay)
        optimizer_2 = torch.optim.Adam(model_params_2, weight_decay=args.weight_decay)
    elif args.opt == "radam":
        from torch_optimizer import RAdam
        optimizer_1 = RAdam(model_params_1)
        optimizer_2 = RAdam(model_params_2)
    elif args.opt == "sgd":
        optimizer_1 = torch.optim.SGD(model_params_1, lr=0.1, momentum=0.9, nesterov=False)
        optimizer_2 = torch.optim.SGD(model_params_2, lr=0.1, momentum=0.9, nesterov=False)
    elif args.opt == "adabound":
        from torch_optimizer import AdaBound
        optimizer_1 = AdaBound(model_params_1, lr=1e-3, betas= (0.9, 0.999), final_lr=0.1,
                               gamma=1e-3, eps= 1e-8, weight_decay=0, amsbound=False)
        optimizer_2 = AdaBound(model_params_2, lr=1e-3, betas= (0.9, 0.999), final_lr=0.1,
                               gamma=1e-3, eps= 1e-8, weight_decay=0, amsbound=False)
    elif args.opt == "adabelief":
        # https://github.com/juntang-zhuang/Adabelief-Optimizer
        #from adabelief_pytorch import AdaBelief
        eps=1e-16
        lr=1e-3
        wd=1.2e-6
        optimizer_1 = AdaBelief(model_params_1, lr=lr, eps=eps, betas=(0.9,0.999),
                                weight_decay=wd, weight_decouple=False, rectify=False)
        optimizer_2 = AdaBelief(model_params_2, lr=lr, eps=eps, betas=(0.9,0.999),
                                weight_decay=wd, weight_decouple=False, rectify=False)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)
    
    return optimizer_1, optimizer_2, scheduler
    

        
################################################################################
#
# Hyper-parameters
#
################################################################################
from distutils.util import strtobool as dist_strtobool
def strtobool(x):
    # distutils.util.strtobool returns integer, but it's confusing,
    return bool(dist_strtobool(x))

def get_train_parser(parser=None):
    if parser is None:
        parser = configargparse.ArgumentParser(
            description="Train reduced-lead ECG classifier w/ dividemix",
            config_file_parser_class=configargparse.YAMLConfigFileParser,
            formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
        )

    parser.add_argument('data_directory', type=str)
    parser.add_argument('model_directory', type=str)
    
    # general option
    parser.add_argument(
        "--ngpu",
        default=1,
        type=int,
        help="Number of GPUs. If not given, use all visible devices",
    )
    parser.add_argument(
        "--outdir", type=str, help="Output directory"
    )
    parser.add_argument(
        "--dict", help="Dictionary"
    )
    #parser.add_argument("--corpus-dict", required=required, help="Dictionary")
    parser.add_argument(
        "--seed", default=1, type=int, help="Random seed"
    )
    parser.add_argument(
        "--resume",
        default="",
        nargs="?",
        help="Resume the training from snapshot",
    )
    parser.add_argument(
        "--tensorboard-dir",
        default=None,
        type=str,
        nargs="?",
        help="Tensorboard log dir path",
    )
    
    # task related
    parser.add_argument(
        "--train-json",
        type=str,
        default=None,
        help="Filename of train label data (json)",
    )
    parser.add_argument(
        "--valid-json",
        type=str,
        default=None,
        help="Filename of validation label data (json)",
    )
    # network architecture
    parser.add_argument(
        "--model-module",
        type=str,
        default=None,
        help="model defined module (default: espnet.nets.xxx_backend.e2e_asr:E2E)",
    )
    # minibatch related
    parser.add_argument(
        "--batch-size",
        default=0,
        type=int,
        help="Maximum seqs in a minibatch (0 to disable)",
    )
    parser.add_argument(
        "--n-iter-processes",
        default=0,
        type=int,
        help="Number of processes of iterator",
    )
    # optimization related
    parser.add_argument(
        "--opt",
        default="adam",
        type=str,
        choices=["adadelta", "adam", "noam", "agc", "radam", "adabelief"],
        help="Optimizer",
    )
    parser.add_argument(
        "--eps", default=1e-8, type=float, help="Epsilon constant for optimizer"
    )
    parser.add_argument(
        "--weight-decay", default=0.0, type=float, help="Weight decay ratio"
    )
    parser.add_argument(
        "--epochs", default=30, type=int, help="Maximum number of epochs"
    )
    parser.add_argument(
        "--grad-clip", default=5, type=float, help="Gradient norm threshold to clip"
    )
    parser.add_argument(
        "--grad-noise",
        type=strtobool,
        default=False,
        help="The flag to switch to use noise injection to gradients during training",
    )
    # ECG related
    parser.add_argument(
        "--min-seqlen",
        type=int,
        default=5000,
        help="min sequence length"
    )
    parser.add_argument(
        "--max-seqlen",
        type=int,
        default=7500,
        help="max sequence length"
    )
    parser.add_argument(
        "--n-leads",
        type=int,
        default=None,
        help="number of leads to use"
    )
    #parser.add_argument(
    #"--corpus-adv", default=False, type=strtobool, help="Add corpus adversarial loss",
    #)
    parser.add_argument(
        "--mixup", default=False, type=strtobool, help="Add manifold mixup",
    )
    parser.add_argument(
        "--resample",
        default=False,
        type=strtobool,
        help="resample freq to 500->[250,499]->500",
    )
    #parser.add_argument(
    #"--speed-perturb",
    #default=False,
    #type=strtobool,
    #help="speed perturbation",
    #)
    #parser.add_argument(
    #    "--lowpass",
    #    default=False,
    #    type=strtobool,
    #    help="apply lowpass filter",
    #)
    #parser.add_argument(
    #    "--delta-feat",
    #    default=False,
    #    type=strtobool,
    #    help="add delta feature(s)",
    #)
    #parser.add_argument(
    #    "--lead-dropout",
    #    default=False,
    #    type=strtobool,
    #    help="add lead dropout",
    #)
    #parser.add_argument(
    #    "--shared-lead-dropout",
    #    default=False,
    #    type=strtobool,
    #    help="add shared id- lead dropout",
    #)
    #parser.add_argument(
    #    "--lead-attn",
    #    default=False,
    #    type=strtobool,
    #    help="add lead attention",
    #)
    #parser.add_argument(
    #"--elr-reg",
    #type=strtobool,
    #default=False,
    #help="add early learning regularization"
    #)
    parser.add_argument(
        "--edim",
        type=int,
        default=-1,
        help="extra feature dim. actual dim is calculated by get_edim()",
    )
    parser.add_argument(
        "--efeat-drop",
        type=float,
        default=0.1,
        help="ratio to drop age and gender information",
    )
    return parser
    
    
################################################################################
#
# Network Definition
#
################################################################################

class Mish_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.tanh(F.softplus(i))
        ctx.save_for_backward(i)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        
        v = 1. + i.exp()
        h = v.log()
        grad_gh = 1./h.cosh().pow_(2)
        
        # Note that grad_hv * grad_vx = sigmoid(x)
        #grad_hv = 1./v
        #grad_vx = i.exp()
        
        grad_hx = i.sigmoid()
        grad_gx = grad_gh *  grad_hx #grad_hv * grad_vx
        grad_f =  torch.tanh(F.softplus(i)) + i * grad_gx
        return grad_output * grad_f
    
    
class Mish(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        pass
    
    def forward(self, input_tensor):
        return Mish_func.apply(input_tensor)
    
mish = Mish()


# Parameters for the entire model (stem, all blocks, and head)
GlobalParams = namedtuple('GlobalParams', [
    'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
    'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
    'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top', 'n_blocks', 'e_version'])

# Parameters for an individual model block
BlockArgs = namedtuple('BlockArgs', [
    'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
    'input_filters', 'output_filters', 'se_ratio', 'id_skip'])

# Set GlobalParams and BlockArgs's defaults
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)

def round_filters(filters, global_params):
    """Calculate and round number of filters based on width multiplier.
       Use width_coefficient, depth_divisor and min_depth of global_params.
    Args:
        filters (int): Filters number to be calculated.
        global_params (namedtuple): Global params of the model.
    Returns:
        new_filters: New filters number after calculating.
    """
    multiplier = global_params.width_coefficient
    if not multiplier:
        return filters
    # TODO: modify the params names.
    #       maybe the names (width_divisor,min_width)
    #       are more suitable than (depth_divisor,min_depth).
    divisor = global_params.depth_divisor
    min_depth = global_params.min_depth
    filters *= multiplier
    min_depth = min_depth or divisor # pay attention to this line when using min_depth
    # follow the formula transferred from official TensorFlow implementation
    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
    if new_filters < 0.9 * filters: # prevent rounding by more than 10%
        new_filters += divisor
    return int(new_filters)

def round_repeats(repeats, global_params):
    """Calculate module's repeat number of a block based on depth multiplier.
       Use depth_coefficient of global_params.
    Args:
        repeats (int): num_repeat to be calculated.
        global_params (namedtuple): Global params of the model.
    Returns:
        new repeat: New repeat number after calculating.
    """
    multiplier = global_params.depth_coefficient
    if not multiplier:
        return repeats
    # follow the formula transferred from official TensorFlow implementation
    return int(math.ceil(multiplier * repeats))

def drop_connect(inputs, p, training):
    """Drop connect.
    Args:
        input (tensor: BCW): Input of this structure.
        p (float: 0.0~1.0): Probability of drop connection.
        training (bool): The running mode.
    Returns:
        output: Output after drop connection.
    """
    assert 0 <= p <= 1, 'p must be in range of [0,1]'
    
    if not training:
        return inputs
    
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    
    # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1], dtype=inputs.dtype, device=inputs.device)
    binary_tensor = torch.floor(random_tensor)
    
    output = inputs / keep_prob * binary_tensor
    return output

def get_width_and_height_from_size(x):
    """Obtain height and width from x.
    Args:
        x (int, tuple or list): Data size.
    Returns:
        size: A tuple or list (H,W).
    """
    if isinstance(x, int):
        return x, x
    if isinstance(x, list) or isinstance(x, tuple):
        return x
    else:
        raise TypeError()
    
def calculate_output_image_size(input_image_size, stride):
    """Calculates the output image size when using Conv2dSamePadding with a stride.
       Necessary for static padding. Thanks to mannatsingh for pointing this out.
    Args:
        input_image_size (int, tuple or list): Size of input image.
        stride (int, tuple or list): Conv2d operation's stride.
    Returns:
        output_image_size: A list [H,W].
    """
    if input_image_size is None:
        return None
    
    #image_height, image_width = get_width_and_height_from_size(input_image_size)
    stride = stride if isinstance(stride, int) else stride[0]
    image_width = int(math.ceil(input_image_size / stride))
    return image_width

# Note:
# The following 'SamePadding' functions make output size equal ceil(input size/stride).
# Only when stride equals 1, can the output size be the same as input size.
# Don't be confused by their function names ! ! !

class Conv1dStaticSamePadding(nn.Conv1d):
    """1D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
       The padding mudule is calculated in construction function, then used in forward.
    """

    # With the same calculation as Conv1dDynamicSamePadding
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
        self.stride = self.stride if isinstance(self.stride, int) else self.stride[0]
        
        # Calculate padding based on image size and save it
        assert image_size is not None and isinstance(image_size, int)
        
        iw = image_size
        kw = self.weight.size()[-1]  # org: (out_channs, in/group, kW)
        sw = self.stride
        
        ow = math.ceil(iw / sw)
        pad_w = max((ow - 1) * self.stride + (kw - 1) * self.dilation[0] + 1 - iw, 0)
        if pad_w > 0:
            self.static_padding_val = (pad_w // 2, pad_w - pad_w // 2)
            self.static_padding = None
        else:
            self.static_padding = nn.Identity()
            
    def forward(self, x):
        if self.static_padding is None:
            x = F.pad(x, self.static_padding_val)  # mode: constant, value: 0
        else:
            x = self.static_padding(x)
            
        x = F.conv1d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        return x
    
def get_same_padding_conv1d(image_size):
    """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
       Static padding is necessary for ONNX exporting of models.
    Args:
        image_size (int or tuple): Size of the image.
    Returns:
        Conv1dDynamicSamePadding or Conv1dStaticSamePadding.
    """
    return partial(Conv1dStaticSamePadding, image_size=image_size)


class MBConvBlock(nn.Module):
    """Mobile Inverted Residual Bottleneck Block.
    Args:
        block_args (namedtuple): BlockArgs, defined in utils.py.
        global_params (namedtuple): GlobalParam, defined in utils.py.
        image_size (tuple or list): [image_height, image_width].
    References:
        [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
        [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
        [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
    """
    
    def __init__(self, block_args, global_params, image_size=None, act=mish, ds_filter_num=10000, edim=-1):
        super().__init__()
        self._block_args = block_args
        self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
        self._bn_eps = global_params.batch_norm_epsilon
        self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
        self.id_skip = block_args.id_skip  # whether to use skip connection and drop connect
        
        # control this by block_args.input_filters (or output_)
        self.use_ds_conv = False
        if block_args.input_filters > ds_filter_num:
            self.use_ds_conv = True
            
        if self.use_ds_conv:
            # Expansion phase (Inverted Bottleneck)
            inp = self._block_args.input_filters  # number of input channels
            oup = self._block_args.input_filters * self._block_args.expand_ratio  # number of output channels
            if self._block_args.expand_ratio != 1:
                Conv1d = get_same_padding_conv1d(image_size=image_size)
                self._expand_conv = Conv1d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
                self._bn0 = nn.BatchNorm1d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
                # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
                
            # Depthwise convolution phase
            k = self._block_args.kernel_size
            s = self._block_args.stride
            Conv1d = get_same_padding_conv1d(image_size=image_size)
            self._depthwise_conv = Conv1d(
                in_channels=oup, out_channels=oup, groups=oup,  # groups makes it depthwise
                kernel_size=k, stride=s, bias=False)
            self._bn1 = nn.BatchNorm1d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
            image_size = calculate_output_image_size(image_size, s)
            
        else:
            inp = self._block_args.input_filters  # number of input channels
            oup = self._block_args.input_filters * self._block_args.expand_ratio  # number of output channels
            k = self._block_args.kernel_size
            s = self._block_args.stride
            Conv1d = get_same_padding_conv1d(image_size=image_size)
            self._ds_conv = Conv1d(
                in_channels=inp, out_channels=oup, kernel_size=k, stride=s, bias=False)
            self._bn1 = nn.BatchNorm1d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
            image_size = calculate_output_image_size(image_size, s)
            
        # Squeeze and Excitation layer, if desired
        if self.has_se:
            Conv1d = get_same_padding_conv1d(image_size=(1))
            num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
            self._se_reduce = Conv1d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
            self._se_expand = Conv1d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
            
        # Concat extra feats
        self.edim = edim
        if edim > 0:
            oup = oup + edim
            
        # Pointwise convolution phase
        final_oup = self._block_args.output_filters
        Conv1d = get_same_padding_conv1d(image_size=image_size)
        self._project_conv = Conv1d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
        self._bn2 = nn.BatchNorm1d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
        self._act = act
        
        # dropout
        self._dropout = nn.Dropout(global_params.dropout_rate)
        self._dropblock = None
        
    def forward(self, inputs, extra, drop_connect_rate=None):
        """MBConvBlock's forward function.
        Args:
            inputs (tensor): Input tensor.
            drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
        Returns:
            Output of this block after processing.
        """
        
        # Expansion and Depthwise Convolution
        x = inputs
        
        if self.use_ds_conv:
            if self._block_args.expand_ratio != 1:
                x = self._expand_conv(inputs)
                x = self._bn0(x)
                x = self._act(x)
                
            x = self._depthwise_conv(x)
            x = self._bn1(x)
            x = self._act(x)
            
        else:
            x = self._ds_conv(x)
            x = self._bn1(x)
            x = self._act(x)
            
        # Squeeze and Excitation
        if self.has_se:
            x_squeezed = F.adaptive_avg_pool1d(x, 1)
            x_squeezed = self._se_reduce(x_squeezed)
            x_squeezed = self._act(x_squeezed)
            x_squeezed = self._se_expand(x_squeezed)
            x = torch.sigmoid(x_squeezed) * x
            
        # Concat extra feats
        if self.edim > 0:
            extra = extra.unsqueeze(-1).expand(
                extra.size(0), extra.size(1), x.size(-1)
            )
            x = torch.cat((x, extra), axis=1)
            
        # Pointwise Convolution
        x = self._project_conv(x)
        x = self._bn2(x)
        
        # Skip connection and drop connect
        input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
        if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
            # The combination of skip connection and drop connect brings about stochastic depth.
            #if drop_connect_rate:
            #    x = drop_connect(x, p=drop_connect_rate, training=self.training)
            if self._dropout is not None:
                x = self._dropout(x)
            if self._dropblock is not None:
                x = self._dropblock(x)
            x = x + inputs  # skip connection
            
        return x
                
                
class EfficientNet(nn.Module):
    """EfficientNet model.
        Most easily loaded with the .from_name or .from_pretrained methods.
    Args:
        blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
        global_params (namedtuple): A set of GlobalParams shared between blocks.
    References:
        [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
    Example:
        >>> import torch
        >>> from efficientnet.model import EfficientNet
        >>> inputs = torch.rand(1, 3, 224, 224)
        >>> model = EfficientNet.from_pretrained('efficientnet-b0')
        >>> model.eval()
        >>> outputs = model(inputs)
    """
    def __init__(self, blocks_args=None, global_params=None, in_channels=-1, act=mish, stem_out_dim=32, edim=-1):
        super().__init__()
        assert isinstance(blocks_args, list), 'blocks_args should be a list'
        assert len(blocks_args) > 0, 'block args must be greater than 0'
        self._global_params = global_params
        self._blocks_args = blocks_args
        
        self.e_version = global_params.e_version
        print("e_version @ effnet: {}".format(self.e_version))
        self.edim = edim
        
        # Batch norm parameters
        bn_mom = 1 - self._global_params.batch_norm_momentum
        bn_eps = self._global_params.batch_norm_epsilon
        
        # Get stem static or dynamic convolution depending on image size
        image_size = global_params.image_size
        Conv1d = get_same_padding_conv1d(image_size=image_size)
        
        # Stem
        out_channels = round_filters(stem_out_dim, self._global_params)  # number of output channels
        
        #self._conv_stem = Conv1d(in_channels, out_channels, kernel_size=15, stride=2, bias=False)
        self._conv_stem = Conv1d(in_channels, out_channels, kernel_size=7, stride=2, bias=False)
        # [memo] stem of resnet-18
        #self.conv1 = nn.Conv1d(idim, 64, kernel_size=15, stride=2, padding=7, bias=False)
        
        self._bn0 = nn.BatchNorm1d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
        image_size = calculate_output_image_size(image_size, 2)
        
        self.n_emb_blocks = self._global_params.n_blocks
        
        # Build blocks
        self._blocks = nn.ModuleList([])
        for i, block_args in enumerate(self._blocks_args):
            
            # Update block input and output filters based on depth multiplier.
            block_args = block_args._replace(
                input_filters=round_filters(block_args.input_filters, self._global_params),
                output_filters=round_filters(block_args.output_filters, self._global_params),
                num_repeat=round_repeats(block_args.num_repeat, self._global_params)
            )
            
            if i >= self.n_emb_blocks:
                edim = -1
                
            # The first block needs to take care of stride and filter size increase.
            self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size, edim=edim))
            image_size = calculate_output_image_size(image_size, block_args.stride)
            if block_args.num_repeat > 1:  # modify block_args to keep same output size
                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
            for _ in range(block_args.num_repeat - 1):
                self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size, edim=-1))
                
        # Head
        in_channels = block_args.output_filters  # output of final block
        out_channels = round_filters(512, self._global_params)
        Conv1d = get_same_padding_conv1d(image_size=image_size)
        self._conv_head = Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        self._bn1 = nn.BatchNorm1d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
        
        # Final linear layer
        self._avg_pooling = nn.AdaptiveMaxPool1d(1)
        self._fc = nn.Linear(out_channels, self._global_params.num_classes)

        # Activation function
        self._act = act

    def extract_features(self, inputs, extra):
        """use convolution layer to extract feature .
        Args:
            inputs (tensor): Input tensor.
        Returns:
            Output of the final convolution
            layer in the efficientnet model.
        """
        # Stem
        x = self._act(self._bn0(self._conv_stem(inputs)))
        
        # Blocks
        for idx, block in enumerate(self._blocks):
            drop_connect_rate = self._global_params.drop_connect_rate
            if drop_connect_rate:
                drop_connect_rate *= float(idx) / len(self._blocks)  # scale drop connect_rate
                
            if self.e_version == 2:
                # if idx >= n_blocks: extra_i is not used.
                extra_idx = idx
                if extra_idx >= len(extra):
                    extra_idx = len(extra) - 1
                extra_i = extra[extra_idx]
            else:
                extra_i = extra
                
            x = block(x, extra_i, drop_connect_rate=drop_connect_rate)
            
        # Head
        x = self._act(self._bn1(self._conv_head(x)))
        return x
    
    def forward(self, inputs, extra):
        """EfficientNet's forward function.
           Calls extract_features to extract features, applies final linear layer, and returns logits.
        Args:
            inputs (tensor): Input tensor.
            extra (tensor): extra input tensor.
        Returns:
            Output of this model after processing.
        """
        # Convolution layers
        x = self.extract_features(inputs, extra)
        
        # Pooling and final linear layer
        x = self._avg_pooling(x)
        x = x.flatten(start_dim=1)
        x = self._fc(x)
        return x
    
    @classmethod
    def get_image_size(cls, model_name):
        """Get the input image size for a given efficientnet model.
        Args:
            model_name (str): Name for efficientnet.
        Returns:
            Input image size (resolution).
        """
        cls._check_model_name_is_valid(model_name)
        _, _, res, _ = efficientnet_params(model_name)
        return res
    

# BlockDecoder: A Class for encoding and decoding BlockArgs
# efficientnet_params: A function to query compound coefficient
# get_model_params and efficientnet:
#     Functions to get BlockArgs and GlobalParams for efficientnet
# url_map and url_map_advprop: Dicts of url_map for pretrained weights
# load_pretrained_weights: A function to load pretrained weights
class BlockDecoder(object):
    """Block Decoder for readability,
       straight from the official TensorFlow repository.
    """
    @staticmethod
    def _decode_block_string(block_string):
        """Get a block through a string notation of arguments.
        Args:
            block_string (str): A string notation of arguments.
                                Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
        Returns:
            BlockArgs: The namedtuple defined at the top of this file.
        """
        assert isinstance(block_string, str)
        
        ops = block_string.split('_')
        options = {}
        for op in ops:
            splits = re.split(r'(\d.*)', op)
            if len(splits) >= 2:
                key, value = splits[:2]
                options[key] = value
                
        # Check stride
        assert (('s' in options and len(options['s']) == 1) or
                (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
        
        return BlockArgs(
            num_repeat=int(options['r']),
            kernel_size=int(options['k']),
            stride=[int(options['s'][0])],
            expand_ratio=int(options['e']),
            input_filters=int(options['i']),
            output_filters=int(options['o']),
            se_ratio=float(options['se']) if 'se' in options else None,
            id_skip=('noskip' not in block_string))
    
    @staticmethod
    def _encode_block_string(block):
        """Encode a block to a string.
        Args:
            block (namedtuple): A BlockArgs type argument.
        Returns:
            block_string: A String form of BlockArgs.
        """
        args = [
            'r%d' % block.num_repeat,
            'k%d' % block.kernel_size,
            's%d%d' % (block.strides[0], block.strides[1]),
            'e%s' % block.expand_ratio,
            'i%d' % block.input_filters,
            'o%d' % block.output_filters
        ]
        if 0 < block.se_ratio <= 1:
            args.append('se%s' % block.se_ratio)
        if block.id_skip is False:
            args.append('noskip')
        return '_'.join(args)
    
    @staticmethod
    def decode(string_list):
        """Decode a list of string notations to specify blocks inside the network.
        Args:
            string_list (list[str]): A list of strings, each string is a notation of block.
        Returns:
            blocks_args: A list of BlockArgs namedtuples of block args.
        """
        assert isinstance(string_list, list)
        blocks_args = []
        for block_string in string_list:
            blocks_args.append(BlockDecoder._decode_block_string(block_string))
        return blocks_args
    
    @staticmethod
    def encode(blocks_args):
        """Encode a list of BlockArgs to a list of strings.
        Args:
            blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
        Returns:
            block_strings: A list of strings, each string is a notation of block.
        """
        block_strings = []
        for block in blocks_args:
            block_strings.append(BlockDecoder._encode_block_string(block))
        return block_strings
    
    
def get_effnet(effnet_type, idim, hdim, edim, n_blocks=-1, fs=500, dur=15, e_version=1):

    print("effnet_type: {}".format(effnet_type))
    width_coefficient=1.0
    depth_coefficient=1.0
    dropout_rate = 0.2
    drop_connect_rate = 0.0
    
    if effnet_type == 1:
        blocks_args = [
            'r2_k5_s22_e2_i32_o64_se0.25', # 48 10e2
            'r2_k5_s22_e2_i64_o128_se0.25', # 48
            'r2_k7_s22_e2_i128_o128_se0.25', # 48-64
            'r2_k7_s22_e2_i128_o256_se0.25', # 0.25
            'r2_k7_s22_e2_i256_o256_se0.25', # 0.25
        ]
    elif effnet_type == 2:
        blocks_args = [
            'r2_k5_s22_e2_i32_o64_se0.25',
            'r1_k5_s22_e2_i64_o128_se0.25',
            'r2_k7_s22_e2_i128_o128_se0.25',
            'r1_k7_s22_e2_i128_o256_se0.25',
            'r2_k7_s22_e2_i256_o256_se0.25',
            'r2_k7_s22_e2_i256_o256_se0.25',
        ]
    elif effnet_type == 3:
        blocks_args = [
            'r2_k5_s22_e2_i32_o64_se0.25',
            'r1_k5_s22_e2_i64_o128_se0.25',
            'r2_k7_s22_e2_i128_o128_se0.25',
            'r1_k7_s22_e2_i128_o256_se0.25',
            'r2_k7_s22_e2_i256_o256_se0.25',
        ]
        
    blocks_args = BlockDecoder.decode(blocks_args)
        
    global_params = GlobalParams(
        width_coefficient=width_coefficient,
        depth_coefficient=depth_coefficient,
        image_size=fs * dur,
        dropout_rate=dropout_rate,
        num_classes=hdim,
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-3,
        drop_connect_rate=drop_connect_rate,
        depth_divisor=8,
        min_depth=None,
        include_top=False,
        n_blocks=n_blocks,
        e_version=e_version,
    )

    effnet = EfficientNet(blocks_args, global_params, in_channels=idim, edim=edim)
    return effnet
                                        
#def conv3x3(in_planes, out_planes, stride=1):
#    """3x3 convolution with padding"""
#    return nn.Conv1d(
#        in_planes, out_planes, kernel_size=7, stride=stride,
#        padding=3, bias=False
#    )
#
#class BasicBlock(nn.Module):
#    expansion = 1
#    
#    def __init__(self, inplanes, planes, stride=1, downsample=None):
#        super(BasicBlock, self).__init__()
#        self.conv1 = conv3x3(inplanes, planes, stride)
#        self.bn1 = nn.BatchNorm1d(planes)
#        self.relu = nn.ReLU(inplace=True)
#        self.conv2 = conv3x3(planes, planes)
#        self.bn2 = nn.BatchNorm1d(planes)
#        self.downsample = downsample
#        self.stride = stride
#        self.dropout = nn.Dropout(.2)
#        
#    def forward(self, x):
#        residual = x
#        
#        out = self.conv1(x)
#        out = self.bn1(out)
#        out = self.relu(out)
#        out = self.dropout(out)
#        out = self.conv2(out)
#        out = self.bn2(out)
#        
#        if self.downsample is not None:
#            residual = self.downsample(x)
#            
#        out += residual
#        out = self.relu(out)
#        return out
#    
#    
#class Bottleneck(nn.Module):
#    expansion = 4
#    
#    def __init__(self, inplanes, planes, stride=1, downsample=None):
#        super(Bottleneck, self).__init__()
#        self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=7, bias=False, padding=3)
#        self.bn1 = nn.BatchNorm1d(planes)
#        self.conv2 = nn.Conv1d(planes, planes, kernel_size=11, stride=stride,
#                               padding=5, bias=False)
#        self.bn2 = nn.BatchNorm1d(planes)
#        self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=7, bias=False, padding=3)
#        self.bn3 = nn.BatchNorm1d(planes * 4)
#        self.relu = nn.ReLU(inplace=True)
#        self.downsample = downsample
#        self.stride = stride
#        self.dropout = nn.Dropout(.2)
#        
#    def forward(self, x):
#        residual = x
#        
#        out = self.conv1(x)
#        out = self.bn1(out)
#        out = self.relu(out)
#        
#        out = self.conv2(out)
#        out = self.bn2(out)
#        out = self.relu(out)
#        out = self.dropout(out)
#        
#        out = self.conv3(out)
#        out = self.bn3(out)
#        if self.downsample is not None:
#            residual = self.downsample(x)
#            
#        out += residual
#        out = self.relu(out)
#        return out
#  
#
#class ResNetBase(nn.Module):
#    def __init__(self, block, layers, idim=12, num_classes=9, use_maxpool=True):
#        self.use_maxpool = use_maxpool
#        self.inplanes = 64
#        
#        super(ResNetBase, self).__init__()
#        
#        self.conv1 = nn.Conv1d(idim, 64, kernel_size=15, stride=2, padding=7,
#                               bias=False)
#        self.bn1 = nn.BatchNorm1d(64)
#        self.relu = nn.ReLU(inplace=True)
#        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
#        
#        self.layer1 = self._make_layer(block, 64,  layers[0])
#        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
#        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
#        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
#        
#        self.fc = nn.Linear(512 * block.expansion, num_classes)
#        
#        for m in self.modules():
#            if isinstance(m, nn.Conv1d):
#                n = m.kernel_size[0] * m.kernel_size[0] * m.out_channels
#                m.weight.data.normal_(0, math.sqrt(2. / n))
#            elif isinstance(m, nn.BatchNorm1d):
#                m.weight.data.fill_(1)
#                m.bias.data.zero_()
#
#    def _make_layer(self, block, planes, blocks, stride=1):
#        downsample = None
#        if stride != 1 or self.inplanes != planes * block.expansion:
#            downsample = nn.Sequential(
#                nn.Conv1d(self.inplanes, planes * block.expansion,
#                          kernel_size=1, stride=stride, bias=False),
#                nn.BatchNorm1d(planes * block.expansion),
#            )
#            
#        layers = []
#        layers.append(block(self.inplanes, planes, stride, downsample))
#        self.inplanes = planes * block.expansion
#        for i in range(1, blocks):
#            layers.append(block(self.inplanes, planes))
#            
#        return nn.Sequential(*layers)
#            
#    def forward(self, x):
#        x = self.conv1(x)
#        x = self.bn1(x)
#        x = self.relu(x)
#        x = self.maxpool(x)
#        
#        x = self.layer1(x)
#        x = self.layer2(x)
#        x = self.layer3(x)
#        x = self.layer4(x)
#
#        x = F.adaptive_max_pool1d(x, 1)
#        x = x.view(x.size(0), -1)
#        x = self.fc(x)
#        return x
#                
#def get_resnet(resnet_type, idim, hdim):
#    if resnet_type == 18:
#        net = ResNetBase(
#            BasicBlock, [2, 2, 2, 2], idim=idim, num_classes=hdim
#        )
#    elif resnet_type == 34:
#        net = ResNetBase(
#            BasicBlock, [3, 4, 6, 3], idim=idim, num_classes=hdim
#        )
#    elif resnet_type == 50:
#        net = ResNetBase(
#            Bottleneck, [3, 4, 6, 3], idim=idim, num_classes=hdim
#        )
#    elif resnet_type == 101:
#        net = ResNetBase(
#            Bottleneck, [3, 4, 23, 3], idim=idim, num_classes=hdim
#        )
#    elif resnet_type == 152:
#        net = ResNetBase(
#            Bottleneck, [3, 8, 36, 3], idim=idim, num_classes=hdim
#        )
#    else:
#        raise NotImplementedError    
#    return net

    
################################################################################
#
# DivideMix Training
#
################################################################################

class DivMixNet(nn.Module):
    @staticmethod
    def add_arguments(parser):
        DivMixNet.add_divmix_arguments(parser)
        return parser
    
    @staticmethod
    def add_divmix_arguments(parser):
        """Add arguments for DivideMix
        """
        group = parser.add_argument_group("DivMixNet model specific setting")
        group.add_argument(
            "--resnet-type",
            type=int,
            default=-1,
            choices=[-1, 18, 34, 50, 101, 152],
            help="ResNet model setting",
        )
        group.add_argument(
            "--effnet-type",
            type=int,
            default=2,
            choices=[-1, 1, 2, 3],
            help="decide efficientnet architecture",
        )
        group.add_argument(
            "--warm-up",
            type=int,
            default=2,
            help="num of epochs for warmup stage",
        )
        group.add_argument(
            "--alpha",
            type=float,
            default=2.0,
            help="parameter for Beta"
        )
        group.add_argument(
            "--lambda-u",
            type=float,
            default=1,
            help="weight for unsupervised loss"
        )
        group.add_argument(
            "--p-threshold",
            type=float,
            default=0.5,
            help="clean probability threshold"
        )
        group.add_argument(
            "--T",
            type=float,
            default=1.0,
            help="sharpening temperature"
        )
        group.add_argument(
            "--r",
            type=float,
            default=0.5,
            help="noise ratio"
        )
        group.add_argument(
            "--ulabel-coeff",
            type=float,
            default=0.5,
            help="interpolate co-guessed labels with ground-truth",
        )
        group.add_argument(
            "--loader-version",
            type=int,
            default=2,
            help="control data loader version",
        )
        group.add_argument(
            "--use-f1",
            default=False,
            action='store_true',
        )
        group.add_argument(
            "--f1-coeff",
            type=float,
            default=0.01,
            help="coefficient of f1 loss",
        )
        # SWA related
        group.add_argument(
            "--swag-epoch-start",
            type=int,
            default=32,
            help="epoch to start SWAG",
        )
        group.add_argument(
            "--swag-collect-interval",
            type=int,
            default=1,
            help="SWAG model collection interval in epochs",
        )
        group.add_argument(
            "--progressive-input-epochs",
            type=int,
            default=-1,
            help='use 10<x<15 sec ECG input in specified epochs',
        )
        group.add_argument(
            "--anneal-lr-epoch",
            type=int,
            default=-1,
            help="change learning rate",
        )
        group.add_argument(
            "--e-version",
            type=int,
            default=2,
            help="conditioning method",
        )
        return parser
    
    def __init__(self, idim, odim, args):
        super(DivMixNet, self).__init__()
        torch.nn.Module.__init__(self)
        
        self.expand = getattr(args, "mixup_hdim_scale", 10)
        self.idim = idim
        self.edim = args.edim
        self.odim = odim
        self.hdim = odim * self.expand

        # stat feat embedding
        self.e_version = args.e_version
        if self.e_version == 1:
            self.e_expand = 10
            self.e_odim = args.edim
        elif self.e_version == 2:
            self.n_blocks = 5
            self.block_emb_dim = 32
            self.e_odim = self.n_blocks * self.block_emb_dim
            self.e_expand = 10  # 10*5*32 = 1600
        else:
            raise NotImplementedError
        
        # load model
        if args.resnet_type != -1:
            self.model = get_resnet(args.resnet_type, self.idim, self.hdim) #, self.e_odim)  # (not impled)
        elif args.effnet_type != -1:
            if self.e_version == 1:
                self.model = get_effnet(args.effnet_type, self.idim, self.hdim, self.e_odim)
            else:
                self.model = get_effnet(
                    args.effnet_type, self.idim, self.hdim, self.block_emb_dim,
                    n_blocks=self.n_blocks,
                    e_version=self.e_version
                )
        self.model.e_version = self.e_version
        
        if self.edim > 0:
            _edim_first = self.edim if self.edim > 1000 else self.edim * self.expand
            _edim = self.e_odim * self.e_expand
            self.fc_emb = nn.Sequential(
                nn.Linear(self.edim, _edim_first),
                nn.BatchNorm1d(_edim_first),
                Mish(),
                nn.Linear(_edim_first, _edim),
                nn.BatchNorm1d(_edim),
                Mish(),
                nn.Linear(_edim, _edim),
                nn.BatchNorm1d(_edim),
                Mish(),
                nn.Linear(_edim, self.e_odim),
                nn.BatchNorm1d(self.e_odim),
                Mish(),
            )

        self.fc_cls = nn.Sequential(
            nn.Linear(self.hdim, self.hdim),
            nn.BatchNorm1d(self.hdim),
            Mish(),
            nn.Linear(self.hdim, self.hdim),
            nn.BatchNorm1d(self.hdim),
            Mish(),
            nn.Linear(self.hdim, odim),
        )
        
        # logging
        tensorboard_dir = getattr(args, "tensorboard_dir", None)
        if tensorboard_dir is not None:
            self.writer = SummaryWriter(tensorboard_dir)
        else:
            self.writer = None
            
        # metric
        #self.reporter = Reporter()
        self.thr = 0.3
        self.sym2int = {'-1': 0}

    def _forward_impl(self, xs_pad, extra, return_logit=True):
        
        if self.edim > 0:
            # compute embedding
            extra = self.fc_emb(extra)
            
            if torch.isnan(extra).any():
                extra = torch.zeros(
                    (xs_pad.size(0), self.e_odim), dtype=torch.float32
                ).to(device=xs_pad.device)
                
            if self.e_version == 2:  # (B, C)
                extra = torch.split(extra, self.block_emb_dim, dim=-1)
                
        hs_pad = self.model(xs_pad, extra)
        if return_logit:
            hs_pad = self.fc_cls(hs_pad)
        return hs_pad
    
    def forward(self, xs_pad, extra, ilens=None, ys=None, return_logit=False):
        """
        xs_pad (torch.Tensor): input ECG data (B, C, T)
        extra (torch.Tensor): extra input data (B, E')
        """
        hs_pad = self._forward_impl(xs_pad, extra, return_logit)
        return hs_pad
    
    def predict(self, xs_pad, extra, ilens=None, ys=None, prob=True):
        """Return prediction
        Args:
            xs_pad (torch.Tensor): input ECG data, (B, C, T)
            prob (bool): return repr after sigmoid if True
        """
        hs_pad = self._forward_impl(xs_pad, extra, return_logit=True)
        if prob:
            hs_pad = torch.sigmoid(hs_pad)
        return hs_pad

    
def disable_bn(model):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm1d):
            module.eval()
            
def enable_bn(model):
    model.train()
    
def zero_grad(model):
    for param in model.parameters():
        param.grad = None

def update_param(model, opt, L, scaler, clip):
    scaler.scale(L).backward()
    scaler.unscale_(opt)
    torch.nn.utils.clip_grad_norm_(
        model.parameters(), clip
    )
    scaler.step(opt)
    scaler.update()
    zero_grad(model)


class MacroSoftF1Loss(nn.Module):
    def __init__(self):
        super(MacroSoftF1Loss, self).__init__()
        self.eps = 1e-8

    def forward(self, inputs, targets):
        targets = targets.float()
        inputs = torch.sigmoid(inputs)

        TP = torch.sum(inputs * targets, dim=0)
        FP = torch.sum((1 - inputs) * targets, dim=0)
        FN = torch.sum(inputs * (1 - targets), dim=0)

        F1_class1 = 2 * TP / (2 * TP + FP + FN + self.eps)
        loss = torch.mean(1 - F1_class1)
        return loss
    
class SemiLoss(object):
    def __init__(self, lambda_u, warm_up, rampup_length=16, use_fscore=False, f1_coeff=0.01):
        self.lambda_u = lambda_u
        self.warm_up = warm_up
        self.rampup_length = rampup_length

        self.f1_coeff = f1_coeff
        self.lx_loss = MacroSoftF1Loss()
        
    def linear_rampup(self, current):
        current = np.clip((current - self.warm_up) / self.rampup_length, 0.0, 1.0)
        return self.lambda_u * float(current)
    
    def __call__(
            self,
            outputs_x,
            targets_x,
            outputs_u,
            targets_u,
            epoch,
            logits_s=None,
            targets_s=None,
    ):
        Lx = F.binary_cross_entropy_with_logits(
            outputs_x, targets_x
        )

        if logits_s is not None:
            Lx = Lx + self.f1_coeff * self.lx_loss(logits_s, targets_s)
        
        probs_u = torch.sigmoid(outputs_u)
        Lu = torch.mean((probs_u - targets_u) ** 2)
        return Lx, Lu, self.linear_rampup(epoch)
    

class NegEntropy(object):
    def __call__(self, outputs):
        probs = torch.softmax(outputs, dim=1)
        return torch.mean(
            torch.sum(probs.log() * probs, dim=1)
        )
    
    
################################################################################
#
# Training function
#
################################################################################

def challenge_score_fn(target, pred_bin):
    return compute_challenge_metric(weights, target, pred_bin, classes, normal_class)

def tune_threshold(loader, model, thr_num=100):
    model = model.to('cuda')
    
    # prediction and ground truth. shape is (batch size, label size)
    #pred, target = epoch_test(None, model, None, loader, return_logit=True)
    pred_1, target = epoch_test(None, model.m1, None, loader, return_logit=True)
    pred_2, target = epoch_test(None, model.m2, None, loader, return_logit=True)
    pred_3, target = epoch_test(None, model.m3, None, loader, return_logit=True)
    pred_4, target = epoch_test(None, model.m4, None, loader, return_logit=True)
    pred = (pred_1 + pred_2 + pred_3 + pred_4) / 4.0

    # ray tune
    #ray.init(address=None, num_cpus=os.cpu_count())

    def _challenge_score_fn(target, pred_bin):
        return compute_challenge_metric(weights, target, pred_bin, classes, normal_class)

    def _tune_objective(config):
        thr = np.array(list(config.values()))
        pred_bin = (pred > thr).astype(np.int32)
        score = _challenge_score_fn(target, pred_bin)
        # Feed the score back back to Tune.
        tune.report(challenge_score=score)

    config = {}
    keys = np.arange(pred.shape[1])
    keys = list(map(str, keys.tolist()))
    for key in keys:
        config[key] = tune.uniform(0.1, 0.6)

    #search_alg = tune.suggest.ax.AxSearch(metric="challenge_score", mode="max")
    #search_alg = tune.suggest.optuna.OptunaSearch(metric="challenge_score", mode="max")
    #search_alg = BayesOptSearch(metric="challenge_score", mode="max")
    
    analysis = tune.run(
        _tune_objective,
        config=config,
        num_samples=4,
        #search_alg=search_alg,
    )

    best_config = analysis.get_best_config(metric="challenge_score", mode="max")
    best_thr = np.array(list(best_config.values()))
    #print("Best config: {}".format(best_config))
    
    """
    for thr in thrs:
        pred_bin = (pred > thr).astype(np.int32)
        #f1, cls_f1 = compute_f_measure(target, pred_bin)
        score = compute_challenge_metric(weights, target, pred_bin, classes, normal_class)
        score_list.append(cls_f1)
        
    f1_list = np.vstack(f1_list)
    best_thr_idx = np.argmax(f1_list, axis=0)
    best_thr = np.array([thrs[i] for i in best_thr_idx])
    """
    model = model.to('cpu')
    return best_thr.tolist()

def epoch_warmup(epoch, model, optimizer, data_loader, scaler, clip, grad_noise=False):
    """Train model using all samples including clean and noisy ones
    Args:
        epoch (int): current epoch number
        model: model
        optimizer: optimizer
        data_loader: data loader
        scaler: torch amp scaler
        clip (float): gradient clip threshold
    """
    
    model.train()
    num_iter = (len(data_loader.dataset) // data_loader.batch_size) + 1
    
    tot_loss = []
    for batch_idx, (feats, ys, ids, extra) in enumerate(data_loader):
        feats, ys, extra = feats.cuda(), ys.cuda(), extra.cuda()
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            logits = model(feats, extra, None, ys, True)
            L = F.binary_cross_entropy_with_logits(logits, ys)
            #penalty = conf_penalty(outputs)
            update_param(model, optimizer, L, scaler, clip)
            tot_loss.append(L.cpu().data.numpy())
            
    return np.mean(np.array(tot_loss))

def _epoch_train(all_inputs, all_extra, mixed_target, idx,
                 mixup_beta, model, batch_size, criterion, epoch_coeff, use_f1):
    # (expanded_B, hdim)
    all_hidden = model(all_inputs, all_extra, None, None, False)
    hidden_a, hidden_b = all_hidden, all_hidden[idx]
    
    if use_f1:
        mixed_hidden = torch.cat(
            (all_hidden[:batch_size], mixup_beta * hidden_a + (1 - mixup_beta) * hidden_b),
            axis=0
        )
        
        logits = model.fc_cls(mixed_hidden)
        logits_s = logits[:batch_size]
        logits_x = logits[batch_size:batch_size*2]
        logits_u = logits[batch_size*2:]
        
        Lx, Lu, lamb = criterion(
            logits_x, mixed_target[batch_size:batch_size*2],
            logits_u, mixed_target[batch_size*2:],
            epoch_coeff,
            logits_s=logits_s, targets_s=mixed_target[:batch_size],
        )
    else:
        mixed_hidden = mixup_beta * hidden_a + (1 - mixup_beta) * hidden_b
        
        logits = model.fc_cls(mixed_hidden)
        logits_x = logits[:batch_size]
        logits_u = logits[batch_size:]
        
        Lx, Lu, lamb = criterion(
            logits_x, mixed_target[:batch_size],
            logits_u, mixed_target[batch_size:],
            epoch_coeff
        )
        
    loss = Lx + lamb * Lu
    return loss, Lx, Lu

def epoch_train(
        epoch,
        model,
        model_2,
        optimizer,
        labeled_train_loader,
        unlabeled_train_loader,
        T,
        alpha,
        scaler,
        clip,
        criterion,
        batch_size,
        ulabel_coeff,
        grad_noise=False,
        use_f1=False,
):
    """Core train stage of DivideMix
    
    Args:
        epoch (int): current epoch number
        model: model
        model_2: model_2
        optimizer: optimizer
        labeled_train_loader: labeled_train_loader
        unlabeled_train_loader: unlabeled_train_loader
        T (float): temperature for sharpening
        alpha (float): parameter for Beta (mixup)
        scaler: pytorch apex
        clip (float): gradient clipping threshold
        criterion: objective function (SemiLoss)
        batch_size (int): batch size
        ulabel_coeff (float): interpolate co-guessed labels with ground-truth
        grad_noise (bool): add gradient noise if True
    """
    model.train()
    model_2.eval() # fix one network and train the other
    
    unlabeled_train_iter = iter(unlabeled_train_loader)
    num_iter = (len(labeled_train_loader.dataset) // batch_size) + 1
    
    # inputs_x: feat1
    # inputs_x2: feats2
    # labels_x: ys (B, odim)
    # w_x: gmm prob (B, 1)
    # inputs_u: unlabeled feat1
    # inputs_u2: unlabeled feat2
    
    tot_Lx = []
    tot_Lu = []
    
    for batch_idx, (inputs_x, inputs_x2, labels_x, w_x, extra_x) in enumerate(labeled_train_loader):
        try:
            inputs_u, inputs_u2, labels_u, extra_u = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_train_loader)
            inputs_u, inputs_u2, labels_u, extra_u = unlabeled_train_iter.next()
        batch_size = inputs_x.size(0)
        
        # Transform label to one-hot
        w_x = w_x.view(-1,1).type(torch.FloatTensor)  # (B, 1)
        
        inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
        inputs_u, inputs_u2, labels_u = inputs_u.cuda(), inputs_u2.cuda(), labels_u.cuda()
        extra_x, extra_u = extra_x.cuda(), extra_u.cuda()
        odim = labels_x.size(-1)
        
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                ### label co-guessing of unlabeled samples
                outputs_u11 = model(inputs_u, extra_u, return_logit=True)
                #outputs_u12 = model(inputs_u2, return_logit=True)
                outputs_u21 = model_2(inputs_u, extra_u, return_logit=True)
                #outputs_u22 = model_2(inputs_u2, return_logit=True)
                """
                pu = (torch.sigmoid(outputs_u11)
                        + torch.sigmoid(outputs_u12)
                        + torch.sigmoid(outputs_u21)
                        + torch.sigmoid(outputs_u22)
                        + 2 * labels_u
                ) / 6
                """
                pu = (torch.sigmoid(outputs_u11)
                      #+ torch.sigmoid(outputs_u12)
                      + torch.sigmoid(outputs_u21)
                      #+ torch.sigmoid(outputs_u22)
                ) / 2
                targets_u = ((1.0 - ulabel_coeff) * pu + ulabel_coeff * labels_u).detach()
                #+ 2 * labels_u
                #) / 6
                #ptu = pu #** (1.0 / T) # temparature sharpening
                #targets_u = ptu #/ ptu.sum(dim=1, keepdim=True) # normalize
                #targets_u = targets_u.detach()
                
                ### label refinement of labeled samples
                outputs_x = model(inputs_x, extra_x, return_logit=True)
                #outputs_x2 = model(inputs_x2, return_logit=True)
                
                px = (
                    torch.sigmoid(outputs_x)
                    #+ torch.sigmoid(outputs_x2)
                ) #/ 2
                targets_x = (w_x * labels_x + (1 - w_x) * px).detach()
                #ptx = px ** (1 / T) # temparature sharpening
                #targets_x = ptx #/ ptx.sum(dim=1, keepdim=True) # normalize
                #targets_x = targets_x.detach()
                
        # mixmatch
        mixup_beta = np.random.beta(alpha, alpha)
        mixup_beta = max(mixup_beta, 1 - mixup_beta)
        epoch_coeff = epoch + batch_idx / num_iter
        
        all_inputs = torch.cat([inputs_x, inputs_u], dim=0)
        all_extra = torch.cat([extra_x, extra_u], dim=0)
        all_targets = torch.cat([targets_x, targets_u], dim=0)
        
        mixup_idx = torch.randperm(all_inputs.size(0))
        target_a, target_b = all_targets, all_targets[mixup_idx]
        mixed_target = mixup_beta * target_a + (1 - mixup_beta) * target_b
        
        if use_f1:
            mixed_target = torch.cat(
                (targets_x, mixed_target),
                axis=0,
            )
        
        with torch.cuda.amp.autocast():
            loss, Lx, Lu = _epoch_train(
                all_inputs,
                all_extra,
                mixed_target,
                mixup_idx,
                mixup_beta,
                model,
                batch_size,
                criterion,
                epoch_coeff,
                use_f1
            )

            update_param(model, optimizer, loss, scaler, clip)
            tot_Lx.append(Lx.cpu().data.numpy())
            tot_Lu.append(Lu.cpu().data.numpy())
                
    return np.mean(tot_Lx), np.mean(tot_Lu)
            
def epoch_eval_train(model, all_loss, loader):
    """GMM
    """
    model.eval()
    losses = torch.zeros(len(loader.dataset))
    
    with torch.no_grad():
        for batch_idx, (feats, ys, ids, extra) in enumerate(loader):
            feats, ys, extra = feats.cuda(), ys.cuda(), extra.cuda()

            with torch.cuda.amp.autocast():
                logits = model(feats, extra, return_logit=True)
                loss = torch.sum(
                    F.binary_cross_entropy_with_logits(logits, ys, reduction='none'),
                    axis=1,
                )
                for b in range(feats.size(0)):
                    losses[ids[b]] = loss[b]

    zero_grad(model)
    
    losses = (losses - losses.min()) / (losses.max() - losses.min())
    all_loss.append(losses)
    
    input_loss = losses.reshape(-1,1)
    
    # fit a two-component GMM to the loss
    gmm = GaussianMixture(
        n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4
    )
    gmm.fit(input_loss)
    prob = gmm.predict_proba(input_loss)
    prob = prob[:, gmm.means_.argmin()]
    return prob, all_loss

def epoch_test(
        epoch, model_1, model_2, loader, thr=0.3, return_logit=False, is_swa=False
):
    """Evaluate current model(s)

    Args:
        epoch (int): current epoch number
        model_1: first model
        model_2: second model
        loader: data loader
        thr (float): thoreshold to make binary prediction
    """
    pred_1 = []
    pred_2 = []
    target = []
    model_1.eval()
    if model_2 is not None:
        model_2.eval()
    
    with torch.no_grad():
        for batch_idx, (feats, ys, extra) in enumerate(loader):
            feats, ys, extra = feats.cuda(), ys.cuda(), extra.cuda()

            with torch.cuda.amp.autocast():
                if is_swa:
                    outputs_1 = model_1.module.predict(feats, extra, None, prob=True)
                else:
                    outputs_1 = model_1.predict(feats, extra, None, prob=True)
                    
                if model_2 is not None:
                    if is_swa:
                        outputs_2 = model_2.module.predict(feats, extra, None, prob=True)
                    else:
                        outputs_2 = model_2.predict(feats, extra, None, prob=True)
                        
                target.append(ys.cpu().data.numpy().astype(np.int32))
                pred_1.append(outputs_1.cpu().data.numpy())
                if model_2 is not None:
                    pred_2.append(outputs_2.cpu().data.numpy())

    zero_grad(model_1)
    model_1.train()
    if model_2 is not None:
        zero_grad(model_2)
        model_2.train()

    target = np.vstack(target)
    pred_1 = np.vstack(pred_1)
    pred_bin_1 = (pred_1 > thr).astype(np.int32)
    if model_2 is not None:
        pred_2 = np.vstack(pred_2)
        pred_bin_2 = (pred_2 > thr).astype(np.int32)

    if return_logit:
        return pred_1, target
    
    acc_1 = compute_accuracy(target, pred_bin_1)
    acc_2 = compute_accuracy(target, pred_bin_2)
    f1_1 = compute_f_measure(target, pred_bin_1)[0]
    f1_2 = compute_f_measure(target, pred_bin_2)[0]
    challenge_1 = compute_challenge_metric(
        weights, target, pred_bin_1, classes, normal_class
    )
    challenge_2 = compute_challenge_metric(
        weights, target, pred_bin_2, classes, normal_class
    )

    return acc_1, f1_1, challenge_1, acc_2, f1_2, challenge_2

@torch.no_grad()
def update_bn(loader, model, device=None):
    r"""Updates BatchNorm running_mean, running_var buffers in the model.
    It performs one pass over data in `loader` to estimate the activation
    statistics for BatchNorm layers in the model.
    Args:
        loader (torch.utils.data.DataLoader): dataset loader to compute the
            activation statistics on. Each data batch should be either a
            tensor, or a list/tuple whose first element is a tensor
            containing data.
        model (torch.nn.Module): model for which we seek to update BatchNorm
            statistics.
        device (torch.device, optional): If set, data will be transferred to
            :attr:`device` before being passed into :attr:`model`.
    Example:
        >>> loader, model = ...
        >>> torch.optim.swa_utils.update_bn(loader, model)
    .. note::
        The `update_bn` utility assumes that each data batch in :attr:`loader`
        is either a tensor or a list or tuple of tensors; in the latter case it
        is assumed that :meth:`model.forward()` should be called on the first
        element of the list or tuple corresponding to the data batch.
    """
    momenta = {}
    for module in model.modules():
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            module.running_mean = torch.zeros_like(module.running_mean)
            module.running_var = torch.ones_like(module.running_var)
            momenta[module] = module.momentum
            
    if not momenta:
        return
    
    was_training = model.training
    model.train()
    for module in momenta.keys():
        module.momentum = None
        module.num_batches_tracked *= 0
        
    for input, _, _, extra in loader:
        if device is not None:
            input, extra = input.to(device), extra.cuda()
            
        with torch.cuda.amp.autocast():
            model(input, extra, None, None, True)
            
    for bn_module in momenta.keys():
        bn_module.momentum = momenta[bn_module]
        
    model.train(was_training)

def save_log(dname, f1_list, challenge_list, model_idx=1, prefix=""):
    # save f1 and challenge score
    with open('%s/log_%s%d.log' % (dname, prefix, model_idx), 'w') as f:
        for i, (f1, challenge) in enumerate(zip(f1_list, challenge_list)):
            f.writelines("{},{},{}\n".format(i, f1, challenge))
            
    # save best index
    best = {
        "best_f1": int(np.argmax(np.array(f1_list))),
        "best_challenge": int(np.argmax(np.array(challenge_list))),
    }
    with open('%s/best_ids_%s%d.log' % (dname, prefix, model_idx), 'w') as f:
        json.dump(best, f, indent=4)
        
    # decide best model
    for obj_fn, best_idx in best.items():
        shutil.copyfile(
            '%s/%smodel_%d_epoch_%d.pth' % (dname, prefix, model_idx, best_idx),
            '%s/%smodel_%d_%s.pth' % (dname, prefix, model_idx, obj_fn)
        )
        
class SwaModel(nn.Module):
    def __init__(self, idim, odim, args, m1=None, m2=None, m3=None, m4=None):
        super(SwaModel, self).__init__()
        
        self.idim = idim
        self.odim = odim
        self.m1 = m1
        self.m2 = m2
        self.m3 = m3
        self.m4 = m4
    
def swa_divmix_train(args, leads):
    # get idim, edim and odim
    sym2int = read_dic(args.dict)
    idim = args.n_leads
    if args.edim > 0:
        args.edim = get_edim(
            idim, args.valid_json.replace('.json', '_extra.json'),
            preproc_config["use_tsfresh"]
        )
    odim = len(list(sym2int.keys()))
    assert(idim == len(leads))
    print('idim: {}, edim: {}, odim: {}'.format(idim, args.edim, odim))
    
    os.makedirs(args.outdir, exist_ok=True)
    
    model_1 = DivMixNet(idim, odim, args)
    model_2 = DivMixNet(idim, odim, args)
    model_1.sym2int = sym2int
    model_2.sym2int = sym2int
    swa_model_1 = AveragedModel(model_1)
    swa_model_2 = AveragedModel(model_2)
    print(model_1)

    def _replace(class_name):
        #if class_name == "17338001":
        #    return "427172004"
        #elif class_name == "63593006":
        #    return "284470004"
        #elif class_name == "59118001":
        #    return "713427006"
        #elif class_name == "164909002":
        #    return "733534002"
        if class_name == "17338001" or "427172004":
            class_name = {"17338001", "427172004"}
        elif class_name == "63593006" or "284470004":
            class_name = {"63593006", "284470004"}
        elif class_name == "59118001" or "713427006":
            class_name = {"59118001", "713427006"}
        elif class_name == "164909002" or "733534002":
            class_name = {"164909002", "733534002"}
        else:
            class_name = {class_name}
        return class_name
    
    # update classes and weights for computation of challenge score
    global classes
    global weights
    classes = list(sym2int.keys())
    weights = np.zeros_like(_weights)
    for ci in classes:
        src_i = _classes.index(_replace(ci))
        dst_i = classes.index(ci)
        for cj in classes:
            src_j = _classes.index(_replace(cj))
            dst_j = classes.index(cj)
            weights[dst_i, dst_j] = _weights[src_i, src_j]
            
    # write model config
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps(
                (idim, odim, vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8")
        )
    for key in sorted(vars(args).keys()):
        print("ARGS: {}: {}".format(key, str(vars(args)[key])))
    print(model_1)
    
    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    dtype = torch.float32

    model_1 = model_1.to(device=device, dtype=dtype)
    model_2 = model_2.to(device=device, dtype=dtype)
    swa_model_1 = swa_model_1.to(device=device, dtype=dtype)
    swa_model_2 = swa_model_2.to(device=device, dtype=dtype)
    
    optimizer_1, optimizer_2, scheduler = get_optimizer(
        args, model_1.parameters(), model_2.parameters()
    )
    
    # setup apex.amp
    scaler_1 = torch.cuda.amp.GradScaler()
    scaler_2 = torch.cuda.amp.GradScaler()

    criterion = SemiLoss(args.lambda_u, args.warm_up, 5, f1_coeff=args.f1_coeff)
    
    master_train_loader = physionet_dataloader(
        args.train_json,
        batch_size=args.batch_size,
        num_workers=os.cpu_count(),
        cache_dir=args.outdir + '/cache_tr',
        idim=idim,
        odim=odim,
        n_leads=args.n_leads,
        min_seqlen=args.min_seqlen,
        max_seqlen=args.max_seqlen,
        loader_version=args.loader_version,
        edim=args.edim,
        efeat_drop=args.efeat_drop,
    )
    master_valid_loader = physionet_dataloader(
        args.valid_json,
        batch_size=args.batch_size,
        num_workers=os.cpu_count(),
        cache_dir=args.outdir + '/cache_dt',
        idim=idim,
        odim=odim,
        n_leads=args.n_leads,
        max_seqlen=args.max_seqlen,
        loader_version=args.loader_version,
        edim=args.edim,
        efeat_drop=0.0,
    )

    all_loss = [[],[]] # save the history of losses from two networks
    #log_acc = []
    log_f1_1 = []
    log_f1_2 = []
    log_challenge_1 = []
    log_challenge_2 = []
    #log_loss = []
    log_f1_swag_1 = []
    log_f1_swag_2 = []
    log_challenge_swag_1 = []
    log_challenge_swag_2 = []

    seqlen_list = np.ones(args.epochs + 1) * args.max_seqlen
    seqlen_list = seqlen_list.astype(np.int32)
    
    for epoch in range(args.epochs):
        print(datetime.datetime.now())
        start_time = time.time()

        # 1. epoch dependent setting
        print("epoch: {}".format(epoch))
        max_seqlen = seqlen_list[epoch]

        # 1.2 progressive learning of effnet v2
        eval_train_loader = master_train_loader.run('eval_train', max_seqlen=max_seqlen)
        valid_loader = master_valid_loader.run('test', max_seqlen=max_seqlen)

        # 2. training
        if epoch < args.warm_up:
            # 2.1 warmup training
            warmup_train_loader = master_train_loader.run('warmup', max_seqlen=max_seqlen)
            np_loss_1 = epoch_warmup(
                epoch, model_1, optimizer_1, warmup_train_loader, scaler_1, args.grad_clip
            )
            np_loss_2 = epoch_warmup(
                epoch, model_2, optimizer_2, warmup_train_loader, scaler_2, args.grad_clip
            )
            print("\ttrain loss: {}".format(np_loss_1 + np_loss_2))
            
        else:
            # 2.2 dividemix training
            prob1, all_loss[0] = epoch_eval_train(model_1, all_loss[0], eval_train_loader)
            prob2, all_loss[1] = epoch_eval_train(model_2, all_loss[1], eval_train_loader)
            
            pred1 = (prob1 > args.p_threshold)
            pred2 = (prob2 > args.p_threshold)

            # 2.3 train net1 (co-divide)
            labeled_train_loader, unlabeled_train_loader = \
                master_train_loader.run('train', pred2, prob2, max_seqlen=max_seqlen)
            np_loss_x1, np_loss_u1 = epoch_train(
                epoch, model_1, model_2, optimizer_1, labeled_train_loader, unlabeled_train_loader,
                args.T, args.alpha, scaler_1, args.grad_clip, criterion, args.batch_size,
                args.ulabel_coeff, args.grad_noise, args.use_f1,
            )

            # 2.4 train net2 (co-divide)
            labeled_train_loader, unlabeled_train_loader = \
                master_train_loader.run('train', pred1, prob1, max_seqlen=max_seqlen)
            np_loss_x2, np_loss_u2 = epoch_train(
                epoch, model_2, model_1, optimizer_2, labeled_train_loader, unlabeled_train_loader,
                args.T, args.alpha, scaler_2, args.grad_clip, criterion, args.batch_size,
                args.ulabel_coeff, args.grad_noise, args.use_f1,
            )
            print("\ttrain loss_x: {}, loss_u: {}".format(
                np_loss_x1 + np_loss_x2, np_loss_u1 + np_loss_u2)
            )

        # 3. validate
        acc_1, f1_1, challenge_1, acc_2, f1_2, challenge_2 = \
            epoch_test(epoch, model_1, model_2, valid_loader)
        log_f1_1.append(f1_1)
        log_f1_2.append(f1_2)
        log_challenge_1.append(challenge_1)
        log_challenge_2.append(challenge_2)
        print("\tvalid f1_1: {}, challenge_1: {}, f1_2: {}, challenge_2: {}".format(
            f1_1, challenge_1, f1_2, challenge_2
        ), flush=True)

        # 4. save model
        torch.save(model_1.state_dict(), "%s/model_1_epoch_%d.pth" % (args.outdir, epoch))
        torch.save(model_2.state_dict(), "%s/model_2_epoch_%d.pth" % (args.outdir, epoch))

        # 5. update bn statistics for swa model
        if epoch > args.swag_epoch_start:
            swa_model_1.update_parameters(model_1)
            swa_model_2.update_parameters(model_2)
            
            update_bn(warmup_train_loader, swa_model_1, 'cuda')
            update_bn(warmup_train_loader, swa_model_2, 'cuda')
            swag_acc_1, swag_f1_1, swag_challenge_1, swag_acc_2, swag_f1_2, swag_challenge_2 \
                = epoch_test(epoch, swa_model_1, swa_model_2, valid_loader, is_swa=True)

            log_f1_swag_1.append(swag_f1_1)
            log_f1_swag_2.append(swag_f1_2)
            log_challenge_swag_1.append(swag_challenge_1)
            log_challenge_swag_2.append(swag_challenge_2)

            torch.save(swa_model_1.module.state_dict(), "%s/swa_model_1_epoch_%d.pth" % (args.outdir, epoch))
            torch.save(swa_model_2.module.state_dict(), "%s/swa_model_2_epoch_%d.pth" % (args.outdir, epoch))
            print("\t (swa valid f1_1: {}, challenge_1: {}, f1_2: {}, challenge_2: {})".format(
                swag_f1_1, swag_challenge_1, swag_f1_2, swag_challenge_2), flush=True
            )
        else:
            log_f1_swag_1.append(0.0)
            log_f1_swag_2.append(0.0)
            log_challenge_swag_1.append(0.0)
            log_challenge_swag_2.append(0.0)

        if scheduler is not None:
            scheduler(epoch)

        end_time = time.time()
        print("\tduration [sec]: {}".format(end_time - start_time), flush=True)
            
    print("training done.")
    
    # save best model
    save_log(args.outdir, log_f1_1, log_challenge_1, model_idx=1)
    save_log(args.outdir, log_f1_2, log_challenge_2, model_idx=2)
    # save best swa model
    save_log(args.outdir, log_f1_swag_1, log_challenge_swag_1, model_idx=1, prefix='swa_')
    save_log(args.outdir, log_f1_swag_2, log_challenge_swag_2, model_idx=2, prefix='swa_')
    
    m1_path = args.outdir + '/model_1_best_challenge.pth'
    m2_path = args.outdir + '/model_2_best_challenge.pth'
    m3_path = args.outdir + '/swa_model_1_best_challenge.pth'
    m4_path = args.outdir + '/swa_model_2_best_challenge.pth'
    m1, train_args = load_trained_model(m1_path)
    m2, train_args = load_trained_model(m2_path)
    m3, train_args = load_trained_model(m3_path)
    m4, train_args = load_trained_model(m4_path)
    
    final_model = SwaModel(idim, odim, args, m1, m2, m3, m4)
    final_model.train_args = train_args
    final_model.sym2int = sym2int

    # threshold tuning
    final_model.thr = tune_threshold(
        master_valid_loader.run('test'), final_model,
    )
    
    return {
        "classes": list(model_1.sym2int.keys()),
        "leads": leads,
        "imputer": None,
        "classifier": final_model,
    }

def train(data_directory, model_directory, train_leads_sets):
    parser = get_train_parser()
    parser = DivMixNet.add_arguments(parser)
    args = parser.parse_args()

    #from espnet.utils.dynamic_import import dynamic_import
    #model_module = args.model_module
    #model_class = dynamic_import(model_module)
    #model_class.add_arguments(parser)
    #args = parser.parse_args(cmd_args)
    #args.model_module = model_module
    
    # display PYTHONPATH
    print("python path = " + os.environ.get("PYTHONPATH", "(None)"))
    
    # set random seed
    #random.seed(args.seed)
    #np.random.seed(args.seed)

    def set_exp(args, model_directory, n_leads):
        # directory to save snapshot
        outdir = model_directory + '/train_' + str(n_leads)
        os.makedirs(outdir, exist_ok=True)
        args.outdir = outdir
        args.n_leads = n_leads
        print("************************************************************")
        print("training {}-leads model.......".format(args.n_leads))
        return args
    
    # set task independent configs
    args.train_json = data_directory + "/tr.json"
    args.valid_json = data_directory + "/tt.json"
    args.batch_size = 200 #128 # tesla t4: 16GB GDDR6 memory
    args.opt = "adam"
    #args.opt = "adabelief"

    if False:
        args.warm_up = 6
        args.swag_epoch_start = 3
        args.epochs = 6

    if False:
        args.warm_up = 6
        args.swag_epoch_start = 6
        args.epochs = 10
        args.edim = 1

    if True:
        args.warm_up = 2
        args.swag_epoch_start = 26
        args.epochs = 34 #40
        args.edim = 1
    
    args.dict = data_directory + "/sym2int_label"

    # train 2-lead model
    if two_leads in train_leads_sets:
        args = set_exp(args, model_directory, 2)
        ret = swa_divmix_train(args, leads=two_leads)
        save_model(
            model_directory, ret["leads"], ret["classes"], ret["imputer"], ret["classifier"]
        )

    # train 3-lead model
    if three_leads in train_leads_sets:
        args = set_exp(args, model_directory, 3)
        ret = swa_divmix_train(args, leads=three_leads)
        save_model(
            model_directory, ret["leads"], ret["classes"], ret["imputer"], ret["classifier"]
        )
        
    # train 4-lead model
    if four_leads in train_leads_sets:
        args = set_exp(args, model_directory, 4)
        ret = swa_divmix_train(args, leads=four_leads)
        save_model(
            model_directory, ret["leads"], ret["classes"], ret["imputer"], ret["classifier"]
        )

    # train 6-lead model
    if six_leads in train_leads_sets:
        args = set_exp(args, model_directory, 6)
        ret = swa_divmix_train(args, leads=six_leads)
        save_model(
            model_directory, ret["leads"], ret["classes"], ret["imputer"], ret["classifier"]
        )
    
    # train 12-lead model
    if twelve_leads in train_leads_sets:
        args = set_exp(args, model_directory, 12)
        ret = swa_divmix_train(args, leads=twelve_leads)
        save_model(
            model_directory, ret["leads"], ret["classes"], ret["imputer"], ret["classifier"]
        )

    
################################################################################
#
# Official Training function
#
################################################################################

# Train your model. This function is *required*. Do *not* change the arguments of this function.
def training_code(data_directory, model_directory):
    print("data directory: {}".format(data_directory))
    print("model directory: {}".format(model_directory))
    
    preprocessed = "preprocessed"
    if True:
        make_json_files(
            data_directory, preprocessed,
            dt_rate=preproc_config["dt_rate"],
            tt_rate=preproc_config["tt_rate"],
            fs=preproc_config["fs"],
            trim_dur=preproc_config["trim_dur"],
            norm_opt=preproc_config["norm_opt"],
            denoise=preproc_config["denoise"],
            use_tsfresh=preproc_config["use_tsfresh"],
        )

    #exit(0)
    
    # Find header and recording files.
    #print('Finding header and recording files...')
    #header_files, recording_files = find_challenge_files(data_directory)
    #num_recordings = len(recording_files)
    #if not num_recordings:
    #    raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    if not os.path.isdir(model_directory):
        os.mkdir(model_directory)

    # train model
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    cuda_avail = initialize()
    if not cuda_avail:
        logging.error("cuda is not available.")
        exit(0)
    train(preprocessed, model_directory, lead_sets)
    
    ## Extract classes from dataset.
    #print('Extracting classes...')

    #classes = set()
    #for header_file in header_files:
    #    header = load_header(header_file)
    #    classes |= set(get_labels(header))
    #if all(is_integer(x) for x in classes):
    #    classes = sorted(classes, key=lambda x: int(x)) # Sort classes numerically if numbers.
    #else:
    #    classes = sorted(classes) # Sort classes alphanumerically otherwise.
    #num_classes = len(classes)

    # Extract features and labels from dataset.
    #print('Extracting features and labels...')

    #data = np.zeros((num_recordings, 14), dtype=np.float32) # 14 features: one feature for each lead, one feature for age, and one feature for sex
    #labels = np.zeros((num_recordings, num_classes), dtype=np.bool) # One-hot encoding of classes

    """
    for i in range(num_recordings):
        print('    {}/{}...'.format(i+1, num_recordings))

        # Load header and recording.
        header = load_header(header_files[i])
        recording = load_recording(recording_files[i])

        # Get age, sex and root mean square of the leads.
        age, sex, rms = get_features(header, recording, twelve_leads)
        data[i, 0:12] = rms
        data[i, 12] = age
        data[i, 13] = sex

        current_labels = get_labels(header)
        for label in current_labels:
            if label in classes:
                j = classes.index(label)
                labels[i, j] = 1
    """
    ################################
    ### Train 12-lead ECG model. ###
    ################################
    """
    print('Training 12-lead ECG model...')

    leads = twelve_leads
    filename = os.path.join(model_directory, twelve_lead_model_filename)

    feature_indices = [twelve_leads.index(lead) for lead in leads] + [12, 13]
    features = data[:, feature_indices]

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)
    classifier = RandomForestClassifier(n_estimators=3, random_state=0).fit(features, labels)
    save_model(filename, classes, leads, imputer, classifier)
    """

    ###############################
    ### Train 6-lead ECG model. ###
    ###############################
    """
    print('Training 6-lead ECG model...')

    leads = six_leads
    filename = os.path.join(model_directory, six_lead_model_filename)

    feature_indices = [twelve_leads.index(lead) for lead in leads] + [12, 13]
    features = data[:, feature_indices]

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)
    classifier = RandomForestClassifier(n_estimators=3, random_state=0).fit(features, labels)
    save_model(filename, classes, leads, imputer, classifier)
    """
    
    ###############################
    ### Train 3-lead ECG model. ###
    ###############################
    """
    print('Training 3-lead ECG model...')

    leads = three_leads
    filename = os.path.join(model_directory, three_lead_model_filename)

    feature_indices = [twelve_leads.index(lead) for lead in leads] + [12, 13]
    features = data[:, feature_indices]

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)
    classifier = RandomForestClassifier(n_estimators=3, random_state=0).fit(features, labels)
    save_model(filename, classes, leads, imputer, classifier)
    """
    
    ###############################
    ### Train 2-lead ECG model. ###
    ###############################
    """
    print('Training 2-lead ECG model...')

    leads = two_leads
    filename = os.path.join(model_directory, two_lead_model_filename)

    feature_indices = [twelve_leads.index(lead) for lead in leads] + [12, 13]
    features = data[:, feature_indices]

    imputer = SimpleImputer().fit(features)
    features = imputer.transform(features)
    classifier = RandomForestClassifier(n_estimators=3, random_state=0).fit(features, labels)
    save_model(filename, classes, leads, imputer, classifier)
    """
    ### DONE! ###

    
    
################################################################################
#
# File I/O functions
#
################################################################################

# Save your trained models.
#def save_model(filename, classes, leads, imputer, classifier):
#    # Construct a data structure for the model and save it.
#    d = {'classes': classes, 'leads': leads, 'imputer': imputer, 'classifier': classifier}
#    joblib.dump(d, filename, protocol=0)
#
# Load your trained 12-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
#def load_twelve_lead_model(model_directory):
#    filename = os.path.join(model_directory, twelve_lead_model_filename)
#    return load_model(filename)
#
# Load your trained 6-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
#def load_six_lead_model(model_directory):
#    filename = os.path.join(model_directory, six_lead_model_filename)
#    return load_model(filename)
#
# Load your trained 3-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
#def load_three_lead_model(model_directory):
#    filename = os.path.join(model_directory, three_lead_model_filename)
#    return load_model(filename)
#
# Load your trained 2-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
#def load_two_lead_model(model_directory):
#    filename = os.path.join(model_directory, two_lead_model_filename)
#    return load_model(filename)
#
# Generic function for loading a model.
#def load_model(filename):
#    return joblib.load(filename)

# Save a trained model. This function is not required. You can change or remove it.
def save_model(model_directory, leads, classes, imputer, classifier):
    d = {'leads': leads, 'classes': classes, 'imputer': imputer, 'classifier': classifier}
    filename = os.path.join(model_directory, get_model_filename(leads))
    joblib.dump(d, filename, protocol=0)
    
# Load a trained model. This function is *required*.
# You should edit this function to add your code, but do *not* change the arguments of this function.
def load_model(model_directory, leads):
    filename = os.path.join(model_directory, get_model_filename(leads))
    return joblib.load(filename)

# Define the filename(s) for the trained models.
# This function is not required. You can change or remove it.
def get_model_filename(leads):
    sorted_leads = sort_leads(leads)
    return 'model_' + '-'.join(sorted_leads) + '.sav'

################################################################################
#
# Running trained model functions
#
################################################################################

# Run your trained 12-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
#def run_twelve_lead_model(model, header, recording):
#    return run_model(model, header, recording)
#
# Run your trained 6-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
#def run_six_lead_model(model, header, recording):
#    return run_model(model, header, recording)
#
# Run your trained 3-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
#def run_three_lead_model(model, header, recording):
#    return run_model(model, header, recording)
#
# Run your trained 2-lead ECG model. This function is *required*. Do *not* change the arguments of this function.
#def run_two_lead_model(model, header, recording):
#    return run_model(model, header, recording)


def pad(recording, seqlen):
    if recording is None:
        return recording
    
    ilen = recording.shape[1]
    _recording = np.zeros((recording.shape[0], seqlen), dtype=np.float32)
    if ilen <= seqlen:
        _recording[:, :recording.shape[1]] = recording
    else:
        #pos = random.randint(0, int(ilen - seqlen - 1))
        pos = int((ilen - seqlen - 1) / 2)
        _recording = recording[:, pos:pos+seqlen]
        ilen = seqlen
    return _recording

# Generic function for running a trained model.
def run_model(model, header, recording):
    classes = model['classes']
    leads = model['leads']
    imputer = model['imputer']
    classifier = model['classifier']
    classifier.eval()

    # Load features.
    n_leads = len(leads)
    print("n_leads: {}".format(n_leads))
    # lead_ids = physionet_dataset.get_lead_ids(n_leads)
    # print("train_args: {}".format(classifier.train_args))
    
    ecg_data = ECGData(
        header,
        recording,
        classifier.sym2int,
        outdir=None,
        fs=preproc_config["fs"],
        trim_dur=preproc_config["trim_dur"],
        norm_opt=preproc_config["norm_opt"],
        denoise=preproc_config["denoise"],
        use_tsfresh=preproc_config["use_tsfresh"],
        is_train_mode=False,
    )
    sample_id, sample_dic, extra_sample_dic = ecg_data.process(leads=leads)
    feat, extra = ECGData.get_input_features(
        sample_dic, extra_sample_dic, classifier.train_args
    )
    feat = pad(feat, max_seqlen)

    
    if feat is not None:
        # prepare input features
        # (C, T) -> (B, C, T)
        feat = torch.from_numpy(feat).unsqueeze(0)
        extra = torch.from_numpy(extra).unsqueeze(0)

        if torch.cuda.is_available():
            classifier = classifier.cuda()
            feat, extra = feat.cuda(), extra.cuda()
        print("---shape of feat/extra: {}/{}".format(feat.size(), extra.size()))
    
        # predict
        # (TODO:hseki)
        thr = np.array(classifier.thr).reshape(-1) * 0.0 + 0.3
        probabilities = (
            classifier.m1.predict(feat, extra, prob=True).cpu().data.numpy()[0]
            + classifier.m2.predict(feat, extra, prob=True).cpu().data.numpy()[0]
            + classifier.m3.predict(feat, extra, prob=True).cpu().data.numpy()[0]
            + classifier.m4.predict(feat, extra, prob=True).cpu().data.numpy()[0]
        ) / 4.0
        labels = (probabilities > thr).astype(np.int32)

    else:
        labels = np.zeros(len(classes), dtype=np.float32)
        probabilitiess = np.ones(len(classes), dtype=np.float32) * 0.5
        
    """
    num_leads = len(leads)
    data = np.zeros(num_leads+2, dtype=np.float32)
    age, sex, rms = get_features(header, recording, leads)
    data[0:num_leads] = rms
    data[num_leads] = age
    data[num_leads+1] = sex
    
    # Impute missing data.
    features = data.reshape(1, -1)
    features = imputer.transform(features)
    """
    """
    # Predict labels and probabilities.
    labels = classifier.predict(features)
    labels = np.asarray(labels, dtype=np.int)[0]
    probabilities = classifier.predict_proba(features)
    probabilities = np.asarray(probabilities, dtype=np.float32)[:, 0, 1]
    """
    return classes, labels, probabilities

################################################################################
#
# Other functions
#
################################################################################

# Extract features from the header and recording.
def get_features(header, recording, leads):
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    # Reorder/reselect leads in recordings.
    available_leads = get_leads(header)
    indices = list()
    for lead in leads:
        i = available_leads.index(lead)
        indices.append(i)
    recording = recording[indices, :]

    # Pre-process recordings.
    adc_gains = get_adcgains(header, leads)
    baselines = get_baselines(header, leads)
    num_leads = len(leads)
    for i in range(num_leads):
        recording[i, :] = (recording[i, :] - baselines[i]) / adc_gains[i]

    # Compute the root mean square of each ECG lead signal.
    rms = np.zeros(num_leads, dtype=np.float32)
    for i in range(num_leads):
        x = recording[i, :]
        rms[i] = np.sqrt(np.sum(x**2) / np.size(x))

    return age, sex, rms
