Untitled
unknown
plain_text
2 years ago
1.5 kB
7
Indexable
# %%
from torchvision.datasets import ImageFolder
from copy import deepcopy
from torchvision.models import resnet101, ResNet101_Weights
import scipy.io as sio
def get_data(data_dir, transform, split_mat, train_on_all):
def create_split(dataset, locations):
split = deepcopy(dataset)
split.samples = [split.samples[_idx-1] for _idx in locations[:, 0]]
split.imgs = split.samples
split.targets = [split.targets[_idx-1] for _idx in locations[:, 0]]
return split
whole_cub = ImageFolder(data_dir, transform=transform)
if not train_on_all:
train_dataset = create_split(whole_cub, split_mat["train_loc"])
else:
train_dataset = deepcopy(whole_cub)
val_dataset = create_split(whole_cub, split_mat["val_loc"])
test_seen_dataset = create_split(whole_cub, split_mat["test_seen_loc"])
test_unseen_dataset = create_split(whole_cub, split_mat["test_unseen_loc"])
return train_dataset, val_dataset, test_seen_dataset, test_unseen_dataset
# %%
data_dir = '/home/tin/datasets/cub/CUB/images/'
weights = {
'IMAGENET1K_V1': ResNet101_Weights.IMAGENET1K_V1,
'IMAGENET1K_V2': ResNet101_Weights.IMAGENET1K_V2,
}['IMAGENET1K_V2']
# vision_model = resnet101(weights=weights)
transforms = weights.transforms()
split_mat = sio.loadmat('att_splits.mat')
train_on_all = False
train_dataset, val_dataset, test_seen_dataset, test_unseen_dataset = \
get_data(data_dir, transforms, split_mat,
train_on_all)
# %%
train_datasetEditor is loading...