ClassificationTransformsTrain

mail@pastecode.io avatar
unknown
python
a year ago
1.5 kB
2
Indexable
Never
class ClassificationTransformsTrain:
    def __init__(
            self,
            *,
            crop_size,
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225),
            hflip_prob=0.5,
            random_erase_prob,
            resize,
            aspect_ratio_th=4
    ):
        trans = [T.ToPILImage()]
        random_transforms = [
             T.AutoAugment(T.AutoAugmentPolicy.IMAGENET),
             T.ColorJitter(brightness=.5, hue=.3),
             T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
             T.RandomPosterize(bits=2),
             T.RandomSolarize(threshold=192.0),
             T.RandomAdjustSharpness(sharpness_factor=2),
             T.RandomAutocontrast(),
             T.RandomEqualize(),
             T.AugMix(),
             T.RandomInvert()
        ]
        trans.append(ResizeRandomCropCollageTrain(
            resize_size=resize, 
            crop_size=crop_size, 
            aspect_ratio_th=aspect_ratio_th))
        trans.append(T.RandomChoice(transforms=random_transforms))
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        trans.extend(
            [
                T.PILToTensor(),
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=mean, std=std),
            ]
        )
        if random_erase_prob > 0:
            trans.append(T.RandomErasing(p=random_erase_prob))
        self.transforms = T.Compose(trans)

    def __call__(self, img):
        return self.transforms(img)