#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 12 17:35:05 2021

@author: chadyang
"""
#%%
from helper_code import get_age, get_frequency, get_sex, load_recording, get_adc_gains, get_baselines
import torch
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import json, random

from pathlib import Path
from lib.evaluation_2021.evaluate_model import *
from src.engine.utils import map_sex, replace_equivalent_classes, SCORED_CLASSES, EQ_CLASSES, ALL_TRAIN_SITES
from src.engine.loaders.transforms import *
from tqdm import tqdm
from neurokit2.signal import signal_resample
from torch.nn import functional as F
from torchvision import transforms
# from torchaudio.transforms import Spectrogram
import coloredlogs, logging
coloredlogs.install()
logger = logging.getLogger(__name__)


#%%
class TestDataset(Dataset):
    def __init__(self, config, header, recording, folds=None, transform=None, mode='test'):
        super(TestDataset, self).__init__()

        # Initialization
        self.config = config
        self.header = header
        self.recording = recording
        self.param_aug = config.param_aug
        # self.num_leads = config.param_loader.num_leads
        self.param_loader = self.config.param_loader
        self.max_length = int(config.param_feature.raw.sr*config.param_feature.raw.window)
        if self.config.param_loader.get("spectrogram"): self.max_length = int(np.ceil(self.max_length/self.param_loader.spectrogram.HOP_LEN))
        self.mode = mode
        self.random_state = config.exp.random_state
        self.transform = transform if transform else self._get_transform(mode)
        np.random.seed(self.random_state)
        torch.manual_seed(self.random_state)

        # process signal and to patches 
        self._get_patch()

    def __len__(self):
        return len(self.all_patch)

    def _get_transform(self, mode="train"):
        # input dimension: num_lead*len_signal
        tfm = transforms.Compose([])
        
        # time series raw signal in numpy
        if self.param_aug.get("RemoveBaselineWander"): tfm.transforms.append(RemoveBaselineWander(cutoff=self.param_aug.RemoveBaselineWander.cutoff))
        if self.param_aug.get("BandPass"): tfm.transforms.append(BandPass(sr=self.config.param_feature.raw.sr, cutoff=self.param_aug.BandPass.cutoff))

        # # spectrogram
        # if self.config.param_loader.get("spectrogram"):
        #     tfm.transforms.append(MinMaxScaler())
        #     tfm.transforms.append(AsTensor())
        #     tfm.transforms.append(Spectrogram(n_fft=self.param_loader.spectrogram.NFFT, hop_length=self.param_loader.spectrogram.HOP_LEN))
        #     if mode=="train" and self.param_aug.get("RandomCrop"): 
        #         tfm.transforms.append(transforms.RandomCrop(
        #             size=[int(self.param_loader.spectrogram.NFFT/2)+1, self.max_length], padding=0, pad_if_needed=True))
        #     else: tfm.transforms.append(transforms.CenterCrop(size=[int(self.param_loader.spectrogram.NFFT/2)+1, self.max_length]))

        # times sereis
        # else:
        if self.param_aug.get("Rescale"): 
            if self.param_aug.Rescale=="zscore": tfm.transforms.append(Zscore())
            elif self.param_aug.Rescale=="minmax": tfm.transforms.append(MinMaxScaler())
        tfm.transforms.append(transforms.ToTensor())
        # if mode=="train" and self.param_aug.get("RandomCrop"): tfm.transforms.append(transforms.RandomCrop(size=[self.num_leads, self.max_length], padding=0, pad_if_needed=True))
        # else:
        #     print("hererere") 
        #     tfm.transforms.append(transforms.CenterCrop(size=[len(self.recording), self.max_length]))

        return tfm

    def _get_patch(self):
        # =============================================================================
        # naive preprocess
        # =============================================================================
        # sampling rate
        sr = get_frequency(self.header)
        adc_gain = get_adc_gains(self.header, leads=self.config.param_loader.leads)
        baseline = get_baselines(self.header, leads=self.config.param_loader.leads)
        if sr!=self.config.param_feature.raw.sr:
            recording = filter_signal(self.recording, sample_rate=sr, cutoff=[3,45], filtertype="bandpass") # bandpass before resample (prevent high-freq noise warping)
            recording = signal_resample(list(recording.T), sampling_rate=sr, desired_sampling_rate=self.config.param_feature.raw.sr, method="FFT").T
        
        # augmentation
        if self.transform:
            recording = self.transform(self.recording)
            if self.config.param_loader.get("spectrogram") is None:
                recording = recording.squeeze(0)
        self.recording = recording.float()
        print("recording",self.recording.shape)
        num_leads, fea_dim = self.recording.shape[0], self.recording.shape[1]
        # segment test recording to several patches
        win_size = self.max_length
        overlap = self.config.exp.overlap_size
        num_patch = ((fea_dim-win_size)//(win_size-overlap)) + 1

        # if recording is too short, return only one patch
        if num_patch<=0:    
            tr = transforms.CenterCrop(size=[len(self.recording), self.max_length])
            all_patch = [tr(self.recording)]

        # if recording is long enough, return multiple patches
        else:
            all_patch = []
            for p in range(num_patch):
                patch = self.recording[:, p*win_size:(p+1)*win_size]
                if patch.shape[1]<self.max_length:  continue
                all_patch.append(patch)
        self.all_patch = torch.stack(all_patch)

    def __getitem__(self, idx):
        
        # get demo
        age, sex = np.asarray(get_age(self.header)).astype("float32"), get_sex(self.header)
        sex = map_sex(sex, impute_nan=1)

        # collect
        x = {"signal":self.all_patch[idx], "age":age, "sex":sex}

        # =============================================================================
        # label
        # =============================================================================
        dummy = np.tile(0, len(SCORED_CLASSES))
        return x, dummy



#%% test data loader
if __name__=="__main__":
    from omegaconf import OmegaConf
    header_files, recording_files = find_challenge_files("../PhysioNet-CinC-Challenges-2021/datasets/raw/WFDB_PTB")

    header = load_header(header_files[0])
    recording = load_recording(recording_files[0])

    config = OmegaConf.load("./config/exp/cnn_raw.yaml")
    testset = TestDataset(config, header, recording)
