Untitled

 avatar
unknown
plain_text
5 months ago
5.2 kB
4
Indexable
from attrs import define, field
from collections import defaultdict, deque
from itertools import product
from typing import Union

from utils import *


Adjacency_List = list[list[int]]


@define(auto_attribs=True)
class Graph:
    adj_list: Adjacency_List
    hole_node: int

    @property
    def reverse_adj_list(self) -> Adjacency_List:
        radj: Adjacency_List = [list() for _ in range(len(self.adj_list))]
        for node, neighbours in enumerate(iterable=self.adj_list):
            for nei in neighbours:
                radj[nei].append(node)
        return radj


@define(auto_attribs=True)
class Position:
    node: int
    graph: Graph
    blocked: list[int] = field(default=list())

    @property
    def is_valid(self) -> bool:
        return self.node not in self.blocked

    def fetch_prev_positions(self) -> list["Position"]:
        return [
            Position(node=adj, graph=self.graph, blocked=self.blocked) for adj in self.graph.reverse_adj_list[self.node]
        ]

    def __eq__(self, pos: Union[int, "Position"]) -> bool:
        return self.node == pos.node if isinstance(pos, Position) else self.node == pos


@define(auto_attribs=True)
class State:
    cat_pos: Position
    mouse_pos: Position
    turn: Turn = Turn.MOUSE

    @property
    def is_valid(self) -> bool:
        return self.cat_pos.is_valid and self.mouse_pos.is_valid

    def fetch_previous_states(self) -> list["State"]:
        positions = [[self.cat_pos], [self.mouse_pos]]
        positions[1 - self.turn] = positions[1 - self.turn][0].fetch_prev_positions()
        return list(
            filter(
                lambda state: state.is_valid,
                [State(*args, turn=Turn(value=1 - self.turn)) for args in product(*positions)],
            )
        )

    def __hash__(self) -> int:
        return hash((self.cat_pos.node, self.mouse_pos.node, self.turn))


def compute_state(present_state: State) -> Result:
    result_dict: defaultdict[State, Result] = defaultdict(lambda: Result.DRAW)
    in_degree: defaultdict[State, int] = DefaultDict(lambda state: len(state.fetch_previous_states()))

    graph = present_state.cat_pos.graph
    hole_position = Position(node=graph.hole_node, graph=graph)

    # Define winning states - A state is winning if the mouse is in the hole and the next turn is CAT's
    # Additionally, a state is winning if its connected to at least one winning state
    winning_states: list[State] = [
        State(
            cat_pos=Position(node=pos, graph=graph, blocked=present_state.cat_pos.blocked),
            mouse_pos=Position(
                node=hole_position.node, graph=hole_position.graph, blocked=present_state.mouse_pos.blocked
            ),
            turn=Turn.CAT,
        )
        for pos in range(
            len(graph.adj_list)
        )  # BONUS: Iterate only through nodes accessible to cat instead of all nodes to trim search
    ]

    # Define losing states - A state is losing if the mouse and cat are in the same node, irrespective of the turn
    # Additionally, a state is losing if the state is only connected to losing states
    losing_states: list[State] = [
        State(
            cat_pos=Position(node=pos, graph=graph, blocked=present_state.cat_pos.blocked),
            mouse_pos=Position(node=pos, graph=graph, blocked=present_state.mouse_pos.blocked),
            turn=turn,
        )
        for pos, turn in product(
            range(len(graph.adj_list)), list(Turn)
        )  # BONUS: Iterate through nodes accessible to both cat and mouse instead of all nodes to trim search
    ]

    states: deque[tuple[State, Result]] = deque(
        iterable=[(state, Result.LOSE) for state in losing_states if state.is_valid]
        + [(state, Result.WIN) for state in winning_states if state.is_valid]
    )

    for state, res in states:
        in_degree[state], result_dict[state] = 0, res

    while states:
        state, res = states.popleft()
        for previous_state in state.fetch_previous_states():
            if in_degree[previous_state]:
                match res, previous_state.turn:
                    case Result.LOSE, Turn.CAT:
                        result_dict[previous_state], in_degree[previous_state] = Result.LOSE, 0

                    case Result.LOSE, Turn.MOUSE:
                        in_degree[previous_state] -= 1
                        if not in_degree[previous_state]:
                            result_dict[previous_state] = Result.LOSE

                    case Result.WIN, Turn.MOUSE:
                        result_dict[previous_state], in_degree[previous_state] = Result.WIN, 0

                    case Result.WIN, Turn.CAT:
                        in_degree[previous_state] -= 1
                        if not in_degree[previous_state]:
                            result_dict[previous_state] = Result.WIN

                if not in_degree[previous_state]:
                    states.append((previous_state, result_dict[previous_state]))

    return result_dict[present_state]
Editor is loading...
Leave a Comment