Untitled

 avatar
unknown
python
a year ago
937 B
5
Indexable
def pad_sequences_collate_fn(samples) -> tuple:
    """
    Zero-pad (in front) each sample to enable batching.
    The longest sequence defines the sequence length for the batch
    """
    # labels = torch.stack([torch.tensor(v[1]) for v in samples])
    labels = torch.tensor([v['label_idx'] for v in samples])

    # variable dimension must be first
    # data = pad_sequence([v[0].permute((2, 0, 1)) for v in samples], batch_first=True)
    data = pad_sequence([v['pixels'].permute((2, 0, 1)) for v in samples], batch_first=True)

    key_mask = pad_sequence(
        [
            # torch.ones((v[0].shape[-1], v[0].shape[0]), dtype=torch.bool)
            torch.ones((v['pixels'].shape[-1], v['pixels'].shape[0]), dtype=torch.bool)
            for v in samples
        ],
        padding_value=False,
        batch_first=True,
    )

    return data.permute((0, 2, 3, 1)), labels, key_mask.permute((0, 2, 1))
Editor is loading...
Leave a Comment