Untitled
unknown
python
10 months ago
1.4 kB
14
Indexable
from typing import Callable, Generic, Sequence, TypeVar, overload, TypeAlias
import numpy as np
import torch
InputT = TypeVar("InputT", np.ndarray, torch.Tensor)
OutputT = TypeVar("OutputT", np.ndarray, torch.Tensor)
# InputT: TypeAlias = np.ndarray | torch.Tensor
def _wrap(elem: np.ndarray | torch.Tensor) -> torch.Tensor:
if isinstance(elem, np.ndarray):
return torch.from_numpy(elem)
return elem
class Wrapper(Generic[InputT, OutputT], Sequence[OutputT]):
def __init__(
self,
seq: Sequence[InputT],
wrapper_func: Callable[[InputT], OutputT],
) -> None:
self._seq: Sequence[InputT] = seq
self._wrapper: Callable[[InputT], OutputT] = wrapper_func
def __len__(self) -> int:
return len(self._seq)
@overload
def __getitem__(self, idx: int) -> OutputT: ...
@overload
def __getitem__(self, idx: slice) -> Sequence[OutputT]: ...
def __getitem__(self, idx: int | slice) -> OutputT | Sequence[OutputT]:
if isinstance(idx, int):
return self._wrapper(self._seq[idx])
else:
return [self._wrapper(elem) for elem in self._seq[idx]]
def some_function(data: Sequence[torch.Tensor]) -> None:
print("Just to check types")
def main(data: Sequence[np.ndarray] | Sequence[torch.Tensor], flag: bool):
if flag:
data = Wrapper(data, _wrap)
# some_function(data)
Editor is loading...
Leave a Comment