Untitled

 avatar
unknown
python
2 months ago
1.4 kB
4
Indexable
from typing import Callable, Generic, Sequence, TypeVar, overload, TypeAlias

import numpy as np
import torch

InputT = TypeVar("InputT", np.ndarray, torch.Tensor)
OutputT = TypeVar("OutputT", np.ndarray, torch.Tensor)
# InputT: TypeAlias = np.ndarray | torch.Tensor


def _wrap(elem: np.ndarray | torch.Tensor) -> torch.Tensor:
    if isinstance(elem, np.ndarray):
        return torch.from_numpy(elem)
    return elem


class Wrapper(Generic[InputT, OutputT], Sequence[OutputT]):
    def __init__(
        self,
        seq: Sequence[InputT],
        wrapper_func: Callable[[InputT], OutputT],
    ) -> None:
        self._seq: Sequence[InputT] = seq
        self._wrapper: Callable[[InputT], OutputT] = wrapper_func

    def __len__(self) -> int:
        return len(self._seq)

    @overload
    def __getitem__(self, idx: int) -> OutputT: ...

    @overload
    def __getitem__(self, idx: slice) -> Sequence[OutputT]: ...

    def __getitem__(self, idx: int | slice) -> OutputT | Sequence[OutputT]:
        if isinstance(idx, int):
            return self._wrapper(self._seq[idx])
        else:
            return [self._wrapper(elem) for elem in self._seq[idx]]


def some_function(data: Sequence[torch.Tensor]) -> None:
    print("Just to check types")


def main(data: Sequence[np.ndarray] | Sequence[torch.Tensor], flag: bool):
    if flag:
        data = Wrapper(data, _wrap)

    # some_function(data)
Editor is loading...
Leave a Comment