Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
4.3 kB
167
Indexable
Never
import sys
from typing import List, Union, Tuple
from functools import lru_cache

@lru_cache(None)
def read_input() -> List[str]:
    lines = [line.strip() for line in sys.stdin]
    return lines

def get_pair(a: str, i: int) -> Union[None, List[int]]:
    close_index = a.find(']', i) # Try to find closing ]. There must be one
    if a[i+1:close_index].find('[') != -1: # Check if we have addition opening [ between i and close_index.
        return None
    
    return [int(val) for val in a[i+1:close_index].split(',')]

def find_prev_num_index(a: str, i: int) -> Tuple[int, str]:
    while i >= 0 and not a[i].isdigit():
        i -= 1

    if i == -1:
        return None, None

    j = i
    while a[j].isdigit():
        j -= 1
    
    return j+1, a[j+1:i+1]

def find_next_num_index(a: str, i: int) -> Tuple[int, str]:
    while i < len(a) and not a[i].isdigit():
        i += 1

    if i == len(a):
        return None, None
    
    j = i
    while a[j].isdigit():
        j += 1

    return i, a[i:j]

def try_explode(a: str) -> Tuple[bool, str]:
    i, depth, pair_to_explode = 0, 0, None

    while i < len(a):
        if a[i] == '[':
            depth += 1
        elif a[i] == ']':
            depth -= 1
        
        # At depth of 5, try to find a pair
        if depth == 5 and a[i] == '[':
            pair_to_explode = get_pair(a, i)
            if pair_to_explode:
                break

        i += 1

    # Return original if we did not find a pair to explode
    if not pair_to_explode:
        return False, a

    # Otherwise, explode
    pair_str = f'[{pair_to_explode[0]},{pair_to_explode[1]}]'
    prev_num_index, prev_num_str = find_prev_num_index(a, i)
    next_num_index, next_num_str = find_next_num_index(a, i + len(pair_str))

    if prev_num_index and next_num_index:
        a = a[:prev_num_index] + str(int(prev_num_str) + pair_to_explode[0]) + a[prev_num_index+len(prev_num_str):next_num_index] + str(int(next_num_str) + pair_to_explode[1]) + a[next_num_index + len(next_num_str):]
        i = a.find(pair_str, prev_num_index)

    elif prev_num_index:    
        a = a[:prev_num_index] + str(int(prev_num_str) + pair_to_explode[0]) + a[prev_num_index + len(prev_num_str):]
        i = a.find(pair_str, prev_num_index)

    elif next_num_index:
        a = a[:next_num_index] + str(int(next_num_str) + pair_to_explode[1]) + a[next_num_index + len(next_num_str):]
        i = a.rfind(pair_str, 0, next_num_index)

    # Zero out pair
    a = a[:i] + '0' + a[i+len(pair_str):]
    
    return True, a

def try_split(a: str) -> Tuple[bool, str]:
    i = 0
    num = None

    while i < len(a):
        i, num = find_next_num_index(a, i)
        if not i:
            return False, a
        if int(num) >= 10:
            break
        
        i += len(num)

    num = int(num)
    lb = num//2
    ub = num//2 + (num%2 != 0)
    return True, a[:i] + f'[{lb},{ub}]' + a[i+len(str(num)):]

def calc_magnitude(a):
    stack = []
    i = 0
    while i < len(a):
        if a[i] == '[':
            stack.append(a[i])
            i += 1

        elif a[i].isdigit():
            ni, ns = find_next_num_index(a, i)
            stack.append(ns)
            i += len(ns)

        elif a[i] == ']':
            second, first = stack.pop(), stack.pop()
            stack.pop()
            stack.append(str(3 * int(first) + 2 * int(second)))
            i += 1

        else:
            i += 1
    
    return int(stack.pop())
    

def add_nums(a, b):
    can_add = True
    a = f'[{a},{b}]'
    while can_add:
        can_add, a = try_explode(a)
        if not can_add:
            can_add, a = try_split(a)

    return a

def solve():
    print('PART 1')
    nums = read_input()
    res = nums[0]
    for n in nums[1:]:
        res = add_nums(res, n)

    print(res)
    print(calc_magnitude(res))

def solve2():
    print('PART 2')
    nums = read_input()
    max_mag = 0

    for first in nums:
        for second in nums:
            if first != second:
                res = add_nums(first, second)
                max_mag = max(max_mag, calc_magnitude(res))
               
    print(max_mag)

solve()
solve2()