#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 12 17:35:05 2021

@author: chadyang
"""
#%%
from helper_code import get_age, get_frequency, get_sex, load_recording, get_adc_gains, get_baselines
import torch
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import json, random
import multiprocessing as mp

from pathlib import Path
from lib.evaluation_2021.evaluate_model import *
from src.engine.utils import map_sex, replace_equivalent_classes, SCORED_CLASSES, EQ_CLASSES, ALL_TRAIN_SITES
from src.engine.loaders.transforms import *
from tqdm import tqdm
from neurokit2.signal import signal_resample
from torchvision import transforms
# from torchaudio.transforms import Spectrogram
import coloredlogs, logging
coloredlogs.install()
logger = logging.getLogger(__name__)


#%%
class RawDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.param_aug = config.param_aug
        self.leads = self.config.param_loader.leads
        self.train_sites = ALL_TRAIN_SITES if config.exp.train_sites=="all" else config.exp.train_sites
        self.eval_sites = ALL_TRAIN_SITES if config.exp.eval_sites=="all" else config.exp.eval_sites
        self.N_fold = config.exp.N_fold
        self.N_fold_Order = config.exp.N_fold_Order
        self.fold_data_scored_all, self.fold_data_unscored_all = None, None

    def setup(self):
        # parse scored fold jsons
        fold_data_scored_all = {mode:[[] for _ in range(self.N_fold)] for mode in ["train","eval"]}
        fold_data_unscored_all = {mode:[[] for _ in range(self.N_fold)] for mode in ["train","eval"]}

        # train
        for site in self.train_sites:
            for foldIdx in range(self.N_fold):
                # add scored samples
                fold_data_json = Path(self.config.path.data_directory)/"CVFolds"/f"cv-{self.N_fold}_order-{self.N_fold_Order}"/site/f"fold_{foldIdx}_scored.json"
                with open(fold_data_json, "r") as f:
                    sampleIds = json.load(f)
                    fold_data = [f"{site}/{sampleId}" for sampleId in sampleIds]
                    fold_data_scored_all["train"][foldIdx]+=fold_data
                
                # add unscored smaples
                fold_data_json = Path(self.config.path.data_directory)/"CVFolds"/f"cv-{self.N_fold}_order-{self.N_fold_Order}"/site/f"fold_{foldIdx}_non_scored.json"
                with open(fold_data_json, "r") as f:
                    sampleIds = json.load(f)
                    fold_data = [f"{site}/{sampleId}" for sampleId in sampleIds]
                    fold_data_unscored_all["train"][foldIdx]+=fold_data

        # eval
        for site in self.eval_sites:
            for foldIdx in range(self.N_fold):
                # add scored samples
                fold_data_json = Path(self.config.path.data_directory)/"CVFolds"/f"cv-{self.N_fold}_order-{self.N_fold_Order}"/site/f"fold_{foldIdx}_scored.json"
                with open(fold_data_json, "r") as f:
                    sampleIds = json.load(f)
                    fold_data = [f"{site}/{sampleId}" for sampleId in sampleIds]
                    fold_data_scored_all["eval"][foldIdx]+=fold_data
                
                # add unscored smaples
                fold_data_json = Path(self.config.path.data_directory)/"CVFolds"/f"cv-{self.N_fold}_order-{self.N_fold_Order}"/site/f"fold_{foldIdx}_non_scored.json"
                with open(fold_data_json, "r") as f:
                    sampleIds = json.load(f)
                    fold_data = [f"{site}/{sampleId}" for sampleId in sampleIds]
                    fold_data_unscored_all["eval"][foldIdx]+=fold_data

        self.fold_data_scored_all = fold_data_scored_all
        self.fold_data_unscored_all = fold_data_unscored_all


    def setup_kfold(self, folds_train, folds_val, folds_test):
        self.folds_train = folds_train
        self.folds_val = folds_val
        self.folds_test = folds_test

    def _get_loader(self, folds, mode):
        num_workers = min(mp.cpu_count()-1, self.config.param_loader.num_workers)
        dataset = RawDataset(self.config,
                            self.fold_data_scored_all.copy(),
                            self.fold_data_unscored_all.copy(),
                            folds = folds,
                            transform=None,
                            mode=mode)
        return DataLoader(dataset=dataset, batch_size=self.config.param_model.batch_size, num_workers=num_workers, shuffle=(mode=="train")) 

    def train_dataloader(self):
        return self._get_loader(self.folds_train, "train")

    def val_dataloader(self):
        return self._get_loader(self.folds_val, "val")

    def test_dataloader(self):
        return self._get_loader(self.folds_test, "test")


class RawDataset(Dataset):
    def __init__(self, config, fold_data_scored_all, fold_data_unscored_all, folds=None, transform=None, mode='train'):
        super(RawDataset, self).__init__()

        # Initialization
        self.config = config
        self.param_aug = config.param_aug
        self.folds = folds
        self.N_folds = config.exp.N_fold
        self.num_leads = config.param_loader.num_leads
        self.param_loader = self.config.param_loader
        self.max_length = int(config.param_feature.raw.sr*config.param_feature.raw.window)
        if self.config.param_loader.get("spectrogram"): self.max_length = int(np.ceil(self.max_length/self.param_loader.spectrogram.HOP_LEN))
        self.mode = mode
        self.random_state = config.exp.random_state
        self.transform = transform if transform else self._get_transform(mode)
        np.random.seed(self.random_state)
        torch.manual_seed(self.random_state)

        # prepare fold
        if not folds is None: # train and validation state
            self.fold_data = []
            for fold in folds:
                if mode=="train":
                    self.fold_data.extend(fold_data_scored_all["train"][fold])
                    if config.exp.N_fold_Use_Unscored:
                        self.fold_data.extend(fold_data_unscored_all["train"][fold])
                else:
                    self.fold_data.extend(fold_data_scored_all["eval"][fold])
                    self.fold_data.extend(fold_data_unscored_all["eval"][fold])

    def __len__(self):
        return len(self.fold_data)

    def _get_transform(self, mode="train"):
        # input dimension: num_lead*len_signal
        tfm = transforms.Compose([])
        
        # time series raw signal in numpy
        if self.param_aug.get("RemoveBaselineWander"): tfm.transforms.append(RemoveBaselineWander(cutoff=self.param_aug.RemoveBaselineWander.cutoff))
        if self.param_aug.get("BandPass"): tfm.transforms.append(BandPass(sr=self.config.param_feature.raw.sr, cutoff=self.param_aug.BandPass.cutoff))

        # # spectrogram
        # if self.config.param_loader.get("spectrogram"):
        #     tfm.transforms.append(MinMaxScaler())
        #     tfm.transforms.append(AsTensor())
        #     tfm.transforms.append(Spectrogram(n_fft=self.param_loader.spectrogram.NFFT, hop_length=self.param_loader.spectrogram.HOP_LEN))
        #     if mode=="train" and self.param_aug.get("RandomCrop"): 
        #         tfm.transforms.append(transforms.RandomCrop(
        #             size=[int(self.param_loader.spectrogram.NFFT/2)+1, self.max_length], padding=0, pad_if_needed=True))
        #     else: tfm.transforms.append(transforms.CenterCrop(size=[int(self.param_loader.spectrogram.NFFT/2)+1, self.max_length]))

        # times sereis
        # else:
        if self.param_aug.get("Rescale"): 
            if self.param_aug.Rescale=="zscore": tfm.transforms.append(Zscore())
            elif self.param_aug.Rescale=="minmax": tfm.transforms.append(MinMaxScaler())
        tfm.transforms.append(transforms.ToTensor())

        if mode=="train":
            if self.param_aug.get("RandomShuflleLead"): tfm.transforms.append(RandomShuflleLead(p=self.param_aug.RandomShuflleLead))
            if self.param_aug.get("RandomLeadMask"): tfm.transforms.append(RandomLeadMask(p=self.param_aug.RandomLeadMask))
            if self.param_aug.get("AddGaussianNoise"): tfm.transforms.append(AddGaussianNoise(self.param_aug.AddGaussianNoise.mean, self.param_aug.AddGaussianNoise.mean))
        if mode=="train" and self.param_aug.get("RandomCrop"): tfm.transforms.append(transforms.RandomCrop(size=[self.num_leads, self.max_length], padding=0, pad_if_needed=True))
        else: tfm.transforms.append(transforms.CenterCrop(size=[self.num_leads, self.max_length]))
        
        return tfm

        

    def __getitem__(self, idx):
        smpale_info = self.fold_data[idx]
        site, sampleId = smpale_info.split("/")
        # site = "WFDB_StPetersburg"
        # sampleId = "I0023"
        # site = "WFDB_ChapmanShaoxing"
        # sampleId = "JS09864"
        # header_path = Path(self.config.path.data_directory)/"raw"/site/(sampleId+".hea")
        # recording_path = Path(self.config.path.data_directory)/"raw"/site/(sampleId+".mat")
        header_path = Path(self.config.path.official_data_directory)/(sampleId+".hea")
        recording_path = Path(self.config.path.official_data_directory)/(sampleId+".mat")
        # TODO load_recording can specify key, not sure different between train / val
        header = load_header(header_path)
        recording = load_recording(recording_path, header=header, leads=self.config.param_loader.leads, key="val")
 
        # get demo
        age, sex = np.asarray(get_age(header)).astype("float32"), get_sex(header)
        sex = map_sex(sex, impute_nan=1)

        # =============================================================================
        # naive preprocess
        # =============================================================================
        # sampling rate
        sr = get_frequency(header)
        adc_gain = get_adc_gains(header, leads=self.config.param_loader.leads)
        baseline = get_baselines(header, leads=self.config.param_loader.leads)
        if sr!=self.config.param_feature.raw.sr:
            recording = filter_signal(recording, sample_rate=sr, cutoff=[3,45], filtertype="bandpass") # bandpass before resample (prevent high-freq noise warping)
            recording = signal_resample(list(recording.T), sampling_rate=sr, desired_sampling_rate=self.config.param_feature.raw.sr, method="FFT").T
        
        # augmentation
        if self.transform:
            recording = self.transform(recording)
            if self.config.param_loader.get("spectrogram") is None:
                recording = recording.squeeze(0)
        recording = recording.float()

        # collect
        x = {"signal":recording, "age":age, "sex":sex}

        # =============================================================================
        # label
        # =============================================================================
        label = replace_equivalent_classes(get_labels(header), EQ_CLASSES)
        label_bin = pd.DataFrame(np.zeros((1,len(SCORED_CLASSES))), columns=SCORED_CLASSES, index=[sampleId])
        labels = [l for l in label if l in SCORED_CLASSES]
        label_bin[labels] = 1
        y = label_bin.values.flatten().astype("float32")
        return x, y



#%% test data loader
# from tqdm import tqdm
# loader = dm.train_dataloader()
# dataset = loader.dataset
# tfm = deepcopy(dataset.transform)
# dataset[2][0]["signal"].shape
# for idx in range(len(loader.dataset)):
#     x,y = dataset[idx]
# for x,y in tqdm(loader):
#     x,y

#%%
# r = recording.copy()
# for tIdx, t in enumerate(tfm.transforms):
#     r = t(r)
#     if len(r.shape)==3:
#         tmp = r.squeeze(0)
#     else:
#         tmp = r
#     plt.figure()
#     plt.plot(tmp[0,:])
#     plt.title(t)
#     print(t, r.shape)


# #%%
# from tqdm import tqdm
# dataset = dm.test_dataloader().dataset
# for dataIdx in tqdm(range(len(dataset))):
#     x, y = dataset[dataIdx]
#     assert(x["signal"].shape==(12,5000))
#     assert(y.shape==(26,))




#%%
# NFFT = 512
# HOP_LEN = 256
# recording = recording.T
# recording = (recording-recording.min(axis=0))/(recording.max(axis=0)-recording.min(axis=0))
# recording = recording.T

#%%
# recording_stft = []
# for lead in recording:
#     spectrogram = librosa.stft(lead, n_fft=NFFT, hop_length=HOP_LEN)
#     recording_stft.append(spectrogram)
# recording_stft = np.asarray(recording_stft)
