import pickle

import numpy as np
import torch
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from helper_code import *
from src.data_preparation.dataset_class import (Dataset_ECG,
                                                Dataset_ECG_SimCLR)
from src.setup import LEAD_DICT
from src.utils.preprocess import preprocess_recording
from src.utils.team_helper_code import get_features, get_id, get_recording_mean_per_lead, get_sample_id_prefix


def get_data_info_dict(lead_str, data_directory, configs):
    """
    create data_info_dict from files
        - if develop is True get data_info_dict from pickle files
        - else from recording, header files

    containing every sample regardless of their lables

    return:
    dict: {
        "headers": list of header per sample,
        "sample_ids": list of sample id per sample,
        "features": list of feature per sample,
        "recordings": list of recording per sample,
        "labeles": list of label per sample ([[0, 1, 0, 1, 0, ...], [0, 0, 0, 0, 1, ...], ...]),
        "classes": list of all classes
        "trans_base": base transform from configs
        "trans_train": train transform from configs (base + train)
        "trans_contrast": contrast transform from configs (base + contrast) 
    }
    """
    if not configs["DEVELOP"]:
        print("Finding header and recording files...")
        header_files, recording_files = find_challenge_files(
            data_directory
        )
        all_classes = set()
        for header_file in header_files:
            header = load_header(header_file)
            all_classes |= set(get_labels(header))
        all_classes.discard('') # remove empty class
        if all(is_integer(x) for x in all_classes):
            all_classes = sorted(
                all_classes, key=lambda x: int(x)
            )  # Sort classes numerically if numbers.
        else:
            all_classes = sorted(all_classes)
        classes = all_classes
        leads = getattr(__import__("team_code", fromlist=[""]), LEAD_DICT[lead_str])
        data_info_dict = {
            "headers": [],
            "sample_ids": [],
            "features": [],
            "recordings": [],
            "labels": [],
            "mean_per_lead": []
        }
        for i in tqdm(range(len(recording_files)), desc="Extracting Data"):
            header = load_header(header_files[i])
            sample_id = get_id(header)

            sample_id_prefix = get_sample_id_prefix(sample_id)
            if sample_id_prefix in ["I", "S"]: # Skip StPetersburg(I), PTB(S)
                continue
            
            recording = load_recording(recording_files[i])
            recording_mean_per_lead = get_recording_mean_per_lead(recording)
            recording = preprocess_recording(header, recording, leads, configs["PREPROCESS_PARAMS"])
            feature = get_features(header, recording, leads, configs["PREPROCESS_PARAMS"])
            label = [0] * len(classes)
            for l in get_labels(header):
                if l != '': # remove empty class
                    i = classes.index(l)
                    label[i] = 1
            
            data_info_dict["headers"].append(header)
            data_info_dict["sample_ids"].append(sample_id)
            data_info_dict["features"].append(torch.Tensor(feature))
            data_info_dict["recordings"].append(torch.Tensor(recording))
            data_info_dict["labels"].append(torch.Tensor(label))
            data_info_dict["mean_per_lead"].append(recording_mean_per_lead)

        data_info_dict["classes"] = classes

    else:
        with open(configs["PICKLE_PATH"] + f"/{lead_str}_simple_preprocessed_train_data_2021.pkl", "rb") as f:
            data_info_dict = pickle.load(f)
    
    ### Transformations here!
    base_transform = getattr(
        __import__("src.augmentation.policies", fromlist=[""]), configs["TRANS_BASE"]
    )(**configs["TRANS_BASE_PARAMS"])

    data_info_dict["trans_base"] = base_transform

    if configs["TRANS_TRAIN"] is not None:
        data_info_dict["trans_train"] = transforms.Compose([
            base_transform,
            getattr(
                __import__("src.augmentation.policies", fromlist=[""]), configs["TRANS_TRAIN"]
            )(**configs["TRANS_TRAIN_PARAMS"])
        ])
    else:
        data_info_dict["trans_train"] = base_transform

    data_info_dict["trans_contrast"] = transforms.Compose([
        base_transform,
        getattr(
            __import__("src.augmentation.policies", fromlist=[""]), configs["TRANS_CONTRAST"]
        )(**configs["TRANS_CONTRAST_PARAMS"])
    ])

    return data_info_dict

def return_indices(dataset_size, valid_ratio, indices_dict={}):
    """
        always return the same train/test indices
        when given same dataset_size & valid_ratio
    """
    
    if (dataset_size, valid_ratio) in indices_dict.keys():
        return indices_dict[(dataset_size, valid_ratio)]

    indices = list(range(dataset_size))
    split = int(np.floor(valid_ratio * dataset_size))
    np.random.shuffle(indices)
    train_indices, valid_indices = indices[split:], indices[:split]

    indices_dict[(dataset_size, valid_ratio)] = (train_indices, valid_indices)

    return train_indices, valid_indices

def get_supervised_dataset(data_info_dict, classes, valid_ratio):
    """
    create supervised train, valid dataset
    
    remove untargeted dataset (which does not contain label in classes)
    and
    split train, valid in a ratio of (1 - valid_ratio) : valid ratio
    """
    dataset = Dataset_ECG(data_info_dict, classes, lambda x: x)
    
    dataset_size = len(dataset.recordings)
    train_indices, valid_indices = return_indices(dataset_size, valid_ratio)

    train_data_info_dict = {}
    valid_data_info_dict = {}

    for data_info_key in data_info_dict.keys():
        if hasattr(dataset, data_info_key) and len(getattr(dataset, data_info_key)) == dataset_size:
            train_data_info_dict[data_info_key] = list(map(lambda i: getattr(dataset, data_info_key)[i], train_indices))
            valid_data_info_dict[data_info_key] = list(map(lambda i: getattr(dataset, data_info_key)[i], valid_indices))
    
    train_data_info_dict["classes"] = classes
    valid_data_info_dict["classes"] = classes

    train_dataset = Dataset_ECG(train_data_info_dict, classes, data_info_dict["trans_train"])
    valid_dataset = Dataset_ECG(valid_data_info_dict, classes, data_info_dict["trans_base"])
    
    return train_dataset, valid_dataset


def get_simclr_dataset(data_info_dict):
    return Dataset_ECG_SimCLR(data_info_dict, data_info_dict["trans_contrast"])


def get_dataloader(dataset, batch_size, num_workers=4):
    return torch.utils.data.DataLoader(
        dataset,
        pin_memory=True,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
    )


def get_ssl_dataloader(dataset, batch_size, num_workers=4):
    return torch.utils.data.DataLoader(
        dataset,
        pin_memory=True,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        drop_last=True
    )
