import torchvision.transforms as transforms

import src.augmentation.transforms as T
from src.augmentation.methods import RandAugmentation, RandAugmentationWithProb

"""
base_transform은 모든 데이터에 다 적용
그 외의 transform은 supervised training 또는 contrastive learning(SSL)에 적용
ex) example_transform(base_transform(x))
"""


def base_transform(
    random_crop_size=3000
):
    return transforms.Compose([
        T.RandomCrop(random_crop_size)
    ])


def example_transform(
    n_select=2,
):
    # those transforms should be in [src.augmentation.transforms]
    transforms_info = [
        ("Jittering", 0.02),
        ("Scaling", 0.2),
        ("HorzontalFlip"),
        ("Cutout"),
        ("Baseline_Shift"),
        ("Gaussian_Blur"),
        ("Baseline_Wander")
    ]
    return transforms.Compose([
        RandAugmentation(
            transforms_info=transforms_info,
            n_select=n_select
        )
    ])

def band_pass(n_select=2,):
    return transforms.Compose([
        T.ButterFilter(btype='band')
    ])

def sup_aug_comb1(n_select=2, trans_prob=0.2):
    # those transforms should be in [src.augmentation.transforms]
    transforms_info = [
        ("Cutout"),
        ("Gaussian_Blur"),
    ]
    return transforms.Compose([
        RandAugmentationWithProb(
            transforms_info=transforms_info,
            n_select=n_select,
            trans_prob=trans_prob,
        )
    ])
