Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
1.5 kB
1
Indexable
Never
# %%
from torchvision.datasets import ImageFolder
from copy import deepcopy
from torchvision.models import resnet101, ResNet101_Weights
import scipy.io as sio

def get_data(data_dir, transform, split_mat, train_on_all):
    def create_split(dataset, locations):
        split = deepcopy(dataset)
        split.samples = [split.samples[_idx-1] for _idx in locations[:, 0]]
        split.imgs = split.samples
        split.targets = [split.targets[_idx-1] for _idx in locations[:, 0]]
        return split

    whole_cub = ImageFolder(data_dir, transform=transform)
    if not train_on_all:
        train_dataset = create_split(whole_cub, split_mat["train_loc"])
    else:
        train_dataset = deepcopy(whole_cub)
    val_dataset = create_split(whole_cub, split_mat["val_loc"])
    test_seen_dataset = create_split(whole_cub, split_mat["test_seen_loc"])
    test_unseen_dataset = create_split(whole_cub, split_mat["test_unseen_loc"])
    return train_dataset, val_dataset, test_seen_dataset, test_unseen_dataset
# %%
data_dir = '/home/tin/datasets/cub/CUB/images/'
weights = {
            'IMAGENET1K_V1': ResNet101_Weights.IMAGENET1K_V1,
            'IMAGENET1K_V2': ResNet101_Weights.IMAGENET1K_V2,
        }['IMAGENET1K_V2']
# vision_model = resnet101(weights=weights)
transforms = weights.transforms()
split_mat = sio.loadmat('att_splits.mat')
train_on_all = False
train_dataset, val_dataset, test_seen_dataset, test_unseen_dataset = \
        get_data(data_dir, transforms, split_mat,
                 train_on_all)

# %%
train_dataset