Untitled
unknown
python
a year ago
937 B
5
Indexable
def pad_sequences_collate_fn(samples) -> tuple: """ Zero-pad (in front) each sample to enable batching. The longest sequence defines the sequence length for the batch """ # labels = torch.stack([torch.tensor(v[1]) for v in samples]) labels = torch.tensor([v['label_idx'] for v in samples]) # variable dimension must be first # data = pad_sequence([v[0].permute((2, 0, 1)) for v in samples], batch_first=True) data = pad_sequence([v['pixels'].permute((2, 0, 1)) for v in samples], batch_first=True) key_mask = pad_sequence( [ # torch.ones((v[0].shape[-1], v[0].shape[0]), dtype=torch.bool) torch.ones((v['pixels'].shape[-1], v['pixels'].shape[0]), dtype=torch.bool) for v in samples ], padding_value=False, batch_first=True, ) return data.permute((0, 2, 3, 1)), labels, key_mask.permute((0, 2, 1))
Editor is loading...
Leave a Comment