Untitled
from typing import List from collections import namedtuple import random import time import sys sys.setrecursionlimit(10000000) Data = namedtuple('Data', ['char', 'cnt']) LEFT = 'LEFT' RIGHT = 'RIGHT' def count_substring(a): return a * (a + 1) // 2 def dp(data: List[Data], seen: List[int], idx: int, direction: str, memo: dict): inputs = (idx, direction) if inputs in memo: return memo[inputs] m = len(data) res = 0 curr_idx = seen[idx] if direction == LEFT: if idx == len(seen) - 1: for j in range(curr_idx + 1, len(data)): res += count_substring(data[j].cnt) else: next_idx = seen[idx+1] for j in range(curr_idx + 1, next_idx - 1): res += count_substring(data[j].cnt) res += max(count_substring(data[next_idx-1].cnt) + dp(data, seen, idx + 1, RIGHT, memo), count_substring(data[next_idx-1].cnt + data[next_idx].cnt) + dp(data, seen, idx + 1, LEFT, memo)) else: # direction == RIGHT if idx == len(seen) - 1: res += count_substring(data[curr_idx].cnt + data[curr_idx+1].cnt) for j in range(curr_idx + 2, len(data)): res += count_substring(data[j].cnt) else: next_idx = seen[idx+1] if next_idx - curr_idx == 2: res += max(count_substring(data[curr_idx].cnt + data[curr_idx+1].cnt) + dp(data, seen, idx + 1, RIGHT, memo), count_substring(data[curr_idx].cnt + data[curr_idx+1].cnt + data[next_idx].cnt) + dp(data, seen, idx + 1, LEFT, memo)) else: for j in range(curr_idx, next_idx-1): res += count_substring(data[j].cnt) res += max(count_substring(data[next_idx-1].cnt) + dp(data, seen, idx + 1, RIGHT, memo), count_substring(data[next_idx-1].cnt + data[next_idx].cnt) + dp(data, seen, idx + 1, LEFT, memo)) memo[inputs] = res return res def solve(color: str): memo = {} # pre-processing data = [] prev = None cnt = 0 for i, v in enumerate(color): if v == prev: cnt += 1 else: data.append(Data(prev, cnt)) prev = v cnt = 1 data.append(Data(color[-1], cnt)) data.pop(0) # print(f'data {data}') # head and tail if len(data) == 1: return count_substring(data[0].cnt) if data[0].char == '.': data[1] = Data(data[1].char, data[0].cnt + data[1].cnt) data.pop(0) if data[-1][0] == '.': data[-2] = Data(data[-2].char, data[-1].cnt + data[-2].cnt) data.pop() # merge for i in range(1, len(data) - 1): if data[i].char == '.' and data[i-1] and data[i-1].char == data[i+1].char: data[i+1] = Data(data[i+1].char, data[i-1].cnt + data[i].cnt + data[i+1].cnt) data[i] = None data[i-1] = None data = [x for x in data if x is not None] # print(f'data {data}') m = len(data) # dp memo = {} seen = [] for i in range(len(data)): if data[i].char == '.': seen.append(i) res = 0 for i in range(seen[0] - 1): res += count_substring(data[i].cnt) return res + max(count_substring(data[seen[0]-1].cnt) + dp(data, seen, 0, RIGHT, memo), count_substring(data[seen[0]-1].cnt + data[seen[0]].cnt) + dp(data, seen, 0, LEFT, memo)) color = '.bb.a.aa.' print(f'solve {color}: {solve(color)}') print() def generate_test(n): color = '' for i in range(n): # choose from low-case letters and dot color += random.choice('abcdefghijklmnopqrstuvwxyz....................') return color for i in range(2, 8): color = generate_test(10 ** i) start = time.time() print(solve(color)) print(f'len {len(color)}') print(f'time {(time.time() - start) * 1000} ms')
Leave a Comment