ClassificationTransformsTrain
unknown
python
2 years ago
1.5 kB
5
Indexable
class ClassificationTransformsTrain: def __init__( self, *, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5, random_erase_prob, resize, aspect_ratio_th=4 ): trans = [T.ToPILImage()] random_transforms = [ T.AutoAugment(T.AutoAugmentPolicy.IMAGENET), T.ColorJitter(brightness=.5, hue=.3), T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), T.RandomPosterize(bits=2), T.RandomSolarize(threshold=192.0), T.RandomAdjustSharpness(sharpness_factor=2), T.RandomAutocontrast(), T.RandomEqualize(), T.AugMix(), T.RandomInvert() ] trans.append(ResizeRandomCropCollageTrain( resize_size=resize, crop_size=crop_size, aspect_ratio_th=aspect_ratio_th)) trans.append(T.RandomChoice(transforms=random_transforms)) if hflip_prob > 0: trans.append(T.RandomHorizontalFlip(hflip_prob)) trans.extend( [ T.PILToTensor(), T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] ) if random_erase_prob > 0: trans.append(T.RandomErasing(p=random_erase_prob)) self.transforms = T.Compose(trans) def __call__(self, img): return self.transforms(img)
Editor is loading...