Untitled
unknown
python
2 months ago
1.4 kB
4
Indexable
from typing import Callable, Generic, Sequence, TypeVar, overload, TypeAlias import numpy as np import torch InputT = TypeVar("InputT", np.ndarray, torch.Tensor) OutputT = TypeVar("OutputT", np.ndarray, torch.Tensor) # InputT: TypeAlias = np.ndarray | torch.Tensor def _wrap(elem: np.ndarray | torch.Tensor) -> torch.Tensor: if isinstance(elem, np.ndarray): return torch.from_numpy(elem) return elem class Wrapper(Generic[InputT, OutputT], Sequence[OutputT]): def __init__( self, seq: Sequence[InputT], wrapper_func: Callable[[InputT], OutputT], ) -> None: self._seq: Sequence[InputT] = seq self._wrapper: Callable[[InputT], OutputT] = wrapper_func def __len__(self) -> int: return len(self._seq) @overload def __getitem__(self, idx: int) -> OutputT: ... @overload def __getitem__(self, idx: slice) -> Sequence[OutputT]: ... def __getitem__(self, idx: int | slice) -> OutputT | Sequence[OutputT]: if isinstance(idx, int): return self._wrapper(self._seq[idx]) else: return [self._wrapper(elem) for elem in self._seq[idx]] def some_function(data: Sequence[torch.Tensor]) -> None: print("Just to check types") def main(data: Sequence[np.ndarray] | Sequence[torch.Tensor], flag: bool): if flag: data = Wrapper(data, _wrap) # some_function(data)
Editor is loading...
Leave a Comment