aoc2021day06 using matrix exp

mail@pastecode.io avatar
unknown
python
3 years ago
1.4 kB
5
Indexable
import sys
import numpy as np

# Change to 80 for part 1
DAYS = 256

# Multiplying fish_v with this matrix returns the fish_v one day later
TIME_SHIFT_M = np.array([
    [0, 1, 0, 0, 0, 0, 0, 0, 0], # Simple timer decrease
    [0, 0, 1, 0, 0, 0, 0, 0, 0], # Simple timer decrease
    [0, 0, 0, 1, 0, 0, 0, 0, 0], # Simple timer decrease
    [0, 0, 0, 0, 1, 0, 0, 0, 0], # Simple timer decrease
    [0, 0, 0, 0, 0, 1, 0, 0, 0], # Simple timer decrease
    [0, 0, 0, 0, 0, 0, 1, 0, 0], # Simple timer decrease
    [1, 0, 0, 0, 0, 0, 0, 1, 0], # Simple timer decrease + fish that just had a baby
    [0, 0, 0, 0, 0, 0, 0, 0, 1], # Simple timer decrease
    [1, 0, 0, 0, 0, 0, 0, 0, 0], # Babies
])

def main(input_file=sys.stdin, output_file=sys.stdout):
    # Read input vector.
    initial_fish = [int(e) for e in input_file.readline().split(',')]

    # Compute fish count vector
    # fish_v[i] = number of fish with timer at i
    fish_v = np.zeros(9, dtype=np.int)
    for e in initial_fish:
        fish_v[e] += 1

    # We want to compute (TIME_SHIFT_M ** DAYS ) * fish_v.
    # Compute first factor efficiently using square-multiply in log(DAYS) time.
    shift_matrix = np.linalg.matrix_power(TIME_SHIFT_M, DAYS)
    fish_v = shift_matrix.dot(fish_v)
    
    # Return number of fish in end-state by simply summing vector values.
    output_file.write(str(fish_v.sum()))


if __name__ == "__main__":
    main()