Untitled

 avatar
unknown
python
2 months ago
2.7 kB
4
Indexable
from advent_of_code.lib.all import *


def get_input(file):
    raw = aoc.read_input(2024, 20, file)
    grid = aoc_parse.as_lines(raw)
    m = {r+1j*c:x for r, line in enumerate(grid) for c, x in enumerate(line)}
    return m


def shortest(m, start, end):
    q = [(start, 0)]
    seen = set([start])
    while q:
        p, l = q.pop()
        if p == end:
            return l
        # visit neighbours
        for step in (1, -1, 1j, -1j):
            new = p + step
            if new in m and m[new] != "#" and new not in seen:
                q.append((new, l+1))
                seen.add(new)
    return None


def bfs(m, start):
    res = {start:0}
    q = [(0, t:=0, start)]
    seen = set([start])
    while q:
        l, _, p = heappop(q)
        res[p] = l
        # visit neighbours
        for step in (1, -1, 1j, -1j):
            new = p + step
            if new in m and m[new] != "#" and new not in seen:
                heappush(q, (l+1, t:=t+1, new))
                seen.add(new)
    return res


def dfs(m, start, end_dist, target, cheat_length):
    dirs = (1, -1, 1j, -1j)
    q = [(start, 0)]
    seen = {start}
    # set of cheats as (start pos, end pos)
    cheats = set()
    while q:
        # regular dfs
        p, l = q.pop()
        for step in dirs:
            new = p + step
            if m[new] != "#" and l+1 <= target and new not in seen:
                q.append((new, l+1))
                seen.add(new)
        # use cheat: dfs inside dfs
        q2 = [(p, 0)]
        seen2 = {p}
        while q2:
            p2, l2 = q2.pop()
            if m[p2] != "#" and end_dist[p2] + l + l2 <= target and (p, p2) not in cheats:
                cheats.add((p, p2))
            for step2 in dirs:
                new2 = p2 + step2
                if new2 in m and new2 not in seen2 and l2 +1 <= cheat_length:
                    q2.append((new2, l2+1))
                    seen2.add(new2)
    return len(cheats)


@aoc.pretty_solution(1)
def part1(m):
    start, end = map(lambda x: next(p for p in m if m[p] == x), ["S", "E"])
    target = shortest(m, start, end) - 100
    end_dist = bfs(m, end)
    return dfs(m, start, end_dist, target, 2)


@aoc.pretty_solution(2)
def part2(m):
    start, end = map(lambda x: next(p for p in m if m[p] == x), ["S", "E"])
    target = shortest(m, start, end) - 50
    end_dist = bfs(m, end)
    return dfs(m, start, end_dist, target, 20)


def test():
    data = get_input("input.txt")
    part1(deepcopy(data))
    part2(data)
    # assert part1(deepcopy(data)) == 1369
    # assert part2(data) == 
    print("Test OK")


if __name__ == "__main__":
    test()
Leave a Comment