Untitled
python
25 days ago
490 kB
1
Indexable
Never
# from tvm.script import ir as I # from tvm.script import tir as T # from tvm.script import relax as R @I.ir_module class Module: @T.prim_func(private=True) def divide2(A: T.Buffer((1, 1, 32001), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((1, 1, 32001), "float32")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(126, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_divide"): v0 = T.axis.spatial(32001, ax0_fused_0 * 256 + ax0_fused_1) T.where(ax0_fused_0 * 256 + ax0_fused_1 < 32001) T.reads(A[0, 0, v0], B[()]) T.writes(T_divide[0, 0, v0]) T_divide[0, 0, v0] = A[0, 0, v0] / B[()] @T.prim_func(private=True) def extend_te(var_A: T.handle, var_concat_te: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, 1, n, n), "float16") m = T.int32() concat_te = T.match_buffer(var_concat_te, (1, 1, n, m), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((n * m + 255) // 256, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("concat_te"): v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % (m * n) // m) v1 = T.axis.spatial(m, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % m) T.where(ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 < n * m) T.reads(A[0, 0, v0, v1 + (n - m)]) T.writes(concat_te[0, 0, v0, v1]) concat_te[0, 0, v0, v1] = T.if_then_else(v1 < m - n, T.float16(65504), A[0, 0, v0, v1 + (n - m)]) @T.prim_func(private=True) def full(var_T_full: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() T_full = T.match_buffer(var_T_full, (1, 1, 1, n), "float16") # with T.block("root"): for ax0_fused_0 in T.thread_binding((n + 255) // 256, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_full"): v0 = T.axis.spatial(n, ax0_fused_0 * 256 + ax0_fused_1) T.where(ax0_fused_0 * 256 + ax0_fused_1 < n) T.reads() T.writes(T_full[0, 0, 0, v0]) T_full[0, 0, 0, v0] = T.float16(65504) @T.prim_func(private=True) def fused_NT_matmul1_divide1_maximum1_minimum1_cast3(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() lv28 = T.match_buffer(p_lv28, (1, 32, n, 128), "float16") m = T.int32() lv29 = T.match_buffer(p_lv29, (1, 32, m, 128), "float16") lv5 = T.match_buffer(p_lv5, (1, 1, n, m), "float16") var_compute_intermediate = T.match_buffer(p_output0, (1, 32, n, m)) # with T.block("root"): var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((32, (n + 31) // 32 * 32, (m + 63) // 64 * 64), "float16", scope="local") lv28_reindex_pad_shared = T.alloc_buffer((32, (n + 31) // 32 * 32, 128), "float16", scope="shared") lv29_reindex_pad_shared = T.alloc_buffer((32, (m + 63) // 64 * 64, 128), "float16", scope="shared") for ax0_ax2_0_fused in T.thread_binding((m + 63) // 64 * 32, thread="blockIdx.y"): for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): for ax2_1 in T.thread_binding(1, thread="vthread.y"): for ax1_1 in T.thread_binding(1, thread="vthread.x"): for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_3_init, ax1_3_init in T.grid(4, 4): with T.block("NT_matmul_init"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64)) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) v2 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) T.reads() T.writes(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] = T.float16(0) for ax3_0 in range(8): for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(2): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("lv28_reindex_pad_shared"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64)) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(128, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv28[0, v0, v1, v2]) T.writes(lv28_reindex_pad_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv28_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv28[0, v0, v1, v2], T.float16(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(4): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("lv29_reindex_pad_shared"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64)) v1 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(128, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv29[0, v0, v1, v2]) T.writes(lv29_reindex_pad_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv29_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, lv29[0, v0, v1, v2], T.float16(0)) for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4): with T.block("NT_matmul_update"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64)) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) v2 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3) v3 = T.axis.reduce(128, ax3_0 * 16 + ax3_1) T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv28_reindex_pad_shared[v0, v1, v3], lv29_reindex_pad_shared[v0, v2, v3]) T.writes(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] + lv28_reindex_pad_shared[v0, v1, v3] * lv29_reindex_pad_shared[v0, v2, v3] for ax0, ax1, ax2_0 in T.grid(1, 4, 2): for ax2_1_1 in T.vectorized(2): with T.block("var_NT_matmul_intermediate_reindex_pad_local"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // ((m + 63) // 64) + ax0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial((m + 63) // 64 * 64, ax0_ax2_0_fused % ((m + 63) // 64) * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv5[0, 0, v1, v2]) T.writes(var_compute_intermediate[0, v0, v1, v2]) if v1 < n and v2 < m: var_compute_intermediate[0, v0, v1, v2] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.float16(0.088397790055248615), T.float16(-65504)), lv5[0, 0, v1, v2])) @T.prim_func(private=True) def fused_NT_matmul7_divide_maximum_minimum_cast(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) # with T.block("root"): var_NT_matmul_intermediate_local = T.alloc_buffer((1, 32, 1, n), "float16", scope="local") var_NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 32, 1, n), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 32, 1, n), "float16", scope="local") lv1638_local = T.alloc_buffer((1, 32, n, 128), "float16", scope="local") lv1637_shared = T.alloc_buffer((1, 32, 1, 128), "float16", scope="shared") for ax0_fused_ax1_fused_fused_0 in T.thread_binding(n * 32, thread="blockIdx.x"): for ax0_fused_ax1_fused_fused_1 in T.thread_binding(1, thread="threadIdx.y"): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(64, thread="threadIdx.x"): for ax0, ax1, ax2 in T.grid(1, 1, 1): for ax3_0 in T.serial(1, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax3_1 in T.thread_binding(1, thread="threadIdx.y"): for ax3_2 in T.thread_binding(64, thread="threadIdx.x"): for ax3_3 in T.vectorized(2): with T.block("lv1637_shared"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1) v2 = T.axis.spatial(1, ax2) v3 = T.axis.spatial(128, ax3_0 * 128 + ax3_1 * 128 + ax3_2 * 2 + ax3_3) T.reads(lv1637[v0, v1, v2, v3]) T.writes(lv1637_shared[v0, v1, v2, v3]) lv1637_shared[v0, v1, v2, v3] = lv1637[v0, v1, v2, v3] for ax0_fused_ax1_fused_fused_2_init in range(1): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(2): with T.block("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) // n) v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) % n) T.reads() T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1]) var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = T.float16(0) for ax2_fused_u_fused_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2): for ax2_1 in T.vectorized(1): with T.block("lv1638_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1) v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1) v3 = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3) T.reads(lv1638[v0, v1, v2, v3]) T.writes(lv1638_local[v0, v1, v2, v3]) lv1638_local[v0, v1, v2, v3] = lv1638[v0, v1, v2, v3] for ax0_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(1, 1): for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(2): with T.block("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) // n) v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) % n) vax2_fused_u_fused_2, vax2_fused_u_fused_0 = T.axis.remap("RR", [ax2_fused_u_fused_2, ax2_fused_u_fused_0]) T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1], lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused], lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused]) T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1]) var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] + lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused] * lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused] for ax2_ax3_fused_0 in T.thread_binding(1, thread="threadIdx.y"): for ax0 in T.thread_binding(64, thread="threadIdx.x"): for ax2_ax3_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_ax3_fused_1_1 in T.vectorized(1): with T.block("NT_matmul_rf_init"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(64, ax0) v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = T.float16(0) for ax1 in range(2): with T.block("NT_matmul_rf_update"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1], var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1]) T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1] for ax1_ax2_fused_1 in range(1): for ax1_ax2_fused_0 in T.thread_binding(1, thread="threadIdx.y"): for ax0 in T.thread_binding(64, thread="threadIdx.x"): with T.block("NT_matmul"): vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(64, ax0) v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) T.writes(var_NT_matmul_intermediate_local[0, v0, 0, v1]) with T.init(): var_NT_matmul_intermediate_local[0, v0, 0, v1] = T.float16(0) var_NT_matmul_intermediate_local[0, v0, 0, v1] = var_NT_matmul_intermediate_local[0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] for ax0_ax1_fused_0 in T.thread_binding(1, thread="threadIdx.y"): for ax0_ax1_fused_1 in range(1): with T.block("compute"): v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) T.reads(var_NT_matmul_intermediate_local[0, v0, 0, v1], lv1614[0, 0, 0, v1]) T.writes(var_compute_intermediate[0, v0, 0, v1]) var_compute_intermediate[0, v0, 0, v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[0, v0, 0, v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1614[0, 0, 0, v1])) @T.prim_func(private=True) def fused_fused_decode1_fused_NT_matmul5_cast2(lv771: T.Buffer((32001, 512), "uint32"), lv772: T.Buffer((32001, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32001), "float32")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 32001), "float16", scope="local") var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 32001), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 32001), "float16", scope="local") lv771_local = T.alloc_buffer((32001, 512), "uint32", scope="local") lv3216_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(501, thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): for ax0, ax1 in T.grid(1, 1): for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(64, thread="threadIdx.y"): for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): for ax2_3 in T.vectorized(4): with T.block("lv3216_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) T.reads(lv3216[v0, v1, v2]) T.writes(lv3216_shared[v0, v1, v2]) lv3216_shared[v0, v1, v2] = lv3216[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(1): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.where(u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init < 32001) T.reads() T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_0, ax1 in T.grid(1, 1): for ax0_1 in T.vectorized(1): with T.block("lv771_local"): v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) T.where(u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 < 32001) T.reads(lv771[v0, v1]) T.writes(lv771_local[v0, v1]) lv771_local[v0, v1] = lv771[v0, v1] for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) T.where(u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2 < 32001) T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv772[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(1): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.where(u_fused_ax0_fused_fused_0 * 64 + (ax2_fused_0 + (ax2_fused_1_0 + ax2_fused_1_1)) < 32001) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) for ax1 in range(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.where(u_fused_ax0_fused_fused_0 * 64 + (ax2_fused_0 + (ax2_fused_1_0 + ax2_fused_1_1)) < 32001) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] for ax1_fused_1 in range(1): for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): with T.block("NT_matmul"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1) T.where(u_fused_ax0_fused_fused_0 * 64 + (ax1_fused_0 + ax1_fused_1) < 32001) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) with T.init(): var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0) var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0_fused_1 in range(1): with T.block("compute"): v0 = T.axis.spatial(32001, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1) T.where(u_fused_ax0_fused_fused_0 * 64 + (ax0_fused_0 + ax0_fused_1) < 32001) T.reads(var_NT_matmul_intermediate_local[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) p_output0_intermediate[0, 0, v0] = T.Cast("float32", var_NT_matmul_intermediate_local[0, 0, v0]) @T.prim_func(private=True) def fused_fused_decode1_take(lv: T.Buffer((32001, 512), "uint32"), lv1: T.Buffer((32001, 128), "float16"), lv1611: T.Buffer((1,), "int32"), var_T_take_intermediate: T.Buffer((1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_take"): v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1) T.reads(lv[lv1611[0], v0 // 8], lv1611[0], lv1[lv1611[0], v0 // 32]) T.writes(var_T_take_intermediate[0, v0]) var_T_take_intermediate[0, v0] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv[lv1611[0], v0 // 8], T.Cast("uint32", v0 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1[lv1611[0], v0 // 32] @T.prim_func(private=True) def fused_fused_decode1_take1(lv775: T.Buffer((32001, 512), "uint32"), lv776: T.Buffer((32001, 128), "float16"), p_lv: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() lv = T.match_buffer(p_lv, (n,), "int32") var_T_take_intermediate = T.match_buffer(p_output0, (n, 4096), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_take"): v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096) v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096) T.reads(lv775[lv[v0], v1 // 8], lv[v0], lv776[lv[v0], v1 // 32]) T.writes(var_T_take_intermediate[v0, v1]) var_T_take_intermediate[v0, v1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv775[lv[v0], v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv776[lv[v0], v1 // 32] @T.prim_func(private=True) def fused_fused_decode2_NT_matmul(lv779: T.Buffer((12288, 512), "uint32"), lv780: T.Buffer((12288, 128), "float16"), p_lv6: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() lv6 = T.match_buffer(p_lv6, (1, n, 4096), "float16") var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 12288), "float16") # with T.block("root"): var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 12288), "float16", scope="local") lv6_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="shared") p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 12288, 4096), "float16", scope="shared") for ax0_ax2_0_fused in T.thread_binding(192, thread="blockIdx.y"): for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): for ax2_1 in T.thread_binding(1, thread="vthread.y"): for ax1_1 in T.thread_binding(1, thread="vthread.x"): for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_3_init, ax1_3_init in T.grid(4, 4): with T.block("NT_matmul_init"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) v2 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) T.reads() T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0) for ax3_0 in range(256): for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(2): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("lv6_reindex_pad_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv6[v0, v1, v2]) T.writes(lv6_reindex_pad_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv6_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv6[v0, v1, v2], T.float16(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(4): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("p_output0_intermediate_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv779[v1, v2 // 8], lv780[v1, v2 // 32]) T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv779[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv780[v1, v2 // 32] for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4): with T.block("NT_matmul_update"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) v2 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3) v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1) T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv6_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3]) T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv6_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3] for ax0, ax1, ax2_0 in T.grid(1, 4, 2): for ax2_1_1 in T.vectorized(2): with T.block("var_NT_matmul_intermediate_reindex_pad_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(12288, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]) T.writes(var_NT_matmul_intermediate[0, v1, v2]) if v1 < n: var_NT_matmul_intermediate[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] @T.prim_func(private=True) def fused_fused_decode2_NT_matmul6(lv3: T.Buffer((12288, 512), "uint32"), lv4: T.Buffer((12288, 128), "float16"), lv1615: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 12288), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 12288), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 12288), "float16", scope="local") lv3_local = T.alloc_buffer((12288, 512), "uint32", scope="local") lv1615_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(192, thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): for ax0, ax1 in T.grid(1, 1): for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(64, thread="threadIdx.y"): for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): for ax2_3 in T.vectorized(4): with T.block("lv1615_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) T.reads(lv1615[v0, v1, v2]) T.writes(lv1615_shared[v0, v1, v2]) lv1615_shared[v0, v1, v2] = lv1615[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(1): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_0, ax1 in T.grid(1, 1): for ax0_1 in T.vectorized(1): with T.block("lv3_local"): v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) T.reads(lv3[v0, v1]) T.writes(lv3_local[v0, v1]) lv3_local[v0, v1] = lv3[v0, v1] for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1615_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv3_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv4[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1615_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv3_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv4[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(1): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) for ax1 in range(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] for ax1_fused_1 in range(1): for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): with T.block("NT_matmul"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) v0 = T.axis.spatial(12288, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) T.writes(var_NT_matmul_intermediate[0, 0, v0]) with T.init(): var_NT_matmul_intermediate[0, 0, v0] = T.float16(0) var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] @T.prim_func(private=True) def fused_fused_decode3_fused_NT_matmul2_add1(lv784: T.Buffer((4096, 512), "uint32"), lv785: T.Buffer((4096, 128), "float16"), p_lv41: T.handle, p_lv2: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() lv41 = T.match_buffer(p_lv41, (1, n, 4096), "float16") lv2 = T.match_buffer(p_lv2, (1, n, 4096), "float16") p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") # with T.block("root"): var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="local") lv41_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="shared") p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 4096, 4096), "float16", scope="shared") for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"): for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): for ax2_1 in T.thread_binding(1, thread="vthread.y"): for ax1_1 in T.thread_binding(1, thread="vthread.x"): for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_3_init, ax1_3_init in T.grid(4, 4): with T.block("NT_matmul_init"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) T.reads() T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0) for ax3_0 in range(256): for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(2): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("lv41_reindex_pad_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv41[v0, v1, v2]) T.writes(lv41_reindex_pad_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv41_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv41[v0, v1, v2], T.float16(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(4): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("p_output0_intermediate_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv784[v1, v2 // 8], lv785[v1, v2 // 32]) T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv784[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv785[v1, v2 // 32] for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4): with T.block("NT_matmul_update"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3) v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1) T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv41_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3]) T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv41_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3] for ax0, ax1, ax2_0 in T.grid(1, 4, 2): for ax2_1_1 in T.vectorized(2): with T.block("var_NT_matmul_intermediate_reindex_pad_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) T.reads(lv2[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]) T.writes(p_output0_intermediate[0, v1, v2]) if v1 < n: p_output0_intermediate[0, v1, v2] = lv2[0, v1, v2] + var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] @T.prim_func(private=True) def fused_fused_decode3_fused_NT_matmul8_add(lv15: T.Buffer((4096, 512), "uint32"), lv16: T.Buffer((4096, 128), "float16"), lv14: T.Buffer((1, 1, 4096), "float16"), lv1613: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 4096), "float16", scope="local") lv15_local = T.alloc_buffer((4096, 512), "uint32", scope="local") lv14_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): for ax0, ax1 in T.grid(1, 1): for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(64, thread="threadIdx.y"): for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): for ax2_3 in T.vectorized(4): with T.block("lv14_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) T.reads(lv14[v0, v1, v2]) T.writes(lv14_shared[v0, v1, v2]) lv14_shared[v0, v1, v2] = lv14[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(1): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_0, ax1 in T.grid(1, 1): for ax0_1 in T.vectorized(1): with T.block("lv15_local"): v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) T.reads(lv15[v0, v1]) T.writes(lv15_local[v0, v1]) lv15_local[v0, v1] = lv15[v0, v1] for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv14_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv15_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv16[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv14_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv15_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv16[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(1): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) for ax1 in range(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] for ax1_fused_1 in range(1): for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): with T.block("NT_matmul"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) with T.init(): var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0) var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0_fused_1 in range(1): with T.block("T_add"): v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1) T.reads(lv1613[0, 0, v0], var_NT_matmul_intermediate_local[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) p_output0_intermediate[0, 0, v0] = lv1613[0, 0, v0] + var_NT_matmul_intermediate_local[0, 0, v0] @T.prim_func(private=True) def fused_fused_decode4_NT_matmul3(lv788: T.Buffer((22016, 512), "uint32"), lv789: T.Buffer((22016, 128), "float16"), p_lv45: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() lv45 = T.match_buffer(p_lv45, (1, n, 4096), "float16") var_NT_matmul_intermediate = T.match_buffer(p_output0, (1, n, 22016), "float16") # with T.block("root"): var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 22016), "float16", scope="local") lv45_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="shared") p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 22016, 4096), "float16", scope="shared") for ax0_ax2_0_fused in T.thread_binding(344, thread="blockIdx.y"): for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): for ax2_1 in T.thread_binding(1, thread="vthread.y"): for ax1_1 in T.thread_binding(1, thread="vthread.x"): for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_3_init, ax1_3_init in T.grid(4, 4): with T.block("NT_matmul_init"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) T.reads() T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0) for ax3_0 in range(256): for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(2): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("lv45_reindex_pad_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv45[v0, v1, v2]) T.writes(lv45_reindex_pad_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv45_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv45[v0, v1, v2], T.float16(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(4): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("p_output0_intermediate_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(4096, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv788[v1, v2 // 8], lv789[v1, v2 // 32]) T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv788[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv789[v1, v2 // 32] for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4): with T.block("NT_matmul_update"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3) v3 = T.axis.reduce(4096, ax3_0 * 16 + ax3_1) T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv45_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3]) T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv45_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3] for ax0, ax1, ax2_0 in T.grid(1, 4, 2): for ax2_1_1 in T.vectorized(2): with T.block("var_NT_matmul_intermediate_reindex_pad_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(22016, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) T.reads(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]) T.writes(var_NT_matmul_intermediate[0, v1, v2]) if v1 < n: var_NT_matmul_intermediate[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] @T.prim_func(private=True) def fused_fused_decode4_NT_matmul9(lv19: T.Buffer((22016, 512), "uint32"), lv20: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 22016), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 22016), "float16", scope="local") lv19_local = T.alloc_buffer((22016, 512), "uint32", scope="local") lv1654_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(344, thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): for ax0, ax1 in T.grid(1, 1): for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(64, thread="threadIdx.y"): for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): for ax2_3 in T.vectorized(4): with T.block("lv1654_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) T.reads(lv1654[v0, v1, v2]) T.writes(lv1654_shared[v0, v1, v2]) lv1654_shared[v0, v1, v2] = lv1654[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(1): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_0, ax1 in T.grid(1, 1): for ax0_1 in T.vectorized(1): with T.block("lv19_local"): v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) T.reads(lv19[v0, v1]) T.writes(lv19_local[v0, v1]) lv19_local[v0, v1] = lv19[v0, v1] for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv19_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv20[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv19_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv20[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(1): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) for ax1 in range(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] for ax1_fused_1 in range(1): for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): with T.block("NT_matmul"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) T.writes(var_NT_matmul_intermediate[0, 0, v0]) with T.init(): var_NT_matmul_intermediate[0, 0, v0] = T.float16(0) var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] @T.prim_func(private=True) def fused_fused_decode5_fused_NT_matmul10_add(lv23: T.Buffer((4096, 1376), "uint32"), lv24: T.Buffer((4096, 344), "float16"), lv22: T.Buffer((1, 1, 11008), "float16"), lv18: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 4096), "float16", scope="local") lv23_local = T.alloc_buffer((4096, 1376), "uint32", scope="local") lv22_shared = T.alloc_buffer((1, 1, 11008), "float16", scope="shared") for u_fused_ax0_fused_fused_0 in T.thread_binding(64, thread="blockIdx.x"): for u_fused_ax0_fused_fused_1 in T.thread_binding(64, thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): for ax0, ax1 in T.grid(1, 1): for ax2_0 in T.serial(22, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): for ax2_1 in T.thread_binding(64, thread="threadIdx.y"): for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): for ax2_3 in T.vectorized(1): with T.block("lv22_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) v2 = T.axis.spatial(11008, ax2_0 * 512 + ax2_1 * 8 + ax2_2 + ax2_3) T.where((ax2_0 * 64 + ax2_1) * 8 + ax2_2 + ax2_3 < 11008) T.reads(lv22[v0, v1, v2]) T.writes(lv22_shared[v0, v1, v2]) lv22_shared[v0, v1, v2] = lv22[v0, v1, v2] for u_fused_ax0_fused_fused_2_init in range(1): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) for ax1_0_fused_ax1_1_fused_0 in T.serial(172, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_0, ax1 in T.grid(1, 1): for ax0_1 in T.vectorized(1): with T.block("lv23_local"): v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) v1 = T.axis.spatial(1376, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) T.reads(lv23[v0, v1]) T.writes(lv23_local[v0, v1]) lv23_local[v0, v1] = lv23[v0, v1] for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv22_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv23_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv24[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv22_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv23_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv24[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) for ax2_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(1): with T.block("NT_matmul_rf_init"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) for ax1 in range(2): with T.block("NT_matmul_rf_update"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] for ax1_fused_1 in range(1): for ax1_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0 in T.thread_binding(8, thread="threadIdx.x"): with T.block("NT_matmul"): vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax1_fused_0 + ax1_fused_1) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) with T.init(): var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0) var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] for ax0_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for ax0_fused_1 in range(1): with T.block("T_add"): v0 = T.axis.spatial(4096, u_fused_ax0_fused_fused_0 * 64 + ax0_fused_0 + ax0_fused_1) T.reads(lv18[0, 0, v0], var_NT_matmul_intermediate_local[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) p_output0_intermediate[0, 0, v0] = lv18[0, 0, v0] + var_NT_matmul_intermediate_local[0, 0, v0] @T.prim_func(private=True) def fused_fused_decode5_fused_NT_matmul4_add1(lv792: T.Buffer((4096, 1376), "uint32"), lv793: T.Buffer((4096, 344), "float16"), p_lv791: T.handle, p_lv787: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() lv791 = T.match_buffer(p_lv791, (1, n, 11008), "float16") lv787 = T.match_buffer(p_lv787, (1, n, 4096), "float16") p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") # with T.block("root"): var_NT_matmul_intermediate_reindex_pad_local = T.alloc_buffer((1, (n + 31) // 32 * 32, 4096), "float16", scope="local") lv791_reindex_pad_shared = T.alloc_buffer((1, (n + 31) // 32 * 32, 11008), "float16", scope="shared") p_output0_intermediate_reindex_shared = T.alloc_buffer((1, 4096, 11008), "float16", scope="shared") for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"): for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): for ax2_1 in T.thread_binding(1, thread="vthread.y"): for ax1_1 in T.thread_binding(1, thread="vthread.x"): for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_3_init, ax1_3_init in T.grid(4, 4): with T.block("NT_matmul_init"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) T.reads() T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = T.float16(0) for ax3_0 in range(688): for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(2): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("lv791_reindex_pad_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(11008, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv791[v0, v1, v2]) T.writes(lv791_reindex_pad_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) lv791_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv791[v0, v1, v2], T.float16(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(4): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("p_output0_intermediate_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial(11008, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(lv792[v1, v2 // 8], lv793[v1, v2 // 32]) T.writes(p_output0_intermediate_reindex_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) p_output0_intermediate_reindex_shared[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv792[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv793[v1, v2 // 32] for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4): with T.block("NT_matmul_update"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3) v3 = T.axis.reduce(11008, ax3_0 * 16 + ax3_1) T.reads(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2], lv791_reindex_pad_shared[0, v1, v3], p_output0_intermediate_reindex_shared[0, v2, v3]) T.writes(var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2]) var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] = var_NT_matmul_intermediate_reindex_pad_local[0, v1, v2] + lv791_reindex_pad_shared[0, v1, v3] * p_output0_intermediate_reindex_shared[0, v2, v3] for ax0, ax1, ax2_0 in T.grid(1, 4, 2): for ax2_1_1 in T.vectorized(2): with T.block("var_NT_matmul_intermediate_reindex_pad_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) T.reads(lv787[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]) T.writes(p_output0_intermediate[0, v1, v2]) if v1 < n: p_output0_intermediate[0, v1, v2] = lv787[0, v1, v2] + var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] @T.prim_func(private=True) def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int32): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (1, 1, n, n), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding((n * n + 255) // 256, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_broadcast_to"): v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // n) v1 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % n) T.where(ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 < n * n) T.reads() T.writes(var_T_broadcast_to_intermediate[0, 0, v0, v1]) var_T_broadcast_to_intermediate[0, 0, v0, v1] = T.Select(v0 < v1, T.float16(-65504), T.float16(65504)) @T.prim_func(private=True) def fused_reshape2_squeeze(lv_1: T.Buffer((1, 1, 4096), "float16"), var_T_squeeze_intermediate: T.Buffer((1, 32, 128), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding(16, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_squeeze"): v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 128) v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 128) T.reads(lv_1[0, 0, v0 * 128 + v1]) T.writes(var_T_squeeze_intermediate[0, v0, v1]) var_T_squeeze_intermediate[0, v0, v1] = lv_1[0, 0, v0 * 128 + v1] @T.prim_func(private=True) def fused_reshape2_transpose5(lv_0: T.Buffer((1, 1, 4096), "float16"), var_T_transpose_intermediate: T.Buffer((1, 32, 1, 128), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding(16, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_transpose"): v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 128) v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 128) T.reads(lv_0[0, 0, v0 * 128 + v1]) T.writes(var_T_transpose_intermediate[0, v0, 0, v1]) var_T_transpose_intermediate[0, v0, 0, v1] = lv_0[0, 0, v0 * 128 + v1] @T.prim_func(private=True) def fused_softmax1_cast4(p_lv36: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n, m = T.int32(), T.int32() lv36 = T.match_buffer(p_lv36, (1, 32, n, m)) var_compute_intermediate = T.match_buffer(p_output0, (1, 32, n, m), "float16") # with T.block("root"): T_softmax_maxelem_shared = T.alloc_buffer((1, 32, n), scope="shared") T_softmax_expsum_shared = T.alloc_buffer((1, 32, n), scope="shared") for ax0_ax1_fused in T.thread_binding(n * 32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): for ax0, ax1, ax2_fused_0 in T.grid(1, 1, (m + 63) // 64): for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax0) v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1) v2 = T.axis.reduce(m, ax2_fused_0 * 64 + ax2_fused_1) T.where(ax2_fused_0 * 64 + ax2_fused_1 < m) T.reads(lv36[0, v0, v1, v2]) T.writes(T_softmax_maxelem_shared[0, v0, v1]) with T.init(): T_softmax_maxelem_shared[0, v0, v1] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem_shared[0, v0, v1] = T.max(T_softmax_maxelem_shared[0, v0, v1], lv36[0, v0, v1, v2]) for ax0, ax1, ax2_fused_0 in T.grid(1, 1, (m + 63) // 64): for ax2_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("T_softmax_expsum"): v0 = T.axis.spatial(32, ax0_ax1_fused // n + ax0) v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1) v2 = T.axis.reduce(m, ax2_fused_0 * 64 + ax2_fused_1) T.where(ax2_fused_0 * 64 + ax2_fused_1 < m) T.reads(lv36[0, v0, v1, v2], T_softmax_maxelem_shared[0, v0, v1]) T.writes(T_softmax_expsum_shared[0, v0, v1]) with T.init(): T_softmax_expsum_shared[0, v0, v1] = T.float32(0) T_softmax_expsum_shared[0, v0, v1] = T_softmax_expsum_shared[0, v0, v1] + T.exp(lv36[0, v0, v1, v2] - T_softmax_maxelem_shared[0, v0, v1]) for ax2_0 in range((m + 63) // 64): for ax2_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("compute"): v0 = T.axis.spatial(32, ax0_ax1_fused // n) v1 = T.axis.spatial(n, ax0_ax1_fused % n) v2 = T.axis.spatial(m, ax2_0 * 64 + ax2_1) T.where(ax2_0 * 64 + ax2_1 < m) T.reads(lv36[0, v0, v1, v2], T_softmax_maxelem_shared[0, v0, v1], T_softmax_expsum_shared[0, v0, v1]) T.writes(var_compute_intermediate[0, v0, v1, v2]) var_compute_intermediate[0, v0, v1, v2] = T.Cast("float16", T.exp(lv36[0, v0, v1, v2] - T_softmax_maxelem_shared[0, v0, v1]) / T_softmax_expsum_shared[0, v0, v1]) @T.prim_func(private=True) def fused_softmax_cast1(p_lv1645: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() lv1645 = T.match_buffer(p_lv1645, (1, 32, 1, n)) var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n), "float16") # with T.block("root"): T_softmax_maxelem_shared = T.alloc_buffer((1, 32, 1), scope="shared") T_softmax_expsum_shared = T.alloc_buffer((1, 32, 1), scope="shared") for ax0_fused in T.thread_binding(32, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): for ax0, ax1_fused_0 in T.grid(1, (n + 63) // 64): for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): v0 = T.axis.spatial(32, ax0_fused + ax0) v1 = T.axis.reduce(n, ax1_fused_0 * 64 + ax1_fused_1) T.where(ax1_fused_0 * 64 + ax1_fused_1 < n) T.reads(lv1645[0, v0, 0, v1]) T.writes(T_softmax_maxelem_shared[0, v0, 0]) with T.init(): T_softmax_maxelem_shared[0, v0, 0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem_shared[0, v0, 0] = T.max(T_softmax_maxelem_shared[0, v0, 0], lv1645[0, v0, 0, v1]) for ax0, ax1_fused_0 in T.grid(1, (n + 63) // 64): for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("T_softmax_expsum"): v0 = T.axis.spatial(32, ax0_fused + ax0) v1 = T.axis.reduce(n, ax1_fused_0 * 64 + ax1_fused_1) T.where(ax1_fused_0 * 64 + ax1_fused_1 < n) T.reads(lv1645[0, v0, 0, v1], T_softmax_maxelem_shared[0, v0, 0]) T.writes(T_softmax_expsum_shared[0, v0, 0]) with T.init(): T_softmax_expsum_shared[0, v0, 0] = T.float32(0) T_softmax_expsum_shared[0, v0, 0] = T_softmax_expsum_shared[0, v0, 0] + T.exp(lv1645[0, v0, 0, v1] - T_softmax_maxelem_shared[0, v0, 0]) for ax1_0 in range((n + 63) // 64): for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("compute"): v0 = T.axis.spatial(32, ax0_fused) v1 = T.axis.spatial(n, ax1_0 * 64 + ax1_1) T.where(ax1_0 * 64 + ax1_1 < n) T.reads(lv1645[0, v0, 0, v1], T_softmax_maxelem_shared[0, v0, 0], T_softmax_expsum_shared[0, v0, 0]) T.writes(var_compute_intermediate[0, v0, 0, v1]) var_compute_intermediate[0, v0, 0, v1] = T.Cast("float16", T.exp(lv1645[0, v0, 0, v1] - T_softmax_maxelem_shared[0, v0, 0]) / T_softmax_expsum_shared[0, v0, 0]) @T.prim_func(private=True) def fused_split2_silu1_multiply1(p_lv3: T.handle, p_output0: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() lv3 = T.match_buffer(p_lv3, (1, n, 22016), "float16") var_T_multiply_intermediate = T.match_buffer(p_output0, (1, n, 11008), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding(n * 43, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_multiply_1"): v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 11008) v1 = T.axis.spatial(11008, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 11008) T.reads(lv3[0, v0, v1:v1 + 11009]) T.writes(var_T_multiply_intermediate[0, v0, v1]) var_T_multiply_intermediate[0, v0, v1] = lv3[0, v0, v1] * T.sigmoid(lv3[0, v0, v1]) * lv3[0, v0, v1 + 11008] @T.prim_func(private=True) def fused_split_silu_multiply(lv164: T.Buffer((1, 1, 22016), "float16"), var_T_multiply_intermediate: T.Buffer((1, 1, 11008), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(43, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_multiply_1"): v0 = T.axis.spatial(11008, ax0_fused_0 * 256 + ax0_fused_1) T.reads(lv164[0, 0, v0:v0 + 11009]) T.writes(var_T_multiply_intermediate[0, 0, v0]) var_T_multiply_intermediate[0, 0, v0] = lv164[0, 0, v0] * T.sigmoid(lv164[0, 0, v0]) * lv164[0, 0, v0 + 11008] @T.prim_func(private=True) def fused_transpose7_reshape4(lv1648: T.Buffer((1, 32, 1, 128), "float16"), var_T_reshape_intermediate: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1) T.reads(lv1648[0, v0 // 128, 0, v0 % 128]) T.writes(var_T_reshape_intermediate[0, 0, v0]) var_T_reshape_intermediate[0, 0, v0] = lv1648[0, v0 // 128, 0, v0 % 128] @T.prim_func(private=True) def matmul10(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n, m = T.int32(), T.int32() A = T.match_buffer(var_A, (1, 32, n, m), "float16") B = T.match_buffer(var_B, (1, 32, m, 128), "float16") matmul = T.match_buffer(var_matmul, (1, 32, n, 128), "float16") # with T.block("root"): matmul_reindex_pad_local = T.alloc_buffer((32, (n + 31) // 32 * 32, 128), "float16", scope="local") A_reindex_pad_shared = T.alloc_buffer((32, (n + 31) // 32 * 32, (m + 15) // 16 * 16), "float16", scope="shared") B_reindex_pad_shared = T.alloc_buffer((32, 128, (m + 15) // 16 * 16), "float16", scope="shared") for ax0_ax2_0_fused in T.thread_binding(64, thread="blockIdx.y"): for ax1_0 in T.thread_binding((n + 31) // 32, thread="blockIdx.x"): for ax2_1 in T.thread_binding(1, thread="vthread.y"): for ax1_1 in T.thread_binding(1, thread="vthread.x"): for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_3_init, ax1_3_init in T.grid(4, 4): with T.block("matmul_init"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) v2 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) T.reads() T.writes(matmul_reindex_pad_local[v0, v1, v2]) matmul_reindex_pad_local[v0, v1, v2] = T.float16(0) for ax3_0 in range((m + 15) // 16): for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(2): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("A_reindex_pad_shared"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(A[0, v0, v1, v2]) T.writes(A_reindex_pad_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n and v2 < m, A[0, v0, v1, v2], T.float16(0)) for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in range(4): for ax0_ax1_ax2_fused_3 in T.vectorized(2): with T.block("B_reindex_pad_shared"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2) v1 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) v2 = T.axis.spatial((m + 15) // 16 * 16, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) T.reads(B[0, v0, v2, v1]) T.writes(B_reindex_pad_shared[v0, v1, v2]) T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) B_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < m, B[0, v0, v2, v1], T.float16(0)) for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 4): with T.block("matmul_update"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) v2 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3) v3 = T.axis.reduce((m + 15) // 16 * 16, ax3_0 * 16 + ax3_1) T.reads(matmul_reindex_pad_local[v0, v1, v2], A_reindex_pad_shared[v0, v1, v3], B_reindex_pad_shared[v0, v2, v3]) T.writes(matmul_reindex_pad_local[v0, v1, v2]) matmul_reindex_pad_local[v0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2] + A_reindex_pad_shared[v0, v1, v3] * B_reindex_pad_shared[v0, v2, v3] for ax0, ax1, ax2_0 in T.grid(1, 4, 2): for ax2_1_1 in T.vectorized(2): with T.block("matmul_reindex_pad_local"): v0 = T.axis.spatial(32, ax0_ax2_0_fused // 2 + ax0) v1 = T.axis.spatial((n + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) v2 = T.axis.spatial(128, ax0_ax2_0_fused % 2 * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) T.reads(matmul_reindex_pad_local[v0, v1, v2]) T.writes(matmul[0, v0, v1, v2]) if v1 < n: matmul[0, v0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2] @T.prim_func(private=True) def matmul9(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((1, 32, 1, 128), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, 32, 1, n), "float16") B = T.match_buffer(var_B, (1, 32, n, 128), "float16") # with T.block("root"): matmul_rf_local = T.alloc_buffer((16, 1, 32, 1, 128), "float16", scope="local") for ax0_ax1_fused_0 in T.thread_binding(256, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.x"): for ax2_fused_1 in T.thread_binding(16, thread="threadIdx.y"): with T.block("matmul_rf_init"): vax2_fused_1 = T.axis.spatial(16, ax2_fused_1) v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) // 128) v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) % 128) T.reads() T.writes(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1]) matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] = T.float16(0) for ax2_fused_0, u in T.grid((n + 15) // 16, 1): with T.block("matmul_rf_update"): vax2_fused_1 = T.axis.spatial(16, ax2_fused_1) v0 = T.axis.spatial(32, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) // 128) v1 = T.axis.spatial(128, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1) % 128) vax2_fused_0 = T.axis.reduce((n + 15) // 16, ax2_fused_0) T.where(ax2_fused_0 * 16 + ax2_fused_1 < n) T.reads(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1], A[0, v0, 0, vax2_fused_0 * 16 + vax2_fused_1], B[0, v0, vax2_fused_0 * 16 + vax2_fused_1, v1]) T.writes(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1]) matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] = matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] + A[0, v0, 0, vax2_fused_0 * 16 + vax2_fused_1] * B[0, v0, vax2_fused_0 * 16 + vax2_fused_1, v1] for ax1_ax2_fused in T.thread_binding(16, thread="threadIdx.x"): for ax0 in T.thread_binding(16, thread="threadIdx.y"): with T.block("matmul"): vax2_fused_1 = T.axis.reduce(16, ax0) v0 = T.axis.spatial(32, ax0_ax1_fused_0 // 8) v1 = T.axis.spatial(128, ax0_ax1_fused_0 % 8 * 16 + ax1_ax2_fused) T.reads(matmul_rf_local[vax2_fused_1, 0, v0, 0, v1]) T.writes(matmul[0, v0, 0, v1]) with T.init(): matmul[0, v0, 0, v1] = T.float16(0) matmul[0, v0, 0, v1] = matmul[0, v0, 0, v1] + matmul_rf_local[vax2_fused_1, 0, v0, 0, v1] @T.prim_func(private=True) def reshape(A: T.Buffer((1, 1), "int32"), T_reshape: T.Buffer((1,), "int32")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(1, 0) T.where(ax0_fused_0 * 256 + ax0_fused_1 < 1) T.reads(A[0, 0]) T.writes(T_reshape[0]) T_reshape[0] = A[0, 0] @T.prim_func(private=True) def reshape1(A: T.Buffer((1, 4096), "float16"), T_reshape: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1) T.reads(A[0, v0]) T.writes(T_reshape[0, 0, v0]) T_reshape[0, 0, v0] = A[0, v0] @T.prim_func(private=True) def reshape3(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (n, 32, 128), "float16") T_reshape = T.match_buffer(var_T_reshape, (1, n, 32, 128), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096) v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128) v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128) T.reads(A[v0, v1, v2]) T.writes(T_reshape[0, v0, v1, v2]) T_reshape[0, v0, v1, v2] = A[v0, v1, v2] @T.prim_func(private=True) def reshape5(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, n), "int32") T_reshape = T.match_buffer(var_T_reshape, (n,), "int32") # with T.block("root"): for ax0_fused_0 in T.thread_binding((n + 255) // 256, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(n, ax0_fused_0 * 256 + ax0_fused_1) T.where(ax0_fused_0 * 256 + ax0_fused_1 < n) T.reads(A[0, v0]) T.writes(T_reshape[v0]) T_reshape[v0] = A[0, v0] @T.prim_func(private=True) def reshape6(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (n, 4096), "float16") T_reshape = T.match_buffer(var_T_reshape, (1, n, 4096), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096) v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096) T.reads(A[v0, v1]) T.writes(T_reshape[0, v0, v1]) T_reshape[0, v0, v1] = A[v0, v1] @T.prim_func(private=True) def reshape7(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, n, 4096), "float16") T_reshape = T.match_buffer(var_T_reshape, (1, n, 32, 128), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096) v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128) v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128) T.reads(A[0, v0, v1 * 128 + v2]) T.writes(T_reshape[0, v0, v1, v2]) T_reshape[0, v0, v1, v2] = A[0, v0, v1 * 128 + v2] @T.prim_func(private=True) def reshape8(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, n, 32, 128), "float16") T_reshape = T.match_buffer(var_T_reshape, (1, n, 4096), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_reshape"): v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096) v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096) T.reads(A[0, v0, v1 // 128, v1 % 128]) T.writes(T_reshape[0, v0, v1]) T_reshape[0, v0, v1] = A[0, v0, v1 // 128, v1 % 128] @T.prim_func(private=True) def rms_norm(var_A: T.handle, B: T.Buffer((4096,), "float16"), var_rms_norm: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, n, 4096), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (1, n, 4096), "float16") # with T.block("root"): Ared_temp_shared = T.alloc_buffer((1, n), scope="shared") Ared_temp_rf_local = T.alloc_buffer((64, 1, n), scope="local") for ax0_fused in T.thread_binding(n, thread="blockIdx.x"): for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("Ared_temp_rf_init"): vax1_fused_1, v0 = T.axis.remap("SS", [ax1_fused_1, ax0_fused]) T.reads() T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0]) Ared_temp_rf_local[vax1_fused_1, 0, v0] = T.float32(0) for ax1_fused_0, u in T.grid(64, 1): with T.block("Ared_temp_rf_update"): vax1_fused_1, v0, vax1_fused_0 = T.axis.remap("SSR", [ax1_fused_1, ax0_fused, ax1_fused_0]) T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0], A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) T.writes(Ared_temp_rf_local[vax1_fused_1, 0, v0]) Ared_temp_rf_local[vax1_fused_1, 0, v0] = Ared_temp_rf_local[vax1_fused_1, 0, v0] + T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", A[0, v0, vax1_fused_0 * 64 + vax1_fused_1]) for ax1_fused in range(1): for ax0 in T.thread_binding(64, thread="threadIdx.x"): with T.block("Ared_temp"): vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax0_fused]) T.reads(Ared_temp_rf_local[vax1_fused_1, 0, v0]) T.writes(Ared_temp_shared[0, v0]) with T.init(): Ared_temp_shared[0, v0] = T.float32(0) Ared_temp_shared[0, v0] = Ared_temp_shared[0, v0] + Ared_temp_rf_local[vax1_fused_1, 0, v0] for ax0_fused_0 in range(64): for ax0_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("rms_norm"): v0 = T.axis.spatial(n, ax0_fused) v1 = T.axis.spatial(4096, ax0_fused_0 * 64 + ax0_fused_1) T.reads(B[v1], A[0, v0, v1], Ared_temp_shared[0, v0]) T.writes(rms_norm_1[0, v0, v1]) rms_norm_1[0, v0, v1] = T.Cast("float16", T.Cast("float32", B[v1]) * (T.Cast("float32", A[0, v0, v1]) / T.sqrt(Ared_temp_shared[0, v0] * T.float32(0.000244140625) + T.float32(1.0000000000000001e-05)))) @T.prim_func(private=True) def rms_norm1(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), "float16"), rms_norm: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): Ared_temp_shared = T.alloc_buffer((1, 1), scope="shared") Ared_temp_rf_local = T.alloc_buffer((64, 1, 1), scope="local") for ax0_fused in T.thread_binding(1, thread="blockIdx.x"): for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): with T.block("Ared_temp_rf_init"): vax1_fused_1 = T.axis.spatial(64, ax1_fused_1) v0 = T.axis.spatial(1, 0) T.reads() T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0]) Ared_temp_rf_local[vax1_fused_1, 0, 0] = T.float32(0) for ax1_fused_0, u in T.grid(64, 1): with T.block("Ared_temp_rf_update"): vax1_fused_1 = T.axis.spatial(64, ax1_fused_1) v0 = T.axis.spatial(1, 0) vax1_fused_0 = T.axis.reduce(64, ax1_fused_0) T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0], A[0, 0, vax1_fused_0 * 64 + vax1_fused_1]) T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0]) Ared_temp_rf_local[vax1_fused_1, 0, 0] = Ared_temp_rf_local[vax1_fused_1, 0, 0] + T.Cast("float32", A[0, 0, vax1_fused_0 * 64 + vax1_fused_1]) * T.Cast("float32", A[0, 0, vax1_fused_0 * 64 + vax1_fused_1]) for ax1_fused in range(1): for ax0 in T.thread_binding(64, thread="threadIdx.x"): with T.block("Ared_temp"): vax1_fused_1 = T.axis.reduce(64, ax0) v0 = T.axis.spatial(1, 0) T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0]) T.writes(Ared_temp_shared[0, 0]) with T.init(): Ared_temp_shared[0, 0] = T.float32(0) Ared_temp_shared[0, 0] = Ared_temp_shared[0, 0] + Ared_temp_rf_local[vax1_fused_1, 0, 0] for ax0_fused_0 in range(64): for ax0_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("rms_norm"): v0 = T.axis.spatial(4096, ax0_fused_0 * 64 + ax0_fused_1) T.reads(B[v0], A[0, 0, v0], Ared_temp_shared[0, 0]) T.writes(rms_norm[0, 0, v0]) rms_norm[0, 0, v0] = T.Cast("float16", T.Cast("float32", B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp_shared[0, 0] * T.float32(0.000244140625) + T.float32(1.0000000000000001e-05)))) @T.prim_func(private=True) def rotary_embedding(var_A: T.handle, B: T.Buffer((2048, 128), "float16"), C: T.Buffer((2048, 128), "float16"), var_rotary: T.handle, m: T.int32): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, n, 32, 128), "float16") rotary = T.match_buffer(var_rotary, (1, n, 32, 128), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("rotary"): v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096) v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128) v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128) T.reads(B[v0 + (m - n), v2], A[0, v0, v1, v2 + -64:v2 + -64 + 129], C[v0 + (m - n), v2]) T.writes(rotary[0, v0, v1, v2]) rotary[0, v0, v1, v2] = B[v0 + (m - n), v2] * A[0, v0, v1, v2] + C[v0 + (m - n), v2] * T.Select(64 <= v2, A[0, v0, v1, v2 + -64], A[0, v0, v1, v2 + 64] * T.float16(-1)) @T.prim_func(private=True) def slice(var_A: T.handle, slice_1: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, n, 4096), "float16") # with T.block("root"): for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("slice"): v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1) T.reads(A[0, n - 1, v0]) T.writes(slice_1[0, 0, v0]) slice_1[0, 0, v0] = A[0, n - 1, v0] @T.prim_func(private=True) def slice1(A: T.Buffer((1, 1, 4096), "float16"), slice: T.Buffer((1, 1, 4096), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("slice"): v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1) T.reads(A[0, 0, v0]) T.writes(slice[0, 0, v0]) slice[0, 0, v0] = A[0, 0, v0] @T.prim_func(private=True) def softmax2(A: T.Buffer((1, 1, 32001), "float32"), T_softmax_norm: T.Buffer((1, 1, 32001), "float32")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): T_softmax_maxelem_shared = T.alloc_buffer((1, 1), scope="shared") T_softmax_expsum_shared = T.alloc_buffer((1, 1), scope="shared") for ax0_fused in T.thread_binding(1, thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): for ax0, ax1_fused_0 in T.grid(1, 501): for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.reduce(32001, ax1_fused_0 * 64 + ax1_fused_1) T.where(ax1_fused_0 * 64 + ax1_fused_1 < 32001) T.reads(A[0, 0, v1]) T.writes(T_softmax_maxelem_shared[0, 0]) with T.init(): T_softmax_maxelem_shared[0, 0] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem_shared[0, 0] = T.max(T_softmax_maxelem_shared[0, 0], A[0, 0, v1]) for ax0, ax1_fused_0 in T.grid(1, 501): for ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("T_softmax_expsum"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.reduce(32001, ax1_fused_0 * 64 + ax1_fused_1) T.where(ax1_fused_0 * 64 + ax1_fused_1 < 32001) T.reads(A[0, 0, v1], T_softmax_maxelem_shared[0, 0]) T.writes(T_softmax_expsum_shared[0, 0]) with T.init(): T_softmax_expsum_shared[0, 0] = T.float32(0) T_softmax_expsum_shared[0, 0] = T_softmax_expsum_shared[0, 0] + T.exp(A[0, 0, v1] - T_softmax_maxelem_shared[0, 0]) for ax1_0 in range(501): for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): with T.block("T_softmax_norm"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(32001, ax1_0 * 64 + ax1_1) T.where(ax1_0 * 64 + ax1_1 < 32001) T.reads(A[0, 0, v1], T_softmax_maxelem_shared[0, 0], T_softmax_expsum_shared[0, 0]) T.writes(T_softmax_norm[0, 0, v1]) T.block_attr({"axis": 2}) T_softmax_norm[0, 0, v1] = T.exp(A[0, 0, v1] - T_softmax_maxelem_shared[0, 0]) / T_softmax_expsum_shared[0, 0] @T.prim_func(private=True) def split1(var_A: T.handle, var_T_split: T.handle, var_T_split_1: T.handle, var_T_split_2: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, n, 12288), "float16") T_split = T.match_buffer(var_T_split, (1, n, 4096), "float16") T_split_1 = T.match_buffer(var_T_split_1, (1, n, 4096), "float16") T_split_2 = T.match_buffer(var_T_split_2, (1, n, 4096), "float16") # with T.block("root"): for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_split"): v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096) v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096) T.reads(A[0, v0, v1]) T.writes(T_split[0, v0, v1]) T_split[0, v0, v1] = A[0, v0, v1] for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_split_1"): v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096) v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096) T.reads(A[0, v0, v1 + 4096]) T.writes(T_split_1[0, v0, v1]) T_split_1[0, v0, v1] = A[0, v0, v1 + 4096] for ax0_ax1_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_split_2"): v0 = T.axis.spatial(n, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) // 4096) v1 = T.axis.spatial(4096, (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1) % 4096) T.reads(A[0, v0, v1 + 8192]) T.writes(T_split_2[0, v0, v1]) T_split_2[0, v0, v1] = A[0, v0, v1 + 8192] @T.prim_func def split_rotary(A: T.Buffer((1, 1, 12288), "float16"), cos: T.Buffer((2048, 128), "float16"), sin: T.Buffer((2048, 128), "float16"), T_split: T.Buffer((1, 1, 4096), "float16"), T_split_1: T.Buffer((1, 1, 4096), "float16"), T_split_2: T.Buffer((1, 1, 4096), "float16"), n: T.int32): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_split"): v0 = T.axis.spatial(4096, ax0_fused_0 * 256 + ax0_fused_1) T.reads(A[0, 0, v0], A[0, 0, v0 + 4096], A[0, 0, v0 + 8192]) T.writes(T_split[0, 0, v0], T_split_1[0, 0, v0], T_split_2[0, 0, v0]) T_split[0, 0, v0] = cos[n - 1, v0 % 128] * A[0, 0, v0] + sin[n - 1, v0 % 128] * T.Select(64 <= v0 % 128, A[0, 0, v0 + -64], A[0, 0, v0 + 64] * T.float16(-1)) T_split_1[0, 0, v0] = cos[n - 1, v0 % 128] * A[0, 0, v0 + 4096] + sin[n - 1, v0 % 128] * T.Select(64 <= v0 % 128, A[0, 0, v0 + 4032], A[0, 0, v0 + 4160] * T.float16(-1)) T_split_2[0, 0, v0] = A[0, 0, v0 + 8192] @T.prim_func(private=True) def squeeze1(var_A: T.handle, var_T_squeeze: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, n, 32, 128), "float16") T_squeeze = T.match_buffer(var_T_squeeze, (n, 32, 128), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_squeeze"): v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096) v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128) v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128) T.reads(A[0, v0, v1, v2]) T.writes(T_squeeze[v0, v1, v2]) T_squeeze[v0, v1, v2] = A[0, v0, v1, v2] @T.prim_func(private=True) def transpose6(var_A: T.handle, var_T_transpose: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, n, 32, 128), "float16") T_transpose = T.match_buffer(var_T_transpose, (1, 32, n, 128), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_transpose"): v0 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // (128 * n)) v1 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % (128 * n) // 128) v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128) T.reads(A[0, v1, v0, v2]) T.writes(T_transpose[0, v0, v1, v2]) T_transpose[0, v0, v1, v2] = A[0, v1, v0, v2] @T.prim_func(private=True) def transpose8(var_A: T.handle, var_T_transpose: T.handle): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) n = T.int32() A = T.match_buffer(var_A, (1, 32, n, 128), "float16") T_transpose = T.match_buffer(var_T_transpose, (1, n, 32, 128), "float16") # with T.block("root"): for ax0_ax1_ax2_fused_0 in T.thread_binding(n * 16, thread="blockIdx.x"): for ax0_ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("T_transpose"): v0 = T.axis.spatial(n, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) // 4096) v1 = T.axis.spatial(32, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 4096 // 128) v2 = T.axis.spatial(128, (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1) % 128) T.reads(A[0, v1, v0, v2]) T.writes(T_transpose[0, v0, v1, v2]) T_transpose[0, v0, v1, v2] = A[0, v1, v0, v2] @R.function def create_kv_cache() -> R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object): R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}}) with R.dataflow(): lv3221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3222: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3223: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3224: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3225: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3226: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3227: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3228: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3229: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3230: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3231: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3232: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3233: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3234: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3235: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3236: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3237: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3238: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3239: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3240: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3241: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3242: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3243: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3244: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3245: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3246: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3247: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3248: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3249: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3250: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3251: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3252: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3253: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3254: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3255: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3256: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3257: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3258: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3259: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3260: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3261: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3262: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3263: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3264: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3265: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3266: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3267: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3268: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3269: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3270: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3271: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3272: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3273: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3274: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3275: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3276: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3277: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3278: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3279: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3280: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3281: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3282: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3283: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) lv3284: R.Object = R.call_packed("vm.builtin.attention_kv_cache_create", metadata["relax.expr.Constant"][0], R.shape([2048, 32, 128]), R.prim_value(0), sinfo_args=(R.Object,)) gv2: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object) = lv3221, lv3222, lv3223, lv3224, lv3225, lv3226, lv3227, lv3228, lv3229, lv3230, lv3231, lv3232, lv3233, lv3234, lv3235, lv3236, lv3237, lv3238, lv3239, lv3240, lv3241, lv3242, lv3243, lv3244, lv3245, lv3246, lv3247, lv3248, lv3249, lv3250, lv3251, lv3252, lv3253, lv3254, lv3255, lv3256, lv3257, lv3258, lv3259, lv3260, lv3261, lv3262, lv3263, lv3264, lv3265, lv3266, lv3267, lv3268, lv3269, lv3270, lv3271, lv3272, lv3273, lv3274, lv3275, lv3276, lv3277, lv3278, lv3279, lv3280, lv3281, lv3282, lv3283, lv3284 R.output(gv2) return gv2 @R.function def decode(input_ids1: R.Tensor((1, 1), dtype="int32"), all_seq_len: R.Shape(["n"]), kv_cache: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object), params: R.Tuple(R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)): n = T.int64() R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}}) cls = Module with R.dataflow(): lv1611 = R.call_tir(cls.reshape, (input_ids1,), out_sinfo=R.Tensor((1,), dtype="int32")) lv: R.Tensor((32001, 512), dtype="uint32") = params[0] lv1: R.Tensor((32001, 128), dtype="float16") = params[1] lv1_1 = R.call_tir(cls.fused_fused_decode1_take, (lv, lv1, lv1611), out_sinfo=R.Tensor((1, 4096), dtype="float16")) lv1613 = R.call_tir(cls.reshape1, (lv1_1,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv1614 = R.call_tir(cls.full, R.tuple(), out_sinfo=R.Tensor((1, 1, 1, n), dtype="float16")) lv460: R.Tensor((4096,), dtype="float16") = params[10] lv1615 = R.call_tir(cls.rms_norm1, (lv1613, lv460), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv3: R.Tensor((12288, 512), dtype="uint32") = params[2] lv4: R.Tensor((12288, 128), dtype="float16") = params[3] lv_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv3, lv4, lv1615), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv464: R.Tensor((2048, 128), dtype="float16") = params[325] lv465: R.Tensor((2048, 128), dtype="float16") = params[326] lv_2 = R.call_tir(cls.split_rotary, (lv_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv6: R.Tensor((1, 1, 4096), dtype="float16") = lv_2[0] lv7 = R.call_tir(cls.fused_reshape2_transpose5, (lv6,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv8: R.Tensor((1, 1, 4096), dtype="float16") = lv_2[1] lv9 = R.call_tir(cls.fused_reshape2_squeeze, (lv8,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv10: R.Tensor((1, 1, 4096), dtype="float16") = lv_2[2] lv11 = R.call_tir(cls.fused_reshape2_squeeze, (lv10,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv1629: R.Object = kv_cache[0] lv1630: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1629, lv9, sinfo_args=(R.Object,)) lv1631: R.Object = kv_cache[1] lv1632: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1631, lv11, sinfo_args=(R.Object,)) lv1633: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1630, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1634: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1632, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1635 = R.call_tir(cls.reshape3, (lv1633,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1636 = R.call_tir(cls.reshape3, (lv1634,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1638 = R.call_tir(cls.transpose6, (lv1635,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1639 = R.call_tir(cls.transpose6, (lv1636,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv12 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv7, lv1638, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv13 = R.call_tir(cls.fused_softmax_cast1, (lv12,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv1648 = R.call_tir(cls.matmul9, (lv13, lv1639), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv14 = R.call_tir(cls.fused_transpose7_reshape4, (lv1648,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv15: R.Tensor((4096, 512), dtype="uint32") = params[4] lv16: R.Tensor((4096, 128), dtype="float16") = params[5] lv_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv15, lv16, lv14, lv1613), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv469: R.Tensor((4096,), dtype="float16") = params[11] lv1654 = R.call_tir(cls.rms_norm1, (lv_3, lv469), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv19: R.Tensor((22016, 512), dtype="uint32") = params[6] lv20: R.Tensor((22016, 128), dtype="float16") = params[7] lv1_2 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv19, lv20, lv1654), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv22 = R.call_tir(cls.fused_split_silu_multiply, (lv1_2,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv23: R.Tensor((4096, 1376), dtype="uint32") = params[8] lv24: R.Tensor((4096, 344), dtype="float16") = params[9] lv1_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv23, lv24, lv22, lv_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv476: R.Tensor((4096,), dtype="float16") = params[20] lv1665 = R.call_tir(cls.rms_norm1, (lv1_3, lv476), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv27: R.Tensor((12288, 512), dtype="uint32") = params[12] lv28: R.Tensor((12288, 128), dtype="float16") = params[13] lv2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv27, lv28, lv1665), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv4_1 = R.call_tir(cls.split_rotary, (lv2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv30: R.Tensor((1, 1, 4096), dtype="float16") = lv4_1[0] lv31 = R.call_tir(cls.fused_reshape2_transpose5, (lv30,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv32: R.Tensor((1, 1, 4096), dtype="float16") = lv4_1[1] lv33 = R.call_tir(cls.fused_reshape2_squeeze, (lv32,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv34: R.Tensor((1, 1, 4096), dtype="float16") = lv4_1[2] lv35 = R.call_tir(cls.fused_reshape2_squeeze, (lv34,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv1679: R.Object = kv_cache[2] lv1680: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1679, lv33, sinfo_args=(R.Object,)) lv1681: R.Object = kv_cache[3] lv1682: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1681, lv35, sinfo_args=(R.Object,)) lv1683: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1680, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1684: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1682, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1685 = R.call_tir(cls.reshape3, (lv1683,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1686 = R.call_tir(cls.reshape3, (lv1684,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1688 = R.call_tir(cls.transpose6, (lv1685,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1689 = R.call_tir(cls.transpose6, (lv1686,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv36 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv31, lv1688, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv37 = R.call_tir(cls.fused_softmax_cast1, (lv36,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv1698 = R.call_tir(cls.matmul9, (lv37, lv1689), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv38 = R.call_tir(cls.fused_transpose7_reshape4, (lv1698,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv39: R.Tensor((4096, 512), dtype="uint32") = params[14] lv40: R.Tensor((4096, 128), dtype="float16") = params[15] lv2_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv39, lv40, lv38, lv1_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv483: R.Tensor((4096,), dtype="float16") = params[21] lv1704 = R.call_tir(cls.rms_norm1, (lv2_1, lv483), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv43: R.Tensor((22016, 512), dtype="uint32") = params[16] lv44: R.Tensor((22016, 128), dtype="float16") = params[17] lv3_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv43, lv44, lv1704), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv46 = R.call_tir(cls.fused_split_silu_multiply, (lv3_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv47: R.Tensor((4096, 1376), dtype="uint32") = params[18] lv48: R.Tensor((4096, 344), dtype="float16") = params[19] lv3_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv47, lv48, lv46, lv2_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv490: R.Tensor((4096,), dtype="float16") = params[30] lv1715 = R.call_tir(cls.rms_norm1, (lv3_2, lv490), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv51: R.Tensor((12288, 512), dtype="uint32") = params[22] lv52: R.Tensor((12288, 128), dtype="float16") = params[23] lv4_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv51, lv52, lv1715), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv8_1 = R.call_tir(cls.split_rotary, (lv4_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv54: R.Tensor((1, 1, 4096), dtype="float16") = lv8_1[0] lv55 = R.call_tir(cls.fused_reshape2_transpose5, (lv54,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv56: R.Tensor((1, 1, 4096), dtype="float16") = lv8_1[1] lv57 = R.call_tir(cls.fused_reshape2_squeeze, (lv56,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv58: R.Tensor((1, 1, 4096), dtype="float16") = lv8_1[2] lv59 = R.call_tir(cls.fused_reshape2_squeeze, (lv58,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv1729: R.Object = kv_cache[4] lv1730: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1729, lv57, sinfo_args=(R.Object,)) lv1731: R.Object = kv_cache[5] lv1732: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1731, lv59, sinfo_args=(R.Object,)) lv1733: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1730, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1734: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1732, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1735 = R.call_tir(cls.reshape3, (lv1733,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1736 = R.call_tir(cls.reshape3, (lv1734,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1738 = R.call_tir(cls.transpose6, (lv1735,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1739 = R.call_tir(cls.transpose6, (lv1736,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv60 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv55, lv1738, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv61 = R.call_tir(cls.fused_softmax_cast1, (lv60,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv1748 = R.call_tir(cls.matmul9, (lv61, lv1739), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv62 = R.call_tir(cls.fused_transpose7_reshape4, (lv1748,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv63: R.Tensor((4096, 512), dtype="uint32") = params[24] lv64: R.Tensor((4096, 128), dtype="float16") = params[25] lv4_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv63, lv64, lv62, lv3_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv497: R.Tensor((4096,), dtype="float16") = params[31] lv1754 = R.call_tir(cls.rms_norm1, (lv4_3, lv497), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv67: R.Tensor((22016, 512), dtype="uint32") = params[26] lv68: R.Tensor((22016, 128), dtype="float16") = params[27] lv5 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv67, lv68, lv1754), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv70 = R.call_tir(cls.fused_split_silu_multiply, (lv5,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv71: R.Tensor((4096, 1376), dtype="uint32") = params[28] lv72: R.Tensor((4096, 344), dtype="float16") = params[29] lv5_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv71, lv72, lv70, lv4_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv504: R.Tensor((4096,), dtype="float16") = params[40] lv1765 = R.call_tir(cls.rms_norm1, (lv5_1, lv504), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv75: R.Tensor((12288, 512), dtype="uint32") = params[32] lv76: R.Tensor((12288, 128), dtype="float16") = params[33] lv6_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv75, lv76, lv1765), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv12_1 = R.call_tir(cls.split_rotary, (lv6_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv78: R.Tensor((1, 1, 4096), dtype="float16") = lv12_1[0] lv79 = R.call_tir(cls.fused_reshape2_transpose5, (lv78,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv80: R.Tensor((1, 1, 4096), dtype="float16") = lv12_1[1] lv81 = R.call_tir(cls.fused_reshape2_squeeze, (lv80,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv82: R.Tensor((1, 1, 4096), dtype="float16") = lv12_1[2] lv83 = R.call_tir(cls.fused_reshape2_squeeze, (lv82,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv1779: R.Object = kv_cache[6] lv1780: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1779, lv81, sinfo_args=(R.Object,)) lv1781: R.Object = kv_cache[7] lv1782: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1781, lv83, sinfo_args=(R.Object,)) lv1783: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1780, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1784: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1782, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1785 = R.call_tir(cls.reshape3, (lv1783,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1786 = R.call_tir(cls.reshape3, (lv1784,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1788 = R.call_tir(cls.transpose6, (lv1785,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1789 = R.call_tir(cls.transpose6, (lv1786,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv84 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv79, lv1788, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv85 = R.call_tir(cls.fused_softmax_cast1, (lv84,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv1798 = R.call_tir(cls.matmul9, (lv85, lv1789), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv86 = R.call_tir(cls.fused_transpose7_reshape4, (lv1798,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv87: R.Tensor((4096, 512), dtype="uint32") = params[34] lv88: R.Tensor((4096, 128), dtype="float16") = params[35] lv6_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv87, lv88, lv86, lv5_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv511: R.Tensor((4096,), dtype="float16") = params[41] lv1804 = R.call_tir(cls.rms_norm1, (lv6_2, lv511), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv91: R.Tensor((22016, 512), dtype="uint32") = params[36] lv92: R.Tensor((22016, 128), dtype="float16") = params[37] lv7_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv91, lv92, lv1804), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv94 = R.call_tir(cls.fused_split_silu_multiply, (lv7_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv95: R.Tensor((4096, 1376), dtype="uint32") = params[38] lv96: R.Tensor((4096, 344), dtype="float16") = params[39] lv7_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv95, lv96, lv94, lv6_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv518: R.Tensor((4096,), dtype="float16") = params[50] lv1815 = R.call_tir(cls.rms_norm1, (lv7_2, lv518), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv99: R.Tensor((12288, 512), dtype="uint32") = params[42] lv100: R.Tensor((12288, 128), dtype="float16") = params[43] lv8_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv99, lv100, lv1815), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv16_1 = R.call_tir(cls.split_rotary, (lv8_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv102: R.Tensor((1, 1, 4096), dtype="float16") = lv16_1[0] lv103 = R.call_tir(cls.fused_reshape2_transpose5, (lv102,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv104: R.Tensor((1, 1, 4096), dtype="float16") = lv16_1[1] lv105 = R.call_tir(cls.fused_reshape2_squeeze, (lv104,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv106: R.Tensor((1, 1, 4096), dtype="float16") = lv16_1[2] lv107 = R.call_tir(cls.fused_reshape2_squeeze, (lv106,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv1829: R.Object = kv_cache[8] lv1830: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1829, lv105, sinfo_args=(R.Object,)) lv1831: R.Object = kv_cache[9] lv1832: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1831, lv107, sinfo_args=(R.Object,)) lv1833: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1830, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1834: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1832, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1835 = R.call_tir(cls.reshape3, (lv1833,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1836 = R.call_tir(cls.reshape3, (lv1834,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1838 = R.call_tir(cls.transpose6, (lv1835,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1839 = R.call_tir(cls.transpose6, (lv1836,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv108 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv103, lv1838, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv109 = R.call_tir(cls.fused_softmax_cast1, (lv108,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv1848 = R.call_tir(cls.matmul9, (lv109, lv1839), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv110 = R.call_tir(cls.fused_transpose7_reshape4, (lv1848,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv111: R.Tensor((4096, 512), dtype="uint32") = params[44] lv112: R.Tensor((4096, 128), dtype="float16") = params[45] lv8_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv111, lv112, lv110, lv7_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv525: R.Tensor((4096,), dtype="float16") = params[51] lv1854 = R.call_tir(cls.rms_norm1, (lv8_3, lv525), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv115: R.Tensor((22016, 512), dtype="uint32") = params[46] lv116: R.Tensor((22016, 128), dtype="float16") = params[47] lv9_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv115, lv116, lv1854), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv118 = R.call_tir(cls.fused_split_silu_multiply, (lv9_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv119: R.Tensor((4096, 1376), dtype="uint32") = params[48] lv120: R.Tensor((4096, 344), dtype="float16") = params[49] lv9_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv119, lv120, lv118, lv8_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv532: R.Tensor((4096,), dtype="float16") = params[60] lv1865 = R.call_tir(cls.rms_norm1, (lv9_2, lv532), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv123: R.Tensor((12288, 512), dtype="uint32") = params[52] lv124: R.Tensor((12288, 128), dtype="float16") = params[53] lv10_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv123, lv124, lv1865), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv20_1 = R.call_tir(cls.split_rotary, (lv10_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv126: R.Tensor((1, 1, 4096), dtype="float16") = lv20_1[0] lv127 = R.call_tir(cls.fused_reshape2_transpose5, (lv126,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv128: R.Tensor((1, 1, 4096), dtype="float16") = lv20_1[1] lv129 = R.call_tir(cls.fused_reshape2_squeeze, (lv128,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv130: R.Tensor((1, 1, 4096), dtype="float16") = lv20_1[2] lv131 = R.call_tir(cls.fused_reshape2_squeeze, (lv130,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv1879: R.Object = kv_cache[10] lv1880: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1879, lv129, sinfo_args=(R.Object,)) lv1881: R.Object = kv_cache[11] lv1882: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1881, lv131, sinfo_args=(R.Object,)) lv1883: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1880, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1884: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1882, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1885 = R.call_tir(cls.reshape3, (lv1883,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1886 = R.call_tir(cls.reshape3, (lv1884,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1888 = R.call_tir(cls.transpose6, (lv1885,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1889 = R.call_tir(cls.transpose6, (lv1886,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv132 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv127, lv1888, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv133 = R.call_tir(cls.fused_softmax_cast1, (lv132,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv1898 = R.call_tir(cls.matmul9, (lv133, lv1889), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv134 = R.call_tir(cls.fused_transpose7_reshape4, (lv1898,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv135: R.Tensor((4096, 512), dtype="uint32") = params[54] lv136: R.Tensor((4096, 128), dtype="float16") = params[55] lv10_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv135, lv136, lv134, lv9_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv539: R.Tensor((4096,), dtype="float16") = params[61] lv1904 = R.call_tir(cls.rms_norm1, (lv10_2, lv539), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv139: R.Tensor((22016, 512), dtype="uint32") = params[56] lv140: R.Tensor((22016, 128), dtype="float16") = params[57] lv11_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv139, lv140, lv1904), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv142 = R.call_tir(cls.fused_split_silu_multiply, (lv11_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv143: R.Tensor((4096, 1376), dtype="uint32") = params[58] lv144: R.Tensor((4096, 344), dtype="float16") = params[59] lv11_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv143, lv144, lv142, lv10_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv546: R.Tensor((4096,), dtype="float16") = params[70] lv1915 = R.call_tir(cls.rms_norm1, (lv11_2, lv546), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv147: R.Tensor((12288, 512), dtype="uint32") = params[62] lv148: R.Tensor((12288, 128), dtype="float16") = params[63] lv12_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv147, lv148, lv1915), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv24_1 = R.call_tir(cls.split_rotary, (lv12_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv150: R.Tensor((1, 1, 4096), dtype="float16") = lv24_1[0] lv151 = R.call_tir(cls.fused_reshape2_transpose5, (lv150,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv152: R.Tensor((1, 1, 4096), dtype="float16") = lv24_1[1] lv153 = R.call_tir(cls.fused_reshape2_squeeze, (lv152,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv154: R.Tensor((1, 1, 4096), dtype="float16") = lv24_1[2] lv155 = R.call_tir(cls.fused_reshape2_squeeze, (lv154,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv1929: R.Object = kv_cache[12] lv1930: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1929, lv153, sinfo_args=(R.Object,)) lv1931: R.Object = kv_cache[13] lv1932: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1931, lv155, sinfo_args=(R.Object,)) lv1933: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1930, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1934: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1932, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1935 = R.call_tir(cls.reshape3, (lv1933,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1936 = R.call_tir(cls.reshape3, (lv1934,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1938 = R.call_tir(cls.transpose6, (lv1935,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1939 = R.call_tir(cls.transpose6, (lv1936,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv156 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv151, lv1938, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv157 = R.call_tir(cls.fused_softmax_cast1, (lv156,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv1948 = R.call_tir(cls.matmul9, (lv157, lv1939), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv158 = R.call_tir(cls.fused_transpose7_reshape4, (lv1948,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv159: R.Tensor((4096, 512), dtype="uint32") = params[64] lv160: R.Tensor((4096, 128), dtype="float16") = params[65] lv12_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv159, lv160, lv158, lv11_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv553: R.Tensor((4096,), dtype="float16") = params[71] lv1954 = R.call_tir(cls.rms_norm1, (lv12_3, lv553), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv163: R.Tensor((22016, 512), dtype="uint32") = params[66] lv164: R.Tensor((22016, 128), dtype="float16") = params[67] lv13_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv163, lv164, lv1954), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv166 = R.call_tir(cls.fused_split_silu_multiply, (lv13_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv167: R.Tensor((4096, 1376), dtype="uint32") = params[68] lv168: R.Tensor((4096, 344), dtype="float16") = params[69] lv13_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv167, lv168, lv166, lv12_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv560: R.Tensor((4096,), dtype="float16") = params[80] lv1965 = R.call_tir(cls.rms_norm1, (lv13_2, lv560), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv171: R.Tensor((12288, 512), dtype="uint32") = params[72] lv172: R.Tensor((12288, 128), dtype="float16") = params[73] lv14_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv171, lv172, lv1965), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv28_1 = R.call_tir(cls.split_rotary, (lv14_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv174: R.Tensor((1, 1, 4096), dtype="float16") = lv28_1[0] lv175 = R.call_tir(cls.fused_reshape2_transpose5, (lv174,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv176: R.Tensor((1, 1, 4096), dtype="float16") = lv28_1[1] lv177 = R.call_tir(cls.fused_reshape2_squeeze, (lv176,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv178: R.Tensor((1, 1, 4096), dtype="float16") = lv28_1[2] lv179 = R.call_tir(cls.fused_reshape2_squeeze, (lv178,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv1979: R.Object = kv_cache[14] lv1980: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1979, lv177, sinfo_args=(R.Object,)) lv1981: R.Object = kv_cache[15] lv1982: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1981, lv179, sinfo_args=(R.Object,)) lv1983: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1980, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1984: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1982, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv1985 = R.call_tir(cls.reshape3, (lv1983,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1986 = R.call_tir(cls.reshape3, (lv1984,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1988 = R.call_tir(cls.transpose6, (lv1985,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1989 = R.call_tir(cls.transpose6, (lv1986,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv180 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv175, lv1988, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv181 = R.call_tir(cls.fused_softmax_cast1, (lv180,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv1998 = R.call_tir(cls.matmul9, (lv181, lv1989), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv182 = R.call_tir(cls.fused_transpose7_reshape4, (lv1998,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv183: R.Tensor((4096, 512), dtype="uint32") = params[74] lv184: R.Tensor((4096, 128), dtype="float16") = params[75] lv14_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv183, lv184, lv182, lv13_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv567: R.Tensor((4096,), dtype="float16") = params[81] lv2004 = R.call_tir(cls.rms_norm1, (lv14_2, lv567), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv187: R.Tensor((22016, 512), dtype="uint32") = params[76] lv188: R.Tensor((22016, 128), dtype="float16") = params[77] lv15_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv187, lv188, lv2004), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv190 = R.call_tir(cls.fused_split_silu_multiply, (lv15_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv191: R.Tensor((4096, 1376), dtype="uint32") = params[78] lv192: R.Tensor((4096, 344), dtype="float16") = params[79] lv15_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv191, lv192, lv190, lv14_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv574: R.Tensor((4096,), dtype="float16") = params[90] lv2015 = R.call_tir(cls.rms_norm1, (lv15_2, lv574), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv195: R.Tensor((12288, 512), dtype="uint32") = params[82] lv196: R.Tensor((12288, 128), dtype="float16") = params[83] lv16_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv195, lv196, lv2015), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv32_1 = R.call_tir(cls.split_rotary, (lv16_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv198: R.Tensor((1, 1, 4096), dtype="float16") = lv32_1[0] lv199 = R.call_tir(cls.fused_reshape2_transpose5, (lv198,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv200: R.Tensor((1, 1, 4096), dtype="float16") = lv32_1[1] lv201 = R.call_tir(cls.fused_reshape2_squeeze, (lv200,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv202: R.Tensor((1, 1, 4096), dtype="float16") = lv32_1[2] lv203 = R.call_tir(cls.fused_reshape2_squeeze, (lv202,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2029: R.Object = kv_cache[16] lv2030: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2029, lv201, sinfo_args=(R.Object,)) lv2031: R.Object = kv_cache[17] lv2032: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2031, lv203, sinfo_args=(R.Object,)) lv2033: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2030, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2034: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2032, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2035 = R.call_tir(cls.reshape3, (lv2033,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2036 = R.call_tir(cls.reshape3, (lv2034,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2038 = R.call_tir(cls.transpose6, (lv2035,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2039 = R.call_tir(cls.transpose6, (lv2036,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv204 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv199, lv2038, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv205 = R.call_tir(cls.fused_softmax_cast1, (lv204,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2048 = R.call_tir(cls.matmul9, (lv205, lv2039), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv206 = R.call_tir(cls.fused_transpose7_reshape4, (lv2048,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv207: R.Tensor((4096, 512), dtype="uint32") = params[84] lv208: R.Tensor((4096, 128), dtype="float16") = params[85] lv16_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv207, lv208, lv206, lv15_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv581: R.Tensor((4096,), dtype="float16") = params[91] lv2054 = R.call_tir(cls.rms_norm1, (lv16_3, lv581), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv211: R.Tensor((22016, 512), dtype="uint32") = params[86] lv212: R.Tensor((22016, 128), dtype="float16") = params[87] lv17 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv211, lv212, lv2054), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv214 = R.call_tir(cls.fused_split_silu_multiply, (lv17,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv215: R.Tensor((4096, 1376), dtype="uint32") = params[88] lv216: R.Tensor((4096, 344), dtype="float16") = params[89] lv17_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv215, lv216, lv214, lv16_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv588: R.Tensor((4096,), dtype="float16") = params[100] lv2065 = R.call_tir(cls.rms_norm1, (lv17_1, lv588), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv219: R.Tensor((12288, 512), dtype="uint32") = params[92] lv220: R.Tensor((12288, 128), dtype="float16") = params[93] lv18 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv219, lv220, lv2065), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv36_1 = R.call_tir(cls.split_rotary, (lv18, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv222: R.Tensor((1, 1, 4096), dtype="float16") = lv36_1[0] lv223 = R.call_tir(cls.fused_reshape2_transpose5, (lv222,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv224: R.Tensor((1, 1, 4096), dtype="float16") = lv36_1[1] lv225 = R.call_tir(cls.fused_reshape2_squeeze, (lv224,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv226: R.Tensor((1, 1, 4096), dtype="float16") = lv36_1[2] lv227 = R.call_tir(cls.fused_reshape2_squeeze, (lv226,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2079: R.Object = kv_cache[18] lv2080: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2079, lv225, sinfo_args=(R.Object,)) lv2081: R.Object = kv_cache[19] lv2082: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2081, lv227, sinfo_args=(R.Object,)) lv2083: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2080, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2084: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2082, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2085 = R.call_tir(cls.reshape3, (lv2083,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2086 = R.call_tir(cls.reshape3, (lv2084,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2088 = R.call_tir(cls.transpose6, (lv2085,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2089 = R.call_tir(cls.transpose6, (lv2086,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv228 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv223, lv2088, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv229 = R.call_tir(cls.fused_softmax_cast1, (lv228,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2098 = R.call_tir(cls.matmul9, (lv229, lv2089), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv230 = R.call_tir(cls.fused_transpose7_reshape4, (lv2098,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv231: R.Tensor((4096, 512), dtype="uint32") = params[94] lv232: R.Tensor((4096, 128), dtype="float16") = params[95] lv18_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv231, lv232, lv230, lv17_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv595: R.Tensor((4096,), dtype="float16") = params[101] lv2104 = R.call_tir(cls.rms_norm1, (lv18_1, lv595), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv235: R.Tensor((22016, 512), dtype="uint32") = params[96] lv236: R.Tensor((22016, 128), dtype="float16") = params[97] lv19_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv235, lv236, lv2104), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv238 = R.call_tir(cls.fused_split_silu_multiply, (lv19_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv239: R.Tensor((4096, 1376), dtype="uint32") = params[98] lv240: R.Tensor((4096, 344), dtype="float16") = params[99] lv19_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv239, lv240, lv238, lv18_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv602: R.Tensor((4096,), dtype="float16") = params[110] lv2115 = R.call_tir(cls.rms_norm1, (lv19_2, lv602), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv243: R.Tensor((12288, 512), dtype="uint32") = params[102] lv244: R.Tensor((12288, 128), dtype="float16") = params[103] lv20_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv243, lv244, lv2115), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv40_1 = R.call_tir(cls.split_rotary, (lv20_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv246: R.Tensor((1, 1, 4096), dtype="float16") = lv40_1[0] lv247 = R.call_tir(cls.fused_reshape2_transpose5, (lv246,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv248: R.Tensor((1, 1, 4096), dtype="float16") = lv40_1[1] lv249 = R.call_tir(cls.fused_reshape2_squeeze, (lv248,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv250: R.Tensor((1, 1, 4096), dtype="float16") = lv40_1[2] lv251 = R.call_tir(cls.fused_reshape2_squeeze, (lv250,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2129: R.Object = kv_cache[20] lv2130: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2129, lv249, sinfo_args=(R.Object,)) lv2131: R.Object = kv_cache[21] lv2132: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2131, lv251, sinfo_args=(R.Object,)) lv2133: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2130, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2134: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2132, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2135 = R.call_tir(cls.reshape3, (lv2133,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2136 = R.call_tir(cls.reshape3, (lv2134,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2138 = R.call_tir(cls.transpose6, (lv2135,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2139 = R.call_tir(cls.transpose6, (lv2136,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv252 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv247, lv2138, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv253 = R.call_tir(cls.fused_softmax_cast1, (lv252,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2148 = R.call_tir(cls.matmul9, (lv253, lv2139), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv254 = R.call_tir(cls.fused_transpose7_reshape4, (lv2148,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv255: R.Tensor((4096, 512), dtype="uint32") = params[104] lv256: R.Tensor((4096, 128), dtype="float16") = params[105] lv20_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv255, lv256, lv254, lv19_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv609: R.Tensor((4096,), dtype="float16") = params[111] lv2154 = R.call_tir(cls.rms_norm1, (lv20_3, lv609), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv259: R.Tensor((22016, 512), dtype="uint32") = params[106] lv260: R.Tensor((22016, 128), dtype="float16") = params[107] lv21 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv259, lv260, lv2154), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv262 = R.call_tir(cls.fused_split_silu_multiply, (lv21,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv263: R.Tensor((4096, 1376), dtype="uint32") = params[108] lv264: R.Tensor((4096, 344), dtype="float16") = params[109] lv21_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv263, lv264, lv262, lv20_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv616: R.Tensor((4096,), dtype="float16") = params[120] lv2165 = R.call_tir(cls.rms_norm1, (lv21_1, lv616), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv267: R.Tensor((12288, 512), dtype="uint32") = params[112] lv268: R.Tensor((12288, 128), dtype="float16") = params[113] lv22_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv267, lv268, lv2165), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv44_1 = R.call_tir(cls.split_rotary, (lv22_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv270: R.Tensor((1, 1, 4096), dtype="float16") = lv44_1[0] lv271 = R.call_tir(cls.fused_reshape2_transpose5, (lv270,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv272: R.Tensor((1, 1, 4096), dtype="float16") = lv44_1[1] lv273 = R.call_tir(cls.fused_reshape2_squeeze, (lv272,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv274: R.Tensor((1, 1, 4096), dtype="float16") = lv44_1[2] lv275 = R.call_tir(cls.fused_reshape2_squeeze, (lv274,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2179: R.Object = kv_cache[22] lv2180: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2179, lv273, sinfo_args=(R.Object,)) lv2181: R.Object = kv_cache[23] lv2182: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2181, lv275, sinfo_args=(R.Object,)) lv2183: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2180, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2184: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2182, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2185 = R.call_tir(cls.reshape3, (lv2183,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2186 = R.call_tir(cls.reshape3, (lv2184,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2188 = R.call_tir(cls.transpose6, (lv2185,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2189 = R.call_tir(cls.transpose6, (lv2186,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv276 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv271, lv2188, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv277 = R.call_tir(cls.fused_softmax_cast1, (lv276,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2198 = R.call_tir(cls.matmul9, (lv277, lv2189), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv278 = R.call_tir(cls.fused_transpose7_reshape4, (lv2198,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv279: R.Tensor((4096, 512), dtype="uint32") = params[114] lv280: R.Tensor((4096, 128), dtype="float16") = params[115] lv22_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv279, lv280, lv278, lv21_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv623: R.Tensor((4096,), dtype="float16") = params[121] lv2204 = R.call_tir(cls.rms_norm1, (lv22_2, lv623), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv283: R.Tensor((22016, 512), dtype="uint32") = params[116] lv284: R.Tensor((22016, 128), dtype="float16") = params[117] lv23_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv283, lv284, lv2204), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv286 = R.call_tir(cls.fused_split_silu_multiply, (lv23_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv287: R.Tensor((4096, 1376), dtype="uint32") = params[118] lv288: R.Tensor((4096, 344), dtype="float16") = params[119] lv23_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv287, lv288, lv286, lv22_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv630: R.Tensor((4096,), dtype="float16") = params[130] lv2215 = R.call_tir(cls.rms_norm1, (lv23_2, lv630), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv291: R.Tensor((12288, 512), dtype="uint32") = params[122] lv292: R.Tensor((12288, 128), dtype="float16") = params[123] lv24_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv291, lv292, lv2215), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv48_1 = R.call_tir(cls.split_rotary, (lv24_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv294: R.Tensor((1, 1, 4096), dtype="float16") = lv48_1[0] lv295 = R.call_tir(cls.fused_reshape2_transpose5, (lv294,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv296: R.Tensor((1, 1, 4096), dtype="float16") = lv48_1[1] lv297 = R.call_tir(cls.fused_reshape2_squeeze, (lv296,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv298: R.Tensor((1, 1, 4096), dtype="float16") = lv48_1[2] lv299 = R.call_tir(cls.fused_reshape2_squeeze, (lv298,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2229: R.Object = kv_cache[24] lv2230: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2229, lv297, sinfo_args=(R.Object,)) lv2231: R.Object = kv_cache[25] lv2232: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2231, lv299, sinfo_args=(R.Object,)) lv2233: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2230, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2234: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2232, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2235 = R.call_tir(cls.reshape3, (lv2233,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2236 = R.call_tir(cls.reshape3, (lv2234,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2238 = R.call_tir(cls.transpose6, (lv2235,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2239 = R.call_tir(cls.transpose6, (lv2236,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv300 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv295, lv2238, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv301 = R.call_tir(cls.fused_softmax_cast1, (lv300,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2248 = R.call_tir(cls.matmul9, (lv301, lv2239), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv302 = R.call_tir(cls.fused_transpose7_reshape4, (lv2248,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv303: R.Tensor((4096, 512), dtype="uint32") = params[124] lv304: R.Tensor((4096, 128), dtype="float16") = params[125] lv24_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv303, lv304, lv302, lv23_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv637: R.Tensor((4096,), dtype="float16") = params[131] lv2254 = R.call_tir(cls.rms_norm1, (lv24_3, lv637), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv307: R.Tensor((22016, 512), dtype="uint32") = params[126] lv308: R.Tensor((22016, 128), dtype="float16") = params[127] lv25 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv307, lv308, lv2254), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv310 = R.call_tir(cls.fused_split_silu_multiply, (lv25,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv311: R.Tensor((4096, 1376), dtype="uint32") = params[128] lv312: R.Tensor((4096, 344), dtype="float16") = params[129] lv25_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv311, lv312, lv310, lv24_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv644: R.Tensor((4096,), dtype="float16") = params[140] lv2265 = R.call_tir(cls.rms_norm1, (lv25_1, lv644), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv315: R.Tensor((12288, 512), dtype="uint32") = params[132] lv316: R.Tensor((12288, 128), dtype="float16") = params[133] lv26 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv315, lv316, lv2265), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv52_1 = R.call_tir(cls.split_rotary, (lv26, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv318: R.Tensor((1, 1, 4096), dtype="float16") = lv52_1[0] lv319 = R.call_tir(cls.fused_reshape2_transpose5, (lv318,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv320: R.Tensor((1, 1, 4096), dtype="float16") = lv52_1[1] lv321 = R.call_tir(cls.fused_reshape2_squeeze, (lv320,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv322: R.Tensor((1, 1, 4096), dtype="float16") = lv52_1[2] lv323 = R.call_tir(cls.fused_reshape2_squeeze, (lv322,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2279: R.Object = kv_cache[26] lv2280: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2279, lv321, sinfo_args=(R.Object,)) lv2281: R.Object = kv_cache[27] lv2282: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2281, lv323, sinfo_args=(R.Object,)) lv2283: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2280, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2284: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2282, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2285 = R.call_tir(cls.reshape3, (lv2283,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2286 = R.call_tir(cls.reshape3, (lv2284,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2288 = R.call_tir(cls.transpose6, (lv2285,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2289 = R.call_tir(cls.transpose6, (lv2286,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv324 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv319, lv2288, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv325 = R.call_tir(cls.fused_softmax_cast1, (lv324,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2298 = R.call_tir(cls.matmul9, (lv325, lv2289), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv326 = R.call_tir(cls.fused_transpose7_reshape4, (lv2298,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv327: R.Tensor((4096, 512), dtype="uint32") = params[134] lv328: R.Tensor((4096, 128), dtype="float16") = params[135] lv26_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv327, lv328, lv326, lv25_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv651: R.Tensor((4096,), dtype="float16") = params[141] lv2304 = R.call_tir(cls.rms_norm1, (lv26_1, lv651), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv331: R.Tensor((22016, 512), dtype="uint32") = params[136] lv332: R.Tensor((22016, 128), dtype="float16") = params[137] lv27_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv331, lv332, lv2304), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv334 = R.call_tir(cls.fused_split_silu_multiply, (lv27_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv335: R.Tensor((4096, 1376), dtype="uint32") = params[138] lv336: R.Tensor((4096, 344), dtype="float16") = params[139] lv27_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv335, lv336, lv334, lv26_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv658: R.Tensor((4096,), dtype="float16") = params[150] lv2315 = R.call_tir(cls.rms_norm1, (lv27_2, lv658), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv339: R.Tensor((12288, 512), dtype="uint32") = params[142] lv340: R.Tensor((12288, 128), dtype="float16") = params[143] lv28_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv339, lv340, lv2315), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv56_1 = R.call_tir(cls.split_rotary, (lv28_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv342: R.Tensor((1, 1, 4096), dtype="float16") = lv56_1[0] lv343 = R.call_tir(cls.fused_reshape2_transpose5, (lv342,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv344: R.Tensor((1, 1, 4096), dtype="float16") = lv56_1[1] lv345 = R.call_tir(cls.fused_reshape2_squeeze, (lv344,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv346: R.Tensor((1, 1, 4096), dtype="float16") = lv56_1[2] lv347 = R.call_tir(cls.fused_reshape2_squeeze, (lv346,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2329: R.Object = kv_cache[28] lv2330: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2329, lv345, sinfo_args=(R.Object,)) lv2331: R.Object = kv_cache[29] lv2332: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2331, lv347, sinfo_args=(R.Object,)) lv2333: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2330, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2334: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2332, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2335 = R.call_tir(cls.reshape3, (lv2333,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2336 = R.call_tir(cls.reshape3, (lv2334,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2338 = R.call_tir(cls.transpose6, (lv2335,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2339 = R.call_tir(cls.transpose6, (lv2336,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv348 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv343, lv2338, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv349 = R.call_tir(cls.fused_softmax_cast1, (lv348,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2348 = R.call_tir(cls.matmul9, (lv349, lv2339), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv350 = R.call_tir(cls.fused_transpose7_reshape4, (lv2348,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv351: R.Tensor((4096, 512), dtype="uint32") = params[144] lv352: R.Tensor((4096, 128), dtype="float16") = params[145] lv28_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv351, lv352, lv350, lv27_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv665: R.Tensor((4096,), dtype="float16") = params[151] lv2354 = R.call_tir(cls.rms_norm1, (lv28_3, lv665), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv355: R.Tensor((22016, 512), dtype="uint32") = params[146] lv356: R.Tensor((22016, 128), dtype="float16") = params[147] lv29 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv355, lv356, lv2354), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv358 = R.call_tir(cls.fused_split_silu_multiply, (lv29,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv359: R.Tensor((4096, 1376), dtype="uint32") = params[148] lv360: R.Tensor((4096, 344), dtype="float16") = params[149] lv29_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv359, lv360, lv358, lv28_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv672: R.Tensor((4096,), dtype="float16") = params[160] lv2365 = R.call_tir(cls.rms_norm1, (lv29_1, lv672), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv363: R.Tensor((12288, 512), dtype="uint32") = params[152] lv364: R.Tensor((12288, 128), dtype="float16") = params[153] lv30_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv363, lv364, lv2365), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv60_1 = R.call_tir(cls.split_rotary, (lv30_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv366: R.Tensor((1, 1, 4096), dtype="float16") = lv60_1[0] lv367 = R.call_tir(cls.fused_reshape2_transpose5, (lv366,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv368: R.Tensor((1, 1, 4096), dtype="float16") = lv60_1[1] lv369 = R.call_tir(cls.fused_reshape2_squeeze, (lv368,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv370: R.Tensor((1, 1, 4096), dtype="float16") = lv60_1[2] lv371 = R.call_tir(cls.fused_reshape2_squeeze, (lv370,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2379: R.Object = kv_cache[30] lv2380: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2379, lv369, sinfo_args=(R.Object,)) lv2381: R.Object = kv_cache[31] lv2382: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2381, lv371, sinfo_args=(R.Object,)) lv2383: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2380, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2384: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2382, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2385 = R.call_tir(cls.reshape3, (lv2383,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2386 = R.call_tir(cls.reshape3, (lv2384,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2388 = R.call_tir(cls.transpose6, (lv2385,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2389 = R.call_tir(cls.transpose6, (lv2386,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv372 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv367, lv2388, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv373 = R.call_tir(cls.fused_softmax_cast1, (lv372,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2398 = R.call_tir(cls.matmul9, (lv373, lv2389), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv374 = R.call_tir(cls.fused_transpose7_reshape4, (lv2398,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv375: R.Tensor((4096, 512), dtype="uint32") = params[154] lv376: R.Tensor((4096, 128), dtype="float16") = params[155] lv30_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv375, lv376, lv374, lv29_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv679: R.Tensor((4096,), dtype="float16") = params[161] lv2404 = R.call_tir(cls.rms_norm1, (lv30_2, lv679), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv379: R.Tensor((22016, 512), dtype="uint32") = params[156] lv380: R.Tensor((22016, 128), dtype="float16") = params[157] lv31_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv379, lv380, lv2404), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv382 = R.call_tir(cls.fused_split_silu_multiply, (lv31_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv383: R.Tensor((4096, 1376), dtype="uint32") = params[158] lv384: R.Tensor((4096, 344), dtype="float16") = params[159] lv31_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv383, lv384, lv382, lv30_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv686: R.Tensor((4096,), dtype="float16") = params[170] lv2415 = R.call_tir(cls.rms_norm1, (lv31_2, lv686), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv387: R.Tensor((12288, 512), dtype="uint32") = params[162] lv388: R.Tensor((12288, 128), dtype="float16") = params[163] lv32_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv387, lv388, lv2415), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv64_1 = R.call_tir(cls.split_rotary, (lv32_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv390: R.Tensor((1, 1, 4096), dtype="float16") = lv64_1[0] lv391 = R.call_tir(cls.fused_reshape2_transpose5, (lv390,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv392: R.Tensor((1, 1, 4096), dtype="float16") = lv64_1[1] lv393 = R.call_tir(cls.fused_reshape2_squeeze, (lv392,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv394: R.Tensor((1, 1, 4096), dtype="float16") = lv64_1[2] lv395 = R.call_tir(cls.fused_reshape2_squeeze, (lv394,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2429: R.Object = kv_cache[32] lv2430: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2429, lv393, sinfo_args=(R.Object,)) lv2431: R.Object = kv_cache[33] lv2432: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2431, lv395, sinfo_args=(R.Object,)) lv2433: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2430, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2434: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2432, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2435 = R.call_tir(cls.reshape3, (lv2433,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2436 = R.call_tir(cls.reshape3, (lv2434,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2438 = R.call_tir(cls.transpose6, (lv2435,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2439 = R.call_tir(cls.transpose6, (lv2436,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv396 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv391, lv2438, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv397 = R.call_tir(cls.fused_softmax_cast1, (lv396,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2448 = R.call_tir(cls.matmul9, (lv397, lv2439), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv398 = R.call_tir(cls.fused_transpose7_reshape4, (lv2448,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv399: R.Tensor((4096, 512), dtype="uint32") = params[164] lv400: R.Tensor((4096, 128), dtype="float16") = params[165] lv32_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv399, lv400, lv398, lv31_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv693: R.Tensor((4096,), dtype="float16") = params[171] lv2454 = R.call_tir(cls.rms_norm1, (lv32_3, lv693), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv403: R.Tensor((22016, 512), dtype="uint32") = params[166] lv404: R.Tensor((22016, 128), dtype="float16") = params[167] lv33_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv403, lv404, lv2454), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv406 = R.call_tir(cls.fused_split_silu_multiply, (lv33_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv407: R.Tensor((4096, 1376), dtype="uint32") = params[168] lv408: R.Tensor((4096, 344), dtype="float16") = params[169] lv33_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv407, lv408, lv406, lv32_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv700: R.Tensor((4096,), dtype="float16") = params[180] lv2465 = R.call_tir(cls.rms_norm1, (lv33_2, lv700), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv411: R.Tensor((12288, 512), dtype="uint32") = params[172] lv412: R.Tensor((12288, 128), dtype="float16") = params[173] lv34_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv411, lv412, lv2465), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv68_1 = R.call_tir(cls.split_rotary, (lv34_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv414: R.Tensor((1, 1, 4096), dtype="float16") = lv68_1[0] lv415 = R.call_tir(cls.fused_reshape2_transpose5, (lv414,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv416: R.Tensor((1, 1, 4096), dtype="float16") = lv68_1[1] lv417 = R.call_tir(cls.fused_reshape2_squeeze, (lv416,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv418: R.Tensor((1, 1, 4096), dtype="float16") = lv68_1[2] lv419 = R.call_tir(cls.fused_reshape2_squeeze, (lv418,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2479: R.Object = kv_cache[34] lv2480: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2479, lv417, sinfo_args=(R.Object,)) lv2481: R.Object = kv_cache[35] lv2482: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2481, lv419, sinfo_args=(R.Object,)) lv2483: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2480, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2484: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2482, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2485 = R.call_tir(cls.reshape3, (lv2483,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2486 = R.call_tir(cls.reshape3, (lv2484,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2488 = R.call_tir(cls.transpose6, (lv2485,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2489 = R.call_tir(cls.transpose6, (lv2486,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv420 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv415, lv2488, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv421 = R.call_tir(cls.fused_softmax_cast1, (lv420,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2498 = R.call_tir(cls.matmul9, (lv421, lv2489), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv422 = R.call_tir(cls.fused_transpose7_reshape4, (lv2498,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv423: R.Tensor((4096, 512), dtype="uint32") = params[174] lv424: R.Tensor((4096, 128), dtype="float16") = params[175] lv34_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv423, lv424, lv422, lv33_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv707: R.Tensor((4096,), dtype="float16") = params[181] lv2504 = R.call_tir(cls.rms_norm1, (lv34_2, lv707), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv427: R.Tensor((22016, 512), dtype="uint32") = params[176] lv428: R.Tensor((22016, 128), dtype="float16") = params[177] lv35_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv427, lv428, lv2504), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv430 = R.call_tir(cls.fused_split_silu_multiply, (lv35_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv431: R.Tensor((4096, 1376), dtype="uint32") = params[178] lv432: R.Tensor((4096, 344), dtype="float16") = params[179] lv35_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv431, lv432, lv430, lv34_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv714: R.Tensor((4096,), dtype="float16") = params[190] lv2515 = R.call_tir(cls.rms_norm1, (lv35_2, lv714), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv435: R.Tensor((12288, 512), dtype="uint32") = params[182] lv436: R.Tensor((12288, 128), dtype="float16") = params[183] lv36_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv435, lv436, lv2515), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv72_1 = R.call_tir(cls.split_rotary, (lv36_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv438: R.Tensor((1, 1, 4096), dtype="float16") = lv72_1[0] lv439 = R.call_tir(cls.fused_reshape2_transpose5, (lv438,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv440: R.Tensor((1, 1, 4096), dtype="float16") = lv72_1[1] lv441 = R.call_tir(cls.fused_reshape2_squeeze, (lv440,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv442: R.Tensor((1, 1, 4096), dtype="float16") = lv72_1[2] lv443 = R.call_tir(cls.fused_reshape2_squeeze, (lv442,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2529: R.Object = kv_cache[36] lv2530: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2529, lv441, sinfo_args=(R.Object,)) lv2531: R.Object = kv_cache[37] lv2532: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2531, lv443, sinfo_args=(R.Object,)) lv2533: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2530, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2534: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2532, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2535 = R.call_tir(cls.reshape3, (lv2533,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2536 = R.call_tir(cls.reshape3, (lv2534,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2538 = R.call_tir(cls.transpose6, (lv2535,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2539 = R.call_tir(cls.transpose6, (lv2536,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv444 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv439, lv2538, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv445 = R.call_tir(cls.fused_softmax_cast1, (lv444,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2548 = R.call_tir(cls.matmul9, (lv445, lv2539), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv446 = R.call_tir(cls.fused_transpose7_reshape4, (lv2548,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv447: R.Tensor((4096, 512), dtype="uint32") = params[184] lv448: R.Tensor((4096, 128), dtype="float16") = params[185] lv36_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv447, lv448, lv446, lv35_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv721: R.Tensor((4096,), dtype="float16") = params[191] lv2554 = R.call_tir(cls.rms_norm1, (lv36_3, lv721), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv451: R.Tensor((22016, 512), dtype="uint32") = params[186] lv452: R.Tensor((22016, 128), dtype="float16") = params[187] lv37_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv451, lv452, lv2554), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv454 = R.call_tir(cls.fused_split_silu_multiply, (lv37_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv455: R.Tensor((4096, 1376), dtype="uint32") = params[188] lv456: R.Tensor((4096, 344), dtype="float16") = params[189] lv37_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv455, lv456, lv454, lv36_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv728: R.Tensor((4096,), dtype="float16") = params[200] lv2565 = R.call_tir(cls.rms_norm1, (lv37_2, lv728), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv459: R.Tensor((12288, 512), dtype="uint32") = params[192] lv460_1: R.Tensor((12288, 128), dtype="float16") = params[193] lv38_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv459, lv460_1, lv2565), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv76_1 = R.call_tir(cls.split_rotary, (lv38_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv462: R.Tensor((1, 1, 4096), dtype="float16") = lv76_1[0] lv463 = R.call_tir(cls.fused_reshape2_transpose5, (lv462,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv464_1: R.Tensor((1, 1, 4096), dtype="float16") = lv76_1[1] lv465_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv464_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv466: R.Tensor((1, 1, 4096), dtype="float16") = lv76_1[2] lv467 = R.call_tir(cls.fused_reshape2_squeeze, (lv466,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2579: R.Object = kv_cache[38] lv2580: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2579, lv465_1, sinfo_args=(R.Object,)) lv2581: R.Object = kv_cache[39] lv2582: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2581, lv467, sinfo_args=(R.Object,)) lv2583: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2580, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2584: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2582, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2585 = R.call_tir(cls.reshape3, (lv2583,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2586 = R.call_tir(cls.reshape3, (lv2584,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2588 = R.call_tir(cls.transpose6, (lv2585,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2589 = R.call_tir(cls.transpose6, (lv2586,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv468 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv463, lv2588, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv469_1 = R.call_tir(cls.fused_softmax_cast1, (lv468,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2598 = R.call_tir(cls.matmul9, (lv469_1, lv2589), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv470 = R.call_tir(cls.fused_transpose7_reshape4, (lv2598,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv471: R.Tensor((4096, 512), dtype="uint32") = params[194] lv472: R.Tensor((4096, 128), dtype="float16") = params[195] lv38_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv471, lv472, lv470, lv37_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv735: R.Tensor((4096,), dtype="float16") = params[201] lv2604 = R.call_tir(cls.rms_norm1, (lv38_2, lv735), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv475: R.Tensor((22016, 512), dtype="uint32") = params[196] lv476_1: R.Tensor((22016, 128), dtype="float16") = params[197] lv39_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv475, lv476_1, lv2604), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv478 = R.call_tir(cls.fused_split_silu_multiply, (lv39_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv479: R.Tensor((4096, 1376), dtype="uint32") = params[198] lv480: R.Tensor((4096, 344), dtype="float16") = params[199] lv39_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv479, lv480, lv478, lv38_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv742: R.Tensor((4096,), dtype="float16") = params[210] lv2615 = R.call_tir(cls.rms_norm1, (lv39_2, lv742), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv483_1: R.Tensor((12288, 512), dtype="uint32") = params[202] lv484: R.Tensor((12288, 128), dtype="float16") = params[203] lv40_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv483_1, lv484, lv2615), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv80_1 = R.call_tir(cls.split_rotary, (lv40_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv486: R.Tensor((1, 1, 4096), dtype="float16") = lv80_1[0] lv487 = R.call_tir(cls.fused_reshape2_transpose5, (lv486,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv488: R.Tensor((1, 1, 4096), dtype="float16") = lv80_1[1] lv489 = R.call_tir(cls.fused_reshape2_squeeze, (lv488,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv490_1: R.Tensor((1, 1, 4096), dtype="float16") = lv80_1[2] lv491 = R.call_tir(cls.fused_reshape2_squeeze, (lv490_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2629: R.Object = kv_cache[40] lv2630: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2629, lv489, sinfo_args=(R.Object,)) lv2631: R.Object = kv_cache[41] lv2632: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2631, lv491, sinfo_args=(R.Object,)) lv2633: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2630, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2634: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2632, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2635 = R.call_tir(cls.reshape3, (lv2633,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2636 = R.call_tir(cls.reshape3, (lv2634,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2638 = R.call_tir(cls.transpose6, (lv2635,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2639 = R.call_tir(cls.transpose6, (lv2636,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv492 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv487, lv2638, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv493 = R.call_tir(cls.fused_softmax_cast1, (lv492,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2648 = R.call_tir(cls.matmul9, (lv493, lv2639), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv494 = R.call_tir(cls.fused_transpose7_reshape4, (lv2648,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv495: R.Tensor((4096, 512), dtype="uint32") = params[204] lv496: R.Tensor((4096, 128), dtype="float16") = params[205] lv40_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv495, lv496, lv494, lv39_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv749: R.Tensor((4096,), dtype="float16") = params[211] lv2654 = R.call_tir(cls.rms_norm1, (lv40_3, lv749), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv499: R.Tensor((22016, 512), dtype="uint32") = params[206] lv500: R.Tensor((22016, 128), dtype="float16") = params[207] lv41 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv499, lv500, lv2654), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv502 = R.call_tir(cls.fused_split_silu_multiply, (lv41,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv503: R.Tensor((4096, 1376), dtype="uint32") = params[208] lv504_1: R.Tensor((4096, 344), dtype="float16") = params[209] lv41_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv503, lv504_1, lv502, lv40_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv756: R.Tensor((4096,), dtype="float16") = params[220] lv2665 = R.call_tir(cls.rms_norm1, (lv41_1, lv756), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv507: R.Tensor((12288, 512), dtype="uint32") = params[212] lv508: R.Tensor((12288, 128), dtype="float16") = params[213] lv42 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv507, lv508, lv2665), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv84_1 = R.call_tir(cls.split_rotary, (lv42, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv510: R.Tensor((1, 1, 4096), dtype="float16") = lv84_1[0] lv511_1 = R.call_tir(cls.fused_reshape2_transpose5, (lv510,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv512: R.Tensor((1, 1, 4096), dtype="float16") = lv84_1[1] lv513 = R.call_tir(cls.fused_reshape2_squeeze, (lv512,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv514: R.Tensor((1, 1, 4096), dtype="float16") = lv84_1[2] lv515 = R.call_tir(cls.fused_reshape2_squeeze, (lv514,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2679: R.Object = kv_cache[42] lv2680: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2679, lv513, sinfo_args=(R.Object,)) lv2681: R.Object = kv_cache[43] lv2682: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2681, lv515, sinfo_args=(R.Object,)) lv2683: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2680, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2684: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2682, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2685 = R.call_tir(cls.reshape3, (lv2683,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2686 = R.call_tir(cls.reshape3, (lv2684,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2688 = R.call_tir(cls.transpose6, (lv2685,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2689 = R.call_tir(cls.transpose6, (lv2686,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv516 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv511_1, lv2688, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv517 = R.call_tir(cls.fused_softmax_cast1, (lv516,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2698 = R.call_tir(cls.matmul9, (lv517, lv2689), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv518_1 = R.call_tir(cls.fused_transpose7_reshape4, (lv2698,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv519: R.Tensor((4096, 512), dtype="uint32") = params[214] lv520: R.Tensor((4096, 128), dtype="float16") = params[215] lv42_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv519, lv520, lv518_1, lv41_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv763: R.Tensor((4096,), dtype="float16") = params[221] lv2704 = R.call_tir(cls.rms_norm1, (lv42_1, lv763), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv523: R.Tensor((22016, 512), dtype="uint32") = params[216] lv524: R.Tensor((22016, 128), dtype="float16") = params[217] lv43_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv523, lv524, lv2704), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv526 = R.call_tir(cls.fused_split_silu_multiply, (lv43_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv527: R.Tensor((4096, 1376), dtype="uint32") = params[218] lv528: R.Tensor((4096, 344), dtype="float16") = params[219] lv43_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv527, lv528, lv526, lv42_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv770: R.Tensor((4096,), dtype="float16") = params[230] lv2715 = R.call_tir(cls.rms_norm1, (lv43_2, lv770), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv531: R.Tensor((12288, 512), dtype="uint32") = params[222] lv532_1: R.Tensor((12288, 128), dtype="float16") = params[223] lv44_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv531, lv532_1, lv2715), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv88_1 = R.call_tir(cls.split_rotary, (lv44_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv534: R.Tensor((1, 1, 4096), dtype="float16") = lv88_1[0] lv535 = R.call_tir(cls.fused_reshape2_transpose5, (lv534,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv536: R.Tensor((1, 1, 4096), dtype="float16") = lv88_1[1] lv537 = R.call_tir(cls.fused_reshape2_squeeze, (lv536,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv538: R.Tensor((1, 1, 4096), dtype="float16") = lv88_1[2] lv539_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv538,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2729: R.Object = kv_cache[44] lv2730: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2729, lv537, sinfo_args=(R.Object,)) lv2731: R.Object = kv_cache[45] lv2732: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2731, lv539_1, sinfo_args=(R.Object,)) lv2733: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2730, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2734: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2732, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2735 = R.call_tir(cls.reshape3, (lv2733,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2736 = R.call_tir(cls.reshape3, (lv2734,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2738 = R.call_tir(cls.transpose6, (lv2735,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2739 = R.call_tir(cls.transpose6, (lv2736,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv540 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv535, lv2738, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv541 = R.call_tir(cls.fused_softmax_cast1, (lv540,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2748 = R.call_tir(cls.matmul9, (lv541, lv2739), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv542 = R.call_tir(cls.fused_transpose7_reshape4, (lv2748,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv543: R.Tensor((4096, 512), dtype="uint32") = params[224] lv544: R.Tensor((4096, 128), dtype="float16") = params[225] lv44_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv543, lv544, lv542, lv43_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv777: R.Tensor((4096,), dtype="float16") = params[231] lv2754 = R.call_tir(cls.rms_norm1, (lv44_3, lv777), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv547: R.Tensor((22016, 512), dtype="uint32") = params[226] lv548: R.Tensor((22016, 128), dtype="float16") = params[227] lv45 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv547, lv548, lv2754), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv550 = R.call_tir(cls.fused_split_silu_multiply, (lv45,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv551: R.Tensor((4096, 1376), dtype="uint32") = params[228] lv552: R.Tensor((4096, 344), dtype="float16") = params[229] lv45_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv551, lv552, lv550, lv44_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv784: R.Tensor((4096,), dtype="float16") = params[240] lv2765 = R.call_tir(cls.rms_norm1, (lv45_1, lv784), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv555: R.Tensor((12288, 512), dtype="uint32") = params[232] lv556: R.Tensor((12288, 128), dtype="float16") = params[233] lv46_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv555, lv556, lv2765), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv92_1 = R.call_tir(cls.split_rotary, (lv46_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv558: R.Tensor((1, 1, 4096), dtype="float16") = lv92_1[0] lv559 = R.call_tir(cls.fused_reshape2_transpose5, (lv558,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv560_1: R.Tensor((1, 1, 4096), dtype="float16") = lv92_1[1] lv561 = R.call_tir(cls.fused_reshape2_squeeze, (lv560_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv562: R.Tensor((1, 1, 4096), dtype="float16") = lv92_1[2] lv563 = R.call_tir(cls.fused_reshape2_squeeze, (lv562,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2779: R.Object = kv_cache[46] lv2780: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2779, lv561, sinfo_args=(R.Object,)) lv2781: R.Object = kv_cache[47] lv2782: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2781, lv563, sinfo_args=(R.Object,)) lv2783: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2780, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2784: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2782, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2785 = R.call_tir(cls.reshape3, (lv2783,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2786 = R.call_tir(cls.reshape3, (lv2784,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2788 = R.call_tir(cls.transpose6, (lv2785,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2789 = R.call_tir(cls.transpose6, (lv2786,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv564 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv559, lv2788, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv565 = R.call_tir(cls.fused_softmax_cast1, (lv564,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2798 = R.call_tir(cls.matmul9, (lv565, lv2789), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv566 = R.call_tir(cls.fused_transpose7_reshape4, (lv2798,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv567_1: R.Tensor((4096, 512), dtype="uint32") = params[234] lv568: R.Tensor((4096, 128), dtype="float16") = params[235] lv46_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv567_1, lv568, lv566, lv45_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv791: R.Tensor((4096,), dtype="float16") = params[241] lv2804 = R.call_tir(cls.rms_norm1, (lv46_2, lv791), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv571: R.Tensor((22016, 512), dtype="uint32") = params[236] lv572: R.Tensor((22016, 128), dtype="float16") = params[237] lv47_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv571, lv572, lv2804), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv574_1 = R.call_tir(cls.fused_split_silu_multiply, (lv47_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv575: R.Tensor((4096, 1376), dtype="uint32") = params[238] lv576: R.Tensor((4096, 344), dtype="float16") = params[239] lv47_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv575, lv576, lv574_1, lv46_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv798: R.Tensor((4096,), dtype="float16") = params[250] lv2815 = R.call_tir(cls.rms_norm1, (lv47_2, lv798), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv579: R.Tensor((12288, 512), dtype="uint32") = params[242] lv580: R.Tensor((12288, 128), dtype="float16") = params[243] lv48_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv579, lv580, lv2815), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv96_1 = R.call_tir(cls.split_rotary, (lv48_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv582: R.Tensor((1, 1, 4096), dtype="float16") = lv96_1[0] lv583 = R.call_tir(cls.fused_reshape2_transpose5, (lv582,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv584: R.Tensor((1, 1, 4096), dtype="float16") = lv96_1[1] lv585 = R.call_tir(cls.fused_reshape2_squeeze, (lv584,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv586: R.Tensor((1, 1, 4096), dtype="float16") = lv96_1[2] lv587 = R.call_tir(cls.fused_reshape2_squeeze, (lv586,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2829: R.Object = kv_cache[48] lv2830: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2829, lv585, sinfo_args=(R.Object,)) lv2831: R.Object = kv_cache[49] lv2832: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2831, lv587, sinfo_args=(R.Object,)) lv2833: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2830, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2834: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2832, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2835 = R.call_tir(cls.reshape3, (lv2833,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2836 = R.call_tir(cls.reshape3, (lv2834,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2838 = R.call_tir(cls.transpose6, (lv2835,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2839 = R.call_tir(cls.transpose6, (lv2836,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv588_1 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv583, lv2838, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv589 = R.call_tir(cls.fused_softmax_cast1, (lv588_1,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2848 = R.call_tir(cls.matmul9, (lv589, lv2839), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv590 = R.call_tir(cls.fused_transpose7_reshape4, (lv2848,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv591: R.Tensor((4096, 512), dtype="uint32") = params[244] lv592: R.Tensor((4096, 128), dtype="float16") = params[245] lv48_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv591, lv592, lv590, lv47_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv805: R.Tensor((4096,), dtype="float16") = params[251] lv2854 = R.call_tir(cls.rms_norm1, (lv48_3, lv805), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv595_1: R.Tensor((22016, 512), dtype="uint32") = params[246] lv596: R.Tensor((22016, 128), dtype="float16") = params[247] lv49 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv595_1, lv596, lv2854), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv598 = R.call_tir(cls.fused_split_silu_multiply, (lv49,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv599: R.Tensor((4096, 1376), dtype="uint32") = params[248] lv600: R.Tensor((4096, 344), dtype="float16") = params[249] lv49_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv599, lv600, lv598, lv48_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv812: R.Tensor((4096,), dtype="float16") = params[260] lv2865 = R.call_tir(cls.rms_norm1, (lv49_1, lv812), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv603: R.Tensor((12288, 512), dtype="uint32") = params[252] lv604: R.Tensor((12288, 128), dtype="float16") = params[253] lv50 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv603, lv604, lv2865), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv100_1 = R.call_tir(cls.split_rotary, (lv50, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv606: R.Tensor((1, 1, 4096), dtype="float16") = lv100_1[0] lv607 = R.call_tir(cls.fused_reshape2_transpose5, (lv606,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv608: R.Tensor((1, 1, 4096), dtype="float16") = lv100_1[1] lv609_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv608,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv610: R.Tensor((1, 1, 4096), dtype="float16") = lv100_1[2] lv611 = R.call_tir(cls.fused_reshape2_squeeze, (lv610,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2879: R.Object = kv_cache[50] lv2880: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2879, lv609_1, sinfo_args=(R.Object,)) lv2881: R.Object = kv_cache[51] lv2882: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2881, lv611, sinfo_args=(R.Object,)) lv2883: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2880, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2884: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2882, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2885 = R.call_tir(cls.reshape3, (lv2883,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2886 = R.call_tir(cls.reshape3, (lv2884,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2888 = R.call_tir(cls.transpose6, (lv2885,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2889 = R.call_tir(cls.transpose6, (lv2886,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv612 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv607, lv2888, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv613 = R.call_tir(cls.fused_softmax_cast1, (lv612,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2898 = R.call_tir(cls.matmul9, (lv613, lv2889), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv614 = R.call_tir(cls.fused_transpose7_reshape4, (lv2898,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv615: R.Tensor((4096, 512), dtype="uint32") = params[254] lv616_1: R.Tensor((4096, 128), dtype="float16") = params[255] lv50_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv615, lv616_1, lv614, lv49_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv819: R.Tensor((4096,), dtype="float16") = params[261] lv2904 = R.call_tir(cls.rms_norm1, (lv50_1, lv819), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv619: R.Tensor((22016, 512), dtype="uint32") = params[256] lv620: R.Tensor((22016, 128), dtype="float16") = params[257] lv51_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv619, lv620, lv2904), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv622 = R.call_tir(cls.fused_split_silu_multiply, (lv51_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv623_1: R.Tensor((4096, 1376), dtype="uint32") = params[258] lv624: R.Tensor((4096, 344), dtype="float16") = params[259] lv51_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv623_1, lv624, lv622, lv50_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv826: R.Tensor((4096,), dtype="float16") = params[270] lv2915 = R.call_tir(cls.rms_norm1, (lv51_2, lv826), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv627: R.Tensor((12288, 512), dtype="uint32") = params[262] lv628: R.Tensor((12288, 128), dtype="float16") = params[263] lv52_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv627, lv628, lv2915), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv104_1 = R.call_tir(cls.split_rotary, (lv52_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv630_1: R.Tensor((1, 1, 4096), dtype="float16") = lv104_1[0] lv631 = R.call_tir(cls.fused_reshape2_transpose5, (lv630_1,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv632: R.Tensor((1, 1, 4096), dtype="float16") = lv104_1[1] lv633 = R.call_tir(cls.fused_reshape2_squeeze, (lv632,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv634: R.Tensor((1, 1, 4096), dtype="float16") = lv104_1[2] lv635 = R.call_tir(cls.fused_reshape2_squeeze, (lv634,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2929: R.Object = kv_cache[52] lv2930: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2929, lv633, sinfo_args=(R.Object,)) lv2931: R.Object = kv_cache[53] lv2932: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2931, lv635, sinfo_args=(R.Object,)) lv2933: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2930, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2934: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2932, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2935 = R.call_tir(cls.reshape3, (lv2933,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2936 = R.call_tir(cls.reshape3, (lv2934,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2938 = R.call_tir(cls.transpose6, (lv2935,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2939 = R.call_tir(cls.transpose6, (lv2936,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv636 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv631, lv2938, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv637_1 = R.call_tir(cls.fused_softmax_cast1, (lv636,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2948 = R.call_tir(cls.matmul9, (lv637_1, lv2939), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv638 = R.call_tir(cls.fused_transpose7_reshape4, (lv2948,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv639: R.Tensor((4096, 512), dtype="uint32") = params[264] lv640: R.Tensor((4096, 128), dtype="float16") = params[265] lv52_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv639, lv640, lv638, lv51_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv833: R.Tensor((4096,), dtype="float16") = params[271] lv2954 = R.call_tir(cls.rms_norm1, (lv52_3, lv833), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv643: R.Tensor((22016, 512), dtype="uint32") = params[266] lv644_1: R.Tensor((22016, 128), dtype="float16") = params[267] lv53 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv643, lv644_1, lv2954), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv646 = R.call_tir(cls.fused_split_silu_multiply, (lv53,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv647: R.Tensor((4096, 1376), dtype="uint32") = params[268] lv648: R.Tensor((4096, 344), dtype="float16") = params[269] lv53_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv647, lv648, lv646, lv52_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv840: R.Tensor((4096,), dtype="float16") = params[280] lv2965 = R.call_tir(cls.rms_norm1, (lv53_1, lv840), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv651_1: R.Tensor((12288, 512), dtype="uint32") = params[272] lv652: R.Tensor((12288, 128), dtype="float16") = params[273] lv54_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv651_1, lv652, lv2965), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv108_1 = R.call_tir(cls.split_rotary, (lv54_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv654: R.Tensor((1, 1, 4096), dtype="float16") = lv108_1[0] lv655 = R.call_tir(cls.fused_reshape2_transpose5, (lv654,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv656: R.Tensor((1, 1, 4096), dtype="float16") = lv108_1[1] lv657 = R.call_tir(cls.fused_reshape2_squeeze, (lv656,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv658_1: R.Tensor((1, 1, 4096), dtype="float16") = lv108_1[2] lv659 = R.call_tir(cls.fused_reshape2_squeeze, (lv658_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv2979: R.Object = kv_cache[54] lv2980: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2979, lv657, sinfo_args=(R.Object,)) lv2981: R.Object = kv_cache[55] lv2982: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv2981, lv659, sinfo_args=(R.Object,)) lv2983: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2980, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2984: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv2982, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv2985 = R.call_tir(cls.reshape3, (lv2983,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2986 = R.call_tir(cls.reshape3, (lv2984,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv2988 = R.call_tir(cls.transpose6, (lv2985,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv2989 = R.call_tir(cls.transpose6, (lv2986,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv660 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv655, lv2988, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv661 = R.call_tir(cls.fused_softmax_cast1, (lv660,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv2998 = R.call_tir(cls.matmul9, (lv661, lv2989), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv662 = R.call_tir(cls.fused_transpose7_reshape4, (lv2998,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv663: R.Tensor((4096, 512), dtype="uint32") = params[274] lv664: R.Tensor((4096, 128), dtype="float16") = params[275] lv54_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv663, lv664, lv662, lv53_1), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv847: R.Tensor((4096,), dtype="float16") = params[281] lv3004 = R.call_tir(cls.rms_norm1, (lv54_2, lv847), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv667: R.Tensor((22016, 512), dtype="uint32") = params[276] lv668: R.Tensor((22016, 128), dtype="float16") = params[277] lv55_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv667, lv668, lv3004), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv670 = R.call_tir(cls.fused_split_silu_multiply, (lv55_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv671: R.Tensor((4096, 1376), dtype="uint32") = params[278] lv672_1: R.Tensor((4096, 344), dtype="float16") = params[279] lv55_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv671, lv672_1, lv670, lv54_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv854: R.Tensor((4096,), dtype="float16") = params[290] lv3015 = R.call_tir(cls.rms_norm1, (lv55_2, lv854), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv675: R.Tensor((12288, 512), dtype="uint32") = params[282] lv676: R.Tensor((12288, 128), dtype="float16") = params[283] lv56_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv675, lv676, lv3015), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv112_1 = R.call_tir(cls.split_rotary, (lv56_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv678: R.Tensor((1, 1, 4096), dtype="float16") = lv112_1[0] lv679_1 = R.call_tir(cls.fused_reshape2_transpose5, (lv678,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv680: R.Tensor((1, 1, 4096), dtype="float16") = lv112_1[1] lv681 = R.call_tir(cls.fused_reshape2_squeeze, (lv680,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv682: R.Tensor((1, 1, 4096), dtype="float16") = lv112_1[2] lv683 = R.call_tir(cls.fused_reshape2_squeeze, (lv682,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv3029: R.Object = kv_cache[56] lv3030: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3029, lv681, sinfo_args=(R.Object,)) lv3031: R.Object = kv_cache[57] lv3032: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3031, lv683, sinfo_args=(R.Object,)) lv3033: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3030, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv3034: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3032, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv3035 = R.call_tir(cls.reshape3, (lv3033,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv3036 = R.call_tir(cls.reshape3, (lv3034,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv3038 = R.call_tir(cls.transpose6, (lv3035,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv3039 = R.call_tir(cls.transpose6, (lv3036,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv684 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv679_1, lv3038, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv685 = R.call_tir(cls.fused_softmax_cast1, (lv684,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv3048 = R.call_tir(cls.matmul9, (lv685, lv3039), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv686_1 = R.call_tir(cls.fused_transpose7_reshape4, (lv3048,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv687: R.Tensor((4096, 512), dtype="uint32") = params[284] lv688: R.Tensor((4096, 128), dtype="float16") = params[285] lv56_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv687, lv688, lv686_1, lv55_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv861: R.Tensor((4096,), dtype="float16") = params[291] lv3054 = R.call_tir(cls.rms_norm1, (lv56_3, lv861), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv691: R.Tensor((22016, 512), dtype="uint32") = params[286] lv692: R.Tensor((22016, 128), dtype="float16") = params[287] lv57_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv691, lv692, lv3054), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv694 = R.call_tir(cls.fused_split_silu_multiply, (lv57_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv695: R.Tensor((4096, 1376), dtype="uint32") = params[288] lv696: R.Tensor((4096, 344), dtype="float16") = params[289] lv57_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv695, lv696, lv694, lv56_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv868: R.Tensor((4096,), dtype="float16") = params[300] lv3065 = R.call_tir(cls.rms_norm1, (lv57_2, lv868), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv699: R.Tensor((12288, 512), dtype="uint32") = params[292] lv700_1: R.Tensor((12288, 128), dtype="float16") = params[293] lv58_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv699, lv700_1, lv3065), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv116_1 = R.call_tir(cls.split_rotary, (lv58_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv702: R.Tensor((1, 1, 4096), dtype="float16") = lv116_1[0] lv703 = R.call_tir(cls.fused_reshape2_transpose5, (lv702,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv704: R.Tensor((1, 1, 4096), dtype="float16") = lv116_1[1] lv705 = R.call_tir(cls.fused_reshape2_squeeze, (lv704,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv706: R.Tensor((1, 1, 4096), dtype="float16") = lv116_1[2] lv707_1 = R.call_tir(cls.fused_reshape2_squeeze, (lv706,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv3079: R.Object = kv_cache[58] lv3080: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3079, lv705, sinfo_args=(R.Object,)) lv3081: R.Object = kv_cache[59] lv3082: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3081, lv707_1, sinfo_args=(R.Object,)) lv3083: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3080, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv3084: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3082, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv3085 = R.call_tir(cls.reshape3, (lv3083,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv3086 = R.call_tir(cls.reshape3, (lv3084,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv3088 = R.call_tir(cls.transpose6, (lv3085,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv3089 = R.call_tir(cls.transpose6, (lv3086,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv708 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv703, lv3088, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv709 = R.call_tir(cls.fused_softmax_cast1, (lv708,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv3098 = R.call_tir(cls.matmul9, (lv709, lv3089), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv710 = R.call_tir(cls.fused_transpose7_reshape4, (lv3098,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv711: R.Tensor((4096, 512), dtype="uint32") = params[294] lv712: R.Tensor((4096, 128), dtype="float16") = params[295] lv58_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv711, lv712, lv710, lv57_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv875: R.Tensor((4096,), dtype="float16") = params[301] lv3104 = R.call_tir(cls.rms_norm1, (lv58_2, lv875), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv715: R.Tensor((22016, 512), dtype="uint32") = params[296] lv716: R.Tensor((22016, 128), dtype="float16") = params[297] lv59_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv715, lv716, lv3104), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv718 = R.call_tir(cls.fused_split_silu_multiply, (lv59_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv719: R.Tensor((4096, 1376), dtype="uint32") = params[298] lv720: R.Tensor((4096, 344), dtype="float16") = params[299] lv59_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv719, lv720, lv718, lv58_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv882: R.Tensor((4096,), dtype="float16") = params[310] lv3115 = R.call_tir(cls.rms_norm1, (lv59_2, lv882), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv723: R.Tensor((12288, 512), dtype="uint32") = params[302] lv724: R.Tensor((12288, 128), dtype="float16") = params[303] lv60_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv723, lv724, lv3115), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv120_1 = R.call_tir(cls.split_rotary, (lv60_2, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv726: R.Tensor((1, 1, 4096), dtype="float16") = lv120_1[0] lv727 = R.call_tir(cls.fused_reshape2_transpose5, (lv726,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv728_1: R.Tensor((1, 1, 4096), dtype="float16") = lv120_1[1] lv729 = R.call_tir(cls.fused_reshape2_squeeze, (lv728_1,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv730: R.Tensor((1, 1, 4096), dtype="float16") = lv120_1[2] lv731 = R.call_tir(cls.fused_reshape2_squeeze, (lv730,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv3129: R.Object = kv_cache[60] lv3130: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3129, lv729, sinfo_args=(R.Object,)) lv3131: R.Object = kv_cache[61] lv3132: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3131, lv731, sinfo_args=(R.Object,)) lv3133: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3130, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv3134: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3132, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv3135 = R.call_tir(cls.reshape3, (lv3133,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv3136 = R.call_tir(cls.reshape3, (lv3134,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv3138 = R.call_tir(cls.transpose6, (lv3135,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv3139 = R.call_tir(cls.transpose6, (lv3136,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv732 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv727, lv3138, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv733 = R.call_tir(cls.fused_softmax_cast1, (lv732,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv3148 = R.call_tir(cls.matmul9, (lv733, lv3139), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv734 = R.call_tir(cls.fused_transpose7_reshape4, (lv3148,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv735_1: R.Tensor((4096, 512), dtype="uint32") = params[304] lv736: R.Tensor((4096, 128), dtype="float16") = params[305] lv60_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv735_1, lv736, lv734, lv59_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv889: R.Tensor((4096,), dtype="float16") = params[311] lv3154 = R.call_tir(cls.rms_norm1, (lv60_3, lv889), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv739: R.Tensor((22016, 512), dtype="uint32") = params[306] lv740: R.Tensor((22016, 128), dtype="float16") = params[307] lv61_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv739, lv740, lv3154), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv742_1 = R.call_tir(cls.fused_split_silu_multiply, (lv61_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv743: R.Tensor((4096, 1376), dtype="uint32") = params[308] lv744: R.Tensor((4096, 344), dtype="float16") = params[309] lv61_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv743, lv744, lv742_1, lv60_3), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv896: R.Tensor((4096,), dtype="float16") = params[320] lv3165 = R.call_tir(cls.rms_norm1, (lv61_2, lv896), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv747: R.Tensor((12288, 512), dtype="uint32") = params[312] lv748: R.Tensor((12288, 128), dtype="float16") = params[313] lv62_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul6, (lv747, lv748, lv3165), out_sinfo=R.Tensor((1, 1, 12288), dtype="float16")) lv124_1 = R.call_tir(cls.split_rotary, (lv62_1, lv464, lv465), out_sinfo=[R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")], tir_vars=R.shape([n])) lv750: R.Tensor((1, 1, 4096), dtype="float16") = lv124_1[0] lv751 = R.call_tir(cls.fused_reshape2_transpose5, (lv750,), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv752: R.Tensor((1, 1, 4096), dtype="float16") = lv124_1[1] lv753 = R.call_tir(cls.fused_reshape2_squeeze, (lv752,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv754: R.Tensor((1, 1, 4096), dtype="float16") = lv124_1[2] lv755 = R.call_tir(cls.fused_reshape2_squeeze, (lv754,), out_sinfo=R.Tensor((1, 32, 128), dtype="float16")) lv3179: R.Object = kv_cache[62] lv3180: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3179, lv753, sinfo_args=(R.Object,)) lv3181: R.Object = kv_cache[63] lv3182: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv3181, lv755, sinfo_args=(R.Object,)) lv3183: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3180, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv3184: R.Tensor((n, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv3182, R.shape([n, 32, 128]), sinfo_args=(R.Tensor((n, 32, 128), dtype="float16"),)) lv3185 = R.call_tir(cls.reshape3, (lv3183,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv3186 = R.call_tir(cls.reshape3, (lv3184,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv3188 = R.call_tir(cls.transpose6, (lv3185,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv3189 = R.call_tir(cls.transpose6, (lv3186,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv756_1 = R.call_tir(cls.fused_NT_matmul7_divide_maximum_minimum_cast, (lv751, lv3188, lv1614), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float32")) lv757 = R.call_tir(cls.fused_softmax_cast1, (lv756_1,), out_sinfo=R.Tensor((1, 32, 1, n), dtype="float16")) lv3198 = R.call_tir(cls.matmul9, (lv757, lv3189), out_sinfo=R.Tensor((1, 32, 1, 128), dtype="float16")) lv758 = R.call_tir(cls.fused_transpose7_reshape4, (lv3198,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv759: R.Tensor((4096, 512), dtype="uint32") = params[314] lv760: R.Tensor((4096, 128), dtype="float16") = params[315] lv62_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul8_add, (lv759, lv760, lv758, lv61_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv903: R.Tensor((4096,), dtype="float16") = params[321] lv3204 = R.call_tir(cls.rms_norm1, (lv62_2, lv903), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv763_1: R.Tensor((22016, 512), dtype="uint32") = params[316] lv764: R.Tensor((22016, 128), dtype="float16") = params[317] lv63_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul9, (lv763_1, lv764, lv3204), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16")) lv766 = R.call_tir(cls.fused_split_silu_multiply, (lv63_1,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16")) lv767: R.Tensor((4096, 1376), dtype="uint32") = params[318] lv768: R.Tensor((4096, 344), dtype="float16") = params[319] lv63_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul10_add, (lv767, lv768, lv766, lv62_2), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv910: R.Tensor((4096,), dtype="float16") = params[322] lv3215 = R.call_tir(cls.rms_norm1, (lv63_2, lv910), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv3216 = R.call_tir(cls.slice1, (lv3215,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv771: R.Tensor((32001, 512), dtype="uint32") = params[323] lv772: R.Tensor((32001, 128), dtype="float16") = params[324] lv64_2 = R.call_tir(cls.fused_fused_decode1_fused_NT_matmul5_cast2, (lv771, lv772, lv3216), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32")) gv1: R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)) = lv64_2, (lv1630, lv1632, lv1680, lv1682, lv1730, lv1732, lv1780, lv1782, lv1830, lv1832, lv1880, lv1882, lv1930, lv1932, lv1980, lv1982, lv2030, lv2032, lv2080, lv2082, lv2130, lv2132, lv2180, lv2182, lv2230, lv2232, lv2280, lv2282, lv2330, lv2332, lv2380, lv2382, lv2430, lv2432, lv2480, lv2482, lv2530, lv2532, lv2580, lv2582, lv2630, lv2632, lv2680, lv2682, lv2730, lv2732, lv2780, lv2782, lv2830, lv2832, lv2880, lv2882, lv2930, lv2932, lv2980, lv2982, lv3030, lv3032, lv3080, lv3082, lv3130, lv3132, lv3180, lv3182) R.output(gv1) return gv1 @R.function def get_metadata() -> R.Object: R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}}) return R.str("{\"model_name\": \"WizardMath-7B-V1.0\", \"max_window_size\": 2048, \"stop_tokens\": [2], \"add_prefix_space\": false}") @R.function def prefill(input_ids: R.Tensor((1, "n"), dtype="int32"), all_seq_len: R.Shape(["m"]), kv_cache: R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object), params: R.Tuple(R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((12288, 512), dtype="uint32"), R.Tensor((12288, 128), dtype="float16"), R.Tensor((4096, 512), dtype="uint32"), R.Tensor((4096, 128), dtype="float16"), R.Tensor((22016, 512), dtype="uint32"), R.Tensor((22016, 128), dtype="float16"), R.Tensor((4096, 1376), dtype="uint32"), R.Tensor((4096, 344), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((4096,), dtype="float16"), R.Tensor((32001, 512), dtype="uint32"), R.Tensor((32001, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"), R.Tensor((2048, 128), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)): n = T.int64() m = T.int64() R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}}) cls = Module with R.dataflow(): lv = R.call_tir(cls.reshape5, (input_ids,), out_sinfo=R.Tensor((n,), dtype="int32")) lv775: R.Tensor((32001, 512), dtype="uint32") = params[0] lv776: R.Tensor((32001, 128), dtype="float16") = params[1] lv_1 = R.call_tir(cls.fused_fused_decode1_take1, (lv775, lv776, lv), out_sinfo=R.Tensor((n, 4096), dtype="float16")) lv2 = R.call_tir(cls.reshape6, (lv_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv778 = R.call_tir(cls.fused_min_max_triu_te_broadcast_to, R.tuple(), out_sinfo=R.Tensor((1, 1, n, n), dtype="float16"), tir_vars=R.shape([n])) lv5 = R.call_tir(cls.extend_te, (lv778,), out_sinfo=R.Tensor((1, 1, n, m), dtype="float16")) lv3: R.Tensor((4096,), dtype="float16") = params[10] lv6 = R.call_tir(cls.rms_norm, (lv2, lv3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv779: R.Tensor((12288, 512), dtype="uint32") = params[2] lv780: R.Tensor((12288, 128), dtype="float16") = params[3] lv65 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv779, lv780, lv6), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv9 = R.call_tir(cls.split1, (lv65,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv10: R.Tensor((1, n, 4096), dtype="float16") = lv9[0] lv11 = R.call_tir(cls.reshape7, (lv10,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv12: R.Tensor((1, n, 4096), dtype="float16") = lv9[1] lv13 = R.call_tir(cls.reshape7, (lv12,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv14: R.Tensor((1, n, 4096), dtype="float16") = lv9[2] lv15 = R.call_tir(cls.reshape7, (lv14,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv7: R.Tensor((2048, 128), dtype="float16") = params[325] lv8: R.Tensor((2048, 128), dtype="float16") = params[326] lv16 = R.call_tir(cls.rotary_embedding, (lv11, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv17 = R.call_tir(cls.rotary_embedding, (lv13, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv18 = R.call_tir(cls.squeeze1, (lv17,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv19 = R.call_tir(cls.squeeze1, (lv15,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv20: R.Object = kv_cache[0] lv21: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv20, lv18, sinfo_args=(R.Object,)) lv22: R.Object = kv_cache[1] lv23: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv22, lv19, sinfo_args=(R.Object,)) lv24: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv21, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv25: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv23, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv26 = R.call_tir(cls.reshape3, (lv24,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv27 = R.call_tir(cls.reshape3, (lv25,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv28 = R.call_tir(cls.transpose6, (lv16,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv29 = R.call_tir(cls.transpose6, (lv26,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv30 = R.call_tir(cls.transpose6, (lv27,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv782 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv28, lv29, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv783 = R.call_tir(cls.fused_softmax1_cast4, (lv782,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv39 = R.call_tir(cls.matmul10, (lv783, lv30), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv40 = R.call_tir(cls.transpose8, (lv39,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv41 = R.call_tir(cls.reshape8, (lv40,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv784: R.Tensor((4096, 512), dtype="uint32") = params[4] lv785: R.Tensor((4096, 128), dtype="float16") = params[5] lv64 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv784, lv785, lv41, lv2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv12_1: R.Tensor((4096,), dtype="float16") = params[11] lv45 = R.call_tir(cls.rms_norm, (lv64, lv12_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv788: R.Tensor((22016, 512), dtype="uint32") = params[6] lv789: R.Tensor((22016, 128), dtype="float16") = params[7] lv66 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv788, lv789, lv45), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv791 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv66,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv792: R.Tensor((4096, 1376), dtype="uint32") = params[8] lv793: R.Tensor((4096, 344), dtype="float16") = params[9] lv65_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv792, lv793, lv791, lv64), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv19_1: R.Tensor((4096,), dtype="float16") = params[20] lv56 = R.call_tir(cls.rms_norm, (lv65_1, lv19_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv796: R.Tensor((12288, 512), dtype="uint32") = params[12] lv797: R.Tensor((12288, 128), dtype="float16") = params[13] lv67 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv796, lv797, lv56), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv59 = R.call_tir(cls.split1, (lv67,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv60: R.Tensor((1, n, 4096), dtype="float16") = lv59[0] lv61 = R.call_tir(cls.reshape7, (lv60,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv62: R.Tensor((1, n, 4096), dtype="float16") = lv59[1] lv63 = R.call_tir(cls.reshape7, (lv62,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv64_1: R.Tensor((1, n, 4096), dtype="float16") = lv59[2] lv65_2 = R.call_tir(cls.reshape7, (lv64_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv66_1 = R.call_tir(cls.rotary_embedding, (lv61, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv67_1 = R.call_tir(cls.rotary_embedding, (lv63, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv68 = R.call_tir(cls.squeeze1, (lv67_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv69 = R.call_tir(cls.squeeze1, (lv65_2,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv70: R.Object = kv_cache[2] lv71: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv70, lv68, sinfo_args=(R.Object,)) lv72: R.Object = kv_cache[3] lv73: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv72, lv69, sinfo_args=(R.Object,)) lv74: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv71, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv75: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv73, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv76 = R.call_tir(cls.reshape3, (lv74,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv77 = R.call_tir(cls.reshape3, (lv75,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv78 = R.call_tir(cls.transpose6, (lv66_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv79 = R.call_tir(cls.transpose6, (lv76,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv80 = R.call_tir(cls.transpose6, (lv77,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv799 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv78, lv79, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv800 = R.call_tir(cls.fused_softmax1_cast4, (lv799,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv89 = R.call_tir(cls.matmul10, (lv800, lv80), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv90 = R.call_tir(cls.transpose8, (lv89,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv91 = R.call_tir(cls.reshape8, (lv90,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv801: R.Tensor((4096, 512), dtype="uint32") = params[14] lv802: R.Tensor((4096, 128), dtype="float16") = params[15] lv66_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv801, lv802, lv91, lv65_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv26_1: R.Tensor((4096,), dtype="float16") = params[21] lv95 = R.call_tir(cls.rms_norm, (lv66_2, lv26_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv805: R.Tensor((22016, 512), dtype="uint32") = params[16] lv806: R.Tensor((22016, 128), dtype="float16") = params[17] lv68_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv805, lv806, lv95), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv808 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv68_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv809: R.Tensor((4096, 1376), dtype="uint32") = params[18] lv810: R.Tensor((4096, 344), dtype="float16") = params[19] lv67_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv809, lv810, lv808, lv66_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv33: R.Tensor((4096,), dtype="float16") = params[30] lv106 = R.call_tir(cls.rms_norm, (lv67_2, lv33), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv813: R.Tensor((12288, 512), dtype="uint32") = params[22] lv814: R.Tensor((12288, 128), dtype="float16") = params[23] lv69_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv813, lv814, lv106), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv109 = R.call_tir(cls.split1, (lv69_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv110: R.Tensor((1, n, 4096), dtype="float16") = lv109[0] lv111 = R.call_tir(cls.reshape7, (lv110,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv112: R.Tensor((1, n, 4096), dtype="float16") = lv109[1] lv113 = R.call_tir(cls.reshape7, (lv112,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv114: R.Tensor((1, n, 4096), dtype="float16") = lv109[2] lv115 = R.call_tir(cls.reshape7, (lv114,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv116 = R.call_tir(cls.rotary_embedding, (lv111, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv117 = R.call_tir(cls.rotary_embedding, (lv113, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv118 = R.call_tir(cls.squeeze1, (lv117,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv119 = R.call_tir(cls.squeeze1, (lv115,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv120: R.Object = kv_cache[4] lv121: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv120, lv118, sinfo_args=(R.Object,)) lv122: R.Object = kv_cache[5] lv123: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv122, lv119, sinfo_args=(R.Object,)) lv124: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv121, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv125: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv123, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv126 = R.call_tir(cls.reshape3, (lv124,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv127 = R.call_tir(cls.reshape3, (lv125,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv128 = R.call_tir(cls.transpose6, (lv116,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv129 = R.call_tir(cls.transpose6, (lv126,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv130 = R.call_tir(cls.transpose6, (lv127,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv816 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv128, lv129, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv817 = R.call_tir(cls.fused_softmax1_cast4, (lv816,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv139 = R.call_tir(cls.matmul10, (lv817, lv130), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv140 = R.call_tir(cls.transpose8, (lv139,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv141 = R.call_tir(cls.reshape8, (lv140,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv818: R.Tensor((4096, 512), dtype="uint32") = params[24] lv819: R.Tensor((4096, 128), dtype="float16") = params[25] lv68_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv818, lv819, lv141, lv67_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv40_1: R.Tensor((4096,), dtype="float16") = params[31] lv145 = R.call_tir(cls.rms_norm, (lv68_2, lv40_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv822: R.Tensor((22016, 512), dtype="uint32") = params[26] lv823: R.Tensor((22016, 128), dtype="float16") = params[27] lv70_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv822, lv823, lv145), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv825 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv70_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv826: R.Tensor((4096, 1376), dtype="uint32") = params[28] lv827: R.Tensor((4096, 344), dtype="float16") = params[29] lv69_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv826, lv827, lv825, lv68_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv47: R.Tensor((4096,), dtype="float16") = params[40] lv156 = R.call_tir(cls.rms_norm, (lv69_2, lv47), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv830: R.Tensor((12288, 512), dtype="uint32") = params[32] lv831: R.Tensor((12288, 128), dtype="float16") = params[33] lv71_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv830, lv831, lv156), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv159 = R.call_tir(cls.split1, (lv71_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv160: R.Tensor((1, n, 4096), dtype="float16") = lv159[0] lv161 = R.call_tir(cls.reshape7, (lv160,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv162: R.Tensor((1, n, 4096), dtype="float16") = lv159[1] lv163 = R.call_tir(cls.reshape7, (lv162,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv164: R.Tensor((1, n, 4096), dtype="float16") = lv159[2] lv165 = R.call_tir(cls.reshape7, (lv164,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv166 = R.call_tir(cls.rotary_embedding, (lv161, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv167 = R.call_tir(cls.rotary_embedding, (lv163, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv168 = R.call_tir(cls.squeeze1, (lv167,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv169 = R.call_tir(cls.squeeze1, (lv165,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv170: R.Object = kv_cache[6] lv171: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv170, lv168, sinfo_args=(R.Object,)) lv172: R.Object = kv_cache[7] lv173: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv172, lv169, sinfo_args=(R.Object,)) lv174: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv171, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv175: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv173, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv176 = R.call_tir(cls.reshape3, (lv174,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv177 = R.call_tir(cls.reshape3, (lv175,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv178 = R.call_tir(cls.transpose6, (lv166,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv179 = R.call_tir(cls.transpose6, (lv176,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv180 = R.call_tir(cls.transpose6, (lv177,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv833 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv178, lv179, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv834 = R.call_tir(cls.fused_softmax1_cast4, (lv833,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv189 = R.call_tir(cls.matmul10, (lv834, lv180), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv190 = R.call_tir(cls.transpose8, (lv189,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv191 = R.call_tir(cls.reshape8, (lv190,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv835: R.Tensor((4096, 512), dtype="uint32") = params[34] lv836: R.Tensor((4096, 128), dtype="float16") = params[35] lv70_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv835, lv836, lv191, lv69_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv54: R.Tensor((4096,), dtype="float16") = params[41] lv195 = R.call_tir(cls.rms_norm, (lv70_2, lv54), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv839: R.Tensor((22016, 512), dtype="uint32") = params[36] lv840: R.Tensor((22016, 128), dtype="float16") = params[37] lv72_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv839, lv840, lv195), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv842 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv72_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv843: R.Tensor((4096, 1376), dtype="uint32") = params[38] lv844: R.Tensor((4096, 344), dtype="float16") = params[39] lv71_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv843, lv844, lv842, lv70_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv61_1: R.Tensor((4096,), dtype="float16") = params[50] lv206 = R.call_tir(cls.rms_norm, (lv71_2, lv61_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv847: R.Tensor((12288, 512), dtype="uint32") = params[42] lv848: R.Tensor((12288, 128), dtype="float16") = params[43] lv73_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv847, lv848, lv206), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv209 = R.call_tir(cls.split1, (lv73_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv210: R.Tensor((1, n, 4096), dtype="float16") = lv209[0] lv211 = R.call_tir(cls.reshape7, (lv210,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv212: R.Tensor((1, n, 4096), dtype="float16") = lv209[1] lv213 = R.call_tir(cls.reshape7, (lv212,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv214: R.Tensor((1, n, 4096), dtype="float16") = lv209[2] lv215 = R.call_tir(cls.reshape7, (lv214,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv216 = R.call_tir(cls.rotary_embedding, (lv211, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv217 = R.call_tir(cls.rotary_embedding, (lv213, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv218 = R.call_tir(cls.squeeze1, (lv217,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv219 = R.call_tir(cls.squeeze1, (lv215,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv220: R.Object = kv_cache[8] lv221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv220, lv218, sinfo_args=(R.Object,)) lv222: R.Object = kv_cache[9] lv223: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv222, lv219, sinfo_args=(R.Object,)) lv224: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv221, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv225: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv223, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv226 = R.call_tir(cls.reshape3, (lv224,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv227 = R.call_tir(cls.reshape3, (lv225,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv228 = R.call_tir(cls.transpose6, (lv216,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv229 = R.call_tir(cls.transpose6, (lv226,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv230 = R.call_tir(cls.transpose6, (lv227,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv850 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv228, lv229, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv851 = R.call_tir(cls.fused_softmax1_cast4, (lv850,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv239 = R.call_tir(cls.matmul10, (lv851, lv230), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv240 = R.call_tir(cls.transpose8, (lv239,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv241 = R.call_tir(cls.reshape8, (lv240,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv852: R.Tensor((4096, 512), dtype="uint32") = params[44] lv853: R.Tensor((4096, 128), dtype="float16") = params[45] lv72_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv852, lv853, lv241, lv71_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv68_3: R.Tensor((4096,), dtype="float16") = params[51] lv245 = R.call_tir(cls.rms_norm, (lv72_2, lv68_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv856: R.Tensor((22016, 512), dtype="uint32") = params[46] lv857: R.Tensor((22016, 128), dtype="float16") = params[47] lv74_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv856, lv857, lv245), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv859 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv74_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv860: R.Tensor((4096, 1376), dtype="uint32") = params[48] lv861: R.Tensor((4096, 344), dtype="float16") = params[49] lv73_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv860, lv861, lv859, lv72_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv75_1: R.Tensor((4096,), dtype="float16") = params[60] lv256 = R.call_tir(cls.rms_norm, (lv73_2, lv75_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv864: R.Tensor((12288, 512), dtype="uint32") = params[52] lv865: R.Tensor((12288, 128), dtype="float16") = params[53] lv75_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv864, lv865, lv256), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv259 = R.call_tir(cls.split1, (lv75_2,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv260: R.Tensor((1, n, 4096), dtype="float16") = lv259[0] lv261 = R.call_tir(cls.reshape7, (lv260,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv262: R.Tensor((1, n, 4096), dtype="float16") = lv259[1] lv263 = R.call_tir(cls.reshape7, (lv262,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv264: R.Tensor((1, n, 4096), dtype="float16") = lv259[2] lv265 = R.call_tir(cls.reshape7, (lv264,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv266 = R.call_tir(cls.rotary_embedding, (lv261, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv267 = R.call_tir(cls.rotary_embedding, (lv263, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv268 = R.call_tir(cls.squeeze1, (lv267,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv269 = R.call_tir(cls.squeeze1, (lv265,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv270: R.Object = kv_cache[10] lv271: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv270, lv268, sinfo_args=(R.Object,)) lv272: R.Object = kv_cache[11] lv273: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv272, lv269, sinfo_args=(R.Object,)) lv274: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv271, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv275: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv273, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv276 = R.call_tir(cls.reshape3, (lv274,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv277 = R.call_tir(cls.reshape3, (lv275,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv278 = R.call_tir(cls.transpose6, (lv266,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv279 = R.call_tir(cls.transpose6, (lv276,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv280 = R.call_tir(cls.transpose6, (lv277,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv867 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv278, lv279, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv868 = R.call_tir(cls.fused_softmax1_cast4, (lv867,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv289 = R.call_tir(cls.matmul10, (lv868, lv280), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv290 = R.call_tir(cls.transpose8, (lv289,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv291 = R.call_tir(cls.reshape8, (lv290,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv869: R.Tensor((4096, 512), dtype="uint32") = params[54] lv870: R.Tensor((4096, 128), dtype="float16") = params[55] lv74_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv869, lv870, lv291, lv73_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv82: R.Tensor((4096,), dtype="float16") = params[61] lv295 = R.call_tir(cls.rms_norm, (lv74_2, lv82), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv873: R.Tensor((22016, 512), dtype="uint32") = params[56] lv874: R.Tensor((22016, 128), dtype="float16") = params[57] lv76_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv873, lv874, lv295), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv876 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv76_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv877: R.Tensor((4096, 1376), dtype="uint32") = params[58] lv878: R.Tensor((4096, 344), dtype="float16") = params[59] lv75_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv877, lv878, lv876, lv74_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv89_1: R.Tensor((4096,), dtype="float16") = params[70] lv306 = R.call_tir(cls.rms_norm, (lv75_3, lv89_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv881: R.Tensor((12288, 512), dtype="uint32") = params[62] lv882: R.Tensor((12288, 128), dtype="float16") = params[63] lv77_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv881, lv882, lv306), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv309 = R.call_tir(cls.split1, (lv77_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv310: R.Tensor((1, n, 4096), dtype="float16") = lv309[0] lv311 = R.call_tir(cls.reshape7, (lv310,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv312: R.Tensor((1, n, 4096), dtype="float16") = lv309[1] lv313 = R.call_tir(cls.reshape7, (lv312,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv314: R.Tensor((1, n, 4096), dtype="float16") = lv309[2] lv315 = R.call_tir(cls.reshape7, (lv314,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv316 = R.call_tir(cls.rotary_embedding, (lv311, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv317 = R.call_tir(cls.rotary_embedding, (lv313, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv318 = R.call_tir(cls.squeeze1, (lv317,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv319 = R.call_tir(cls.squeeze1, (lv315,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv320: R.Object = kv_cache[12] lv321: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv320, lv318, sinfo_args=(R.Object,)) lv322: R.Object = kv_cache[13] lv323: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv322, lv319, sinfo_args=(R.Object,)) lv324: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv321, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv325: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv323, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv326 = R.call_tir(cls.reshape3, (lv324,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv327 = R.call_tir(cls.reshape3, (lv325,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv328 = R.call_tir(cls.transpose6, (lv316,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv329 = R.call_tir(cls.transpose6, (lv326,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv330 = R.call_tir(cls.transpose6, (lv327,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv884 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv328, lv329, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv885 = R.call_tir(cls.fused_softmax1_cast4, (lv884,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv339 = R.call_tir(cls.matmul10, (lv885, lv330), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv340 = R.call_tir(cls.transpose8, (lv339,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv341 = R.call_tir(cls.reshape8, (lv340,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv886: R.Tensor((4096, 512), dtype="uint32") = params[64] lv887: R.Tensor((4096, 128), dtype="float16") = params[65] lv76_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv886, lv887, lv341, lv75_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv96: R.Tensor((4096,), dtype="float16") = params[71] lv345 = R.call_tir(cls.rms_norm, (lv76_2, lv96), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv890: R.Tensor((22016, 512), dtype="uint32") = params[66] lv891: R.Tensor((22016, 128), dtype="float16") = params[67] lv78_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv890, lv891, lv345), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv893 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv78_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv894: R.Tensor((4096, 1376), dtype="uint32") = params[68] lv895: R.Tensor((4096, 344), dtype="float16") = params[69] lv77_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv894, lv895, lv893, lv76_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv103: R.Tensor((4096,), dtype="float16") = params[80] lv356 = R.call_tir(cls.rms_norm, (lv77_2, lv103), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv898: R.Tensor((12288, 512), dtype="uint32") = params[72] lv899: R.Tensor((12288, 128), dtype="float16") = params[73] lv79_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv898, lv899, lv356), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv359 = R.call_tir(cls.split1, (lv79_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv360: R.Tensor((1, n, 4096), dtype="float16") = lv359[0] lv361 = R.call_tir(cls.reshape7, (lv360,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv362: R.Tensor((1, n, 4096), dtype="float16") = lv359[1] lv363 = R.call_tir(cls.reshape7, (lv362,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv364: R.Tensor((1, n, 4096), dtype="float16") = lv359[2] lv365 = R.call_tir(cls.reshape7, (lv364,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv366 = R.call_tir(cls.rotary_embedding, (lv361, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv367 = R.call_tir(cls.rotary_embedding, (lv363, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv368 = R.call_tir(cls.squeeze1, (lv367,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv369 = R.call_tir(cls.squeeze1, (lv365,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv370: R.Object = kv_cache[14] lv371: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv370, lv368, sinfo_args=(R.Object,)) lv372: R.Object = kv_cache[15] lv373: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv372, lv369, sinfo_args=(R.Object,)) lv374: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv371, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv375: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv373, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv376 = R.call_tir(cls.reshape3, (lv374,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv377 = R.call_tir(cls.reshape3, (lv375,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv378 = R.call_tir(cls.transpose6, (lv366,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv379 = R.call_tir(cls.transpose6, (lv376,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv380 = R.call_tir(cls.transpose6, (lv377,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv901 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv378, lv379, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv902 = R.call_tir(cls.fused_softmax1_cast4, (lv901,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv389 = R.call_tir(cls.matmul10, (lv902, lv380), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv390 = R.call_tir(cls.transpose8, (lv389,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv391 = R.call_tir(cls.reshape8, (lv390,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv903: R.Tensor((4096, 512), dtype="uint32") = params[74] lv904: R.Tensor((4096, 128), dtype="float16") = params[75] lv78_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv903, lv904, lv391, lv77_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv110_1: R.Tensor((4096,), dtype="float16") = params[81] lv395 = R.call_tir(cls.rms_norm, (lv78_2, lv110_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv907: R.Tensor((22016, 512), dtype="uint32") = params[76] lv908: R.Tensor((22016, 128), dtype="float16") = params[77] lv80_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv907, lv908, lv395), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv910 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv80_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv911: R.Tensor((4096, 1376), dtype="uint32") = params[78] lv912: R.Tensor((4096, 344), dtype="float16") = params[79] lv79_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv911, lv912, lv910, lv78_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv117_1: R.Tensor((4096,), dtype="float16") = params[90] lv406 = R.call_tir(cls.rms_norm, (lv79_2, lv117_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv915: R.Tensor((12288, 512), dtype="uint32") = params[82] lv916: R.Tensor((12288, 128), dtype="float16") = params[83] lv81 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv915, lv916, lv406), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv409 = R.call_tir(cls.split1, (lv81,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv410: R.Tensor((1, n, 4096), dtype="float16") = lv409[0] lv411 = R.call_tir(cls.reshape7, (lv410,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv412: R.Tensor((1, n, 4096), dtype="float16") = lv409[1] lv413 = R.call_tir(cls.reshape7, (lv412,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv414: R.Tensor((1, n, 4096), dtype="float16") = lv409[2] lv415 = R.call_tir(cls.reshape7, (lv414,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv416 = R.call_tir(cls.rotary_embedding, (lv411, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv417 = R.call_tir(cls.rotary_embedding, (lv413, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv418 = R.call_tir(cls.squeeze1, (lv417,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv419 = R.call_tir(cls.squeeze1, (lv415,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv420: R.Object = kv_cache[16] lv421: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv420, lv418, sinfo_args=(R.Object,)) lv422: R.Object = kv_cache[17] lv423: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv422, lv419, sinfo_args=(R.Object,)) lv424: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv421, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv425: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv423, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv426 = R.call_tir(cls.reshape3, (lv424,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv427 = R.call_tir(cls.reshape3, (lv425,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv428 = R.call_tir(cls.transpose6, (lv416,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv429 = R.call_tir(cls.transpose6, (lv426,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv430 = R.call_tir(cls.transpose6, (lv427,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv918 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv428, lv429, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv919 = R.call_tir(cls.fused_softmax1_cast4, (lv918,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv439 = R.call_tir(cls.matmul10, (lv919, lv430), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv440 = R.call_tir(cls.transpose8, (lv439,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv441 = R.call_tir(cls.reshape8, (lv440,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv920: R.Tensor((4096, 512), dtype="uint32") = params[84] lv921: R.Tensor((4096, 128), dtype="float16") = params[85] lv80_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv920, lv921, lv441, lv79_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv124_1: R.Tensor((4096,), dtype="float16") = params[91] lv445 = R.call_tir(cls.rms_norm, (lv80_2, lv124_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv924: R.Tensor((22016, 512), dtype="uint32") = params[86] lv925: R.Tensor((22016, 128), dtype="float16") = params[87] lv82_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv924, lv925, lv445), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv927 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv82_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv928: R.Tensor((4096, 1376), dtype="uint32") = params[88] lv929: R.Tensor((4096, 344), dtype="float16") = params[89] lv81_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv928, lv929, lv927, lv80_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv131: R.Tensor((4096,), dtype="float16") = params[100] lv456 = R.call_tir(cls.rms_norm, (lv81_1, lv131), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv932: R.Tensor((12288, 512), dtype="uint32") = params[92] lv933: R.Tensor((12288, 128), dtype="float16") = params[93] lv83 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv932, lv933, lv456), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv459 = R.call_tir(cls.split1, (lv83,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv460: R.Tensor((1, n, 4096), dtype="float16") = lv459[0] lv461 = R.call_tir(cls.reshape7, (lv460,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv462: R.Tensor((1, n, 4096), dtype="float16") = lv459[1] lv463 = R.call_tir(cls.reshape7, (lv462,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv464: R.Tensor((1, n, 4096), dtype="float16") = lv459[2] lv465 = R.call_tir(cls.reshape7, (lv464,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv466 = R.call_tir(cls.rotary_embedding, (lv461, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv467 = R.call_tir(cls.rotary_embedding, (lv463, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv468 = R.call_tir(cls.squeeze1, (lv467,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv469 = R.call_tir(cls.squeeze1, (lv465,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv470: R.Object = kv_cache[18] lv471: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv470, lv468, sinfo_args=(R.Object,)) lv472: R.Object = kv_cache[19] lv473: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv472, lv469, sinfo_args=(R.Object,)) lv474: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv471, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv475: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv473, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv476 = R.call_tir(cls.reshape3, (lv474,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv477 = R.call_tir(cls.reshape3, (lv475,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv478 = R.call_tir(cls.transpose6, (lv466,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv479 = R.call_tir(cls.transpose6, (lv476,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv480 = R.call_tir(cls.transpose6, (lv477,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv935 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv478, lv479, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv936 = R.call_tir(cls.fused_softmax1_cast4, (lv935,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv489 = R.call_tir(cls.matmul10, (lv936, lv480), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv490 = R.call_tir(cls.transpose8, (lv489,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv491 = R.call_tir(cls.reshape8, (lv490,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv937: R.Tensor((4096, 512), dtype="uint32") = params[94] lv938: R.Tensor((4096, 128), dtype="float16") = params[95] lv82_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv937, lv938, lv491, lv81_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv138: R.Tensor((4096,), dtype="float16") = params[101] lv495 = R.call_tir(cls.rms_norm, (lv82_2, lv138), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv941: R.Tensor((22016, 512), dtype="uint32") = params[96] lv942: R.Tensor((22016, 128), dtype="float16") = params[97] lv84 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv941, lv942, lv495), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv944 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv84,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv945: R.Tensor((4096, 1376), dtype="uint32") = params[98] lv946: R.Tensor((4096, 344), dtype="float16") = params[99] lv83_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv945, lv946, lv944, lv82_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv145_1: R.Tensor((4096,), dtype="float16") = params[110] lv506 = R.call_tir(cls.rms_norm, (lv83_1, lv145_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv949: R.Tensor((12288, 512), dtype="uint32") = params[102] lv950: R.Tensor((12288, 128), dtype="float16") = params[103] lv85 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv949, lv950, lv506), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv509 = R.call_tir(cls.split1, (lv85,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv510: R.Tensor((1, n, 4096), dtype="float16") = lv509[0] lv511 = R.call_tir(cls.reshape7, (lv510,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv512: R.Tensor((1, n, 4096), dtype="float16") = lv509[1] lv513 = R.call_tir(cls.reshape7, (lv512,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv514: R.Tensor((1, n, 4096), dtype="float16") = lv509[2] lv515 = R.call_tir(cls.reshape7, (lv514,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv516 = R.call_tir(cls.rotary_embedding, (lv511, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv517 = R.call_tir(cls.rotary_embedding, (lv513, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv518 = R.call_tir(cls.squeeze1, (lv517,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv519 = R.call_tir(cls.squeeze1, (lv515,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv520: R.Object = kv_cache[20] lv521: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv520, lv518, sinfo_args=(R.Object,)) lv522: R.Object = kv_cache[21] lv523: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv522, lv519, sinfo_args=(R.Object,)) lv524: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv521, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv525: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv523, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv526 = R.call_tir(cls.reshape3, (lv524,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv527 = R.call_tir(cls.reshape3, (lv525,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv528 = R.call_tir(cls.transpose6, (lv516,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv529 = R.call_tir(cls.transpose6, (lv526,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv530 = R.call_tir(cls.transpose6, (lv527,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv952 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv528, lv529, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv953 = R.call_tir(cls.fused_softmax1_cast4, (lv952,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv539 = R.call_tir(cls.matmul10, (lv953, lv530), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv540 = R.call_tir(cls.transpose8, (lv539,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv541 = R.call_tir(cls.reshape8, (lv540,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv954: R.Tensor((4096, 512), dtype="uint32") = params[104] lv955: R.Tensor((4096, 128), dtype="float16") = params[105] lv84_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv954, lv955, lv541, lv83_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv152: R.Tensor((4096,), dtype="float16") = params[111] lv545 = R.call_tir(cls.rms_norm, (lv84_1, lv152), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv958: R.Tensor((22016, 512), dtype="uint32") = params[106] lv959: R.Tensor((22016, 128), dtype="float16") = params[107] lv86 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv958, lv959, lv545), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv961 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv86,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv962: R.Tensor((4096, 1376), dtype="uint32") = params[108] lv963: R.Tensor((4096, 344), dtype="float16") = params[109] lv85_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv962, lv963, lv961, lv84_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv159_1: R.Tensor((4096,), dtype="float16") = params[120] lv556 = R.call_tir(cls.rms_norm, (lv85_1, lv159_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv966: R.Tensor((12288, 512), dtype="uint32") = params[112] lv967: R.Tensor((12288, 128), dtype="float16") = params[113] lv87 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv966, lv967, lv556), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv559 = R.call_tir(cls.split1, (lv87,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv560: R.Tensor((1, n, 4096), dtype="float16") = lv559[0] lv561 = R.call_tir(cls.reshape7, (lv560,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv562: R.Tensor((1, n, 4096), dtype="float16") = lv559[1] lv563 = R.call_tir(cls.reshape7, (lv562,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv564: R.Tensor((1, n, 4096), dtype="float16") = lv559[2] lv565 = R.call_tir(cls.reshape7, (lv564,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv566 = R.call_tir(cls.rotary_embedding, (lv561, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv567 = R.call_tir(cls.rotary_embedding, (lv563, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv568 = R.call_tir(cls.squeeze1, (lv567,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv569 = R.call_tir(cls.squeeze1, (lv565,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv570: R.Object = kv_cache[22] lv571: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv570, lv568, sinfo_args=(R.Object,)) lv572: R.Object = kv_cache[23] lv573: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv572, lv569, sinfo_args=(R.Object,)) lv574: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv571, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv575: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv573, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv576 = R.call_tir(cls.reshape3, (lv574,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv577 = R.call_tir(cls.reshape3, (lv575,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv578 = R.call_tir(cls.transpose6, (lv566,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv579 = R.call_tir(cls.transpose6, (lv576,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv580 = R.call_tir(cls.transpose6, (lv577,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv969 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv578, lv579, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv970 = R.call_tir(cls.fused_softmax1_cast4, (lv969,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv589 = R.call_tir(cls.matmul10, (lv970, lv580), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv590 = R.call_tir(cls.transpose8, (lv589,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv591 = R.call_tir(cls.reshape8, (lv590,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv971: R.Tensor((4096, 512), dtype="uint32") = params[114] lv972: R.Tensor((4096, 128), dtype="float16") = params[115] lv86_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv971, lv972, lv591, lv85_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv166_1: R.Tensor((4096,), dtype="float16") = params[121] lv595 = R.call_tir(cls.rms_norm, (lv86_1, lv166_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv975: R.Tensor((22016, 512), dtype="uint32") = params[116] lv976: R.Tensor((22016, 128), dtype="float16") = params[117] lv88 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv975, lv976, lv595), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv978 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv88,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv979: R.Tensor((4096, 1376), dtype="uint32") = params[118] lv980: R.Tensor((4096, 344), dtype="float16") = params[119] lv87_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv979, lv980, lv978, lv86_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv173_1: R.Tensor((4096,), dtype="float16") = params[130] lv606 = R.call_tir(cls.rms_norm, (lv87_1, lv173_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv983: R.Tensor((12288, 512), dtype="uint32") = params[122] lv984: R.Tensor((12288, 128), dtype="float16") = params[123] lv89_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv983, lv984, lv606), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv609 = R.call_tir(cls.split1, (lv89_2,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv610: R.Tensor((1, n, 4096), dtype="float16") = lv609[0] lv611 = R.call_tir(cls.reshape7, (lv610,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv612: R.Tensor((1, n, 4096), dtype="float16") = lv609[1] lv613 = R.call_tir(cls.reshape7, (lv612,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv614: R.Tensor((1, n, 4096), dtype="float16") = lv609[2] lv615 = R.call_tir(cls.reshape7, (lv614,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv616 = R.call_tir(cls.rotary_embedding, (lv611, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv617 = R.call_tir(cls.rotary_embedding, (lv613, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv618 = R.call_tir(cls.squeeze1, (lv617,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv619 = R.call_tir(cls.squeeze1, (lv615,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv620: R.Object = kv_cache[24] lv621: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv620, lv618, sinfo_args=(R.Object,)) lv622: R.Object = kv_cache[25] lv623: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv622, lv619, sinfo_args=(R.Object,)) lv624: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv621, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv625: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv623, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv626 = R.call_tir(cls.reshape3, (lv624,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv627 = R.call_tir(cls.reshape3, (lv625,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv628 = R.call_tir(cls.transpose6, (lv616,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv629 = R.call_tir(cls.transpose6, (lv626,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv630 = R.call_tir(cls.transpose6, (lv627,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv986 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv628, lv629, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv987 = R.call_tir(cls.fused_softmax1_cast4, (lv986,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv639 = R.call_tir(cls.matmul10, (lv987, lv630), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv640 = R.call_tir(cls.transpose8, (lv639,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv641 = R.call_tir(cls.reshape8, (lv640,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv988: R.Tensor((4096, 512), dtype="uint32") = params[124] lv989: R.Tensor((4096, 128), dtype="float16") = params[125] lv88_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv988, lv989, lv641, lv87_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv180_1: R.Tensor((4096,), dtype="float16") = params[131] lv645 = R.call_tir(cls.rms_norm, (lv88_1, lv180_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv992: R.Tensor((22016, 512), dtype="uint32") = params[126] lv993: R.Tensor((22016, 128), dtype="float16") = params[127] lv90_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv992, lv993, lv645), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv995 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv90_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv996: R.Tensor((4096, 1376), dtype="uint32") = params[128] lv997: R.Tensor((4096, 344), dtype="float16") = params[129] lv89_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv996, lv997, lv995, lv88_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv187: R.Tensor((4096,), dtype="float16") = params[140] lv656 = R.call_tir(cls.rms_norm, (lv89_3, lv187), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1000: R.Tensor((12288, 512), dtype="uint32") = params[132] lv1001: R.Tensor((12288, 128), dtype="float16") = params[133] lv91_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1000, lv1001, lv656), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv659 = R.call_tir(cls.split1, (lv91_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv660: R.Tensor((1, n, 4096), dtype="float16") = lv659[0] lv661 = R.call_tir(cls.reshape7, (lv660,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv662: R.Tensor((1, n, 4096), dtype="float16") = lv659[1] lv663 = R.call_tir(cls.reshape7, (lv662,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv664: R.Tensor((1, n, 4096), dtype="float16") = lv659[2] lv665 = R.call_tir(cls.reshape7, (lv664,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv666 = R.call_tir(cls.rotary_embedding, (lv661, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv667 = R.call_tir(cls.rotary_embedding, (lv663, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv668 = R.call_tir(cls.squeeze1, (lv667,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv669 = R.call_tir(cls.squeeze1, (lv665,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv670: R.Object = kv_cache[26] lv671: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv670, lv668, sinfo_args=(R.Object,)) lv672: R.Object = kv_cache[27] lv673: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv672, lv669, sinfo_args=(R.Object,)) lv674: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv671, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv675: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv673, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv676 = R.call_tir(cls.reshape3, (lv674,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv677 = R.call_tir(cls.reshape3, (lv675,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv678 = R.call_tir(cls.transpose6, (lv666,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv679 = R.call_tir(cls.transpose6, (lv676,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv680 = R.call_tir(cls.transpose6, (lv677,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1003 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv678, lv679, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1004 = R.call_tir(cls.fused_softmax1_cast4, (lv1003,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv689 = R.call_tir(cls.matmul10, (lv1004, lv680), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv690 = R.call_tir(cls.transpose8, (lv689,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv691 = R.call_tir(cls.reshape8, (lv690,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1005: R.Tensor((4096, 512), dtype="uint32") = params[134] lv1006: R.Tensor((4096, 128), dtype="float16") = params[135] lv90_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1005, lv1006, lv691, lv89_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv194: R.Tensor((4096,), dtype="float16") = params[141] lv695 = R.call_tir(cls.rms_norm, (lv90_2, lv194), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1009: R.Tensor((22016, 512), dtype="uint32") = params[136] lv1010: R.Tensor((22016, 128), dtype="float16") = params[137] lv92 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1009, lv1010, lv695), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1012 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv92,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1013: R.Tensor((4096, 1376), dtype="uint32") = params[138] lv1014: R.Tensor((4096, 344), dtype="float16") = params[139] lv91_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1013, lv1014, lv1012, lv90_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv201: R.Tensor((4096,), dtype="float16") = params[150] lv706 = R.call_tir(cls.rms_norm, (lv91_2, lv201), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1017: R.Tensor((12288, 512), dtype="uint32") = params[142] lv1018: R.Tensor((12288, 128), dtype="float16") = params[143] lv93 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1017, lv1018, lv706), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv709 = R.call_tir(cls.split1, (lv93,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv710: R.Tensor((1, n, 4096), dtype="float16") = lv709[0] lv711 = R.call_tir(cls.reshape7, (lv710,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv712: R.Tensor((1, n, 4096), dtype="float16") = lv709[1] lv713 = R.call_tir(cls.reshape7, (lv712,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv714: R.Tensor((1, n, 4096), dtype="float16") = lv709[2] lv715 = R.call_tir(cls.reshape7, (lv714,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv716 = R.call_tir(cls.rotary_embedding, (lv711, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv717 = R.call_tir(cls.rotary_embedding, (lv713, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv718 = R.call_tir(cls.squeeze1, (lv717,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv719 = R.call_tir(cls.squeeze1, (lv715,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv720: R.Object = kv_cache[28] lv721: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv720, lv718, sinfo_args=(R.Object,)) lv722: R.Object = kv_cache[29] lv723: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv722, lv719, sinfo_args=(R.Object,)) lv724: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv721, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv725: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv723, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv726 = R.call_tir(cls.reshape3, (lv724,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv727 = R.call_tir(cls.reshape3, (lv725,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv728 = R.call_tir(cls.transpose6, (lv716,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv729 = R.call_tir(cls.transpose6, (lv726,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv730 = R.call_tir(cls.transpose6, (lv727,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1020 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv728, lv729, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1021 = R.call_tir(cls.fused_softmax1_cast4, (lv1020,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv739 = R.call_tir(cls.matmul10, (lv1021, lv730), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv740 = R.call_tir(cls.transpose8, (lv739,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv741 = R.call_tir(cls.reshape8, (lv740,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1022: R.Tensor((4096, 512), dtype="uint32") = params[144] lv1023: R.Tensor((4096, 128), dtype="float16") = params[145] lv92_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1022, lv1023, lv741, lv91_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv208: R.Tensor((4096,), dtype="float16") = params[151] lv745 = R.call_tir(cls.rms_norm, (lv92_1, lv208), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1026: R.Tensor((22016, 512), dtype="uint32") = params[146] lv1027: R.Tensor((22016, 128), dtype="float16") = params[147] lv94 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1026, lv1027, lv745), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1029 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv94,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1030: R.Tensor((4096, 1376), dtype="uint32") = params[148] lv1031: R.Tensor((4096, 344), dtype="float16") = params[149] lv93_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1030, lv1031, lv1029, lv92_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv215_1: R.Tensor((4096,), dtype="float16") = params[160] lv756 = R.call_tir(cls.rms_norm, (lv93_1, lv215_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1034: R.Tensor((12288, 512), dtype="uint32") = params[152] lv1035: R.Tensor((12288, 128), dtype="float16") = params[153] lv95_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1034, lv1035, lv756), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv759 = R.call_tir(cls.split1, (lv95_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv760: R.Tensor((1, n, 4096), dtype="float16") = lv759[0] lv761 = R.call_tir(cls.reshape7, (lv760,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv762: R.Tensor((1, n, 4096), dtype="float16") = lv759[1] lv763 = R.call_tir(cls.reshape7, (lv762,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv764: R.Tensor((1, n, 4096), dtype="float16") = lv759[2] lv765 = R.call_tir(cls.reshape7, (lv764,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv766 = R.call_tir(cls.rotary_embedding, (lv761, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv767 = R.call_tir(cls.rotary_embedding, (lv763, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv768 = R.call_tir(cls.squeeze1, (lv767,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv769 = R.call_tir(cls.squeeze1, (lv765,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv770: R.Object = kv_cache[30] lv771: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv770, lv768, sinfo_args=(R.Object,)) lv772: R.Object = kv_cache[31] lv773: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv772, lv769, sinfo_args=(R.Object,)) lv774: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv771, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv775_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv773, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv776_1 = R.call_tir(cls.reshape3, (lv774,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv777 = R.call_tir(cls.reshape3, (lv775_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv778_1 = R.call_tir(cls.transpose6, (lv766,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv779_1 = R.call_tir(cls.transpose6, (lv776_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv780_1 = R.call_tir(cls.transpose6, (lv777,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1037 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv778_1, lv779_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1038 = R.call_tir(cls.fused_softmax1_cast4, (lv1037,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv789_1 = R.call_tir(cls.matmul10, (lv1038, lv780_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv790 = R.call_tir(cls.transpose8, (lv789_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv791_1 = R.call_tir(cls.reshape8, (lv790,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1039: R.Tensor((4096, 512), dtype="uint32") = params[154] lv1040: R.Tensor((4096, 128), dtype="float16") = params[155] lv94_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1039, lv1040, lv791_1, lv93_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv222_1: R.Tensor((4096,), dtype="float16") = params[161] lv795 = R.call_tir(cls.rms_norm, (lv94_1, lv222_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1043: R.Tensor((22016, 512), dtype="uint32") = params[156] lv1044: R.Tensor((22016, 128), dtype="float16") = params[157] lv96_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1043, lv1044, lv795), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1046 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv96_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1047: R.Tensor((4096, 1376), dtype="uint32") = params[158] lv1048: R.Tensor((4096, 344), dtype="float16") = params[159] lv95_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1047, lv1048, lv1046, lv94_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv229_1: R.Tensor((4096,), dtype="float16") = params[170] lv806_1 = R.call_tir(cls.rms_norm, (lv95_2, lv229_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1051: R.Tensor((12288, 512), dtype="uint32") = params[162] lv1052: R.Tensor((12288, 128), dtype="float16") = params[163] lv97 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1051, lv1052, lv806_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv809_1 = R.call_tir(cls.split1, (lv97,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv810_1: R.Tensor((1, n, 4096), dtype="float16") = lv809_1[0] lv811 = R.call_tir(cls.reshape7, (lv810_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv812: R.Tensor((1, n, 4096), dtype="float16") = lv809_1[1] lv813_1 = R.call_tir(cls.reshape7, (lv812,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv814_1: R.Tensor((1, n, 4096), dtype="float16") = lv809_1[2] lv815 = R.call_tir(cls.reshape7, (lv814_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv816_1 = R.call_tir(cls.rotary_embedding, (lv811, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv817_1 = R.call_tir(cls.rotary_embedding, (lv813_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv818_1 = R.call_tir(cls.squeeze1, (lv817_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv819_1 = R.call_tir(cls.squeeze1, (lv815,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv820: R.Object = kv_cache[32] lv821: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv820, lv818_1, sinfo_args=(R.Object,)) lv822_1: R.Object = kv_cache[33] lv823_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv822_1, lv819_1, sinfo_args=(R.Object,)) lv824: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv821, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv825_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv823_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv826_1 = R.call_tir(cls.reshape3, (lv824,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv827_1 = R.call_tir(cls.reshape3, (lv825_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv828 = R.call_tir(cls.transpose6, (lv816_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv829 = R.call_tir(cls.transpose6, (lv826_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv830_1 = R.call_tir(cls.transpose6, (lv827_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1054 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv828, lv829, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1055 = R.call_tir(cls.fused_softmax1_cast4, (lv1054,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv839_1 = R.call_tir(cls.matmul10, (lv1055, lv830_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv840_1 = R.call_tir(cls.transpose8, (lv839_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv841 = R.call_tir(cls.reshape8, (lv840_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1056: R.Tensor((4096, 512), dtype="uint32") = params[164] lv1057: R.Tensor((4096, 128), dtype="float16") = params[165] lv96_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1056, lv1057, lv841, lv95_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv236: R.Tensor((4096,), dtype="float16") = params[171] lv845 = R.call_tir(cls.rms_norm, (lv96_2, lv236), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1060: R.Tensor((22016, 512), dtype="uint32") = params[166] lv1061: R.Tensor((22016, 128), dtype="float16") = params[167] lv98 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1060, lv1061, lv845), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1063 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv98,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1064: R.Tensor((4096, 1376), dtype="uint32") = params[168] lv1065: R.Tensor((4096, 344), dtype="float16") = params[169] lv97_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1064, lv1065, lv1063, lv96_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv243: R.Tensor((4096,), dtype="float16") = params[180] lv856_1 = R.call_tir(cls.rms_norm, (lv97_1, lv243), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1068: R.Tensor((12288, 512), dtype="uint32") = params[172] lv1069: R.Tensor((12288, 128), dtype="float16") = params[173] lv99 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1068, lv1069, lv856_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv859_1 = R.call_tir(cls.split1, (lv99,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv860_1: R.Tensor((1, n, 4096), dtype="float16") = lv859_1[0] lv861_1 = R.call_tir(cls.reshape7, (lv860_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv862: R.Tensor((1, n, 4096), dtype="float16") = lv859_1[1] lv863 = R.call_tir(cls.reshape7, (lv862,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv864_1: R.Tensor((1, n, 4096), dtype="float16") = lv859_1[2] lv865_1 = R.call_tir(cls.reshape7, (lv864_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv866 = R.call_tir(cls.rotary_embedding, (lv861_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv867_1 = R.call_tir(cls.rotary_embedding, (lv863, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv868_1 = R.call_tir(cls.squeeze1, (lv867_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv869_1 = R.call_tir(cls.squeeze1, (lv865_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv870_1: R.Object = kv_cache[34] lv871: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv870_1, lv868_1, sinfo_args=(R.Object,)) lv872: R.Object = kv_cache[35] lv873_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv872, lv869_1, sinfo_args=(R.Object,)) lv874_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv871, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv875: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv873_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv876_1 = R.call_tir(cls.reshape3, (lv874_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv877_1 = R.call_tir(cls.reshape3, (lv875,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv878_1 = R.call_tir(cls.transpose6, (lv866,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv879 = R.call_tir(cls.transpose6, (lv876_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv880 = R.call_tir(cls.transpose6, (lv877_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1071 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv878_1, lv879, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1072 = R.call_tir(cls.fused_softmax1_cast4, (lv1071,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv889 = R.call_tir(cls.matmul10, (lv1072, lv880), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv890_1 = R.call_tir(cls.transpose8, (lv889,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv891_1 = R.call_tir(cls.reshape8, (lv890_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1073: R.Tensor((4096, 512), dtype="uint32") = params[174] lv1074: R.Tensor((4096, 128), dtype="float16") = params[175] lv98_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1073, lv1074, lv891_1, lv97_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv250: R.Tensor((4096,), dtype="float16") = params[181] lv895_1 = R.call_tir(cls.rms_norm, (lv98_1, lv250), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1077: R.Tensor((22016, 512), dtype="uint32") = params[176] lv1078: R.Tensor((22016, 128), dtype="float16") = params[177] lv100 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1077, lv1078, lv895_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1080 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv100,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1081: R.Tensor((4096, 1376), dtype="uint32") = params[178] lv1082: R.Tensor((4096, 344), dtype="float16") = params[179] lv99_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1081, lv1082, lv1080, lv98_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv257: R.Tensor((4096,), dtype="float16") = params[190] lv906 = R.call_tir(cls.rms_norm, (lv99_1, lv257), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1085: R.Tensor((12288, 512), dtype="uint32") = params[182] lv1086: R.Tensor((12288, 128), dtype="float16") = params[183] lv101 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1085, lv1086, lv906), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv909 = R.call_tir(cls.split1, (lv101,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv910_1: R.Tensor((1, n, 4096), dtype="float16") = lv909[0] lv911_1 = R.call_tir(cls.reshape7, (lv910_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv912_1: R.Tensor((1, n, 4096), dtype="float16") = lv909[1] lv913 = R.call_tir(cls.reshape7, (lv912_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv914: R.Tensor((1, n, 4096), dtype="float16") = lv909[2] lv915_1 = R.call_tir(cls.reshape7, (lv914,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv916_1 = R.call_tir(cls.rotary_embedding, (lv911_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv917 = R.call_tir(cls.rotary_embedding, (lv913, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv918_1 = R.call_tir(cls.squeeze1, (lv917,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv919_1 = R.call_tir(cls.squeeze1, (lv915_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv920_1: R.Object = kv_cache[36] lv921_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv920_1, lv918_1, sinfo_args=(R.Object,)) lv922: R.Object = kv_cache[37] lv923: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv922, lv919_1, sinfo_args=(R.Object,)) lv924_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv921_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv925_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv923, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv926 = R.call_tir(cls.reshape3, (lv924_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv927_1 = R.call_tir(cls.reshape3, (lv925_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv928_1 = R.call_tir(cls.transpose6, (lv916_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv929_1 = R.call_tir(cls.transpose6, (lv926,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv930 = R.call_tir(cls.transpose6, (lv927_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1088 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv928_1, lv929_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1089 = R.call_tir(cls.fused_softmax1_cast4, (lv1088,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv939 = R.call_tir(cls.matmul10, (lv1089, lv930), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv940 = R.call_tir(cls.transpose8, (lv939,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv941_1 = R.call_tir(cls.reshape8, (lv940,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1090: R.Tensor((4096, 512), dtype="uint32") = params[184] lv1091: R.Tensor((4096, 128), dtype="float16") = params[185] lv100_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1090, lv1091, lv941_1, lv99_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv264_1: R.Tensor((4096,), dtype="float16") = params[191] lv945_1 = R.call_tir(cls.rms_norm, (lv100_1, lv264_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1094: R.Tensor((22016, 512), dtype="uint32") = params[186] lv1095: R.Tensor((22016, 128), dtype="float16") = params[187] lv102 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1094, lv1095, lv945_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1097 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv102,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1098: R.Tensor((4096, 1376), dtype="uint32") = params[188] lv1099: R.Tensor((4096, 344), dtype="float16") = params[189] lv101_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1098, lv1099, lv1097, lv100_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv271_1: R.Tensor((4096,), dtype="float16") = params[200] lv956 = R.call_tir(cls.rms_norm, (lv101_1, lv271_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1102: R.Tensor((12288, 512), dtype="uint32") = params[192] lv1103: R.Tensor((12288, 128), dtype="float16") = params[193] lv103_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1102, lv1103, lv956), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv959_1 = R.call_tir(cls.split1, (lv103_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv960: R.Tensor((1, n, 4096), dtype="float16") = lv959_1[0] lv961_1 = R.call_tir(cls.reshape7, (lv960,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv962_1: R.Tensor((1, n, 4096), dtype="float16") = lv959_1[1] lv963_1 = R.call_tir(cls.reshape7, (lv962_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv964: R.Tensor((1, n, 4096), dtype="float16") = lv959_1[2] lv965 = R.call_tir(cls.reshape7, (lv964,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv966_1 = R.call_tir(cls.rotary_embedding, (lv961_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv967_1 = R.call_tir(cls.rotary_embedding, (lv963_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv968 = R.call_tir(cls.squeeze1, (lv967_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv969_1 = R.call_tir(cls.squeeze1, (lv965,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv970_1: R.Object = kv_cache[38] lv971_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv970_1, lv968, sinfo_args=(R.Object,)) lv972_1: R.Object = kv_cache[39] lv973: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv972_1, lv969_1, sinfo_args=(R.Object,)) lv974: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv971_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv975_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv973, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv976_1 = R.call_tir(cls.reshape3, (lv974,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv977 = R.call_tir(cls.reshape3, (lv975_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv978_1 = R.call_tir(cls.transpose6, (lv966_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv979_1 = R.call_tir(cls.transpose6, (lv976_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv980_1 = R.call_tir(cls.transpose6, (lv977,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1105 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv978_1, lv979_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1106 = R.call_tir(cls.fused_softmax1_cast4, (lv1105,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv989_1 = R.call_tir(cls.matmul10, (lv1106, lv980_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv990 = R.call_tir(cls.transpose8, (lv989_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv991 = R.call_tir(cls.reshape8, (lv990,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1107: R.Tensor((4096, 512), dtype="uint32") = params[194] lv1108: R.Tensor((4096, 128), dtype="float16") = params[195] lv102_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1107, lv1108, lv991, lv101_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv278_1: R.Tensor((4096,), dtype="float16") = params[201] lv995_1 = R.call_tir(cls.rms_norm, (lv102_1, lv278_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1111: R.Tensor((22016, 512), dtype="uint32") = params[196] lv1112: R.Tensor((22016, 128), dtype="float16") = params[197] lv104 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1111, lv1112, lv995_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1114 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv104,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1115: R.Tensor((4096, 1376), dtype="uint32") = params[198] lv1116: R.Tensor((4096, 344), dtype="float16") = params[199] lv103_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1115, lv1116, lv1114, lv102_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv285: R.Tensor((4096,), dtype="float16") = params[210] lv1006_1 = R.call_tir(cls.rms_norm, (lv103_2, lv285), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1119: R.Tensor((12288, 512), dtype="uint32") = params[202] lv1120: R.Tensor((12288, 128), dtype="float16") = params[203] lv105 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1119, lv1120, lv1006_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1009_1 = R.call_tir(cls.split1, (lv105,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1010_1: R.Tensor((1, n, 4096), dtype="float16") = lv1009_1[0] lv1011 = R.call_tir(cls.reshape7, (lv1010_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1012_1: R.Tensor((1, n, 4096), dtype="float16") = lv1009_1[1] lv1013_1 = R.call_tir(cls.reshape7, (lv1012_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1014_1: R.Tensor((1, n, 4096), dtype="float16") = lv1009_1[2] lv1015 = R.call_tir(cls.reshape7, (lv1014_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1016 = R.call_tir(cls.rotary_embedding, (lv1011, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1017_1 = R.call_tir(cls.rotary_embedding, (lv1013_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1018_1 = R.call_tir(cls.squeeze1, (lv1017_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1019 = R.call_tir(cls.squeeze1, (lv1015,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1020_1: R.Object = kv_cache[40] lv1021_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1020_1, lv1018_1, sinfo_args=(R.Object,)) lv1022_1: R.Object = kv_cache[41] lv1023_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1022_1, lv1019, sinfo_args=(R.Object,)) lv1024: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1021_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1025: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1023_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1026_1 = R.call_tir(cls.reshape3, (lv1024,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1027_1 = R.call_tir(cls.reshape3, (lv1025,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1028 = R.call_tir(cls.transpose6, (lv1016,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1029_1 = R.call_tir(cls.transpose6, (lv1026_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1030_1 = R.call_tir(cls.transpose6, (lv1027_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1122 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1028, lv1029_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1123 = R.call_tir(cls.fused_softmax1_cast4, (lv1122,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1039_1 = R.call_tir(cls.matmul10, (lv1123, lv1030_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1040_1 = R.call_tir(cls.transpose8, (lv1039_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1041 = R.call_tir(cls.reshape8, (lv1040_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1124: R.Tensor((4096, 512), dtype="uint32") = params[204] lv1125: R.Tensor((4096, 128), dtype="float16") = params[205] lv104_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1124, lv1125, lv1041, lv103_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv292: R.Tensor((4096,), dtype="float16") = params[211] lv1045 = R.call_tir(cls.rms_norm, (lv104_1, lv292), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1128: R.Tensor((22016, 512), dtype="uint32") = params[206] lv1129: R.Tensor((22016, 128), dtype="float16") = params[207] lv106_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1128, lv1129, lv1045), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1131 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv106_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1132: R.Tensor((4096, 1376), dtype="uint32") = params[208] lv1133: R.Tensor((4096, 344), dtype="float16") = params[209] lv105_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1132, lv1133, lv1131, lv104_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv299: R.Tensor((4096,), dtype="float16") = params[220] lv1056_1 = R.call_tir(cls.rms_norm, (lv105_1, lv299), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1136: R.Tensor((12288, 512), dtype="uint32") = params[212] lv1137: R.Tensor((12288, 128), dtype="float16") = params[213] lv107 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1136, lv1137, lv1056_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1059 = R.call_tir(cls.split1, (lv107,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1060_1: R.Tensor((1, n, 4096), dtype="float16") = lv1059[0] lv1061_1 = R.call_tir(cls.reshape7, (lv1060_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1062: R.Tensor((1, n, 4096), dtype="float16") = lv1059[1] lv1063_1 = R.call_tir(cls.reshape7, (lv1062,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1064_1: R.Tensor((1, n, 4096), dtype="float16") = lv1059[2] lv1065_1 = R.call_tir(cls.reshape7, (lv1064_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1066 = R.call_tir(cls.rotary_embedding, (lv1061_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1067 = R.call_tir(cls.rotary_embedding, (lv1063_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1068_1 = R.call_tir(cls.squeeze1, (lv1067,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1069_1 = R.call_tir(cls.squeeze1, (lv1065_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1070: R.Object = kv_cache[42] lv1071_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1070, lv1068_1, sinfo_args=(R.Object,)) lv1072_1: R.Object = kv_cache[43] lv1073_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1072_1, lv1069_1, sinfo_args=(R.Object,)) lv1074_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1071_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1075: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1073_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1076 = R.call_tir(cls.reshape3, (lv1074_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1077_1 = R.call_tir(cls.reshape3, (lv1075,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1078_1 = R.call_tir(cls.transpose6, (lv1066,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1079 = R.call_tir(cls.transpose6, (lv1076,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1080_1 = R.call_tir(cls.transpose6, (lv1077_1,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1139 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1078_1, lv1079, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1140 = R.call_tir(cls.fused_softmax1_cast4, (lv1139,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1089_1 = R.call_tir(cls.matmul10, (lv1140, lv1080_1), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1090_1 = R.call_tir(cls.transpose8, (lv1089_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1091_1 = R.call_tir(cls.reshape8, (lv1090_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1141: R.Tensor((4096, 512), dtype="uint32") = params[214] lv1142: R.Tensor((4096, 128), dtype="float16") = params[215] lv106_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1141, lv1142, lv1091_1, lv105_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv306_1: R.Tensor((4096,), dtype="float16") = params[221] lv1095_1 = R.call_tir(cls.rms_norm, (lv106_2, lv306_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1145: R.Tensor((22016, 512), dtype="uint32") = params[216] lv1146: R.Tensor((22016, 128), dtype="float16") = params[217] lv108 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1145, lv1146, lv1095_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1148 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv108,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1149: R.Tensor((4096, 1376), dtype="uint32") = params[218] lv1150: R.Tensor((4096, 344), dtype="float16") = params[219] lv107_1 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1149, lv1150, lv1148, lv106_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv313_1: R.Tensor((4096,), dtype="float16") = params[230] lv1106_1 = R.call_tir(cls.rms_norm, (lv107_1, lv313_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1153: R.Tensor((12288, 512), dtype="uint32") = params[222] lv1154: R.Tensor((12288, 128), dtype="float16") = params[223] lv109_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1153, lv1154, lv1106_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1109 = R.call_tir(cls.split1, (lv109_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1110: R.Tensor((1, n, 4096), dtype="float16") = lv1109[0] lv1111_1 = R.call_tir(cls.reshape7, (lv1110,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1112_1: R.Tensor((1, n, 4096), dtype="float16") = lv1109[1] lv1113 = R.call_tir(cls.reshape7, (lv1112_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1114_1: R.Tensor((1, n, 4096), dtype="float16") = lv1109[2] lv1115_1 = R.call_tir(cls.reshape7, (lv1114_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1116_1 = R.call_tir(cls.rotary_embedding, (lv1111_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1117 = R.call_tir(cls.rotary_embedding, (lv1113, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1118 = R.call_tir(cls.squeeze1, (lv1117,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1119_1 = R.call_tir(cls.squeeze1, (lv1115_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1120_1: R.Object = kv_cache[44] lv1121: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1120_1, lv1118, sinfo_args=(R.Object,)) lv1122_1: R.Object = kv_cache[45] lv1123_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1122_1, lv1119_1, sinfo_args=(R.Object,)) lv1124_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1121, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1125_1: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1123_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1126 = R.call_tir(cls.reshape3, (lv1124_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1127 = R.call_tir(cls.reshape3, (lv1125_1,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1128_1 = R.call_tir(cls.transpose6, (lv1116_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1129_1 = R.call_tir(cls.transpose6, (lv1126,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1130 = R.call_tir(cls.transpose6, (lv1127,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1156 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1128_1, lv1129_1, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1157 = R.call_tir(cls.fused_softmax1_cast4, (lv1156,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1139_1 = R.call_tir(cls.matmul10, (lv1157, lv1130), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1140_1 = R.call_tir(cls.transpose8, (lv1139_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1141_1 = R.call_tir(cls.reshape8, (lv1140_1,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1158: R.Tensor((4096, 512), dtype="uint32") = params[224] lv1159: R.Tensor((4096, 128), dtype="float16") = params[225] lv108_1 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1158, lv1159, lv1141_1, lv107_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv320_1: R.Tensor((4096,), dtype="float16") = params[231] lv1145_1 = R.call_tir(cls.rms_norm, (lv108_1, lv320_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1162: R.Tensor((22016, 512), dtype="uint32") = params[226] lv1163: R.Tensor((22016, 128), dtype="float16") = params[227] lv110_2 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1162, lv1163, lv1145_1), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1165 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv110_2,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1166: R.Tensor((4096, 1376), dtype="uint32") = params[228] lv1167: R.Tensor((4096, 344), dtype="float16") = params[229] lv109_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1166, lv1167, lv1165, lv108_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv327_1: R.Tensor((4096,), dtype="float16") = params[240] lv1156_1 = R.call_tir(cls.rms_norm, (lv109_2, lv327_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1170: R.Tensor((12288, 512), dtype="uint32") = params[232] lv1171: R.Tensor((12288, 128), dtype="float16") = params[233] lv111_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1170, lv1171, lv1156_1), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1159_1 = R.call_tir(cls.split1, (lv111_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1160: R.Tensor((1, n, 4096), dtype="float16") = lv1159_1[0] lv1161 = R.call_tir(cls.reshape7, (lv1160,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1162_1: R.Tensor((1, n, 4096), dtype="float16") = lv1159_1[1] lv1163_1 = R.call_tir(cls.reshape7, (lv1162_1,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1164: R.Tensor((1, n, 4096), dtype="float16") = lv1159_1[2] lv1165_1 = R.call_tir(cls.reshape7, (lv1164,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1166_1 = R.call_tir(cls.rotary_embedding, (lv1161, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1167_1 = R.call_tir(cls.rotary_embedding, (lv1163_1, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1168 = R.call_tir(cls.squeeze1, (lv1167_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1169 = R.call_tir(cls.squeeze1, (lv1165_1,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1170_1: R.Object = kv_cache[46] lv1171_1: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1170_1, lv1168, sinfo_args=(R.Object,)) lv1172: R.Object = kv_cache[47] lv1173: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1172, lv1169, sinfo_args=(R.Object,)) lv1174: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1171_1, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1175: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1173, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1176 = R.call_tir(cls.reshape3, (lv1174,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1177 = R.call_tir(cls.reshape3, (lv1175,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1178 = R.call_tir(cls.transpose6, (lv1166_1,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1179 = R.call_tir(cls.transpose6, (lv1176,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1180 = R.call_tir(cls.transpose6, (lv1177,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1173_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1178, lv1179, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1174_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1173_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1189 = R.call_tir(cls.matmul10, (lv1174_1, lv1180), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1190 = R.call_tir(cls.transpose8, (lv1189,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1191 = R.call_tir(cls.reshape8, (lv1190,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1175_1: R.Tensor((4096, 512), dtype="uint32") = params[234] lv1176_1: R.Tensor((4096, 128), dtype="float16") = params[235] lv110_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1175_1, lv1176_1, lv1191, lv109_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv334: R.Tensor((4096,), dtype="float16") = params[241] lv1195 = R.call_tir(cls.rms_norm, (lv110_3, lv334), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1179_1: R.Tensor((22016, 512), dtype="uint32") = params[236] lv1180_1: R.Tensor((22016, 128), dtype="float16") = params[237] lv112_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1179_1, lv1180_1, lv1195), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1182 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv112_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1183: R.Tensor((4096, 1376), dtype="uint32") = params[238] lv1184: R.Tensor((4096, 344), dtype="float16") = params[239] lv111_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1183, lv1184, lv1182, lv110_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv341_1: R.Tensor((4096,), dtype="float16") = params[250] lv1206 = R.call_tir(cls.rms_norm, (lv111_2, lv341_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1187: R.Tensor((12288, 512), dtype="uint32") = params[242] lv1188: R.Tensor((12288, 128), dtype="float16") = params[243] lv113_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1187, lv1188, lv1206), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1209 = R.call_tir(cls.split1, (lv113_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1210: R.Tensor((1, n, 4096), dtype="float16") = lv1209[0] lv1211 = R.call_tir(cls.reshape7, (lv1210,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1212: R.Tensor((1, n, 4096), dtype="float16") = lv1209[1] lv1213 = R.call_tir(cls.reshape7, (lv1212,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1214: R.Tensor((1, n, 4096), dtype="float16") = lv1209[2] lv1215 = R.call_tir(cls.reshape7, (lv1214,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1216 = R.call_tir(cls.rotary_embedding, (lv1211, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1217 = R.call_tir(cls.rotary_embedding, (lv1213, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1218 = R.call_tir(cls.squeeze1, (lv1217,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1219 = R.call_tir(cls.squeeze1, (lv1215,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1220: R.Object = kv_cache[48] lv1221: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1220, lv1218, sinfo_args=(R.Object,)) lv1222: R.Object = kv_cache[49] lv1223: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1222, lv1219, sinfo_args=(R.Object,)) lv1224: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1221, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1225: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1223, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1226 = R.call_tir(cls.reshape3, (lv1224,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1227 = R.call_tir(cls.reshape3, (lv1225,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1228 = R.call_tir(cls.transpose6, (lv1216,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1229 = R.call_tir(cls.transpose6, (lv1226,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1230 = R.call_tir(cls.transpose6, (lv1227,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1190_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1228, lv1229, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1191_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1190_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1239 = R.call_tir(cls.matmul10, (lv1191_1, lv1230), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1240 = R.call_tir(cls.transpose8, (lv1239,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1241 = R.call_tir(cls.reshape8, (lv1240,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1192: R.Tensor((4096, 512), dtype="uint32") = params[244] lv1193: R.Tensor((4096, 128), dtype="float16") = params[245] lv112_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1192, lv1193, lv1241, lv111_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv348: R.Tensor((4096,), dtype="float16") = params[251] lv1245 = R.call_tir(cls.rms_norm, (lv112_2, lv348), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1196: R.Tensor((22016, 512), dtype="uint32") = params[246] lv1197: R.Tensor((22016, 128), dtype="float16") = params[247] lv114_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1196, lv1197, lv1245), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1199 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv114_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1200: R.Tensor((4096, 1376), dtype="uint32") = params[248] lv1201: R.Tensor((4096, 344), dtype="float16") = params[249] lv113_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1200, lv1201, lv1199, lv112_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv355: R.Tensor((4096,), dtype="float16") = params[260] lv1256 = R.call_tir(cls.rms_norm, (lv113_2, lv355), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1204: R.Tensor((12288, 512), dtype="uint32") = params[252] lv1205: R.Tensor((12288, 128), dtype="float16") = params[253] lv115_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1204, lv1205, lv1256), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1259 = R.call_tir(cls.split1, (lv115_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1260: R.Tensor((1, n, 4096), dtype="float16") = lv1259[0] lv1261 = R.call_tir(cls.reshape7, (lv1260,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1262: R.Tensor((1, n, 4096), dtype="float16") = lv1259[1] lv1263 = R.call_tir(cls.reshape7, (lv1262,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1264: R.Tensor((1, n, 4096), dtype="float16") = lv1259[2] lv1265 = R.call_tir(cls.reshape7, (lv1264,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1266 = R.call_tir(cls.rotary_embedding, (lv1261, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1267 = R.call_tir(cls.rotary_embedding, (lv1263, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1268 = R.call_tir(cls.squeeze1, (lv1267,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1269 = R.call_tir(cls.squeeze1, (lv1265,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1270: R.Object = kv_cache[50] lv1271: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1270, lv1268, sinfo_args=(R.Object,)) lv1272: R.Object = kv_cache[51] lv1273: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1272, lv1269, sinfo_args=(R.Object,)) lv1274: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1271, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1275: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1273, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1276 = R.call_tir(cls.reshape3, (lv1274,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1277 = R.call_tir(cls.reshape3, (lv1275,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1278 = R.call_tir(cls.transpose6, (lv1266,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1279 = R.call_tir(cls.transpose6, (lv1276,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1280 = R.call_tir(cls.transpose6, (lv1277,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1207 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1278, lv1279, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1208 = R.call_tir(cls.fused_softmax1_cast4, (lv1207,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1289 = R.call_tir(cls.matmul10, (lv1208, lv1280), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1290 = R.call_tir(cls.transpose8, (lv1289,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1291 = R.call_tir(cls.reshape8, (lv1290,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1209_1: R.Tensor((4096, 512), dtype="uint32") = params[254] lv1210_1: R.Tensor((4096, 128), dtype="float16") = params[255] lv114_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1209_1, lv1210_1, lv1291, lv113_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv362_1: R.Tensor((4096,), dtype="float16") = params[261] lv1295 = R.call_tir(cls.rms_norm, (lv114_2, lv362_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1213_1: R.Tensor((22016, 512), dtype="uint32") = params[256] lv1214_1: R.Tensor((22016, 128), dtype="float16") = params[257] lv116_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1213_1, lv1214_1, lv1295), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1216_1 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv116_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1217_1: R.Tensor((4096, 1376), dtype="uint32") = params[258] lv1218_1: R.Tensor((4096, 344), dtype="float16") = params[259] lv115_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1217_1, lv1218_1, lv1216_1, lv114_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv369_1: R.Tensor((4096,), dtype="float16") = params[270] lv1306 = R.call_tir(cls.rms_norm, (lv115_2, lv369_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1221_1: R.Tensor((12288, 512), dtype="uint32") = params[262] lv1222_1: R.Tensor((12288, 128), dtype="float16") = params[263] lv117_2 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1221_1, lv1222_1, lv1306), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1309 = R.call_tir(cls.split1, (lv117_2,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1310: R.Tensor((1, n, 4096), dtype="float16") = lv1309[0] lv1311 = R.call_tir(cls.reshape7, (lv1310,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1312: R.Tensor((1, n, 4096), dtype="float16") = lv1309[1] lv1313 = R.call_tir(cls.reshape7, (lv1312,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1314: R.Tensor((1, n, 4096), dtype="float16") = lv1309[2] lv1315 = R.call_tir(cls.reshape7, (lv1314,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1316 = R.call_tir(cls.rotary_embedding, (lv1311, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1317 = R.call_tir(cls.rotary_embedding, (lv1313, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1318 = R.call_tir(cls.squeeze1, (lv1317,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1319 = R.call_tir(cls.squeeze1, (lv1315,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1320: R.Object = kv_cache[52] lv1321: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1320, lv1318, sinfo_args=(R.Object,)) lv1322: R.Object = kv_cache[53] lv1323: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1322, lv1319, sinfo_args=(R.Object,)) lv1324: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1321, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1325: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1323, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1326 = R.call_tir(cls.reshape3, (lv1324,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1327 = R.call_tir(cls.reshape3, (lv1325,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1328 = R.call_tir(cls.transpose6, (lv1316,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1329 = R.call_tir(cls.transpose6, (lv1326,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1330 = R.call_tir(cls.transpose6, (lv1327,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1224_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1328, lv1329, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1225_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1224_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1339 = R.call_tir(cls.matmul10, (lv1225_1, lv1330), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1340 = R.call_tir(cls.transpose8, (lv1339,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1341 = R.call_tir(cls.reshape8, (lv1340,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1226_1: R.Tensor((4096, 512), dtype="uint32") = params[264] lv1227_1: R.Tensor((4096, 128), dtype="float16") = params[265] lv116_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1226_1, lv1227_1, lv1341, lv115_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv376_1: R.Tensor((4096,), dtype="float16") = params[271] lv1345 = R.call_tir(cls.rms_norm, (lv116_2, lv376_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1230_1: R.Tensor((22016, 512), dtype="uint32") = params[266] lv1231: R.Tensor((22016, 128), dtype="float16") = params[267] lv118_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1230_1, lv1231, lv1345), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1233 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv118_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1234: R.Tensor((4096, 1376), dtype="uint32") = params[268] lv1235: R.Tensor((4096, 344), dtype="float16") = params[269] lv117_3 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1234, lv1235, lv1233, lv116_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv383: R.Tensor((4096,), dtype="float16") = params[280] lv1356 = R.call_tir(cls.rms_norm, (lv117_3, lv383), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1238: R.Tensor((12288, 512), dtype="uint32") = params[272] lv1239_1: R.Tensor((12288, 128), dtype="float16") = params[273] lv119_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1238, lv1239_1, lv1356), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1359 = R.call_tir(cls.split1, (lv119_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1360: R.Tensor((1, n, 4096), dtype="float16") = lv1359[0] lv1361 = R.call_tir(cls.reshape7, (lv1360,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1362: R.Tensor((1, n, 4096), dtype="float16") = lv1359[1] lv1363 = R.call_tir(cls.reshape7, (lv1362,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1364: R.Tensor((1, n, 4096), dtype="float16") = lv1359[2] lv1365 = R.call_tir(cls.reshape7, (lv1364,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1366 = R.call_tir(cls.rotary_embedding, (lv1361, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1367 = R.call_tir(cls.rotary_embedding, (lv1363, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1368 = R.call_tir(cls.squeeze1, (lv1367,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1369 = R.call_tir(cls.squeeze1, (lv1365,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1370: R.Object = kv_cache[54] lv1371: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1370, lv1368, sinfo_args=(R.Object,)) lv1372: R.Object = kv_cache[55] lv1373: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1372, lv1369, sinfo_args=(R.Object,)) lv1374: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1371, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1375: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1373, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1376 = R.call_tir(cls.reshape3, (lv1374,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1377 = R.call_tir(cls.reshape3, (lv1375,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1378 = R.call_tir(cls.transpose6, (lv1366,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1379 = R.call_tir(cls.transpose6, (lv1376,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1380 = R.call_tir(cls.transpose6, (lv1377,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1241_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1378, lv1379, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1242 = R.call_tir(cls.fused_softmax1_cast4, (lv1241_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1389 = R.call_tir(cls.matmul10, (lv1242, lv1380), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1390 = R.call_tir(cls.transpose8, (lv1389,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1391 = R.call_tir(cls.reshape8, (lv1390,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1243: R.Tensor((4096, 512), dtype="uint32") = params[274] lv1244: R.Tensor((4096, 128), dtype="float16") = params[275] lv118_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1243, lv1244, lv1391, lv117_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv390_1: R.Tensor((4096,), dtype="float16") = params[281] lv1395 = R.call_tir(cls.rms_norm, (lv118_2, lv390_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1247: R.Tensor((22016, 512), dtype="uint32") = params[276] lv1248: R.Tensor((22016, 128), dtype="float16") = params[277] lv120_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1247, lv1248, lv1395), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1250 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv120_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1251: R.Tensor((4096, 1376), dtype="uint32") = params[278] lv1252: R.Tensor((4096, 344), dtype="float16") = params[279] lv119_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1251, lv1252, lv1250, lv118_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv397: R.Tensor((4096,), dtype="float16") = params[290] lv1406 = R.call_tir(cls.rms_norm, (lv119_2, lv397), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1255: R.Tensor((12288, 512), dtype="uint32") = params[282] lv1256_1: R.Tensor((12288, 128), dtype="float16") = params[283] lv121_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1255, lv1256_1, lv1406), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1409 = R.call_tir(cls.split1, (lv121_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1410: R.Tensor((1, n, 4096), dtype="float16") = lv1409[0] lv1411 = R.call_tir(cls.reshape7, (lv1410,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1412: R.Tensor((1, n, 4096), dtype="float16") = lv1409[1] lv1413 = R.call_tir(cls.reshape7, (lv1412,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1414: R.Tensor((1, n, 4096), dtype="float16") = lv1409[2] lv1415 = R.call_tir(cls.reshape7, (lv1414,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1416 = R.call_tir(cls.rotary_embedding, (lv1411, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1417 = R.call_tir(cls.rotary_embedding, (lv1413, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1418 = R.call_tir(cls.squeeze1, (lv1417,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1419 = R.call_tir(cls.squeeze1, (lv1415,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1420: R.Object = kv_cache[56] lv1421: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1420, lv1418, sinfo_args=(R.Object,)) lv1422: R.Object = kv_cache[57] lv1423: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1422, lv1419, sinfo_args=(R.Object,)) lv1424: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1421, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1425: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1423, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1426 = R.call_tir(cls.reshape3, (lv1424,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1427 = R.call_tir(cls.reshape3, (lv1425,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1428 = R.call_tir(cls.transpose6, (lv1416,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1429 = R.call_tir(cls.transpose6, (lv1426,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1430 = R.call_tir(cls.transpose6, (lv1427,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1258 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1428, lv1429, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1259_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1258,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1439 = R.call_tir(cls.matmul10, (lv1259_1, lv1430), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1440 = R.call_tir(cls.transpose8, (lv1439,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1441 = R.call_tir(cls.reshape8, (lv1440,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1260_1: R.Tensor((4096, 512), dtype="uint32") = params[284] lv1261_1: R.Tensor((4096, 128), dtype="float16") = params[285] lv120_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1260_1, lv1261_1, lv1441, lv119_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv404: R.Tensor((4096,), dtype="float16") = params[291] lv1445 = R.call_tir(cls.rms_norm, (lv120_2, lv404), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1264_1: R.Tensor((22016, 512), dtype="uint32") = params[286] lv1265_1: R.Tensor((22016, 128), dtype="float16") = params[287] lv122_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1264_1, lv1265_1, lv1445), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1267_1 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv122_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1268_1: R.Tensor((4096, 1376), dtype="uint32") = params[288] lv1269_1: R.Tensor((4096, 344), dtype="float16") = params[289] lv121_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1268_1, lv1269_1, lv1267_1, lv120_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv411_1: R.Tensor((4096,), dtype="float16") = params[300] lv1456 = R.call_tir(cls.rms_norm, (lv121_2, lv411_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1272_1: R.Tensor((12288, 512), dtype="uint32") = params[292] lv1273_1: R.Tensor((12288, 128), dtype="float16") = params[293] lv123_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1272_1, lv1273_1, lv1456), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1459 = R.call_tir(cls.split1, (lv123_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1460: R.Tensor((1, n, 4096), dtype="float16") = lv1459[0] lv1461 = R.call_tir(cls.reshape7, (lv1460,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1462: R.Tensor((1, n, 4096), dtype="float16") = lv1459[1] lv1463 = R.call_tir(cls.reshape7, (lv1462,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1464: R.Tensor((1, n, 4096), dtype="float16") = lv1459[2] lv1465 = R.call_tir(cls.reshape7, (lv1464,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1466 = R.call_tir(cls.rotary_embedding, (lv1461, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1467 = R.call_tir(cls.rotary_embedding, (lv1463, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1468 = R.call_tir(cls.squeeze1, (lv1467,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1469 = R.call_tir(cls.squeeze1, (lv1465,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1470: R.Object = kv_cache[58] lv1471: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1470, lv1468, sinfo_args=(R.Object,)) lv1472: R.Object = kv_cache[59] lv1473: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1472, lv1469, sinfo_args=(R.Object,)) lv1474: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1471, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1475: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1473, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1476 = R.call_tir(cls.reshape3, (lv1474,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1477 = R.call_tir(cls.reshape3, (lv1475,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1478 = R.call_tir(cls.transpose6, (lv1466,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1479 = R.call_tir(cls.transpose6, (lv1476,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1480 = R.call_tir(cls.transpose6, (lv1477,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1275_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1478, lv1479, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1276_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1275_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1489 = R.call_tir(cls.matmul10, (lv1276_1, lv1480), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1490 = R.call_tir(cls.transpose8, (lv1489,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1491 = R.call_tir(cls.reshape8, (lv1490,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1277_1: R.Tensor((4096, 512), dtype="uint32") = params[294] lv1278_1: R.Tensor((4096, 128), dtype="float16") = params[295] lv122_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1277_1, lv1278_1, lv1491, lv121_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv418_1: R.Tensor((4096,), dtype="float16") = params[301] lv1495 = R.call_tir(cls.rms_norm, (lv122_2, lv418_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1281: R.Tensor((22016, 512), dtype="uint32") = params[296] lv1282: R.Tensor((22016, 128), dtype="float16") = params[297] lv124_2 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1281, lv1282, lv1495), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1284 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv124_2,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1285: R.Tensor((4096, 1376), dtype="uint32") = params[298] lv1286: R.Tensor((4096, 344), dtype="float16") = params[299] lv123_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1285, lv1286, lv1284, lv122_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv425_1: R.Tensor((4096,), dtype="float16") = params[310] lv1506 = R.call_tir(cls.rms_norm, (lv123_2, lv425_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1289_1: R.Tensor((12288, 512), dtype="uint32") = params[302] lv1290_1: R.Tensor((12288, 128), dtype="float16") = params[303] lv125_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1289_1, lv1290_1, lv1506), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1509 = R.call_tir(cls.split1, (lv125_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1510: R.Tensor((1, n, 4096), dtype="float16") = lv1509[0] lv1511 = R.call_tir(cls.reshape7, (lv1510,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1512: R.Tensor((1, n, 4096), dtype="float16") = lv1509[1] lv1513 = R.call_tir(cls.reshape7, (lv1512,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1514: R.Tensor((1, n, 4096), dtype="float16") = lv1509[2] lv1515 = R.call_tir(cls.reshape7, (lv1514,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1516 = R.call_tir(cls.rotary_embedding, (lv1511, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1517 = R.call_tir(cls.rotary_embedding, (lv1513, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1518 = R.call_tir(cls.squeeze1, (lv1517,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1519 = R.call_tir(cls.squeeze1, (lv1515,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1520: R.Object = kv_cache[60] lv1521: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1520, lv1518, sinfo_args=(R.Object,)) lv1522: R.Object = kv_cache[61] lv1523: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1522, lv1519, sinfo_args=(R.Object,)) lv1524: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1521, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1525: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1523, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1526 = R.call_tir(cls.reshape3, (lv1524,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1527 = R.call_tir(cls.reshape3, (lv1525,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1528 = R.call_tir(cls.transpose6, (lv1516,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1529 = R.call_tir(cls.transpose6, (lv1526,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1530 = R.call_tir(cls.transpose6, (lv1527,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1292 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1528, lv1529, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1293 = R.call_tir(cls.fused_softmax1_cast4, (lv1292,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1539 = R.call_tir(cls.matmul10, (lv1293, lv1530), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1540 = R.call_tir(cls.transpose8, (lv1539,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1541 = R.call_tir(cls.reshape8, (lv1540,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1294: R.Tensor((4096, 512), dtype="uint32") = params[304] lv1295_1: R.Tensor((4096, 128), dtype="float16") = params[305] lv124_3 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1294, lv1295_1, lv1541, lv123_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv432: R.Tensor((4096,), dtype="float16") = params[311] lv1545 = R.call_tir(cls.rms_norm, (lv124_3, lv432), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1298: R.Tensor((22016, 512), dtype="uint32") = params[306] lv1299: R.Tensor((22016, 128), dtype="float16") = params[307] lv126_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1298, lv1299, lv1545), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1301 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv126_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1302: R.Tensor((4096, 1376), dtype="uint32") = params[308] lv1303: R.Tensor((4096, 344), dtype="float16") = params[309] lv125_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1302, lv1303, lv1301, lv124_3), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv439_1: R.Tensor((4096,), dtype="float16") = params[320] lv1556 = R.call_tir(cls.rms_norm, (lv125_2, lv439_1), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1306_1: R.Tensor((12288, 512), dtype="uint32") = params[312] lv1307: R.Tensor((12288, 128), dtype="float16") = params[313] lv127_1 = R.call_tir(cls.fused_fused_decode2_NT_matmul, (lv1306_1, lv1307, lv1556), out_sinfo=R.Tensor((1, n, 12288), dtype="float16")) lv1559 = R.call_tir(cls.split1, (lv127_1,), out_sinfo=[R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16"), R.Tensor((1, n, 4096), dtype="float16")]) lv1560: R.Tensor((1, n, 4096), dtype="float16") = lv1559[0] lv1561 = R.call_tir(cls.reshape7, (lv1560,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1562: R.Tensor((1, n, 4096), dtype="float16") = lv1559[1] lv1563 = R.call_tir(cls.reshape7, (lv1562,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1564: R.Tensor((1, n, 4096), dtype="float16") = lv1559[2] lv1565 = R.call_tir(cls.reshape7, (lv1564,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1566 = R.call_tir(cls.rotary_embedding, (lv1561, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1567 = R.call_tir(cls.rotary_embedding, (lv1563, lv7, lv8), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16"), tir_vars=R.shape([m])) lv1568 = R.call_tir(cls.squeeze1, (lv1567,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1569 = R.call_tir(cls.squeeze1, (lv1565,), out_sinfo=R.Tensor((n, 32, 128), dtype="float16")) lv1570: R.Object = kv_cache[62] lv1571: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1570, lv1568, sinfo_args=(R.Object,)) lv1572: R.Object = kv_cache[63] lv1573: R.Object = R.call_packed("vm.builtin.attention_kv_cache_append", lv1572, lv1569, sinfo_args=(R.Object,)) lv1574: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1571, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1575: R.Tensor((m, 32, 128), dtype="float16") = R.call_packed("vm.builtin.attention_kv_cache_view", lv1573, R.shape([m, 32, 128]), sinfo_args=(R.Tensor((m, 32, 128), dtype="float16"),)) lv1576 = R.call_tir(cls.reshape3, (lv1574,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1577 = R.call_tir(cls.reshape3, (lv1575,), out_sinfo=R.Tensor((1, m, 32, 128), dtype="float16")) lv1578 = R.call_tir(cls.transpose6, (lv1566,), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1579 = R.call_tir(cls.transpose6, (lv1576,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1580 = R.call_tir(cls.transpose6, (lv1577,), out_sinfo=R.Tensor((1, 32, m, 128), dtype="float16")) lv1309_1 = R.call_tir(cls.fused_NT_matmul1_divide1_maximum1_minimum1_cast3, (lv1578, lv1579, lv5), out_sinfo=R.Tensor((1, 32, n, m), dtype="float32")) lv1310_1 = R.call_tir(cls.fused_softmax1_cast4, (lv1309_1,), out_sinfo=R.Tensor((1, 32, n, m), dtype="float16")) lv1589 = R.call_tir(cls.matmul10, (lv1310_1, lv1580), out_sinfo=R.Tensor((1, 32, n, 128), dtype="float16")) lv1590 = R.call_tir(cls.transpose8, (lv1589,), out_sinfo=R.Tensor((1, n, 32, 128), dtype="float16")) lv1591 = R.call_tir(cls.reshape8, (lv1590,), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1311_1: R.Tensor((4096, 512), dtype="uint32") = params[314] lv1312_1: R.Tensor((4096, 128), dtype="float16") = params[315] lv126_2 = R.call_tir(cls.fused_fused_decode3_fused_NT_matmul2_add1, (lv1311_1, lv1312_1, lv1591, lv125_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv446: R.Tensor((4096,), dtype="float16") = params[321] lv1595 = R.call_tir(cls.rms_norm, (lv126_2, lv446), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1315_1: R.Tensor((22016, 512), dtype="uint32") = params[316] lv1316_1: R.Tensor((22016, 128), dtype="float16") = params[317] lv128_1 = R.call_tir(cls.fused_fused_decode4_NT_matmul3, (lv1315_1, lv1316_1, lv1595), out_sinfo=R.Tensor((1, n, 22016), dtype="float16")) lv1318_1 = R.call_tir(cls.fused_split2_silu1_multiply1, (lv128_1,), out_sinfo=R.Tensor((1, n, 11008), dtype="float16")) lv1319_1: R.Tensor((4096, 1376), dtype="uint32") = params[318] lv1320_1: R.Tensor((4096, 344), dtype="float16") = params[319] lv127_2 = R.call_tir(cls.fused_fused_decode5_fused_NT_matmul4_add1, (lv1319_1, lv1320_1, lv1318_1, lv126_2), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv453: R.Tensor((4096,), dtype="float16") = params[322] lv1606 = R.call_tir(cls.rms_norm, (lv127_2, lv453), out_sinfo=R.Tensor((1, n, 4096), dtype="float16")) lv1607 = R.call_tir(cls.slice, (lv1606,), out_sinfo=R.Tensor((1, 1, 4096), dtype="float16")) lv1323_1: R.Tensor((32001, 512), dtype="uint32") = params[323] lv1324_1: R.Tensor((32001, 128), dtype="float16") = params[324] lv129_1 = R.call_tir(cls.fused_fused_decode1_fused_NT_matmul5_cast2, (lv1323_1, lv1324_1, lv1607), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32")) gv: R.Tuple(R.Tensor((1, 1, 32001), dtype="float32"), R.Tuple(R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object, R.Object)) = lv129_1, (lv21, lv23, lv71, lv73, lv121, lv123, lv171, lv173, lv221, lv223, lv271, lv273, lv321, lv323, lv371, lv373, lv421, lv423, lv471, lv473, lv521, lv523, lv571, lv573, lv621, lv623, lv671, lv673, lv721, lv723, lv771, lv773, lv821, lv823_1, lv871, lv873_1, lv921_1, lv923, lv971_1, lv973, lv1021_1, lv1023_1, lv1071_1, lv1073_1, lv1121, lv1123_1, lv1171_1, lv1173, lv1221, lv1223, lv1271, lv1273, lv1321, lv1323, lv1371, lv1373, lv1421, lv1423, lv1471, lv1473, lv1521, lv1523, lv1571, lv1573) R.output(gv) return gv @R.function def softmax_with_temperature(logits: R.Tensor((1, 1, 32001), dtype="float32"), temperature: R.Tensor((), dtype="float32")) -> R.Tensor((1, 1, 32001), dtype="float32"): R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}}) cls = Module with R.dataflow(): lv3285 = R.call_tir(cls.divide2, (logits, temperature), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32")) lv3286 = R.call_tir(cls.softmax2, (lv3285,), out_sinfo=R.Tensor((1, 1, 32001), dtype="float32")) gv3: R.Tensor((1, 1, 32001), dtype="float32") = lv3286 R.output(gv3) return gv3