Untitled
unknown
plain_text
2 years ago
1.5 kB
5
Indexable
# %% from torchvision.datasets import ImageFolder from copy import deepcopy from torchvision.models import resnet101, ResNet101_Weights import scipy.io as sio def get_data(data_dir, transform, split_mat, train_on_all): def create_split(dataset, locations): split = deepcopy(dataset) split.samples = [split.samples[_idx-1] for _idx in locations[:, 0]] split.imgs = split.samples split.targets = [split.targets[_idx-1] for _idx in locations[:, 0]] return split whole_cub = ImageFolder(data_dir, transform=transform) if not train_on_all: train_dataset = create_split(whole_cub, split_mat["train_loc"]) else: train_dataset = deepcopy(whole_cub) val_dataset = create_split(whole_cub, split_mat["val_loc"]) test_seen_dataset = create_split(whole_cub, split_mat["test_seen_loc"]) test_unseen_dataset = create_split(whole_cub, split_mat["test_unseen_loc"]) return train_dataset, val_dataset, test_seen_dataset, test_unseen_dataset # %% data_dir = '/home/tin/datasets/cub/CUB/images/' weights = { 'IMAGENET1K_V1': ResNet101_Weights.IMAGENET1K_V1, 'IMAGENET1K_V2': ResNet101_Weights.IMAGENET1K_V2, }['IMAGENET1K_V2'] # vision_model = resnet101(weights=weights) transforms = weights.transforms() split_mat = sio.loadmat('att_splits.mat') train_on_all = False train_dataset, val_dataset, test_seen_dataset, test_unseen_dataset = \ get_data(data_dir, transforms, split_mat, train_on_all) # %% train_dataset
Editor is loading...