Untitled
user_5581171
plain_text
a year ago
1.2 kB
5
Indexable
def main():
import sys
input = sys.stdin.read
data = input().splitlines()
n, q = map(int, data[0].split())
T = list(map(int, data[1].split()))
# We will construct a teleport mapping
teleport = [0] * n
for i in range(n):
teleport[i] = T[i] - 1 # Store 0-based indexing
# To handle up to 10^9 teleports, we need a method to jump in powers of two
max_jump = 32 # 2^32 is larger than 10^9
jump_table = [[0] * max_jump for _ in range(n)]
# Fill the jump table
for i in range(n):
jump_table[i][0] = teleport[i]
for j in range(1, max_jump):
for i in range(n):
jump_table[i][j] = jump_table[jump_table[i][j-1]][j-1]
results = []
queries = data[2:]
for query in queries:
x, k = map(int, query.split())
x -= 1 # Convert to 0-based indexing
for j in range(max_jump):
if k & (1 << j): # If the j-th bit of k is set
x = jump_table[x][j]
results.append(x + 1) # Convert back to 1-based indexing
print("\n".join(map(str, results)))
main()
Editor is loading...
Leave a Comment