# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def divide2(A: T.Buffer((1, 1, 32001), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((1, 1, 32001), "float32")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_fused_0 in T.thread_binding(126, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_divide"):
v0 = T.axis.spatial(32001, ax0_fused_0 * 256 + ax0_fused_1)
T.where(ax0_fused_0 * 256 + ax0_fused_1 < 32001)
T.reads(A[0, 0, v0], B[()])
T.writes(T_divide[0, 0, v0])
T_divide[0, 0, v0] = A[0, 0, v0] / B[()]
@T.prim_func(private=True)
def extend_te(var_A: T.handle, var_concat_te: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, 1, n, n), "float16")
m = T.int32()
concat_te = T.match_buffer(var_concat_te, (1, 1, n, m), "float16")
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding((n * m + 255) // 256, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("concat_te"):
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % (m * n) // m)
v1 = T.axis.spatial(m, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % m)
T.where(ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 < n * m)
T.reads(A[0, 0, v0, v1 + (n - m)])
T.writes(concat_te[0, 0, v0, v1])
concat_te[0, 0, v0, v1] = T.if_then_else(v1 < m - n, T.float16(65504), A[0, 0, v0, v1 + (n - m)])
@T.prim_func(private=True)
def full(var_T_full: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
T_full = T.match_buffer(var_T_full, (1, 1, 1, n), "float16")
# with T.block("root"):
for ax0_fused_0 in T.thread_binding((n + 255) // 256, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_full"):
v0 = T.axis.spatial(n, ax0_fused_0 * 256 + ax0_fused_1)
T.where(ax0_fused_0 * 256 + ax0_fused_1 < n)
T.reads()
T.writes(T_full[0, 0, 0, v0])
T_full[0, 0, 0, v0] = T.float16(65504)
@T.prim_func(private=True)
def fused_NT_matmul1_divide1_maximum1_minimum1_cast3(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
lv28 = T.match_buffer(p_lv28, (1, 32, n, 128), "float16")
m = T.int32()
lv29 = T.match_buffer(p_lv29, (1, 32, m, 128), "float16")
lv5 = T.match_buffer(p_lv5, (1, 1, n, m), "float16")
var_compute_intermediate = T.match_buffer(p_output0, (1, 32, n, m))
# with T.block("root"):
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((32, (n + 31) // 32 * 32, (m + 63) // 64 * 64), "float16", scope="local")
lv28_reindex_pad_shared = T.alloc_buffer((32, (n + 31) // 32 * 32, 128), "float16", scope="shared")
lv29_reindex_pad_shared = T.alloc_buffer((32, (m + 63) // 64 * 64, 128), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding((m + 63) // 64 * 32, thread="blockIdx.y"):
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_3_init, ax1_3_init in T.grid(4, 4):
with T.block("NT_matmul_init"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64))
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
v2 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
T.reads()
T.writes(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] = T.float16(0)
for ax3_0 in range(8):
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(2):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("lv28_reindex_pad_shared"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64))
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(128, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv28[0, v0, v1, v2])
T.writes(lv28_reindex_pad_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
lv28_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv28[0, v0, v1, v2], T.float16(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(4):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("lv29_reindex_pad_shared"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64))
v1 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(128, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv29[0, v0, v1, v2])
T.writes(lv29_reindex_pad_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
lv29_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, lv29[0, v0, v1, v2], T.float16(0))
for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
with T.block("NT_matmul_update"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64))
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
v2 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
v3 = T.axis.reduce(128, ax3_0 * 16 + ax3_1)
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv28_reindex_pad_shared[v0, v1, v3], lv29_reindex_pad_shared[v0, v2, v3])
T.writes(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv28_reindex_pad_shared[v0, v1, v3] * lv29_reindex_pad_shared[v0, v2, v3]
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
for ax2_1_1 in T.vectorized(2):
with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64) + ax0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv5[0, 0, v1, v2])
T.writes(var_compute_intermediate[0, v0, v1, v2])
if v1 < n and v2 < m:
var_compute_intermediate[0, v0, v1, v2] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.float16(0.088397790055248615), T.float16(-65504)), lv5[0, 0, v1, v2]))
@T.prim_func(private=True)
def fused_NT_matmul7_divide_maximum_minimum_cast(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16")
lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16")
var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n))
# with T.block("root"):
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 32, 1, n), "float16", scope="local")
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 32, 1, n), "float16", scope="local")
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 32, 1, n), "float16", scope="local")
lv1638_local = T.alloc_buffer((1, 32, n, 128), "float16", scope="local")
lv1637_shared = T.alloc_buffer((1, 32, 1, 128), "float16", scope="shared")
for ax0_fused_ax1_fused_fused_0 in T.thread_binding(n * 32, thread="blockIdx.x"):
for ax0_fused_ax1_fused_fused_1 in T.thread_binding(1, thread="threadIdx.y"):
for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(64, thread="threadIdx.x"):
for ax0, ax1, ax2 in T.grid(1, 1, 1):
for ax3_0 in T.serial(1, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
for ax3_1 in T.thread_binding(1, thread="threadIdx.y"):
for ax3_2 in T.thread_binding(64, thread="threadIdx.x"):
for ax3_3 in T.vectorized(2):
with T.block("lv1637_shared"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1)
v2 = T.axis.spatial(1, ax2)
v3 = T.axis.spatial(128, ax3_0 * 128 + ax3_1 * 128 + ax3_2 * 2 + ax3_3)
T.reads(lv1637[v0, v1, v2, v3])
T.writes(lv1637_shared[v0, v1, v2, v3])
lv1637_shared[v0, v1, v2, v3] = lv1637[v0, v1, v2, v3]
for ax0_fused_ax1_fused_fused_2_init in range(1):
for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(2):
with T.block("NT_matmul_rf_init"):
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init)
v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) // n)
v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) % n)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1])
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = T.float16(0)
for ax2_fused_u_fused_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2):
for ax2_1 in T.vectorized(1):
with T.block("lv1638_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1)
v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1)
v3 = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3)
T.reads(lv1638[v0, v1, v2, v3])
T.writes(lv1638_local[v0, v1, v2, v3])
lv1638_local[v0, v1, v2, v3] = lv1638[v0, v1, v2, v3]
for ax0_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(1, 1):
for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(2):
with T.block("NT_matmul_rf_update"):
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1)
v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) // n)
v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) % n)
vax2_fused_u_fused_2, vax2_fused_u_fused_0 = T.axis.remap("RR", [ax2_fused_u_fused_2, ax2_fused_u_fused_0])
T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1], lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused], lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused])
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1])
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] + lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused] * lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused]
for ax2_ax3_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
for ax0 in T.thread_binding(64, thread="threadIdx.x"):
for ax2_ax3_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_ax3_fused_1_1 in T.vectorized(1):
with T.block("NT_matmul_rf_init"):
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(64, ax0)
v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n)
v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1])
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = T.float16(0)
for ax1 in range(2):
with T.block("NT_matmul_rf_update"):
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n)
v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1], var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1])
T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1])
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1]
for ax1_ax2_fused_1 in range(1):
for ax1_ax2_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
for ax0 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("NT_matmul"):
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(64, ax0)
v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n)
v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1])
T.writes(var_NT_matmul_intermediate_local[0, v0, 0, v1])
with T.init():
var_NT_matmul_intermediate_local[0, v0, 0, v1] = T.float16(0)
var_NT_matmul_intermediate_local[0, v0, 0, v1] = var_NT_matmul_intermediate_local[0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]
for ax0_ax1_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
for ax0_ax1_fused_1 in range(1):
with T.block("compute"):
v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n)
v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
T.reads(var_NT_matmul_intermediate_local[0, v0, 0, v1], lv1614[0, 0, 0, v1])
T.writes(var_compute_intermediate[0, v0, 0, v1])
var_compute_intermediate[0, v0, 0, v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[0, v0, 0, v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1614[0, 0, 0, v1]))
@T.prim_func(private=True)
def fused_fused_decode1_fused_NT_matmul5_cast2(lv771: T.Buffer((32001, 512), "uint32"), lv772: T.Buffer((32001, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32001), "float32")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 32001), "float16", scope="local")
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 32001), "float16", scope="local")
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 32001), "float16", scope="local")
lv771_local = T.alloc_buffer((32001, 512), "uint32", scope="local")
lv3216_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
for u_fused_ax0_fused_fused_0 in T.thread_binding(501, thread="blockIdx.x"):
for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 1):
for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_3 in T.vectorized(4):
with T.block("lv3216_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3)
T.reads(lv3216[v0, v1, v2])
T.writes(lv3216_shared[v0, v1, v2])
lv3216_shared[v0, v1, v2] = lv3216[v0, v1, v2]
for u_fused_ax0_fused_fused_2_init in range(1):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
T.where(u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init < 32001)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_0, ax1 in T.grid(1, 1):
for ax0_1 in T.vectorized(1):
with T.block("lv771_local"):
v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
T.where(u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 < 32001)
T.reads(lv771[v0, v1])
T.writes(lv771_local[v0, v1])
lv771_local[v0, v1] = lv771[v0, v1]
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
T.where(u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2 < 32001)
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv772[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_fused_1_1 in T.vectorized(1):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.where(u_fused_ax0_fused_fused_0 * 64 + (ax2_fused_0 + (ax2_fused_1_0 + ax2_fused_1_1)) < 32001)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
for ax1 in range(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.where(u_fused_ax0_fused_fused_0 * 64 + (ax2_fused_0 + (ax2_fused_1_0 + ax2_fused_1_1)) < 32001)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
for ax1_fused_1 in range(1):
for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
with T.block("NT_matmul"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
T.where(u_fused_ax0_fused_fused_0 * 64 + (ax1_fused_0 + ax1_fused_1) < 32001)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
T.writes(var_NT_matmul_intermediate_local[0, 0, v0])
with T.init():
var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0)
var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]
for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0_fused_1 in range(1):
with T.block("compute"):
v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1)
T.where(u_fused_ax0_fused_fused_0 * 64 + (ax0_fused_0 + ax0_fused_1) < 32001)
T.reads(var_NT_matmul_intermediate_local[0, 0, v0])
T.writes(p_output0_intermediate[0, 0, v0])
p_output0_intermediate[0, 0, v0] = T.Cast("float32", var_NT_matmul_intermediate_local[0, 0, v0])
@T.prim_func(private=True)
def fused_fused_decode1_take(lv: T.Buffer((32001, 512), "uint32"), lv1: T.Buffer((32001, 128), "float16"), lv1611: T.Buffer((1,), "int32"), var_T_take_intermediate: T.Buffer((1, 4096), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_take"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
T.reads(lv[lv1611[0], v0 // 8], lv1611[0], lv1[lv1611[0], v0 // 32])
T.writes(var_T_take_intermediate[0, v0])
var_T_take_intermediate[0, v0] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv[lv1611[0], v0 // 8], T.Cast("uint32", v0 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1[lv1611[0], v0 // 32]
@T.prim_func(private=True)
def fused_fused_decode1_take1(lv775: T.Buffer((32001, 512), "uint32"), lv776: T.Buffer((32001, 128), "float16"), p_lv: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
lv = T.match_buffer(p_lv, (n,), "int32")
var_T_take_intermediate = T.match_buffer(p_output0, (n, 4096), "float16")
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_take"):
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
T.reads(lv775[lv[v0], v1 // 8], lv[v0], lv776[lv[v0], v1 // 32])
T.writes(var_T_take_intermediate[v0, v1])
var_T_take_intermediate[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv775[lv[v0], v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv776[lv[v0], v1 // 32]
@T.prim_func(private=True)
def fused_fused_decode2_NT_matmul(lv779: T.Buffer((12288, 512), "uint32"), lv780: T.Buffer((12288, 128), "float16"), p_lv6: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
lv6 = T.match_buffer(p_lv6, (1, n, 4096), "float16")
var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 12288), "float16")
# with T.block("root"):
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 12288), "float16", scope="local")
lv6_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="shared")
p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 12288, 4096), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding(192, thread="blockIdx.y"):
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_3_init, ax1_3_init in T.grid(4, 4):
with T.block("NT_matmul_init"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
v2 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
T.reads()
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0)
for ax3_0 in range(256):
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(2):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("lv6_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv6[v0, v1, v2])
T.writes(lv6_reindex_pad_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
lv6_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv6[v0, v1, v2], T.float16(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(4):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("p_output0_intermediate_reindex_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv779[v1, v2 // 8], lv780[v1, v2 // 32])
T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv779[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv780[v1, v2 // 32]
for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
with T.block("NT_matmul_update"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
v2 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1)
T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv6_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3])
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv6_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3]
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
for ax2_1_1 in T.vectorized(2):
with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
T.writes(var_NT_matmul_intermediate[0, v1, v2])
if v1 < n:
var_NT_matmul_intermediate[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]
@T.prim_func(private=True)
def fused_fused_decode2_NT_matmul6(lv3: T.Buffer((12288, 512), "uint32"), lv4: T.Buffer((12288, 128), "float16"), lv1615: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 12288), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 12288), "float16", scope="local")
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 12288), "float16", scope="local")
lv3_local = T.alloc_buffer((12288, 512), "uint32", scope="local")
lv1615_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
for u_fused_ax0_fused_fused_0 in T.thread_binding(192, thread="blockIdx.x"):
for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 1):
for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_3 in T.vectorized(4):
with T.block("lv1615_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3)
T.reads(lv1615[v0, v1, v2])
T.writes(lv1615_shared[v0, v1, v2])
lv1615_shared[v0, v1, v2] = lv1615[v0, v1, v2]
for u_fused_ax0_fused_fused_2_init in range(1):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_0, ax1 in T.grid(1, 1):
for ax0_1 in T.vectorized(1):
with T.block("lv3_local"):
v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
T.reads(lv3[v0, v1])
T.writes(lv3_local[v0, v1])
lv3_local[v0, v1] = lv3[v0, v1]
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1615_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv3_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv4[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1615_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv3_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv4[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_fused_1_1 in T.vectorized(1):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
for ax1 in range(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
for ax1_fused_1 in range(1):
for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
with T.block("NT_matmul"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
T.writes(var_NT_matmul_intermediate[0, 0, v0])
with T.init():
var_NT_matmul_intermediate[0, 0, v0] = T.float16(0)
var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]
@T.prim_func(private=True)
def fused_fused_decode3_fused_NT_matmul2_add1(lv784: T.Buffer((4096, 512), "uint32"), lv785: T.Buffer((4096, 128), "float16"), p_lv41: T.handle, p_lv2: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
lv41 = T.match_buffer(p_lv41, (1, n, 4096), "float16")
lv2 = T.match_buffer(p_lv2, (1, n, 4096), "float16")
p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16")
# with T.block("root"):
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="local")
lv41_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="shared")
p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 4096, 4096), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_3_init, ax1_3_init in T.grid(4, 4):
with T.block("NT_matmul_init"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
T.reads()
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0)
for ax3_0 in range(256):
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(2):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("lv41_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv41[v0, v1, v2])
T.writes(lv41_reindex_pad_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
lv41_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv41[v0, v1, v2], T.float16(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(4):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("p_output0_intermediate_reindex_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv784[v1, v2 // 8], lv785[v1, v2 // 32])
T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv784[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv785[v1, v2 // 32]
for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
with T.block("NT_matmul_update"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1)
T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv41_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3])
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv41_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3]
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
for ax2_1_1 in T.vectorized(2):
with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.reads(lv2[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
T.writes(p_output0_intermediate[0, v1, v2])
if v1 < n:
p_output0_intermediate[0, v1, v2] = lv2[0, v1, v2] + var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]
@T.prim_func(private=True)
def fused_fused_decode3_fused_NT_matmul8_add(lv15: T.Buffer((4096, 512), "uint32"), lv16: T.Buffer((4096, 128), "float16"), lv14: T.Buffer((1, 1, 4096), "float16"), lv1613: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local")
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local")
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 4096), "float16", scope="local")
lv15_local = T.alloc_buffer((4096, 512), "uint32", scope="local")
lv14_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"):
for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 1):
for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_3 in T.vectorized(4):
with T.block("lv14_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3)
T.reads(lv14[v0, v1, v2])
T.writes(lv14_shared[v0, v1, v2])
lv14_shared[v0, v1, v2] = lv14[v0, v1, v2]
for u_fused_ax0_fused_fused_2_init in range(1):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_0, ax1 in T.grid(1, 1):
for ax0_1 in T.vectorized(1):
with T.block("lv15_local"):
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
T.reads(lv15[v0, v1])
T.writes(lv15_local[v0, v1])
lv15_local[v0, v1] = lv15[v0, v1]
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv14_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv15_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv16[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv14_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv15_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv16[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_fused_1_1 in T.vectorized(1):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
for ax1 in range(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
for ax1_fused_1 in range(1):
for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
with T.block("NT_matmul"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
T.writes(var_NT_matmul_intermediate_local[0, 0, v0])
with T.init():
var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0)
var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]
for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0_fused_1 in range(1):
with T.block("T_add"):
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1)
T.reads(lv1613[0, 0, v0], var_NT_matmul_intermediate_local[0, 0, v0])
T.writes(p_output0_intermediate[0, 0, v0])
p_output0_intermediate[0, 0, v0] = lv1613[0, 0, v0] + var_NT_matmul_intermediate_local[0, 0, v0]
@T.prim_func(private=True)
def fused_fused_decode4_NT_matmul3(lv788: T.Buffer((22016, 512), "uint32"), lv789: T.Buffer((22016, 128), "float16"), p_lv45: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16")
var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 22016), "float16")
# with T.block("root"):
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 22016), "float16", scope="local")
lv45_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="shared")
p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 22016, 4096), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding(344, thread="blockIdx.y"):
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_3_init, ax1_3_init in T.grid(4, 4):
with T.block("NT_matmul_init"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
T.reads()
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0)
for ax3_0 in range(256):
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(2):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("lv45_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv45[v0, v1, v2])
T.writes(lv45_reindex_pad_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
lv45_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv45[v0, v1, v2], T.float16(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(4):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("p_output0_intermediate_reindex_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv788[v1, v2 // 8], lv789[v1, v2 // 32])
T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv788[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv789[v1, v2 // 32]
for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
with T.block("NT_matmul_update"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1)
T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv45_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3])
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv45_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3]
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
for ax2_1_1 in T.vectorized(2):
with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
T.writes(var_NT_matmul_intermediate[0, v1, v2])
if v1 < n:
var_NT_matmul_intermediate[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]
@T.prim_func(private=True)
def fused_fused_decode4_NT_matmul9(lv19: T.Buffer((22016, 512), "uint32"), lv20: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 22016), "float16", scope="local")
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 22016), "float16", scope="local")
lv19_local = T.alloc_buffer((22016, 512), "uint32", scope="local")
lv1654_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
for u_fused_ax0_fused_fused_0 in T.thread_binding(344, thread="blockIdx.x"):
for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 1):
for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_3 in T.vectorized(4):
with T.block("lv1654_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3)
T.reads(lv1654[v0, v1, v2])
T.writes(lv1654_shared[v0, v1, v2])
lv1654_shared[v0, v1, v2] = lv1654[v0, v1, v2]
for u_fused_ax0_fused_fused_2_init in range(1):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_0, ax1 in T.grid(1, 1):
for ax0_1 in T.vectorized(1):
with T.block("lv19_local"):
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
T.reads(lv19[v0, v1])
T.writes(lv19_local[v0, v1])
lv19_local[v0, v1] = lv19[v0, v1]
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv19_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv20[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv19_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv20[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_fused_1_1 in T.vectorized(1):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
for ax1 in range(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
for ax1_fused_1 in range(1):
for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
with T.block("NT_matmul"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
T.writes(var_NT_matmul_intermediate[0, 0, v0])
with T.init():
var_NT_matmul_intermediate[0, 0, v0] = T.float16(0)
var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]
@T.prim_func(private=True)
def fused_fused_decode5_fused_NT_matmul10_add(lv23: T.Buffer((4096, 1376), "uint32"), lv24: T.Buffer((4096, 344), "float16"), lv22: T.Buffer((1, 1, 11008), "float16"), lv18: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local")
var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local")
var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 4096), "float16", scope="local")
lv23_local = T.alloc_buffer((4096, 1376), "uint32", scope="local")
lv22_shared = T.alloc_buffer((1, 1, 11008), "float16", scope="shared")
for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"):
for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 1):
for ax2_0 in T.serial(22, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
for ax2_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax2_2 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_3 in T.vectorized(1):
with T.block("lv22_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(11008, ax2_0 * 512 + ax2_1 * 8 + ax2_2 + ax2_3)
T.where((ax2_0 * 64 + ax2_1) * 8 + ax2_2 + ax2_3 < 11008)
T.reads(lv22[v0, v1, v2])
T.writes(lv22_shared[v0, v1, v2])
lv22_shared[v0, v1, v2] = lv22[v0, v1, v2]
for u_fused_ax0_fused_fused_2_init in range(1):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(172, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_0, ax1 in T.grid(1, 1):
for ax0_1 in T.vectorized(1):
with T.block("lv23_local"):
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
v1 = T.axis.spatial(1376, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
T.reads(lv23[v0, v1])
T.writes(lv23_local[v0, v1])
lv23_local[v0, v1] = lv23[v0, v1]
for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4):
for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2])
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv22_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv23_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv24[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv22_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv23_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv24[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32])
for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_fused_1_1 in T.vectorized(1):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0)
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0)
for ax1 in range(2):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
for ax1_fused_1 in range(1):
for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0 in T.thread_binding(8, thread="threadIdx.x"):
with T.block("NT_matmul"):
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0)
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1)
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0])
T.writes(var_NT_matmul_intermediate_local[0, 0, v0])
with T.init():
var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0)
var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]
for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.y"):
for ax0_fused_1 in range(1):
with T.block("T_add"):
v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1)
T.reads(lv18[0, 0, v0], var_NT_matmul_intermediate_local[0, 0, v0])
T.writes(p_output0_intermediate[0, 0, v0])
p_output0_intermediate[0, 0, v0] = lv18[0, 0, v0] + var_NT_matmul_intermediate_local[0, 0, v0]
@T.prim_func(private=True)
def fused_fused_decode5_fused_NT_matmul4_add1(lv792: T.Buffer((4096, 1376), "uint32"), lv793: T.Buffer((4096, 344), "float16"), p_lv791: T.handle, p_lv787: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
lv791 = T.match_buffer(p_lv791, (1, n, 11008), "float16")
lv787 = T.match_buffer(p_lv787, (1, n, 4096), "float16")
p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16")
# with T.block("root"):
var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="local")
lv791_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 11008), "float16", scope="shared")
p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 4096, 11008), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_3_init, ax1_3_init in T.grid(4, 4):
with T.block("NT_matmul_init"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
T.reads()
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0)
for ax3_0 in range(688):
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(2):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("lv791_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(11008, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv791[v0, v1, v2])
T.writes(lv791_reindex_pad_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
lv791_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv791[v0, v1, v2], T.float16(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(4):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("p_output0_intermediate_reindex_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(11008, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(lv792[v1, v2 // 8], lv793[v1, v2 // 32])
T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv792[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv793[v1, v2 // 32]
for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
with T.block("NT_matmul_update"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
v3 = T.axis.reduce(11008, ax3_0 * 16 + ax3_1)
T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv791_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3])
T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2])
var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv791_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3]
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
for ax2_1_1 in T.vectorized(2):
with T.block("var_NT_matmul_intermediate_reindex_pad_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.reads(lv787[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
T.writes(p_output0_intermediate[0, v1, v2])
if v1 < n:
p_output0_intermediate[0, v1, v2] = lv787[0, v1, v2] + var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]
@T.prim_func(private=True)
def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int32):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (1, 1, n, n), "float16")
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding((n * n + 255) // 256, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_broadcast_to"):
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // n)
v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % n)
T.where(ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 < n * n)
T.reads()
T.writes(var_T_broadcast_to_intermediate[0, 0, v0, v1])
var_T_broadcast_to_intermediate[0, 0, v0, v1] = T.Select(v0 < v1, T.float16(-65504), T.float16(65504))
@T.prim_func(private=True)
def fused_reshape2_squeeze(lv_1: T.Buffer((1, 1, 4096), "float16"), var_T_squeeze_intermediate: T.Buffer((1, 32, 128), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_squeeze"):
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 128)
v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 128)
T.reads(lv_1[0, 0, v0 * 128 + v1])
T.writes(var_T_squeeze_intermediate[0, v0, v1])
var_T_squeeze_intermediate[0, v0, v1] = lv_1[0, 0, v0 * 128 + v1]
@T.prim_func(private=True)
def fused_reshape2_transpose5(lv_0: T.Buffer((1, 1, 4096), "float16"), var_T_transpose_intermediate: T.Buffer((1, 32, 1, 128), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_transpose"):
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 128)
v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 128)
T.reads(lv_0[0, 0, v0 * 128 + v1])
T.writes(var_T_transpose_intermediate[0, v0, 0, v1])
var_T_transpose_intermediate[0, v0, 0, v1] = lv_0[0, 0, v0 * 128 + v1]
@T.prim_func(private=True)
def fused_softmax1_cast4(p_lv36: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n, m = T.int32(), T.int32()
lv36 = T.match_buffer(p_lv36, (1, 32, n, m))
var_compute_intermediate = T.match_buffer(p_output0, (1, 32, n, m), "float16")
# with T.block("root"):
T_softmax_maxelem_shared = T.alloc_buffer((1, 32, n), scope="shared")
T_softmax_expsum_shared = T.alloc_buffer((1, 32, n), scope="shared")
for ax0_ax1_fused in T.thread_binding(n * 32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
for ax0, ax1, ax2_fused_0 in T.grid(1, 1, (m + 63) // 64):
for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("T_softmax_maxelem"):
v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax0)
v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
v2 = T.axis.reduce(m, ax2_fused_0 * 64 + ax2_fused_1)
T.where(ax2_fused_0 * 64 + ax2_fused_1 < m)
T.reads(lv36[0, v0, v1, v2])
T.writes(T_softmax_maxelem_shared[0, v0, v1])
with T.init():
T_softmax_maxelem_shared[0, v0, v1] = T.float32(-3.4028234663852886e+38)
T_softmax_maxelem_shared[0, v0, v1] = T.max(T_softmax_maxelem_shared[0, v0, v1], lv36[0, v0, v1, v2])
for ax0, ax1, ax2_fused_0 in T.grid(1, 1, (m + 63) // 64):
for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("T_softmax_expsum"):
v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax0)
v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
v2 = T.axis.reduce(m, ax2_fused_0 * 64 + ax2_fused_1)
T.where(ax2_fused_0 * 64 + ax2_fused_1 < m)
T.reads(lv36[0, v0, v1, v2], T_softmax_maxelem_shared[0, v0, v1])
T.writes(T_softmax_expsum_shared[0, v0, v1])
with T.init():
T_softmax_expsum_shared[0, v0, v1] = T.float32(0)
T_softmax_expsum_shared[0, v0, v1] = T_softmax_expsum_shared[0, v0, v1] + T.exp(lv36[0, v0, v1, v2] - T_softmax_maxelem_shared[0, v0, v1])
for ax2_0 in range((m + 63) // 64):
for ax2_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("compute"):
v0 = T.axis.spatial(32, ax0_ax1_fused // n)
v1 = T.axis.spatial(n, ax0_ax1_fused % n)
v2 = T.axis.spatial(m, ax2_0 * 64 + ax2_1)
T.where(ax2_0 * 64 + ax2_1 < m)
T.reads(lv36[0, v0, v1, v2], T_softmax_maxelem_shared[0, v0, v1], T_softmax_expsum_shared[0, v0, v1])
T.writes(var_compute_intermediate[0, v0, v1, v2])
var_compute_intermediate[0, v0, v1, v2] = T.Cast("float16", T.exp(lv36[0, v0, v1, v2] - T_softmax_maxelem_shared[0, v0, v1]) / T_softmax_expsum_shared[0, v0, v1])
@T.prim_func(private=True)
def fused_softmax_cast1(p_lv1645: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
lv1645 = T.match_buffer(p_lv1645, (1, 32, 1, n))
var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n), "float16")
# with T.block("root"):
T_softmax_maxelem_shared = T.alloc_buffer((1, 32, 1), scope="shared")
T_softmax_expsum_shared = T.alloc_buffer((1, 32, 1), scope="shared")
for ax0_fused in T.thread_binding(32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
for ax0, ax1_fused_0 in T.grid(1, (n + 63) // 64):
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("T_softmax_maxelem"):
v0 = T.axis.spatial(32, ax0_fused + ax0)
v1 = T.axis.reduce(n, ax1_fused_0 * 64 + ax1_fused_1)
T.where(ax1_fused_0 * 64 + ax1_fused_1 < n)
T.reads(lv1645[0, v0, 0, v1])
T.writes(T_softmax_maxelem_shared[0, v0, 0])
with T.init():
T_softmax_maxelem_shared[0, v0, 0] = T.float32(-3.4028234663852886e+38)
T_softmax_maxelem_shared[0, v0, 0] = T.max(T_softmax_maxelem_shared[0, v0, 0], lv1645[0, v0, 0, v1])
for ax0, ax1_fused_0 in T.grid(1, (n + 63) // 64):
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("T_softmax_expsum"):
v0 = T.axis.spatial(32, ax0_fused + ax0)
v1 = T.axis.reduce(n, ax1_fused_0 * 64 + ax1_fused_1)
T.where(ax1_fused_0 * 64 + ax1_fused_1 < n)
T.reads(lv1645[0, v0, 0, v1], T_softmax_maxelem_shared[0, v0, 0])
T.writes(T_softmax_expsum_shared[0, v0, 0])
with T.init():
T_softmax_expsum_shared[0, v0, 0] = T.float32(0)
T_softmax_expsum_shared[0, v0, 0] = T_softmax_expsum_shared[0, v0, 0] + T.exp(lv1645[0, v0, 0, v1] - T_softmax_maxelem_shared[0, v0, 0])
for ax1_0 in range((n + 63) // 64):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("compute"):
v0 = T.axis.spatial(32, ax0_fused)
v1 = T.axis.spatial(n, ax1_0 * 64 + ax1_1)
T.where(ax1_0 * 64 + ax1_1 < n)
T.reads(lv1645[0, v0, 0, v1], T_softmax_maxelem_shared[0, v0, 0], T_softmax_expsum_shared[0, v0, 0])
T.writes(var_compute_intermediate[0, v0, 0, v1])
var_compute_intermediate[0, v0, 0, v1] = T.Cast("float16", T.exp(lv1645[0, v0, 0, v1] - T_softmax_maxelem_shared[0, v0, 0]) / T_softmax_expsum_shared[0, v0, 0])
@T.prim_func(private=True)
def fused_split2_silu1_multiply1(p_lv3: T.handle, p_output0: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
lv3 = T.match_buffer(p_lv3, (1, n, 22016), "float16")
var_T_multiply_intermediate = T.match_buffer(p_output0, (1, n, 11008), "float16")
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding(n * 43, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_multiply_1"):
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 11008)
v1 = T.axis.spatial(11008, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 11008)
T.reads(lv3[0, v0, v1:v1 + 11009])
T.writes(var_T_multiply_intermediate[0, v0, v1])
var_T_multiply_intermediate[0, v0, v1] = lv3[0, v0, v1] * T.sigmoid(lv3[0, v0, v1]) * lv3[0, v0, v1 + 11008]
@T.prim_func(private=True)
def fused_split_silu_multiply(lv164: T.Buffer((1, 1, 22016), "float16"), var_T_multiply_intermediate: T.Buffer((1, 1, 11008), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_fused_0 in T.thread_binding(43, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_multiply_1"):
v0 = T.axis.spatial(11008, ax0_fused_0 * 256 + ax0_fused_1)
T.reads(lv164[0, 0, v0:v0 + 11009])
T.writes(var_T_multiply_intermediate[0, 0, v0])
var_T_multiply_intermediate[0, 0, v0] = lv164[0, 0, v0] * T.sigmoid(lv164[0, 0, v0]) * lv164[0, 0, v0 + 11008]
@T.prim_func(private=True)
def fused_transpose7_reshape4(lv1648: T.Buffer((1, 32, 1, 128), "float16"), var_T_reshape_intermediate: T.Buffer((1, 1, 4096), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
T.reads(lv1648[0, v0 // 128, 0, v0 % 128])
T.writes(var_T_reshape_intermediate[0, 0, v0])
var_T_reshape_intermediate[0, 0, v0] = lv1648[0, v0 // 128, 0, v0 % 128]
@T.prim_func(private=True)
def matmul10(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n, m = T.int32(), T.int32()
A = T.match_buffer(var_A, (1, 32, n, m), "float16")
B = T.match_buffer(var_B, (1, 32, m, 128), "float16")
matmul = T.match_buffer(var_matmul, (1, 32, n, 128), "float16")
# with T.block("root"):
matmul_reindex_pad_local = T.alloc_buffer((32, (n + 31) // 32 * 32, 128), "float16", scope="local")
A_reindex_pad_shared = T.alloc_buffer((32, (n + 31) // 32 * 32, (m + 15) // 16 * 16), "float16", scope="shared")
B_reindex_pad_shared = T.alloc_buffer((32, 128, (m + 15) // 16 * 16), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"):
for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_3_init, ax1_3_init in T.grid(4, 4):
with T.block("matmul_init"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
v2 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init)
T.reads()
T.writes(matmul_reindex_pad_local[v0, v1, v2])
matmul_reindex_pad_local[v0, v1, v2] = T.float16(0)
for ax3_0 in range((m + 15) // 16):
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(2):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("A_reindex_pad_shared"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(A[0, v0, v1, v2])
T.writes(A_reindex_pad_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n and v2 < m, A[0, v0, v1, v2], T.float16(0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(4):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("B_reindex_pad_shared"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2)
v1 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.reads(B[0, v0, v2, v1])
T.writes(B_reindex_pad_shared[v0, v1, v2])
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
B_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < m, B[0, v0, v2, v1], T.float16(0))
for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4):
with T.block("matmul_update"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
v2 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3)
v3 = T.axis.reduce((m + 15) // 16 * 16, ax3_0 * 16 + ax3_1)
T.reads(matmul_reindex_pad_local[v0, v1, v2], A_reindex_pad_shared[v0, v1, v3], B_reindex_pad_shared[v0, v2, v3])
T.writes(matmul_reindex_pad_local[v0, v1, v2])
matmul_reindex_pad_local[v0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2] + A_reindex_pad_shared[v0, v1, v3] * B_reindex_pad_shared[v0, v2, v3]
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
for ax2_1_1 in T.vectorized(2):
with T.block("matmul_reindex_pad_local"):
v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2 + ax0)
v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.reads(matmul_reindex_pad_local[v0, v1, v2])
T.writes(matmul[0, v0, v1, v2])
if v1 < n:
matmul[0, v0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
@T.prim_func(private=True)
def matmul9(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((1, 32, 1, 128), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, 32, 1, n), "float16")
B = T.match_buffer(var_B, (1, 32, n, 128), "float16")
# with T.block("root"):
matmul_rf_local = T.alloc_buffer((16, 1, 32, 1, 128), "float16", scope="local")
for ax0_ax1_fused_0 in T.thread_binding(256, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"):
for ax2_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
with T.block("matmul_rf_init"):
vax2_fused_1 = T.axis.spatial(16, ax2_fused_1)
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) // 128)
v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) % 128)
T.reads()
T.writes(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1])
matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] = T.float16(0)
for ax2_fused_0, u in T.grid((n + 15) // 16, 1):
with T.block("matmul_rf_update"):
vax2_fused_1 = T.axis.spatial(16, ax2_fused_1)
v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) // 128)
v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) % 128)
vax2_fused_0 = T.axis.reduce((n + 15) // 16, ax2_fused_0)
T.where(ax2_fused_0 * 16 + ax2_fused_1 < n)
T.reads(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1], A[0, v0, 0, vax2_fused_0 * 16 + vax2_fused_1], B[0, v0, vax2_fused_0 * 16 + vax2_fused_1, v1])
T.writes(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1])
matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] = matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] + A[0, v0, 0, vax2_fused_0 * 16 + vax2_fused_1] * B[0, v0, vax2_fused_0 * 16 + vax2_fused_1, v1]
for ax1_ax2_fused in T.thread_binding(16, thread="threadIdx.x"):
for ax0 in T.thread_binding(16, thread="threadIdx.y"):
with T.block("matmul"):
vax2_fused_1 = T.axis.reduce(16, ax0)
v0 = T.axis.spatial(32, ax0_ax1_fused_0 // 8)
v1 = T.axis.spatial(128, ax0_ax1_fused_0 % 8 * 16 + ax1_ax2_fused)
T.reads(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1])
T.writes(matmul[0, v0, 0, v1])
with T.init():
matmul[0, v0, 0, v1] = T.float16(0)
matmul[0, v0, 0, v1] = matmul[0, v0, 0, v1] + matmul_rf_local[vax2_fused_1, 0, v0, 0, v1]
@T.prim_func(private=True)
def reshape(A: T.Buffer((1, 1), "int32"), T_reshape: T.Buffer((1,), "int32")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(1, 0)
T.where(ax0_fused_0 * 256 + ax0_fused_1 < 1)
T.reads(A[0, 0])
T.writes(T_reshape[0])
T_reshape[0] = A[0, 0]
@T.prim_func(private=True)
def reshape1(A: T.Buffer((1, 4096), "float16"), T_reshape: T.Buffer((1, 1, 4096), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
T.reads(A[0, v0])
T.writes(T_reshape[0, 0, v0])
T_reshape[0, 0, v0] = A[0, v0]
@T.prim_func(private=True)
def reshape3(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (n, 32, 128), "float16")
T_reshape = T.match_buffer(var_T_reshape, (1, n, 32, 128), "float16")
# with T.block("root"):
for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
T.reads(A[v0, v1, v2])
T.writes(T_reshape[0, v0, v1, v2])
T_reshape[0, v0, v1, v2] = A[v0, v1, v2]
@T.prim_func(private=True)
def reshape5(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, n), "int32")
T_reshape = T.match_buffer(var_T_reshape, (n,), "int32")
# with T.block("root"):
for ax0_fused_0 in T.thread_binding((n + 255) // 256, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(n, ax0_fused_0 * 256 + ax0_fused_1)
T.where(ax0_fused_0 * 256 + ax0_fused_1 < n)
T.reads(A[0, v0])
T.writes(T_reshape[v0])
T_reshape[v0] = A[0, v0]
@T.prim_func(private=True)
def reshape6(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (n, 4096), "float16")
T_reshape = T.match_buffer(var_T_reshape, (1, n, 4096), "float16")
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
T.reads(A[v0, v1])
T.writes(T_reshape[0, v0, v1])
T_reshape[0, v0, v1] = A[v0, v1]
@T.prim_func(private=True)
def reshape7(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, n, 4096), "float16")
T_reshape = T.match_buffer(var_T_reshape, (1, n, 32, 128), "float16")
# with T.block("root"):
for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
T.reads(A[0, v0, v1 * 128 + v2])
T.writes(T_reshape[0, v0, v1, v2])
T_reshape[0, v0, v1, v2] = A[0, v0, v1 * 128 + v2]
@T.prim_func(private=True)
def reshape8(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, n, 32, 128), "float16")
T_reshape = T.match_buffer(var_T_reshape, (1, n, 4096), "float16")
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_reshape"):
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
T.reads(A[0, v0, v1 // 128, v1 % 128])
T.writes(T_reshape[0, v0, v1])
T_reshape[0, v0, v1] = A[0, v0, v1 // 128, v1 % 128]
@T.prim_func(private=True)
def rms_norm(var_A: T.handle, B: T.Buffer((4096,), "float16"), var_rms_norm: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, n, 4096), "float16")
rms_norm_1 = T.match_buffer(var_rms_norm, (1, n, 4096), "float16")
# with T.block("root"):
Ared_temp_shared = T.alloc_buffer((1, n), scope="shared")
Ared_temp_rf_local = T.alloc_buffer((64, 1, n), scope="local")
for ax0_fused in T.thread_binding(n, thread="blockIdx.x"):
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.block("Ared_temp_rf_init"):
vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused])
T.reads()
T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0])
Ared_temp_rf_local[vax1_fused_1, 0, v0] = T.float32(0)
for ax1_fused_0, u in T.grid(64, 1):
with T.block("Ared_temp_rf_update"):
vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0])
T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0], A[0, v0, vax1_fused_0 * 64 + vax1_fused_1])
T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0])
Ared_temp_rf_local[vax1_fused_1, 0, v0] = Ared_temp_rf_local[vax1_fused_1, 0, v0] + T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1])
for ax1_fused in range(1):
for ax0 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("Ared_temp"):
vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused])
T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0])
T.writes(Ared_temp_shared[0, v0])
with T.init():
Ared_temp_shared[0, v0] = T.float32(0)
Ared_temp_shared[0, v0] = Ared_temp_shared[0, v0] + Ared_temp_rf_local[vax1_fused_1, 0, v0]
for ax0_fused_0 in range(64):
for ax0_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("rms_norm"):
v0 = T.axis.spatial(n, ax0_fused)
v1 = T.axis.spatial(4096, ax0_fused_0 * 64 + ax0_fused_1)
T.reads(B[v1], A[0, v0, v1], Ared_temp_shared[0, v0])
T.writes(rms_norm_1[0, v0, v1])
rms_norm_1[0, v0, v1] = T.Cast("float16", T.Cast("float32", B[v1]) * (T.Cast("float32", A[0, v0, v1]) / T.sqrt(Ared_temp_shared[0, v0] * T.float32(0.000244140625) + T.float32(1.0000000000000001e-05))))
@T.prim_func(private=True)
def rms_norm1(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), "float16"), rms_norm: T.Buffer((1, 1, 4096), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
Ared_temp_shared = T.alloc_buffer((1, 1), scope="shared")
Ared_temp_rf_local = T.alloc_buffer((64, 1, 1), scope="local")
for ax0_fused in T.thread_binding(1, thread="blockIdx.x"):
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.block("Ared_temp_rf_init"):
vax1_fused_1 = T.axis.spatial(64, ax1_fused_1)
v0 = T.axis.spatial(1, 0)
T.reads()
T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0])
Ared_temp_rf_local[vax1_fused_1, 0, 0] = T.float32(0)
for ax1_fused_0, u in T.grid(64, 1):
with T.block("Ared_temp_rf_update"):
vax1_fused_1 = T.axis.spatial(64, ax1_fused_1)
v0 = T.axis.spatial(1, 0)
vax1_fused_0 = T.axis.reduce(64, ax1_fused_0)
T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0], A[0, 0, vax1_fused_0 * 64 + vax1_fused_1])
T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0])
Ared_temp_rf_local[vax1_fused_1, 0, 0] = Ared_temp_rf_local[vax1_fused_1, 0, 0] + T.Cast("float32", A[0, 0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", A[0, 0, vax1_fused_0 * 64 + vax1_fused_1])
for ax1_fused in range(1):
for ax0 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("Ared_temp"):
vax1_fused_1 = T.axis.reduce(64, ax0)
v0 = T.axis.spatial(1, 0)
T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0])
T.writes(Ared_temp_shared[0, 0])
with T.init():
Ared_temp_shared[0, 0] = T.float32(0)
Ared_temp_shared[0, 0] = Ared_temp_shared[0, 0] + Ared_temp_rf_local[vax1_fused_1, 0, 0]
for ax0_fused_0 in range(64):
for ax0_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("rms_norm"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 64 + ax0_fused_1)
T.reads(B[v0], A[0, 0, v0], Ared_temp_shared[0, 0])
T.writes(rms_norm[0, 0, v0])
rms_norm[0, 0, v0] = T.Cast("float16", T.Cast("float32", B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp_shared[0, 0] * T.float32(0.000244140625) + T.float32(1.0000000000000001e-05))))
@T.prim_func(private=True)
def rotary_embedding(var_A: T.handle, B: T.Buffer((2048, 128), "float16"), C: T.Buffer((2048, 128), "float16"), var_rotary: T.handle, m: T.int32):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, n, 32, 128), "float16")
rotary = T.match_buffer(var_rotary, (1, n, 32, 128), "float16")
# with T.block("root"):
for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("rotary"):
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
T.reads(B[v0 + (m - n), v2], A[0, v0, v1, v2 + -64:v2 + -64 + 129], C[v0 + (m - n), v2])
T.writes(rotary[0, v0, v1, v2])
rotary[0, v0, v1, v2] = B[v0 + (m - n), v2] * A[0, v0, v1, v2] + C[v0 + (m - n), v2] * T.Select(64 <= v2, A[0, v0, v1, v2 + -64], A[0, v0, v1, v2 + 64] * T.float16(-1))
@T.prim_func(private=True)
def slice(var_A: T.handle, slice_1: T.Buffer((1, 1, 4096), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, n, 4096), "float16")
# with T.block("root"):
for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("slice"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
T.reads(A[0, n - 1, v0])
T.writes(slice_1[0, 0, v0])
slice_1[0, 0, v0] = A[0, n - 1, v0]
@T.prim_func(private=True)
def slice1(A: T.Buffer((1, 1, 4096), "float16"), slice: T.Buffer((1, 1, 4096), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("slice"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
T.reads(A[0, 0, v0])
T.writes(slice[0, 0, v0])
slice[0, 0, v0] = A[0, 0, v0]
@T.prim_func(private=True)
def softmax2(A: T.Buffer((1, 1, 32001), "float32"), T_softmax_norm: T.Buffer((1, 1, 32001), "float32")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
T_softmax_maxelem_shared = T.alloc_buffer((1, 1), scope="shared")
T_softmax_expsum_shared = T.alloc_buffer((1, 1), scope="shared")
for ax0_fused in T.thread_binding(1, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
for ax0, ax1_fused_0 in T.grid(1, 501):
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("T_softmax_maxelem"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.reduce(32001, ax1_fused_0 * 64 + ax1_fused_1)
T.where(ax1_fused_0 * 64 + ax1_fused_1 < 32001)
T.reads(A[0, 0, v1])
T.writes(T_softmax_maxelem_shared[0, 0])
with T.init():
T_softmax_maxelem_shared[0, 0] = T.float32(-3.4028234663852886e+38)
T_softmax_maxelem_shared[0, 0] = T.max(T_softmax_maxelem_shared[0, 0], A[0, 0, v1])
for ax0, ax1_fused_0 in T.grid(1, 501):
for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("T_softmax_expsum"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.reduce(32001, ax1_fused_0 * 64 + ax1_fused_1)
T.where(ax1_fused_0 * 64 + ax1_fused_1 < 32001)
T.reads(A[0, 0, v1], T_softmax_maxelem_shared[0, 0])
T.writes(T_softmax_expsum_shared[0, 0])
with T.init():
T_softmax_expsum_shared[0, 0] = T.float32(0)
T_softmax_expsum_shared[0, 0] = T_softmax_expsum_shared[0, 0] + T.exp(A[0, 0, v1] - T_softmax_maxelem_shared[0, 0])
for ax1_0 in range(501):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
with T.block("T_softmax_norm"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(32001, ax1_0 * 64 + ax1_1)
T.where(ax1_0 * 64 + ax1_1 < 32001)
T.reads(A[0, 0, v1], T_softmax_maxelem_shared[0, 0], T_softmax_expsum_shared[0, 0])
T.writes(T_softmax_norm[0, 0, v1])
T.block_attr({"axis": 2})
T_softmax_norm[0, 0, v1] = T.exp(A[0, 0, v1] - T_softmax_maxelem_shared[0, 0]) / T_softmax_expsum_shared[0, 0]
@T.prim_func(private=True)
def split1(var_A: T.handle, var_T_split: T.handle, var_T_split_1: T.handle, var_T_split_2: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, n, 12288), "float16")
T_split = T.match_buffer(var_T_split, (1, n, 4096), "float16")
T_split_1 = T.match_buffer(var_T_split_1, (1, n, 4096), "float16")
T_split_2 = T.match_buffer(var_T_split_2, (1, n, 4096), "float16")
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_split"):
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
T.reads(A[0, v0, v1])
T.writes(T_split[0, v0, v1])
T_split[0, v0, v1] = A[0, v0, v1]
for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_split_1"):
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
T.reads(A[0, v0, v1 + 4096])
T.writes(T_split_1[0, v0, v1])
T_split_1[0, v0, v1] = A[0, v0, v1 + 4096]
for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_split_2"):
v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096)
v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096)
T.reads(A[0, v0, v1 + 8192])
T.writes(T_split_2[0, v0, v1])
T_split_2[0, v0, v1] = A[0, v0, v1 + 8192]
@T.prim_func
def split_rotary(A: T.Buffer((1, 1, 12288), "float16"), cos: T.Buffer((2048, 128), "float16"), sin: T.Buffer((2048, 128), "float16"), T_split: T.Buffer((1, 1, 4096), "float16"), T_split_1: T.Buffer((1, 1, 4096), "float16"), T_split_2: T.Buffer((1, 1, 4096), "float16"), n: T.int32):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_split"):
v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1)
T.reads(A[0, 0, v0], A[0, 0, v0 + 4096], A[0, 0, v0 + 8192])
T.writes(T_split[0, 0, v0], T_split_1[0, 0, v0], T_split_2[0, 0, v0])
T_split[0, 0, v0] = cos[n - 1, v0 % 128] * A[0, 0, v0] + sin[n - 1, v0 % 128] * T.Select(64 <= v0 % 128, A[0, 0, v0 + -64], A[0, 0, v0 + 64] * T.float16(-1))
T_split_1[0, 0, v0] = cos[n - 1, v0 % 128] * A[0, 0, v0 + 4096] + sin[n - 1, v0 % 128] * T.Select(64 <= v0 % 128, A[0, 0, v0 + 4032], A[0, 0, v0 + 4160] * T.float16(-1))
T_split_2[0, 0, v0] = A[0, 0, v0 + 8192]
@T.prim_func(private=True)
def squeeze1(var_A: T.handle, var_T_squeeze: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, n, 32, 128), "float16")
T_squeeze = T.match_buffer(var_T_squeeze, (n, 32, 128), "float16")
# with T.block("root"):
for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_squeeze"):
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
T.reads(A[0, v0, v1, v2])
T.writes(T_squeeze[v0, v1, v2])
T_squeeze[v0, v1, v2] = A[0, v0, v1, v2]
@T.prim_func(private=True)
def transpose6(var_A: T.handle, var_T_transpose: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, n, 32, 128), "float16")
T_transpose = T.match_buffer(var_T_transpose, (1, 32, n, 128), "float16")
# with T.block("root"):
for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_transpose"):
v0 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // (128 * n))
v1 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % (128 * n) // 128)
v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
T.reads(A[0, v1, v0, v2])
T.writes(T_transpose[0, v0, v1, v2])
T_transpose[0, v0, v1, v2] = A[0, v1, v0, v2]
@T.prim_func(private=True)
def transpose8(var_A: T.handle, var_T_transpose: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
n = T.int32()
A = T.match_buffer(var_A, (1, 32, n, 128), "float16")
T_transpose = T.match_buffer(var_T_transpose, (1, n, 32, 128), "float16")
# with T.block("root"):
for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
with T.block("T_transpose"):
v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096)
v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128)
v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128)
T.reads(A[0, v1, v0, v2])
T.writes(T_transpose[0, v0, v1, v2])
T_transpose[0, v0, v1, v2] = A[0, v1, v0, v2]
@R.function
def create_kv_cache() -> R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object):
R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
with R.dataflow():
lv3221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3222: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3223: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3224: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3225: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3226: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3227: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3228: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3229: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3230: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3231: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3232: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3233: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3234: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3235: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3236: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3237: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3238: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3239: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3240: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3241: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3242: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3243: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3244: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3245: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3246: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3247: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3248: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3249: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3250: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3251: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3252: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3253: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3254: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3255: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3256: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3257: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3258: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3259: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3260: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3261: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3262: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3263: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3264: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3265: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3266: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3267: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3268: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3269: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3270: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3271: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3272: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3273: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3274: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3275: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3276: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3277: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3278: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3279: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3280: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3281: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3282: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3283: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
lv3284: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,))
gv2: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object) = lv3221, lv3222, lv3223, lv3224, lv3225, lv3226, lv3227, lv3228, lv3229, lv3230, lv3231, lv3232, lv3233, lv3234, lv3235, lv3236, lv3237, lv3238, lv3239, lv3240, lv3241, lv3242, lv3243, lv3244, lv3245, lv3246, lv3247, lv3248, lv3249, lv3250, lv3251, lv3252, lv3253, lv3254, lv3255, lv3256, lv3257, lv3258, lv3259, lv3260, lv3261, lv3262, lv3263, lv3264, lv3265, lv3266, lv3267, lv3268, lv3269, lv3270, lv3271, lv3272, lv3273, lv3274, lv3275, lv3276, lv3277, lv3278, lv3279, lv3280, lv3281, lv3282, lv3283, lv3284
R.output(gv2)
return gv2
@R.function
def decode(input_ids1: R.Tensor((1, 1), dtype="int32"), all_seq_len: R.Shape(["n"]), kv_cache: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object), params: R.Tuple(R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)):
n = T.int64()
R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
cls = Module
with R.dataflow():
lv1611 = R.call_tir(cls.reshape, (input_ids1,), out_sinfo=R.Tensor((1,), dtype="int32"))
lv: R.Tensor((32001, 512), dtype="uint32") = params[0]
lv1: R.Tensor((32001, 128), dtype="float16") = params[1]
lv1_1 = R.call_tir(cls.fused_fused_decode1_take, (lv, lv1, lv1611), out_sinfo=R.Tensor((1, 4096), dtype="float16"))
lv1613 = R.call_tir(cls.reshape1, (lv1_1,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv1614 = R.call_tir(cls.full, R.tuple(), out_sinfo=R.Tensor((1, 1, 1, n), dtype="float16"))
lv460: R.Tensor((4096,), dtype="float16") = params[10]
lv1615 = R.call_tir(cls.rms_norm1, (lv1613, lv460), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv3: R.Tensor((12288, 512), dtype="uint32") = params[2]
lv4: R.Tensor((12288, 128), dtype="float16") = params[3]
lv_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv3, lv4, lv1615), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv464: R.Tensor((2048, 128), dtype="float16") = params[325]
lv465: R.Tensor((2048, 128), dtype="float16") = params[326]
lv_2 = R.call_tir(cls.split_rotary, (lv_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv6: R.Tensor((1, 1, 4096), dtype="float16") = lv_2[0]
lv7 = R.call_tir(cls.fused_reshape2_transpose5, (lv6,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv8: R.Tensor((1, 1, 4096), dtype="float16") = lv_2[1]
lv9 = R.call_tir(cls.fused_reshape2_squeeze, (lv8,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv10: R.Tensor((1, 1, 4096), dtype="float16") = lv_2[2]
lv11 = R.call_tir(cls.fused_reshape2_squeeze, (lv10,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv1629: R.Object = kv_cache[0]
lv1630: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1629, lv9, sinfo_args=(R.Object,))
lv1631: R.Object = kv_cache[1]
lv1632: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1631, lv11, sinfo_args=(R.Object,))
lv1633: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1630, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1634: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1632, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1635 = R.call_tir(cls.reshape3, (lv1633,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1636 = R.call_tir(cls.reshape3, (lv1634,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1638 = R.call_tir(cls.transpose6, (lv1635,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1639 = R.call_tir(cls.transpose6, (lv1636,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv12 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv7, lv1638, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv13 = R.call_tir(cls.fused_softmax_cast1, (lv12,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1648 = R.call_tir(cls.matmul9, (lv13, lv1639), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv14 = R.call_tir(cls.fused_transpose7_reshape4, (lv1648,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv15: R.Tensor((4096, 512), dtype="uint32") = params[4]
lv16: R.Tensor((4096, 128), dtype="float16") = params[5]
lv_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv15, lv16, lv14, lv1613), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv469: R.Tensor((4096,), dtype="float16") = params[11]
lv1654 = R.call_tir(cls.rms_norm1, (lv_3, lv469), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv19: R.Tensor((22016, 512), dtype="uint32") = params[6]
lv20: R.Tensor((22016, 128), dtype="float16") = params[7]
lv1_2 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv19, lv20, lv1654), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv22 = R.call_tir(cls.fused_split_silu_multiply, (lv1_2,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv23: R.Tensor((4096, 1376), dtype="uint32") = params[8]
lv24: R.Tensor((4096, 344), dtype="float16") = params[9]
lv1_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv23, lv24, lv22, lv_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv476: R.Tensor((4096,), dtype="float16") = params[20]
lv1665 = R.call_tir(cls.rms_norm1, (lv1_3, lv476), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv27: R.Tensor((12288, 512), dtype="uint32") = params[12]
lv28: R.Tensor((12288, 128), dtype="float16") = params[13]
lv2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv27, lv28, lv1665), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv4_1 = R.call_tir(cls.split_rotary, (lv2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv30: R.Tensor((1, 1, 4096), dtype="float16") = lv4_1[0]
lv31 = R.call_tir(cls.fused_reshape2_transpose5, (lv30,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv32: R.Tensor((1, 1, 4096), dtype="float16") = lv4_1[1]
lv33 = R.call_tir(cls.fused_reshape2_squeeze, (lv32,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv34: R.Tensor((1, 1, 4096), dtype="float16") = lv4_1[2]
lv35 = R.call_tir(cls.fused_reshape2_squeeze, (lv34,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv1679: R.Object = kv_cache[2]
lv1680: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1679, lv33, sinfo_args=(R.Object,))
lv1681: R.Object = kv_cache[3]
lv1682: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1681, lv35, sinfo_args=(R.Object,))
lv1683: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1680, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1684: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1682, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1685 = R.call_tir(cls.reshape3, (lv1683,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1686 = R.call_tir(cls.reshape3, (lv1684,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1688 = R.call_tir(cls.transpose6, (lv1685,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1689 = R.call_tir(cls.transpose6, (lv1686,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv36 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv31, lv1688, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv37 = R.call_tir(cls.fused_softmax_cast1, (lv36,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1698 = R.call_tir(cls.matmul9, (lv37, lv1689), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv38 = R.call_tir(cls.fused_transpose7_reshape4, (lv1698,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv39: R.Tensor((4096, 512), dtype="uint32") = params[14]
lv40: R.Tensor((4096, 128), dtype="float16") = params[15]
lv2_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv39, lv40, lv38, lv1_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv483: R.Tensor((4096,), dtype="float16") = params[21]
lv1704 = R.call_tir(cls.rms_norm1, (lv2_1, lv483), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv43: R.Tensor((22016, 512), dtype="uint32") = params[16]
lv44: R.Tensor((22016, 128), dtype="float16") = params[17]
lv3_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv43, lv44, lv1704), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv46 = R.call_tir(cls.fused_split_silu_multiply, (lv3_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv47: R.Tensor((4096, 1376), dtype="uint32") = params[18]
lv48: R.Tensor((4096, 344), dtype="float16") = params[19]
lv3_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv47, lv48, lv46, lv2_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv490: R.Tensor((4096,), dtype="float16") = params[30]
lv1715 = R.call_tir(cls.rms_norm1, (lv3_2, lv490), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv51: R.Tensor((12288, 512), dtype="uint32") = params[22]
lv52: R.Tensor((12288, 128), dtype="float16") = params[23]
lv4_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv51, lv52, lv1715), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv8_1 = R.call_tir(cls.split_rotary, (lv4_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv54: R.Tensor((1, 1, 4096), dtype="float16") = lv8_1[0]
lv55 = R.call_tir(cls.fused_reshape2_transpose5, (lv54,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv56: R.Tensor((1, 1, 4096), dtype="float16") = lv8_1[1]
lv57 = R.call_tir(cls.fused_reshape2_squeeze, (lv56,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv58: R.Tensor((1, 1, 4096), dtype="float16") = lv8_1[2]
lv59 = R.call_tir(cls.fused_reshape2_squeeze, (lv58,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv1729: R.Object = kv_cache[4]
lv1730: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1729, lv57, sinfo_args=(R.Object,))
lv1731: R.Object = kv_cache[5]
lv1732: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1731, lv59, sinfo_args=(R.Object,))
lv1733: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1730, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1734: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1732, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1735 = R.call_tir(cls.reshape3, (lv1733,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1736 = R.call_tir(cls.reshape3, (lv1734,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1738 = R.call_tir(cls.transpose6, (lv1735,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1739 = R.call_tir(cls.transpose6, (lv1736,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv60 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv55, lv1738, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv61 = R.call_tir(cls.fused_softmax_cast1, (lv60,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1748 = R.call_tir(cls.matmul9, (lv61, lv1739), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv62 = R.call_tir(cls.fused_transpose7_reshape4, (lv1748,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv63: R.Tensor((4096, 512), dtype="uint32") = params[24]
lv64: R.Tensor((4096, 128), dtype="float16") = params[25]
lv4_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv63, lv64, lv62, lv3_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv497: R.Tensor((4096,), dtype="float16") = params[31]
lv1754 = R.call_tir(cls.rms_norm1, (lv4_3, lv497), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv67: R.Tensor((22016, 512), dtype="uint32") = params[26]
lv68: R.Tensor((22016, 128), dtype="float16") = params[27]
lv5 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv67, lv68, lv1754), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv70 = R.call_tir(cls.fused_split_silu_multiply, (lv5,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv71: R.Tensor((4096, 1376), dtype="uint32") = params[28]
lv72: R.Tensor((4096, 344), dtype="float16") = params[29]
lv5_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv71, lv72, lv70, lv4_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv504: R.Tensor((4096,), dtype="float16") = params[40]
lv1765 = R.call_tir(cls.rms_norm1, (lv5_1, lv504), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv75: R.Tensor((12288, 512), dtype="uint32") = params[32]
lv76: R.Tensor((12288, 128), dtype="float16") = params[33]
lv6_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv75, lv76, lv1765), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv12_1 = R.call_tir(cls.split_rotary, (lv6_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv78: R.Tensor((1, 1, 4096), dtype="float16") = lv12_1[0]
lv79 = R.call_tir(cls.fused_reshape2_transpose5, (lv78,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv80: R.Tensor((1, 1, 4096), dtype="float16") = lv12_1[1]
lv81 = R.call_tir(cls.fused_reshape2_squeeze, (lv80,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv82: R.Tensor((1, 1, 4096), dtype="float16") = lv12_1[2]
lv83 = R.call_tir(cls.fused_reshape2_squeeze, (lv82,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv1779: R.Object = kv_cache[6]
lv1780: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1779, lv81, sinfo_args=(R.Object,))
lv1781: R.Object = kv_cache[7]
lv1782: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1781, lv83, sinfo_args=(R.Object,))
lv1783: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1780, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1784: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1782, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1785 = R.call_tir(cls.reshape3, (lv1783,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1786 = R.call_tir(cls.reshape3, (lv1784,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1788 = R.call_tir(cls.transpose6, (lv1785,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1789 = R.call_tir(cls.transpose6, (lv1786,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv84 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv79, lv1788, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv85 = R.call_tir(cls.fused_softmax_cast1, (lv84,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1798 = R.call_tir(cls.matmul9, (lv85, lv1789), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv86 = R.call_tir(cls.fused_transpose7_reshape4, (lv1798,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv87: R.Tensor((4096, 512), dtype="uint32") = params[34]
lv88: R.Tensor((4096, 128), dtype="float16") = params[35]
lv6_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv87, lv88, lv86, lv5_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv511: R.Tensor((4096,), dtype="float16") = params[41]
lv1804 = R.call_tir(cls.rms_norm1, (lv6_2, lv511), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv91: R.Tensor((22016, 512), dtype="uint32") = params[36]
lv92: R.Tensor((22016, 128), dtype="float16") = params[37]
lv7_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv91, lv92, lv1804), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv94 = R.call_tir(cls.fused_split_silu_multiply, (lv7_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv95: R.Tensor((4096, 1376), dtype="uint32") = params[38]
lv96: R.Tensor((4096, 344), dtype="float16") = params[39]
lv7_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv95, lv96, lv94, lv6_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv518: R.Tensor((4096,), dtype="float16") = params[50]
lv1815 = R.call_tir(cls.rms_norm1, (lv7_2, lv518), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv99: R.Tensor((12288, 512), dtype="uint32") = params[42]
lv100: R.Tensor((12288, 128), dtype="float16") = params[43]
lv8_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv99, lv100, lv1815), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv16_1 = R.call_tir(cls.split_rotary, (lv8_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv102: R.Tensor((1, 1, 4096), dtype="float16") = lv16_1[0]
lv103 = R.call_tir(cls.fused_reshape2_transpose5, (lv102,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv104: R.Tensor((1, 1, 4096), dtype="float16") = lv16_1[1]
lv105 = R.call_tir(cls.fused_reshape2_squeeze, (lv104,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv106: R.Tensor((1, 1, 4096), dtype="float16") = lv16_1[2]
lv107 = R.call_tir(cls.fused_reshape2_squeeze, (lv106,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv1829: R.Object = kv_cache[8]
lv1830: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1829, lv105, sinfo_args=(R.Object,))
lv1831: R.Object = kv_cache[9]
lv1832: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1831, lv107, sinfo_args=(R.Object,))
lv1833: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1830, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1834: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1832, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1835 = R.call_tir(cls.reshape3, (lv1833,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1836 = R.call_tir(cls.reshape3, (lv1834,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1838 = R.call_tir(cls.transpose6, (lv1835,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1839 = R.call_tir(cls.transpose6, (lv1836,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv108 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv103, lv1838, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv109 = R.call_tir(cls.fused_softmax_cast1, (lv108,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1848 = R.call_tir(cls.matmul9, (lv109, lv1839), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv110 = R.call_tir(cls.fused_transpose7_reshape4, (lv1848,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv111: R.Tensor((4096, 512), dtype="uint32") = params[44]
lv112: R.Tensor((4096, 128), dtype="float16") = params[45]
lv8_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv111, lv112, lv110, lv7_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv525: R.Tensor((4096,), dtype="float16") = params[51]
lv1854 = R.call_tir(cls.rms_norm1, (lv8_3, lv525), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv115: R.Tensor((22016, 512), dtype="uint32") = params[46]
lv116: R.Tensor((22016, 128), dtype="float16") = params[47]
lv9_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv115, lv116, lv1854), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv118 = R.call_tir(cls.fused_split_silu_multiply, (lv9_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv119: R.Tensor((4096, 1376), dtype="uint32") = params[48]
lv120: R.Tensor((4096, 344), dtype="float16") = params[49]
lv9_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv119, lv120, lv118, lv8_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv532: R.Tensor((4096,), dtype="float16") = params[60]
lv1865 = R.call_tir(cls.rms_norm1, (lv9_2, lv532), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv123: R.Tensor((12288, 512), dtype="uint32") = params[52]
lv124: R.Tensor((12288, 128), dtype="float16") = params[53]
lv10_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv123, lv124, lv1865), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv20_1 = R.call_tir(cls.split_rotary, (lv10_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv126: R.Tensor((1, 1, 4096), dtype="float16") = lv20_1[0]
lv127 = R.call_tir(cls.fused_reshape2_transpose5, (lv126,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv128: R.Tensor((1, 1, 4096), dtype="float16") = lv20_1[1]
lv129 = R.call_tir(cls.fused_reshape2_squeeze, (lv128,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv130: R.Tensor((1, 1, 4096), dtype="float16") = lv20_1[2]
lv131 = R.call_tir(cls.fused_reshape2_squeeze, (lv130,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv1879: R.Object = kv_cache[10]
lv1880: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1879, lv129, sinfo_args=(R.Object,))
lv1881: R.Object = kv_cache[11]
lv1882: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1881, lv131, sinfo_args=(R.Object,))
lv1883: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1880, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1884: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1882, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1885 = R.call_tir(cls.reshape3, (lv1883,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1886 = R.call_tir(cls.reshape3, (lv1884,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1888 = R.call_tir(cls.transpose6, (lv1885,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1889 = R.call_tir(cls.transpose6, (lv1886,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv132 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv127, lv1888, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv133 = R.call_tir(cls.fused_softmax_cast1, (lv132,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1898 = R.call_tir(cls.matmul9, (lv133, lv1889), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv134 = R.call_tir(cls.fused_transpose7_reshape4, (lv1898,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv135: R.Tensor((4096, 512), dtype="uint32") = params[54]
lv136: R.Tensor((4096, 128), dtype="float16") = params[55]
lv10_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv135, lv136, lv134, lv9_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv539: R.Tensor((4096,), dtype="float16") = params[61]
lv1904 = R.call_tir(cls.rms_norm1, (lv10_2, lv539), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv139: R.Tensor((22016, 512), dtype="uint32") = params[56]
lv140: R.Tensor((22016, 128), dtype="float16") = params[57]
lv11_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv139, lv140, lv1904), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv142 = R.call_tir(cls.fused_split_silu_multiply, (lv11_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv143: R.Tensor((4096, 1376), dtype="uint32") = params[58]
lv144: R.Tensor((4096, 344), dtype="float16") = params[59]
lv11_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv143, lv144, lv142, lv10_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv546: R.Tensor((4096,), dtype="float16") = params[70]
lv1915 = R.call_tir(cls.rms_norm1, (lv11_2, lv546), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv147: R.Tensor((12288, 512), dtype="uint32") = params[62]
lv148: R.Tensor((12288, 128), dtype="float16") = params[63]
lv12_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv147, lv148, lv1915), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv24_1 = R.call_tir(cls.split_rotary, (lv12_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv150: R.Tensor((1, 1, 4096), dtype="float16") = lv24_1[0]
lv151 = R.call_tir(cls.fused_reshape2_transpose5, (lv150,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv152: R.Tensor((1, 1, 4096), dtype="float16") = lv24_1[1]
lv153 = R.call_tir(cls.fused_reshape2_squeeze, (lv152,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv154: R.Tensor((1, 1, 4096), dtype="float16") = lv24_1[2]
lv155 = R.call_tir(cls.fused_reshape2_squeeze, (lv154,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv1929: R.Object = kv_cache[12]
lv1930: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1929, lv153, sinfo_args=(R.Object,))
lv1931: R.Object = kv_cache[13]
lv1932: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1931, lv155, sinfo_args=(R.Object,))
lv1933: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1930, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1934: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1932, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1935 = R.call_tir(cls.reshape3, (lv1933,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1936 = R.call_tir(cls.reshape3, (lv1934,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1938 = R.call_tir(cls.transpose6, (lv1935,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1939 = R.call_tir(cls.transpose6, (lv1936,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv156 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv151, lv1938, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv157 = R.call_tir(cls.fused_softmax_cast1, (lv156,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1948 = R.call_tir(cls.matmul9, (lv157, lv1939), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv158 = R.call_tir(cls.fused_transpose7_reshape4, (lv1948,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv159: R.Tensor((4096, 512), dtype="uint32") = params[64]
lv160: R.Tensor((4096, 128), dtype="float16") = params[65]
lv12_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv159, lv160, lv158, lv11_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv553: R.Tensor((4096,), dtype="float16") = params[71]
lv1954 = R.call_tir(cls.rms_norm1, (lv12_3, lv553), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv163: R.Tensor((22016, 512), dtype="uint32") = params[66]
lv164: R.Tensor((22016, 128), dtype="float16") = params[67]
lv13_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv163, lv164, lv1954), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv166 = R.call_tir(cls.fused_split_silu_multiply, (lv13_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv167: R.Tensor((4096, 1376), dtype="uint32") = params[68]
lv168: R.Tensor((4096, 344), dtype="float16") = params[69]
lv13_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv167, lv168, lv166, lv12_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv560: R.Tensor((4096,), dtype="float16") = params[80]
lv1965 = R.call_tir(cls.rms_norm1, (lv13_2, lv560), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv171: R.Tensor((12288, 512), dtype="uint32") = params[72]
lv172: R.Tensor((12288, 128), dtype="float16") = params[73]
lv14_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv171, lv172, lv1965), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv28_1 = R.call_tir(cls.split_rotary, (lv14_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv174: R.Tensor((1, 1, 4096), dtype="float16") = lv28_1[0]
lv175 = R.call_tir(cls.fused_reshape2_transpose5, (lv174,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv176: R.Tensor((1, 1, 4096), dtype="float16") = lv28_1[1]
lv177 = R.call_tir(cls.fused_reshape2_squeeze, (lv176,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv178: R.Tensor((1, 1, 4096), dtype="float16") = lv28_1[2]
lv179 = R.call_tir(cls.fused_reshape2_squeeze, (lv178,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv1979: R.Object = kv_cache[14]
lv1980: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1979, lv177, sinfo_args=(R.Object,))
lv1981: R.Object = kv_cache[15]
lv1982: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1981, lv179, sinfo_args=(R.Object,))
lv1983: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1980, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1984: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1982, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv1985 = R.call_tir(cls.reshape3, (lv1983,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1986 = R.call_tir(cls.reshape3, (lv1984,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1988 = R.call_tir(cls.transpose6, (lv1985,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1989 = R.call_tir(cls.transpose6, (lv1986,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv180 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv175, lv1988, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv181 = R.call_tir(cls.fused_softmax_cast1, (lv180,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv1998 = R.call_tir(cls.matmul9, (lv181, lv1989), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv182 = R.call_tir(cls.fused_transpose7_reshape4, (lv1998,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv183: R.Tensor((4096, 512), dtype="uint32") = params[74]
lv184: R.Tensor((4096, 128), dtype="float16") = params[75]
lv14_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv183, lv184, lv182, lv13_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv567: R.Tensor((4096,), dtype="float16") = params[81]
lv2004 = R.call_tir(cls.rms_norm1, (lv14_2, lv567), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv187: R.Tensor((22016, 512), dtype="uint32") = params[76]
lv188: R.Tensor((22016, 128), dtype="float16") = params[77]
lv15_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv187, lv188, lv2004), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv190 = R.call_tir(cls.fused_split_silu_multiply, (lv15_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv191: R.Tensor((4096, 1376), dtype="uint32") = params[78]
lv192: R.Tensor((4096, 344), dtype="float16") = params[79]
lv15_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv191, lv192, lv190, lv14_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv574: R.Tensor((4096,), dtype="float16") = params[90]
lv2015 = R.call_tir(cls.rms_norm1, (lv15_2, lv574), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv195: R.Tensor((12288, 512), dtype="uint32") = params[82]
lv196: R.Tensor((12288, 128), dtype="float16") = params[83]
lv16_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv195, lv196, lv2015), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv32_1 = R.call_tir(cls.split_rotary, (lv16_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv198: R.Tensor((1, 1, 4096), dtype="float16") = lv32_1[0]
lv199 = R.call_tir(cls.fused_reshape2_transpose5, (lv198,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv200: R.Tensor((1, 1, 4096), dtype="float16") = lv32_1[1]
lv201 = R.call_tir(cls.fused_reshape2_squeeze, (lv200,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv202: R.Tensor((1, 1, 4096), dtype="float16") = lv32_1[2]
lv203 = R.call_tir(cls.fused_reshape2_squeeze, (lv202,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2029: R.Object = kv_cache[16]
lv2030: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2029, lv201, sinfo_args=(R.Object,))
lv2031: R.Object = kv_cache[17]
lv2032: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2031, lv203, sinfo_args=(R.Object,))
lv2033: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2030, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2034: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2032, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2035 = R.call_tir(cls.reshape3, (lv2033,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2036 = R.call_tir(cls.reshape3, (lv2034,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2038 = R.call_tir(cls.transpose6, (lv2035,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2039 = R.call_tir(cls.transpose6, (lv2036,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv204 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv199, lv2038, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv205 = R.call_tir(cls.fused_softmax_cast1, (lv204,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2048 = R.call_tir(cls.matmul9, (lv205, lv2039), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv206 = R.call_tir(cls.fused_transpose7_reshape4, (lv2048,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv207: R.Tensor((4096, 512), dtype="uint32") = params[84]
lv208: R.Tensor((4096, 128), dtype="float16") = params[85]
lv16_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv207, lv208, lv206, lv15_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv581: R.Tensor((4096,), dtype="float16") = params[91]
lv2054 = R.call_tir(cls.rms_norm1, (lv16_3, lv581), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv211: R.Tensor((22016, 512), dtype="uint32") = params[86]
lv212: R.Tensor((22016, 128), dtype="float16") = params[87]
lv17 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv211, lv212, lv2054), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv214 = R.call_tir(cls.fused_split_silu_multiply, (lv17,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv215: R.Tensor((4096, 1376), dtype="uint32") = params[88]
lv216: R.Tensor((4096, 344), dtype="float16") = params[89]
lv17_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv215, lv216, lv214, lv16_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv588: R.Tensor((4096,), dtype="float16") = params[100]
lv2065 = R.call_tir(cls.rms_norm1, (lv17_1, lv588), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv219: R.Tensor((12288, 512), dtype="uint32") = params[92]
lv220: R.Tensor((12288, 128), dtype="float16") = params[93]
lv18 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv219, lv220, lv2065), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv36_1 = R.call_tir(cls.split_rotary, (lv18, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv222: R.Tensor((1, 1, 4096), dtype="float16") = lv36_1[0]
lv223 = R.call_tir(cls.fused_reshape2_transpose5, (lv222,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv224: R.Tensor((1, 1, 4096), dtype="float16") = lv36_1[1]
lv225 = R.call_tir(cls.fused_reshape2_squeeze, (lv224,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv226: R.Tensor((1, 1, 4096), dtype="float16") = lv36_1[2]
lv227 = R.call_tir(cls.fused_reshape2_squeeze, (lv226,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2079: R.Object = kv_cache[18]
lv2080: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2079, lv225, sinfo_args=(R.Object,))
lv2081: R.Object = kv_cache[19]
lv2082: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2081, lv227, sinfo_args=(R.Object,))
lv2083: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2080, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2084: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2082, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2085 = R.call_tir(cls.reshape3, (lv2083,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2086 = R.call_tir(cls.reshape3, (lv2084,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2088 = R.call_tir(cls.transpose6, (lv2085,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2089 = R.call_tir(cls.transpose6, (lv2086,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv228 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv223, lv2088, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv229 = R.call_tir(cls.fused_softmax_cast1, (lv228,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2098 = R.call_tir(cls.matmul9, (lv229, lv2089), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv230 = R.call_tir(cls.fused_transpose7_reshape4, (lv2098,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv231: R.Tensor((4096, 512), dtype="uint32") = params[94]
lv232: R.Tensor((4096, 128), dtype="float16") = params[95]
lv18_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv231, lv232, lv230, lv17_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv595: R.Tensor((4096,), dtype="float16") = params[101]
lv2104 = R.call_tir(cls.rms_norm1, (lv18_1, lv595), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv235: R.Tensor((22016, 512), dtype="uint32") = params[96]
lv236: R.Tensor((22016, 128), dtype="float16") = params[97]
lv19_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv235, lv236, lv2104), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv238 = R.call_tir(cls.fused_split_silu_multiply, (lv19_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv239: R.Tensor((4096, 1376), dtype="uint32") = params[98]
lv240: R.Tensor((4096, 344), dtype="float16") = params[99]
lv19_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv239, lv240, lv238, lv18_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv602: R.Tensor((4096,), dtype="float16") = params[110]
lv2115 = R.call_tir(cls.rms_norm1, (lv19_2, lv602), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv243: R.Tensor((12288, 512), dtype="uint32") = params[102]
lv244: R.Tensor((12288, 128), dtype="float16") = params[103]
lv20_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv243, lv244, lv2115), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv40_1 = R.call_tir(cls.split_rotary, (lv20_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv246: R.Tensor((1, 1, 4096), dtype="float16") = lv40_1[0]
lv247 = R.call_tir(cls.fused_reshape2_transpose5, (lv246,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv248: R.Tensor((1, 1, 4096), dtype="float16") = lv40_1[1]
lv249 = R.call_tir(cls.fused_reshape2_squeeze, (lv248,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv250: R.Tensor((1, 1, 4096), dtype="float16") = lv40_1[2]
lv251 = R.call_tir(cls.fused_reshape2_squeeze, (lv250,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2129: R.Object = kv_cache[20]
lv2130: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2129, lv249, sinfo_args=(R.Object,))
lv2131: R.Object = kv_cache[21]
lv2132: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2131, lv251, sinfo_args=(R.Object,))
lv2133: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2130, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2134: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2132, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2135 = R.call_tir(cls.reshape3, (lv2133,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2136 = R.call_tir(cls.reshape3, (lv2134,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2138 = R.call_tir(cls.transpose6, (lv2135,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2139 = R.call_tir(cls.transpose6, (lv2136,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv252 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv247, lv2138, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv253 = R.call_tir(cls.fused_softmax_cast1, (lv252,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2148 = R.call_tir(cls.matmul9, (lv253, lv2139), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv254 = R.call_tir(cls.fused_transpose7_reshape4, (lv2148,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv255: R.Tensor((4096, 512), dtype="uint32") = params[104]
lv256: R.Tensor((4096, 128), dtype="float16") = params[105]
lv20_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv255, lv256, lv254, lv19_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv609: R.Tensor((4096,), dtype="float16") = params[111]
lv2154 = R.call_tir(cls.rms_norm1, (lv20_3, lv609), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv259: R.Tensor((22016, 512), dtype="uint32") = params[106]
lv260: R.Tensor((22016, 128), dtype="float16") = params[107]
lv21 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv259, lv260, lv2154), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv262 = R.call_tir(cls.fused_split_silu_multiply, (lv21,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv263: R.Tensor((4096, 1376), dtype="uint32") = params[108]
lv264: R.Tensor((4096, 344), dtype="float16") = params[109]
lv21_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv263, lv264, lv262, lv20_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv616: R.Tensor((4096,), dtype="float16") = params[120]
lv2165 = R.call_tir(cls.rms_norm1, (lv21_1, lv616), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv267: R.Tensor((12288, 512), dtype="uint32") = params[112]
lv268: R.Tensor((12288, 128), dtype="float16") = params[113]
lv22_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv267, lv268, lv2165), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv44_1 = R.call_tir(cls.split_rotary, (lv22_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv270: R.Tensor((1, 1, 4096), dtype="float16") = lv44_1[0]
lv271 = R.call_tir(cls.fused_reshape2_transpose5, (lv270,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv272: R.Tensor((1, 1, 4096), dtype="float16") = lv44_1[1]
lv273 = R.call_tir(cls.fused_reshape2_squeeze, (lv272,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv274: R.Tensor((1, 1, 4096), dtype="float16") = lv44_1[2]
lv275 = R.call_tir(cls.fused_reshape2_squeeze, (lv274,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2179: R.Object = kv_cache[22]
lv2180: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2179, lv273, sinfo_args=(R.Object,))
lv2181: R.Object = kv_cache[23]
lv2182: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2181, lv275, sinfo_args=(R.Object,))
lv2183: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2180, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2184: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2182, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2185 = R.call_tir(cls.reshape3, (lv2183,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2186 = R.call_tir(cls.reshape3, (lv2184,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2188 = R.call_tir(cls.transpose6, (lv2185,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2189 = R.call_tir(cls.transpose6, (lv2186,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv276 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv271, lv2188, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv277 = R.call_tir(cls.fused_softmax_cast1, (lv276,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2198 = R.call_tir(cls.matmul9, (lv277, lv2189), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv278 = R.call_tir(cls.fused_transpose7_reshape4, (lv2198,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv279: R.Tensor((4096, 512), dtype="uint32") = params[114]
lv280: R.Tensor((4096, 128), dtype="float16") = params[115]
lv22_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv279, lv280, lv278, lv21_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv623: R.Tensor((4096,), dtype="float16") = params[121]
lv2204 = R.call_tir(cls.rms_norm1, (lv22_2, lv623), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv283: R.Tensor((22016, 512), dtype="uint32") = params[116]
lv284: R.Tensor((22016, 128), dtype="float16") = params[117]
lv23_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv283, lv284, lv2204), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv286 = R.call_tir(cls.fused_split_silu_multiply, (lv23_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv287: R.Tensor((4096, 1376), dtype="uint32") = params[118]
lv288: R.Tensor((4096, 344), dtype="float16") = params[119]
lv23_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv287, lv288, lv286, lv22_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv630: R.Tensor((4096,), dtype="float16") = params[130]
lv2215 = R.call_tir(cls.rms_norm1, (lv23_2, lv630), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv291: R.Tensor((12288, 512), dtype="uint32") = params[122]
lv292: R.Tensor((12288, 128), dtype="float16") = params[123]
lv24_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv291, lv292, lv2215), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv48_1 = R.call_tir(cls.split_rotary, (lv24_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv294: R.Tensor((1, 1, 4096), dtype="float16") = lv48_1[0]
lv295 = R.call_tir(cls.fused_reshape2_transpose5, (lv294,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv296: R.Tensor((1, 1, 4096), dtype="float16") = lv48_1[1]
lv297 = R.call_tir(cls.fused_reshape2_squeeze, (lv296,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv298: R.Tensor((1, 1, 4096), dtype="float16") = lv48_1[2]
lv299 = R.call_tir(cls.fused_reshape2_squeeze, (lv298,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2229: R.Object = kv_cache[24]
lv2230: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2229, lv297, sinfo_args=(R.Object,))
lv2231: R.Object = kv_cache[25]
lv2232: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2231, lv299, sinfo_args=(R.Object,))
lv2233: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2230, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2234: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2232, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2235 = R.call_tir(cls.reshape3, (lv2233,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2236 = R.call_tir(cls.reshape3, (lv2234,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2238 = R.call_tir(cls.transpose6, (lv2235,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2239 = R.call_tir(cls.transpose6, (lv2236,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv300 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv295, lv2238, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv301 = R.call_tir(cls.fused_softmax_cast1, (lv300,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2248 = R.call_tir(cls.matmul9, (lv301, lv2239), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv302 = R.call_tir(cls.fused_transpose7_reshape4, (lv2248,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv303: R.Tensor((4096, 512), dtype="uint32") = params[124]
lv304: R.Tensor((4096, 128), dtype="float16") = params[125]
lv24_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv303, lv304, lv302, lv23_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv637: R.Tensor((4096,), dtype="float16") = params[131]
lv2254 = R.call_tir(cls.rms_norm1, (lv24_3, lv637), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv307: R.Tensor((22016, 512), dtype="uint32") = params[126]
lv308: R.Tensor((22016, 128), dtype="float16") = params[127]
lv25 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv307, lv308, lv2254), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv310 = R.call_tir(cls.fused_split_silu_multiply, (lv25,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv311: R.Tensor((4096, 1376), dtype="uint32") = params[128]
lv312: R.Tensor((4096, 344), dtype="float16") = params[129]
lv25_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv311, lv312, lv310, lv24_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv644: R.Tensor((4096,), dtype="float16") = params[140]
lv2265 = R.call_tir(cls.rms_norm1, (lv25_1, lv644), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv315: R.Tensor((12288, 512), dtype="uint32") = params[132]
lv316: R.Tensor((12288, 128), dtype="float16") = params[133]
lv26 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv315, lv316, lv2265), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv52_1 = R.call_tir(cls.split_rotary, (lv26, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv318: R.Tensor((1, 1, 4096), dtype="float16") = lv52_1[0]
lv319 = R.call_tir(cls.fused_reshape2_transpose5, (lv318,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv320: R.Tensor((1, 1, 4096), dtype="float16") = lv52_1[1]
lv321 = R.call_tir(cls.fused_reshape2_squeeze, (lv320,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv322: R.Tensor((1, 1, 4096), dtype="float16") = lv52_1[2]
lv323 = R.call_tir(cls.fused_reshape2_squeeze, (lv322,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2279: R.Object = kv_cache[26]
lv2280: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2279, lv321, sinfo_args=(R.Object,))
lv2281: R.Object = kv_cache[27]
lv2282: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2281, lv323, sinfo_args=(R.Object,))
lv2283: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2280, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2284: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2282, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2285 = R.call_tir(cls.reshape3, (lv2283,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2286 = R.call_tir(cls.reshape3, (lv2284,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2288 = R.call_tir(cls.transpose6, (lv2285,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2289 = R.call_tir(cls.transpose6, (lv2286,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv324 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv319, lv2288, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv325 = R.call_tir(cls.fused_softmax_cast1, (lv324,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2298 = R.call_tir(cls.matmul9, (lv325, lv2289), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv326 = R.call_tir(cls.fused_transpose7_reshape4, (lv2298,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv327: R.Tensor((4096, 512), dtype="uint32") = params[134]
lv328: R.Tensor((4096, 128), dtype="float16") = params[135]
lv26_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv327, lv328, lv326, lv25_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv651: R.Tensor((4096,), dtype="float16") = params[141]
lv2304 = R.call_tir(cls.rms_norm1, (lv26_1, lv651), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv331: R.Tensor((22016, 512), dtype="uint32") = params[136]
lv332: R.Tensor((22016, 128), dtype="float16") = params[137]
lv27_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv331, lv332, lv2304), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv334 = R.call_tir(cls.fused_split_silu_multiply, (lv27_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv335: R.Tensor((4096, 1376), dtype="uint32") = params[138]
lv336: R.Tensor((4096, 344), dtype="float16") = params[139]
lv27_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv335, lv336, lv334, lv26_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv658: R.Tensor((4096,), dtype="float16") = params[150]
lv2315 = R.call_tir(cls.rms_norm1, (lv27_2, lv658), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv339: R.Tensor((12288, 512), dtype="uint32") = params[142]
lv340: R.Tensor((12288, 128), dtype="float16") = params[143]
lv28_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv339, lv340, lv2315), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv56_1 = R.call_tir(cls.split_rotary, (lv28_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv342: R.Tensor((1, 1, 4096), dtype="float16") = lv56_1[0]
lv343 = R.call_tir(cls.fused_reshape2_transpose5, (lv342,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv344: R.Tensor((1, 1, 4096), dtype="float16") = lv56_1[1]
lv345 = R.call_tir(cls.fused_reshape2_squeeze, (lv344,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv346: R.Tensor((1, 1, 4096), dtype="float16") = lv56_1[2]
lv347 = R.call_tir(cls.fused_reshape2_squeeze, (lv346,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2329: R.Object = kv_cache[28]
lv2330: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2329, lv345, sinfo_args=(R.Object,))
lv2331: R.Object = kv_cache[29]
lv2332: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2331, lv347, sinfo_args=(R.Object,))
lv2333: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2330, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2334: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2332, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2335 = R.call_tir(cls.reshape3, (lv2333,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2336 = R.call_tir(cls.reshape3, (lv2334,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2338 = R.call_tir(cls.transpose6, (lv2335,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2339 = R.call_tir(cls.transpose6, (lv2336,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv348 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv343, lv2338, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv349 = R.call_tir(cls.fused_softmax_cast1, (lv348,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2348 = R.call_tir(cls.matmul9, (lv349, lv2339), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv350 = R.call_tir(cls.fused_transpose7_reshape4, (lv2348,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv351: R.Tensor((4096, 512), dtype="uint32") = params[144]
lv352: R.Tensor((4096, 128), dtype="float16") = params[145]
lv28_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv351, lv352, lv350, lv27_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv665: R.Tensor((4096,), dtype="float16") = params[151]
lv2354 = R.call_tir(cls.rms_norm1, (lv28_3, lv665), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv355: R.Tensor((22016, 512), dtype="uint32") = params[146]
lv356: R.Tensor((22016, 128), dtype="float16") = params[147]
lv29 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv355, lv356, lv2354), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv358 = R.call_tir(cls.fused_split_silu_multiply, (lv29,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv359: R.Tensor((4096, 1376), dtype="uint32") = params[148]
lv360: R.Tensor((4096, 344), dtype="float16") = params[149]
lv29_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv359, lv360, lv358, lv28_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv672: R.Tensor((4096,), dtype="float16") = params[160]
lv2365 = R.call_tir(cls.rms_norm1, (lv29_1, lv672), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv363: R.Tensor((12288, 512), dtype="uint32") = params[152]
lv364: R.Tensor((12288, 128), dtype="float16") = params[153]
lv30_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv363, lv364, lv2365), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv60_1 = R.call_tir(cls.split_rotary, (lv30_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv366: R.Tensor((1, 1, 4096), dtype="float16") = lv60_1[0]
lv367 = R.call_tir(cls.fused_reshape2_transpose5, (lv366,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv368: R.Tensor((1, 1, 4096), dtype="float16") = lv60_1[1]
lv369 = R.call_tir(cls.fused_reshape2_squeeze, (lv368,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv370: R.Tensor((1, 1, 4096), dtype="float16") = lv60_1[2]
lv371 = R.call_tir(cls.fused_reshape2_squeeze, (lv370,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2379: R.Object = kv_cache[30]
lv2380: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2379, lv369, sinfo_args=(R.Object,))
lv2381: R.Object = kv_cache[31]
lv2382: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2381, lv371, sinfo_args=(R.Object,))
lv2383: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2380, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2384: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2382, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2385 = R.call_tir(cls.reshape3, (lv2383,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2386 = R.call_tir(cls.reshape3, (lv2384,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2388 = R.call_tir(cls.transpose6, (lv2385,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2389 = R.call_tir(cls.transpose6, (lv2386,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv372 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv367, lv2388, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv373 = R.call_tir(cls.fused_softmax_cast1, (lv372,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2398 = R.call_tir(cls.matmul9, (lv373, lv2389), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv374 = R.call_tir(cls.fused_transpose7_reshape4, (lv2398,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv375: R.Tensor((4096, 512), dtype="uint32") = params[154]
lv376: R.Tensor((4096, 128), dtype="float16") = params[155]
lv30_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv375, lv376, lv374, lv29_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv679: R.Tensor((4096,), dtype="float16") = params[161]
lv2404 = R.call_tir(cls.rms_norm1, (lv30_2, lv679), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv379: R.Tensor((22016, 512), dtype="uint32") = params[156]
lv380: R.Tensor((22016, 128), dtype="float16") = params[157]
lv31_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv379, lv380, lv2404), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv382 = R.call_tir(cls.fused_split_silu_multiply, (lv31_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv383: R.Tensor((4096, 1376), dtype="uint32") = params[158]
lv384: R.Tensor((4096, 344), dtype="float16") = params[159]
lv31_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv383, lv384, lv382, lv30_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv686: R.Tensor((4096,), dtype="float16") = params[170]
lv2415 = R.call_tir(cls.rms_norm1, (lv31_2, lv686), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv387: R.Tensor((12288, 512), dtype="uint32") = params[162]
lv388: R.Tensor((12288, 128), dtype="float16") = params[163]
lv32_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv387, lv388, lv2415), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv64_1 = R.call_tir(cls.split_rotary, (lv32_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv390: R.Tensor((1, 1, 4096), dtype="float16") = lv64_1[0]
lv391 = R.call_tir(cls.fused_reshape2_transpose5, (lv390,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv392: R.Tensor((1, 1, 4096), dtype="float16") = lv64_1[1]
lv393 = R.call_tir(cls.fused_reshape2_squeeze, (lv392,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv394: R.Tensor((1, 1, 4096), dtype="float16") = lv64_1[2]
lv395 = R.call_tir(cls.fused_reshape2_squeeze, (lv394,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2429: R.Object = kv_cache[32]
lv2430: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2429, lv393, sinfo_args=(R.Object,))
lv2431: R.Object = kv_cache[33]
lv2432: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2431, lv395, sinfo_args=(R.Object,))
lv2433: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2430, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2434: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2432, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2435 = R.call_tir(cls.reshape3, (lv2433,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2436 = R.call_tir(cls.reshape3, (lv2434,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2438 = R.call_tir(cls.transpose6, (lv2435,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2439 = R.call_tir(cls.transpose6, (lv2436,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv396 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv391, lv2438, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv397 = R.call_tir(cls.fused_softmax_cast1, (lv396,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2448 = R.call_tir(cls.matmul9, (lv397, lv2439), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv398 = R.call_tir(cls.fused_transpose7_reshape4, (lv2448,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv399: R.Tensor((4096, 512), dtype="uint32") = params[164]
lv400: R.Tensor((4096, 128), dtype="float16") = params[165]
lv32_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv399, lv400, lv398, lv31_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv693: R.Tensor((4096,), dtype="float16") = params[171]
lv2454 = R.call_tir(cls.rms_norm1, (lv32_3, lv693), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv403: R.Tensor((22016, 512), dtype="uint32") = params[166]
lv404: R.Tensor((22016, 128), dtype="float16") = params[167]
lv33_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv403, lv404, lv2454), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv406 = R.call_tir(cls.fused_split_silu_multiply, (lv33_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv407: R.Tensor((4096, 1376), dtype="uint32") = params[168]
lv408: R.Tensor((4096, 344), dtype="float16") = params[169]
lv33_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv407, lv408, lv406, lv32_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv700: R.Tensor((4096,), dtype="float16") = params[180]
lv2465 = R.call_tir(cls.rms_norm1, (lv33_2, lv700), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv411: R.Tensor((12288, 512), dtype="uint32") = params[172]
lv412: R.Tensor((12288, 128), dtype="float16") = params[173]
lv34_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv411, lv412, lv2465), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv68_1 = R.call_tir(cls.split_rotary, (lv34_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv414: R.Tensor((1, 1, 4096), dtype="float16") = lv68_1[0]
lv415 = R.call_tir(cls.fused_reshape2_transpose5, (lv414,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv416: R.Tensor((1, 1, 4096), dtype="float16") = lv68_1[1]
lv417 = R.call_tir(cls.fused_reshape2_squeeze, (lv416,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv418: R.Tensor((1, 1, 4096), dtype="float16") = lv68_1[2]
lv419 = R.call_tir(cls.fused_reshape2_squeeze, (lv418,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2479: R.Object = kv_cache[34]
lv2480: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2479, lv417, sinfo_args=(R.Object,))
lv2481: R.Object = kv_cache[35]
lv2482: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2481, lv419, sinfo_args=(R.Object,))
lv2483: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2480, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2484: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2482, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2485 = R.call_tir(cls.reshape3, (lv2483,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2486 = R.call_tir(cls.reshape3, (lv2484,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2488 = R.call_tir(cls.transpose6, (lv2485,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2489 = R.call_tir(cls.transpose6, (lv2486,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv420 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv415, lv2488, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv421 = R.call_tir(cls.fused_softmax_cast1, (lv420,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2498 = R.call_tir(cls.matmul9, (lv421, lv2489), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv422 = R.call_tir(cls.fused_transpose7_reshape4, (lv2498,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv423: R.Tensor((4096, 512), dtype="uint32") = params[174]
lv424: R.Tensor((4096, 128), dtype="float16") = params[175]
lv34_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv423, lv424, lv422, lv33_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv707: R.Tensor((4096,), dtype="float16") = params[181]
lv2504 = R.call_tir(cls.rms_norm1, (lv34_2, lv707), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv427: R.Tensor((22016, 512), dtype="uint32") = params[176]
lv428: R.Tensor((22016, 128), dtype="float16") = params[177]
lv35_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv427, lv428, lv2504), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv430 = R.call_tir(cls.fused_split_silu_multiply, (lv35_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv431: R.Tensor((4096, 1376), dtype="uint32") = params[178]
lv432: R.Tensor((4096, 344), dtype="float16") = params[179]
lv35_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv431, lv432, lv430, lv34_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv714: R.Tensor((4096,), dtype="float16") = params[190]
lv2515 = R.call_tir(cls.rms_norm1, (lv35_2, lv714), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv435: R.Tensor((12288, 512), dtype="uint32") = params[182]
lv436: R.Tensor((12288, 128), dtype="float16") = params[183]
lv36_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv435, lv436, lv2515), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv72_1 = R.call_tir(cls.split_rotary, (lv36_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv438: R.Tensor((1, 1, 4096), dtype="float16") = lv72_1[0]
lv439 = R.call_tir(cls.fused_reshape2_transpose5, (lv438,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv440: R.Tensor((1, 1, 4096), dtype="float16") = lv72_1[1]
lv441 = R.call_tir(cls.fused_reshape2_squeeze, (lv440,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv442: R.Tensor((1, 1, 4096), dtype="float16") = lv72_1[2]
lv443 = R.call_tir(cls.fused_reshape2_squeeze, (lv442,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2529: R.Object = kv_cache[36]
lv2530: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2529, lv441, sinfo_args=(R.Object,))
lv2531: R.Object = kv_cache[37]
lv2532: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2531, lv443, sinfo_args=(R.Object,))
lv2533: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2530, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2534: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2532, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2535 = R.call_tir(cls.reshape3, (lv2533,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2536 = R.call_tir(cls.reshape3, (lv2534,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2538 = R.call_tir(cls.transpose6, (lv2535,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2539 = R.call_tir(cls.transpose6, (lv2536,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv444 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv439, lv2538, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv445 = R.call_tir(cls.fused_softmax_cast1, (lv444,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2548 = R.call_tir(cls.matmul9, (lv445, lv2539), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv446 = R.call_tir(cls.fused_transpose7_reshape4, (lv2548,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv447: R.Tensor((4096, 512), dtype="uint32") = params[184]
lv448: R.Tensor((4096, 128), dtype="float16") = params[185]
lv36_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv447, lv448, lv446, lv35_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv721: R.Tensor((4096,), dtype="float16") = params[191]
lv2554 = R.call_tir(cls.rms_norm1, (lv36_3, lv721), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv451: R.Tensor((22016, 512), dtype="uint32") = params[186]
lv452: R.Tensor((22016, 128), dtype="float16") = params[187]
lv37_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv451, lv452, lv2554), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv454 = R.call_tir(cls.fused_split_silu_multiply, (lv37_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv455: R.Tensor((4096, 1376), dtype="uint32") = params[188]
lv456: R.Tensor((4096, 344), dtype="float16") = params[189]
lv37_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv455, lv456, lv454, lv36_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv728: R.Tensor((4096,), dtype="float16") = params[200]
lv2565 = R.call_tir(cls.rms_norm1, (lv37_2, lv728), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv459: R.Tensor((12288, 512), dtype="uint32") = params[192]
lv460_1: R.Tensor((12288, 128), dtype="float16") = params[193]
lv38_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv459, lv460_1, lv2565), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv76_1 = R.call_tir(cls.split_rotary, (lv38_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv462: R.Tensor((1, 1, 4096), dtype="float16") = lv76_1[0]
lv463 = R.call_tir(cls.fused_reshape2_transpose5, (lv462,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv464_1: R.Tensor((1, 1, 4096), dtype="float16") = lv76_1[1]
lv465_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv464_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv466: R.Tensor((1, 1, 4096), dtype="float16") = lv76_1[2]
lv467 = R.call_tir(cls.fused_reshape2_squeeze, (lv466,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2579: R.Object = kv_cache[38]
lv2580: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2579, lv465_1, sinfo_args=(R.Object,))
lv2581: R.Object = kv_cache[39]
lv2582: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2581, lv467, sinfo_args=(R.Object,))
lv2583: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2580, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2584: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2582, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2585 = R.call_tir(cls.reshape3, (lv2583,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2586 = R.call_tir(cls.reshape3, (lv2584,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2588 = R.call_tir(cls.transpose6, (lv2585,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2589 = R.call_tir(cls.transpose6, (lv2586,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv468 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv463, lv2588, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv469_1 = R.call_tir(cls.fused_softmax_cast1, (lv468,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2598 = R.call_tir(cls.matmul9, (lv469_1, lv2589), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv470 = R.call_tir(cls.fused_transpose7_reshape4, (lv2598,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv471: R.Tensor((4096, 512), dtype="uint32") = params[194]
lv472: R.Tensor((4096, 128), dtype="float16") = params[195]
lv38_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv471, lv472, lv470, lv37_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv735: R.Tensor((4096,), dtype="float16") = params[201]
lv2604 = R.call_tir(cls.rms_norm1, (lv38_2, lv735), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv475: R.Tensor((22016, 512), dtype="uint32") = params[196]
lv476_1: R.Tensor((22016, 128), dtype="float16") = params[197]
lv39_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv475, lv476_1, lv2604), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv478 = R.call_tir(cls.fused_split_silu_multiply, (lv39_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv479: R.Tensor((4096, 1376), dtype="uint32") = params[198]
lv480: R.Tensor((4096, 344), dtype="float16") = params[199]
lv39_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv479, lv480, lv478, lv38_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv742: R.Tensor((4096,), dtype="float16") = params[210]
lv2615 = R.call_tir(cls.rms_norm1, (lv39_2, lv742), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv483_1: R.Tensor((12288, 512), dtype="uint32") = params[202]
lv484: R.Tensor((12288, 128), dtype="float16") = params[203]
lv40_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv483_1, lv484, lv2615), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv80_1 = R.call_tir(cls.split_rotary, (lv40_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv486: R.Tensor((1, 1, 4096), dtype="float16") = lv80_1[0]
lv487 = R.call_tir(cls.fused_reshape2_transpose5, (lv486,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv488: R.Tensor((1, 1, 4096), dtype="float16") = lv80_1[1]
lv489 = R.call_tir(cls.fused_reshape2_squeeze, (lv488,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv490_1: R.Tensor((1, 1, 4096), dtype="float16") = lv80_1[2]
lv491 = R.call_tir(cls.fused_reshape2_squeeze, (lv490_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2629: R.Object = kv_cache[40]
lv2630: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2629, lv489, sinfo_args=(R.Object,))
lv2631: R.Object = kv_cache[41]
lv2632: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2631, lv491, sinfo_args=(R.Object,))
lv2633: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2630, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2634: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2632, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2635 = R.call_tir(cls.reshape3, (lv2633,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2636 = R.call_tir(cls.reshape3, (lv2634,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2638 = R.call_tir(cls.transpose6, (lv2635,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2639 = R.call_tir(cls.transpose6, (lv2636,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv492 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv487, lv2638, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv493 = R.call_tir(cls.fused_softmax_cast1, (lv492,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2648 = R.call_tir(cls.matmul9, (lv493, lv2639), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv494 = R.call_tir(cls.fused_transpose7_reshape4, (lv2648,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv495: R.Tensor((4096, 512), dtype="uint32") = params[204]
lv496: R.Tensor((4096, 128), dtype="float16") = params[205]
lv40_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv495, lv496, lv494, lv39_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv749: R.Tensor((4096,), dtype="float16") = params[211]
lv2654 = R.call_tir(cls.rms_norm1, (lv40_3, lv749), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv499: R.Tensor((22016, 512), dtype="uint32") = params[206]
lv500: R.Tensor((22016, 128), dtype="float16") = params[207]
lv41 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv499, lv500, lv2654), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv502 = R.call_tir(cls.fused_split_silu_multiply, (lv41,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv503: R.Tensor((4096, 1376), dtype="uint32") = params[208]
lv504_1: R.Tensor((4096, 344), dtype="float16") = params[209]
lv41_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv503, lv504_1, lv502, lv40_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv756: R.Tensor((4096,), dtype="float16") = params[220]
lv2665 = R.call_tir(cls.rms_norm1, (lv41_1, lv756), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv507: R.Tensor((12288, 512), dtype="uint32") = params[212]
lv508: R.Tensor((12288, 128), dtype="float16") = params[213]
lv42 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv507, lv508, lv2665), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv84_1 = R.call_tir(cls.split_rotary, (lv42, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv510: R.Tensor((1, 1, 4096), dtype="float16") = lv84_1[0]
lv511_1 = R.call_tir(cls.fused_reshape2_transpose5, (lv510,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv512: R.Tensor((1, 1, 4096), dtype="float16") = lv84_1[1]
lv513 = R.call_tir(cls.fused_reshape2_squeeze, (lv512,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv514: R.Tensor((1, 1, 4096), dtype="float16") = lv84_1[2]
lv515 = R.call_tir(cls.fused_reshape2_squeeze, (lv514,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2679: R.Object = kv_cache[42]
lv2680: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2679, lv513, sinfo_args=(R.Object,))
lv2681: R.Object = kv_cache[43]
lv2682: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2681, lv515, sinfo_args=(R.Object,))
lv2683: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2680, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2684: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2682, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2685 = R.call_tir(cls.reshape3, (lv2683,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2686 = R.call_tir(cls.reshape3, (lv2684,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2688 = R.call_tir(cls.transpose6, (lv2685,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2689 = R.call_tir(cls.transpose6, (lv2686,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv516 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv511_1, lv2688, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv517 = R.call_tir(cls.fused_softmax_cast1, (lv516,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2698 = R.call_tir(cls.matmul9, (lv517, lv2689), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv518_1 = R.call_tir(cls.fused_transpose7_reshape4, (lv2698,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv519: R.Tensor((4096, 512), dtype="uint32") = params[214]
lv520: R.Tensor((4096, 128), dtype="float16") = params[215]
lv42_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv519, lv520, lv518_1, lv41_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv763: R.Tensor((4096,), dtype="float16") = params[221]
lv2704 = R.call_tir(cls.rms_norm1, (lv42_1, lv763), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv523: R.Tensor((22016, 512), dtype="uint32") = params[216]
lv524: R.Tensor((22016, 128), dtype="float16") = params[217]
lv43_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv523, lv524, lv2704), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv526 = R.call_tir(cls.fused_split_silu_multiply, (lv43_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv527: R.Tensor((4096, 1376), dtype="uint32") = params[218]
lv528: R.Tensor((4096, 344), dtype="float16") = params[219]
lv43_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv527, lv528, lv526, lv42_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv770: R.Tensor((4096,), dtype="float16") = params[230]
lv2715 = R.call_tir(cls.rms_norm1, (lv43_2, lv770), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv531: R.Tensor((12288, 512), dtype="uint32") = params[222]
lv532_1: R.Tensor((12288, 128), dtype="float16") = params[223]
lv44_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv531, lv532_1, lv2715), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv88_1 = R.call_tir(cls.split_rotary, (lv44_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv534: R.Tensor((1, 1, 4096), dtype="float16") = lv88_1[0]
lv535 = R.call_tir(cls.fused_reshape2_transpose5, (lv534,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv536: R.Tensor((1, 1, 4096), dtype="float16") = lv88_1[1]
lv537 = R.call_tir(cls.fused_reshape2_squeeze, (lv536,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv538: R.Tensor((1, 1, 4096), dtype="float16") = lv88_1[2]
lv539_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv538,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2729: R.Object = kv_cache[44]
lv2730: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2729, lv537, sinfo_args=(R.Object,))
lv2731: R.Object = kv_cache[45]
lv2732: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2731, lv539_1, sinfo_args=(R.Object,))
lv2733: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2730, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2734: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2732, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2735 = R.call_tir(cls.reshape3, (lv2733,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2736 = R.call_tir(cls.reshape3, (lv2734,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2738 = R.call_tir(cls.transpose6, (lv2735,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2739 = R.call_tir(cls.transpose6, (lv2736,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv540 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv535, lv2738, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv541 = R.call_tir(cls.fused_softmax_cast1, (lv540,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2748 = R.call_tir(cls.matmul9, (lv541, lv2739), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv542 = R.call_tir(cls.fused_transpose7_reshape4, (lv2748,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv543: R.Tensor((4096, 512), dtype="uint32") = params[224]
lv544: R.Tensor((4096, 128), dtype="float16") = params[225]
lv44_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv543, lv544, lv542, lv43_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv777: R.Tensor((4096,), dtype="float16") = params[231]
lv2754 = R.call_tir(cls.rms_norm1, (lv44_3, lv777), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv547: R.Tensor((22016, 512), dtype="uint32") = params[226]
lv548: R.Tensor((22016, 128), dtype="float16") = params[227]
lv45 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv547, lv548, lv2754), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv550 = R.call_tir(cls.fused_split_silu_multiply, (lv45,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv551: R.Tensor((4096, 1376), dtype="uint32") = params[228]
lv552: R.Tensor((4096, 344), dtype="float16") = params[229]
lv45_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv551, lv552, lv550, lv44_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv784: R.Tensor((4096,), dtype="float16") = params[240]
lv2765 = R.call_tir(cls.rms_norm1, (lv45_1, lv784), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv555: R.Tensor((12288, 512), dtype="uint32") = params[232]
lv556: R.Tensor((12288, 128), dtype="float16") = params[233]
lv46_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv555, lv556, lv2765), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv92_1 = R.call_tir(cls.split_rotary, (lv46_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv558: R.Tensor((1, 1, 4096), dtype="float16") = lv92_1[0]
lv559 = R.call_tir(cls.fused_reshape2_transpose5, (lv558,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv560_1: R.Tensor((1, 1, 4096), dtype="float16") = lv92_1[1]
lv561 = R.call_tir(cls.fused_reshape2_squeeze, (lv560_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv562: R.Tensor((1, 1, 4096), dtype="float16") = lv92_1[2]
lv563 = R.call_tir(cls.fused_reshape2_squeeze, (lv562,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2779: R.Object = kv_cache[46]
lv2780: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2779, lv561, sinfo_args=(R.Object,))
lv2781: R.Object = kv_cache[47]
lv2782: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2781, lv563, sinfo_args=(R.Object,))
lv2783: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2780, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2784: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2782, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2785 = R.call_tir(cls.reshape3, (lv2783,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2786 = R.call_tir(cls.reshape3, (lv2784,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2788 = R.call_tir(cls.transpose6, (lv2785,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2789 = R.call_tir(cls.transpose6, (lv2786,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv564 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv559, lv2788, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv565 = R.call_tir(cls.fused_softmax_cast1, (lv564,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2798 = R.call_tir(cls.matmul9, (lv565, lv2789), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv566 = R.call_tir(cls.fused_transpose7_reshape4, (lv2798,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv567_1: R.Tensor((4096, 512), dtype="uint32") = params[234]
lv568: R.Tensor((4096, 128), dtype="float16") = params[235]
lv46_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv567_1, lv568, lv566, lv45_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv791: R.Tensor((4096,), dtype="float16") = params[241]
lv2804 = R.call_tir(cls.rms_norm1, (lv46_2, lv791), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv571: R.Tensor((22016, 512), dtype="uint32") = params[236]
lv572: R.Tensor((22016, 128), dtype="float16") = params[237]
lv47_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv571, lv572, lv2804), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv574_1 = R.call_tir(cls.fused_split_silu_multiply, (lv47_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv575: R.Tensor((4096, 1376), dtype="uint32") = params[238]
lv576: R.Tensor((4096, 344), dtype="float16") = params[239]
lv47_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv575, lv576, lv574_1, lv46_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv798: R.Tensor((4096,), dtype="float16") = params[250]
lv2815 = R.call_tir(cls.rms_norm1, (lv47_2, lv798), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv579: R.Tensor((12288, 512), dtype="uint32") = params[242]
lv580: R.Tensor((12288, 128), dtype="float16") = params[243]
lv48_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv579, lv580, lv2815), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv96_1 = R.call_tir(cls.split_rotary, (lv48_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv582: R.Tensor((1, 1, 4096), dtype="float16") = lv96_1[0]
lv583 = R.call_tir(cls.fused_reshape2_transpose5, (lv582,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv584: R.Tensor((1, 1, 4096), dtype="float16") = lv96_1[1]
lv585 = R.call_tir(cls.fused_reshape2_squeeze, (lv584,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv586: R.Tensor((1, 1, 4096), dtype="float16") = lv96_1[2]
lv587 = R.call_tir(cls.fused_reshape2_squeeze, (lv586,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2829: R.Object = kv_cache[48]
lv2830: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2829, lv585, sinfo_args=(R.Object,))
lv2831: R.Object = kv_cache[49]
lv2832: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2831, lv587, sinfo_args=(R.Object,))
lv2833: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2830, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2834: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2832, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2835 = R.call_tir(cls.reshape3, (lv2833,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2836 = R.call_tir(cls.reshape3, (lv2834,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2838 = R.call_tir(cls.transpose6, (lv2835,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2839 = R.call_tir(cls.transpose6, (lv2836,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv588_1 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv583, lv2838, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv589 = R.call_tir(cls.fused_softmax_cast1, (lv588_1,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2848 = R.call_tir(cls.matmul9, (lv589, lv2839), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv590 = R.call_tir(cls.fused_transpose7_reshape4, (lv2848,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv591: R.Tensor((4096, 512), dtype="uint32") = params[244]
lv592: R.Tensor((4096, 128), dtype="float16") = params[245]
lv48_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv591, lv592, lv590, lv47_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv805: R.Tensor((4096,), dtype="float16") = params[251]
lv2854 = R.call_tir(cls.rms_norm1, (lv48_3, lv805), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv595_1: R.Tensor((22016, 512), dtype="uint32") = params[246]
lv596: R.Tensor((22016, 128), dtype="float16") = params[247]
lv49 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv595_1, lv596, lv2854), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv598 = R.call_tir(cls.fused_split_silu_multiply, (lv49,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv599: R.Tensor((4096, 1376), dtype="uint32") = params[248]
lv600: R.Tensor((4096, 344), dtype="float16") = params[249]
lv49_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv599, lv600, lv598, lv48_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv812: R.Tensor((4096,), dtype="float16") = params[260]
lv2865 = R.call_tir(cls.rms_norm1, (lv49_1, lv812), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv603: R.Tensor((12288, 512), dtype="uint32") = params[252]
lv604: R.Tensor((12288, 128), dtype="float16") = params[253]
lv50 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv603, lv604, lv2865), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv100_1 = R.call_tir(cls.split_rotary, (lv50, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv606: R.Tensor((1, 1, 4096), dtype="float16") = lv100_1[0]
lv607 = R.call_tir(cls.fused_reshape2_transpose5, (lv606,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv608: R.Tensor((1, 1, 4096), dtype="float16") = lv100_1[1]
lv609_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv608,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv610: R.Tensor((1, 1, 4096), dtype="float16") = lv100_1[2]
lv611 = R.call_tir(cls.fused_reshape2_squeeze, (lv610,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2879: R.Object = kv_cache[50]
lv2880: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2879, lv609_1, sinfo_args=(R.Object,))
lv2881: R.Object = kv_cache[51]
lv2882: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2881, lv611, sinfo_args=(R.Object,))
lv2883: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2880, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2884: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2882, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2885 = R.call_tir(cls.reshape3, (lv2883,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2886 = R.call_tir(cls.reshape3, (lv2884,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2888 = R.call_tir(cls.transpose6, (lv2885,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2889 = R.call_tir(cls.transpose6, (lv2886,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv612 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv607, lv2888, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv613 = R.call_tir(cls.fused_softmax_cast1, (lv612,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2898 = R.call_tir(cls.matmul9, (lv613, lv2889), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv614 = R.call_tir(cls.fused_transpose7_reshape4, (lv2898,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv615: R.Tensor((4096, 512), dtype="uint32") = params[254]
lv616_1: R.Tensor((4096, 128), dtype="float16") = params[255]
lv50_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv615, lv616_1, lv614, lv49_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv819: R.Tensor((4096,), dtype="float16") = params[261]
lv2904 = R.call_tir(cls.rms_norm1, (lv50_1, lv819), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv619: R.Tensor((22016, 512), dtype="uint32") = params[256]
lv620: R.Tensor((22016, 128), dtype="float16") = params[257]
lv51_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv619, lv620, lv2904), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv622 = R.call_tir(cls.fused_split_silu_multiply, (lv51_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv623_1: R.Tensor((4096, 1376), dtype="uint32") = params[258]
lv624: R.Tensor((4096, 344), dtype="float16") = params[259]
lv51_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv623_1, lv624, lv622, lv50_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv826: R.Tensor((4096,), dtype="float16") = params[270]
lv2915 = R.call_tir(cls.rms_norm1, (lv51_2, lv826), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv627: R.Tensor((12288, 512), dtype="uint32") = params[262]
lv628: R.Tensor((12288, 128), dtype="float16") = params[263]
lv52_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv627, lv628, lv2915), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv104_1 = R.call_tir(cls.split_rotary, (lv52_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv630_1: R.Tensor((1, 1, 4096), dtype="float16") = lv104_1[0]
lv631 = R.call_tir(cls.fused_reshape2_transpose5, (lv630_1,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv632: R.Tensor((1, 1, 4096), dtype="float16") = lv104_1[1]
lv633 = R.call_tir(cls.fused_reshape2_squeeze, (lv632,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv634: R.Tensor((1, 1, 4096), dtype="float16") = lv104_1[2]
lv635 = R.call_tir(cls.fused_reshape2_squeeze, (lv634,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2929: R.Object = kv_cache[52]
lv2930: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2929, lv633, sinfo_args=(R.Object,))
lv2931: R.Object = kv_cache[53]
lv2932: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2931, lv635, sinfo_args=(R.Object,))
lv2933: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2930, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2934: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2932, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2935 = R.call_tir(cls.reshape3, (lv2933,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2936 = R.call_tir(cls.reshape3, (lv2934,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2938 = R.call_tir(cls.transpose6, (lv2935,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2939 = R.call_tir(cls.transpose6, (lv2936,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv636 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv631, lv2938, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv637_1 = R.call_tir(cls.fused_softmax_cast1, (lv636,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2948 = R.call_tir(cls.matmul9, (lv637_1, lv2939), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv638 = R.call_tir(cls.fused_transpose7_reshape4, (lv2948,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv639: R.Tensor((4096, 512), dtype="uint32") = params[264]
lv640: R.Tensor((4096, 128), dtype="float16") = params[265]
lv52_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv639, lv640, lv638, lv51_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv833: R.Tensor((4096,), dtype="float16") = params[271]
lv2954 = R.call_tir(cls.rms_norm1, (lv52_3, lv833), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv643: R.Tensor((22016, 512), dtype="uint32") = params[266]
lv644_1: R.Tensor((22016, 128), dtype="float16") = params[267]
lv53 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv643, lv644_1, lv2954), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv646 = R.call_tir(cls.fused_split_silu_multiply, (lv53,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv647: R.Tensor((4096, 1376), dtype="uint32") = params[268]
lv648: R.Tensor((4096, 344), dtype="float16") = params[269]
lv53_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv647, lv648, lv646, lv52_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv840: R.Tensor((4096,), dtype="float16") = params[280]
lv2965 = R.call_tir(cls.rms_norm1, (lv53_1, lv840), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv651_1: R.Tensor((12288, 512), dtype="uint32") = params[272]
lv652: R.Tensor((12288, 128), dtype="float16") = params[273]
lv54_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv651_1, lv652, lv2965), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv108_1 = R.call_tir(cls.split_rotary, (lv54_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv654: R.Tensor((1, 1, 4096), dtype="float16") = lv108_1[0]
lv655 = R.call_tir(cls.fused_reshape2_transpose5, (lv654,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv656: R.Tensor((1, 1, 4096), dtype="float16") = lv108_1[1]
lv657 = R.call_tir(cls.fused_reshape2_squeeze, (lv656,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv658_1: R.Tensor((1, 1, 4096), dtype="float16") = lv108_1[2]
lv659 = R.call_tir(cls.fused_reshape2_squeeze, (lv658_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv2979: R.Object = kv_cache[54]
lv2980: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2979, lv657, sinfo_args=(R.Object,))
lv2981: R.Object = kv_cache[55]
lv2982: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2981, lv659, sinfo_args=(R.Object,))
lv2983: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2980, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2984: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2982, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv2985 = R.call_tir(cls.reshape3, (lv2983,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2986 = R.call_tir(cls.reshape3, (lv2984,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv2988 = R.call_tir(cls.transpose6, (lv2985,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv2989 = R.call_tir(cls.transpose6, (lv2986,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv660 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv655, lv2988, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv661 = R.call_tir(cls.fused_softmax_cast1, (lv660,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv2998 = R.call_tir(cls.matmul9, (lv661, lv2989), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv662 = R.call_tir(cls.fused_transpose7_reshape4, (lv2998,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv663: R.Tensor((4096, 512), dtype="uint32") = params[274]
lv664: R.Tensor((4096, 128), dtype="float16") = params[275]
lv54_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv663, lv664, lv662, lv53_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv847: R.Tensor((4096,), dtype="float16") = params[281]
lv3004 = R.call_tir(cls.rms_norm1, (lv54_2, lv847), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv667: R.Tensor((22016, 512), dtype="uint32") = params[276]
lv668: R.Tensor((22016, 128), dtype="float16") = params[277]
lv55_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv667, lv668, lv3004), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv670 = R.call_tir(cls.fused_split_silu_multiply, (lv55_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv671: R.Tensor((4096, 1376), dtype="uint32") = params[278]
lv672_1: R.Tensor((4096, 344), dtype="float16") = params[279]
lv55_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv671, lv672_1, lv670, lv54_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv854: R.Tensor((4096,), dtype="float16") = params[290]
lv3015 = R.call_tir(cls.rms_norm1, (lv55_2, lv854), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv675: R.Tensor((12288, 512), dtype="uint32") = params[282]
lv676: R.Tensor((12288, 128), dtype="float16") = params[283]
lv56_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv675, lv676, lv3015), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv112_1 = R.call_tir(cls.split_rotary, (lv56_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv678: R.Tensor((1, 1, 4096), dtype="float16") = lv112_1[0]
lv679_1 = R.call_tir(cls.fused_reshape2_transpose5, (lv678,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv680: R.Tensor((1, 1, 4096), dtype="float16") = lv112_1[1]
lv681 = R.call_tir(cls.fused_reshape2_squeeze, (lv680,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv682: R.Tensor((1, 1, 4096), dtype="float16") = lv112_1[2]
lv683 = R.call_tir(cls.fused_reshape2_squeeze, (lv682,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv3029: R.Object = kv_cache[56]
lv3030: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3029, lv681, sinfo_args=(R.Object,))
lv3031: R.Object = kv_cache[57]
lv3032: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3031, lv683, sinfo_args=(R.Object,))
lv3033: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3030, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv3034: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3032, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv3035 = R.call_tir(cls.reshape3, (lv3033,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv3036 = R.call_tir(cls.reshape3, (lv3034,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv3038 = R.call_tir(cls.transpose6, (lv3035,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv3039 = R.call_tir(cls.transpose6, (lv3036,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv684 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv679_1, lv3038, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv685 = R.call_tir(cls.fused_softmax_cast1, (lv684,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv3048 = R.call_tir(cls.matmul9, (lv685, lv3039), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv686_1 = R.call_tir(cls.fused_transpose7_reshape4, (lv3048,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv687: R.Tensor((4096, 512), dtype="uint32") = params[284]
lv688: R.Tensor((4096, 128), dtype="float16") = params[285]
lv56_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv687, lv688, lv686_1, lv55_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv861: R.Tensor((4096,), dtype="float16") = params[291]
lv3054 = R.call_tir(cls.rms_norm1, (lv56_3, lv861), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv691: R.Tensor((22016, 512), dtype="uint32") = params[286]
lv692: R.Tensor((22016, 128), dtype="float16") = params[287]
lv57_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv691, lv692, lv3054), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv694 = R.call_tir(cls.fused_split_silu_multiply, (lv57_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv695: R.Tensor((4096, 1376), dtype="uint32") = params[288]
lv696: R.Tensor((4096, 344), dtype="float16") = params[289]
lv57_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv695, lv696, lv694, lv56_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv868: R.Tensor((4096,), dtype="float16") = params[300]
lv3065 = R.call_tir(cls.rms_norm1, (lv57_2, lv868), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv699: R.Tensor((12288, 512), dtype="uint32") = params[292]
lv700_1: R.Tensor((12288, 128), dtype="float16") = params[293]
lv58_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv699, lv700_1, lv3065), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv116_1 = R.call_tir(cls.split_rotary, (lv58_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv702: R.Tensor((1, 1, 4096), dtype="float16") = lv116_1[0]
lv703 = R.call_tir(cls.fused_reshape2_transpose5, (lv702,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv704: R.Tensor((1, 1, 4096), dtype="float16") = lv116_1[1]
lv705 = R.call_tir(cls.fused_reshape2_squeeze, (lv704,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv706: R.Tensor((1, 1, 4096), dtype="float16") = lv116_1[2]
lv707_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv706,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv3079: R.Object = kv_cache[58]
lv3080: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3079, lv705, sinfo_args=(R.Object,))
lv3081: R.Object = kv_cache[59]
lv3082: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3081, lv707_1, sinfo_args=(R.Object,))
lv3083: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3080, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv3084: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3082, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv3085 = R.call_tir(cls.reshape3, (lv3083,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv3086 = R.call_tir(cls.reshape3, (lv3084,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv3088 = R.call_tir(cls.transpose6, (lv3085,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv3089 = R.call_tir(cls.transpose6, (lv3086,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv708 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv703, lv3088, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv709 = R.call_tir(cls.fused_softmax_cast1, (lv708,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv3098 = R.call_tir(cls.matmul9, (lv709, lv3089), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv710 = R.call_tir(cls.fused_transpose7_reshape4, (lv3098,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv711: R.Tensor((4096, 512), dtype="uint32") = params[294]
lv712: R.Tensor((4096, 128), dtype="float16") = params[295]
lv58_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv711, lv712, lv710, lv57_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv875: R.Tensor((4096,), dtype="float16") = params[301]
lv3104 = R.call_tir(cls.rms_norm1, (lv58_2, lv875), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv715: R.Tensor((22016, 512), dtype="uint32") = params[296]
lv716: R.Tensor((22016, 128), dtype="float16") = params[297]
lv59_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv715, lv716, lv3104), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv718 = R.call_tir(cls.fused_split_silu_multiply, (lv59_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv719: R.Tensor((4096, 1376), dtype="uint32") = params[298]
lv720: R.Tensor((4096, 344), dtype="float16") = params[299]
lv59_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv719, lv720, lv718, lv58_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv882: R.Tensor((4096,), dtype="float16") = params[310]
lv3115 = R.call_tir(cls.rms_norm1, (lv59_2, lv882), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv723: R.Tensor((12288, 512), dtype="uint32") = params[302]
lv724: R.Tensor((12288, 128), dtype="float16") = params[303]
lv60_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv723, lv724, lv3115), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv120_1 = R.call_tir(cls.split_rotary, (lv60_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv726: R.Tensor((1, 1, 4096), dtype="float16") = lv120_1[0]
lv727 = R.call_tir(cls.fused_reshape2_transpose5, (lv726,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv728_1: R.Tensor((1, 1, 4096), dtype="float16") = lv120_1[1]
lv729 = R.call_tir(cls.fused_reshape2_squeeze, (lv728_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv730: R.Tensor((1, 1, 4096), dtype="float16") = lv120_1[2]
lv731 = R.call_tir(cls.fused_reshape2_squeeze, (lv730,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv3129: R.Object = kv_cache[60]
lv3130: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3129, lv729, sinfo_args=(R.Object,))
lv3131: R.Object = kv_cache[61]
lv3132: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3131, lv731, sinfo_args=(R.Object,))
lv3133: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3130, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv3134: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3132, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv3135 = R.call_tir(cls.reshape3, (lv3133,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv3136 = R.call_tir(cls.reshape3, (lv3134,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv3138 = R.call_tir(cls.transpose6, (lv3135,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv3139 = R.call_tir(cls.transpose6, (lv3136,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv732 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv727, lv3138, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv733 = R.call_tir(cls.fused_softmax_cast1, (lv732,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv3148 = R.call_tir(cls.matmul9, (lv733, lv3139), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv734 = R.call_tir(cls.fused_transpose7_reshape4, (lv3148,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv735_1: R.Tensor((4096, 512), dtype="uint32") = params[304]
lv736: R.Tensor((4096, 128), dtype="float16") = params[305]
lv60_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv735_1, lv736, lv734, lv59_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv889: R.Tensor((4096,), dtype="float16") = params[311]
lv3154 = R.call_tir(cls.rms_norm1, (lv60_3, lv889), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv739: R.Tensor((22016, 512), dtype="uint32") = params[306]
lv740: R.Tensor((22016, 128), dtype="float16") = params[307]
lv61_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv739, lv740, lv3154), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv742_1 = R.call_tir(cls.fused_split_silu_multiply, (lv61_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv743: R.Tensor((4096, 1376), dtype="uint32") = params[308]
lv744: R.Tensor((4096, 344), dtype="float16") = params[309]
lv61_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv743, lv744, lv742_1, lv60_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv896: R.Tensor((4096,), dtype="float16") = params[320]
lv3165 = R.call_tir(cls.rms_norm1, (lv61_2, lv896), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv747: R.Tensor((12288, 512), dtype="uint32") = params[312]
lv748: R.Tensor((12288, 128), dtype="float16") = params[313]
lv62_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv747, lv748, lv3165), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16"))
lv124_1 = R.call_tir(cls.split_rotary, (lv62_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n]))
lv750: R.Tensor((1, 1, 4096), dtype="float16") = lv124_1[0]
lv751 = R.call_tir(cls.fused_reshape2_transpose5, (lv750,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv752: R.Tensor((1, 1, 4096), dtype="float16") = lv124_1[1]
lv753 = R.call_tir(cls.fused_reshape2_squeeze, (lv752,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv754: R.Tensor((1, 1, 4096), dtype="float16") = lv124_1[2]
lv755 = R.call_tir(cls.fused_reshape2_squeeze, (lv754,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16"))
lv3179: R.Object = kv_cache[62]
lv3180: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3179, lv753, sinfo_args=(R.Object,))
lv3181: R.Object = kv_cache[63]
lv3182: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3181, lv755, sinfo_args=(R.Object,))
lv3183: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3180, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv3184: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3182, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),))
lv3185 = R.call_tir(cls.reshape3, (lv3183,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv3186 = R.call_tir(cls.reshape3, (lv3184,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv3188 = R.call_tir(cls.transpose6, (lv3185,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv3189 = R.call_tir(cls.transpose6, (lv3186,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv756_1 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv751, lv3188, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32"))
lv757 = R.call_tir(cls.fused_softmax_cast1, (lv756_1,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16"))
lv3198 = R.call_tir(cls.matmul9, (lv757, lv3189), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16"))
lv758 = R.call_tir(cls.fused_transpose7_reshape4, (lv3198,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv759: R.Tensor((4096, 512), dtype="uint32") = params[314]
lv760: R.Tensor((4096, 128), dtype="float16") = params[315]
lv62_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv759, lv760, lv758, lv61_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv903: R.Tensor((4096,), dtype="float16") = params[321]
lv3204 = R.call_tir(cls.rms_norm1, (lv62_2, lv903), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv763_1: R.Tensor((22016, 512), dtype="uint32") = params[316]
lv764: R.Tensor((22016, 128), dtype="float16") = params[317]
lv63_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv763_1, lv764, lv3204), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
lv766 = R.call_tir(cls.fused_split_silu_multiply, (lv63_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
lv767: R.Tensor((4096, 1376), dtype="uint32") = params[318]
lv768: R.Tensor((4096, 344), dtype="float16") = params[319]
lv63_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv767, lv768, lv766, lv62_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv910: R.Tensor((4096,), dtype="float16") = params[322]
lv3215 = R.call_tir(cls.rms_norm1, (lv63_2, lv910), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv3216 = R.call_tir(cls.slice1, (lv3215,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv771: R.Tensor((32001, 512), dtype="uint32") = params[323]
lv772: R.Tensor((32001, 128), dtype="float16") = params[324]
lv64_2 = R.call_tir(cls.fused_fused_decode1_fused_NT_matmul5_cast2, (lv771, lv772, lv3216), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32"))
gv1: R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)) = lv64_2, (lv1630, lv1632, lv1680, lv1682, lv1730, lv1732, lv1780, lv1782, lv1830, lv1832, lv1880, lv1882, lv1930, lv1932, lv1980, lv1982, lv2030, lv2032, lv2080, lv2082, lv2130, lv2132, lv2180, lv2182, lv2230, lv2232, lv2280, lv2282, lv2330, lv2332, lv2380, lv2382, lv2430, lv2432, lv2480, lv2482, lv2530, lv2532, lv2580, lv2582, lv2630, lv2632, lv2680, lv2682, lv2730, lv2732, lv2780, lv2782, lv2830, lv2832, lv2880, lv2882, lv2930, lv2932, lv2980, lv2982, lv3030, lv3032, lv3080, lv3082, lv3130, lv3132, lv3180, lv3182)
R.output(gv1)
return gv1
@R.function
def get_metadata() -> R.Object:
R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
return R.str("{\"model_name\": \"WizardMath-7B-V1.0\", \"max_window_size\": 2048, \"stop_tokens\": [2], \"add_prefix_space\": false}")
@R.function
def prefill(input_ids: R.Tensor((1, "n"), dtype="int32"), all_seq_len: R.Shape(["m"]), kv_cache: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object), params: R.Tuple(R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)):
n = T.int64()
m = T.int64()
R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.reshape5, (input_ids,), out_sinfo=R.Tensor((n,), dtype="int32"))
lv775: R.Tensor((32001, 512), dtype="uint32") = params[0]
lv776: R.Tensor((32001, 128), dtype="float16") = params[1]
lv_1 = R.call_tir(cls.fused_fused_decode1_take1, (lv775, lv776, lv), out_sinfo=R.Tensor((n, 4096), dtype="float16"))
lv2 = R.call_tir(cls.reshape6, (lv_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv778 = R.call_tir(cls.fused_min_max_triu_te_broadcast_to, R.tuple(), out_sinfo=R.Tensor((1, 1, n, n), dtype="float16"), tir_vars=R.shape([n]))
lv5 = R.call_tir(cls.extend_te, (lv778,), out_sinfo=R.Tensor((1, 1, n, m), dtype="float16"))
lv3: R.Tensor((4096,), dtype="float16") = params[10]
lv6 = R.call_tir(cls.rms_norm, (lv2, lv3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv779: R.Tensor((12288, 512), dtype="uint32") = params[2]
lv780: R.Tensor((12288, 128), dtype="float16") = params[3]
lv65 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv779, lv780, lv6), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv9 = R.call_tir(cls.split1, (lv65,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv10: R.Tensor((1, n, 4096), dtype="float16") = lv9[0]
lv11 = R.call_tir(cls.reshape7, (lv10,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv12: R.Tensor((1, n, 4096), dtype="float16") = lv9[1]
lv13 = R.call_tir(cls.reshape7, (lv12,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv14: R.Tensor((1, n, 4096), dtype="float16") = lv9[2]
lv15 = R.call_tir(cls.reshape7, (lv14,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv7: R.Tensor((2048, 128), dtype="float16") = params[325]
lv8: R.Tensor((2048, 128), dtype="float16") = params[326]
lv16 = R.call_tir(cls.rotary_embedding, (lv11, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv17 = R.call_tir(cls.rotary_embedding, (lv13, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv18 = R.call_tir(cls.squeeze1, (lv17,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv19 = R.call_tir(cls.squeeze1, (lv15,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv20: R.Object = kv_cache[0]
lv21: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv20, lv18, sinfo_args=(R.Object,))
lv22: R.Object = kv_cache[1]
lv23: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv22, lv19, sinfo_args=(R.Object,))
lv24: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv21, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv25: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv23, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv26 = R.call_tir(cls.reshape3, (lv24,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv27 = R.call_tir(cls.reshape3, (lv25,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv28 = R.call_tir(cls.transpose6, (lv16,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv29 = R.call_tir(cls.transpose6, (lv26,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv30 = R.call_tir(cls.transpose6, (lv27,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv782 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv28, lv29, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv783 = R.call_tir(cls.fused_softmax1_cast4, (lv782,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv39 = R.call_tir(cls.matmul10, (lv783, lv30), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv40 = R.call_tir(cls.transpose8, (lv39,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv41 = R.call_tir(cls.reshape8, (lv40,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv784: R.Tensor((4096, 512), dtype="uint32") = params[4]
lv785: R.Tensor((4096, 128), dtype="float16") = params[5]
lv64 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv784, lv785, lv41, lv2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv12_1: R.Tensor((4096,), dtype="float16") = params[11]
lv45 = R.call_tir(cls.rms_norm, (lv64, lv12_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv788: R.Tensor((22016, 512), dtype="uint32") = params[6]
lv789: R.Tensor((22016, 128), dtype="float16") = params[7]
lv66 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv788, lv789, lv45), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv791 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv66,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv792: R.Tensor((4096, 1376), dtype="uint32") = params[8]
lv793: R.Tensor((4096, 344), dtype="float16") = params[9]
lv65_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv792, lv793, lv791, lv64), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv19_1: R.Tensor((4096,), dtype="float16") = params[20]
lv56 = R.call_tir(cls.rms_norm, (lv65_1, lv19_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv796: R.Tensor((12288, 512), dtype="uint32") = params[12]
lv797: R.Tensor((12288, 128), dtype="float16") = params[13]
lv67 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv796, lv797, lv56), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv59 = R.call_tir(cls.split1, (lv67,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv60: R.Tensor((1, n, 4096), dtype="float16") = lv59[0]
lv61 = R.call_tir(cls.reshape7, (lv60,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv62: R.Tensor((1, n, 4096), dtype="float16") = lv59[1]
lv63 = R.call_tir(cls.reshape7, (lv62,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv64_1: R.Tensor((1, n, 4096), dtype="float16") = lv59[2]
lv65_2 = R.call_tir(cls.reshape7, (lv64_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv66_1 = R.call_tir(cls.rotary_embedding, (lv61, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv67_1 = R.call_tir(cls.rotary_embedding, (lv63, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv68 = R.call_tir(cls.squeeze1, (lv67_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv69 = R.call_tir(cls.squeeze1, (lv65_2,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv70: R.Object = kv_cache[2]
lv71: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv70, lv68, sinfo_args=(R.Object,))
lv72: R.Object = kv_cache[3]
lv73: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv72, lv69, sinfo_args=(R.Object,))
lv74: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv71, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv75: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv73, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv76 = R.call_tir(cls.reshape3, (lv74,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv77 = R.call_tir(cls.reshape3, (lv75,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv78 = R.call_tir(cls.transpose6, (lv66_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv79 = R.call_tir(cls.transpose6, (lv76,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv80 = R.call_tir(cls.transpose6, (lv77,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv799 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv78, lv79, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv800 = R.call_tir(cls.fused_softmax1_cast4, (lv799,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv89 = R.call_tir(cls.matmul10, (lv800, lv80), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv90 = R.call_tir(cls.transpose8, (lv89,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv91 = R.call_tir(cls.reshape8, (lv90,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv801: R.Tensor((4096, 512), dtype="uint32") = params[14]
lv802: R.Tensor((4096, 128), dtype="float16") = params[15]
lv66_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv801, lv802, lv91, lv65_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv26_1: R.Tensor((4096,), dtype="float16") = params[21]
lv95 = R.call_tir(cls.rms_norm, (lv66_2, lv26_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv805: R.Tensor((22016, 512), dtype="uint32") = params[16]
lv806: R.Tensor((22016, 128), dtype="float16") = params[17]
lv68_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv805, lv806, lv95), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv808 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv68_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv809: R.Tensor((4096, 1376), dtype="uint32") = params[18]
lv810: R.Tensor((4096, 344), dtype="float16") = params[19]
lv67_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv809, lv810, lv808, lv66_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv33: R.Tensor((4096,), dtype="float16") = params[30]
lv106 = R.call_tir(cls.rms_norm, (lv67_2, lv33), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv813: R.Tensor((12288, 512), dtype="uint32") = params[22]
lv814: R.Tensor((12288, 128), dtype="float16") = params[23]
lv69_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv813, lv814, lv106), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv109 = R.call_tir(cls.split1, (lv69_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv110: R.Tensor((1, n, 4096), dtype="float16") = lv109[0]
lv111 = R.call_tir(cls.reshape7, (lv110,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv112: R.Tensor((1, n, 4096), dtype="float16") = lv109[1]
lv113 = R.call_tir(cls.reshape7, (lv112,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv114: R.Tensor((1, n, 4096), dtype="float16") = lv109[2]
lv115 = R.call_tir(cls.reshape7, (lv114,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv116 = R.call_tir(cls.rotary_embedding, (lv111, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv117 = R.call_tir(cls.rotary_embedding, (lv113, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv118 = R.call_tir(cls.squeeze1, (lv117,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv119 = R.call_tir(cls.squeeze1, (lv115,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv120: R.Object = kv_cache[4]
lv121: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv120, lv118, sinfo_args=(R.Object,))
lv122: R.Object = kv_cache[5]
lv123: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv122, lv119, sinfo_args=(R.Object,))
lv124: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv121, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv125: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv123, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv126 = R.call_tir(cls.reshape3, (lv124,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv127 = R.call_tir(cls.reshape3, (lv125,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv128 = R.call_tir(cls.transpose6, (lv116,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv129 = R.call_tir(cls.transpose6, (lv126,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv130 = R.call_tir(cls.transpose6, (lv127,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv816 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv128, lv129, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv817 = R.call_tir(cls.fused_softmax1_cast4, (lv816,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv139 = R.call_tir(cls.matmul10, (lv817, lv130), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv140 = R.call_tir(cls.transpose8, (lv139,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv141 = R.call_tir(cls.reshape8, (lv140,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv818: R.Tensor((4096, 512), dtype="uint32") = params[24]
lv819: R.Tensor((4096, 128), dtype="float16") = params[25]
lv68_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv818, lv819, lv141, lv67_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv40_1: R.Tensor((4096,), dtype="float16") = params[31]
lv145 = R.call_tir(cls.rms_norm, (lv68_2, lv40_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv822: R.Tensor((22016, 512), dtype="uint32") = params[26]
lv823: R.Tensor((22016, 128), dtype="float16") = params[27]
lv70_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv822, lv823, lv145), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv825 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv70_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv826: R.Tensor((4096, 1376), dtype="uint32") = params[28]
lv827: R.Tensor((4096, 344), dtype="float16") = params[29]
lv69_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv826, lv827, lv825, lv68_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv47: R.Tensor((4096,), dtype="float16") = params[40]
lv156 = R.call_tir(cls.rms_norm, (lv69_2, lv47), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv830: R.Tensor((12288, 512), dtype="uint32") = params[32]
lv831: R.Tensor((12288, 128), dtype="float16") = params[33]
lv71_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv830, lv831, lv156), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv159 = R.call_tir(cls.split1, (lv71_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv160: R.Tensor((1, n, 4096), dtype="float16") = lv159[0]
lv161 = R.call_tir(cls.reshape7, (lv160,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv162: R.Tensor((1, n, 4096), dtype="float16") = lv159[1]
lv163 = R.call_tir(cls.reshape7, (lv162,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv164: R.Tensor((1, n, 4096), dtype="float16") = lv159[2]
lv165 = R.call_tir(cls.reshape7, (lv164,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv166 = R.call_tir(cls.rotary_embedding, (lv161, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv167 = R.call_tir(cls.rotary_embedding, (lv163, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv168 = R.call_tir(cls.squeeze1, (lv167,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv169 = R.call_tir(cls.squeeze1, (lv165,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv170: R.Object = kv_cache[6]
lv171: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv170, lv168, sinfo_args=(R.Object,))
lv172: R.Object = kv_cache[7]
lv173: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv172, lv169, sinfo_args=(R.Object,))
lv174: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv171, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv175: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv173, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv176 = R.call_tir(cls.reshape3, (lv174,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv177 = R.call_tir(cls.reshape3, (lv175,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv178 = R.call_tir(cls.transpose6, (lv166,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv179 = R.call_tir(cls.transpose6, (lv176,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv180 = R.call_tir(cls.transpose6, (lv177,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv833 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv178, lv179, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv834 = R.call_tir(cls.fused_softmax1_cast4, (lv833,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv189 = R.call_tir(cls.matmul10, (lv834, lv180), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv190 = R.call_tir(cls.transpose8, (lv189,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv191 = R.call_tir(cls.reshape8, (lv190,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv835: R.Tensor((4096, 512), dtype="uint32") = params[34]
lv836: R.Tensor((4096, 128), dtype="float16") = params[35]
lv70_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv835, lv836, lv191, lv69_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv54: R.Tensor((4096,), dtype="float16") = params[41]
lv195 = R.call_tir(cls.rms_norm, (lv70_2, lv54), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv839: R.Tensor((22016, 512), dtype="uint32") = params[36]
lv840: R.Tensor((22016, 128), dtype="float16") = params[37]
lv72_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv839, lv840, lv195), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv842 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv72_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv843: R.Tensor((4096, 1376), dtype="uint32") = params[38]
lv844: R.Tensor((4096, 344), dtype="float16") = params[39]
lv71_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv843, lv844, lv842, lv70_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv61_1: R.Tensor((4096,), dtype="float16") = params[50]
lv206 = R.call_tir(cls.rms_norm, (lv71_2, lv61_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv847: R.Tensor((12288, 512), dtype="uint32") = params[42]
lv848: R.Tensor((12288, 128), dtype="float16") = params[43]
lv73_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv847, lv848, lv206), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv209 = R.call_tir(cls.split1, (lv73_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv210: R.Tensor((1, n, 4096), dtype="float16") = lv209[0]
lv211 = R.call_tir(cls.reshape7, (lv210,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv212: R.Tensor((1, n, 4096), dtype="float16") = lv209[1]
lv213 = R.call_tir(cls.reshape7, (lv212,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv214: R.Tensor((1, n, 4096), dtype="float16") = lv209[2]
lv215 = R.call_tir(cls.reshape7, (lv214,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv216 = R.call_tir(cls.rotary_embedding, (lv211, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv217 = R.call_tir(cls.rotary_embedding, (lv213, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv218 = R.call_tir(cls.squeeze1, (lv217,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv219 = R.call_tir(cls.squeeze1, (lv215,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv220: R.Object = kv_cache[8]
lv221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv220, lv218, sinfo_args=(R.Object,))
lv222: R.Object = kv_cache[9]
lv223: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv222, lv219, sinfo_args=(R.Object,))
lv224: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv221, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv225: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv223, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv226 = R.call_tir(cls.reshape3, (lv224,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv227 = R.call_tir(cls.reshape3, (lv225,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv228 = R.call_tir(cls.transpose6, (lv216,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv229 = R.call_tir(cls.transpose6, (lv226,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv230 = R.call_tir(cls.transpose6, (lv227,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv850 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv228, lv229, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv851 = R.call_tir(cls.fused_softmax1_cast4, (lv850,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv239 = R.call_tir(cls.matmul10, (lv851, lv230), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv240 = R.call_tir(cls.transpose8, (lv239,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv241 = R.call_tir(cls.reshape8, (lv240,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv852: R.Tensor((4096, 512), dtype="uint32") = params[44]
lv853: R.Tensor((4096, 128), dtype="float16") = params[45]
lv72_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv852, lv853, lv241, lv71_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv68_3: R.Tensor((4096,), dtype="float16") = params[51]
lv245 = R.call_tir(cls.rms_norm, (lv72_2, lv68_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv856: R.Tensor((22016, 512), dtype="uint32") = params[46]
lv857: R.Tensor((22016, 128), dtype="float16") = params[47]
lv74_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv856, lv857, lv245), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv859 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv74_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv860: R.Tensor((4096, 1376), dtype="uint32") = params[48]
lv861: R.Tensor((4096, 344), dtype="float16") = params[49]
lv73_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv860, lv861, lv859, lv72_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv75_1: R.Tensor((4096,), dtype="float16") = params[60]
lv256 = R.call_tir(cls.rms_norm, (lv73_2, lv75_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv864: R.Tensor((12288, 512), dtype="uint32") = params[52]
lv865: R.Tensor((12288, 128), dtype="float16") = params[53]
lv75_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv864, lv865, lv256), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv259 = R.call_tir(cls.split1, (lv75_2,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv260: R.Tensor((1, n, 4096), dtype="float16") = lv259[0]
lv261 = R.call_tir(cls.reshape7, (lv260,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv262: R.Tensor((1, n, 4096), dtype="float16") = lv259[1]
lv263 = R.call_tir(cls.reshape7, (lv262,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv264: R.Tensor((1, n, 4096), dtype="float16") = lv259[2]
lv265 = R.call_tir(cls.reshape7, (lv264,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv266 = R.call_tir(cls.rotary_embedding, (lv261, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv267 = R.call_tir(cls.rotary_embedding, (lv263, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv268 = R.call_tir(cls.squeeze1, (lv267,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv269 = R.call_tir(cls.squeeze1, (lv265,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv270: R.Object = kv_cache[10]
lv271: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv270, lv268, sinfo_args=(R.Object,))
lv272: R.Object = kv_cache[11]
lv273: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv272, lv269, sinfo_args=(R.Object,))
lv274: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv271, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv275: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv273, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv276 = R.call_tir(cls.reshape3, (lv274,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv277 = R.call_tir(cls.reshape3, (lv275,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv278 = R.call_tir(cls.transpose6, (lv266,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv279 = R.call_tir(cls.transpose6, (lv276,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv280 = R.call_tir(cls.transpose6, (lv277,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv867 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv278, lv279, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv868 = R.call_tir(cls.fused_softmax1_cast4, (lv867,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv289 = R.call_tir(cls.matmul10, (lv868, lv280), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv290 = R.call_tir(cls.transpose8, (lv289,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv291 = R.call_tir(cls.reshape8, (lv290,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv869: R.Tensor((4096, 512), dtype="uint32") = params[54]
lv870: R.Tensor((4096, 128), dtype="float16") = params[55]
lv74_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv869, lv870, lv291, lv73_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv82: R.Tensor((4096,), dtype="float16") = params[61]
lv295 = R.call_tir(cls.rms_norm, (lv74_2, lv82), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv873: R.Tensor((22016, 512), dtype="uint32") = params[56]
lv874: R.Tensor((22016, 128), dtype="float16") = params[57]
lv76_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv873, lv874, lv295), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv876 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv76_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv877: R.Tensor((4096, 1376), dtype="uint32") = params[58]
lv878: R.Tensor((4096, 344), dtype="float16") = params[59]
lv75_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv877, lv878, lv876, lv74_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv89_1: R.Tensor((4096,), dtype="float16") = params[70]
lv306 = R.call_tir(cls.rms_norm, (lv75_3, lv89_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv881: R.Tensor((12288, 512), dtype="uint32") = params[62]
lv882: R.Tensor((12288, 128), dtype="float16") = params[63]
lv77_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv881, lv882, lv306), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv309 = R.call_tir(cls.split1, (lv77_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv310: R.Tensor((1, n, 4096), dtype="float16") = lv309[0]
lv311 = R.call_tir(cls.reshape7, (lv310,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv312: R.Tensor((1, n, 4096), dtype="float16") = lv309[1]
lv313 = R.call_tir(cls.reshape7, (lv312,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv314: R.Tensor((1, n, 4096), dtype="float16") = lv309[2]
lv315 = R.call_tir(cls.reshape7, (lv314,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv316 = R.call_tir(cls.rotary_embedding, (lv311, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv317 = R.call_tir(cls.rotary_embedding, (lv313, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv318 = R.call_tir(cls.squeeze1, (lv317,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv319 = R.call_tir(cls.squeeze1, (lv315,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv320: R.Object = kv_cache[12]
lv321: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv320, lv318, sinfo_args=(R.Object,))
lv322: R.Object = kv_cache[13]
lv323: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv322, lv319, sinfo_args=(R.Object,))
lv324: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv321, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv325: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv323, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv326 = R.call_tir(cls.reshape3, (lv324,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv327 = R.call_tir(cls.reshape3, (lv325,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv328 = R.call_tir(cls.transpose6, (lv316,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv329 = R.call_tir(cls.transpose6, (lv326,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv330 = R.call_tir(cls.transpose6, (lv327,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv884 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv328, lv329, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv885 = R.call_tir(cls.fused_softmax1_cast4, (lv884,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv339 = R.call_tir(cls.matmul10, (lv885, lv330), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv340 = R.call_tir(cls.transpose8, (lv339,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv341 = R.call_tir(cls.reshape8, (lv340,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv886: R.Tensor((4096, 512), dtype="uint32") = params[64]
lv887: R.Tensor((4096, 128), dtype="float16") = params[65]
lv76_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv886, lv887, lv341, lv75_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv96: R.Tensor((4096,), dtype="float16") = params[71]
lv345 = R.call_tir(cls.rms_norm, (lv76_2, lv96), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv890: R.Tensor((22016, 512), dtype="uint32") = params[66]
lv891: R.Tensor((22016, 128), dtype="float16") = params[67]
lv78_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv890, lv891, lv345), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv893 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv78_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv894: R.Tensor((4096, 1376), dtype="uint32") = params[68]
lv895: R.Tensor((4096, 344), dtype="float16") = params[69]
lv77_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv894, lv895, lv893, lv76_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv103: R.Tensor((4096,), dtype="float16") = params[80]
lv356 = R.call_tir(cls.rms_norm, (lv77_2, lv103), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv898: R.Tensor((12288, 512), dtype="uint32") = params[72]
lv899: R.Tensor((12288, 128), dtype="float16") = params[73]
lv79_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv898, lv899, lv356), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv359 = R.call_tir(cls.split1, (lv79_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv360: R.Tensor((1, n, 4096), dtype="float16") = lv359[0]
lv361 = R.call_tir(cls.reshape7, (lv360,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv362: R.Tensor((1, n, 4096), dtype="float16") = lv359[1]
lv363 = R.call_tir(cls.reshape7, (lv362,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv364: R.Tensor((1, n, 4096), dtype="float16") = lv359[2]
lv365 = R.call_tir(cls.reshape7, (lv364,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv366 = R.call_tir(cls.rotary_embedding, (lv361, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv367 = R.call_tir(cls.rotary_embedding, (lv363, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv368 = R.call_tir(cls.squeeze1, (lv367,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv369 = R.call_tir(cls.squeeze1, (lv365,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv370: R.Object = kv_cache[14]
lv371: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv370, lv368, sinfo_args=(R.Object,))
lv372: R.Object = kv_cache[15]
lv373: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv372, lv369, sinfo_args=(R.Object,))
lv374: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv371, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv375: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv373, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv376 = R.call_tir(cls.reshape3, (lv374,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv377 = R.call_tir(cls.reshape3, (lv375,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv378 = R.call_tir(cls.transpose6, (lv366,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv379 = R.call_tir(cls.transpose6, (lv376,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv380 = R.call_tir(cls.transpose6, (lv377,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv901 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv378, lv379, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv902 = R.call_tir(cls.fused_softmax1_cast4, (lv901,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv389 = R.call_tir(cls.matmul10, (lv902, lv380), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv390 = R.call_tir(cls.transpose8, (lv389,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv391 = R.call_tir(cls.reshape8, (lv390,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv903: R.Tensor((4096, 512), dtype="uint32") = params[74]
lv904: R.Tensor((4096, 128), dtype="float16") = params[75]
lv78_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv903, lv904, lv391, lv77_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv110_1: R.Tensor((4096,), dtype="float16") = params[81]
lv395 = R.call_tir(cls.rms_norm, (lv78_2, lv110_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv907: R.Tensor((22016, 512), dtype="uint32") = params[76]
lv908: R.Tensor((22016, 128), dtype="float16") = params[77]
lv80_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv907, lv908, lv395), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv910 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv80_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv911: R.Tensor((4096, 1376), dtype="uint32") = params[78]
lv912: R.Tensor((4096, 344), dtype="float16") = params[79]
lv79_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv911, lv912, lv910, lv78_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv117_1: R.Tensor((4096,), dtype="float16") = params[90]
lv406 = R.call_tir(cls.rms_norm, (lv79_2, lv117_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv915: R.Tensor((12288, 512), dtype="uint32") = params[82]
lv916: R.Tensor((12288, 128), dtype="float16") = params[83]
lv81 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv915, lv916, lv406), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv409 = R.call_tir(cls.split1, (lv81,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv410: R.Tensor((1, n, 4096), dtype="float16") = lv409[0]
lv411 = R.call_tir(cls.reshape7, (lv410,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv412: R.Tensor((1, n, 4096), dtype="float16") = lv409[1]
lv413 = R.call_tir(cls.reshape7, (lv412,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv414: R.Tensor((1, n, 4096), dtype="float16") = lv409[2]
lv415 = R.call_tir(cls.reshape7, (lv414,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv416 = R.call_tir(cls.rotary_embedding, (lv411, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv417 = R.call_tir(cls.rotary_embedding, (lv413, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv418 = R.call_tir(cls.squeeze1, (lv417,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv419 = R.call_tir(cls.squeeze1, (lv415,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv420: R.Object = kv_cache[16]
lv421: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv420, lv418, sinfo_args=(R.Object,))
lv422: R.Object = kv_cache[17]
lv423: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv422, lv419, sinfo_args=(R.Object,))
lv424: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv421, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv425: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv423, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv426 = R.call_tir(cls.reshape3, (lv424,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv427 = R.call_tir(cls.reshape3, (lv425,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv428 = R.call_tir(cls.transpose6, (lv416,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv429 = R.call_tir(cls.transpose6, (lv426,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv430 = R.call_tir(cls.transpose6, (lv427,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv918 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv428, lv429, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv919 = R.call_tir(cls.fused_softmax1_cast4, (lv918,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv439 = R.call_tir(cls.matmul10, (lv919, lv430), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv440 = R.call_tir(cls.transpose8, (lv439,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv441 = R.call_tir(cls.reshape8, (lv440,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv920: R.Tensor((4096, 512), dtype="uint32") = params[84]
lv921: R.Tensor((4096, 128), dtype="float16") = params[85]
lv80_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv920, lv921, lv441, lv79_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv124_1: R.Tensor((4096,), dtype="float16") = params[91]
lv445 = R.call_tir(cls.rms_norm, (lv80_2, lv124_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv924: R.Tensor((22016, 512), dtype="uint32") = params[86]
lv925: R.Tensor((22016, 128), dtype="float16") = params[87]
lv82_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv924, lv925, lv445), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv927 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv82_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv928: R.Tensor((4096, 1376), dtype="uint32") = params[88]
lv929: R.Tensor((4096, 344), dtype="float16") = params[89]
lv81_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv928, lv929, lv927, lv80_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv131: R.Tensor((4096,), dtype="float16") = params[100]
lv456 = R.call_tir(cls.rms_norm, (lv81_1, lv131), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv932: R.Tensor((12288, 512), dtype="uint32") = params[92]
lv933: R.Tensor((12288, 128), dtype="float16") = params[93]
lv83 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv932, lv933, lv456), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv459 = R.call_tir(cls.split1, (lv83,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv460: R.Tensor((1, n, 4096), dtype="float16") = lv459[0]
lv461 = R.call_tir(cls.reshape7, (lv460,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv462: R.Tensor((1, n, 4096), dtype="float16") = lv459[1]
lv463 = R.call_tir(cls.reshape7, (lv462,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv464: R.Tensor((1, n, 4096), dtype="float16") = lv459[2]
lv465 = R.call_tir(cls.reshape7, (lv464,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv466 = R.call_tir(cls.rotary_embedding, (lv461, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv467 = R.call_tir(cls.rotary_embedding, (lv463, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv468 = R.call_tir(cls.squeeze1, (lv467,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv469 = R.call_tir(cls.squeeze1, (lv465,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv470: R.Object = kv_cache[18]
lv471: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv470, lv468, sinfo_args=(R.Object,))
lv472: R.Object = kv_cache[19]
lv473: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv472, lv469, sinfo_args=(R.Object,))
lv474: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv471, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv475: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv473, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv476 = R.call_tir(cls.reshape3, (lv474,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv477 = R.call_tir(cls.reshape3, (lv475,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv478 = R.call_tir(cls.transpose6, (lv466,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv479 = R.call_tir(cls.transpose6, (lv476,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv480 = R.call_tir(cls.transpose6, (lv477,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv935 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv478, lv479, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv936 = R.call_tir(cls.fused_softmax1_cast4, (lv935,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv489 = R.call_tir(cls.matmul10, (lv936, lv480), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv490 = R.call_tir(cls.transpose8, (lv489,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv491 = R.call_tir(cls.reshape8, (lv490,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv937: R.Tensor((4096, 512), dtype="uint32") = params[94]
lv938: R.Tensor((4096, 128), dtype="float16") = params[95]
lv82_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv937, lv938, lv491, lv81_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv138: R.Tensor((4096,), dtype="float16") = params[101]
lv495 = R.call_tir(cls.rms_norm, (lv82_2, lv138), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv941: R.Tensor((22016, 512), dtype="uint32") = params[96]
lv942: R.Tensor((22016, 128), dtype="float16") = params[97]
lv84 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv941, lv942, lv495), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv944 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv84,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv945: R.Tensor((4096, 1376), dtype="uint32") = params[98]
lv946: R.Tensor((4096, 344), dtype="float16") = params[99]
lv83_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv945, lv946, lv944, lv82_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv145_1: R.Tensor((4096,), dtype="float16") = params[110]
lv506 = R.call_tir(cls.rms_norm, (lv83_1, lv145_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv949: R.Tensor((12288, 512), dtype="uint32") = params[102]
lv950: R.Tensor((12288, 128), dtype="float16") = params[103]
lv85 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv949, lv950, lv506), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv509 = R.call_tir(cls.split1, (lv85,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv510: R.Tensor((1, n, 4096), dtype="float16") = lv509[0]
lv511 = R.call_tir(cls.reshape7, (lv510,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv512: R.Tensor((1, n, 4096), dtype="float16") = lv509[1]
lv513 = R.call_tir(cls.reshape7, (lv512,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv514: R.Tensor((1, n, 4096), dtype="float16") = lv509[2]
lv515 = R.call_tir(cls.reshape7, (lv514,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv516 = R.call_tir(cls.rotary_embedding, (lv511, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv517 = R.call_tir(cls.rotary_embedding, (lv513, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv518 = R.call_tir(cls.squeeze1, (lv517,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv519 = R.call_tir(cls.squeeze1, (lv515,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv520: R.Object = kv_cache[20]
lv521: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv520, lv518, sinfo_args=(R.Object,))
lv522: R.Object = kv_cache[21]
lv523: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv522, lv519, sinfo_args=(R.Object,))
lv524: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv521, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv525: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv523, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv526 = R.call_tir(cls.reshape3, (lv524,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv527 = R.call_tir(cls.reshape3, (lv525,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv528 = R.call_tir(cls.transpose6, (lv516,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv529 = R.call_tir(cls.transpose6, (lv526,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv530 = R.call_tir(cls.transpose6, (lv527,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv952 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv528, lv529, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv953 = R.call_tir(cls.fused_softmax1_cast4, (lv952,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv539 = R.call_tir(cls.matmul10, (lv953, lv530), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv540 = R.call_tir(cls.transpose8, (lv539,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv541 = R.call_tir(cls.reshape8, (lv540,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv954: R.Tensor((4096, 512), dtype="uint32") = params[104]
lv955: R.Tensor((4096, 128), dtype="float16") = params[105]
lv84_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv954, lv955, lv541, lv83_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv152: R.Tensor((4096,), dtype="float16") = params[111]
lv545 = R.call_tir(cls.rms_norm, (lv84_1, lv152), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv958: R.Tensor((22016, 512), dtype="uint32") = params[106]
lv959: R.Tensor((22016, 128), dtype="float16") = params[107]
lv86 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv958, lv959, lv545), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv961 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv86,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv962: R.Tensor((4096, 1376), dtype="uint32") = params[108]
lv963: R.Tensor((4096, 344), dtype="float16") = params[109]
lv85_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv962, lv963, lv961, lv84_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv159_1: R.Tensor((4096,), dtype="float16") = params[120]
lv556 = R.call_tir(cls.rms_norm, (lv85_1, lv159_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv966: R.Tensor((12288, 512), dtype="uint32") = params[112]
lv967: R.Tensor((12288, 128), dtype="float16") = params[113]
lv87 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv966, lv967, lv556), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv559 = R.call_tir(cls.split1, (lv87,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv560: R.Tensor((1, n, 4096), dtype="float16") = lv559[0]
lv561 = R.call_tir(cls.reshape7, (lv560,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv562: R.Tensor((1, n, 4096), dtype="float16") = lv559[1]
lv563 = R.call_tir(cls.reshape7, (lv562,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv564: R.Tensor((1, n, 4096), dtype="float16") = lv559[2]
lv565 = R.call_tir(cls.reshape7, (lv564,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv566 = R.call_tir(cls.rotary_embedding, (lv561, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv567 = R.call_tir(cls.rotary_embedding, (lv563, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv568 = R.call_tir(cls.squeeze1, (lv567,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv569 = R.call_tir(cls.squeeze1, (lv565,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv570: R.Object = kv_cache[22]
lv571: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv570, lv568, sinfo_args=(R.Object,))
lv572: R.Object = kv_cache[23]
lv573: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv572, lv569, sinfo_args=(R.Object,))
lv574: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv571, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv575: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv573, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv576 = R.call_tir(cls.reshape3, (lv574,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv577 = R.call_tir(cls.reshape3, (lv575,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv578 = R.call_tir(cls.transpose6, (lv566,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv579 = R.call_tir(cls.transpose6, (lv576,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv580 = R.call_tir(cls.transpose6, (lv577,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv969 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv578, lv579, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv970 = R.call_tir(cls.fused_softmax1_cast4, (lv969,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv589 = R.call_tir(cls.matmul10, (lv970, lv580), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv590 = R.call_tir(cls.transpose8, (lv589,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv591 = R.call_tir(cls.reshape8, (lv590,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv971: R.Tensor((4096, 512), dtype="uint32") = params[114]
lv972: R.Tensor((4096, 128), dtype="float16") = params[115]
lv86_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv971, lv972, lv591, lv85_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv166_1: R.Tensor((4096,), dtype="float16") = params[121]
lv595 = R.call_tir(cls.rms_norm, (lv86_1, lv166_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv975: R.Tensor((22016, 512), dtype="uint32") = params[116]
lv976: R.Tensor((22016, 128), dtype="float16") = params[117]
lv88 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv975, lv976, lv595), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv978 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv88,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv979: R.Tensor((4096, 1376), dtype="uint32") = params[118]
lv980: R.Tensor((4096, 344), dtype="float16") = params[119]
lv87_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv979, lv980, lv978, lv86_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv173_1: R.Tensor((4096,), dtype="float16") = params[130]
lv606 = R.call_tir(cls.rms_norm, (lv87_1, lv173_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv983: R.Tensor((12288, 512), dtype="uint32") = params[122]
lv984: R.Tensor((12288, 128), dtype="float16") = params[123]
lv89_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv983, lv984, lv606), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv609 = R.call_tir(cls.split1, (lv89_2,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv610: R.Tensor((1, n, 4096), dtype="float16") = lv609[0]
lv611 = R.call_tir(cls.reshape7, (lv610,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv612: R.Tensor((1, n, 4096), dtype="float16") = lv609[1]
lv613 = R.call_tir(cls.reshape7, (lv612,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv614: R.Tensor((1, n, 4096), dtype="float16") = lv609[2]
lv615 = R.call_tir(cls.reshape7, (lv614,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv616 = R.call_tir(cls.rotary_embedding, (lv611, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv617 = R.call_tir(cls.rotary_embedding, (lv613, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv618 = R.call_tir(cls.squeeze1, (lv617,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv619 = R.call_tir(cls.squeeze1, (lv615,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv620: R.Object = kv_cache[24]
lv621: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv620, lv618, sinfo_args=(R.Object,))
lv622: R.Object = kv_cache[25]
lv623: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv622, lv619, sinfo_args=(R.Object,))
lv624: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv621, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv625: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv623, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv626 = R.call_tir(cls.reshape3, (lv624,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv627 = R.call_tir(cls.reshape3, (lv625,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv628 = R.call_tir(cls.transpose6, (lv616,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv629 = R.call_tir(cls.transpose6, (lv626,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv630 = R.call_tir(cls.transpose6, (lv627,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv986 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv628, lv629, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv987 = R.call_tir(cls.fused_softmax1_cast4, (lv986,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv639 = R.call_tir(cls.matmul10, (lv987, lv630), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv640 = R.call_tir(cls.transpose8, (lv639,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv641 = R.call_tir(cls.reshape8, (lv640,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv988: R.Tensor((4096, 512), dtype="uint32") = params[124]
lv989: R.Tensor((4096, 128), dtype="float16") = params[125]
lv88_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv988, lv989, lv641, lv87_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv180_1: R.Tensor((4096,), dtype="float16") = params[131]
lv645 = R.call_tir(cls.rms_norm, (lv88_1, lv180_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv992: R.Tensor((22016, 512), dtype="uint32") = params[126]
lv993: R.Tensor((22016, 128), dtype="float16") = params[127]
lv90_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv992, lv993, lv645), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv995 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv90_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv996: R.Tensor((4096, 1376), dtype="uint32") = params[128]
lv997: R.Tensor((4096, 344), dtype="float16") = params[129]
lv89_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv996, lv997, lv995, lv88_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv187: R.Tensor((4096,), dtype="float16") = params[140]
lv656 = R.call_tir(cls.rms_norm, (lv89_3, lv187), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1000: R.Tensor((12288, 512), dtype="uint32") = params[132]
lv1001: R.Tensor((12288, 128), dtype="float16") = params[133]
lv91_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1000, lv1001, lv656), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv659 = R.call_tir(cls.split1, (lv91_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv660: R.Tensor((1, n, 4096), dtype="float16") = lv659[0]
lv661 = R.call_tir(cls.reshape7, (lv660,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv662: R.Tensor((1, n, 4096), dtype="float16") = lv659[1]
lv663 = R.call_tir(cls.reshape7, (lv662,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv664: R.Tensor((1, n, 4096), dtype="float16") = lv659[2]
lv665 = R.call_tir(cls.reshape7, (lv664,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv666 = R.call_tir(cls.rotary_embedding, (lv661, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv667 = R.call_tir(cls.rotary_embedding, (lv663, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv668 = R.call_tir(cls.squeeze1, (lv667,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv669 = R.call_tir(cls.squeeze1, (lv665,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv670: R.Object = kv_cache[26]
lv671: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv670, lv668, sinfo_args=(R.Object,))
lv672: R.Object = kv_cache[27]
lv673: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv672, lv669, sinfo_args=(R.Object,))
lv674: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv671, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv675: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv673, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv676 = R.call_tir(cls.reshape3, (lv674,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv677 = R.call_tir(cls.reshape3, (lv675,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv678 = R.call_tir(cls.transpose6, (lv666,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv679 = R.call_tir(cls.transpose6, (lv676,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv680 = R.call_tir(cls.transpose6, (lv677,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1003 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv678, lv679, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1004 = R.call_tir(cls.fused_softmax1_cast4, (lv1003,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv689 = R.call_tir(cls.matmul10, (lv1004, lv680), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv690 = R.call_tir(cls.transpose8, (lv689,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv691 = R.call_tir(cls.reshape8, (lv690,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1005: R.Tensor((4096, 512), dtype="uint32") = params[134]
lv1006: R.Tensor((4096, 128), dtype="float16") = params[135]
lv90_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1005, lv1006, lv691, lv89_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv194: R.Tensor((4096,), dtype="float16") = params[141]
lv695 = R.call_tir(cls.rms_norm, (lv90_2, lv194), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1009: R.Tensor((22016, 512), dtype="uint32") = params[136]
lv1010: R.Tensor((22016, 128), dtype="float16") = params[137]
lv92 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1009, lv1010, lv695), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1012 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv92,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1013: R.Tensor((4096, 1376), dtype="uint32") = params[138]
lv1014: R.Tensor((4096, 344), dtype="float16") = params[139]
lv91_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1013, lv1014, lv1012, lv90_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv201: R.Tensor((4096,), dtype="float16") = params[150]
lv706 = R.call_tir(cls.rms_norm, (lv91_2, lv201), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1017: R.Tensor((12288, 512), dtype="uint32") = params[142]
lv1018: R.Tensor((12288, 128), dtype="float16") = params[143]
lv93 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1017, lv1018, lv706), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv709 = R.call_tir(cls.split1, (lv93,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv710: R.Tensor((1, n, 4096), dtype="float16") = lv709[0]
lv711 = R.call_tir(cls.reshape7, (lv710,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv712: R.Tensor((1, n, 4096), dtype="float16") = lv709[1]
lv713 = R.call_tir(cls.reshape7, (lv712,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv714: R.Tensor((1, n, 4096), dtype="float16") = lv709[2]
lv715 = R.call_tir(cls.reshape7, (lv714,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv716 = R.call_tir(cls.rotary_embedding, (lv711, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv717 = R.call_tir(cls.rotary_embedding, (lv713, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv718 = R.call_tir(cls.squeeze1, (lv717,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv719 = R.call_tir(cls.squeeze1, (lv715,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv720: R.Object = kv_cache[28]
lv721: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv720, lv718, sinfo_args=(R.Object,))
lv722: R.Object = kv_cache[29]
lv723: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv722, lv719, sinfo_args=(R.Object,))
lv724: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv721, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv725: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv723, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv726 = R.call_tir(cls.reshape3, (lv724,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv727 = R.call_tir(cls.reshape3, (lv725,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv728 = R.call_tir(cls.transpose6, (lv716,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv729 = R.call_tir(cls.transpose6, (lv726,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv730 = R.call_tir(cls.transpose6, (lv727,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1020 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv728, lv729, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1021 = R.call_tir(cls.fused_softmax1_cast4, (lv1020,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv739 = R.call_tir(cls.matmul10, (lv1021, lv730), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv740 = R.call_tir(cls.transpose8, (lv739,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv741 = R.call_tir(cls.reshape8, (lv740,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1022: R.Tensor((4096, 512), dtype="uint32") = params[144]
lv1023: R.Tensor((4096, 128), dtype="float16") = params[145]
lv92_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1022, lv1023, lv741, lv91_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv208: R.Tensor((4096,), dtype="float16") = params[151]
lv745 = R.call_tir(cls.rms_norm, (lv92_1, lv208), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1026: R.Tensor((22016, 512), dtype="uint32") = params[146]
lv1027: R.Tensor((22016, 128), dtype="float16") = params[147]
lv94 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1026, lv1027, lv745), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1029 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv94,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1030: R.Tensor((4096, 1376), dtype="uint32") = params[148]
lv1031: R.Tensor((4096, 344), dtype="float16") = params[149]
lv93_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1030, lv1031, lv1029, lv92_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv215_1: R.Tensor((4096,), dtype="float16") = params[160]
lv756 = R.call_tir(cls.rms_norm, (lv93_1, lv215_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1034: R.Tensor((12288, 512), dtype="uint32") = params[152]
lv1035: R.Tensor((12288, 128), dtype="float16") = params[153]
lv95_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1034, lv1035, lv756), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv759 = R.call_tir(cls.split1, (lv95_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv760: R.Tensor((1, n, 4096), dtype="float16") = lv759[0]
lv761 = R.call_tir(cls.reshape7, (lv760,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv762: R.Tensor((1, n, 4096), dtype="float16") = lv759[1]
lv763 = R.call_tir(cls.reshape7, (lv762,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv764: R.Tensor((1, n, 4096), dtype="float16") = lv759[2]
lv765 = R.call_tir(cls.reshape7, (lv764,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv766 = R.call_tir(cls.rotary_embedding, (lv761, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv767 = R.call_tir(cls.rotary_embedding, (lv763, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv768 = R.call_tir(cls.squeeze1, (lv767,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv769 = R.call_tir(cls.squeeze1, (lv765,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv770: R.Object = kv_cache[30]
lv771: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv770, lv768, sinfo_args=(R.Object,))
lv772: R.Object = kv_cache[31]
lv773: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv772, lv769, sinfo_args=(R.Object,))
lv774: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv771, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv775_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv773, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv776_1 = R.call_tir(cls.reshape3, (lv774,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv777 = R.call_tir(cls.reshape3, (lv775_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv778_1 = R.call_tir(cls.transpose6, (lv766,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv779_1 = R.call_tir(cls.transpose6, (lv776_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv780_1 = R.call_tir(cls.transpose6, (lv777,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1037 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv778_1, lv779_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1038 = R.call_tir(cls.fused_softmax1_cast4, (lv1037,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv789_1 = R.call_tir(cls.matmul10, (lv1038, lv780_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv790 = R.call_tir(cls.transpose8, (lv789_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv791_1 = R.call_tir(cls.reshape8, (lv790,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1039: R.Tensor((4096, 512), dtype="uint32") = params[154]
lv1040: R.Tensor((4096, 128), dtype="float16") = params[155]
lv94_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1039, lv1040, lv791_1, lv93_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv222_1: R.Tensor((4096,), dtype="float16") = params[161]
lv795 = R.call_tir(cls.rms_norm, (lv94_1, lv222_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1043: R.Tensor((22016, 512), dtype="uint32") = params[156]
lv1044: R.Tensor((22016, 128), dtype="float16") = params[157]
lv96_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1043, lv1044, lv795), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1046 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv96_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1047: R.Tensor((4096, 1376), dtype="uint32") = params[158]
lv1048: R.Tensor((4096, 344), dtype="float16") = params[159]
lv95_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1047, lv1048, lv1046, lv94_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv229_1: R.Tensor((4096,), dtype="float16") = params[170]
lv806_1 = R.call_tir(cls.rms_norm, (lv95_2, lv229_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1051: R.Tensor((12288, 512), dtype="uint32") = params[162]
lv1052: R.Tensor((12288, 128), dtype="float16") = params[163]
lv97 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1051, lv1052, lv806_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv809_1 = R.call_tir(cls.split1, (lv97,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv810_1: R.Tensor((1, n, 4096), dtype="float16") = lv809_1[0]
lv811 = R.call_tir(cls.reshape7, (lv810_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv812: R.Tensor((1, n, 4096), dtype="float16") = lv809_1[1]
lv813_1 = R.call_tir(cls.reshape7, (lv812,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv814_1: R.Tensor((1, n, 4096), dtype="float16") = lv809_1[2]
lv815 = R.call_tir(cls.reshape7, (lv814_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv816_1 = R.call_tir(cls.rotary_embedding, (lv811, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv817_1 = R.call_tir(cls.rotary_embedding, (lv813_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv818_1 = R.call_tir(cls.squeeze1, (lv817_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv819_1 = R.call_tir(cls.squeeze1, (lv815,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv820: R.Object = kv_cache[32]
lv821: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv820, lv818_1, sinfo_args=(R.Object,))
lv822_1: R.Object = kv_cache[33]
lv823_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv822_1, lv819_1, sinfo_args=(R.Object,))
lv824: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv821, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv825_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv823_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv826_1 = R.call_tir(cls.reshape3, (lv824,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv827_1 = R.call_tir(cls.reshape3, (lv825_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv828 = R.call_tir(cls.transpose6, (lv816_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv829 = R.call_tir(cls.transpose6, (lv826_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv830_1 = R.call_tir(cls.transpose6, (lv827_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1054 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv828, lv829, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1055 = R.call_tir(cls.fused_softmax1_cast4, (lv1054,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv839_1 = R.call_tir(cls.matmul10, (lv1055, lv830_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv840_1 = R.call_tir(cls.transpose8, (lv839_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv841 = R.call_tir(cls.reshape8, (lv840_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1056: R.Tensor((4096, 512), dtype="uint32") = params[164]
lv1057: R.Tensor((4096, 128), dtype="float16") = params[165]
lv96_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1056, lv1057, lv841, lv95_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv236: R.Tensor((4096,), dtype="float16") = params[171]
lv845 = R.call_tir(cls.rms_norm, (lv96_2, lv236), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1060: R.Tensor((22016, 512), dtype="uint32") = params[166]
lv1061: R.Tensor((22016, 128), dtype="float16") = params[167]
lv98 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1060, lv1061, lv845), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1063 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv98,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1064: R.Tensor((4096, 1376), dtype="uint32") = params[168]
lv1065: R.Tensor((4096, 344), dtype="float16") = params[169]
lv97_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1064, lv1065, lv1063, lv96_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv243: R.Tensor((4096,), dtype="float16") = params[180]
lv856_1 = R.call_tir(cls.rms_norm, (lv97_1, lv243), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1068: R.Tensor((12288, 512), dtype="uint32") = params[172]
lv1069: R.Tensor((12288, 128), dtype="float16") = params[173]
lv99 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1068, lv1069, lv856_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv859_1 = R.call_tir(cls.split1, (lv99,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv860_1: R.Tensor((1, n, 4096), dtype="float16") = lv859_1[0]
lv861_1 = R.call_tir(cls.reshape7, (lv860_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv862: R.Tensor((1, n, 4096), dtype="float16") = lv859_1[1]
lv863 = R.call_tir(cls.reshape7, (lv862,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv864_1: R.Tensor((1, n, 4096), dtype="float16") = lv859_1[2]
lv865_1 = R.call_tir(cls.reshape7, (lv864_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv866 = R.call_tir(cls.rotary_embedding, (lv861_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv867_1 = R.call_tir(cls.rotary_embedding, (lv863, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv868_1 = R.call_tir(cls.squeeze1, (lv867_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv869_1 = R.call_tir(cls.squeeze1, (lv865_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv870_1: R.Object = kv_cache[34]
lv871: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv870_1, lv868_1, sinfo_args=(R.Object,))
lv872: R.Object = kv_cache[35]
lv873_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv872, lv869_1, sinfo_args=(R.Object,))
lv874_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv871, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv875: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv873_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv876_1 = R.call_tir(cls.reshape3, (lv874_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv877_1 = R.call_tir(cls.reshape3, (lv875,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv878_1 = R.call_tir(cls.transpose6, (lv866,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv879 = R.call_tir(cls.transpose6, (lv876_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv880 = R.call_tir(cls.transpose6, (lv877_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1071 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv878_1, lv879, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1072 = R.call_tir(cls.fused_softmax1_cast4, (lv1071,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv889 = R.call_tir(cls.matmul10, (lv1072, lv880), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv890_1 = R.call_tir(cls.transpose8, (lv889,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv891_1 = R.call_tir(cls.reshape8, (lv890_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1073: R.Tensor((4096, 512), dtype="uint32") = params[174]
lv1074: R.Tensor((4096, 128), dtype="float16") = params[175]
lv98_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1073, lv1074, lv891_1, lv97_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv250: R.Tensor((4096,), dtype="float16") = params[181]
lv895_1 = R.call_tir(cls.rms_norm, (lv98_1, lv250), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1077: R.Tensor((22016, 512), dtype="uint32") = params[176]
lv1078: R.Tensor((22016, 128), dtype="float16") = params[177]
lv100 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1077, lv1078, lv895_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1080 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv100,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1081: R.Tensor((4096, 1376), dtype="uint32") = params[178]
lv1082: R.Tensor((4096, 344), dtype="float16") = params[179]
lv99_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1081, lv1082, lv1080, lv98_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv257: R.Tensor((4096,), dtype="float16") = params[190]
lv906 = R.call_tir(cls.rms_norm, (lv99_1, lv257), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1085: R.Tensor((12288, 512), dtype="uint32") = params[182]
lv1086: R.Tensor((12288, 128), dtype="float16") = params[183]
lv101 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1085, lv1086, lv906), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv909 = R.call_tir(cls.split1, (lv101,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv910_1: R.Tensor((1, n, 4096), dtype="float16") = lv909[0]
lv911_1 = R.call_tir(cls.reshape7, (lv910_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv912_1: R.Tensor((1, n, 4096), dtype="float16") = lv909[1]
lv913 = R.call_tir(cls.reshape7, (lv912_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv914: R.Tensor((1, n, 4096), dtype="float16") = lv909[2]
lv915_1 = R.call_tir(cls.reshape7, (lv914,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv916_1 = R.call_tir(cls.rotary_embedding, (lv911_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv917 = R.call_tir(cls.rotary_embedding, (lv913, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv918_1 = R.call_tir(cls.squeeze1, (lv917,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv919_1 = R.call_tir(cls.squeeze1, (lv915_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv920_1: R.Object = kv_cache[36]
lv921_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv920_1, lv918_1, sinfo_args=(R.Object,))
lv922: R.Object = kv_cache[37]
lv923: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv922, lv919_1, sinfo_args=(R.Object,))
lv924_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv921_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv925_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv923, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv926 = R.call_tir(cls.reshape3, (lv924_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv927_1 = R.call_tir(cls.reshape3, (lv925_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv928_1 = R.call_tir(cls.transpose6, (lv916_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv929_1 = R.call_tir(cls.transpose6, (lv926,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv930 = R.call_tir(cls.transpose6, (lv927_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1088 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv928_1, lv929_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1089 = R.call_tir(cls.fused_softmax1_cast4, (lv1088,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv939 = R.call_tir(cls.matmul10, (lv1089, lv930), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv940 = R.call_tir(cls.transpose8, (lv939,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv941_1 = R.call_tir(cls.reshape8, (lv940,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1090: R.Tensor((4096, 512), dtype="uint32") = params[184]
lv1091: R.Tensor((4096, 128), dtype="float16") = params[185]
lv100_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1090, lv1091, lv941_1, lv99_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv264_1: R.Tensor((4096,), dtype="float16") = params[191]
lv945_1 = R.call_tir(cls.rms_norm, (lv100_1, lv264_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1094: R.Tensor((22016, 512), dtype="uint32") = params[186]
lv1095: R.Tensor((22016, 128), dtype="float16") = params[187]
lv102 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1094, lv1095, lv945_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1097 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv102,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1098: R.Tensor((4096, 1376), dtype="uint32") = params[188]
lv1099: R.Tensor((4096, 344), dtype="float16") = params[189]
lv101_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1098, lv1099, lv1097, lv100_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv271_1: R.Tensor((4096,), dtype="float16") = params[200]
lv956 = R.call_tir(cls.rms_norm, (lv101_1, lv271_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1102: R.Tensor((12288, 512), dtype="uint32") = params[192]
lv1103: R.Tensor((12288, 128), dtype="float16") = params[193]
lv103_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1102, lv1103, lv956), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv959_1 = R.call_tir(cls.split1, (lv103_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv960: R.Tensor((1, n, 4096), dtype="float16") = lv959_1[0]
lv961_1 = R.call_tir(cls.reshape7, (lv960,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv962_1: R.Tensor((1, n, 4096), dtype="float16") = lv959_1[1]
lv963_1 = R.call_tir(cls.reshape7, (lv962_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv964: R.Tensor((1, n, 4096), dtype="float16") = lv959_1[2]
lv965 = R.call_tir(cls.reshape7, (lv964,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv966_1 = R.call_tir(cls.rotary_embedding, (lv961_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv967_1 = R.call_tir(cls.rotary_embedding, (lv963_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv968 = R.call_tir(cls.squeeze1, (lv967_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv969_1 = R.call_tir(cls.squeeze1, (lv965,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv970_1: R.Object = kv_cache[38]
lv971_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv970_1, lv968, sinfo_args=(R.Object,))
lv972_1: R.Object = kv_cache[39]
lv973: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv972_1, lv969_1, sinfo_args=(R.Object,))
lv974: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv971_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv975_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv973, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv976_1 = R.call_tir(cls.reshape3, (lv974,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv977 = R.call_tir(cls.reshape3, (lv975_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv978_1 = R.call_tir(cls.transpose6, (lv966_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv979_1 = R.call_tir(cls.transpose6, (lv976_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv980_1 = R.call_tir(cls.transpose6, (lv977,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1105 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv978_1, lv979_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1106 = R.call_tir(cls.fused_softmax1_cast4, (lv1105,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv989_1 = R.call_tir(cls.matmul10, (lv1106, lv980_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv990 = R.call_tir(cls.transpose8, (lv989_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv991 = R.call_tir(cls.reshape8, (lv990,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1107: R.Tensor((4096, 512), dtype="uint32") = params[194]
lv1108: R.Tensor((4096, 128), dtype="float16") = params[195]
lv102_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1107, lv1108, lv991, lv101_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv278_1: R.Tensor((4096,), dtype="float16") = params[201]
lv995_1 = R.call_tir(cls.rms_norm, (lv102_1, lv278_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1111: R.Tensor((22016, 512), dtype="uint32") = params[196]
lv1112: R.Tensor((22016, 128), dtype="float16") = params[197]
lv104 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1111, lv1112, lv995_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1114 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv104,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1115: R.Tensor((4096, 1376), dtype="uint32") = params[198]
lv1116: R.Tensor((4096, 344), dtype="float16") = params[199]
lv103_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1115, lv1116, lv1114, lv102_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv285: R.Tensor((4096,), dtype="float16") = params[210]
lv1006_1 = R.call_tir(cls.rms_norm, (lv103_2, lv285), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1119: R.Tensor((12288, 512), dtype="uint32") = params[202]
lv1120: R.Tensor((12288, 128), dtype="float16") = params[203]
lv105 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1119, lv1120, lv1006_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1009_1 = R.call_tir(cls.split1, (lv105,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1010_1: R.Tensor((1, n, 4096), dtype="float16") = lv1009_1[0]
lv1011 = R.call_tir(cls.reshape7, (lv1010_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1012_1: R.Tensor((1, n, 4096), dtype="float16") = lv1009_1[1]
lv1013_1 = R.call_tir(cls.reshape7, (lv1012_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1014_1: R.Tensor((1, n, 4096), dtype="float16") = lv1009_1[2]
lv1015 = R.call_tir(cls.reshape7, (lv1014_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1016 = R.call_tir(cls.rotary_embedding, (lv1011, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1017_1 = R.call_tir(cls.rotary_embedding, (lv1013_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1018_1 = R.call_tir(cls.squeeze1, (lv1017_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1019 = R.call_tir(cls.squeeze1, (lv1015,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1020_1: R.Object = kv_cache[40]
lv1021_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1020_1, lv1018_1, sinfo_args=(R.Object,))
lv1022_1: R.Object = kv_cache[41]
lv1023_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1022_1, lv1019, sinfo_args=(R.Object,))
lv1024: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1021_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1025: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1023_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1026_1 = R.call_tir(cls.reshape3, (lv1024,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1027_1 = R.call_tir(cls.reshape3, (lv1025,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1028 = R.call_tir(cls.transpose6, (lv1016,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1029_1 = R.call_tir(cls.transpose6, (lv1026_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1030_1 = R.call_tir(cls.transpose6, (lv1027_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1122 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1028, lv1029_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1123 = R.call_tir(cls.fused_softmax1_cast4, (lv1122,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1039_1 = R.call_tir(cls.matmul10, (lv1123, lv1030_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1040_1 = R.call_tir(cls.transpose8, (lv1039_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1041 = R.call_tir(cls.reshape8, (lv1040_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1124: R.Tensor((4096, 512), dtype="uint32") = params[204]
lv1125: R.Tensor((4096, 128), dtype="float16") = params[205]
lv104_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1124, lv1125, lv1041, lv103_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv292: R.Tensor((4096,), dtype="float16") = params[211]
lv1045 = R.call_tir(cls.rms_norm, (lv104_1, lv292), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1128: R.Tensor((22016, 512), dtype="uint32") = params[206]
lv1129: R.Tensor((22016, 128), dtype="float16") = params[207]
lv106_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1128, lv1129, lv1045), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1131 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv106_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1132: R.Tensor((4096, 1376), dtype="uint32") = params[208]
lv1133: R.Tensor((4096, 344), dtype="float16") = params[209]
lv105_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1132, lv1133, lv1131, lv104_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv299: R.Tensor((4096,), dtype="float16") = params[220]
lv1056_1 = R.call_tir(cls.rms_norm, (lv105_1, lv299), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1136: R.Tensor((12288, 512), dtype="uint32") = params[212]
lv1137: R.Tensor((12288, 128), dtype="float16") = params[213]
lv107 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1136, lv1137, lv1056_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1059 = R.call_tir(cls.split1, (lv107,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1060_1: R.Tensor((1, n, 4096), dtype="float16") = lv1059[0]
lv1061_1 = R.call_tir(cls.reshape7, (lv1060_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1062: R.Tensor((1, n, 4096), dtype="float16") = lv1059[1]
lv1063_1 = R.call_tir(cls.reshape7, (lv1062,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1064_1: R.Tensor((1, n, 4096), dtype="float16") = lv1059[2]
lv1065_1 = R.call_tir(cls.reshape7, (lv1064_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1066 = R.call_tir(cls.rotary_embedding, (lv1061_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1067 = R.call_tir(cls.rotary_embedding, (lv1063_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1068_1 = R.call_tir(cls.squeeze1, (lv1067,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1069_1 = R.call_tir(cls.squeeze1, (lv1065_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1070: R.Object = kv_cache[42]
lv1071_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1070, lv1068_1, sinfo_args=(R.Object,))
lv1072_1: R.Object = kv_cache[43]
lv1073_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1072_1, lv1069_1, sinfo_args=(R.Object,))
lv1074_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1071_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1075: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1073_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1076 = R.call_tir(cls.reshape3, (lv1074_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1077_1 = R.call_tir(cls.reshape3, (lv1075,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1078_1 = R.call_tir(cls.transpose6, (lv1066,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1079 = R.call_tir(cls.transpose6, (lv1076,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1080_1 = R.call_tir(cls.transpose6, (lv1077_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1139 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1078_1, lv1079, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1140 = R.call_tir(cls.fused_softmax1_cast4, (lv1139,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1089_1 = R.call_tir(cls.matmul10, (lv1140, lv1080_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1090_1 = R.call_tir(cls.transpose8, (lv1089_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1091_1 = R.call_tir(cls.reshape8, (lv1090_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1141: R.Tensor((4096, 512), dtype="uint32") = params[214]
lv1142: R.Tensor((4096, 128), dtype="float16") = params[215]
lv106_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1141, lv1142, lv1091_1, lv105_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv306_1: R.Tensor((4096,), dtype="float16") = params[221]
lv1095_1 = R.call_tir(cls.rms_norm, (lv106_2, lv306_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1145: R.Tensor((22016, 512), dtype="uint32") = params[216]
lv1146: R.Tensor((22016, 128), dtype="float16") = params[217]
lv108 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1145, lv1146, lv1095_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1148 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv108,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1149: R.Tensor((4096, 1376), dtype="uint32") = params[218]
lv1150: R.Tensor((4096, 344), dtype="float16") = params[219]
lv107_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1149, lv1150, lv1148, lv106_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv313_1: R.Tensor((4096,), dtype="float16") = params[230]
lv1106_1 = R.call_tir(cls.rms_norm, (lv107_1, lv313_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1153: R.Tensor((12288, 512), dtype="uint32") = params[222]
lv1154: R.Tensor((12288, 128), dtype="float16") = params[223]
lv109_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1153, lv1154, lv1106_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1109 = R.call_tir(cls.split1, (lv109_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1110: R.Tensor((1, n, 4096), dtype="float16") = lv1109[0]
lv1111_1 = R.call_tir(cls.reshape7, (lv1110,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1112_1: R.Tensor((1, n, 4096), dtype="float16") = lv1109[1]
lv1113 = R.call_tir(cls.reshape7, (lv1112_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1114_1: R.Tensor((1, n, 4096), dtype="float16") = lv1109[2]
lv1115_1 = R.call_tir(cls.reshape7, (lv1114_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1116_1 = R.call_tir(cls.rotary_embedding, (lv1111_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1117 = R.call_tir(cls.rotary_embedding, (lv1113, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1118 = R.call_tir(cls.squeeze1, (lv1117,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1119_1 = R.call_tir(cls.squeeze1, (lv1115_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1120_1: R.Object = kv_cache[44]
lv1121: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1120_1, lv1118, sinfo_args=(R.Object,))
lv1122_1: R.Object = kv_cache[45]
lv1123_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1122_1, lv1119_1, sinfo_args=(R.Object,))
lv1124_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1121, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1125_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1123_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1126 = R.call_tir(cls.reshape3, (lv1124_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1127 = R.call_tir(cls.reshape3, (lv1125_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1128_1 = R.call_tir(cls.transpose6, (lv1116_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1129_1 = R.call_tir(cls.transpose6, (lv1126,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1130 = R.call_tir(cls.transpose6, (lv1127,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1156 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1128_1, lv1129_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1157 = R.call_tir(cls.fused_softmax1_cast4, (lv1156,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1139_1 = R.call_tir(cls.matmul10, (lv1157, lv1130), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1140_1 = R.call_tir(cls.transpose8, (lv1139_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1141_1 = R.call_tir(cls.reshape8, (lv1140_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1158: R.Tensor((4096, 512), dtype="uint32") = params[224]
lv1159: R.Tensor((4096, 128), dtype="float16") = params[225]
lv108_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1158, lv1159, lv1141_1, lv107_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv320_1: R.Tensor((4096,), dtype="float16") = params[231]
lv1145_1 = R.call_tir(cls.rms_norm, (lv108_1, lv320_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1162: R.Tensor((22016, 512), dtype="uint32") = params[226]
lv1163: R.Tensor((22016, 128), dtype="float16") = params[227]
lv110_2 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1162, lv1163, lv1145_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1165 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv110_2,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1166: R.Tensor((4096, 1376), dtype="uint32") = params[228]
lv1167: R.Tensor((4096, 344), dtype="float16") = params[229]
lv109_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1166, lv1167, lv1165, lv108_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv327_1: R.Tensor((4096,), dtype="float16") = params[240]
lv1156_1 = R.call_tir(cls.rms_norm, (lv109_2, lv327_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1170: R.Tensor((12288, 512), dtype="uint32") = params[232]
lv1171: R.Tensor((12288, 128), dtype="float16") = params[233]
lv111_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1170, lv1171, lv1156_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1159_1 = R.call_tir(cls.split1, (lv111_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1160: R.Tensor((1, n, 4096), dtype="float16") = lv1159_1[0]
lv1161 = R.call_tir(cls.reshape7, (lv1160,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1162_1: R.Tensor((1, n, 4096), dtype="float16") = lv1159_1[1]
lv1163_1 = R.call_tir(cls.reshape7, (lv1162_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1164: R.Tensor((1, n, 4096), dtype="float16") = lv1159_1[2]
lv1165_1 = R.call_tir(cls.reshape7, (lv1164,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1166_1 = R.call_tir(cls.rotary_embedding, (lv1161, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1167_1 = R.call_tir(cls.rotary_embedding, (lv1163_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1168 = R.call_tir(cls.squeeze1, (lv1167_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1169 = R.call_tir(cls.squeeze1, (lv1165_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1170_1: R.Object = kv_cache[46]
lv1171_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1170_1, lv1168, sinfo_args=(R.Object,))
lv1172: R.Object = kv_cache[47]
lv1173: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1172, lv1169, sinfo_args=(R.Object,))
lv1174: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1171_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1175: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1173, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1176 = R.call_tir(cls.reshape3, (lv1174,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1177 = R.call_tir(cls.reshape3, (lv1175,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1178 = R.call_tir(cls.transpose6, (lv1166_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1179 = R.call_tir(cls.transpose6, (lv1176,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1180 = R.call_tir(cls.transpose6, (lv1177,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1173_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1178, lv1179, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1174_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1173_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1189 = R.call_tir(cls.matmul10, (lv1174_1, lv1180), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1190 = R.call_tir(cls.transpose8, (lv1189,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1191 = R.call_tir(cls.reshape8, (lv1190,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1175_1: R.Tensor((4096, 512), dtype="uint32") = params[234]
lv1176_1: R.Tensor((4096, 128), dtype="float16") = params[235]
lv110_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1175_1, lv1176_1, lv1191, lv109_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv334: R.Tensor((4096,), dtype="float16") = params[241]
lv1195 = R.call_tir(cls.rms_norm, (lv110_3, lv334), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1179_1: R.Tensor((22016, 512), dtype="uint32") = params[236]
lv1180_1: R.Tensor((22016, 128), dtype="float16") = params[237]
lv112_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1179_1, lv1180_1, lv1195), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1182 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv112_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1183: R.Tensor((4096, 1376), dtype="uint32") = params[238]
lv1184: R.Tensor((4096, 344), dtype="float16") = params[239]
lv111_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1183, lv1184, lv1182, lv110_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv341_1: R.Tensor((4096,), dtype="float16") = params[250]
lv1206 = R.call_tir(cls.rms_norm, (lv111_2, lv341_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1187: R.Tensor((12288, 512), dtype="uint32") = params[242]
lv1188: R.Tensor((12288, 128), dtype="float16") = params[243]
lv113_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1187, lv1188, lv1206), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1209 = R.call_tir(cls.split1, (lv113_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1210: R.Tensor((1, n, 4096), dtype="float16") = lv1209[0]
lv1211 = R.call_tir(cls.reshape7, (lv1210,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1212: R.Tensor((1, n, 4096), dtype="float16") = lv1209[1]
lv1213 = R.call_tir(cls.reshape7, (lv1212,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1214: R.Tensor((1, n, 4096), dtype="float16") = lv1209[2]
lv1215 = R.call_tir(cls.reshape7, (lv1214,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1216 = R.call_tir(cls.rotary_embedding, (lv1211, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1217 = R.call_tir(cls.rotary_embedding, (lv1213, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1218 = R.call_tir(cls.squeeze1, (lv1217,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1219 = R.call_tir(cls.squeeze1, (lv1215,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1220: R.Object = kv_cache[48]
lv1221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1220, lv1218, sinfo_args=(R.Object,))
lv1222: R.Object = kv_cache[49]
lv1223: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1222, lv1219, sinfo_args=(R.Object,))
lv1224: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1221, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1225: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1223, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1226 = R.call_tir(cls.reshape3, (lv1224,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1227 = R.call_tir(cls.reshape3, (lv1225,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1228 = R.call_tir(cls.transpose6, (lv1216,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1229 = R.call_tir(cls.transpose6, (lv1226,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1230 = R.call_tir(cls.transpose6, (lv1227,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1190_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1228, lv1229, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1191_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1190_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1239 = R.call_tir(cls.matmul10, (lv1191_1, lv1230), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1240 = R.call_tir(cls.transpose8, (lv1239,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1241 = R.call_tir(cls.reshape8, (lv1240,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1192: R.Tensor((4096, 512), dtype="uint32") = params[244]
lv1193: R.Tensor((4096, 128), dtype="float16") = params[245]
lv112_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1192, lv1193, lv1241, lv111_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv348: R.Tensor((4096,), dtype="float16") = params[251]
lv1245 = R.call_tir(cls.rms_norm, (lv112_2, lv348), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1196: R.Tensor((22016, 512), dtype="uint32") = params[246]
lv1197: R.Tensor((22016, 128), dtype="float16") = params[247]
lv114_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1196, lv1197, lv1245), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1199 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv114_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1200: R.Tensor((4096, 1376), dtype="uint32") = params[248]
lv1201: R.Tensor((4096, 344), dtype="float16") = params[249]
lv113_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1200, lv1201, lv1199, lv112_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv355: R.Tensor((4096,), dtype="float16") = params[260]
lv1256 = R.call_tir(cls.rms_norm, (lv113_2, lv355), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1204: R.Tensor((12288, 512), dtype="uint32") = params[252]
lv1205: R.Tensor((12288, 128), dtype="float16") = params[253]
lv115_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1204, lv1205, lv1256), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1259 = R.call_tir(cls.split1, (lv115_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1260: R.Tensor((1, n, 4096), dtype="float16") = lv1259[0]
lv1261 = R.call_tir(cls.reshape7, (lv1260,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1262: R.Tensor((1, n, 4096), dtype="float16") = lv1259[1]
lv1263 = R.call_tir(cls.reshape7, (lv1262,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1264: R.Tensor((1, n, 4096), dtype="float16") = lv1259[2]
lv1265 = R.call_tir(cls.reshape7, (lv1264,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1266 = R.call_tir(cls.rotary_embedding, (lv1261, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1267 = R.call_tir(cls.rotary_embedding, (lv1263, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1268 = R.call_tir(cls.squeeze1, (lv1267,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1269 = R.call_tir(cls.squeeze1, (lv1265,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1270: R.Object = kv_cache[50]
lv1271: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1270, lv1268, sinfo_args=(R.Object,))
lv1272: R.Object = kv_cache[51]
lv1273: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1272, lv1269, sinfo_args=(R.Object,))
lv1274: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1271, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1275: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1273, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1276 = R.call_tir(cls.reshape3, (lv1274,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1277 = R.call_tir(cls.reshape3, (lv1275,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1278 = R.call_tir(cls.transpose6, (lv1266,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1279 = R.call_tir(cls.transpose6, (lv1276,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1280 = R.call_tir(cls.transpose6, (lv1277,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1207 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1278, lv1279, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1208 = R.call_tir(cls.fused_softmax1_cast4, (lv1207,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1289 = R.call_tir(cls.matmul10, (lv1208, lv1280), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1290 = R.call_tir(cls.transpose8, (lv1289,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1291 = R.call_tir(cls.reshape8, (lv1290,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1209_1: R.Tensor((4096, 512), dtype="uint32") = params[254]
lv1210_1: R.Tensor((4096, 128), dtype="float16") = params[255]
lv114_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1209_1, lv1210_1, lv1291, lv113_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv362_1: R.Tensor((4096,), dtype="float16") = params[261]
lv1295 = R.call_tir(cls.rms_norm, (lv114_2, lv362_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1213_1: R.Tensor((22016, 512), dtype="uint32") = params[256]
lv1214_1: R.Tensor((22016, 128), dtype="float16") = params[257]
lv116_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1213_1, lv1214_1, lv1295), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1216_1 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv116_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1217_1: R.Tensor((4096, 1376), dtype="uint32") = params[258]
lv1218_1: R.Tensor((4096, 344), dtype="float16") = params[259]
lv115_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1217_1, lv1218_1, lv1216_1, lv114_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv369_1: R.Tensor((4096,), dtype="float16") = params[270]
lv1306 = R.call_tir(cls.rms_norm, (lv115_2, lv369_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1221_1: R.Tensor((12288, 512), dtype="uint32") = params[262]
lv1222_1: R.Tensor((12288, 128), dtype="float16") = params[263]
lv117_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1221_1, lv1222_1, lv1306), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1309 = R.call_tir(cls.split1, (lv117_2,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1310: R.Tensor((1, n, 4096), dtype="float16") = lv1309[0]
lv1311 = R.call_tir(cls.reshape7, (lv1310,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1312: R.Tensor((1, n, 4096), dtype="float16") = lv1309[1]
lv1313 = R.call_tir(cls.reshape7, (lv1312,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1314: R.Tensor((1, n, 4096), dtype="float16") = lv1309[2]
lv1315 = R.call_tir(cls.reshape7, (lv1314,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1316 = R.call_tir(cls.rotary_embedding, (lv1311, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1317 = R.call_tir(cls.rotary_embedding, (lv1313, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1318 = R.call_tir(cls.squeeze1, (lv1317,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1319 = R.call_tir(cls.squeeze1, (lv1315,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1320: R.Object = kv_cache[52]
lv1321: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1320, lv1318, sinfo_args=(R.Object,))
lv1322: R.Object = kv_cache[53]
lv1323: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1322, lv1319, sinfo_args=(R.Object,))
lv1324: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1321, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1325: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1323, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1326 = R.call_tir(cls.reshape3, (lv1324,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1327 = R.call_tir(cls.reshape3, (lv1325,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1328 = R.call_tir(cls.transpose6, (lv1316,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1329 = R.call_tir(cls.transpose6, (lv1326,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1330 = R.call_tir(cls.transpose6, (lv1327,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1224_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1328, lv1329, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1225_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1224_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1339 = R.call_tir(cls.matmul10, (lv1225_1, lv1330), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1340 = R.call_tir(cls.transpose8, (lv1339,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1341 = R.call_tir(cls.reshape8, (lv1340,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1226_1: R.Tensor((4096, 512), dtype="uint32") = params[264]
lv1227_1: R.Tensor((4096, 128), dtype="float16") = params[265]
lv116_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1226_1, lv1227_1, lv1341, lv115_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv376_1: R.Tensor((4096,), dtype="float16") = params[271]
lv1345 = R.call_tir(cls.rms_norm, (lv116_2, lv376_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1230_1: R.Tensor((22016, 512), dtype="uint32") = params[266]
lv1231: R.Tensor((22016, 128), dtype="float16") = params[267]
lv118_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1230_1, lv1231, lv1345), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1233 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv118_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1234: R.Tensor((4096, 1376), dtype="uint32") = params[268]
lv1235: R.Tensor((4096, 344), dtype="float16") = params[269]
lv117_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1234, lv1235, lv1233, lv116_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv383: R.Tensor((4096,), dtype="float16") = params[280]
lv1356 = R.call_tir(cls.rms_norm, (lv117_3, lv383), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1238: R.Tensor((12288, 512), dtype="uint32") = params[272]
lv1239_1: R.Tensor((12288, 128), dtype="float16") = params[273]
lv119_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1238, lv1239_1, lv1356), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1359 = R.call_tir(cls.split1, (lv119_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1360: R.Tensor((1, n, 4096), dtype="float16") = lv1359[0]
lv1361 = R.call_tir(cls.reshape7, (lv1360,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1362: R.Tensor((1, n, 4096), dtype="float16") = lv1359[1]
lv1363 = R.call_tir(cls.reshape7, (lv1362,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1364: R.Tensor((1, n, 4096), dtype="float16") = lv1359[2]
lv1365 = R.call_tir(cls.reshape7, (lv1364,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1366 = R.call_tir(cls.rotary_embedding, (lv1361, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1367 = R.call_tir(cls.rotary_embedding, (lv1363, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1368 = R.call_tir(cls.squeeze1, (lv1367,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1369 = R.call_tir(cls.squeeze1, (lv1365,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1370: R.Object = kv_cache[54]
lv1371: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1370, lv1368, sinfo_args=(R.Object,))
lv1372: R.Object = kv_cache[55]
lv1373: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1372, lv1369, sinfo_args=(R.Object,))
lv1374: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1371, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1375: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1373, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1376 = R.call_tir(cls.reshape3, (lv1374,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1377 = R.call_tir(cls.reshape3, (lv1375,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1378 = R.call_tir(cls.transpose6, (lv1366,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1379 = R.call_tir(cls.transpose6, (lv1376,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1380 = R.call_tir(cls.transpose6, (lv1377,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1241_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1378, lv1379, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1242 = R.call_tir(cls.fused_softmax1_cast4, (lv1241_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1389 = R.call_tir(cls.matmul10, (lv1242, lv1380), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1390 = R.call_tir(cls.transpose8, (lv1389,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1391 = R.call_tir(cls.reshape8, (lv1390,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1243: R.Tensor((4096, 512), dtype="uint32") = params[274]
lv1244: R.Tensor((4096, 128), dtype="float16") = params[275]
lv118_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1243, lv1244, lv1391, lv117_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv390_1: R.Tensor((4096,), dtype="float16") = params[281]
lv1395 = R.call_tir(cls.rms_norm, (lv118_2, lv390_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1247: R.Tensor((22016, 512), dtype="uint32") = params[276]
lv1248: R.Tensor((22016, 128), dtype="float16") = params[277]
lv120_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1247, lv1248, lv1395), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1250 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv120_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1251: R.Tensor((4096, 1376), dtype="uint32") = params[278]
lv1252: R.Tensor((4096, 344), dtype="float16") = params[279]
lv119_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1251, lv1252, lv1250, lv118_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv397: R.Tensor((4096,), dtype="float16") = params[290]
lv1406 = R.call_tir(cls.rms_norm, (lv119_2, lv397), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1255: R.Tensor((12288, 512), dtype="uint32") = params[282]
lv1256_1: R.Tensor((12288, 128), dtype="float16") = params[283]
lv121_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1255, lv1256_1, lv1406), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1409 = R.call_tir(cls.split1, (lv121_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1410: R.Tensor((1, n, 4096), dtype="float16") = lv1409[0]
lv1411 = R.call_tir(cls.reshape7, (lv1410,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1412: R.Tensor((1, n, 4096), dtype="float16") = lv1409[1]
lv1413 = R.call_tir(cls.reshape7, (lv1412,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1414: R.Tensor((1, n, 4096), dtype="float16") = lv1409[2]
lv1415 = R.call_tir(cls.reshape7, (lv1414,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1416 = R.call_tir(cls.rotary_embedding, (lv1411, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1417 = R.call_tir(cls.rotary_embedding, (lv1413, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1418 = R.call_tir(cls.squeeze1, (lv1417,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1419 = R.call_tir(cls.squeeze1, (lv1415,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1420: R.Object = kv_cache[56]
lv1421: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1420, lv1418, sinfo_args=(R.Object,))
lv1422: R.Object = kv_cache[57]
lv1423: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1422, lv1419, sinfo_args=(R.Object,))
lv1424: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1421, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1425: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1423, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1426 = R.call_tir(cls.reshape3, (lv1424,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1427 = R.call_tir(cls.reshape3, (lv1425,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1428 = R.call_tir(cls.transpose6, (lv1416,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1429 = R.call_tir(cls.transpose6, (lv1426,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1430 = R.call_tir(cls.transpose6, (lv1427,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1258 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1428, lv1429, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1259_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1258,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1439 = R.call_tir(cls.matmul10, (lv1259_1, lv1430), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1440 = R.call_tir(cls.transpose8, (lv1439,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1441 = R.call_tir(cls.reshape8, (lv1440,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1260_1: R.Tensor((4096, 512), dtype="uint32") = params[284]
lv1261_1: R.Tensor((4096, 128), dtype="float16") = params[285]
lv120_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1260_1, lv1261_1, lv1441, lv119_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv404: R.Tensor((4096,), dtype="float16") = params[291]
lv1445 = R.call_tir(cls.rms_norm, (lv120_2, lv404), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1264_1: R.Tensor((22016, 512), dtype="uint32") = params[286]
lv1265_1: R.Tensor((22016, 128), dtype="float16") = params[287]
lv122_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1264_1, lv1265_1, lv1445), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1267_1 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv122_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1268_1: R.Tensor((4096, 1376), dtype="uint32") = params[288]
lv1269_1: R.Tensor((4096, 344), dtype="float16") = params[289]
lv121_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1268_1, lv1269_1, lv1267_1, lv120_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv411_1: R.Tensor((4096,), dtype="float16") = params[300]
lv1456 = R.call_tir(cls.rms_norm, (lv121_2, lv411_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1272_1: R.Tensor((12288, 512), dtype="uint32") = params[292]
lv1273_1: R.Tensor((12288, 128), dtype="float16") = params[293]
lv123_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1272_1, lv1273_1, lv1456), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1459 = R.call_tir(cls.split1, (lv123_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1460: R.Tensor((1, n, 4096), dtype="float16") = lv1459[0]
lv1461 = R.call_tir(cls.reshape7, (lv1460,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1462: R.Tensor((1, n, 4096), dtype="float16") = lv1459[1]
lv1463 = R.call_tir(cls.reshape7, (lv1462,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1464: R.Tensor((1, n, 4096), dtype="float16") = lv1459[2]
lv1465 = R.call_tir(cls.reshape7, (lv1464,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1466 = R.call_tir(cls.rotary_embedding, (lv1461, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1467 = R.call_tir(cls.rotary_embedding, (lv1463, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1468 = R.call_tir(cls.squeeze1, (lv1467,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1469 = R.call_tir(cls.squeeze1, (lv1465,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1470: R.Object = kv_cache[58]
lv1471: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1470, lv1468, sinfo_args=(R.Object,))
lv1472: R.Object = kv_cache[59]
lv1473: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1472, lv1469, sinfo_args=(R.Object,))
lv1474: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1471, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1475: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1473, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1476 = R.call_tir(cls.reshape3, (lv1474,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1477 = R.call_tir(cls.reshape3, (lv1475,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1478 = R.call_tir(cls.transpose6, (lv1466,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1479 = R.call_tir(cls.transpose6, (lv1476,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1480 = R.call_tir(cls.transpose6, (lv1477,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1275_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1478, lv1479, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1276_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1275_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1489 = R.call_tir(cls.matmul10, (lv1276_1, lv1480), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1490 = R.call_tir(cls.transpose8, (lv1489,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1491 = R.call_tir(cls.reshape8, (lv1490,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1277_1: R.Tensor((4096, 512), dtype="uint32") = params[294]
lv1278_1: R.Tensor((4096, 128), dtype="float16") = params[295]
lv122_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1277_1, lv1278_1, lv1491, lv121_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv418_1: R.Tensor((4096,), dtype="float16") = params[301]
lv1495 = R.call_tir(cls.rms_norm, (lv122_2, lv418_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1281: R.Tensor((22016, 512), dtype="uint32") = params[296]
lv1282: R.Tensor((22016, 128), dtype="float16") = params[297]
lv124_2 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1281, lv1282, lv1495), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1284 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv124_2,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1285: R.Tensor((4096, 1376), dtype="uint32") = params[298]
lv1286: R.Tensor((4096, 344), dtype="float16") = params[299]
lv123_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1285, lv1286, lv1284, lv122_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv425_1: R.Tensor((4096,), dtype="float16") = params[310]
lv1506 = R.call_tir(cls.rms_norm, (lv123_2, lv425_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1289_1: R.Tensor((12288, 512), dtype="uint32") = params[302]
lv1290_1: R.Tensor((12288, 128), dtype="float16") = params[303]
lv125_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1289_1, lv1290_1, lv1506), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1509 = R.call_tir(cls.split1, (lv125_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1510: R.Tensor((1, n, 4096), dtype="float16") = lv1509[0]
lv1511 = R.call_tir(cls.reshape7, (lv1510,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1512: R.Tensor((1, n, 4096), dtype="float16") = lv1509[1]
lv1513 = R.call_tir(cls.reshape7, (lv1512,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1514: R.Tensor((1, n, 4096), dtype="float16") = lv1509[2]
lv1515 = R.call_tir(cls.reshape7, (lv1514,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1516 = R.call_tir(cls.rotary_embedding, (lv1511, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1517 = R.call_tir(cls.rotary_embedding, (lv1513, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1518 = R.call_tir(cls.squeeze1, (lv1517,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1519 = R.call_tir(cls.squeeze1, (lv1515,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1520: R.Object = kv_cache[60]
lv1521: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1520, lv1518, sinfo_args=(R.Object,))
lv1522: R.Object = kv_cache[61]
lv1523: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1522, lv1519, sinfo_args=(R.Object,))
lv1524: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1521, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1525: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1523, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1526 = R.call_tir(cls.reshape3, (lv1524,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1527 = R.call_tir(cls.reshape3, (lv1525,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1528 = R.call_tir(cls.transpose6, (lv1516,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1529 = R.call_tir(cls.transpose6, (lv1526,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1530 = R.call_tir(cls.transpose6, (lv1527,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1292 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1528, lv1529, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1293 = R.call_tir(cls.fused_softmax1_cast4, (lv1292,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1539 = R.call_tir(cls.matmul10, (lv1293, lv1530), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1540 = R.call_tir(cls.transpose8, (lv1539,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1541 = R.call_tir(cls.reshape8, (lv1540,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1294: R.Tensor((4096, 512), dtype="uint32") = params[304]
lv1295_1: R.Tensor((4096, 128), dtype="float16") = params[305]
lv124_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1294, lv1295_1, lv1541, lv123_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv432: R.Tensor((4096,), dtype="float16") = params[311]
lv1545 = R.call_tir(cls.rms_norm, (lv124_3, lv432), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1298: R.Tensor((22016, 512), dtype="uint32") = params[306]
lv1299: R.Tensor((22016, 128), dtype="float16") = params[307]
lv126_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1298, lv1299, lv1545), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1301 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv126_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1302: R.Tensor((4096, 1376), dtype="uint32") = params[308]
lv1303: R.Tensor((4096, 344), dtype="float16") = params[309]
lv125_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1302, lv1303, lv1301, lv124_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv439_1: R.Tensor((4096,), dtype="float16") = params[320]
lv1556 = R.call_tir(cls.rms_norm, (lv125_2, lv439_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1306_1: R.Tensor((12288, 512), dtype="uint32") = params[312]
lv1307: R.Tensor((12288, 128), dtype="float16") = params[313]
lv127_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1306_1, lv1307, lv1556), out_sinfo=R.Tensor((1, n, 12288), dtype="float16"))
lv1559 = R.call_tir(cls.split1, (lv127_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")])
lv1560: R.Tensor((1, n, 4096), dtype="float16") = lv1559[0]
lv1561 = R.call_tir(cls.reshape7, (lv1560,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1562: R.Tensor((1, n, 4096), dtype="float16") = lv1559[1]
lv1563 = R.call_tir(cls.reshape7, (lv1562,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1564: R.Tensor((1, n, 4096), dtype="float16") = lv1559[2]
lv1565 = R.call_tir(cls.reshape7, (lv1564,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1566 = R.call_tir(cls.rotary_embedding, (lv1561, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1567 = R.call_tir(cls.rotary_embedding, (lv1563, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m]))
lv1568 = R.call_tir(cls.squeeze1, (lv1567,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1569 = R.call_tir(cls.squeeze1, (lv1565,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16"))
lv1570: R.Object = kv_cache[62]
lv1571: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1570, lv1568, sinfo_args=(R.Object,))
lv1572: R.Object = kv_cache[63]
lv1573: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1572, lv1569, sinfo_args=(R.Object,))
lv1574: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1571, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1575: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1573, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),))
lv1576 = R.call_tir(cls.reshape3, (lv1574,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1577 = R.call_tir(cls.reshape3, (lv1575,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16"))
lv1578 = R.call_tir(cls.transpose6, (lv1566,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1579 = R.call_tir(cls.transpose6, (lv1576,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1580 = R.call_tir(cls.transpose6, (lv1577,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16"))
lv1309_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1578, lv1579, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32"))
lv1310_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1309_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16"))
lv1589 = R.call_tir(cls.matmul10, (lv1310_1, lv1580), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16"))
lv1590 = R.call_tir(cls.transpose8, (lv1589,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"))
lv1591 = R.call_tir(cls.reshape8, (lv1590,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1311_1: R.Tensor((4096, 512), dtype="uint32") = params[314]
lv1312_1: R.Tensor((4096, 128), dtype="float16") = params[315]
lv126_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1311_1, lv1312_1, lv1591, lv125_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv446: R.Tensor((4096,), dtype="float16") = params[321]
lv1595 = R.call_tir(cls.rms_norm, (lv126_2, lv446), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1315_1: R.Tensor((22016, 512), dtype="uint32") = params[316]
lv1316_1: R.Tensor((22016, 128), dtype="float16") = params[317]
lv128_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1315_1, lv1316_1, lv1595), out_sinfo=R.Tensor((1, n, 22016), dtype="float16"))
lv1318_1 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv128_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16"))
lv1319_1: R.Tensor((4096, 1376), dtype="uint32") = params[318]
lv1320_1: R.Tensor((4096, 344), dtype="float16") = params[319]
lv127_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1319_1, lv1320_1, lv1318_1, lv126_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv453: R.Tensor((4096,), dtype="float16") = params[322]
lv1606 = R.call_tir(cls.rms_norm, (lv127_2, lv453), out_sinfo=R.Tensor((1, n, 4096), dtype="float16"))
lv1607 = R.call_tir(cls.slice, (lv1606,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16"))
lv1323_1: R.Tensor((32001, 512), dtype="uint32") = params[323]
lv1324_1: R.Tensor((32001, 128), dtype="float16") = params[324]
lv129_1 = R.call_tir(cls.fused_fused_decode1_fused_NT_matmul5_cast2, (lv1323_1, lv1324_1, lv1607), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32"))
gv: R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)) = lv129_1, (lv21, lv23, lv71, lv73, lv121, lv123, lv171, lv173, lv221, lv223, lv271, lv273, lv321, lv323, lv371, lv373, lv421, lv423, lv471, lv473, lv521, lv523, lv571, lv573, lv621, lv623, lv671, lv673, lv721, lv723, lv771, lv773, lv821, lv823_1, lv871, lv873_1, lv921_1, lv923, lv971_1, lv973, lv1021_1, lv1023_1, lv1071_1, lv1073_1, lv1121, lv1123_1, lv1171_1, lv1173, lv1221, lv1223, lv1271, lv1273, lv1321, lv1323, lv1371, lv1373, lv1421, lv1423, lv1471, lv1473, lv1521, lv1523, lv1571, lv1573)
R.output(gv)
return gv
@R.function
def softmax_with_temperature(logits: R.Tensor((1, 1, 32001), dtype="float32"), temperature: R.Tensor((), dtype="float32")) -> R.Tensor((1, 1, 32001), dtype="float32"):
R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})
cls = Module
with R.dataflow():
lv3285 = R.call_tir(cls.divide2, (logits, temperature), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32"))
lv3286 = R.call_tir(cls.softmax2, (lv3285,), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32"))
gv3: R.Tensor((1, 1, 32001), dtype="float32") = lv3286
R.output(gv3)
return gv3