npartition.py

mail@pastecode.io avatar
unknown
python
a year ago
2.4 kB
20
Indexable
import typing
import queue
import threading
from collections.abc import Iterable, Callable

_T = typing.TypeVar('_T')


def npartition(iterable: Iterable[_T], *predicates: Callable[[_T], bool]) -> tuple[Iterable[_T]]:
    """
    Given an iterable and N unary predicates, returns N + 1 iterables such that
    - for 0 <= i < N, the ith output iterable yields only those elements of the
        input for which the ith predicate, but not any predicate 0 <= j < i,
        returns truthy
    - the Nth output iterable yields only those elements of the input for which
        no predicate returns truthy
    - in the case where N == 0, the sole output will be the input as a forward
        iterator. this will still be inside a tuple, so the user will have to
        index or unpack it to access the iterator.

    :param iterable: Any iterable object. The iterable will be consumed lazily.
    :param predicates: Zero or more unary predicates.
    :return: One or more iterables with the behavior as described above.
    """
    lock = threading.Lock()
    iterable = iter(iterable)
    if not predicates:
        return iterable,

    ret: 'tuple[wrapped]'

    class wrapped:
        def __init__(self, index: int):
            self._index = index
            self._queue: queue.Queue[_T] = queue.Queue()

        def __next__(self):
            with lock:
                try:
                    # queue may be populated by sister instances
                    return self._queue.get_nowait()
                except queue.Empty:
                    # consume the iterable until one passing
                    # my predicate is reached
                    # add others' elements to their queues
                    for x in iterable:
                        for i, pred in enumerate(predicates):
                            if (pred(x)):
                                if i == self._index:
                                    return x
                                ret[i]._queue.put_nowait(x)
                                break
                        else:
                            ret[-1]._queue.put_nowait(x)
                    raise StopIteration

        def __iter__(self):
            return self

    # Assign ret here to expose it to the class above
    ret = tuple(wrapped(i) for i in range(len(predicates) + 1))
    return ret