Untitled
unknown
python
a year ago
937 B
9
Indexable
def pad_sequences_collate_fn(samples) -> tuple:
"""
Zero-pad (in front) each sample to enable batching.
The longest sequence defines the sequence length for the batch
"""
# labels = torch.stack([torch.tensor(v[1]) for v in samples])
labels = torch.tensor([v['label_idx'] for v in samples])
# variable dimension must be first
# data = pad_sequence([v[0].permute((2, 0, 1)) for v in samples], batch_first=True)
data = pad_sequence([v['pixels'].permute((2, 0, 1)) for v in samples], batch_first=True)
key_mask = pad_sequence(
[
# torch.ones((v[0].shape[-1], v[0].shape[0]), dtype=torch.bool)
torch.ones((v['pixels'].shape[-1], v['pixels'].shape[0]), dtype=torch.bool)
for v in samples
],
padding_value=False,
batch_first=True,
)
return data.permute((0, 2, 3, 1)), labels, key_mask.permute((0, 2, 1))Editor is loading...
Leave a Comment